From 1c4a91c0f3af274a0294ceddfe59e48d08399a7e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 6 Jun 2026 10:48:36 +0300 Subject: [PATCH] wip --- common/speculative.cpp | 18 +++- src/llama-batch.cpp | 18 ++-- src/llama-context.cpp | 76 +++------------ src/llama-context.h | 12 +-- src/llama-ext.h | 5 +- src/llama-graph.cpp | 167 ++++++-------------------------- src/llama-graph.h | 61 ------------ src/llama-kv-cache.cpp | 4 + src/llama-memory.h | 2 +- src/llama-model.cpp | 9 ++ src/llama-model.h | 1 + src/models/gemma4-assistant.cpp | 21 +--- tools/server/server-context.cpp | 17 +--- 13 files changed, 97 insertions(+), 314 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index e4fe2a86eb..e899be804e 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -491,11 +491,16 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { } } - kv_shared_with_target = llama_get_memory(ctx_dft) == nullptr; - llama_set_embeddings_nextn(ctx_tgt, true, /*masked*/ false); llama_set_embeddings_nextn(ctx_dft, true, /*masked*/ true); - llama_set_mtp_source(ctx_dft, ctx_tgt); + + kv_shared_with_target = llama_get_memory(ctx_dft) == nullptr; + if (kv_shared_with_target) { + llama_set_memory(ctx_dft, ctx_tgt); + + // TODO: avoid the const cast + llama_model_set_tok_embd(const_cast(llama_get_model(ctx_dft)), llama_model_get_tok_embd(llama_get_model(ctx_tgt))); + } pending_h.assign(n_seq, std::vector(n_embd, 0.0f)); @@ -640,6 +645,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { void draft(common_speculative_draft_params_vec & dparams) override { auto & ctx_dft = params.ctx_dft; + auto & ctx_tgt = params.ctx_tgt; common_batch_clear(batch); @@ -665,6 +671,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { h_row = pending_h[seq_id].data(); std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes); + + llama_memory_seq_rm(llama_get_memory(ctx_dft), seq_id, dp.n_past, -1); } int ret = llama_decode(ctx_dft, batch); @@ -725,6 +733,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { common_batch_add(batch, id, dp.n_past + i + 1, { seq_id }, true); std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes); + + llama_memory_seq_rm(llama_get_memory(ctx_dft), seq_id, dp.n_past, -1); } if (batch.n_tokens == 0) { @@ -741,6 +751,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { ++i; } + llama_synchronize(ctx_dft); + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { auto & dp = dparams[seq_id]; if (!dp.drafting) { diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index 6bf76939cd..06ebd32312 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -301,16 +301,16 @@ bool llama_batch_allocr::init( ok = false; } - if (!ok) { - LLAMA_LOG_ERROR( - "%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n" - " - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n" - " - the tokens for sequence %d in the input batch have a starting position of Y = %d\n" - " it is required that the sequence positions remain consecutive: Y = X + 1\n", - __func__, s, s, p0, s, seq_pos_min(s)); + //if (!ok) { + // LLAMA_LOG_ERROR( + // "%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n" + // " - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n" + // " - the tokens for sequence %d in the input batch have a starting position of Y = %d\n" + // " it is required that the sequence positions remain consecutive: Y = X + 1\n", + // __func__, s, s, p0, s, seq_pos_min(s)); - return false; - } + // return false; + //} } if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) { diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 5daa80cfe8..96720e3d67 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -38,12 +38,6 @@ static uint32_t ctx_type_to_embd_inp(const llama_hparams & hparams, llama_contex throw std::runtime_error("Unsupported ctx type"); } -namespace { -struct src_mctx_reset_on_exit { - llama_memory_context_ptr * slot; - ~src_mctx_reset_on_exit() { if (slot) slot->reset(); } -}; - static void llama_assert_gemma4_mtp_source_placement( const llama_context * ctx, const llama_context * src) { @@ -92,7 +86,6 @@ static void llama_assert_gemma4_mtp_source_placement( } } } -} llama_context::llama_context( const llama_model & model, @@ -514,23 +507,6 @@ void llama_context::sched_reserve() { } } - // When called from decode(), src_mctx_for_decode is already populated and - // we must not drop it on exit (process_ubatch still needs it). Snapshot - // only when sched_reserve runs standalone (e.g. lazy first-decode reserve - // when set_mtp_source flipped sched_need_reserve). - const bool owns_src_snapshot = src_ctx && !src_mctx_for_decode; - if (owns_src_snapshot) { - auto * src_memory = src_ctx->get_memory(); - if (!src_memory) { - throw std::runtime_error("MTP source context has no memory module"); - } - src_mctx_for_decode = src_memory->init_full(); - if (!src_mctx_for_decode) { - throw std::runtime_error("failed to initialize MTP source memory snapshot"); - } - } - src_mctx_reset_on_exit reserve_src_drop{owns_src_snapshot ? &src_mctx_for_decode : nullptr}; - // avoid reserving graphs with zero outputs - assume one output per sequence const int n_outputs = n_seqs; @@ -816,8 +792,14 @@ uint32_t llama_context::n_threads_batch() const { return cparams.n_threads_batch; } -llama_memory_t llama_context::get_memory() const { - return memory.get(); +llama_memory_ptr llama_context::get_memory() const { + return memory; +} + +void llama_context::set_memory(llama_memory_ptr memory) { + this->memory = std::move(memory); + + sched_need_reserve = true; } bool llama_context::memory_update(bool optimize) { @@ -1198,18 +1180,6 @@ void llama_context::set_embeddings_nextn(bool value, bool masked) { cparams.embeddings_nextn_masked = masked; } -void llama_context::set_mtp_source(llama_context * src) { - if (src_ctx == src) { - return; - } - llama_assert_gemma4_mtp_source_placement(this, src); - src_ctx = src; - src_mctx_for_decode.reset(); - // worst-case compute buffers were reserved without knowing about the source - // memory; force a re-reserve so the next decode sees src views - sched_need_reserve = true; -} - void llama_context::set_causal_attn(bool value) { LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); @@ -1807,20 +1777,6 @@ int llama_context::decode(const llama_batch & batch_inp) { embd_seq.clear(); output_swaps.clear(); - src_mctx_reset_on_exit decode_src_drop{&src_mctx_for_decode}; - if (src_ctx) { - auto * src_memory = src_ctx->get_memory(); - if (!src_memory) { - LLAMA_LOG_ERROR("%s: MTP source context has no memory module\n", __func__); - return -2; - } - src_mctx_for_decode = src_memory->init_full(); - if (!src_mctx_for_decode) { - LLAMA_LOG_ERROR("%s: failed to snapshot MTP source memory\n", __func__); - return -2; - } - } - sched_reserve(); bool did_optimize = false; @@ -2414,8 +2370,6 @@ llm_graph_params llama_context::graph_params( /*.cvec =*/ cvec.get(), /*.loras =*/ loras.get(), /*.mctx =*/ mctx, - /*.src_mctx =*/ src_mctx_for_decode.get(), - /*.src_model =*/ src_ctx ? &src_ctx->get_model() : nullptr, /*.cross =*/ &cross, /*.samplers =*/ sampling.samplers, /*.n_outputs =*/ n_outputs, @@ -3704,8 +3658,12 @@ void llama_set_embeddings_nextn(llama_context * ctx, bool value, bool masked) { ctx->set_embeddings_nextn(value, masked); } -void llama_set_mtp_source(llama_context * ctx, llama_context * src) { - ctx->set_mtp_source(src); +void llama_set_memory(llama_context * ctx, llama_context * src) { + ctx->set_memory(src->get_memory()); +} + +llama_memory_t llama_get_memory(const struct llama_context * ctx) { + return ctx->get_memory().get(); } float * llama_get_embeddings_nextn(llama_context * ctx) { @@ -3771,7 +3729,7 @@ struct ggml_cgraph * llama_graph_reserve( uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs) { - auto * memory = ctx->get_memory(); + auto memory = ctx->get_memory(); llama_memory_context_ptr mctx; if (memory) { mctx = memory->init_full(); @@ -3811,10 +3769,6 @@ int32_t llama_set_adapter_cvec( // memory // -llama_memory_t llama_get_memory(const struct llama_context * ctx) { - return ctx->get_memory(); -} - void llama_memory_clear(llama_memory_t mem, bool data) { if (!mem) { return; diff --git a/src/llama-context.h b/src/llama-context.h index 935fa75af8..af26186714 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -71,7 +71,8 @@ struct llama_context { uint32_t n_threads() const; uint32_t n_threads_batch() const; - llama_memory_t get_memory() const; + llama_memory_ptr get_memory() const; + void set_memory(llama_memory_ptr memory); // return true if the memory was updated bool memory_update(bool optimize); @@ -112,7 +113,6 @@ struct llama_context { void set_embeddings (bool value); void set_embeddings_nextn(bool value, bool masked); - void set_mtp_source(llama_context * src); void set_causal_attn(bool value); void set_warmup(bool value); @@ -275,13 +275,7 @@ private: llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably - std::unique_ptr memory; - - // external KV source used by MTP draft contexts. src_ctx is the target - // context whose memory we read; src_mctx_for_decode is a per-decode - // snapshot held for the duration of one decode/sched_reserve call. - llama_context * src_ctx = nullptr; - llama_memory_context_ptr src_mctx_for_decode; + llama_memory_ptr memory; // decode output (2-dimensional array: [n_outputs][n_vocab]) buffer_view logits = {nullptr, 0}; diff --git a/src/llama-ext.h b/src/llama-ext.h index 92f8bfffa4..e6faa4f26d 100644 --- a/src/llama-ext.h +++ b/src/llama-ext.h @@ -93,7 +93,7 @@ LLAMA_API llama_memory_breakdown llama_get_memory_breakdown(const struct llama_c // If masked == true, output the embeddings only for the tokens with batch.logits != 0 // If masked == false, output the embeddings for all tokens in the batch regardless of batch.logits LLAMA_API void llama_set_embeddings_nextn(struct llama_context * ctx, bool value, bool masked); -LLAMA_API void llama_set_mtp_source(struct llama_context * ctx, struct llama_context * src); +LLAMA_API void llama_set_memory(struct llama_context * ctx, struct llama_context * src); // mirrors: // LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); @@ -101,3 +101,6 @@ LLAMA_API float * llama_get_embeddings_nextn(struct llama_context * ctx); // LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i); LLAMA_API float * llama_get_embeddings_nextn_ith(struct llama_context * ctx, int32_t i); + +LLAMA_API ggml_tensor * llama_model_get_tok_embd(const struct llama_model * model); +LLAMA_API void llama_model_set_tok_embd( struct llama_model * model, ggml_tensor * tensor); diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 8bb28e6960..18f5d24ae4 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -397,7 +397,7 @@ static void print_mask(const T * data, int64_t n_tokens, int64_t n_kv, int64_t n case LLAMA_SWA_TYPE_SYMMETRIC: swa_type_str = "LLAMA_SWA_TYPE_SYMMETRIC"; break; }; - LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swq_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str); + LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swa_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str); LLAMA_LOG_DEBUG("%s: '0' = can attend, '∞' = masked\n", __func__); LLAMA_LOG_DEBUG("%s: Rows = query tokens, Columns = key/value tokens\n\n", __func__); @@ -565,18 +565,18 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) { if (self_k_idxs && self_k_idxs->buffer) { mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch); mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch); - - mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); } + mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); + // swa tensors may not be allocated if there are no SWA attention layers if (self_k_idxs_swa && self_k_idxs_swa->buffer) { mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch); mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch); - - mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); } + mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); + if (self_k_rot) { mctx->get_base()->set_input_k_rot(self_k_rot); } @@ -605,47 +605,18 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) { if (self_k_idxs && self_k_idxs->buffer) { res &= self_k_idxs->ne[0] == params.ubatch.n_tokens; //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there - - res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams); } + res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams); + // swa tensors may not be allocated if there are no SWA attention layers if (self_k_idxs_swa && self_k_idxs_swa->buffer) { res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens; //res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there - - res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams); } - return res; -} + res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams); -void llm_graph_input_attn_src_kv_iswa::set_input(const llama_ubatch * ubatch) { - src_mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); - src_mctx->get_swa() ->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); - - if (self_k_rot) { - src_mctx->get_base()->set_input_k_rot(self_k_rot); - } - if (self_v_rot) { - src_mctx->get_base()->set_input_v_rot(self_v_rot); - } - if (self_k_rot_swa) { - src_mctx->get_swa()->set_input_k_rot(self_k_rot_swa); - } - if (self_v_rot_swa) { - src_mctx->get_swa()->set_input_v_rot(self_v_rot_swa); - } -} - -bool llm_graph_input_attn_src_kv_iswa::can_reuse(const llm_graph_params & params) { - const auto * mctx = static_cast(params.src_mctx); - - this->src_mctx = mctx; - - bool res = true; - res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams); - res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams); return res; } @@ -785,18 +756,18 @@ void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) { if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) { attn_ctx->get_base()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch); attn_ctx->get_base()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch); - - attn_ctx->get_base()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn); } + attn_ctx->get_base()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn); + // swa tensors may not be allocated if there are no SWA attention layers if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) { attn_ctx->get_swa()->set_input_k_idxs(inp_attn->self_k_idxs_swa, ubatch); attn_ctx->get_swa()->set_input_v_idxs(inp_attn->self_v_idxs_swa, ubatch); - - attn_ctx->get_swa()->set_input_kq_mask(inp_attn->self_kq_mask_swa, ubatch, cparams.causal_attn); } + attn_ctx->get_swa()->set_input_kq_mask(inp_attn->self_kq_mask_swa, ubatch, cparams.causal_attn); + if (inp_attn->self_k_rot) { attn_ctx->get_base()->set_input_k_rot(inp_attn->self_k_rot); } @@ -839,18 +810,18 @@ bool llm_graph_input_mem_hybrid_iswa::can_reuse(const llm_graph_params & params) if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) { res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens; //res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there - - res &= can_reuse_kq_mask(inp_attn->self_kq_mask, attn_ctx->get_base(), params.ubatch, params.cparams); } + res &= can_reuse_kq_mask(inp_attn->self_kq_mask, attn_ctx->get_base(), params.ubatch, params.cparams); + // swa tensors may not be allocated if there are no SWA attention layers if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) { res &= inp_attn->self_k_idxs_swa->ne[0] == params.ubatch.n_tokens; //res &= inp_attn->self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there - - res &= can_reuse_kq_mask(inp_attn->self_kq_mask_swa, attn_ctx->get_swa(), params.ubatch, params.cparams); } + res &= can_reuse_kq_mask(inp_attn->self_kq_mask_swa, attn_ctx->get_swa(), params.ubatch, params.cparams); + res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs(); res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs; @@ -1063,8 +1034,6 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) : cvec (params.cvec), loras (params.loras), mctx (params.mctx), - src_mctx (params.src_mctx), - src_model (params.src_model), cross (params.cross), samplers (params.samplers), cb_func (params.cb), @@ -2574,6 +2543,16 @@ ggml_tensor * llm_graph_context::build_attn( ggml_build_forward_expand(gf, v_cur); } + int il_save = il; + + if (arch == LLM_ARCH_GEMMA4_ASSISTANT) { + if (il == n_layer - 1) { + il = 59; + } else { + il = 58; + } + } + const auto * mctx_iswa = inp->mctx; const auto * mctx_cur = is_swa ? mctx_iswa->get_swa() : mctx_iswa->get_base(); @@ -2597,6 +2576,8 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * k = mctx_cur->get_k(ctx0, il); ggml_tensor * v = mctx_cur->get_v(ctx0, il); + il = il_save; + ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il); cb(cur, "kqv_out", il); @@ -2635,98 +2616,6 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const { return (llm_graph_input_attn_cross *) res->add_input(std::move(inp)); } -llm_graph_input_attn_src_kv_iswa * llm_graph_context::build_attn_inp_src_kv_iswa() const { - GGML_ASSERT(src_mctx && "MTP draft graph requires src_mctx (set via llama_set_mtp_source)"); - - const auto * src_iswa = static_cast(src_mctx); - - auto inp = std::make_unique(hparams, cparams, src_iswa); - - inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, src_iswa->get_base(), ubatch, cparams); - inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; - - inp->self_kq_mask_swa = build_attn_inp_kq_mask(ctx0, src_iswa->get_swa(), ubatch, cparams); - inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa; - - inp->self_k_rot = src_iswa->get_base()->build_input_k_rot(ctx0); - inp->self_v_rot = src_iswa->get_base()->build_input_v_rot(ctx0); - inp->self_k_rot_swa = src_iswa->get_swa()->build_input_k_rot(ctx0); - inp->self_v_rot_swa = src_iswa->get_swa()->build_input_v_rot(ctx0); - - return (llm_graph_input_attn_src_kv_iswa *) res->add_input(std::move(inp)); -} - -ggml_tensor * llm_graph_context::build_attn( - llm_graph_input_attn_src_kv_iswa * inp, - ggml_tensor * wo, - ggml_tensor * wo_b, - ggml_tensor * wo_s, - ggml_tensor * q_cur, - ggml_tensor * kq_b, - ggml_tensor * sinks, - ggml_tensor * v_mla, - float kq_scale, - int il_assist, - int il_src) const { - const bool is_swa = hparams.is_swa(il_assist); - - const auto * src_iswa = inp->src_mctx; - const auto * src_cur = is_swa ? src_iswa->get_swa() : src_iswa->get_base(); - - const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask(); - - auto * k_rot = is_swa ? inp->self_k_rot_swa : inp->self_k_rot; - auto * v_rot = is_swa ? inp->self_v_rot_swa : inp->self_v_rot; - - if (k_rot) { - q_cur = ggml_mul_mat_aux(ctx0, q_cur, k_rot); - } - - ggml_build_forward_expand(gf, q_cur); - - ggml_tensor * q = q_cur; - ggml_tensor * k = src_cur->get_k(ctx0, il_src); - ggml_tensor * v = src_cur->get_v(ctx0, il_src); - - // build_attn_mha splits q across k->ne[3] (the trunk's stream count). When the - // trunk runs kv_unified=false the assistant's ubatch only references a subset - // of streams (one per active draft seq); q->ne[2] is not divisible by the full - // n_stream and the view collapses tokens. Slice k/v down to exactly the streams - // referenced by this ubatch. Requires those streams to form a contiguous range. - if (k->ne[3] > 1 && (uint32_t) k->ne[3] != ubatch.n_seqs_unq) { - GGML_ASSERT(ubatch.n_seqs_unq > 0 && ubatch.seq_id_unq); - llama_seq_id min_s = ubatch.seq_id_unq[0]; - llama_seq_id max_s = ubatch.seq_id_unq[0]; - for (uint32_t s = 1; s < ubatch.n_seqs_unq; ++s) { - min_s = std::min(min_s, ubatch.seq_id_unq[s]); - max_s = std::max(max_s, ubatch.seq_id_unq[s]); - } - GGML_ASSERT((uint32_t)(max_s - min_s + 1) == ubatch.n_seqs_unq && - "MTP src-kv attn requires the active draft seq_ids to be contiguous"); - GGML_ASSERT((int64_t) max_s < k->ne[3] && "MTP assistant seq_id beyond trunk stream count"); - - k = ggml_view_4d(ctx0, k, k->ne[0], k->ne[1], k->ne[2], (int64_t) ubatch.n_seqs_unq, - k->nb[1], k->nb[2], k->nb[3], (size_t) min_s * k->nb[3]); - v = ggml_view_4d(ctx0, v, v->ne[0], v->ne[1], v->ne[2], (int64_t) ubatch.n_seqs_unq, - v->nb[1], v->nb[2], v->nb[3], (size_t) min_s * v->nb[3]); - } - - ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il_assist); - cb(cur, "kqv_out", il_assist); - - if (v_rot) { - cur = ggml_mul_mat_aux(ctx0, cur, v_rot); - } - - if (wo) { - cur = build_lora_mm(wo, cur, wo_s); - } - if (wo_b) { - cur = ggml_add(ctx0, cur, wo_b); - } - return cur; -} - ggml_tensor * llm_graph_context::build_attn( llm_graph_input_attn_cross * inp, ggml_tensor * wo, diff --git a/src/llama-graph.h b/src/llama-graph.h index 8d6b88fe01..bf5be09ac7 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -459,42 +459,6 @@ public: const llama_kv_cache_iswa_context * mctx; }; -// mask-only input for attention against an external (read-only) ISWA KV cache. -// used by MTP draft graphs that attend to the target's KV without owning any. -class llm_graph_input_attn_src_kv_iswa : public llm_graph_input_i { -public: - llm_graph_input_attn_src_kv_iswa( - const llama_hparams & hparams, - const llama_cparams & cparams, - const llama_kv_cache_iswa_context * src_mctx) : - hparams(hparams), - cparams(cparams), - src_mctx(src_mctx) { - } - ~llm_graph_input_attn_src_kv_iswa() = default; - - void set_input(const llama_ubatch * ubatch) override; - bool can_reuse(const llm_graph_params & params) override; - - ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; } - ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; } - - ggml_tensor * self_kq_mask = nullptr; - ggml_tensor * self_kq_mask_cnv = nullptr; - ggml_tensor * self_kq_mask_swa = nullptr; - ggml_tensor * self_kq_mask_swa_cnv = nullptr; - - ggml_tensor * self_k_rot = nullptr; - ggml_tensor * self_v_rot = nullptr; - ggml_tensor * self_k_rot_swa = nullptr; - ggml_tensor * self_v_rot_swa = nullptr; - - const llama_hparams hparams; - const llama_cparams cparams; - - const llama_kv_cache_iswa_context * src_mctx; -}; - class llm_graph_input_attn_cross : public llm_graph_input_i { public: llm_graph_input_attn_cross(const llama_cross * cross) : cross(cross) {} @@ -637,11 +601,6 @@ struct llm_graph_params { const llama_adapter_cvec * cvec; const llama_adapter_loras * loras; const llama_memory_context_i * mctx; - // per-decode snapshot of an external memory module the graph reads from - // (never writes) — e.g. ctx_dft reading target KV during MTP draft. - // nullptr for a main decode. Rebound inside reuse-aware input classes. - const llama_memory_context_i * src_mctx; - const llama_model * src_model; const llama_cross * cross; std::map samplers; @@ -859,8 +818,6 @@ struct llm_graph_context { const llama_adapter_cvec * cvec; const llama_adapter_loras * loras; const llama_memory_context_i * mctx; - const llama_memory_context_i * src_mctx; - const llama_model * src_model; const llama_cross * cross; std::map samplers; @@ -1090,24 +1047,6 @@ struct llm_graph_context { float kq_scale, int il) const; - llm_graph_input_attn_src_kv_iswa * build_attn_inp_src_kv_iswa() const; - - // Q-only attention against an external ISWA KV cache (no K/V projections, - // no writes). il_assist labels the attention block in the local graph for - // logging; il_src indexes the source K/V layer to attend to. - ggml_tensor * build_attn( - llm_graph_input_attn_src_kv_iswa * inp, - ggml_tensor * wo, - ggml_tensor * wo_b, - ggml_tensor * wo_s, - ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] - ggml_tensor * kq_b, - ggml_tensor * sinks, // [n_head_q] - ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] - float kq_scale, - int il_assist, - int il_src) const; - llm_graph_input_attn_cross * build_attn_inp_cross() const; ggml_tensor * build_attn( diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 4bbec21a8a..c2fda0e9d3 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -1552,6 +1552,10 @@ static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, T * data goto skip; } + if (ubatch->n_tokens == 1 && p0 == p1) { + goto skip; + } + // M-RoPE causal mask if (is_2d) { if (p0 == p1) { diff --git a/src/llama-memory.h b/src/llama-memory.h index 4ad1612e45..c058299ff3 100644 --- a/src/llama-memory.h +++ b/src/llama-memory.h @@ -122,4 +122,4 @@ struct llama_memory_i { virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) = 0; }; -using llama_memory_ptr = std::unique_ptr; +using llama_memory_ptr = std::shared_ptr; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index db3219f18f..057a4f85d8 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1987,6 +1987,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, case LLM_ARCH_WAVTOKENIZER_DEC: case LLM_ARCH_MODERN_BERT: case LLM_ARCH_GEMMA_EMBEDDING: + case LLM_ARCH_GEMMA4_ASSISTANT: case LLM_ARCH_DREAM: case LLM_ARCH_LLADA: case LLM_ARCH_LLADA_MOE: @@ -2630,3 +2631,11 @@ void llama_model_base::create_tensor_qkv(llama_layer & layer, int bid, layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", bid), {n_embd_v_}, TENSOR_NOT_REQUIRED); } } + +ggml_tensor * llama_model_get_tok_embd(const struct llama_model * model) { + return model->tok_embd; +} + +void llama_model_set_tok_embd(struct llama_model * model, ggml_tensor * tensor) { + model->tok_embd = tensor; +} diff --git a/src/llama-model.h b/src/llama-model.h index a3c068b84d..3c9ce4d582 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -706,6 +706,7 @@ const char * llm_type_name(llm_type type); #define LLAMA_LOAD_LOCALS \ const int n_layer = hparams.n_layer(); GGML_UNUSED(n_layer); \ const int n_layer_all = hparams.n_layer_all; GGML_UNUSED(n_layer_all); \ + const int n_layer_nextn = hparams.n_layer_nextn; GGML_UNUSED(n_layer_nextn); \ const int64_t n_head = hparams.n_head(); GGML_UNUSED(n_head); \ const int64_t n_head_kv = hparams.n_head_kv(); GGML_UNUSED(n_head_kv); \ const int64_t n_embd = hparams.n_embd; GGML_UNUSED(n_embd); \ diff --git a/src/models/gemma4-assistant.cpp b/src/models/gemma4-assistant.cpp index 8c274e0cbd..491fbd9ddc 100644 --- a/src/models/gemma4-assistant.cpp +++ b/src/models/gemma4-assistant.cpp @@ -82,18 +82,6 @@ std::unique_ptr llama_model_gemma4_assistant::build_arch_grap llama_model_gemma4_assistant::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - GGML_ASSERT(src_mctx && "Gemma 4 assistant graph requires an MTP source (llama_set_mtp_source)"); - GGML_ASSERT(src_model && "Gemma 4 assistant graph requires a source model"); - GGML_ASSERT(src_model->tok_embd && "source model missing tok_embd"); - - const auto & src_hparams = src_model->hparams; - - // By convention the MTP draft reads from the trunk's final SWA and full layers. - const int32_t src_layer_full = (int32_t) src_hparams.n_layer() - 1; - const int32_t src_layer_swa = (int32_t) src_hparams.n_layer() - 2; - GGML_ASSERT(!src_hparams.is_swa(src_layer_full) && "trunk's last layer must be full attention"); - GGML_ASSERT( src_hparams.is_swa(src_layer_swa) && "trunk's penultimate layer must be SWA"); - const int64_t n_embd_backbone = hparams.n_embd_out(); ggml_tensor * inp_tokens; @@ -116,7 +104,7 @@ llama_model_gemma4_assistant::graph::graph(const llama_model & model, const llm_ res->add_input(std::move(inp)); } - ggml_tensor * x = ggml_get_rows(ctx0, src_model->tok_embd, inp_tokens); + ggml_tensor * x = ggml_get_rows(ctx0, model.tok_embd, inp_tokens); x = ggml_scale(ctx0, x, sqrtf((float) n_embd_backbone)); cb(x, "inp_embd_target", -1); @@ -126,15 +114,14 @@ llama_model_gemma4_assistant::graph::graph(const llama_model & model, const llm_ ggml_tensor * cur = ggml_mul_mat(ctx0, model.nextn_pre_proj, xh); cb(cur, "pre_proj", -1); - auto * inp_attn = build_attn_inp_src_kv_iswa(); + auto * inp_attn = build_attn_inp_kv_iswa(); ggml_tensor * inp_pos = build_inp_pos(); ggml_tensor * inp_out_ids = build_inp_out_ids(); ggml_tensor * inpL = cur; for (int il = 0; il < n_layer; ++il) { - const bool is_swa = hparams.is_swa(il); - const int32_t il_src = is_swa ? src_layer_swa : src_layer_full; + const bool is_swa = hparams.is_swa(il); const int64_t n_embd_head = hparams.n_embd_head_k(il); const int64_t n_head = hparams.n_head(il); @@ -157,7 +144,7 @@ llama_model_gemma4_assistant::graph::graph(const llama_model & model, const llm_ cb(Qcur, "Qcur_pos", il); cur = build_attn(inp_attn, model.layers[il].wo, nullptr, nullptr, - Qcur, nullptr, nullptr, nullptr, hparams.f_attention_scale, il, il_src); + Qcur, nullptr, nullptr, nullptr, nullptr, nullptr, hparams.f_attention_scale, il); if (il == n_layer - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 80119c8c83..47489e8eff 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -974,13 +974,6 @@ private: cparams.n_rs_seq = 0; ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams)); - if (spec_mtp) { - // MTP draft must know its target before the first decode - llama_set_mtp_source(ctx_dft.get(), ctx_tgt); - } - - ctx_dft_seq_rm_type = common_context_can_seq_rm(ctx_dft.get()); - params_base.speculative.draft.ctx_tgt = ctx_tgt; params_base.speculative.draft.ctx_dft = ctx_dft.get(); } else if (std::find(params_base.speculative.types.begin(), params_base.speculative.types.end(), @@ -1001,12 +994,6 @@ private: return false; } - // wire the source before any decode (the seq-rm probe below - // triggers sched_reserve which needs src for Gemma4-style MTP) - llama_set_mtp_source(ctx_dft.get(), ctx_tgt); - - ctx_dft_seq_rm_type = common_context_can_seq_rm(ctx_dft.get()); - params_base.speculative.draft.ctx_tgt = ctx_tgt; params_base.speculative.draft.ctx_dft = ctx_dft.get(); } @@ -1094,6 +1081,10 @@ private: } } + if (ctx_dft) { + ctx_dft_seq_rm_type = common_context_can_seq_rm(ctx_dft.get()); + } + if (spec) { SRV_INF("%s", "speculative decoding context initialized\n"); } else {