diff --git a/common/speculative.cpp b/common/speculative.cpp index 9c20585dc3..3c38ae2b02 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -905,7 +905,13 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { int32_t n_embd = 0; - bool is_mem_shared = false; + // One MTP draft driver, three modes (set once in the ctor): + // is_mem_shared (gemma4): shares the target KV, runs all heads in one graph. + // chain_heads (step35): n_mtp_layers trained heads, one per draft step. + // neither (qwen35 / qwen35moe): a single trained MTP head. + int32_t n_mtp_layers = 1; + bool is_mem_shared = false; // gemma4 + bool chain_heads = false; // derived in the ctor: n_mtp_layers > 1 && !is_mem_shared // Per-sequence cross-batch carryover: pair (h_p, x_{p+1}) at MTP pos p+1. // The last h-row of one process() call needs the first token of the NEXT @@ -920,10 +926,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { std::vector> verify_h; std::vector verify_h_rows; - // Per-seq draft length from the last draft() call, used in accept() to - // roll back ctx_dft's recurrent state past the AR draft's redundant - // pre-advancement before process() mirrored the verify batch. - std::vector last_n_drafted; + std::vector i_last; + std::vector> chain_h; common_speculative_impl_draft_mtp(const common_params_speculative & params, uint32_t n_seq) : common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT_MTP, n_seq) @@ -936,6 +940,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { n_embd = llama_model_n_embd_out(llama_get_model(ctx_dft)); GGML_ASSERT(n_embd == llama_model_n_embd(llama_get_model(ctx_tgt)) && "MTP input row width must match the target h_nextn width"); + n_mtp_layers = std::max(1, (int) llama_model_n_layer_nextn(llama_get_model(ctx_dft))); LOG_INF("%s: adding speculative implementation 'draft-mtp'\n", __func__); LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%.2f, n_embd=%d, backend_sampling=%d\n", __func__, this->params.n_max, this->params.n_min, this->params.p_min, n_embd, (int) this->params.backend_sampling); @@ -982,16 +987,25 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { llama_set_embeddings_nextn(ctx_dft, true, /*masked*/ true); is_mem_shared = llama_get_ctx_other(ctx_dft) == ctx_tgt; + chain_heads = n_mtp_layers > 1 && !is_mem_shared; + + if (chain_heads) { + this->params.n_max = std::min(this->params.n_max, n_mtp_layers); + + chain_h.assign(n_seq, {}); + for (auto & c : chain_h) { + c.reserve((size_t) (this->params.n_max + 1) * n_embd); + } + } pending_h.assign(n_seq, std::vector(n_embd, 0.0f)); + i_last.assign(n_seq, -1); i_batch_beg.assign(n_seq, -1); i_batch_end.assign(n_seq, -1); verify_h.assign(n_seq, {}); verify_h_rows.assign(n_seq, 0); - - last_n_drafted.assign(n_seq, 0); } ~common_speculative_impl_draft_mtp() override { @@ -1097,9 +1111,34 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { set_h(i_batch_beg[seq_id], pending_h[seq_id].data()); } - const int32_t rc = llama_decode(ctx_dft, batch); - if (rc != 0) { - LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (pos=%d)\n", __func__, (int) rc, (int) batch_in.pos[0]); + auto * mem_dft = llama_get_memory(ctx_dft); + + bool ok = true; + for (int head = 0; head < n_mtp_layers; ++head) { + if (chain_heads) { + // ref: https://github.com/ggml-org/llama.cpp/pull/24340/changes#r3413498544 + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { + if (i_batch_beg[seq_id] < 0) { + continue; + } + llama_memory_seq_rm(mem_dft, seq_id, batch_in.pos[i_batch_beg[seq_id]], -1); + } + llama_set_nextn_layer_offset(ctx_dft, head); + } + + const int32_t rc = llama_decode(ctx_dft, batch); + if (rc != 0) { + LOG_ERR("%s: llama_decode(ctx_dft) head=%d failed rc=%d (pos=%d)\n", + __func__, head, (int) rc, (int) batch_in.pos[0]); + ok = false; + break; + } + } + + if (chain_heads) { + llama_set_nextn_layer_offset(ctx_dft, 0); // restore default for non-draft decodes + } + if (!ok) { return false; } } @@ -1134,7 +1173,6 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { int n_drafting = 0; std::vector drafting(n_seq); - const float * h_row = nullptr; const size_t row_bytes = (size_t) n_embd * sizeof(float); for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { @@ -1149,22 +1187,43 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { common_sampler_reset(smpls[seq_id].get()); common_batch_add(batch, dp.id_last, dp.n_past, { seq_id }, true); + std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd, pending_h[seq_id].data(), row_bytes); - h_row = pending_h[seq_id].data(); - std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes); - } + i_last[seq_id] = batch.n_tokens - 1; - int ret = llama_decode(ctx_dft, batch); - if (ret != 0) { - LOG_WRN("%s: llama_decode returned %d\n", __func__, ret); - return; + if (chain_heads) { + chain_h[seq_id].assign(pending_h[seq_id].begin(), pending_h[seq_id].end()); + } } int i = 0; while (n_drafting > 0) { - int i_batch = 0; + // each step decodes under a different head, i.e. a different decoder layer, and + // KV is per layer. process() filled this layer's KV only for positions < n_past + // (prompt + accepted prefix) — nothing in the draft region yet. so reset the + // draft region (the seq_rm lower bound is n_past, leaving the prompt KV intact) + // and select head i so it rebuilds its own layer's KV there; decoding just the + // latest token would leave its attention reading cells only another head wrote. + if (chain_heads) { + auto * mem_dft = llama_get_memory(ctx_dft); + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { + if (drafting[seq_id]) { + llama_memory_seq_rm(mem_dft, seq_id, dparams[seq_id].n_past, -1); + } + } + llama_set_nextn_layer_offset(ctx_dft, i); + } + int ret = llama_decode(ctx_dft, batch); + if (ret != 0) { + LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret); + break; + } + + // rebuild the batch for the next step: the growing-KV paths re-add only the + // new token (the KV already holds the prefix), while chained heads re-add the + // whole prefix at the next head. dropped sequences are simply not re-added. common_batch_clear(batch); for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { @@ -1174,9 +1233,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { auto * smpl = smpls[seq_id].get(); - common_sampler_sample(smpl, ctx_dft, i_batch, true); - h_row = llama_get_embeddings_nextn_ith(ctx_dft, i_batch); - ++i_batch; + common_sampler_sample(smpl, ctx_dft, i_last[seq_id], true); + const float * h_row = llama_get_embeddings_nextn_ith(ctx_dft, i_last[seq_id]); const auto * cur_p = common_sampler_get_candidates(smpl, true); @@ -1210,30 +1268,41 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { continue; } - if (is_mem_shared) { + if (chain_heads) { + // ref: https://github.com/ggml-org/llama.cpp/pull/24340#discussion_r3448031546 + chain_h[seq_id].insert(chain_h[seq_id].end(), h_row, h_row + n_embd); + + const int n_rows = (int) result.size() + 1; // id_last + tokens drafted so far + for (int t = 0; t < n_rows; ++t) { + const llama_token tok = (t == 0) ? dp.id_last : result[t - 1]; + common_batch_add(batch, tok, dp.n_past + t, { seq_id }, t == n_rows - 1); + std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd, + chain_h[seq_id].data() + (size_t) t * n_embd, row_bytes); + } + } else if (is_mem_shared) { // note: with shared memory (e.g. Gemma4 assistants) we use the same position for all draft tokens // ref: https://github.com/huggingface/transformers/blob/effde20942e3f82a1b97449f60b3a48c5ff96145/docs/source/en/model_doc/gemma4_assistant.md?plain=1#L36-L37 common_batch_add(batch, id, dp.n_past, { seq_id }, true); + std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd, h_row, row_bytes); } else { common_batch_add(batch, id, dp.n_past + i + 1, { seq_id }, true); + std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd, h_row, row_bytes); } - std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes); + + i_last[seq_id] = batch.n_tokens - 1; } if (batch.n_tokens == 0) { break; } - // evaluate the drafted tokens on the draft model - ret = llama_decode(ctx_dft, batch); - if (ret != 0) { - LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret); - break; - } - ++i; } + if (chain_heads) { + llama_set_nextn_layer_offset(ctx_dft, 0); // restore default for non-draft decodes + } + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { auto & dp = dparams[seq_id]; if (!dp.drafting) { @@ -1243,8 +1312,6 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { if (dp.result->size() < (size_t) params.n_min) { dp.result->clear(); } - - last_n_drafted[seq_id] = (uint16_t) dp.result->size(); } } @@ -1857,7 +1924,7 @@ common_speculative * common_speculative_init(common_params_speculative & params, bool has_draft_simple = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE)); bool has_draft_eagle3 = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3)) && params.draft.ctx_dft != nullptr; - bool has_mtp = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_MTP)) && params.draft.ctx_dft != nullptr; + bool has_draft_mtp = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_MTP)) && params.draft.ctx_dft != nullptr; @@ -1895,7 +1962,7 @@ common_speculative * common_speculative_init(common_params_speculative & params, if (has_draft_eagle3) { configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, params)); } - if (has_mtp) { + if (has_draft_mtp) { configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_MTP, params)); } } diff --git a/include/llama.h b/include/llama.h index 27e4806742..f723c9f60c 100644 --- a/include/llama.h +++ b/include/llama.h @@ -558,14 +558,15 @@ extern "C" { LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model); LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model); - LLAMA_API int32_t llama_model_n_ctx_train(const struct llama_model * model); - LLAMA_API int32_t llama_model_n_embd (const struct llama_model * model); - LLAMA_API int32_t llama_model_n_embd_inp (const struct llama_model * model); - LLAMA_API int32_t llama_model_n_embd_out (const struct llama_model * model); - LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model); - LLAMA_API int32_t llama_model_n_head (const struct llama_model * model); - LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model); - LLAMA_API int32_t llama_model_n_swa (const struct llama_model * model); + LLAMA_API int32_t llama_model_n_ctx_train (const struct llama_model * model); + LLAMA_API int32_t llama_model_n_embd (const struct llama_model * model); + LLAMA_API int32_t llama_model_n_embd_inp (const struct llama_model * model); + LLAMA_API int32_t llama_model_n_embd_out (const struct llama_model * model); + LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model); + LLAMA_API int32_t llama_model_n_layer_nextn(const struct llama_model * model); + LLAMA_API int32_t llama_model_n_head (const struct llama_model * model); + LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model); + LLAMA_API int32_t llama_model_n_swa (const struct llama_model * model); // Get the model's RoPE frequency scaling factor LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model); diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 529bc4a5e9..220240ea95 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1156,6 +1156,10 @@ void llama_context::set_embeddings_layer_inp(uint32_t lid, bool enable) { sched_need_reserve = true; } +void llama_context::set_nextn_layer_offset(int32_t offset) { + cparams.nextn_layer_offset = offset; +} + void llama_context::set_causal_attn(bool value) { LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); @@ -3699,6 +3703,10 @@ void llama_set_embeddings_layer_inp(llama_context * ctx, uint32_t lid, bool valu ctx->set_embeddings_layer_inp(lid, value); } +void llama_set_nextn_layer_offset(llama_context * ctx, int32_t offset) { + ctx->set_nextn_layer_offset(offset); +} + llama_memory_t llama_get_memory(const struct llama_context * ctx) { if (!ctx) { return nullptr; diff --git a/src/llama-context.h b/src/llama-context.h index 853052be2c..f8b7805871 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -115,6 +115,7 @@ struct llama_context { void set_embeddings (bool value); void set_embeddings_nextn(bool value, bool masked); void set_embeddings_layer_inp(uint32_t lid, bool enable); + void set_nextn_layer_offset(int32_t offset); void set_causal_attn(bool value); void set_warmup(bool value); diff --git a/src/llama-cparams.h b/src/llama-cparams.h index 2b109f909c..546ae1e2c1 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -18,6 +18,8 @@ struct llama_cparams { int32_t n_threads; // number of threads to use for generation int32_t n_threads_batch; // number of threads to use for batch processing + int32_t nextn_layer_offset = 0; + float rope_freq_base; float rope_freq_scale; diff --git a/src/llama-ext.h b/src/llama-ext.h index 8b5679b690..348bbae957 100644 --- a/src/llama-ext.h +++ b/src/llama-ext.h @@ -95,6 +95,11 @@ LLAMA_API llama_memory_breakdown llama_get_memory_breakdown(const struct llama_c // If masked == false, output the embeddings for all tokens in the batch regardless of batch.logits LLAMA_API void llama_set_embeddings_nextn(struct llama_context * ctx, bool value, bool masked); +// Select which appended NextN block the DECODER_MTP graph runs (offset past +// the trunk: il = n_layer() + offset). Used by the speculative NextN driver to +// chain multiple trained NextN heads. Default 0 (first head). +LLAMA_API void llama_set_nextn_layer_offset(struct llama_context * ctx, int32_t offset); + // mirrors: // LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); LLAMA_API float * llama_get_embeddings_nextn(struct llama_context * ctx); diff --git a/src/llama-graph.h b/src/llama-graph.h index 5e8a658350..a6e8c3985b 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -682,9 +682,16 @@ struct llm_graph_params { } } + // TODO: https://github.com/ggml-org/llama.cpp/pull/24340#discussion_r3448035248 + if (cparams.nextn_layer_offset != other.cparams.nextn_layer_offset) { + return false; + } + return - cparams.embeddings == other.cparams.embeddings && - cparams.causal_attn == other.cparams.causal_attn && + cparams.embeddings == other.cparams.embeddings && + cparams.embeddings_nextn == other.cparams.embeddings_nextn && + cparams.embeddings_nextn_masked == other.cparams.embeddings_nextn_masked && + cparams.causal_attn == other.cparams.causal_attn && arch == other.arch && gtype == other.gtype && cvec == other.cvec && diff --git a/src/llama-model.cpp b/src/llama-model.cpp index c528755339..d041a9ce3e 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2312,6 +2312,10 @@ int32_t llama_model_n_layer(const llama_model * model) { return model->hparams.n_layer(); } +int32_t llama_model_n_layer_nextn(const llama_model * model) { + return model->hparams.n_layer_nextn; +} + int32_t llama_model_n_head(const llama_model * model) { return model->hparams.n_head(); } diff --git a/src/models/step35.cpp b/src/models/step35.cpp index e2218c5870..9b7b18a367 100644 --- a/src/models/step35.cpp +++ b/src/models/step35.cpp @@ -112,7 +112,7 @@ void llama_model_step35::load_arch_tensors(llama_model_loader & ml) { layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, TENSOR_NOT_REQUIRED); }; - auto load_block_mtp = [&](int i, bool is_first_mtp) { + auto load_block_mtp = [&](int i) { auto & layer = layers[i]; const uint32_t n_head_l = hparams.n_head(i); @@ -121,15 +121,12 @@ void llama_model_step35::load_arch_tensors(llama_model_loader & ml) { // The MTP block is a full Step3p5 decoder layer (mtp_block) plus the // NextN-specific wiring (enorm/hnorm/eh_proj + optional shared head). - // `mtp_flags` becomes NOT_REQUIRED when the GGUF is trunk-only. - // - // Only the FIRST MTP block (i == n_main) is required for the - // single-block MTP runtime; trailing MTP blocks are always tolerated - // as missing so pruned GGUFs (block 0 only) load cleanly. Override - // mtp_flags to NOT_REQUIRED for those. - const int eff_mtp_flags = is_first_mtp ? mtp_flags : (mtp_flags | TENSOR_NOT_REQUIRED); + // Multi-block MTP: every declared MTP block is required (the draft chain + // runs all n_layer_nextn heads), so each block uses the captured + // `mtp_flags` directly — already NOT_REQUIRED for a trunk-only GGUF, + // which keeps that path correct. - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, eff_mtp_flags); + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, mtp_flags); layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED); layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED); @@ -140,12 +137,12 @@ void llama_model_step35::load_arch_tensors(llama_model_loader & ml) { layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | TENSOR_DUPLICATED); } - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_l, n_embd_k_gqa, n_embd_v_gqa, eff_mtp_flags); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v * n_head_l, n_embd}, eff_mtp_flags); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_l, n_embd_k_gqa, n_embd_v_gqa, mtp_flags); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v * n_head_l, n_embd}, mtp_flags); layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), {n_embd, n_head_l}, TENSOR_NOT_REQUIRED); - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, eff_mtp_flags); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, mtp_flags); // dense MLP (leading dense blocks) — present if the MTP block isn't MoE layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); @@ -165,9 +162,9 @@ void llama_model_step35::load_arch_tensors(llama_model_loader & ml) { layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, TENSOR_NOT_REQUIRED); // NextN-specific tensors that define the MTP block. - layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, eff_mtp_flags); - layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, eff_mtp_flags); - layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, eff_mtp_flags); + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, mtp_flags); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, mtp_flags); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, mtp_flags); layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED); @@ -176,13 +173,11 @@ void llama_model_step35::load_arch_tensors(llama_model_loader & ml) { for (int i = 0; i < n_layer; ++i) { load_block_trunk(i, trunk_flags); } - // Only the first MTP block (i == n_main) is required at runtime — the - // single-block-MTP graph in build_arch_graph always uses that one. - // Trailing MTP blocks are loaded if present (so an un-pruned GGUF with - // all MTP layers still works) but tolerated when absent via the pruning - // path. See scripts/prune_step35_extra_mtp.py for the pruner. + // All n_layer_nextn MTP blocks are required — the multi-block draft chain + // runs every head (head k at offset k). The GGUF declares the count via + // step35.nextn_predict_layers. for (int i = n_layer; i < n_layer_all; ++i) { - load_block_mtp(i, /*is_first_mtp=*/ i == n_layer); + load_block_mtp(i); } } @@ -372,13 +367,14 @@ llama_model_step35::graph_mtp::graph_mtp(const llama_model & model, const llm_gr : llm_graph_context(params) { GGML_ASSERT(hparams.n_layer_nextn > 0 && "STEP35 MTP requires n_layer_nextn > 0"); - // Single-block MTP only: always run the first trained MTP block (Qwen - // MTP / vLLM single-MTP-layer style). Multi-block round-robin proved to - // be a much deeper refactor than this PR justifies; the trailing MTP - // blocks are loaded with TENSOR_NOT_REQUIRED so pruned GGUFs (with just - // block 0) also work — see load_arch_tensors below and - // scripts/prune_step35_extra_mtp.py. - const int il = hparams.n_layer(); + // Multi-block MTP: the DECODER_MTP graph runs the MTP head selected by + // cparams.nextn_layer_offset (0 = first trained head). The speculative driver + // bumps the offset per draft step to chain heads 45->46->47. offset 0 keeps + // single-block behavior identical to before. + const int il = hparams.n_layer() + cparams.nextn_layer_offset; + GGML_ASSERT(cparams.nextn_layer_offset >= 0 && + cparams.nextn_layer_offset < (int) hparams.n_layer_nextn && + "nextn_layer_offset out of range [0, n_layer_nextn)"); const auto & layer = model.layers[il]; GGML_ASSERT(layer.nextn.eh_proj && "MTP block missing nextn.eh_proj"); @@ -536,6 +532,9 @@ llama_model_step35::graph_mtp::graph_mtp(const llama_model & model, const llm_gr cur = ggml_add(ctx0, cur, ffn_inp); cb(cur, "mtp_post_ffn", il); + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + // Pre-norm hidden state: used by the AR draft loop to seed the next MTP step. cb(cur, "h_nextn", -1); res->t_h_nextn = cur;