llama: Gemma 4 MTP

This commit is contained in:
Aman Gupta 2026-05-19 20:18:00 +08:00
parent 0066404085
commit f268966d49
19 changed files with 574 additions and 51 deletions

View File

@ -418,6 +418,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
int32_t n_embd = 0;
bool kv_shared_with_target = false;
// Per-sequence cross-batch carryover: pair (h_p, x_{p+1}) at MTP pos p+1.
// The last h-row of one process() call needs the first token of the NEXT
// call to pair with, so it's stashed here until that next call fires.
@ -444,7 +446,9 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
auto * ctx_dft = this->params.ctx_dft;
GGML_ASSERT(ctx_tgt && ctx_dft && "MTP requires ctx_tgt and ctx_dft to be set");
n_embd = llama_model_n_embd(llama_get_model(ctx_dft));
n_embd = llama_model_n_embd_out(llama_get_model(ctx_dft));
GGML_ASSERT(n_embd == llama_model_n_embd(llama_get_model(ctx_tgt)) &&
"MTP input row width must match the target h_nextn width");
LOG_INF("%s: adding speculative implementation 'draft-mtp'\n", __func__);
LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%.2f, n_embd=%d, backend_sampling=%d\n", __func__, this->params.n_max, this->params.n_min, this->params.p_min, n_embd, (int) this->params.backend_sampling);
@ -489,6 +493,9 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
llama_set_embeddings_nextn(ctx_tgt, true, /*masked*/ false);
llama_set_embeddings_nextn(ctx_dft, true, /*masked*/ true);
llama_set_mtp_source(ctx_dft, ctx_tgt);
kv_shared_with_target = llama_model_n_layer_kv(llama_get_model(ctx_dft)) == 0;
pending_h.assign(n_seq, std::vector<float>(n_embd, 0.0f));
@ -526,9 +533,10 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
if (N <= 0) {
return;
}
auto * ctx_dft = this->params.ctx_dft;
const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id);
if (pos_max < N - 1) {
if (pos_max < N - 1 && !kv_shared_with_target) {
LOG_WRN("%s: ctx_dft pos_max=%d < N-1=%d - "
"process() hook may not have run on every prefill ubatch "
"(need_embd / logits=1 on every prompt position?). "
@ -571,48 +579,42 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
const size_t row_bytes = (size_t) n_embd * sizeof(float);
common_batch_clear(batch);
// if kv is shared with target (e.g Gemma4), then we can skip this catch-up decode
if (!kv_shared_with_target) {
common_batch_clear(batch);
for (int k = 0; k < n_tokens; ++k) {
common_batch_add(batch, batch_in.token[k], batch_in.pos[k], { batch_in.seq_id[k][0] }, 0);
}
// shift the tgt embeddings to the right by one position
// assumes that the tokens in the batch are sequential for each sequence
// i.e. we cannot have seq_id like this: [0, 0, 0, 1, 1, 0, 1, 1]
// ^--- this is a problem
// TODO:this is generally true, but would be nice to assert it
{
const float * h_tgt = llama_get_embeddings_nextn(ctx_tgt);
std::memcpy(batch.embd + (size_t) 1 * n_embd, h_tgt, row_bytes * (n_tokens-1));
//{
// // string with seq_ids in the batch
// std::stringstream ss;
// for (int i = 0; i < n_tokens; ++i) {
// ss << batch_in.seq_id[i][0] << ",";
// }
// LOG_WRN("%s: batch_in.seq_id = %s\n", __func__, ss.str().c_str());
//}
}
// fill the pending embeddings from a previous run
auto set_h = [&](int idx, const float * h_row) {
std::memcpy(batch.embd + (size_t) idx * n_embd, h_row, row_bytes);
};
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
if (i_batch_beg[seq_id] < 0) {
continue;
for (int k = 0; k < n_tokens; ++k) {
common_batch_add(batch, batch_in.token[k], batch_in.pos[k], { batch_in.seq_id[k][0] }, 0);
}
set_h(i_batch_beg[seq_id], pending_h[seq_id].data());
}
// shift the tgt embeddings to the right by one position
// assumes that the tokens in the batch are sequential for each sequence
// i.e. we cannot have seq_id like this: [0, 0, 0, 1, 1, 0, 1, 1]
// ^--- this is a problem
// TODO:this is generally true, but would be nice to assert it
{
const float * h_tgt = llama_get_embeddings_nextn(ctx_tgt);
std::memcpy(batch.embd + (size_t) 1 * n_embd, h_tgt, row_bytes * (n_tokens-1));
}
const int32_t rc = llama_decode(ctx_dft, batch);
if (rc != 0) {
LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (pos=%d)\n", __func__, (int) rc, (int) batch_in.pos[0]);
return false;
// fill the pending embeddings from a previous run
auto set_h = [&](int idx, const float * h_row) {
std::memcpy(batch.embd + (size_t) idx * n_embd, h_row, row_bytes);
};
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
if (i_batch_beg[seq_id] < 0) {
continue;
}
set_h(i_batch_beg[seq_id], pending_h[seq_id].data());
}
const int32_t rc = llama_decode(ctx_dft, batch);
if (rc != 0) {
LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (pos=%d)\n", __func__, (int) rc, (int) batch_in.pos[0]);
return false;
}
}
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {

View File

@ -75,6 +75,7 @@ TEXT_MODEL_MAP: dict[str, str] = {
"Gemma3TextModel": "gemma",
"Gemma3nForCausalLM": "gemma",
"Gemma3nForConditionalGeneration": "gemma",
"Gemma4AssistantForCausalLM": "gemma",
"Gemma4ForConditionalGeneration": "gemma",
"Gemma4ForCausalLM": "gemma",
"Gemma4UnifiedForConditionalGeneration": "gemma",

View File

@ -785,6 +785,16 @@ class Gemma4UnifiedModel(Gemma4Model):
self.gguf_writer.add_suppress_tokens(suppress_tokens)
@ModelBase.register("Gemma4AssistantForCausalLM")
class Gemma4AssistantModel(Gemma4Model):
model_arch = gguf.MODEL_ARCH.GEMMA4_ASSISTANT
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_embedding_length_out(self.hparams["backbone_hidden_size"])
self.gguf_writer.add_nextn_predict_layers(self.block_count)
@ModelBase.register("Gemma4ForConditionalGeneration")
class Gemma4VisionAudioModel(MmprojModel):
has_audio_encoder = True

View File

@ -434,6 +434,7 @@ class MODEL_ARCH(IntEnum):
GEMMA3 = auto()
GEMMA3N = auto()
GEMMA4 = auto()
GEMMA4_ASSISTANT = auto()
GEMMA_EMBEDDING = auto()
STARCODER2 = auto()
RWKV6 = auto()
@ -866,6 +867,8 @@ class MODEL_TENSOR(IntEnum):
A_PER_DIM_K_SCALE = auto() # gemma4
A_PER_DIM_SCALE = auto() # gemma4
# nextn/mtp
NEXTN_PRE_PROJ = auto()
NEXTN_POST_PROJ = auto()
NEXTN_EH_PROJ = auto()
NEXTN_EMBED_TOKENS = auto()
NEXTN_ENORM = auto()
@ -955,6 +958,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.GEMMA3: "gemma3",
MODEL_ARCH.GEMMA3N: "gemma3n",
MODEL_ARCH.GEMMA4: "gemma4",
MODEL_ARCH.GEMMA4_ASSISTANT: "gemma4-assistant",
MODEL_ARCH.GEMMA_EMBEDDING: "gemma-embedding",
MODEL_ARCH.STARCODER2: "starcoder2",
MODEL_ARCH.RWKV6: "rwkv6",
@ -1417,6 +1421,8 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.A_QF_FFN_DOWN: "a.proj_blk.{bid}.ffn_down",
MODEL_TENSOR.A_QF_FFN_NORM: "a.proj_blk.{bid}.ffn_norm",
# NextN/MTP
MODEL_TENSOR.NEXTN_PRE_PROJ: "nextn.pre_projection",
MODEL_TENSOR.NEXTN_POST_PROJ: "nextn.post_projection",
MODEL_TENSOR.NEXTN_EH_PROJ: "blk.{bid}.nextn.eh_proj",
MODEL_TENSOR.NEXTN_EMBED_TOKENS: "blk.{bid}.nextn.embed_tokens",
MODEL_TENSOR.NEXTN_ENORM: "blk.{bid}.nextn.enorm",
@ -2500,6 +2506,24 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.PER_LAYER_PROJ_NORM,
MODEL_TENSOR.PER_LAYER_POST_NORM,
],
MODEL_ARCH.GEMMA4_ASSISTANT: [
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.NEXTN_PRE_PROJ,
MODEL_TENSOR.NEXTN_POST_PROJ,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_POST_NORM,
MODEL_TENSOR.FFN_PRE_NORM,
MODEL_TENSOR.FFN_POST_NORM,
MODEL_TENSOR.LAYER_OUT_SCALE,
],
MODEL_ARCH.GEMMA_EMBEDDING: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT,

View File

@ -2279,6 +2279,14 @@ class TensorNameMap:
),
# NextN/MTP tensors
MODEL_TENSOR.NEXTN_PRE_PROJ: (
"pre_projection",
),
MODEL_TENSOR.NEXTN_POST_PROJ: (
"post_projection",
),
MODEL_TENSOR.NEXTN_EH_PROJ: (
"model.layers.{bid}.eh_proj",
),

View File

@ -57,6 +57,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_GEMMA3, "gemma3" },
{ LLM_ARCH_GEMMA3N, "gemma3n" },
{ LLM_ARCH_GEMMA4, "gemma4" },
{ LLM_ARCH_GEMMA4_ASSISTANT, "gemma4-assistant" },
{ LLM_ARCH_GEMMA_EMBEDDING, "gemma-embedding" },
{ LLM_ARCH_STARCODER2, "starcoder2" },
{ LLM_ARCH_MAMBA, "mamba" },
@ -452,6 +453,8 @@ static const std::map<llm_tensor, const char *> LLM_TENSOR_NAMES = {
{ LLM_TENSOR_FFN_NORM_EXPS, "blk.%d.ffn_norm_exps" },
{ LLM_TENSOR_ATTN_K_B, "blk.%d.attn_k_b" },
{ LLM_TENSOR_ATTN_V_B, "blk.%d.attn_v_b" },
{ LLM_TENSOR_NEXTN_PRE_PROJ, "nextn.pre_projection" },
{ LLM_TENSOR_NEXTN_POST_PROJ, "nextn.post_projection" },
{ LLM_TENSOR_NEXTN_EH_PROJ, "blk.%d.nextn.eh_proj" },
{ LLM_TENSOR_NEXTN_EMBED_TOKENS, "blk.%d.nextn.embed_tokens" },
{ LLM_TENSOR_NEXTN_ENORM, "blk.%d.nextn.enorm" },
@ -764,6 +767,8 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_INDEXER_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_INDEXER_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_INDEXER_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_NEXTN_PRE_PROJ, {LLM_TENSOR_LAYER_INPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_NEXTN_POST_PROJ, {LLM_TENSOR_LAYER_INPUT, GGML_OP_MUL_MAT}},
// NextN/MTP tensors are stored per-block (blk.%d.nextn.*) even though only the
// last nextn_predict_layers blocks carry them. Classify as LAYER_REPEATING so
// the model loader doesn't fault on the block index.

View File

@ -61,6 +61,7 @@ enum llm_arch {
LLM_ARCH_GEMMA3,
LLM_ARCH_GEMMA3N,
LLM_ARCH_GEMMA4,
LLM_ARCH_GEMMA4_ASSISTANT,
LLM_ARCH_GEMMA_EMBEDDING,
LLM_ARCH_STARCODER2,
LLM_ARCH_MAMBA,
@ -556,6 +557,8 @@ enum llm_tensor {
LLM_TENSOR_INDEXER_PROJ,
LLM_TENSOR_INDEXER_ATTN_K,
LLM_TENSOR_INDEXER_ATTN_Q_B,
LLM_TENSOR_NEXTN_PRE_PROJ,
LLM_TENSOR_NEXTN_POST_PROJ,
LLM_TENSOR_NEXTN_EH_PROJ,
LLM_TENSOR_NEXTN_EMBED_TOKENS,
LLM_TENSOR_NEXTN_ENORM,

View File

@ -30,6 +30,21 @@ static llm_graph_type ctx_type_to_graph_type(llama_context_type ctx_type) {
throw std::runtime_error("Unsupported ctx type");
}
static uint32_t ctx_type_to_embd_inp(const llama_hparams & hparams, llama_context_type ctx_type) {
switch (ctx_type) {
case LLAMA_CONTEXT_TYPE_DEFAULT: return hparams.n_embd_inp();
case LLAMA_CONTEXT_TYPE_MTP : return hparams.n_embd_out();
}
throw std::runtime_error("Unsupported ctx type");
}
namespace {
struct src_mctx_reset_on_exit {
llama_memory_context_ptr * slot;
~src_mctx_reset_on_exit() { if (slot) slot->reset(); }
};
}
llama_context::llama_context(
const llama_model & model,
llama_context_params params) :
@ -372,7 +387,11 @@ llama_context::llama_context(
LLAMA_LOG_INFO("%s: pipeline parallelism enabled\n", __func__);
}
sched_reserve();
// MTP draft contexts can't reserve until the source context is wired
// via llama_set_mtp_source — defer to the first decode.
if (cparams.ctx_type != LLAMA_CONTEXT_TYPE_MTP) {
sched_reserve();
}
if (!cparams.flash_attn) {
if (ggml_is_quantized(params.type_v)) {
@ -446,6 +465,23 @@ void llama_context::sched_reserve() {
}
}
// When called from decode(), src_mctx_for_decode is already populated and
// we must not drop it on exit (process_ubatch still needs it). Snapshot
// only when sched_reserve runs standalone (e.g. lazy first-decode reserve
// when set_mtp_source flipped sched_need_reserve).
const bool owns_src_snapshot = src_ctx && !src_mctx_for_decode;
if (owns_src_snapshot) {
auto * src_memory = src_ctx->get_memory();
if (!src_memory) {
throw std::runtime_error("MTP source context has no memory module");
}
src_mctx_for_decode = src_memory->init_full();
if (!src_mctx_for_decode) {
throw std::runtime_error("failed to initialize MTP source memory snapshot");
}
}
src_mctx_reset_on_exit reserve_src_drop{owns_src_snapshot ? &src_mctx_for_decode : nullptr};
// avoid reserving graphs with zero outputs - assume one output per sequence
const int n_outputs = n_seqs;
@ -904,7 +940,7 @@ float * llama_context::get_embeddings_nextn_ith(int32_t i) {
throw std::runtime_error("no nextn embeddings");
}
const uint32_t n_embd = model.hparams.n_embd;
const uint32_t n_embd = model.hparams.n_embd_out();
if (!cparams.embeddings_nextn_masked) {
// unmasked: nextn rows are stored densely, indexed by raw token position.
@ -1113,6 +1149,17 @@ void llama_context::set_embeddings_nextn(bool value, bool masked) {
cparams.embeddings_nextn_masked = masked;
}
void llama_context::set_mtp_source(llama_context * src) {
if (src_ctx == src) {
return;
}
src_ctx = src;
src_mctx_for_decode.reset();
// worst-case compute buffers were reserved without knowing about the source
// memory; force a re-reserve so the next decode sees src views
sched_need_reserve = true;
}
void llama_context::set_causal_attn(bool value) {
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
@ -1338,7 +1385,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
const auto & hparams = model.hparams;
const int64_t n_embd = hparams.n_embd_inp();
const int64_t n_embd = ctx_type_to_embd_inp(hparams, cparams.ctx_type);
const int64_t n_vocab = model.vocab.n_tokens();
// note: during encode, we always pass the full sequence starting from pos = 0
@ -1473,7 +1520,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_nextn);
GGML_ASSERT(backend_h != nullptr);
const uint32_t n_embd = hparams.n_embd;
const uint32_t n_embd = hparams.n_embd_out();
GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_nextn.size);
ggml_backend_tensor_get_async(backend_h, t_h_nextn, embd_nextn.data, 0, n_tokens*n_embd*sizeof(float));
}
@ -1648,7 +1695,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
const auto & hparams = model.hparams;
const int64_t n_vocab = vocab.n_tokens();
const int64_t n_embd = hparams.n_embd_inp();
const int64_t n_embd = ctx_type_to_embd_inp(hparams, cparams.ctx_type);
// when computing embeddings, all tokens are output
const bool output_all = cparams.embeddings;
@ -1710,6 +1757,20 @@ int llama_context::decode(const llama_batch & batch_inp) {
embd_seq.clear();
output_swaps.clear();
src_mctx_reset_on_exit decode_src_drop{&src_mctx_for_decode};
if (src_ctx) {
auto * src_memory = src_ctx->get_memory();
if (!src_memory) {
LLAMA_LOG_ERROR("%s: MTP source context has no memory module\n", __func__);
return -2;
}
src_mctx_for_decode = src_memory->init_full();
if (!src_mctx_for_decode) {
LLAMA_LOG_ERROR("%s: failed to snapshot MTP source memory\n", __func__);
return -2;
}
}
sched_reserve();
bool did_optimize = false;
@ -1924,7 +1985,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_nextn);
GGML_ASSERT(backend_h != nullptr);
const uint32_t n_embd = hparams.n_embd;
const uint32_t n_embd = hparams.n_embd_out();
float * embd_nextn_out = embd_nextn.data + offset*n_embd;
GGML_ASSERT((offset + n_rows)*n_embd <= (int64_t) embd_nextn.size);
@ -2017,7 +2078,6 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
const auto n_batch = cparams.n_batch;
const auto n_vocab = vocab.n_tokens();
const auto n_embd = hparams.n_embd;
const auto n_embd_out = hparams.n_embd_out();
bool has_logits = true;
@ -2036,12 +2096,12 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
logits.size = has_logits ? n_vocab*n_outputs_max : 0;
embd.size = has_embd ? n_embd_out*n_outputs_max : 0;
embd_nextn.size = has_embd_nextn ? n_embd*n_outputs_max : 0;
embd_nextn.size = has_embd_nextn ? n_embd_out*n_outputs_max : 0;
if (has_embd_nextn && !cparams.embeddings_nextn_masked) {
// unmasked: nextn row exists for every token in the batch, not just
// those flagged via batch.logits[i] -> size by token count instead.
embd_nextn.size = (size_t) n_embd * n_batch;
embd_nextn.size = (size_t) n_embd_out * n_batch;
}
// Allocate backend sampling output buffers if there are backend samplers configured.
@ -2304,6 +2364,8 @@ llm_graph_params llama_context::graph_params(
/*.cvec =*/ cvec.get(),
/*.loras =*/ loras.get(),
/*.mctx =*/ mctx,
/*.src_mctx =*/ src_mctx_for_decode.get(),
/*.src_model =*/ src_ctx ? &src_ctx->get_model() : nullptr,
/*.cross =*/ &cross,
/*.samplers =*/ sampling.samplers,
/*.n_outputs =*/ n_outputs,
@ -3593,6 +3655,10 @@ void llama_set_embeddings_nextn(llama_context * ctx, bool value, bool masked) {
ctx->set_embeddings_nextn(value, masked);
}
void llama_set_mtp_source(llama_context * ctx, llama_context * src) {
ctx->set_mtp_source(src);
}
float * llama_get_embeddings_nextn(llama_context * ctx) {
ctx->synchronize();

View File

@ -6,6 +6,7 @@
#include "llama-graph.h"
#include "llama-adapter.h"
#include "llama-impl.h"
#include "llama-memory.h"
#include "ggml-cpp.h"
#include "ggml-opt.h"
@ -111,6 +112,7 @@ struct llama_context {
void set_embeddings (bool value);
void set_embeddings_nextn(bool value, bool masked);
void set_mtp_source(llama_context * src);
void set_causal_attn(bool value);
void set_warmup(bool value);
@ -275,6 +277,12 @@ private:
std::unique_ptr<llama_memory_i> memory;
// external KV source used by MTP draft contexts. src_ctx is the target
// context whose memory we read; src_mctx_for_decode is a per-decode
// snapshot held for the duration of one decode/sched_reserve call.
llama_context * src_ctx = nullptr;
llama_memory_context_ptr src_mctx_for_decode;
// decode output (2-dimensional array: [n_outputs][n_vocab])
buffer_view<float> logits = {nullptr, 0};

View File

@ -85,6 +85,11 @@ using llama_memory_breakdown = std::map<ggml_backend_buffer_type_t, llama_memory
LLAMA_API int32_t llama_model_n_expert (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_devices(const struct llama_model * model);
// number of layers that own KV (i.e. layers whose graph writes K/V).
// 0 means the model owns no KV — e.g. a Gemma4-style MTP draft that reads
// trunk KV via llama_set_mtp_source.
LLAMA_API int32_t llama_model_n_layer_kv(const struct llama_model * model);
LLAMA_API ggml_backend_dev_t llama_model_get_device(const struct llama_model * model, int i);
LLAMA_API llama_memory_breakdown llama_get_memory_breakdown(const struct llama_context * ctx);
@ -93,6 +98,7 @@ LLAMA_API llama_memory_breakdown llama_get_memory_breakdown(const struct llama_c
// If masked == true, output the embeddings only for the tokens with batch.logits != 0
// If masked == false, output the embeddings for all tokens in the batch regardless of batch.logits
LLAMA_API void llama_set_embeddings_nextn(struct llama_context * ctx, bool value, bool masked);
LLAMA_API void llama_set_mtp_source(struct llama_context * ctx, struct llama_context * src);
// mirrors:
// LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);

View File

@ -620,6 +620,22 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
return res;
}
void llm_graph_input_attn_src_kv_iswa::set_input(const llama_ubatch * ubatch) {
src_mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
src_mctx->get_swa() ->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
}
bool llm_graph_input_attn_src_kv_iswa::can_reuse(const llm_graph_params & params) {
const auto * mctx = static_cast<const llama_kv_cache_iswa_context *>(params.src_mctx);
this->src_mctx = mctx;
bool res = true;
res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams);
res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams);
return res;
}
void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
GGML_ASSERT(cross_kq_mask);
@ -1034,6 +1050,8 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
cvec (params.cvec),
loras (params.loras),
mctx (params.mctx),
src_mctx (params.src_mctx),
src_model (params.src_model),
cross (params.cross),
samplers (params.samplers),
cb_func (params.cb),
@ -2604,6 +2622,59 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
return (llm_graph_input_attn_cross *) res->add_input(std::move(inp));
}
llm_graph_input_attn_src_kv_iswa * llm_graph_context::build_attn_inp_src_kv_iswa() const {
GGML_ASSERT(src_mctx && "MTP draft graph requires src_mctx (set via llama_set_mtp_source)");
const auto * src_iswa = static_cast<const llama_kv_cache_iswa_context *>(src_mctx);
auto inp = std::make_unique<llm_graph_input_attn_src_kv_iswa>(hparams, cparams, src_iswa);
inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, src_iswa->get_base(), ubatch, cparams);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
inp->self_kq_mask_swa = build_attn_inp_kq_mask(ctx0, src_iswa->get_swa(), ubatch, cparams);
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
return (llm_graph_input_attn_src_kv_iswa *) res->add_input(std::move(inp));
}
ggml_tensor * llm_graph_context::build_attn(
llm_graph_input_attn_src_kv_iswa * inp,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * wo_s,
ggml_tensor * q_cur,
ggml_tensor * kq_b,
ggml_tensor * sinks,
ggml_tensor * v_mla,
float kq_scale,
int il_assist,
int il_src) const {
const bool is_swa = hparams.is_swa(il_assist);
const auto * src_iswa = inp->src_mctx;
const auto * src_cur = is_swa ? src_iswa->get_swa() : src_iswa->get_base();
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
ggml_build_forward_expand(gf, q_cur);
ggml_tensor * q = q_cur;
ggml_tensor * k = src_cur->get_k(ctx0, il_src);
ggml_tensor * v = src_cur->get_v(ctx0, il_src);
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il_assist);
cb(cur, "kqv_out", il_assist);
if (wo) {
cur = build_lora_mm(wo, cur, wo_s);
}
if (wo_b) {
cur = ggml_add(ctx0, cur, wo_b);
}
return cur;
}
ggml_tensor * llm_graph_context::build_attn(
llm_graph_input_attn_cross * inp,
ggml_tensor * wo,

View File

@ -459,6 +459,37 @@ public:
const llama_kv_cache_iswa_context * mctx;
};
// mask-only input for attention against an external (read-only) ISWA KV cache.
// used by MTP draft graphs that attend to the target's KV without owning any.
class llm_graph_input_attn_src_kv_iswa : public llm_graph_input_i {
public:
llm_graph_input_attn_src_kv_iswa(
const llama_hparams & hparams,
const llama_cparams & cparams,
const llama_kv_cache_iswa_context * src_mctx) :
hparams(hparams),
cparams(cparams),
src_mctx(src_mctx) {
}
~llm_graph_input_attn_src_kv_iswa() = default;
void set_input(const llama_ubatch * ubatch) override;
bool can_reuse(const llm_graph_params & params) override;
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
ggml_tensor * self_kq_mask = nullptr;
ggml_tensor * self_kq_mask_cnv = nullptr;
ggml_tensor * self_kq_mask_swa = nullptr;
ggml_tensor * self_kq_mask_swa_cnv = nullptr;
const llama_hparams hparams;
const llama_cparams cparams;
const llama_kv_cache_iswa_context * src_mctx;
};
class llm_graph_input_attn_cross : public llm_graph_input_i {
public:
llm_graph_input_attn_cross(const llama_cross * cross) : cross(cross) {}
@ -601,6 +632,11 @@ struct llm_graph_params {
const llama_adapter_cvec * cvec;
const llama_adapter_loras * loras;
const llama_memory_context_i * mctx;
// per-decode snapshot of an external memory module the graph reads from
// (never writes) — e.g. ctx_dft reading target KV during MTP draft.
// nullptr for a main decode. Rebound inside reuse-aware input classes.
const llama_memory_context_i * src_mctx;
const llama_model * src_model;
const llama_cross * cross;
std::map<llama_seq_id, llama_sampler *> samplers;
@ -818,6 +854,8 @@ struct llm_graph_context {
const llama_adapter_cvec * cvec;
const llama_adapter_loras * loras;
const llama_memory_context_i * mctx;
const llama_memory_context_i * src_mctx;
const llama_model * src_model;
const llama_cross * cross;
std::map<llama_seq_id, llama_sampler *> samplers;
@ -1047,6 +1085,24 @@ struct llm_graph_context {
float kq_scale,
int il) const;
llm_graph_input_attn_src_kv_iswa * build_attn_inp_src_kv_iswa() const;
// Q-only attention against an external ISWA KV cache (no K/V projections,
// no writes). il_assist labels the attention block in the local graph for
// logging; il_src indexes the source K/V layer to attend to.
ggml_tensor * build_attn(
llm_graph_input_attn_src_kv_iswa * inp,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * wo_s,
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
ggml_tensor * kq_b,
ggml_tensor * sinks, // [n_head_q]
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
float kq_scale,
int il_assist,
int il_src) const;
llm_graph_input_attn_cross * build_attn_inp_cross() const;
ggml_tensor * build_attn(

View File

@ -2462,6 +2462,10 @@ uint32_t llama_kv_cache_context::get_n_kv() const {
return n_kv;
}
llama_pos llama_kv_cache_context::seq_pos_max(llama_seq_id seq_id) const {
return kv->seq_pos_max(seq_id);
}
ggml_type llama_kv_cache_context::type_k() const {
return kv->type_k();
}

View File

@ -354,6 +354,11 @@ public:
uint32_t get_n_kv() const;
// last position recorded in the cache for this sequence; -1 if absent.
// exposed for cross-context KV consumers (e.g. MTP draft) that need to
// anchor the source position without owning a memory module of their own.
llama_pos seq_pos_max(llama_seq_id seq_id) const;
ggml_type type_k() const;
ggml_type type_v() const;

View File

@ -139,6 +139,8 @@ static llama_model * llama_model_mapping(llm_arch arch, const llama_model_params
return new llama_model_gemma3n(params);
case LLM_ARCH_GEMMA4:
return new llama_model_gemma4(params);
case LLM_ARCH_GEMMA4_ASSISTANT:
return new llama_model_gemma4_assistant(params);
case LLM_ARCH_GEMMA_EMBEDDING:
return new llama_model_gemma_embedding(params);
case LLM_ARCH_STARCODER2:
@ -2378,6 +2380,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_GEMMA3:
case LLM_ARCH_GEMMA3N:
case LLM_ARCH_GEMMA4:
case LLM_ARCH_GEMMA4_ASSISTANT:
case LLM_ARCH_GEMMA_EMBEDDING:
case LLM_ARCH_STARCODER2:
case LLM_ARCH_OPENELM:
@ -2572,6 +2575,10 @@ int32_t llama_model_n_devices(const struct llama_model * model) {
return (int32_t)model->devices.size();
}
int32_t llama_model_n_layer_kv(const struct llama_model * model) {
return (int32_t) model->hparams.n_layer_kv();
}
ggml_backend_dev_t llama_model_get_device(const struct llama_model * model, int i) {
if (i < 0 || i >= (int)model->devices.size()) {
return nullptr;

View File

@ -548,6 +548,10 @@ struct llama_model {
struct ggml_tensor * output_s = nullptr;
struct ggml_tensor * output_in_s = nullptr;
// NextN/MTP model-level projections
struct ggml_tensor * nextn_pre_proj = nullptr;
struct ggml_tensor * nextn_post_proj = nullptr;
// classifier
struct ggml_tensor * cls = nullptr;
struct ggml_tensor * cls_b = nullptr;

View File

@ -135,6 +135,214 @@ std::unique_ptr<llm_graph_context> llama_model_gemma4::build_arch_graph(const ll
return std::make_unique<graph>(*this, params);
}
void llama_model_gemma4_assistant::load_arch_hparams(llama_model_loader & ml) {
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.is_swa_impl, hparams.n_layer);
uint32_t n_kv_shared_layers = 0;
ml.get_key(LLM_KV_ATTENTION_SHARED_KV_LAYERS, n_kv_shared_layers, false);
hparams.n_layer_kv_from_start = hparams.n_layer - (int32_t) n_kv_shared_layers;
hparams.f_attention_scale = 1.0f;
ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false);
ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false);
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_SWA, hparams.n_embd_head_k_swa);
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_SWA, hparams.n_embd_head_v_swa);
if (hparams.n_layer == 4) {
type = LLM_TYPE_31B;
}
}
void llama_model_gemma4_assistant::load_arch_tensors(llama_model_loader &) {
LLAMA_LOAD_LOCALS;
if (n_embd_head_k != n_embd_head_v) {
throw std::runtime_error("Gemma 4 assistant requires n_embd_head_k == n_embd_head_v");
}
if (hparams.n_embd_head_k_swa != hparams.n_embd_head_v_swa) {
throw std::runtime_error("Gemma 4 assistant requires n_embd_head_k_swa == n_embd_head_v_swa");
}
if (hparams.n_embd_out() == n_embd) {
throw std::runtime_error("Gemma 4 assistant requires embedding_length_out to carry the target hidden size");
}
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0);
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED);
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0);
const int64_t n_embd_backbone = hparams.n_embd_out();
nextn_pre_proj = create_tensor(tn(LLM_TENSOR_NEXTN_PRE_PROJ, "weight"), { 2*n_embd_backbone, n_embd }, 0);
nextn_post_proj = create_tensor(tn(LLM_TENSOR_NEXTN_POST_PROJ, "weight"), { n_embd, n_embd_backbone }, 0);
int rope_freqs_flag = 0;
for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i];
const int64_t n_head = hparams.n_head(i);
const int64_t n_embd_head = hparams.n_embd_head_k(i);
const int64_t n_ff = hparams.n_ff(i);
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head*n_head }, 0);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head*n_head, n_embd }, 0);
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head }, 0);
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0);
layer.out_scale = create_tensor(tn(LLM_TENSOR_LAYER_OUT_SCALE, "weight", i), { 1u }, 0);
if (!hparams.is_swa(i)) {
layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), { n_embd_head/2 }, rope_freqs_flag);
rope_freqs_flag = TENSOR_DUPLICATED;
}
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0);
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0);
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), { n_embd }, 0);
}
}
std::unique_ptr<llm_graph_context> llama_model_gemma4_assistant::build_arch_graph(const llm_graph_params & params) const {
return std::make_unique<graph>(*this, params);
}
llama_model_gemma4_assistant::graph::graph(const llama_model & model, const llm_graph_params & params) :
llm_graph_context(params) {
GGML_ASSERT(src_mctx && "Gemma 4 assistant graph requires an MTP source (llama_set_mtp_source)");
GGML_ASSERT(src_model && "Gemma 4 assistant graph requires a source model");
GGML_ASSERT(src_model->tok_embd && "source model missing tok_embd");
const auto & src_hparams = src_model->hparams;
// By convention the MTP draft reads from the trunk's final SWA and full layers.
const int32_t src_layer_full = (int32_t) src_hparams.n_layer - 1;
const int32_t src_layer_swa = (int32_t) src_hparams.n_layer - 2;
GGML_ASSERT(!src_hparams.is_swa(src_layer_full) && "trunk's last layer must be full attention");
GGML_ASSERT( src_hparams.is_swa(src_layer_swa) && "trunk's penultimate layer must be SWA");
const int64_t n_embd_backbone = hparams.n_embd_out();
ggml_tensor * inp_tokens;
ggml_tensor * inp_h;
{
auto inp = std::make_unique<llm_graph_input_embd>(n_embd_backbone);
inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
cb(inp->tokens, "inp_tokens", -1);
ggml_set_input(inp->tokens);
inp_tokens = inp->tokens;
res->t_inp_tokens = inp->tokens;
inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd_backbone, ubatch.n_tokens);
cb(inp->embd, "inp_h", -1);
ggml_set_input(inp->embd);
inp_h = inp->embd;
res->t_inp_embd = inp->embd;
res->add_input(std::move(inp));
}
ggml_tensor * x = ggml_get_rows(ctx0, src_model->tok_embd, inp_tokens);
x = ggml_scale(ctx0, x, sqrtf((float) n_embd_backbone));
cb(x, "inp_embd_target", -1);
ggml_tensor * xh = ggml_concat(ctx0, x, inp_h, 0);
cb(xh, "inp_xh", -1);
ggml_tensor * cur = ggml_mul_mat(ctx0, model.nextn_pre_proj, xh);
cb(cur, "pre_proj", -1);
auto * inp_attn = build_attn_inp_src_kv_iswa();
ggml_tensor * inp_pos = build_inp_pos();
ggml_tensor * inp_out_ids = build_inp_out_ids();
ggml_tensor * inpL = cur;
for (int il = 0; il < n_layer; ++il) {
const bool is_swa = hparams.is_swa(il);
const int32_t il_src = is_swa ? src_layer_swa : src_layer_full;
const int64_t n_embd_head = hparams.n_embd_head_k(il);
const int64_t n_head = hparams.n_head(il);
const float freq_base_l = model.get_rope_freq_base(cparams, il);
const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
const int n_rot_l = hparams.n_rot(il);
ggml_tensor * cur_norm = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il);
cb(cur_norm, "attn_norm", il);
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur_norm);
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il);
cb(Qcur, "Qcur_normed", il);
ggml_tensor * freq_factors = is_swa ? nullptr : model.layers[il].rope_freqs;
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, freq_factors, n_rot_l, rope_type, n_ctx_orig,
freq_base_l, freq_scale_l, ext_factor, attn_factor, beta_fast, beta_slow);
cb(Qcur, "Qcur_pos", il);
cur = build_attn(inp_attn, model.layers[il].wo, nullptr, nullptr,
Qcur, nullptr, nullptr, nullptr, hparams.f_attention_scale, il, il_src);
if (il == n_layer - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
}
cur = build_norm(cur, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, il);
cb(cur, "attn_post_norm", il);
ggml_tensor * attn_out = ggml_add(ctx0, cur, inpL);
cb(attn_out, "attn_out", il);
cur = build_norm(attn_out, model.layers[il].ffn_norm, nullptr, LLM_NORM_RMS, il);
cb(cur, "ffn_norm", il);
cur = build_ffn(cur,
model.layers[il].ffn_up, nullptr, nullptr,
model.layers[il].ffn_gate, nullptr, nullptr,
model.layers[il].ffn_down, nullptr, nullptr,
nullptr,
LLM_FFN_GELU, LLM_FFN_PAR, il);
cb(cur, "ffn_out", il);
cur = build_norm(cur, model.layers[il].ffn_post_norm, nullptr, LLM_NORM_RMS, -1);
cb(cur, "ffn_post_norm", il);
cur = ggml_add(ctx0, cur, attn_out);
cur = ggml_mul(ctx0, cur, model.layers[il].out_scale);
cb(cur, "out_scaled", il);
inpL = cur;
}
cur = inpL;
cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1);
cb(cur, "result_norm", -1);
res->t_embd = cur;
ggml_tensor * logits = build_lora_mm(model.output, cur);
cb(logits, "result_output", -1);
res->t_logits = logits;
ggml_tensor * h_next = ggml_mul_mat(ctx0, model.nextn_post_proj, cur);
cb(h_next, "h_nextn", -1);
res->t_h_nextn = h_next;
ggml_build_forward_expand(gf, logits);
ggml_build_forward_expand(gf, h_next);
}
// get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim
static ggml_tensor * ggml_view_2d_slice(ggml_context * ctx0, ggml_tensor * x, int idx) {
GGML_ASSERT(idx < (int) x->ne[2]);
@ -270,7 +478,8 @@ llama_model_gemma4::graph::graph(const llama_model & model, const llm_graph_para
}
// TODO @ngxson : strip unused token right after the last KV layer to speed up prompt processing
if (il == n_layer - 1 && inp_out_ids) {
// keep all rows when extracting unmasked nextn embeddings (MTP target needs the hidden state for every token)
if (il == n_layer - 1 && inp_out_ids && cparams.embeddings_nextn_masked) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
}
@ -370,7 +579,7 @@ llama_model_gemma4::graph::graph(const llama_model & model, const llm_graph_para
ggml_tensor * inp_this_layer = ggml_view_2d_slice(ctx0, inp_per_layer, il); // [n_embd_per_layer, n_tokens]
// TODO @ngxson : improve this
if (il == n_layer - 1 && inp_out_ids) {
if (il == n_layer - 1 && inp_out_ids && cparams.embeddings_nextn_masked) {
inp_this_layer = ggml_get_rows(ctx0, inp_this_layer, inp_out_ids);
}
@ -401,6 +610,17 @@ llama_model_gemma4::graph::graph(const llama_model & model, const llm_graph_para
model.output_norm, nullptr,
LLM_NORM_RMS, -1);
// Expose the post-output-norm hidden state (the LM-head input feature) so that
// MTP draft contexts can read it via llama_get_embeddings_nextn_ith() as the
// recurrent h input. This matches the reference (transformers/vLLM/SGLang),
// which feeds the drafter the target's post-final-norm hidden state.
cb(cur, "h_nextn", -1);
res->t_h_nextn = cur;
if (!cparams.embeddings_nextn_masked && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
}
cb(cur, "result_norm", -1);
res->t_embd = cur;

View File

@ -822,6 +822,19 @@ struct llama_model_gemma4 : public llama_model_base {
};
struct llama_model_gemma4_assistant : public llama_model_base {
llama_model_gemma4_assistant(const struct llama_model_params & params) : llama_model_base(params) {}
void load_arch_hparams(llama_model_loader & ml) override;
void load_arch_tensors(llama_model_loader & ml) override;
struct graph : public llm_graph_context {
graph(const llama_model & model, const llm_graph_params & params);
};
std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override;
};
struct llama_model_gemma_embedding : public llama_model_base {
llama_model_gemma_embedding(const struct llama_model_params & params) : llama_model_base(params) {}
void load_arch_hparams(llama_model_loader & ml) override;

View File

@ -10,6 +10,7 @@
#include "common.h"
#include "fit.h"
#include "llama.h"
#include "../../src/llama-ext.h" // staging API: llama_set_mtp_source
#include "log.h"
#include "sampling.h"
#include "speculative.h"
@ -949,6 +950,11 @@ private:
cparams.n_rs_seq = 0;
ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams));
if (spec_mtp) {
// MTP draft must know its target before the first decode
llama_set_mtp_source(ctx_dft.get(), ctx_tgt);
}
ctx_dft_seq_rm_type = common_context_can_seq_rm(ctx_dft.get());
params_base.speculative.draft.ctx_tgt = ctx_tgt;
@ -971,6 +977,10 @@ private:
return false;
}
// wire the source before any decode (the seq-rm probe below
// triggers sched_reserve which needs src for Gemma4-style MTP)
llama_set_mtp_source(ctx_dft.get(), ctx_tgt);
ctx_dft_seq_rm_type = common_context_can_seq_rm(ctx_dft.get());
params_base.speculative.draft.ctx_tgt = ctx_tgt;