From 37c56c245e93b94572c656c15df738e873a0c802 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 6 Jun 2026 16:30:41 +0300 Subject: [PATCH] wip --- common/speculative.cpp | 23 +++------- include/llama.h | 2 + src/llama-context.cpp | 25 ++++------- src/llama-context.h | 3 +- src/llama-cparams.h | 2 + src/llama-ext.h | 4 -- src/llama-graph.cpp | 12 ----- src/llama-kv-cache-dsa.cpp | 4 +- src/llama-kv-cache-iswa.cpp | 18 ++++++-- src/llama-kv-cache-iswa.h | 4 +- src/llama-kv-cache.cpp | 77 +++++++++++++++++++++++++++++--- src/llama-kv-cache.h | 9 +++- src/llama-memory-hybrid-iswa.cpp | 2 + src/llama-memory-hybrid.cpp | 2 + src/llama-memory.h | 6 ++- src/llama-model.cpp | 75 +++++++++++++++++++++---------- src/models/gemma4-assistant.cpp | 5 ++- tools/server/server-context.cpp | 25 +++-------- 18 files changed, 188 insertions(+), 110 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index e899be804e..3e959afecc 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -418,7 +418,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { int32_t n_embd = 0; - bool kv_shared_with_target = false; + bool is_mem_shared = false; // Per-sequence cross-batch carryover: pair (h_p, x_{p+1}) at MTP pos p+1. // The last h-row of one process() call needs the first token of the NEXT @@ -494,13 +494,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { llama_set_embeddings_nextn(ctx_tgt, true, /*masked*/ false); llama_set_embeddings_nextn(ctx_dft, true, /*masked*/ true); - kv_shared_with_target = llama_get_memory(ctx_dft) == nullptr; - if (kv_shared_with_target) { - llama_set_memory(ctx_dft, ctx_tgt); - - // TODO: avoid the const cast - llama_model_set_tok_embd(const_cast(llama_get_model(ctx_dft)), llama_model_get_tok_embd(llama_get_model(ctx_tgt))); - } + // TODO: fix this + is_mem_shared = true; pending_h.assign(n_seq, std::vector(n_embd, 0.0f)); @@ -541,7 +536,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { auto * ctx_dft = this->params.ctx_dft; const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id); - if (pos_max < N - 1 && !kv_shared_with_target) { + + if (pos_max < N - 1 && !is_mem_shared) { LOG_WRN("%s: ctx_dft pos_max=%d < N-1=%d - " "process() hook may not have run on every prefill ubatch " "(need_embd / logits=1 on every prompt position?). " @@ -585,7 +581,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { const size_t row_bytes = (size_t) n_embd * sizeof(float); // if kv is shared with target (e.g Gemma4), then we can skip this catch-up decode - if (!kv_shared_with_target) { + if (!is_mem_shared) { common_batch_clear(batch); for (int k = 0; k < n_tokens; ++k) { @@ -645,7 +641,6 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { void draft(common_speculative_draft_params_vec & dparams) override { auto & ctx_dft = params.ctx_dft; - auto & ctx_tgt = params.ctx_tgt; common_batch_clear(batch); @@ -671,8 +666,6 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { h_row = pending_h[seq_id].data(); std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes); - - llama_memory_seq_rm(llama_get_memory(ctx_dft), seq_id, dp.n_past, -1); } int ret = llama_decode(ctx_dft, batch); @@ -733,8 +726,6 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { common_batch_add(batch, id, dp.n_past + i + 1, { seq_id }, true); std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes); - - llama_memory_seq_rm(llama_get_memory(ctx_dft), seq_id, dp.n_past, -1); } if (batch.n_tokens == 0) { @@ -751,8 +742,6 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { ++i; } - llama_synchronize(ctx_dft); - for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { auto & dp = dparams[seq_id]; if (!dp.drafting) { diff --git a/include/llama.h b/include/llama.h index 9f78aa9a05..c475bab2b9 100644 --- a/include/llama.h +++ b/include/llama.h @@ -388,6 +388,8 @@ extern "C" { // note: the samplers must be sampler chains (i.e. use llama_sampler_chain_init) struct llama_sampler_seq_config * samplers; size_t n_samplers; + + struct llama_context * ctx_src; }; struct llama_model_tensor_override { diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 96720e3d67..b8a7d35f3a 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -126,9 +126,10 @@ llama_context::llama_context( cparams.embeddings_nextn_masked = false; cparams.offload_kqv = params.offload_kqv; cparams.no_perf = params.no_perf; - cparams.pooling_type = params.pooling_type; cparams.warmup = false; + cparams.ctx_type = params.ctx_type; + cparams.pooling_type = params.pooling_type; cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx; cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base; @@ -141,7 +142,7 @@ llama_context::llama_context( cparams.cb_eval = params.cb_eval; cparams.cb_eval_user_data = params.cb_eval_user_data; - cparams.ctx_type = params.ctx_type; + cparams.ctx_src = params.ctx_src; // Initialize backend samplers here so they are part of the sampling graph // before the reserve passes run later in this function. This avoids a later @@ -360,7 +361,8 @@ llama_context::llama_context( /*.type_k =*/ params.type_k, /*.type_v =*/ params.type_v, /*.swa_full =*/ params.swa_full, - /*.ctx_type= */ cparams.ctx_type, + /*.ctx_type =*/ cparams.ctx_type, + /*.mem_src =*/ params.ctx_src ? params.ctx_src->memory.get() : nullptr, }; memory.reset(model.create_memory(params_mem, cparams)); @@ -792,14 +794,8 @@ uint32_t llama_context::n_threads_batch() const { return cparams.n_threads_batch; } -llama_memory_ptr llama_context::get_memory() const { - return memory; -} - -void llama_context::set_memory(llama_memory_ptr memory) { - this->memory = std::move(memory); - - sched_need_reserve = true; +llama_memory_t llama_context::get_memory() const { + return memory.get(); } bool llama_context::memory_update(bool optimize) { @@ -3441,6 +3437,7 @@ llama_context_params llama_context_default_params() { /*.kv_unified =*/ false, /*.sampler =*/ nullptr, /*.n_sampler =*/ 0, + /*.ctx_src =*/ nullptr, }; return result; @@ -3658,12 +3655,8 @@ void llama_set_embeddings_nextn(llama_context * ctx, bool value, bool masked) { ctx->set_embeddings_nextn(value, masked); } -void llama_set_memory(llama_context * ctx, llama_context * src) { - ctx->set_memory(src->get_memory()); -} - llama_memory_t llama_get_memory(const struct llama_context * ctx) { - return ctx->get_memory().get(); + return ctx->get_memory(); } float * llama_get_embeddings_nextn(llama_context * ctx) { diff --git a/src/llama-context.h b/src/llama-context.h index af26186714..6f8f59a22a 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -71,8 +71,7 @@ struct llama_context { uint32_t n_threads() const; uint32_t n_threads_batch() const; - llama_memory_ptr get_memory() const; - void set_memory(llama_memory_ptr memory); + llama_memory_t get_memory() const; // return true if the memory was updated bool memory_update(bool optimize); diff --git a/src/llama-cparams.h b/src/llama-cparams.h index fd227ee5a2..8bb75aff35 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -49,4 +49,6 @@ struct llama_cparams { ggml_backend_sched_eval_callback cb_eval; void * cb_eval_user_data; + + llama_context * ctx_src; }; diff --git a/src/llama-ext.h b/src/llama-ext.h index e6faa4f26d..7ad6125fad 100644 --- a/src/llama-ext.h +++ b/src/llama-ext.h @@ -93,7 +93,6 @@ LLAMA_API llama_memory_breakdown llama_get_memory_breakdown(const struct llama_c // If masked == true, output the embeddings only for the tokens with batch.logits != 0 // If masked == false, output the embeddings for all tokens in the batch regardless of batch.logits LLAMA_API void llama_set_embeddings_nextn(struct llama_context * ctx, bool value, bool masked); -LLAMA_API void llama_set_memory(struct llama_context * ctx, struct llama_context * src); // mirrors: // LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); @@ -101,6 +100,3 @@ LLAMA_API float * llama_get_embeddings_nextn(struct llama_context * ctx); // LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i); LLAMA_API float * llama_get_embeddings_nextn_ith(struct llama_context * ctx, int32_t i); - -LLAMA_API ggml_tensor * llama_model_get_tok_embd(const struct llama_model * model); -LLAMA_API void llama_model_set_tok_embd( struct llama_model * model, ggml_tensor * tensor); diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 18f5d24ae4..4af710a145 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -2543,16 +2543,6 @@ ggml_tensor * llm_graph_context::build_attn( ggml_build_forward_expand(gf, v_cur); } - int il_save = il; - - if (arch == LLM_ARCH_GEMMA4_ASSISTANT) { - if (il == n_layer - 1) { - il = 59; - } else { - il = 58; - } - } - const auto * mctx_iswa = inp->mctx; const auto * mctx_cur = is_swa ? mctx_iswa->get_swa() : mctx_iswa->get_base(); @@ -2576,8 +2566,6 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * k = mctx_cur->get_k(ctx0, il); ggml_tensor * v = mctx_cur->get_v(ctx0, il); - il = il_save; - ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il); cb(cur, "kqv_out", il); diff --git a/src/llama-kv-cache-dsa.cpp b/src/llama-kv-cache-dsa.cpp index e44004b558..916ab65375 100644 --- a/src/llama-kv-cache-dsa.cpp +++ b/src/llama-kv-cache-dsa.cpp @@ -32,7 +32,7 @@ llama_kv_cache_dsa::llama_kv_cache_dsa( kv_mla = std::make_unique( model, model.hparams, type_k, type_v, v_trans, offload, unified, kv_size, n_seq_max, n_pad, - n_swa, swa_type, filter, reuse); + n_swa, swa_type, nullptr, filter, reuse, nullptr); // we use llama_kv_cache for caching indexer keys // by hand-tweaking some hparams we fool it to create @@ -49,7 +49,7 @@ llama_kv_cache_dsa::llama_kv_cache_dsa( kv_lid = std::make_unique( model, hparams_lid, type_k, type_v, v_trans, offload, unified, kv_size, n_seq_max, n_pad, - n_swa, swa_type, filter, reuse); + n_swa, swa_type, nullptr, filter, reuse, nullptr); } void llama_kv_cache_dsa::clear(bool data) { diff --git a/src/llama-kv-cache-iswa.cpp b/src/llama-kv-cache-iswa.cpp index 9b9f179036..54694d4a7e 100644 --- a/src/llama-kv-cache-iswa.cpp +++ b/src/llama-kv-cache-iswa.cpp @@ -23,8 +23,10 @@ llama_kv_cache_iswa::llama_kv_cache_iswa( uint32_t n_seq_max, uint32_t n_ubatch, uint32_t n_pad, + llama_memory_t mem_src, const layer_filter_cb & filter, - const layer_reuse_cb & reuse) : hparams(model.hparams), unified(unified) { + const layer_reuse_cb & reuse, + const layer_share_cb & share) : hparams(model.hparams), unified(unified) { // chain filters const layer_filter_cb filter_base = [&](int32_t il) { @@ -59,17 +61,27 @@ llama_kv_cache_iswa::llama_kv_cache_iswa( LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base); + llama_memory_t mem_src_base = nullptr; + if (mem_src) { + mem_src_base = static_cast(mem_src)->get_base(); + } + + llama_memory_t mem_src_swa = nullptr; + if (mem_src) { + mem_src_swa = static_cast(mem_src)->get_swa(); + } + kv_base = std::make_unique( model, hparams, type_k, type_v, v_trans, offload, unified, size_base, n_seq_max, n_pad, - 0, LLAMA_SWA_TYPE_NONE, filter_base, reuse); + 0, LLAMA_SWA_TYPE_NONE, mem_src_base, filter_base, reuse, share); LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa); kv_swa = std::make_unique( model, hparams, type_k, type_v, v_trans, offload, unified, size_swa, n_seq_max, n_pad, - hparams.n_swa, hparams.swa_type, filter_swa, reuse); + hparams.n_swa, hparams.swa_type, mem_src_swa, filter_swa, reuse, share); } void llama_kv_cache_iswa::clear(bool data) { diff --git a/src/llama-kv-cache-iswa.h b/src/llama-kv-cache-iswa.h index 70ab22f0d6..0206dd27e6 100644 --- a/src/llama-kv-cache-iswa.h +++ b/src/llama-kv-cache-iswa.h @@ -25,8 +25,10 @@ public: uint32_t n_seq_max, uint32_t n_ubatch, uint32_t n_pad, + llama_memory_t mem_src, const layer_filter_cb & filter, - const layer_reuse_cb & reuse); + const layer_reuse_cb & reuse, + const layer_share_cb & share); ~llama_kv_cache_iswa() = default; diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index c2fda0e9d3..d27fc44541 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -90,8 +90,10 @@ llama_kv_cache::llama_kv_cache( uint32_t n_pad, uint32_t n_swa, llama_swa_type swa_type, + llama_memory_t mem_src, const layer_filter_cb & filter, - const layer_reuse_cb & reuse) : + const layer_reuse_cb & reuse, + const layer_share_cb & share) : model(model), hparams(hparams), v_trans(v_trans), n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) { @@ -160,6 +162,8 @@ llama_kv_cache::llama_kv_cache( const bool is_mla = hparams.is_mla(); + other = static_cast(mem_src); + for (uint32_t il = 0; il < n_layer; il++) { if (!hparams.has_kv(il)) { LLAMA_LOG_DEBUG("%s: layer %3d: does not have KV cache\n", __func__, il); @@ -171,6 +175,24 @@ llama_kv_cache::llama_kv_cache( continue; } + if (share && other) { + const int32_t il_share = share(il); + + if (il_share >= 0) { + const auto & layer_share = other->layers[other->map_layer_ids[il_share]]; + + LLAMA_LOG_WARN("%s: layer %3d: sharing with layer %d. k = %p, v = %p\n", __func__, il, il_share, + layer_share.k->data, layer_share.v->data); + + map_layer_ids[il] = layers.size(); + + layers.push_back(layer_share); + layers.back().il = il; + + continue; + } + } + if (n_embd_head_k_all == 0) { n_embd_head_k_all = (int32_t) hparams.n_embd_head_k(il); } else if (n_embd_head_k_all > 0 && n_embd_head_k_all != (int32_t) hparams.n_embd_head_k(il)) { @@ -347,6 +369,11 @@ void llama_kv_cache::clear(bool data) { } bool llama_kv_cache::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + return true; + } + GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size())); if (p0 < 0) { @@ -410,6 +437,11 @@ bool llama_kv_cache::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { } void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + return; + } + GGML_ASSERT(seq_id_src >= 0 && (size_t) seq_id_src < seq_to_stream.size()); GGML_ASSERT(seq_id_dst >= 0 && (size_t) seq_id_dst < seq_to_stream.size()); @@ -497,6 +529,11 @@ void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, ll } void llama_kv_cache::seq_keep(llama_seq_id seq_id) { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + return; + } + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); auto & cells = v_cells[seq_to_stream[seq_id]]; @@ -519,6 +556,11 @@ void llama_kv_cache::seq_keep(llama_seq_id seq_id) { } void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + return; + } + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); GGML_ASSERT(hparams.n_pos_per_embd() == 1 && "seq_add() is only supported for n_pos_per_embd() == 1"); @@ -564,6 +606,11 @@ void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, ll } void llama_kv_cache::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + return; + } + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); GGML_ASSERT(hparams.n_pos_per_embd() == 1 && "seq_div() is only supported for n_pos_per_embd() == 1"); @@ -746,6 +793,11 @@ llama_kv_cache::slot_info_vec_t llama_kv_cache::prepare(const std::vectorget_sched(); @@ -1021,6 +1073,12 @@ llama_kv_cache::slot_info llama_kv_cache::find_slot(const llama_ubatch & ubatch, } void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch) { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + v_cells = other->v_cells; + return; + } + // keep track of the max sequence position that we would overwrite with this ubatch // for non-SWA cache, this would be always empty llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ]; @@ -1552,10 +1610,6 @@ static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, T * data goto skip; } - if (ubatch->n_tokens == 1 && p0 == p1) { - goto skip; - } - // M-RoPE causal mask if (is_2d) { if (p0 == p1) { @@ -1819,6 +1873,9 @@ void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) { } ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_context * lctx) const { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + GGML_ASSERT(!other); + auto * ctx = res->get_ctx(); auto * gf = res->get_gf(); @@ -1864,6 +1921,11 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co } void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + return; + } + GGML_UNUSED(flags); io.write(&n_stream, sizeof(n_stream)); @@ -1929,6 +1991,11 @@ void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, lla } void llama_kv_cache::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + return; + } + GGML_UNUSED(flags); GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size())); diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index 99f5010195..8ad2412149 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -98,7 +98,7 @@ public: // likely through `struct llama_memory_params` llama_kv_cache( const llama_model & model, - const llama_hparams & hparams, + const llama_hparams & hparams, ggml_type type_k, ggml_type type_v, bool v_trans, @@ -109,8 +109,10 @@ public: uint32_t n_pad, uint32_t n_swa, llama_swa_type swa_type, + llama_memory_t mem_src, const layer_filter_cb & filter, - const layer_reuse_cb & reuse); + const layer_reuse_cb & reuse, + const layer_share_cb & share); ~llama_kv_cache() = default; @@ -264,6 +266,9 @@ private: // note: this is not part of the KV state and it's only used to speed-up the find_slot() method std::vector v_heads; + // TODO: temporary until we refactor to be able to share the same cells between 2 kv caches [TAG_KV_CACHE_SHARE_CELLS] + llama_kv_cache * other; + std::vector v_cells; // maps from a sequence id to a stream id diff --git a/src/llama-memory-hybrid-iswa.cpp b/src/llama-memory-hybrid-iswa.cpp index a242079b40..c7d4bcd413 100644 --- a/src/llama-memory-hybrid-iswa.cpp +++ b/src/llama-memory-hybrid-iswa.cpp @@ -43,9 +43,11 @@ llama_memory_hybrid_iswa::llama_memory_hybrid_iswa( n_seq_max, n_ubatch, n_pad, + nullptr, filter_attn == nullptr ? [&](int32_t il) { return !hparams.is_recr(il); } : filter_attn, + nullptr, nullptr )), mem_recr(new llama_memory_recurrent( diff --git a/src/llama-memory-hybrid.cpp b/src/llama-memory-hybrid.cpp index 66ec3fd6d5..f2d49cbce5 100644 --- a/src/llama-memory-hybrid.cpp +++ b/src/llama-memory-hybrid.cpp @@ -44,9 +44,11 @@ llama_memory_hybrid::llama_memory_hybrid( n_pad, n_swa, swa_type, + nullptr, filter_attn == nullptr ? [&](int32_t il) { return !hparams.is_recr(il); } : filter_attn, + nullptr, nullptr )), mem_recr(new llama_memory_recurrent( diff --git a/src/llama-memory.h b/src/llama-memory.h index c058299ff3..e3025ec789 100644 --- a/src/llama-memory.h +++ b/src/llama-memory.h @@ -23,6 +23,8 @@ struct llama_memory_params { bool swa_full; llama_context_type ctx_type; + + llama_memory_t mem_src; }; enum llama_memory_status { @@ -76,6 +78,8 @@ struct llama_memory_i { // return negative value to indicate that the layer il should not reuse memory using layer_reuse_cb = std::function; + using layer_share_cb = std::function; + virtual ~llama_memory_i() = default; // split the input batch into a set of ubatches and verify that they can fit into the cache @@ -122,4 +126,4 @@ struct llama_memory_i { virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) = 0; }; -using llama_memory_ptr = std::shared_ptr; +using llama_memory_ptr = std::unique_ptr; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 057a4f85d8..88f72a9415 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1987,7 +1987,6 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, case LLM_ARCH_WAVTOKENIZER_DEC: case LLM_ARCH_MODERN_BERT: case LLM_ARCH_GEMMA_EMBEDDING: - case LLM_ARCH_GEMMA4_ASSISTANT: case LLM_ARCH_DREAM: case LLM_ARCH_LLADA: case LLM_ARCH_LLADA_MOE: @@ -2097,8 +2096,9 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, /* filter_recr */ std::move(filter_recr)); } } else { - llama_memory_i::layer_reuse_cb reuse = nullptr; llama_kv_cache::layer_filter_cb filter = nullptr; + llama_memory_i::layer_reuse_cb reuse = nullptr; + llama_kv_cache::layer_share_cb share = nullptr; if (arch == LLM_ARCH_GEMMA3N || arch == LLM_ARCH_GEMMA4) { reuse = [&](uint32_t il) { @@ -2124,23 +2124,56 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, } } + llama_memory_t mem_src = cparams.ctx_src ? llama_get_memory(cparams.ctx_src) : nullptr; + if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { GGML_ASSERT(hparams.is_swa_any()); - res = new llama_kv_cache_iswa( - *this, - params.type_k, - params.type_v, - !cparams.flash_attn, - cparams.offload_kqv, - params.swa_full, - cparams.kv_unified, - cparams.n_ctx_seq, - cparams.n_seq_max, - cparams.n_ubatch, - 1, - filter, - reuse); + if (arch == LLM_ARCH_GEMMA4_ASSISTANT) { + share = [&](int32_t il) { + const llama_model * model_src = llama_get_model(cparams.ctx_src); + + if (hparams.is_swa(il)) { + return llama_model_n_layer(model_src) - 2; + } + + return llama_model_n_layer(model_src) - 1; + }; + + res = new llama_kv_cache_iswa( + *this, + params.type_k, + params.type_v, + !cparams.flash_attn, + cparams.offload_kqv, + params.swa_full, + cparams.kv_unified, + cparams.n_ctx_seq, + cparams.n_seq_max, + cparams.n_ubatch, + 1, + mem_src, + filter, + reuse, + share); + } else { + res = new llama_kv_cache_iswa( + *this, + params.type_k, + params.type_v, + !cparams.flash_attn, + cparams.offload_kqv, + params.swa_full, + cparams.kv_unified, + cparams.n_ctx_seq, + cparams.n_seq_max, + cparams.n_ubatch, + 1, + mem_src, + filter, + reuse, + share); + } } else { GGML_ASSERT(!hparams.is_swa_any()); @@ -2157,7 +2190,9 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, 1, hparams.n_swa, hparams.swa_type, + mem_src, filter, + nullptr, nullptr); } } @@ -2631,11 +2666,3 @@ void llama_model_base::create_tensor_qkv(llama_layer & layer, int bid, layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", bid), {n_embd_v_}, TENSOR_NOT_REQUIRED); } } - -ggml_tensor * llama_model_get_tok_embd(const struct llama_model * model) { - return model->tok_embd; -} - -void llama_model_set_tok_embd(struct llama_model * model, ggml_tensor * tensor) { - model->tok_embd = tensor; -} diff --git a/src/models/gemma4-assistant.cpp b/src/models/gemma4-assistant.cpp index 491fbd9ddc..f7ece04fa5 100644 --- a/src/models/gemma4-assistant.cpp +++ b/src/models/gemma4-assistant.cpp @@ -104,7 +104,10 @@ llama_model_gemma4_assistant::graph::graph(const llama_model & model, const llm_ res->add_input(std::move(inp)); } - ggml_tensor * x = ggml_get_rows(ctx0, model.tok_embd, inp_tokens); + GGML_ASSERT(cparams.ctx_src != nullptr); + const auto * model_src = llama_get_model(cparams.ctx_src); + + ggml_tensor * x = ggml_get_rows(ctx0, model_src->tok_embd, inp_tokens); x = ggml_scale(ctx0, x, sqrtf((float) n_embd_backbone)); cb(x, "inp_embd_target", -1); diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 47489e8eff..f0081d6155 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -10,8 +10,6 @@ #include "common.h" #include "fit.h" #include "llama.h" -#include "../../src/llama-ext.h" // staging API: llama_set_mtp_source -#include "ggml-cpp.h" #include "log.h" #include "sampling.h" #include "speculative.h" @@ -847,30 +845,14 @@ private: } cparams_dft.n_rs_seq = 0; - bool skip_measure = false; - //TODO: remove this - if (spec_mtp && has_draft) { - struct gguf_init_params meta_params = { - /* .no_alloc = */ true, - /* .ctx = */ nullptr, - }; - gguf_context_ptr meta(gguf_init_from_file(params_dft.model.path.c_str(), meta_params)); - - if (std::string(gguf_get_val_str(meta.get(), gguf_find_key(meta.get(), "general.architecture"))) == "gemma4-assistant") { - skip_measure = true; - SRV_WRN("[spec] skipping --fit memory measurement for Gemma 4 assistant draft model '%s'\n", - params_dft.model.path.c_str()); - } - } - std::vector devs; uint32_t hp_ngl = 0; uint32_t hp_nct = 0; uint32_t hp_nex = 0; - if (!skip_measure) try { + try { auto dmd = common_get_device_memory_data( params_dft.model.path.c_str(), &mparams_dft, &cparams_dft, - devs, hp_ngl, hp_nct, hp_nex, GGML_LOG_LEVEL_ERROR); + devs, hp_ngl, hp_nct, hp_nex, GGML_LOG_LEVEL_DEBUG); GGML_ASSERT(!params_base.fit_params_target.empty()); size_t total = 0; @@ -972,6 +954,8 @@ private: // note: for small models maybe we can set this to the maximum possible draft from all speculative types // the extra memory for small models is likely negligible? cparams.n_rs_seq = 0; + cparams.ctx_src = ctx_tgt; + ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams)); params_base.speculative.draft.ctx_tgt = ctx_tgt; @@ -987,6 +971,7 @@ private: cparams_mtp.type_v = params_base.speculative.draft.cache_type_v; cparams_mtp.n_rs_seq = 0; cparams_mtp.n_outputs_max = params_base.n_parallel; + cparams_mtp.ctx_src = ctx_tgt; ctx_dft.reset(llama_init_from_model(model_tgt, cparams_mtp)); if (ctx_dft == nullptr) {