mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-06-28 04:30:15 -05:00
Allow graph reuse for Gemma4 MTP
This commit is contained in:
parent
4bcfe5b872
commit
d1692e1951
@ -576,42 +576,47 @@ void llama_context::reset_scheduler() {
|
||||
prev_mtp.reset();
|
||||
}
|
||||
|
||||
bool llama_context::can_reuse_graph(const llama_batch & u_batch) {
|
||||
if (!cparams.graph_reuse) return false;
|
||||
//if (kv_self.save_per_step_ssm) return false;
|
||||
if ((model.arch == LLM_ARCH_GEMMA4_MTP || model.arch == LLM_ARCH_GEMMA4_ASSISTANT) && mtp_target_ctx != nullptr) return false;
|
||||
auto the_prev = cparams.mtp_op_type == MTP_OP_NONE ? prev.get() : prev_mtp.get();
|
||||
if (!the_prev || !the_prev->graph) return false;
|
||||
//if (u_batch.n_tokens > 1) return false;
|
||||
if (u_batch.embd) return false;
|
||||
if (the_prev->save_per_step_ssm != kv_self.save_per_step_ssm ||
|
||||
the_prev->per_step_max_allocated != kv_self.ckpt.per_step_max_allocated) return false;
|
||||
return u_batch.all_seq_id == the_prev->all_seq_id &&
|
||||
kv_self.head > 0 &&
|
||||
kv_self.n == the_prev->n_kv &&
|
||||
n_outputs == the_prev->n_outputs &&
|
||||
u_batch.n_tokens == the_prev->n_tokens &&
|
||||
cparams.mtp_op_type == the_prev->mtp_op_type &&
|
||||
update_cache_copies();
|
||||
}
|
||||
|
||||
/*
|
||||
static void why_not_reuse_previous(const llama_batch & u_batch, const llama_context & ctx, const llama_context::Prev * the_prev) {
|
||||
if (!the_prev) { printf(" previous is null\n"); return; }
|
||||
if (!the_prev->graph) { printf(" previous graph is null\n"); return; }
|
||||
if (!ctx.cparams.graph_reuse) { printf(" graph_reuse is false\n"); return; }
|
||||
if (u_batch.embd) { printf(" ubatch.embd is not null\n"); return; }
|
||||
if (u_batch.all_seq_id != the_prev->all_seq_id) { printf(" all_seq_id is not the same\n"); return; }
|
||||
if (ctx.kv_self.head == 0) { printf(" kv_self.head = 0\n"); return; }
|
||||
if (ctx.kv_self.n != the_prev->n_kv) { printf(" kv_self.n is not the same\n"); return; }
|
||||
auto & kv_self_used = (ctx.model.arch == LLM_ARCH_GEMMA4_MTP || ctx.model.arch == LLM_ARCH_GEMMA4_ASSISTANT) &&
|
||||
ctx.mtp_target_ctx != nullptr ? ctx.mtp_target_ctx->kv_self : ctx.kv_self;
|
||||
if (kv_self_used.head == 0) { printf(" kv_self.head = 0\n"); return; }
|
||||
if (kv_self_used.n != the_prev->n_kv) { printf(" kv_self.n is not the same\n"); return; }
|
||||
if (ctx.n_outputs != the_prev->n_outputs) { printf(" n_outputs is not the same\n"); return; }
|
||||
if (u_batch.n_tokens != the_prev->n_tokens) { printf(" n_tokens is not the same\n"); return; }
|
||||
if (ctx.cparams.mtp_op_type != the_prev->mtp_op_type) { printf(" mtp_op_type is not the same\n"); return; }
|
||||
printf(" update_cache_copies() must have failed\n");
|
||||
}
|
||||
*/
|
||||
|
||||
bool llama_context::can_reuse_graph(const llama_batch & u_batch) {
|
||||
if (!cparams.graph_reuse) return false;
|
||||
auto the_prev = cparams.mtp_op_type == MTP_OP_NONE ? prev.get() : prev_mtp.get();
|
||||
if (!the_prev || !the_prev->graph) return false;
|
||||
if (u_batch.embd) return false;
|
||||
auto & kv_self_used = (model.arch == LLM_ARCH_GEMMA4_MTP || model.arch == LLM_ARCH_GEMMA4_ASSISTANT) &&
|
||||
mtp_target_ctx != nullptr ? mtp_target_ctx->kv_self : kv_self;
|
||||
if (the_prev->save_per_step_ssm != kv_self_used.save_per_step_ssm ||
|
||||
the_prev->per_step_max_allocated != kv_self_used.ckpt.per_step_max_allocated) return false;
|
||||
bool result = u_batch.all_seq_id == the_prev->all_seq_id &&
|
||||
kv_self_used.head > 0 &&
|
||||
kv_self_used.n == the_prev->n_kv &&
|
||||
n_outputs == the_prev->n_outputs &&
|
||||
u_batch.n_tokens == the_prev->n_tokens &&
|
||||
cparams.mtp_op_type == the_prev->mtp_op_type &&
|
||||
update_cache_copies();
|
||||
if (false && !result) {
|
||||
printf("%s(%d):", __func__, cparams.mtp_op_type);
|
||||
why_not_reuse_previous(u_batch, *this, the_prev);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
bool llama_context::update_cache_copies() {
|
||||
if (model.arch == LLM_ARCH_GEMMA4_MTP || model.arch == LLM_ARCH_GEMMA4_ASSISTANT) return true;
|
||||
const int n_layer = model.mtp && cparams.mtp_op_type != MTP_OP_NONE ?
|
||||
model.hparams.n_layer : model.hparams.n_layer - model.hparams.nextn_predict_layers; //cache_copies.size()/2;
|
||||
auto layer_has_attention_kv = [&](int il) {
|
||||
@ -5409,17 +5414,17 @@ static int llama_decode_internal(
|
||||
tim2 = ggml_time_us();
|
||||
printf("sched_alloc_graph(...): %d us\n", int(tim2-tim1));
|
||||
#endif
|
||||
//if (u_batch.n_tokens == 1 && u_batch.embd == nullptr && lctx.cparams.graph_reuse) {
|
||||
if (u_batch.embd == nullptr && lctx.cparams.graph_reuse &&
|
||||
!((lctx.model.arch == LLM_ARCH_GEMMA4_MTP || lctx.model.arch == LLM_ARCH_GEMMA4_ASSISTANT) && lctx.mtp_target_ctx != nullptr)) {
|
||||
if (u_batch.embd == nullptr && lctx.cparams.graph_reuse) {
|
||||
auto & kv_self_used = (model.arch == LLM_ARCH_GEMMA4_MTP || model.arch == LLM_ARCH_GEMMA4_ASSISTANT) &&
|
||||
lctx.mtp_target_ctx != nullptr ? lctx.mtp_target_ctx->kv_self : lctx.kv_self;
|
||||
prev = std::make_unique<llama_context::Prev>(llama_context::Prev{
|
||||
(int)u_batch.all_seq_id, (int)lctx.n_outputs, (int)lctx.kv_self.n,
|
||||
(int)u_batch.all_seq_id, (int)lctx.n_outputs, (int)kv_self_used.n,
|
||||
(int)u_batch.n_tokens,
|
||||
lctx.kv_self.save_per_step_ssm, lctx.kv_self.ckpt.per_step_max_allocated,
|
||||
kv_self_used.save_per_step_ssm, kv_self_used.ckpt.per_step_max_allocated,
|
||||
cparams.mtp_op_type, gf});
|
||||
}
|
||||
} else {
|
||||
//printf("Reusing graph with n_kv = %d, n_tokens = %d\n", (int)prev->n_kv, (int)prev->n_tokens);
|
||||
//printf("Reusing graph with type = %d, n_kv = %d, n_tokens = %d\n", cparams.mtp_op_type, (int)prev->n_kv, (int)prev->n_tokens);
|
||||
gf = prev->graph;
|
||||
}
|
||||
|
||||
@ -5497,9 +5502,9 @@ static int llama_decode_internal(
|
||||
#if IK_PRINT_TIMING
|
||||
tim1 = ggml_time_us();
|
||||
#endif
|
||||
if (lctx.dflash.kv.workspace_sync_pending) {
|
||||
llama_sync_dflash_workspace_if_pending(lctx);
|
||||
}
|
||||
if (lctx.dflash.kv.workspace_sync_pending) {
|
||||
llama_sync_dflash_workspace_if_pending(lctx);
|
||||
}
|
||||
llama_graph_compute(lctx, gf, n_threads);
|
||||
#if IK_PRINT_TIMING
|
||||
llama_synchronize(&lctx);
|
||||
@ -5676,12 +5681,15 @@ static int llama_decode_internal(
|
||||
#if IK_PRINT_TIMING
|
||||
auto tim1 = ggml_time_us();
|
||||
#endif
|
||||
if (!lctx.prev) {
|
||||
lctx.reset_scheduler();
|
||||
if (lctx.cparams.mtp_op_type == MTP_OP_NONE && !lctx.prev) {
|
||||
ggml_backend_sched_reset(lctx.sched);
|
||||
}
|
||||
else if (lctx.cparams.mtp_op_type != MTP_OP_NONE && !lctx.prev_mtp) {
|
||||
ggml_backend_sched_reset(lctx.sched);
|
||||
}
|
||||
#if IK_PRINT_TIMING
|
||||
auto tim2 = ggml_time_us();
|
||||
printf("sched_reset(...): %d us\n", int(tim2-tim1));
|
||||
auto tim2 = ggml_time_us();
|
||||
printf("sched_reset(...): %d us\n", int(tim2-tim1));
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user