This commit is contained in:
Georgi Gerganov 2026-06-06 16:30:41 +03:00
parent 1c4a91c0f3
commit 37c56c245e
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
18 changed files with 188 additions and 110 deletions

View File

@ -418,7 +418,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
int32_t n_embd = 0;
bool kv_shared_with_target = false;
bool is_mem_shared = false;
// Per-sequence cross-batch carryover: pair (h_p, x_{p+1}) at MTP pos p+1.
// The last h-row of one process() call needs the first token of the NEXT
@ -494,13 +494,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
llama_set_embeddings_nextn(ctx_tgt, true, /*masked*/ false);
llama_set_embeddings_nextn(ctx_dft, true, /*masked*/ true);
kv_shared_with_target = llama_get_memory(ctx_dft) == nullptr;
if (kv_shared_with_target) {
llama_set_memory(ctx_dft, ctx_tgt);
// TODO: avoid the const cast
llama_model_set_tok_embd(const_cast<llama_model *>(llama_get_model(ctx_dft)), llama_model_get_tok_embd(llama_get_model(ctx_tgt)));
}
// TODO: fix this
is_mem_shared = true;
pending_h.assign(n_seq, std::vector<float>(n_embd, 0.0f));
@ -541,7 +536,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
auto * ctx_dft = this->params.ctx_dft;
const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id);
if (pos_max < N - 1 && !kv_shared_with_target) {
if (pos_max < N - 1 && !is_mem_shared) {
LOG_WRN("%s: ctx_dft pos_max=%d < N-1=%d - "
"process() hook may not have run on every prefill ubatch "
"(need_embd / logits=1 on every prompt position?). "
@ -585,7 +581,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
const size_t row_bytes = (size_t) n_embd * sizeof(float);
// if kv is shared with target (e.g Gemma4), then we can skip this catch-up decode
if (!kv_shared_with_target) {
if (!is_mem_shared) {
common_batch_clear(batch);
for (int k = 0; k < n_tokens; ++k) {
@ -645,7 +641,6 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
void draft(common_speculative_draft_params_vec & dparams) override {
auto & ctx_dft = params.ctx_dft;
auto & ctx_tgt = params.ctx_tgt;
common_batch_clear(batch);
@ -671,8 +666,6 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
h_row = pending_h[seq_id].data();
std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes);
llama_memory_seq_rm(llama_get_memory(ctx_dft), seq_id, dp.n_past, -1);
}
int ret = llama_decode(ctx_dft, batch);
@ -733,8 +726,6 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
common_batch_add(batch, id, dp.n_past + i + 1, { seq_id }, true);
std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes);
llama_memory_seq_rm(llama_get_memory(ctx_dft), seq_id, dp.n_past, -1);
}
if (batch.n_tokens == 0) {
@ -751,8 +742,6 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
++i;
}
llama_synchronize(ctx_dft);
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
auto & dp = dparams[seq_id];
if (!dp.drafting) {

View File

@ -388,6 +388,8 @@ extern "C" {
// note: the samplers must be sampler chains (i.e. use llama_sampler_chain_init)
struct llama_sampler_seq_config * samplers;
size_t n_samplers;
struct llama_context * ctx_src;
};
struct llama_model_tensor_override {

View File

@ -126,9 +126,10 @@ llama_context::llama_context(
cparams.embeddings_nextn_masked = false;
cparams.offload_kqv = params.offload_kqv;
cparams.no_perf = params.no_perf;
cparams.pooling_type = params.pooling_type;
cparams.warmup = false;
cparams.ctx_type = params.ctx_type;
cparams.pooling_type = params.pooling_type;
cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base;
@ -141,7 +142,7 @@ llama_context::llama_context(
cparams.cb_eval = params.cb_eval;
cparams.cb_eval_user_data = params.cb_eval_user_data;
cparams.ctx_type = params.ctx_type;
cparams.ctx_src = params.ctx_src;
// Initialize backend samplers here so they are part of the sampling graph
// before the reserve passes run later in this function. This avoids a later
@ -360,7 +361,8 @@ llama_context::llama_context(
/*.type_k =*/ params.type_k,
/*.type_v =*/ params.type_v,
/*.swa_full =*/ params.swa_full,
/*.ctx_type= */ cparams.ctx_type,
/*.ctx_type =*/ cparams.ctx_type,
/*.mem_src =*/ params.ctx_src ? params.ctx_src->memory.get() : nullptr,
};
memory.reset(model.create_memory(params_mem, cparams));
@ -792,14 +794,8 @@ uint32_t llama_context::n_threads_batch() const {
return cparams.n_threads_batch;
}
llama_memory_ptr llama_context::get_memory() const {
return memory;
}
void llama_context::set_memory(llama_memory_ptr memory) {
this->memory = std::move(memory);
sched_need_reserve = true;
llama_memory_t llama_context::get_memory() const {
return memory.get();
}
bool llama_context::memory_update(bool optimize) {
@ -3441,6 +3437,7 @@ llama_context_params llama_context_default_params() {
/*.kv_unified =*/ false,
/*.sampler =*/ nullptr,
/*.n_sampler =*/ 0,
/*.ctx_src =*/ nullptr,
};
return result;
@ -3658,12 +3655,8 @@ void llama_set_embeddings_nextn(llama_context * ctx, bool value, bool masked) {
ctx->set_embeddings_nextn(value, masked);
}
void llama_set_memory(llama_context * ctx, llama_context * src) {
ctx->set_memory(src->get_memory());
}
llama_memory_t llama_get_memory(const struct llama_context * ctx) {
return ctx->get_memory().get();
return ctx->get_memory();
}
float * llama_get_embeddings_nextn(llama_context * ctx) {

View File

@ -71,8 +71,7 @@ struct llama_context {
uint32_t n_threads() const;
uint32_t n_threads_batch() const;
llama_memory_ptr get_memory() const;
void set_memory(llama_memory_ptr memory);
llama_memory_t get_memory() const;
// return true if the memory was updated
bool memory_update(bool optimize);

View File

@ -49,4 +49,6 @@ struct llama_cparams {
ggml_backend_sched_eval_callback cb_eval;
void * cb_eval_user_data;
llama_context * ctx_src;
};

View File

@ -93,7 +93,6 @@ LLAMA_API llama_memory_breakdown llama_get_memory_breakdown(const struct llama_c
// If masked == true, output the embeddings only for the tokens with batch.logits != 0
// If masked == false, output the embeddings for all tokens in the batch regardless of batch.logits
LLAMA_API void llama_set_embeddings_nextn(struct llama_context * ctx, bool value, bool masked);
LLAMA_API void llama_set_memory(struct llama_context * ctx, struct llama_context * src);
// mirrors:
// LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
@ -101,6 +100,3 @@ LLAMA_API float * llama_get_embeddings_nextn(struct llama_context * ctx);
// LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
LLAMA_API float * llama_get_embeddings_nextn_ith(struct llama_context * ctx, int32_t i);
LLAMA_API ggml_tensor * llama_model_get_tok_embd(const struct llama_model * model);
LLAMA_API void llama_model_set_tok_embd( struct llama_model * model, ggml_tensor * tensor);

View File

@ -2543,16 +2543,6 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_build_forward_expand(gf, v_cur);
}
int il_save = il;
if (arch == LLM_ARCH_GEMMA4_ASSISTANT) {
if (il == n_layer - 1) {
il = 59;
} else {
il = 58;
}
}
const auto * mctx_iswa = inp->mctx;
const auto * mctx_cur = is_swa ? mctx_iswa->get_swa() : mctx_iswa->get_base();
@ -2576,8 +2566,6 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
il = il_save;
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
cb(cur, "kqv_out", il);

View File

@ -32,7 +32,7 @@ llama_kv_cache_dsa::llama_kv_cache_dsa(
kv_mla = std::make_unique<llama_kv_cache>(
model, model.hparams, type_k, type_v,
v_trans, offload, unified, kv_size, n_seq_max, n_pad,
n_swa, swa_type, filter, reuse);
n_swa, swa_type, nullptr, filter, reuse, nullptr);
// we use llama_kv_cache for caching indexer keys
// by hand-tweaking some hparams we fool it to create
@ -49,7 +49,7 @@ llama_kv_cache_dsa::llama_kv_cache_dsa(
kv_lid = std::make_unique<llama_kv_cache>(
model, hparams_lid, type_k, type_v,
v_trans, offload, unified, kv_size, n_seq_max, n_pad,
n_swa, swa_type, filter, reuse);
n_swa, swa_type, nullptr, filter, reuse, nullptr);
}
void llama_kv_cache_dsa::clear(bool data) {

View File

@ -23,8 +23,10 @@ llama_kv_cache_iswa::llama_kv_cache_iswa(
uint32_t n_seq_max,
uint32_t n_ubatch,
uint32_t n_pad,
llama_memory_t mem_src,
const layer_filter_cb & filter,
const layer_reuse_cb & reuse) : hparams(model.hparams), unified(unified) {
const layer_reuse_cb & reuse,
const layer_share_cb & share) : hparams(model.hparams), unified(unified) {
// chain filters
const layer_filter_cb filter_base = [&](int32_t il) {
@ -59,17 +61,27 @@ llama_kv_cache_iswa::llama_kv_cache_iswa(
LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base);
llama_memory_t mem_src_base = nullptr;
if (mem_src) {
mem_src_base = static_cast<llama_kv_cache_iswa *>(mem_src)->get_base();
}
llama_memory_t mem_src_swa = nullptr;
if (mem_src) {
mem_src_swa = static_cast<llama_kv_cache_iswa *>(mem_src)->get_swa();
}
kv_base = std::make_unique<llama_kv_cache>(
model, hparams, type_k, type_v,
v_trans, offload, unified, size_base, n_seq_max, n_pad,
0, LLAMA_SWA_TYPE_NONE, filter_base, reuse);
0, LLAMA_SWA_TYPE_NONE, mem_src_base, filter_base, reuse, share);
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
kv_swa = std::make_unique<llama_kv_cache>(
model, hparams, type_k, type_v,
v_trans, offload, unified, size_swa, n_seq_max, n_pad,
hparams.n_swa, hparams.swa_type, filter_swa, reuse);
hparams.n_swa, hparams.swa_type, mem_src_swa, filter_swa, reuse, share);
}
void llama_kv_cache_iswa::clear(bool data) {

View File

@ -25,8 +25,10 @@ public:
uint32_t n_seq_max,
uint32_t n_ubatch,
uint32_t n_pad,
llama_memory_t mem_src,
const layer_filter_cb & filter,
const layer_reuse_cb & reuse);
const layer_reuse_cb & reuse,
const layer_share_cb & share);
~llama_kv_cache_iswa() = default;

View File

@ -90,8 +90,10 @@ llama_kv_cache::llama_kv_cache(
uint32_t n_pad,
uint32_t n_swa,
llama_swa_type swa_type,
llama_memory_t mem_src,
const layer_filter_cb & filter,
const layer_reuse_cb & reuse) :
const layer_reuse_cb & reuse,
const layer_share_cb & share) :
model(model), hparams(hparams), v_trans(v_trans),
n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
@ -160,6 +162,8 @@ llama_kv_cache::llama_kv_cache(
const bool is_mla = hparams.is_mla();
other = static_cast<llama_kv_cache *>(mem_src);
for (uint32_t il = 0; il < n_layer; il++) {
if (!hparams.has_kv(il)) {
LLAMA_LOG_DEBUG("%s: layer %3d: does not have KV cache\n", __func__, il);
@ -171,6 +175,24 @@ llama_kv_cache::llama_kv_cache(
continue;
}
if (share && other) {
const int32_t il_share = share(il);
if (il_share >= 0) {
const auto & layer_share = other->layers[other->map_layer_ids[il_share]];
LLAMA_LOG_WARN("%s: layer %3d: sharing with layer %d. k = %p, v = %p\n", __func__, il, il_share,
layer_share.k->data, layer_share.v->data);
map_layer_ids[il] = layers.size();
layers.push_back(layer_share);
layers.back().il = il;
continue;
}
}
if (n_embd_head_k_all == 0) {
n_embd_head_k_all = (int32_t) hparams.n_embd_head_k(il);
} else if (n_embd_head_k_all > 0 && n_embd_head_k_all != (int32_t) hparams.n_embd_head_k(il)) {
@ -347,6 +369,11 @@ void llama_kv_cache::clear(bool data) {
}
bool llama_kv_cache::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
if (other) {
return true;
}
GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));
if (p0 < 0) {
@ -410,6 +437,11 @@ bool llama_kv_cache::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
}
void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
if (other) {
return;
}
GGML_ASSERT(seq_id_src >= 0 && (size_t) seq_id_src < seq_to_stream.size());
GGML_ASSERT(seq_id_dst >= 0 && (size_t) seq_id_dst < seq_to_stream.size());
@ -497,6 +529,11 @@ void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, ll
}
void llama_kv_cache::seq_keep(llama_seq_id seq_id) {
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
if (other) {
return;
}
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
auto & cells = v_cells[seq_to_stream[seq_id]];
@ -519,6 +556,11 @@ void llama_kv_cache::seq_keep(llama_seq_id seq_id) {
}
void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
if (other) {
return;
}
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
GGML_ASSERT(hparams.n_pos_per_embd() == 1 && "seq_add() is only supported for n_pos_per_embd() == 1");
@ -564,6 +606,11 @@ void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, ll
}
void llama_kv_cache::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
if (other) {
return;
}
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
GGML_ASSERT(hparams.n_pos_per_embd() == 1 && "seq_div() is only supported for n_pos_per_embd() == 1");
@ -746,6 +793,11 @@ llama_kv_cache::slot_info_vec_t llama_kv_cache::prepare(const std::vector<llama_
}
bool llama_kv_cache::update(llama_context * lctx, bool do_shift, const stream_copy_info & sc_info) {
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
if (other) {
return true;
}
bool updated = false;
auto * sched = lctx->get_sched();
@ -1021,6 +1073,12 @@ llama_kv_cache::slot_info llama_kv_cache::find_slot(const llama_ubatch & ubatch,
}
void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch) {
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
if (other) {
v_cells = other->v_cells;
return;
}
// keep track of the max sequence position that we would overwrite with this ubatch
// for non-SWA cache, this would be always empty
llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
@ -1552,10 +1610,6 @@ static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, T * data
goto skip;
}
if (ubatch->n_tokens == 1 && p0 == p1) {
goto skip;
}
// M-RoPE causal mask
if (is_2d) {
if (p0 == p1) {
@ -1819,6 +1873,9 @@ void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
}
ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_context * lctx) const {
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
GGML_ASSERT(!other);
auto * ctx = res->get_ctx();
auto * gf = res->get_gf();
@ -1864,6 +1921,11 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
}
void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
if (other) {
return;
}
GGML_UNUSED(flags);
io.write(&n_stream, sizeof(n_stream));
@ -1929,6 +1991,11 @@ void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, lla
}
void llama_kv_cache::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
if (other) {
return;
}
GGML_UNUSED(flags);
GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));

View File

@ -98,7 +98,7 @@ public:
// likely through `struct llama_memory_params`
llama_kv_cache(
const llama_model & model,
const llama_hparams & hparams,
const llama_hparams & hparams,
ggml_type type_k,
ggml_type type_v,
bool v_trans,
@ -109,8 +109,10 @@ public:
uint32_t n_pad,
uint32_t n_swa,
llama_swa_type swa_type,
llama_memory_t mem_src,
const layer_filter_cb & filter,
const layer_reuse_cb & reuse);
const layer_reuse_cb & reuse,
const layer_share_cb & share);
~llama_kv_cache() = default;
@ -264,6 +266,9 @@ private:
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method
std::vector<uint32_t> v_heads;
// TODO: temporary until we refactor to be able to share the same cells between 2 kv caches [TAG_KV_CACHE_SHARE_CELLS]
llama_kv_cache * other;
std::vector<llama_kv_cells> v_cells;
// maps from a sequence id to a stream id

View File

@ -43,9 +43,11 @@ llama_memory_hybrid_iswa::llama_memory_hybrid_iswa(
n_seq_max,
n_ubatch,
n_pad,
nullptr,
filter_attn == nullptr ?
[&](int32_t il) { return !hparams.is_recr(il); }
: filter_attn,
nullptr,
nullptr
)),
mem_recr(new llama_memory_recurrent(

View File

@ -44,9 +44,11 @@ llama_memory_hybrid::llama_memory_hybrid(
n_pad,
n_swa,
swa_type,
nullptr,
filter_attn == nullptr ?
[&](int32_t il) { return !hparams.is_recr(il); }
: filter_attn,
nullptr,
nullptr
)),
mem_recr(new llama_memory_recurrent(

View File

@ -23,6 +23,8 @@ struct llama_memory_params {
bool swa_full;
llama_context_type ctx_type;
llama_memory_t mem_src;
};
enum llama_memory_status {
@ -76,6 +78,8 @@ struct llama_memory_i {
// return negative value to indicate that the layer il should not reuse memory
using layer_reuse_cb = std::function<int32_t(int32_t il)>;
using layer_share_cb = std::function<int32_t(int32_t il)>;
virtual ~llama_memory_i() = default;
// split the input batch into a set of ubatches and verify that they can fit into the cache
@ -122,4 +126,4 @@ struct llama_memory_i {
virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) = 0;
};
using llama_memory_ptr = std::shared_ptr<llama_memory_i>;
using llama_memory_ptr = std::unique_ptr<llama_memory_i>;

View File

@ -1987,7 +1987,6 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
case LLM_ARCH_WAVTOKENIZER_DEC:
case LLM_ARCH_MODERN_BERT:
case LLM_ARCH_GEMMA_EMBEDDING:
case LLM_ARCH_GEMMA4_ASSISTANT:
case LLM_ARCH_DREAM:
case LLM_ARCH_LLADA:
case LLM_ARCH_LLADA_MOE:
@ -2097,8 +2096,9 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
/* filter_recr */ std::move(filter_recr));
}
} else {
llama_memory_i::layer_reuse_cb reuse = nullptr;
llama_kv_cache::layer_filter_cb filter = nullptr;
llama_memory_i::layer_reuse_cb reuse = nullptr;
llama_kv_cache::layer_share_cb share = nullptr;
if (arch == LLM_ARCH_GEMMA3N || arch == LLM_ARCH_GEMMA4) {
reuse = [&](uint32_t il) {
@ -2124,23 +2124,56 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
}
}
llama_memory_t mem_src = cparams.ctx_src ? llama_get_memory(cparams.ctx_src) : nullptr;
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
GGML_ASSERT(hparams.is_swa_any());
res = new llama_kv_cache_iswa(
*this,
params.type_k,
params.type_v,
!cparams.flash_attn,
cparams.offload_kqv,
params.swa_full,
cparams.kv_unified,
cparams.n_ctx_seq,
cparams.n_seq_max,
cparams.n_ubatch,
1,
filter,
reuse);
if (arch == LLM_ARCH_GEMMA4_ASSISTANT) {
share = [&](int32_t il) {
const llama_model * model_src = llama_get_model(cparams.ctx_src);
if (hparams.is_swa(il)) {
return llama_model_n_layer(model_src) - 2;
}
return llama_model_n_layer(model_src) - 1;
};
res = new llama_kv_cache_iswa(
*this,
params.type_k,
params.type_v,
!cparams.flash_attn,
cparams.offload_kqv,
params.swa_full,
cparams.kv_unified,
cparams.n_ctx_seq,
cparams.n_seq_max,
cparams.n_ubatch,
1,
mem_src,
filter,
reuse,
share);
} else {
res = new llama_kv_cache_iswa(
*this,
params.type_k,
params.type_v,
!cparams.flash_attn,
cparams.offload_kqv,
params.swa_full,
cparams.kv_unified,
cparams.n_ctx_seq,
cparams.n_seq_max,
cparams.n_ubatch,
1,
mem_src,
filter,
reuse,
share);
}
} else {
GGML_ASSERT(!hparams.is_swa_any());
@ -2157,7 +2190,9 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1,
hparams.n_swa,
hparams.swa_type,
mem_src,
filter,
nullptr,
nullptr);
}
}
@ -2631,11 +2666,3 @@ void llama_model_base::create_tensor_qkv(llama_layer & layer, int bid,
layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", bid), {n_embd_v_}, TENSOR_NOT_REQUIRED);
}
}
ggml_tensor * llama_model_get_tok_embd(const struct llama_model * model) {
return model->tok_embd;
}
void llama_model_set_tok_embd(struct llama_model * model, ggml_tensor * tensor) {
model->tok_embd = tensor;
}

View File

@ -104,7 +104,10 @@ llama_model_gemma4_assistant::graph::graph(const llama_model & model, const llm_
res->add_input(std::move(inp));
}
ggml_tensor * x = ggml_get_rows(ctx0, model.tok_embd, inp_tokens);
GGML_ASSERT(cparams.ctx_src != nullptr);
const auto * model_src = llama_get_model(cparams.ctx_src);
ggml_tensor * x = ggml_get_rows(ctx0, model_src->tok_embd, inp_tokens);
x = ggml_scale(ctx0, x, sqrtf((float) n_embd_backbone));
cb(x, "inp_embd_target", -1);

View File

@ -10,8 +10,6 @@
#include "common.h"
#include "fit.h"
#include "llama.h"
#include "../../src/llama-ext.h" // staging API: llama_set_mtp_source
#include "ggml-cpp.h"
#include "log.h"
#include "sampling.h"
#include "speculative.h"
@ -847,30 +845,14 @@ private:
}
cparams_dft.n_rs_seq = 0;
bool skip_measure = false;
//TODO: remove this
if (spec_mtp && has_draft) {
struct gguf_init_params meta_params = {
/* .no_alloc = */ true,
/* .ctx = */ nullptr,
};
gguf_context_ptr meta(gguf_init_from_file(params_dft.model.path.c_str(), meta_params));
if (std::string(gguf_get_val_str(meta.get(), gguf_find_key(meta.get(), "general.architecture"))) == "gemma4-assistant") {
skip_measure = true;
SRV_WRN("[spec] skipping --fit memory measurement for Gemma 4 assistant draft model '%s'\n",
params_dft.model.path.c_str());
}
}
std::vector<ggml_backend_dev_t> devs;
uint32_t hp_ngl = 0;
uint32_t hp_nct = 0;
uint32_t hp_nex = 0;
if (!skip_measure) try {
try {
auto dmd = common_get_device_memory_data(
params_dft.model.path.c_str(), &mparams_dft, &cparams_dft,
devs, hp_ngl, hp_nct, hp_nex, GGML_LOG_LEVEL_ERROR);
devs, hp_ngl, hp_nct, hp_nex, GGML_LOG_LEVEL_DEBUG);
GGML_ASSERT(!params_base.fit_params_target.empty());
size_t total = 0;
@ -972,6 +954,8 @@ private:
// note: for small models maybe we can set this to the maximum possible draft from all speculative types
// the extra memory for small models is likely negligible?
cparams.n_rs_seq = 0;
cparams.ctx_src = ctx_tgt;
ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams));
params_base.speculative.draft.ctx_tgt = ctx_tgt;
@ -987,6 +971,7 @@ private:
cparams_mtp.type_v = params_base.speculative.draft.cache_type_v;
cparams_mtp.n_rs_seq = 0;
cparams_mtp.n_outputs_max = params_base.n_parallel;
cparams_mtp.ctx_src = ctx_tgt;
ctx_dft.reset(llama_init_from_model(model_tgt, cparams_mtp));
if (ctx_dft == nullptr) {