spec : Support Step3.5/3.7 flash mtp3 (#24340)

* add mtp_layer_offset + include nextn flags in graph reuse

* add llama_set_mtp_layer_offset + llama_model_n_nextn_layer API

* offset head select + require all MTP blocks

* speculative multi-head process()

* speculative multi-head draft()

* gather outputs via inp_out_ids

* cleanup

* fix core

* minor cleanup

* merged draft_multi_head into draft()

* mtp rename nextn

* Apply suggestions from code review

Co-authored-by: Aman Gupta <amangupta052@gmail.com>

* clean-up comments

* fix for multi seq

* apply suggestions && chain-heads comment

* add a reference for chain_heads discussion

---------

Co-authored-by: Aman Gupta <amangupta052@gmail.com>
This commit is contained in:
YiChen Lv 2026-06-21 16:33:18 +08:00 committed by GitHub
parent 063d9c156e
commit d789527482
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 167 additions and 73 deletions

View File

@ -905,7 +905,13 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
int32_t n_embd = 0;
bool is_mem_shared = false;
// One MTP draft driver, three modes (set once in the ctor):
// is_mem_shared (gemma4): shares the target KV, runs all heads in one graph.
// chain_heads (step35): n_mtp_layers trained heads, one per draft step.
// neither (qwen35 / qwen35moe): a single trained MTP head.
int32_t n_mtp_layers = 1;
bool is_mem_shared = false; // gemma4
bool chain_heads = false; // derived in the ctor: n_mtp_layers > 1 && !is_mem_shared
// Per-sequence cross-batch carryover: pair (h_p, x_{p+1}) at MTP pos p+1.
// The last h-row of one process() call needs the first token of the NEXT
@ -920,10 +926,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
std::vector<std::vector<float>> verify_h;
std::vector<int32_t> verify_h_rows;
// Per-seq draft length from the last draft() call, used in accept() to
// roll back ctx_dft's recurrent state past the AR draft's redundant
// pre-advancement before process() mirrored the verify batch.
std::vector<uint16_t> last_n_drafted;
std::vector<int> i_last;
std::vector<std::vector<float>> chain_h;
common_speculative_impl_draft_mtp(const common_params_speculative & params, uint32_t n_seq)
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT_MTP, n_seq)
@ -936,6 +940,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
n_embd = llama_model_n_embd_out(llama_get_model(ctx_dft));
GGML_ASSERT(n_embd == llama_model_n_embd(llama_get_model(ctx_tgt)) &&
"MTP input row width must match the target h_nextn width");
n_mtp_layers = std::max(1, (int) llama_model_n_layer_nextn(llama_get_model(ctx_dft)));
LOG_INF("%s: adding speculative implementation 'draft-mtp'\n", __func__);
LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%.2f, n_embd=%d, backend_sampling=%d\n", __func__, this->params.n_max, this->params.n_min, this->params.p_min, n_embd, (int) this->params.backend_sampling);
@ -982,16 +987,25 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
llama_set_embeddings_nextn(ctx_dft, true, /*masked*/ true);
is_mem_shared = llama_get_ctx_other(ctx_dft) == ctx_tgt;
chain_heads = n_mtp_layers > 1 && !is_mem_shared;
if (chain_heads) {
this->params.n_max = std::min(this->params.n_max, n_mtp_layers);
chain_h.assign(n_seq, {});
for (auto & c : chain_h) {
c.reserve((size_t) (this->params.n_max + 1) * n_embd);
}
}
pending_h.assign(n_seq, std::vector<float>(n_embd, 0.0f));
i_last.assign(n_seq, -1);
i_batch_beg.assign(n_seq, -1);
i_batch_end.assign(n_seq, -1);
verify_h.assign(n_seq, {});
verify_h_rows.assign(n_seq, 0);
last_n_drafted.assign(n_seq, 0);
}
~common_speculative_impl_draft_mtp() override {
@ -1097,9 +1111,34 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
set_h(i_batch_beg[seq_id], pending_h[seq_id].data());
}
const int32_t rc = llama_decode(ctx_dft, batch);
if (rc != 0) {
LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (pos=%d)\n", __func__, (int) rc, (int) batch_in.pos[0]);
auto * mem_dft = llama_get_memory(ctx_dft);
bool ok = true;
for (int head = 0; head < n_mtp_layers; ++head) {
if (chain_heads) {
// ref: https://github.com/ggml-org/llama.cpp/pull/24340/changes#r3413498544
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
if (i_batch_beg[seq_id] < 0) {
continue;
}
llama_memory_seq_rm(mem_dft, seq_id, batch_in.pos[i_batch_beg[seq_id]], -1);
}
llama_set_nextn_layer_offset(ctx_dft, head);
}
const int32_t rc = llama_decode(ctx_dft, batch);
if (rc != 0) {
LOG_ERR("%s: llama_decode(ctx_dft) head=%d failed rc=%d (pos=%d)\n",
__func__, head, (int) rc, (int) batch_in.pos[0]);
ok = false;
break;
}
}
if (chain_heads) {
llama_set_nextn_layer_offset(ctx_dft, 0); // restore default for non-draft decodes
}
if (!ok) {
return false;
}
}
@ -1134,7 +1173,6 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
int n_drafting = 0;
std::vector<bool> drafting(n_seq);
const float * h_row = nullptr;
const size_t row_bytes = (size_t) n_embd * sizeof(float);
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
@ -1149,22 +1187,43 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
common_sampler_reset(smpls[seq_id].get());
common_batch_add(batch, dp.id_last, dp.n_past, { seq_id }, true);
std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd, pending_h[seq_id].data(), row_bytes);
h_row = pending_h[seq_id].data();
std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes);
}
i_last[seq_id] = batch.n_tokens - 1;
int ret = llama_decode(ctx_dft, batch);
if (ret != 0) {
LOG_WRN("%s: llama_decode returned %d\n", __func__, ret);
return;
if (chain_heads) {
chain_h[seq_id].assign(pending_h[seq_id].begin(), pending_h[seq_id].end());
}
}
int i = 0;
while (n_drafting > 0) {
int i_batch = 0;
// each step decodes under a different head, i.e. a different decoder layer, and
// KV is per layer. process() filled this layer's KV only for positions < n_past
// (prompt + accepted prefix) — nothing in the draft region yet. so reset the
// draft region (the seq_rm lower bound is n_past, leaving the prompt KV intact)
// and select head i so it rebuilds its own layer's KV there; decoding just the
// latest token would leave its attention reading cells only another head wrote.
if (chain_heads) {
auto * mem_dft = llama_get_memory(ctx_dft);
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
if (drafting[seq_id]) {
llama_memory_seq_rm(mem_dft, seq_id, dparams[seq_id].n_past, -1);
}
}
llama_set_nextn_layer_offset(ctx_dft, i);
}
int ret = llama_decode(ctx_dft, batch);
if (ret != 0) {
LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret);
break;
}
// rebuild the batch for the next step: the growing-KV paths re-add only the
// new token (the KV already holds the prefix), while chained heads re-add the
// whole prefix at the next head. dropped sequences are simply not re-added.
common_batch_clear(batch);
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
@ -1174,9 +1233,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
auto * smpl = smpls[seq_id].get();
common_sampler_sample(smpl, ctx_dft, i_batch, true);
h_row = llama_get_embeddings_nextn_ith(ctx_dft, i_batch);
++i_batch;
common_sampler_sample(smpl, ctx_dft, i_last[seq_id], true);
const float * h_row = llama_get_embeddings_nextn_ith(ctx_dft, i_last[seq_id]);
const auto * cur_p = common_sampler_get_candidates(smpl, true);
@ -1210,30 +1268,41 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
continue;
}
if (is_mem_shared) {
if (chain_heads) {
// ref: https://github.com/ggml-org/llama.cpp/pull/24340#discussion_r3448031546
chain_h[seq_id].insert(chain_h[seq_id].end(), h_row, h_row + n_embd);
const int n_rows = (int) result.size() + 1; // id_last + tokens drafted so far
for (int t = 0; t < n_rows; ++t) {
const llama_token tok = (t == 0) ? dp.id_last : result[t - 1];
common_batch_add(batch, tok, dp.n_past + t, { seq_id }, t == n_rows - 1);
std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd,
chain_h[seq_id].data() + (size_t) t * n_embd, row_bytes);
}
} else if (is_mem_shared) {
// note: with shared memory (e.g. Gemma4 assistants) we use the same position for all draft tokens
// ref: https://github.com/huggingface/transformers/blob/effde20942e3f82a1b97449f60b3a48c5ff96145/docs/source/en/model_doc/gemma4_assistant.md?plain=1#L36-L37
common_batch_add(batch, id, dp.n_past, { seq_id }, true);
std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd, h_row, row_bytes);
} else {
common_batch_add(batch, id, dp.n_past + i + 1, { seq_id }, true);
std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd, h_row, row_bytes);
}
std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes);
i_last[seq_id] = batch.n_tokens - 1;
}
if (batch.n_tokens == 0) {
break;
}
// evaluate the drafted tokens on the draft model
ret = llama_decode(ctx_dft, batch);
if (ret != 0) {
LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret);
break;
}
++i;
}
if (chain_heads) {
llama_set_nextn_layer_offset(ctx_dft, 0); // restore default for non-draft decodes
}
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
auto & dp = dparams[seq_id];
if (!dp.drafting) {
@ -1243,8 +1312,6 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
if (dp.result->size() < (size_t) params.n_min) {
dp.result->clear();
}
last_n_drafted[seq_id] = (uint16_t) dp.result->size();
}
}
@ -1857,7 +1924,7 @@ common_speculative * common_speculative_init(common_params_speculative & params,
bool has_draft_simple = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE));
bool has_draft_eagle3 = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3)) && params.draft.ctx_dft != nullptr;
bool has_mtp = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_MTP)) && params.draft.ctx_dft != nullptr;
bool has_draft_mtp = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_MTP)) && params.draft.ctx_dft != nullptr;
@ -1895,7 +1962,7 @@ common_speculative * common_speculative_init(common_params_speculative & params,
if (has_draft_eagle3) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, params));
}
if (has_mtp) {
if (has_draft_mtp) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_MTP, params));
}
}

View File

@ -558,14 +558,15 @@ extern "C" {
LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model);
LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model);
LLAMA_API int32_t llama_model_n_ctx_train(const struct llama_model * model);
LLAMA_API int32_t llama_model_n_embd (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_embd_inp (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_embd_out (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_head (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_swa (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_ctx_train (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_embd (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_embd_inp (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_embd_out (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_layer_nextn(const struct llama_model * model);
LLAMA_API int32_t llama_model_n_head (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_swa (const struct llama_model * model);
// Get the model's RoPE frequency scaling factor
LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model);

View File

@ -1156,6 +1156,10 @@ void llama_context::set_embeddings_layer_inp(uint32_t lid, bool enable) {
sched_need_reserve = true;
}
void llama_context::set_nextn_layer_offset(int32_t offset) {
cparams.nextn_layer_offset = offset;
}
void llama_context::set_causal_attn(bool value) {
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
@ -3699,6 +3703,10 @@ void llama_set_embeddings_layer_inp(llama_context * ctx, uint32_t lid, bool valu
ctx->set_embeddings_layer_inp(lid, value);
}
void llama_set_nextn_layer_offset(llama_context * ctx, int32_t offset) {
ctx->set_nextn_layer_offset(offset);
}
llama_memory_t llama_get_memory(const struct llama_context * ctx) {
if (!ctx) {
return nullptr;

View File

@ -115,6 +115,7 @@ struct llama_context {
void set_embeddings (bool value);
void set_embeddings_nextn(bool value, bool masked);
void set_embeddings_layer_inp(uint32_t lid, bool enable);
void set_nextn_layer_offset(int32_t offset);
void set_causal_attn(bool value);
void set_warmup(bool value);

View File

@ -18,6 +18,8 @@ struct llama_cparams {
int32_t n_threads; // number of threads to use for generation
int32_t n_threads_batch; // number of threads to use for batch processing
int32_t nextn_layer_offset = 0;
float rope_freq_base;
float rope_freq_scale;

View File

@ -95,6 +95,11 @@ LLAMA_API llama_memory_breakdown llama_get_memory_breakdown(const struct llama_c
// If masked == false, output the embeddings for all tokens in the batch regardless of batch.logits
LLAMA_API void llama_set_embeddings_nextn(struct llama_context * ctx, bool value, bool masked);
// Select which appended NextN block the DECODER_MTP graph runs (offset past
// the trunk: il = n_layer() + offset). Used by the speculative NextN driver to
// chain multiple trained NextN heads. Default 0 (first head).
LLAMA_API void llama_set_nextn_layer_offset(struct llama_context * ctx, int32_t offset);
// mirrors:
// LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
LLAMA_API float * llama_get_embeddings_nextn(struct llama_context * ctx);

View File

@ -682,9 +682,16 @@ struct llm_graph_params {
}
}
// TODO: https://github.com/ggml-org/llama.cpp/pull/24340#discussion_r3448035248
if (cparams.nextn_layer_offset != other.cparams.nextn_layer_offset) {
return false;
}
return
cparams.embeddings == other.cparams.embeddings &&
cparams.causal_attn == other.cparams.causal_attn &&
cparams.embeddings == other.cparams.embeddings &&
cparams.embeddings_nextn == other.cparams.embeddings_nextn &&
cparams.embeddings_nextn_masked == other.cparams.embeddings_nextn_masked &&
cparams.causal_attn == other.cparams.causal_attn &&
arch == other.arch &&
gtype == other.gtype &&
cvec == other.cvec &&

View File

@ -2312,6 +2312,10 @@ int32_t llama_model_n_layer(const llama_model * model) {
return model->hparams.n_layer();
}
int32_t llama_model_n_layer_nextn(const llama_model * model) {
return model->hparams.n_layer_nextn;
}
int32_t llama_model_n_head(const llama_model * model) {
return model->hparams.n_head();
}

View File

@ -112,7 +112,7 @@ void llama_model_step35::load_arch_tensors(llama_model_loader & ml) {
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, TENSOR_NOT_REQUIRED);
};
auto load_block_mtp = [&](int i, bool is_first_mtp) {
auto load_block_mtp = [&](int i) {
auto & layer = layers[i];
const uint32_t n_head_l = hparams.n_head(i);
@ -121,15 +121,12 @@ void llama_model_step35::load_arch_tensors(llama_model_loader & ml) {
// The MTP block is a full Step3p5 decoder layer (mtp_block) plus the
// NextN-specific wiring (enorm/hnorm/eh_proj + optional shared head).
// `mtp_flags` becomes NOT_REQUIRED when the GGUF is trunk-only.
//
// Only the FIRST MTP block (i == n_main) is required for the
// single-block MTP runtime; trailing MTP blocks are always tolerated
// as missing so pruned GGUFs (block 0 only) load cleanly. Override
// mtp_flags to NOT_REQUIRED for those.
const int eff_mtp_flags = is_first_mtp ? mtp_flags : (mtp_flags | TENSOR_NOT_REQUIRED);
// Multi-block MTP: every declared MTP block is required (the draft chain
// runs all n_layer_nextn heads), so each block uses the captured
// `mtp_flags` directly — already NOT_REQUIRED for a trunk-only GGUF,
// which keeps that path correct.
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, eff_mtp_flags);
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, mtp_flags);
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED);
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED);
@ -140,12 +137,12 @@ void llama_model_step35::load_arch_tensors(llama_model_loader & ml) {
layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | TENSOR_DUPLICATED);
}
create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_l, n_embd_k_gqa, n_embd_v_gqa, eff_mtp_flags);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v * n_head_l, n_embd}, eff_mtp_flags);
create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_l, n_embd_k_gqa, n_embd_v_gqa, mtp_flags);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v * n_head_l, n_embd}, mtp_flags);
layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), {n_embd, n_head_l}, TENSOR_NOT_REQUIRED);
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, eff_mtp_flags);
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, mtp_flags);
// dense MLP (leading dense blocks) — present if the MTP block isn't MoE
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED);
@ -165,9 +162,9 @@ void llama_model_step35::load_arch_tensors(llama_model_loader & ml) {
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, TENSOR_NOT_REQUIRED);
// NextN-specific tensors that define the MTP block.
layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, eff_mtp_flags);
layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, eff_mtp_flags);
layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, eff_mtp_flags);
layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, mtp_flags);
layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, mtp_flags);
layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, mtp_flags);
layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED);
@ -176,13 +173,11 @@ void llama_model_step35::load_arch_tensors(llama_model_loader & ml) {
for (int i = 0; i < n_layer; ++i) {
load_block_trunk(i, trunk_flags);
}
// Only the first MTP block (i == n_main) is required at runtime — the
// single-block-MTP graph in build_arch_graph always uses that one.
// Trailing MTP blocks are loaded if present (so an un-pruned GGUF with
// all MTP layers still works) but tolerated when absent via the pruning
// path. See scripts/prune_step35_extra_mtp.py for the pruner.
// All n_layer_nextn MTP blocks are required — the multi-block draft chain
// runs every head (head k at offset k). The GGUF declares the count via
// step35.nextn_predict_layers.
for (int i = n_layer; i < n_layer_all; ++i) {
load_block_mtp(i, /*is_first_mtp=*/ i == n_layer);
load_block_mtp(i);
}
}
@ -372,13 +367,14 @@ llama_model_step35::graph_mtp::graph_mtp(const llama_model & model, const llm_gr
: llm_graph_context(params) {
GGML_ASSERT(hparams.n_layer_nextn > 0 && "STEP35 MTP requires n_layer_nextn > 0");
// Single-block MTP only: always run the first trained MTP block (Qwen
// MTP / vLLM single-MTP-layer style). Multi-block round-robin proved to
// be a much deeper refactor than this PR justifies; the trailing MTP
// blocks are loaded with TENSOR_NOT_REQUIRED so pruned GGUFs (with just
// block 0) also work — see load_arch_tensors below and
// scripts/prune_step35_extra_mtp.py.
const int il = hparams.n_layer();
// Multi-block MTP: the DECODER_MTP graph runs the MTP head selected by
// cparams.nextn_layer_offset (0 = first trained head). The speculative driver
// bumps the offset per draft step to chain heads 45->46->47. offset 0 keeps
// single-block behavior identical to before.
const int il = hparams.n_layer() + cparams.nextn_layer_offset;
GGML_ASSERT(cparams.nextn_layer_offset >= 0 &&
cparams.nextn_layer_offset < (int) hparams.n_layer_nextn &&
"nextn_layer_offset out of range [0, n_layer_nextn)");
const auto & layer = model.layers[il];
GGML_ASSERT(layer.nextn.eh_proj && "MTP block missing nextn.eh_proj");
@ -536,6 +532,9 @@ llama_model_step35::graph_mtp::graph_mtp(const llama_model & model, const llm_gr
cur = ggml_add(ctx0, cur, ffn_inp);
cb(cur, "mtp_post_ffn", il);
ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
// Pre-norm hidden state: used by the AR draft loop to seed the next MTP step.
cb(cur, "h_nextn", -1);
res->t_h_nextn = cur;