From 8369cf74128519e49ccbafe1311c89c498bef778 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Fri, 19 Jun 2026 18:16:53 +0200 Subject: [PATCH] Allow graph reuse for Gemma4 MTP (#1996) --- src/llama.cpp | 80 ++++++++++++++++++++++++++++----------------------- 1 file changed, 44 insertions(+), 36 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 198d4e45..b71eed23 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -572,42 +572,47 @@ void llama_context::reset_scheduler() { prev_mtp.reset(); } -bool llama_context::can_reuse_graph(const llama_batch & u_batch) { - if (!cparams.graph_reuse) return false; - //if (kv_self.save_per_step_ssm) return false; - if ((model.arch == LLM_ARCH_GEMMA4_MTP || model.arch == LLM_ARCH_GEMMA4_ASSISTANT) && mtp_target_ctx != nullptr) return false; - auto the_prev = cparams.mtp_op_type == MTP_OP_NONE ? prev.get() : prev_mtp.get(); - if (!the_prev || !the_prev->graph) return false; - //if (u_batch.n_tokens > 1) return false; - if (u_batch.embd) return false; - if (the_prev->save_per_step_ssm != kv_self.save_per_step_ssm || - the_prev->per_step_max_allocated != kv_self.ckpt.per_step_max_allocated) return false; - return u_batch.all_seq_id == the_prev->all_seq_id && - kv_self.head > 0 && - kv_self.n == the_prev->n_kv && - n_outputs == the_prev->n_outputs && - u_batch.n_tokens == the_prev->n_tokens && - cparams.mtp_op_type == the_prev->mtp_op_type && - update_cache_copies(); -} - -/* static void why_not_reuse_previous(const llama_batch & u_batch, const llama_context & ctx, const llama_context::Prev * the_prev) { if (!the_prev) { printf(" previous is null\n"); return; } if (!the_prev->graph) { printf(" previous graph is null\n"); return; } if (!ctx.cparams.graph_reuse) { printf(" graph_reuse is false\n"); return; } if (u_batch.embd) { printf(" ubatch.embd is not null\n"); return; } if (u_batch.all_seq_id != the_prev->all_seq_id) { printf(" all_seq_id is not the same\n"); return; } - if (ctx.kv_self.head == 0) { printf(" kv_self.head = 0\n"); return; } - if (ctx.kv_self.n != the_prev->n_kv) { printf(" kv_self.n is not the same\n"); return; } + auto & kv_self_used = (ctx.model.arch == LLM_ARCH_GEMMA4_MTP || ctx.model.arch == LLM_ARCH_GEMMA4_ASSISTANT) && + ctx.mtp_target_ctx != nullptr ? ctx.mtp_target_ctx->kv_self : ctx.kv_self; + if (kv_self_used.head == 0) { printf(" kv_self.head = 0\n"); return; } + if (kv_self_used.n != the_prev->n_kv) { printf(" kv_self.n is not the same\n"); return; } if (ctx.n_outputs != the_prev->n_outputs) { printf(" n_outputs is not the same\n"); return; } if (u_batch.n_tokens != the_prev->n_tokens) { printf(" n_tokens is not the same\n"); return; } if (ctx.cparams.mtp_op_type != the_prev->mtp_op_type) { printf(" mtp_op_type is not the same\n"); return; } printf(" update_cache_copies() must have failed\n"); } -*/ + +bool llama_context::can_reuse_graph(const llama_batch & u_batch) { + if (!cparams.graph_reuse) return false; + auto the_prev = cparams.mtp_op_type == MTP_OP_NONE ? prev.get() : prev_mtp.get(); + if (!the_prev || !the_prev->graph) return false; + if (u_batch.embd) return false; + auto & kv_self_used = (model.arch == LLM_ARCH_GEMMA4_MTP || model.arch == LLM_ARCH_GEMMA4_ASSISTANT) && + mtp_target_ctx != nullptr ? mtp_target_ctx->kv_self : kv_self; + if (the_prev->save_per_step_ssm != kv_self_used.save_per_step_ssm || + the_prev->per_step_max_allocated != kv_self_used.ckpt.per_step_max_allocated) return false; + bool result = u_batch.all_seq_id == the_prev->all_seq_id && + kv_self_used.head > 0 && + kv_self_used.n == the_prev->n_kv && + n_outputs == the_prev->n_outputs && + u_batch.n_tokens == the_prev->n_tokens && + cparams.mtp_op_type == the_prev->mtp_op_type && + update_cache_copies(); + if (false && !result) { + printf("%s(%d):", __func__, cparams.mtp_op_type); + why_not_reuse_previous(u_batch, *this, the_prev); + } + return result; +} bool llama_context::update_cache_copies() { + if (model.arch == LLM_ARCH_GEMMA4_MTP || model.arch == LLM_ARCH_GEMMA4_ASSISTANT) return true; const int n_layer = model.mtp && cparams.mtp_op_type != MTP_OP_NONE ? model.hparams.n_layer : model.hparams.n_layer - model.hparams.nextn_predict_layers; //cache_copies.size()/2; auto layer_has_attention_kv = [&](int il) { @@ -5395,17 +5400,17 @@ static int llama_decode_internal( tim2 = ggml_time_us(); printf("sched_alloc_graph(...): %d us\n", int(tim2-tim1)); #endif - //if (u_batch.n_tokens == 1 && u_batch.embd == nullptr && lctx.cparams.graph_reuse) { - if (u_batch.embd == nullptr && lctx.cparams.graph_reuse && - !((lctx.model.arch == LLM_ARCH_GEMMA4_MTP || lctx.model.arch == LLM_ARCH_GEMMA4_ASSISTANT) && lctx.mtp_target_ctx != nullptr)) { + if (u_batch.embd == nullptr && lctx.cparams.graph_reuse) { + auto & kv_self_used = (model.arch == LLM_ARCH_GEMMA4_MTP || model.arch == LLM_ARCH_GEMMA4_ASSISTANT) && + lctx.mtp_target_ctx != nullptr ? lctx.mtp_target_ctx->kv_self : lctx.kv_self; prev = std::make_unique(llama_context::Prev{ - (int)u_batch.all_seq_id, (int)lctx.n_outputs, (int)lctx.kv_self.n, + (int)u_batch.all_seq_id, (int)lctx.n_outputs, (int)kv_self_used.n, (int)u_batch.n_tokens, - lctx.kv_self.save_per_step_ssm, lctx.kv_self.ckpt.per_step_max_allocated, + kv_self_used.save_per_step_ssm, kv_self_used.ckpt.per_step_max_allocated, cparams.mtp_op_type, gf}); } } else { - //printf("Reusing graph with n_kv = %d, n_tokens = %d\n", (int)prev->n_kv, (int)prev->n_tokens); + //printf("Reusing graph with type = %d, n_kv = %d, n_tokens = %d\n", cparams.mtp_op_type, (int)prev->n_kv, (int)prev->n_tokens); gf = prev->graph; } @@ -5483,9 +5488,9 @@ static int llama_decode_internal( #if IK_PRINT_TIMING tim1 = ggml_time_us(); #endif - if (lctx.dflash.kv.workspace_sync_pending) { - llama_sync_dflash_workspace_if_pending(lctx); - } + if (lctx.dflash.kv.workspace_sync_pending) { + llama_sync_dflash_workspace_if_pending(lctx); + } llama_graph_compute(lctx, gf, n_threads); #if IK_PRINT_TIMING llama_synchronize(&lctx); @@ -5662,12 +5667,15 @@ static int llama_decode_internal( #if IK_PRINT_TIMING auto tim1 = ggml_time_us(); #endif - if (!lctx.prev) { - lctx.reset_scheduler(); + if (lctx.cparams.mtp_op_type == MTP_OP_NONE && !lctx.prev) { + ggml_backend_sched_reset(lctx.sched); + } + else if (lctx.cparams.mtp_op_type != MTP_OP_NONE && !lctx.prev_mtp) { + ggml_backend_sched_reset(lctx.sched); } #if IK_PRINT_TIMING - auto tim2 = ggml_time_us(); - printf("sched_reset(...): %d us\n", int(tim2-tim1)); + auto tim2 = ggml_time_us(); + printf("sched_reset(...): %d us\n", int(tim2-tim1)); #endif return 0;