Fix DFlash oerformance with split mode graph

This commit is contained in:
Kawrakow 2026-06-17 05:46:05 +00:00
parent 064d23a6f8
commit 5b9c3bbc3b
7 changed files with 44 additions and 21 deletions

View File

@ -319,22 +319,22 @@ static std::vector<std::string> ctrlvec_load_prompt_file(std::string path, bool
////////////////////////////////////////////////// //////////////////////////////////////////////////
static bool cb_eval(struct ggml_tensor * t, bool ask, void * user_data) { static int cb_eval(struct ggml_tensor * t, bool ask, void * user_data) {
auto * cb_data = (callback_data *) user_data; auto * cb_data = (callback_data *) user_data;
static const char * l_out_name = "l_out"; static const char * l_out_name = "l_out";
const bool is_l_out = strncmp(t->name, l_out_name, strlen(l_out_name)) == 0; const bool is_l_out = strncmp(t->name, l_out_name, strlen(l_out_name)) == 0;
if (ask) { if (ask) {
return is_l_out; return is_l_out ? 1 : 0;
} }
if (!is_l_out || t->ne[1] != cb_data->n_tokens) { if (!is_l_out || t->ne[1] != cb_data->n_tokens) {
return true; return 1;
} }
// save the tensor to current context // save the tensor to current context
cb_data->save_tensor_for_layer(t); cb_data->save_tensor_for_layer(t);
return true; return 1;
} }
static bool get_hidden_layers(llama_context * ctx, std::vector<llama_token> & tokens) { static bool get_hidden_layers(llama_context * ctx, std::vector<llama_token> & tokens) {

View File

@ -87,14 +87,14 @@ static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne
* @param user_data user data to pass at each call back * @param user_data user data to pass at each call back
* @return true to receive data or continue the graph, false otherwise * @return true to receive data or continue the graph, false otherwise
*/ */
static bool ggml_debug(struct ggml_tensor * t, bool ask, void * user_data) { static int ggml_debug(struct ggml_tensor * t, bool ask, void * user_data) {
auto * cb_data = (callback_data *) user_data; auto * cb_data = (callback_data *) user_data;
const struct ggml_tensor * src0 = t->src[0]; const struct ggml_tensor * src0 = t->src[0];
const struct ggml_tensor * src1 = t->src[1]; const struct ggml_tensor * src1 = t->src[1];
if (ask) { if (ask) {
return true; // Always retrieve data return 1; // Always retrieve data
} }
char src1_str[128] = {0}; char src1_str[128] = {0};
@ -123,7 +123,7 @@ static bool ggml_debug(struct ggml_tensor * t, bool ask, void * user_data) {
ggml_print_tensor(data, t->type, t->ne, t->nb, 3); ggml_print_tensor(data, t->type, t->ne, t->nb, 3);
} }
return true; return 1;
} }
static bool run(llama_context * ctx, const gpt_params & params) { static bool run(llama_context * ctx, const gpt_params & params) {

View File

@ -791,8 +791,8 @@ static IMatrixCollector * ik_get_imatrix_collector(void * user_data) {
return user_data != nullptr ? static_cast<IMatrixCollector *>(user_data) : &g_target_collector; return user_data != nullptr ? static_cast<IMatrixCollector *>(user_data) : &g_target_collector;
} }
static bool ik_collect_imatrix(struct ggml_tensor * t, bool ask, void * user_data) { static int ik_collect_imatrix(struct ggml_tensor * t, bool ask, void * user_data) {
return ik_get_imatrix_collector(user_data)->collect_imatrix(t, ask, user_data); return ik_get_imatrix_collector(user_data)->collect_imatrix(t, ask, user_data) ? 1 : 0;
} }

View File

@ -1968,6 +1968,23 @@ void server_context::kv_cache_clear() {
clean_kv_cache = false; clean_kv_cache = false;
} }
static inline int server_decode(llama_context * ctx, const llama_batch & batch) {
#if 0
static int64_t tot_time = 0;
static int64_t ncalls = 0;
auto tim1 = ggml_time_us();
int ret = llama_decode(ctx, batch);
llama_synchronize(ctx);
auto tim2 = ggml_time_us();
tot_time += tim2 - tim1;
++ncalls;
LOG_INF("%s: %ld calls, %g ms, %g us/call\n", __func__, ncalls, 1e-3*tot_time, 1.*tot_time/ncalls);
return ret;
#else
return llama_decode(ctx, batch);
#endif
}
void server_context::system_prompt_update() { void server_context::system_prompt_update() {
LOG_VERBOSE("system prompt update", { LOG_VERBOSE("system prompt update", {
{"system_prompt", system_prompt}, {"system_prompt", system_prompt},
@ -1991,7 +2008,7 @@ void server_context::system_prompt_update() {
common_batch_add(batch, system_tokens[i + j], i + j, { 0 }, false); common_batch_add(batch, system_tokens[i + j], i + j, { 0 }, false);
} }
if (llama_decode(ctx, batch) != 0) { if (server_decode(ctx, batch) != 0) {
LOG_ERROR("llama_decode() failed", {}); LOG_ERROR("llama_decode() failed", {});
return; return;
} }
@ -4414,7 +4431,7 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
0, 0, 0, // unused 0, 0, 0, // unused
}; };
const int ret = llama_decode(ctx, batch_view); const int ret = server_decode(ctx, batch_view);
if (ret != 0) { if (ret != 0) {
if (n_batch == 1 || ret < 0) { if (n_batch == 1 || ret < 0) {
int user_cancel = -3; int user_cancel = -3;

View File

@ -175,7 +175,7 @@ extern "C" {
// when ask == false, the scheduler is passing the node tensor to the user for observation // when ask == false, the scheduler is passing the node tensor to the user for observation
// if the user returns false, the scheduler will cancel the graph compute // if the user returns false, the scheduler will cancel the graph compute
// //
typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data); typedef int (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data);
// Initialize a backend scheduler // Initialize a backend scheduler
GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel); GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel);

View File

@ -2127,7 +2127,7 @@ static ggml_status ggml_backend_sched_eval(ggml_backend_sched_t sched, ggml_back
struct ggml_tensor * t = split->graph.nodes[j0]; struct ggml_tensor * t = split->graph.nodes[j0];
// check if the user needs data from this node // check if the user needs data from this node
bool need = sched->callback_eval(t, true, sched->callback_eval_user_data); int need = sched->callback_eval(t, true, sched->callback_eval_user_data);
int j1 = j0; int j1 = j0;
@ -2150,7 +2150,9 @@ static ggml_status ggml_backend_sched_eval(ggml_backend_sched_t sched, ggml_back
} }
// TODO: pass backend to the callback, then the user can decide if they want to synchronize // TODO: pass backend to the callback, then the user can decide if they want to synchronize
ggml_backend_synchronize(split_backend); if (need == 1) {
ggml_backend_synchronize(split_backend);
}
if (need && !sched->callback_eval(t, false, sched->callback_eval_user_data)) { if (need && !sched->callback_eval(t, false, sched->callback_eval_user_data)) {
break; break;

View File

@ -365,7 +365,7 @@ static int32_t llama_dflash_find_layer_index(const struct llama_context * ctx, i
return it == layer_ids.end() ? -1 : (int32_t) std::distance(layer_ids.begin(), it); return it == layer_ids.end() ? -1 : (int32_t) std::distance(layer_ids.begin(), it);
} }
static bool llama_dflash_capture_eval_callback(struct ggml_tensor * tensor, bool ask, void * user_data) { static int llama_dflash_capture_eval_callback(struct ggml_tensor * tensor, bool ask, void * user_data) {
auto * ctx = static_cast<llama_context *>(user_data); auto * ctx = static_cast<llama_context *>(user_data);
if (ctx == nullptr || !ctx->dflash.capture) { if (ctx == nullptr || !ctx->dflash.capture) {
return false; return false;
@ -373,22 +373,24 @@ static bool llama_dflash_capture_eval_callback(struct ggml_tensor * tensor, bool
int32_t layer_id = -1; int32_t layer_id = -1;
if (!llama_dflash_parse_layer_id(tensor, layer_id)) { if (!llama_dflash_parse_layer_id(tensor, layer_id)) {
return false; return 0;
} }
const int32_t layer_idx = llama_dflash_find_layer_index(ctx, layer_id); const int32_t layer_idx = llama_dflash_find_layer_index(ctx, layer_id);
if (layer_idx < 0) { if (layer_idx < 0) {
return false; return 0;
} }
//printf("%s -> %d, %d\n", tensor->name, layer_id, layer_idx);
if (ask) { if (ask) {
return true; return 2;
} }
const int32_t row_width = (int32_t) tensor->ne[0]; const int32_t row_width = (int32_t) tensor->ne[0];
const int32_t row_count = row_width > 0 ? (int32_t) (ggml_nelements(tensor) / (int64_t) row_width) : 0; const int32_t row_count = row_width > 0 ? (int32_t) (ggml_nelements(tensor) / (int64_t) row_width) : 0;
if (row_width <= 0 || row_count <= 0) { if (row_width <= 0 || row_count <= 0) {
return false; return 0;
} }
auto & capture = *ctx->dflash.capture; auto & capture = *ctx->dflash.capture;
@ -401,11 +403,13 @@ static bool llama_dflash_capture_eval_callback(struct ggml_tensor * tensor, bool
auto & rows = capture.layer_rows[(size_t) layer_idx]; auto & rows = capture.layer_rows[(size_t) layer_idx];
rows.resize((size_t) row_count * (size_t) row_width); rows.resize((size_t) row_count * (size_t) row_width);
ggml_backend_tensor_get(tensor, rows.data(), 0, ggml_nbytes(tensor)); auto backend = ggml_backend_sched_get_tensor_backend(ctx->sched, tensor);
GGML_ASSERT(backend);
ggml_backend_tensor_get_async(backend, tensor, rows.data(), 0, ggml_nbytes(tensor));
capture.row_width = row_width; capture.row_width = row_width;
capture.row_count = row_count; capture.row_count = row_count;
capture.layer_seen_batch_id[(size_t) layer_idx] = capture.capture_batch_id; capture.layer_seen_batch_id[(size_t) layer_idx] = capture.capture_batch_id;
return true; return 2;
} }
bool llama_set_dflash_capture_layers( bool llama_set_dflash_capture_layers(