This commit is contained in:
Georgi Gerganov 2026-06-06 10:48:36 +03:00
parent 65eef9549c
commit 1c4a91c0f3
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
13 changed files with 97 additions and 314 deletions

View File

@ -491,11 +491,16 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
}
}
kv_shared_with_target = llama_get_memory(ctx_dft) == nullptr;
llama_set_embeddings_nextn(ctx_tgt, true, /*masked*/ false);
llama_set_embeddings_nextn(ctx_dft, true, /*masked*/ true);
llama_set_mtp_source(ctx_dft, ctx_tgt);
kv_shared_with_target = llama_get_memory(ctx_dft) == nullptr;
if (kv_shared_with_target) {
llama_set_memory(ctx_dft, ctx_tgt);
// TODO: avoid the const cast
llama_model_set_tok_embd(const_cast<llama_model *>(llama_get_model(ctx_dft)), llama_model_get_tok_embd(llama_get_model(ctx_tgt)));
}
pending_h.assign(n_seq, std::vector<float>(n_embd, 0.0f));
@ -640,6 +645,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
void draft(common_speculative_draft_params_vec & dparams) override {
auto & ctx_dft = params.ctx_dft;
auto & ctx_tgt = params.ctx_tgt;
common_batch_clear(batch);
@ -665,6 +671,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
h_row = pending_h[seq_id].data();
std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes);
llama_memory_seq_rm(llama_get_memory(ctx_dft), seq_id, dp.n_past, -1);
}
int ret = llama_decode(ctx_dft, batch);
@ -725,6 +733,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
common_batch_add(batch, id, dp.n_past + i + 1, { seq_id }, true);
std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes);
llama_memory_seq_rm(llama_get_memory(ctx_dft), seq_id, dp.n_past, -1);
}
if (batch.n_tokens == 0) {
@ -741,6 +751,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
++i;
}
llama_synchronize(ctx_dft);
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
auto & dp = dparams[seq_id];
if (!dp.drafting) {

View File

@ -301,16 +301,16 @@ bool llama_batch_allocr::init(
ok = false;
}
if (!ok) {
LLAMA_LOG_ERROR(
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
" - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
" it is required that the sequence positions remain consecutive: Y = X + 1\n",
__func__, s, s, p0, s, seq_pos_min(s));
//if (!ok) {
// LLAMA_LOG_ERROR(
// "%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
// " - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
// " - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
// " it is required that the sequence positions remain consecutive: Y = X + 1\n",
// __func__, s, s, p0, s, seq_pos_min(s));
return false;
}
// return false;
//}
}
if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {

View File

@ -38,12 +38,6 @@ static uint32_t ctx_type_to_embd_inp(const llama_hparams & hparams, llama_contex
throw std::runtime_error("Unsupported ctx type");
}
namespace {
struct src_mctx_reset_on_exit {
llama_memory_context_ptr * slot;
~src_mctx_reset_on_exit() { if (slot) slot->reset(); }
};
static void llama_assert_gemma4_mtp_source_placement(
const llama_context * ctx,
const llama_context * src) {
@ -92,7 +86,6 @@ static void llama_assert_gemma4_mtp_source_placement(
}
}
}
}
llama_context::llama_context(
const llama_model & model,
@ -514,23 +507,6 @@ void llama_context::sched_reserve() {
}
}
// When called from decode(), src_mctx_for_decode is already populated and
// we must not drop it on exit (process_ubatch still needs it). Snapshot
// only when sched_reserve runs standalone (e.g. lazy first-decode reserve
// when set_mtp_source flipped sched_need_reserve).
const bool owns_src_snapshot = src_ctx && !src_mctx_for_decode;
if (owns_src_snapshot) {
auto * src_memory = src_ctx->get_memory();
if (!src_memory) {
throw std::runtime_error("MTP source context has no memory module");
}
src_mctx_for_decode = src_memory->init_full();
if (!src_mctx_for_decode) {
throw std::runtime_error("failed to initialize MTP source memory snapshot");
}
}
src_mctx_reset_on_exit reserve_src_drop{owns_src_snapshot ? &src_mctx_for_decode : nullptr};
// avoid reserving graphs with zero outputs - assume one output per sequence
const int n_outputs = n_seqs;
@ -816,8 +792,14 @@ uint32_t llama_context::n_threads_batch() const {
return cparams.n_threads_batch;
}
llama_memory_t llama_context::get_memory() const {
return memory.get();
llama_memory_ptr llama_context::get_memory() const {
return memory;
}
void llama_context::set_memory(llama_memory_ptr memory) {
this->memory = std::move(memory);
sched_need_reserve = true;
}
bool llama_context::memory_update(bool optimize) {
@ -1198,18 +1180,6 @@ void llama_context::set_embeddings_nextn(bool value, bool masked) {
cparams.embeddings_nextn_masked = masked;
}
void llama_context::set_mtp_source(llama_context * src) {
if (src_ctx == src) {
return;
}
llama_assert_gemma4_mtp_source_placement(this, src);
src_ctx = src;
src_mctx_for_decode.reset();
// worst-case compute buffers were reserved without knowing about the source
// memory; force a re-reserve so the next decode sees src views
sched_need_reserve = true;
}
void llama_context::set_causal_attn(bool value) {
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
@ -1807,20 +1777,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
embd_seq.clear();
output_swaps.clear();
src_mctx_reset_on_exit decode_src_drop{&src_mctx_for_decode};
if (src_ctx) {
auto * src_memory = src_ctx->get_memory();
if (!src_memory) {
LLAMA_LOG_ERROR("%s: MTP source context has no memory module\n", __func__);
return -2;
}
src_mctx_for_decode = src_memory->init_full();
if (!src_mctx_for_decode) {
LLAMA_LOG_ERROR("%s: failed to snapshot MTP source memory\n", __func__);
return -2;
}
}
sched_reserve();
bool did_optimize = false;
@ -2414,8 +2370,6 @@ llm_graph_params llama_context::graph_params(
/*.cvec =*/ cvec.get(),
/*.loras =*/ loras.get(),
/*.mctx =*/ mctx,
/*.src_mctx =*/ src_mctx_for_decode.get(),
/*.src_model =*/ src_ctx ? &src_ctx->get_model() : nullptr,
/*.cross =*/ &cross,
/*.samplers =*/ sampling.samplers,
/*.n_outputs =*/ n_outputs,
@ -3704,8 +3658,12 @@ void llama_set_embeddings_nextn(llama_context * ctx, bool value, bool masked) {
ctx->set_embeddings_nextn(value, masked);
}
void llama_set_mtp_source(llama_context * ctx, llama_context * src) {
ctx->set_mtp_source(src);
void llama_set_memory(llama_context * ctx, llama_context * src) {
ctx->set_memory(src->get_memory());
}
llama_memory_t llama_get_memory(const struct llama_context * ctx) {
return ctx->get_memory().get();
}
float * llama_get_embeddings_nextn(llama_context * ctx) {
@ -3771,7 +3729,7 @@ struct ggml_cgraph * llama_graph_reserve(
uint32_t n_tokens,
uint32_t n_seqs,
uint32_t n_outputs) {
auto * memory = ctx->get_memory();
auto memory = ctx->get_memory();
llama_memory_context_ptr mctx;
if (memory) {
mctx = memory->init_full();
@ -3811,10 +3769,6 @@ int32_t llama_set_adapter_cvec(
// memory
//
llama_memory_t llama_get_memory(const struct llama_context * ctx) {
return ctx->get_memory();
}
void llama_memory_clear(llama_memory_t mem, bool data) {
if (!mem) {
return;

View File

@ -71,7 +71,8 @@ struct llama_context {
uint32_t n_threads() const;
uint32_t n_threads_batch() const;
llama_memory_t get_memory() const;
llama_memory_ptr get_memory() const;
void set_memory(llama_memory_ptr memory);
// return true if the memory was updated
bool memory_update(bool optimize);
@ -112,7 +113,6 @@ struct llama_context {
void set_embeddings (bool value);
void set_embeddings_nextn(bool value, bool masked);
void set_mtp_source(llama_context * src);
void set_causal_attn(bool value);
void set_warmup(bool value);
@ -275,13 +275,7 @@ private:
llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably
std::unique_ptr<llama_memory_i> memory;
// external KV source used by MTP draft contexts. src_ctx is the target
// context whose memory we read; src_mctx_for_decode is a per-decode
// snapshot held for the duration of one decode/sched_reserve call.
llama_context * src_ctx = nullptr;
llama_memory_context_ptr src_mctx_for_decode;
llama_memory_ptr memory;
// decode output (2-dimensional array: [n_outputs][n_vocab])
buffer_view<float> logits = {nullptr, 0};

View File

@ -93,7 +93,7 @@ LLAMA_API llama_memory_breakdown llama_get_memory_breakdown(const struct llama_c
// If masked == true, output the embeddings only for the tokens with batch.logits != 0
// If masked == false, output the embeddings for all tokens in the batch regardless of batch.logits
LLAMA_API void llama_set_embeddings_nextn(struct llama_context * ctx, bool value, bool masked);
LLAMA_API void llama_set_mtp_source(struct llama_context * ctx, struct llama_context * src);
LLAMA_API void llama_set_memory(struct llama_context * ctx, struct llama_context * src);
// mirrors:
// LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
@ -101,3 +101,6 @@ LLAMA_API float * llama_get_embeddings_nextn(struct llama_context * ctx);
// LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
LLAMA_API float * llama_get_embeddings_nextn_ith(struct llama_context * ctx, int32_t i);
LLAMA_API ggml_tensor * llama_model_get_tok_embd(const struct llama_model * model);
LLAMA_API void llama_model_set_tok_embd( struct llama_model * model, ggml_tensor * tensor);

View File

@ -397,7 +397,7 @@ static void print_mask(const T * data, int64_t n_tokens, int64_t n_kv, int64_t n
case LLAMA_SWA_TYPE_SYMMETRIC: swa_type_str = "LLAMA_SWA_TYPE_SYMMETRIC"; break;
};
LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swq_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str);
LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swa_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str);
LLAMA_LOG_DEBUG("%s: '0' = can attend, '∞' = masked\n", __func__);
LLAMA_LOG_DEBUG("%s: Rows = query tokens, Columns = key/value tokens\n\n", __func__);
@ -565,18 +565,18 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
if (self_k_idxs && self_k_idxs->buffer) {
mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
}
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
// swa tensors may not be allocated if there are no SWA attention layers
if (self_k_idxs_swa && self_k_idxs_swa->buffer) {
mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch);
mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
}
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
if (self_k_rot) {
mctx->get_base()->set_input_k_rot(self_k_rot);
}
@ -605,47 +605,18 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
if (self_k_idxs && self_k_idxs->buffer) {
res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
//res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams);
}
res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams);
// swa tensors may not be allocated if there are no SWA attention layers
if (self_k_idxs_swa && self_k_idxs_swa->buffer) {
res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
//res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams);
}
return res;
}
res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams);
void llm_graph_input_attn_src_kv_iswa::set_input(const llama_ubatch * ubatch) {
src_mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
src_mctx->get_swa() ->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
if (self_k_rot) {
src_mctx->get_base()->set_input_k_rot(self_k_rot);
}
if (self_v_rot) {
src_mctx->get_base()->set_input_v_rot(self_v_rot);
}
if (self_k_rot_swa) {
src_mctx->get_swa()->set_input_k_rot(self_k_rot_swa);
}
if (self_v_rot_swa) {
src_mctx->get_swa()->set_input_v_rot(self_v_rot_swa);
}
}
bool llm_graph_input_attn_src_kv_iswa::can_reuse(const llm_graph_params & params) {
const auto * mctx = static_cast<const llama_kv_cache_iswa_context *>(params.src_mctx);
this->src_mctx = mctx;
bool res = true;
res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams);
res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams);
return res;
}
@ -785,18 +756,18 @@ void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) {
if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) {
attn_ctx->get_base()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch);
attn_ctx->get_base()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch);
attn_ctx->get_base()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
}
attn_ctx->get_base()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
// swa tensors may not be allocated if there are no SWA attention layers
if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) {
attn_ctx->get_swa()->set_input_k_idxs(inp_attn->self_k_idxs_swa, ubatch);
attn_ctx->get_swa()->set_input_v_idxs(inp_attn->self_v_idxs_swa, ubatch);
attn_ctx->get_swa()->set_input_kq_mask(inp_attn->self_kq_mask_swa, ubatch, cparams.causal_attn);
}
attn_ctx->get_swa()->set_input_kq_mask(inp_attn->self_kq_mask_swa, ubatch, cparams.causal_attn);
if (inp_attn->self_k_rot) {
attn_ctx->get_base()->set_input_k_rot(inp_attn->self_k_rot);
}
@ -839,18 +810,18 @@ bool llm_graph_input_mem_hybrid_iswa::can_reuse(const llm_graph_params & params)
if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) {
res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
//res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
res &= can_reuse_kq_mask(inp_attn->self_kq_mask, attn_ctx->get_base(), params.ubatch, params.cparams);
}
res &= can_reuse_kq_mask(inp_attn->self_kq_mask, attn_ctx->get_base(), params.ubatch, params.cparams);
// swa tensors may not be allocated if there are no SWA attention layers
if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) {
res &= inp_attn->self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
//res &= inp_attn->self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
res &= can_reuse_kq_mask(inp_attn->self_kq_mask_swa, attn_ctx->get_swa(), params.ubatch, params.cparams);
}
res &= can_reuse_kq_mask(inp_attn->self_kq_mask_swa, attn_ctx->get_swa(), params.ubatch, params.cparams);
res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs;
@ -1063,8 +1034,6 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
cvec (params.cvec),
loras (params.loras),
mctx (params.mctx),
src_mctx (params.src_mctx),
src_model (params.src_model),
cross (params.cross),
samplers (params.samplers),
cb_func (params.cb),
@ -2574,6 +2543,16 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_build_forward_expand(gf, v_cur);
}
int il_save = il;
if (arch == LLM_ARCH_GEMMA4_ASSISTANT) {
if (il == n_layer - 1) {
il = 59;
} else {
il = 58;
}
}
const auto * mctx_iswa = inp->mctx;
const auto * mctx_cur = is_swa ? mctx_iswa->get_swa() : mctx_iswa->get_base();
@ -2597,6 +2576,8 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
il = il_save;
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
cb(cur, "kqv_out", il);
@ -2635,98 +2616,6 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
return (llm_graph_input_attn_cross *) res->add_input(std::move(inp));
}
llm_graph_input_attn_src_kv_iswa * llm_graph_context::build_attn_inp_src_kv_iswa() const {
GGML_ASSERT(src_mctx && "MTP draft graph requires src_mctx (set via llama_set_mtp_source)");
const auto * src_iswa = static_cast<const llama_kv_cache_iswa_context *>(src_mctx);
auto inp = std::make_unique<llm_graph_input_attn_src_kv_iswa>(hparams, cparams, src_iswa);
inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, src_iswa->get_base(), ubatch, cparams);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
inp->self_kq_mask_swa = build_attn_inp_kq_mask(ctx0, src_iswa->get_swa(), ubatch, cparams);
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
inp->self_k_rot = src_iswa->get_base()->build_input_k_rot(ctx0);
inp->self_v_rot = src_iswa->get_base()->build_input_v_rot(ctx0);
inp->self_k_rot_swa = src_iswa->get_swa()->build_input_k_rot(ctx0);
inp->self_v_rot_swa = src_iswa->get_swa()->build_input_v_rot(ctx0);
return (llm_graph_input_attn_src_kv_iswa *) res->add_input(std::move(inp));
}
ggml_tensor * llm_graph_context::build_attn(
llm_graph_input_attn_src_kv_iswa * inp,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * wo_s,
ggml_tensor * q_cur,
ggml_tensor * kq_b,
ggml_tensor * sinks,
ggml_tensor * v_mla,
float kq_scale,
int il_assist,
int il_src) const {
const bool is_swa = hparams.is_swa(il_assist);
const auto * src_iswa = inp->src_mctx;
const auto * src_cur = is_swa ? src_iswa->get_swa() : src_iswa->get_base();
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
auto * k_rot = is_swa ? inp->self_k_rot_swa : inp->self_k_rot;
auto * v_rot = is_swa ? inp->self_v_rot_swa : inp->self_v_rot;
if (k_rot) {
q_cur = ggml_mul_mat_aux(ctx0, q_cur, k_rot);
}
ggml_build_forward_expand(gf, q_cur);
ggml_tensor * q = q_cur;
ggml_tensor * k = src_cur->get_k(ctx0, il_src);
ggml_tensor * v = src_cur->get_v(ctx0, il_src);
// build_attn_mha splits q across k->ne[3] (the trunk's stream count). When the
// trunk runs kv_unified=false the assistant's ubatch only references a subset
// of streams (one per active draft seq); q->ne[2] is not divisible by the full
// n_stream and the view collapses tokens. Slice k/v down to exactly the streams
// referenced by this ubatch. Requires those streams to form a contiguous range.
if (k->ne[3] > 1 && (uint32_t) k->ne[3] != ubatch.n_seqs_unq) {
GGML_ASSERT(ubatch.n_seqs_unq > 0 && ubatch.seq_id_unq);
llama_seq_id min_s = ubatch.seq_id_unq[0];
llama_seq_id max_s = ubatch.seq_id_unq[0];
for (uint32_t s = 1; s < ubatch.n_seqs_unq; ++s) {
min_s = std::min(min_s, ubatch.seq_id_unq[s]);
max_s = std::max(max_s, ubatch.seq_id_unq[s]);
}
GGML_ASSERT((uint32_t)(max_s - min_s + 1) == ubatch.n_seqs_unq &&
"MTP src-kv attn requires the active draft seq_ids to be contiguous");
GGML_ASSERT((int64_t) max_s < k->ne[3] && "MTP assistant seq_id beyond trunk stream count");
k = ggml_view_4d(ctx0, k, k->ne[0], k->ne[1], k->ne[2], (int64_t) ubatch.n_seqs_unq,
k->nb[1], k->nb[2], k->nb[3], (size_t) min_s * k->nb[3]);
v = ggml_view_4d(ctx0, v, v->ne[0], v->ne[1], v->ne[2], (int64_t) ubatch.n_seqs_unq,
v->nb[1], v->nb[2], v->nb[3], (size_t) min_s * v->nb[3]);
}
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il_assist);
cb(cur, "kqv_out", il_assist);
if (v_rot) {
cur = ggml_mul_mat_aux(ctx0, cur, v_rot);
}
if (wo) {
cur = build_lora_mm(wo, cur, wo_s);
}
if (wo_b) {
cur = ggml_add(ctx0, cur, wo_b);
}
return cur;
}
ggml_tensor * llm_graph_context::build_attn(
llm_graph_input_attn_cross * inp,
ggml_tensor * wo,

View File

@ -459,42 +459,6 @@ public:
const llama_kv_cache_iswa_context * mctx;
};
// mask-only input for attention against an external (read-only) ISWA KV cache.
// used by MTP draft graphs that attend to the target's KV without owning any.
class llm_graph_input_attn_src_kv_iswa : public llm_graph_input_i {
public:
llm_graph_input_attn_src_kv_iswa(
const llama_hparams & hparams,
const llama_cparams & cparams,
const llama_kv_cache_iswa_context * src_mctx) :
hparams(hparams),
cparams(cparams),
src_mctx(src_mctx) {
}
~llm_graph_input_attn_src_kv_iswa() = default;
void set_input(const llama_ubatch * ubatch) override;
bool can_reuse(const llm_graph_params & params) override;
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
ggml_tensor * self_kq_mask = nullptr;
ggml_tensor * self_kq_mask_cnv = nullptr;
ggml_tensor * self_kq_mask_swa = nullptr;
ggml_tensor * self_kq_mask_swa_cnv = nullptr;
ggml_tensor * self_k_rot = nullptr;
ggml_tensor * self_v_rot = nullptr;
ggml_tensor * self_k_rot_swa = nullptr;
ggml_tensor * self_v_rot_swa = nullptr;
const llama_hparams hparams;
const llama_cparams cparams;
const llama_kv_cache_iswa_context * src_mctx;
};
class llm_graph_input_attn_cross : public llm_graph_input_i {
public:
llm_graph_input_attn_cross(const llama_cross * cross) : cross(cross) {}
@ -637,11 +601,6 @@ struct llm_graph_params {
const llama_adapter_cvec * cvec;
const llama_adapter_loras * loras;
const llama_memory_context_i * mctx;
// per-decode snapshot of an external memory module the graph reads from
// (never writes) — e.g. ctx_dft reading target KV during MTP draft.
// nullptr for a main decode. Rebound inside reuse-aware input classes.
const llama_memory_context_i * src_mctx;
const llama_model * src_model;
const llama_cross * cross;
std::map<llama_seq_id, llama_sampler *> samplers;
@ -859,8 +818,6 @@ struct llm_graph_context {
const llama_adapter_cvec * cvec;
const llama_adapter_loras * loras;
const llama_memory_context_i * mctx;
const llama_memory_context_i * src_mctx;
const llama_model * src_model;
const llama_cross * cross;
std::map<llama_seq_id, llama_sampler *> samplers;
@ -1090,24 +1047,6 @@ struct llm_graph_context {
float kq_scale,
int il) const;
llm_graph_input_attn_src_kv_iswa * build_attn_inp_src_kv_iswa() const;
// Q-only attention against an external ISWA KV cache (no K/V projections,
// no writes). il_assist labels the attention block in the local graph for
// logging; il_src indexes the source K/V layer to attend to.
ggml_tensor * build_attn(
llm_graph_input_attn_src_kv_iswa * inp,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * wo_s,
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
ggml_tensor * kq_b,
ggml_tensor * sinks, // [n_head_q]
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
float kq_scale,
int il_assist,
int il_src) const;
llm_graph_input_attn_cross * build_attn_inp_cross() const;
ggml_tensor * build_attn(

View File

@ -1552,6 +1552,10 @@ static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, T * data
goto skip;
}
if (ubatch->n_tokens == 1 && p0 == p1) {
goto skip;
}
// M-RoPE causal mask
if (is_2d) {
if (p0 == p1) {

View File

@ -122,4 +122,4 @@ struct llama_memory_i {
virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) = 0;
};
using llama_memory_ptr = std::unique_ptr<llama_memory_i>;
using llama_memory_ptr = std::shared_ptr<llama_memory_i>;

View File

@ -1987,6 +1987,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
case LLM_ARCH_WAVTOKENIZER_DEC:
case LLM_ARCH_MODERN_BERT:
case LLM_ARCH_GEMMA_EMBEDDING:
case LLM_ARCH_GEMMA4_ASSISTANT:
case LLM_ARCH_DREAM:
case LLM_ARCH_LLADA:
case LLM_ARCH_LLADA_MOE:
@ -2630,3 +2631,11 @@ void llama_model_base::create_tensor_qkv(llama_layer & layer, int bid,
layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", bid), {n_embd_v_}, TENSOR_NOT_REQUIRED);
}
}
ggml_tensor * llama_model_get_tok_embd(const struct llama_model * model) {
return model->tok_embd;
}
void llama_model_set_tok_embd(struct llama_model * model, ggml_tensor * tensor) {
model->tok_embd = tensor;
}

View File

@ -706,6 +706,7 @@ const char * llm_type_name(llm_type type);
#define LLAMA_LOAD_LOCALS \
const int n_layer = hparams.n_layer(); GGML_UNUSED(n_layer); \
const int n_layer_all = hparams.n_layer_all; GGML_UNUSED(n_layer_all); \
const int n_layer_nextn = hparams.n_layer_nextn; GGML_UNUSED(n_layer_nextn); \
const int64_t n_head = hparams.n_head(); GGML_UNUSED(n_head); \
const int64_t n_head_kv = hparams.n_head_kv(); GGML_UNUSED(n_head_kv); \
const int64_t n_embd = hparams.n_embd; GGML_UNUSED(n_embd); \

View File

@ -82,18 +82,6 @@ std::unique_ptr<llm_graph_context> llama_model_gemma4_assistant::build_arch_grap
llama_model_gemma4_assistant::graph::graph(const llama_model & model, const llm_graph_params & params) :
llm_graph_context(params) {
GGML_ASSERT(src_mctx && "Gemma 4 assistant graph requires an MTP source (llama_set_mtp_source)");
GGML_ASSERT(src_model && "Gemma 4 assistant graph requires a source model");
GGML_ASSERT(src_model->tok_embd && "source model missing tok_embd");
const auto & src_hparams = src_model->hparams;
// By convention the MTP draft reads from the trunk's final SWA and full layers.
const int32_t src_layer_full = (int32_t) src_hparams.n_layer() - 1;
const int32_t src_layer_swa = (int32_t) src_hparams.n_layer() - 2;
GGML_ASSERT(!src_hparams.is_swa(src_layer_full) && "trunk's last layer must be full attention");
GGML_ASSERT( src_hparams.is_swa(src_layer_swa) && "trunk's penultimate layer must be SWA");
const int64_t n_embd_backbone = hparams.n_embd_out();
ggml_tensor * inp_tokens;
@ -116,7 +104,7 @@ llama_model_gemma4_assistant::graph::graph(const llama_model & model, const llm_
res->add_input(std::move(inp));
}
ggml_tensor * x = ggml_get_rows(ctx0, src_model->tok_embd, inp_tokens);
ggml_tensor * x = ggml_get_rows(ctx0, model.tok_embd, inp_tokens);
x = ggml_scale(ctx0, x, sqrtf((float) n_embd_backbone));
cb(x, "inp_embd_target", -1);
@ -126,15 +114,14 @@ llama_model_gemma4_assistant::graph::graph(const llama_model & model, const llm_
ggml_tensor * cur = ggml_mul_mat(ctx0, model.nextn_pre_proj, xh);
cb(cur, "pre_proj", -1);
auto * inp_attn = build_attn_inp_src_kv_iswa();
auto * inp_attn = build_attn_inp_kv_iswa();
ggml_tensor * inp_pos = build_inp_pos();
ggml_tensor * inp_out_ids = build_inp_out_ids();
ggml_tensor * inpL = cur;
for (int il = 0; il < n_layer; ++il) {
const bool is_swa = hparams.is_swa(il);
const int32_t il_src = is_swa ? src_layer_swa : src_layer_full;
const bool is_swa = hparams.is_swa(il);
const int64_t n_embd_head = hparams.n_embd_head_k(il);
const int64_t n_head = hparams.n_head(il);
@ -157,7 +144,7 @@ llama_model_gemma4_assistant::graph::graph(const llama_model & model, const llm_
cb(Qcur, "Qcur_pos", il);
cur = build_attn(inp_attn, model.layers[il].wo, nullptr, nullptr,
Qcur, nullptr, nullptr, nullptr, hparams.f_attention_scale, il, il_src);
Qcur, nullptr, nullptr, nullptr, nullptr, nullptr, hparams.f_attention_scale, il);
if (il == n_layer - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);

View File

@ -974,13 +974,6 @@ private:
cparams.n_rs_seq = 0;
ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams));
if (spec_mtp) {
// MTP draft must know its target before the first decode
llama_set_mtp_source(ctx_dft.get(), ctx_tgt);
}
ctx_dft_seq_rm_type = common_context_can_seq_rm(ctx_dft.get());
params_base.speculative.draft.ctx_tgt = ctx_tgt;
params_base.speculative.draft.ctx_dft = ctx_dft.get();
} else if (std::find(params_base.speculative.types.begin(), params_base.speculative.types.end(),
@ -1001,12 +994,6 @@ private:
return false;
}
// wire the source before any decode (the seq-rm probe below
// triggers sched_reserve which needs src for Gemma4-style MTP)
llama_set_mtp_source(ctx_dft.get(), ctx_tgt);
ctx_dft_seq_rm_type = common_context_can_seq_rm(ctx_dft.get());
params_base.speculative.draft.ctx_tgt = ctx_tgt;
params_base.speculative.draft.ctx_dft = ctx_dft.get();
}
@ -1094,6 +1081,10 @@ private:
}
}
if (ctx_dft) {
ctx_dft_seq_rm_type = common_context_can_seq_rm(ctx_dft.get());
}
if (spec) {
SRV_INF("%s", "speculative decoding context initialized\n");
} else {