spec: support MTP

This commit is contained in:
Aman Gupta 2026-05-11 11:18:17 +08:00
parent 634275fbbb
commit a55493bbda
24 changed files with 1210 additions and 45 deletions

View File

@ -159,6 +159,7 @@ enum common_speculative_type {
COMMON_SPECULATIVE_TYPE_NONE, // no speculative decoding
COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE, // standalone draft model speculative decoding
COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, // Eagle3 speculative decoding
COMMON_SPECULATIVE_TYPE_MTP, // multi-token prediction head loaded from the target GGUF
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding based on n-grams
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, // self-speculative decoding with n-gram keys only
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, // self-speculative decoding with n-gram keys and 4 m-gram values

View File

@ -3,6 +3,7 @@
#include "common.h"
#include "ggml.h"
#include "llama.h"
#include "../src/llama-ext.h" // staging API: llama_set_embeddings_pre_norm / llama_get_embeddings_pre_norm_ith (used by MTP)
#include "log.h"
#include "ngram-cache.h"
#include "ngram-map.h"
@ -23,6 +24,7 @@ const std::map<std::string, common_speculative_type> common_speculative_type_fro
{"none", COMMON_SPECULATIVE_TYPE_NONE},
{"draft-simple", COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE},
{"draft-eagle3", COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3},
{"mtp", COMMON_SPECULATIVE_TYPE_MTP},
{"ngram-simple", COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE},
{"ngram-map-k", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K},
{"ngram-map-k4v", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V},
@ -364,6 +366,330 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
}
};
struct common_speculative_state_mtp : public common_speculative_impl {
common_params_speculative_draft params; // reuses the draft-model params slot (ctx_tgt/ctx_dft)
llama_batch batch;
std::vector<common_sampler_ptr> smpls;
int32_t n_embd = 0;
// Per-sequence cross-batch carryover: pair (h_p, x_{p+1}) at MTP pos p+1.
// The last h-row of one process() call needs the first token of the NEXT
// call to pair with, so it's stashed here until that next call fires.
std::vector<std::vector<float>> pending_h; // [n_seq][n_embd]
std::vector<llama_pos> pending_pos; // [n_seq]
std::vector<uint16_t> last_n_drafted;
std::vector<int32_t> last_n_accepted;
// Number of trunk output rows produced by the most recent process() call.
// Used by draft() for the first AR step (when last_n_accepted is -1) to
// pick the last prefill row out of ctx_tgt's pre-norm buffer.
std::vector<int32_t> last_trunk_n_outputs;
common_speculative_state_mtp(const common_params_speculative & params, uint32_t n_seq)
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_MTP, n_seq)
, params(params.draft)
{
GGML_ASSERT(n_seq == 1 && "MTP currently supports only single-sequence speculation");
auto * ctx_tgt = this->params.ctx_tgt;
auto * ctx_dft = this->params.ctx_dft;
GGML_ASSERT(ctx_tgt && ctx_dft && "MTP requires ctx_tgt and ctx_dft to be set");
n_embd = llama_model_n_embd(llama_get_model(ctx_dft));
const int32_t n_ub = (int32_t) llama_n_ubatch(ctx_dft);
batch = llama_batch_init(/*n_tokens=*/ n_ub, /*embd=*/ n_embd, /*n_seq_max=*/ 1);
// llama_batch_init allocates only one of token/embd; MTP needs both.
// TODO: fix, how to call without malloc
batch.token = (llama_token *) malloc(sizeof(llama_token) * n_ub);
smpls.resize(n_seq);
for (auto & s : smpls) {
common_params_sampling sparams;
sparams.no_perf = false;
sparams.top_k = 1;
sparams.samplers = { COMMON_SAMPLER_TYPE_TOP_K };
s.reset(common_sampler_init(llama_get_model(ctx_dft), sparams));
}
llama_set_embeddings_pre_norm(ctx_tgt, true);
llama_set_embeddings_pre_norm(ctx_dft, true);
pending_h.assign(n_seq, std::vector<float>(n_embd, 0.0f));
pending_pos.assign(n_seq, -1);
last_n_drafted.assign(n_seq, 0);
last_n_accepted.assign(n_seq, -1);
last_trunk_n_outputs.assign(n_seq, 0);
}
~common_speculative_state_mtp() override {
if (batch.token != nullptr) {
free(batch.token);
batch.token = nullptr;
}
llama_batch_free(batch);
}
void begin(llama_seq_id seq_id, const llama_tokens & prompt) override {
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < pending_pos.size());
last_n_accepted[seq_id] = -1;
last_n_drafted [seq_id] = 0;
pending_pos [seq_id] = -1;
const int32_t N = (int32_t) prompt.size();
if (N <= 0) {
return;
}
auto * ctx_dft = this->params.ctx_dft;
const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id);
if (pos_max < N - 1) {
LOG_WRN("%s: ctx_dft pos_max=%d < N-1=%d — "
"process() hook may not have run on every prefill ubatch "
"(need_embd / logits=1 on every prompt position?). "
"Drafts may degrade.\n",
__func__, (int) pos_max, N - 1);
}
}
bool process(const llama_batch & batch_in) override {
if (batch_in.n_tokens <= 0) {
return true;
}
// Single-seq for now (asserted in ctor). Future: bucket by seq_id.
const llama_seq_id seq_id = 0;
// TODO: how to make it work with vision tokens?
if (batch_in.token == nullptr || batch_in.embd != nullptr) {
pending_pos[seq_id] = -1;
return true;
}
auto * ctx_tgt = this->params.ctx_tgt;
auto * ctx_dft = this->params.ctx_dft;
const int32_t n_rows = batch_in.n_tokens;
const llama_pos pos_start = batch_in.pos[0];
const llama_pos pos_max_dft = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id);
if (pos_start <= pos_max_dft) {
return true;
}
// Stale pending: discard if the new batch doesn't start one past it.
const bool pending_continues =
pending_pos[seq_id] >= 0 && pending_pos[seq_id] + 1 == pos_start;
if (pending_pos[seq_id] >= 0 && !pending_continues) {
pending_pos[seq_id] = -1;
}
// Build a paired hook batch:
// row 0 = (pending_h, batch_in.token[0]) at pos_start if pending_continues
// rows 1..n_rows-1 = (h_k from this batch, batch_in.token[k+1]) at pos[k+1]
// The last h-row (h_{n_rows-1}) is stashed as the new pending and is *not*
// decoded this call — it waits for the next batch's first token to pair.
const size_t row_bytes = (size_t) n_embd * sizeof(float);
common_batch_clear(batch);
int out_idx = 0;
auto add_pair = [&](const float * h_row, llama_token tok, llama_pos pos) {
std::memcpy(batch.embd + (size_t) out_idx * n_embd, h_row, row_bytes);
batch.token [out_idx] = tok;
batch.pos [out_idx] = pos;
batch.n_seq_id[out_idx] = 1;
batch.seq_id [out_idx][0] = seq_id;
batch.logits [out_idx] = 0;
++out_idx;
};
if (pending_continues) {
add_pair(pending_h[seq_id].data(), batch_in.token[0], pos_start);
}
// TODO: is there is a fast way to build this batch?
for (int k = 0; k + 1 < n_rows; ++k) {
if (batch_in.logits[k] == 0) {
LOG_WRN("%s: batch_in.logits[%d] == 0 (need_embd / logits=1 missing on prefill); stopping hook at this row\n",
__func__, k);
break;
}
const float * h_k = llama_get_embeddings_pre_norm_ith(ctx_tgt, k);
if (h_k == nullptr) {
LOG_WRN("%s: ctx_tgt has no pre-norm row at i=%d; stopping hook\n", __func__, k);
break;
}
add_pair(h_k, batch_in.token[k + 1], batch_in.pos[k + 1]);
}
if (out_idx > 0) {
batch.n_tokens = out_idx;
const int32_t rc = llama_decode(ctx_dft, batch);
if (rc != 0) {
LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (pos=%d, n=%d)\n",
__func__, (int) rc, (int) pos_start, out_idx);
return false;
}
}
// last_n_accepted < 0) can find the last pre-norm row of this batch.
// We assume every batch position has logits=1 (server sets need_embd
// for MTP slots) → n_outputs == n_tokens.
last_trunk_n_outputs[seq_id] = n_rows;
// Stash the last h-row (h_{n_rows-1}) as the new pending for the next
// process() call's first token to pair with.
if (batch_in.logits[n_rows - 1] != 0) {
const float * h_last = llama_get_embeddings_pre_norm_ith(ctx_tgt, n_rows - 1);
if (h_last != nullptr) {
std::memcpy(pending_h[seq_id].data(), h_last, row_bytes);
pending_pos[seq_id] = batch_in.pos[n_rows - 1];
} else {
pending_pos[seq_id] = -1;
}
} else {
// No trunk output at the tail — can't carry over.
pending_pos[seq_id] = -1;
}
return true;
}
void draft(common_speculative_draft_params_vec & dparams) override {
// Single-seq for now (asserted in ctor). Future: iterate over dparams.
const llama_seq_id seq_id = 0;
if ((size_t) seq_id >= dparams.size()) {
return;
}
auto & dp = dparams[seq_id];
if (!dp.drafting) {
return;
}
auto * ctx_tgt = this->params.ctx_tgt;
auto * ctx_dft = this->params.ctx_dft;
auto * smpl = smpls[seq_id].get();
GGML_ASSERT(dp.result != nullptr);
auto & draft_tokens = *dp.result;
draft_tokens.clear();
if (last_n_drafted[seq_id] > 0) {
const int32_t n_to_drop = (int32_t) last_n_drafted[seq_id] - 1;
if (n_to_drop > 0) {
const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id);
if (pos_max >= 0) {
const llama_pos drop_from = pos_max - n_to_drop + 1;
llama_memory_seq_rm(llama_get_memory(ctx_dft), seq_id, drop_from, -1);
}
}
last_n_drafted[seq_id] = 0;
last_n_accepted[seq_id] = 0;
}
// Effective draft length: min(global cap, per-sequence override).
int32_t n_max = std::max(1, params.n_max);
if (dp.n_max > 0) {
n_max = std::min(n_max, dp.n_max);
}
const size_t row_bytes = (size_t) n_embd * sizeof(float);
common_sampler_reset(smpl);
llama_token cond_tok = dp.id_last;
llama_pos pos = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id) + 1;
for (int32_t k = 0; k < n_max; ++k) {
const float * h_row = nullptr;
if (k == 0) {
// Condition on the trunk's pre-norm row.
int32_t row_idx;
if (last_n_accepted[seq_id] < 0) {
// First draft after begin(): use the last prefill row.
row_idx = last_trunk_n_outputs[seq_id] - 1;
} else {
// After accept(n_accepted): row of the next conditioning
// position in the trunk's verify batch.
row_idx = last_n_accepted[seq_id];
}
if (row_idx < 0) {
LOG_WRN("%s: no trunk pre-norm row available (row_idx=%d); stopping chain\n",
__func__, row_idx);
break;
}
h_row = llama_get_embeddings_pre_norm_ith(ctx_tgt, row_idx);
} else {
// AR step: condition on the MTP head's own pre-norm row from
// the just-completed single-token decode. n_outputs=1 there,
// so the row is at batch position 0.
h_row = llama_get_embeddings_pre_norm_ith(ctx_dft, 0);
}
if (h_row == nullptr) {
LOG_WRN("%s: missing pre-norm row at k=%d; stopping chain\n", __func__, k);
break;
}
// 1-token batch carrying both (token, h_pre_norm).
common_batch_clear(batch);
std::memcpy(batch.embd, h_row, row_bytes);
batch.token [0] = cond_tok;
batch.pos [0] = pos;
batch.n_seq_id[0] = 1;
batch.seq_id [0][0] = seq_id;
batch.logits [0] = 1; // need logits for sampling
batch.n_tokens = 1;
const int32_t rc = llama_decode(ctx_dft, batch);
if (rc != 0) {
LOG_WRN("%s: llama_decode(ctx_dft) failed rc=%d at k=%d; stopping chain\n",
__func__, rc, k);
break;
}
const llama_token best = common_sampler_sample(smpl, ctx_dft, 0);
common_sampler_accept(smpl, best, /*is_generated=*/ false);
draft_tokens.push_back(best);
cond_tok = best;
++pos;
}
last_n_drafted[seq_id] = (uint16_t) draft_tokens.size();
}
void accept(llama_seq_id seq_id, uint16_t n_accepted) override {
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < last_n_drafted.size());
auto * ctx_dft = this->params.ctx_dft;
const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id);
const int32_t n_drafted_last = (int32_t) last_n_drafted[seq_id];
const int32_t n_to_drop = std::max(0, n_drafted_last - (int32_t) n_accepted - 1);
if (pos_max < 0) {
last_n_accepted[seq_id] = (int32_t) n_accepted;
return;
}
if (n_to_drop > 0) {
const llama_pos drop_from = pos_max - n_to_drop + 1;
llama_memory_seq_rm(llama_get_memory(ctx_dft), seq_id, drop_from, -1);
}
last_n_drafted [seq_id] = 0;
last_n_accepted[seq_id] = (int32_t) n_accepted;
}
};
// state of self-speculation (simple implementation, not ngram-map)
struct common_speculative_impl_ngram_simple : public common_speculative_impl {
common_params_speculative_ngram_map params;
@ -820,6 +1146,7 @@ std::string common_speculative_type_to_str(common_speculative_type type) {
case COMMON_SPECULATIVE_TYPE_NONE: return "none";
case COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE: return "draft-simple";
case COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3: return "draft-eagle3";
case COMMON_SPECULATIVE_TYPE_MTP: return "mtp";
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return "ngram-simple";
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: return "ngram-map-k";
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: return "ngram-map-k4v";
@ -875,8 +1202,8 @@ common_speculative * common_speculative_init(common_params_speculative & params,
bool has_draft_model_path = !params.draft.mparams.path.empty();
bool has_draft_simple = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE));
// bool has_mtp = false; // TODO: add MTP here
bool has_draft_eagle3 = false; // TODO PR-18039: if params.speculative.eagle3
bool has_mtp = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_MTP)) && params.draft.ctx_dft != nullptr;
bool has_ngram_cache = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_NGRAM_CACHE));
bool has_ngram_simple = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE));
@ -885,7 +1212,7 @@ common_speculative * common_speculative_init(common_params_speculative & params,
bool has_ngram_mod = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_NGRAM_MOD));
// when adding a new type - update here the logic above
static_assert(COMMON_SPECULATIVE_TYPE_COUNT == 8);
static_assert(COMMON_SPECULATIVE_TYPE_COUNT == 9);
// this list here defines the priority of the speculators
// the one with highest priority are listed first
@ -919,10 +1246,12 @@ common_speculative * common_speculative_init(common_params_speculative & params,
if (has_draft_simple) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE, params));
}
// TODO: add MTP here
if (has_draft_eagle3) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, params));
}
if (has_mtp) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_MTP, params));
}
}
std::vector<std::unique_ptr<common_speculative_impl>> impls = {};
@ -940,6 +1269,10 @@ common_speculative * common_speculative_init(common_params_speculative & params,
impls.push_back(std::make_unique<common_speculative_impl_draft_eagle3>(config.params, n_seq));
break;
}
case COMMON_SPECULATIVE_TYPE_MTP: {
impls.push_back(std::make_unique<common_speculative_state_mtp>(config.params, n_seq));
break;
}
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: {
common_ngram_map ngram_map = get_common_ngram_map(config.type, config.params.ngram_simple);

View File

@ -5549,13 +5549,70 @@ class _Qwen35MRopeMixin:
self.gguf_writer.add_rope_dimension_sections(self._QWEN35_DEFAULT_MROPE_SECTION)
class _Qwen35MtpMixin:
"""Shared MTP wiring for Qwen3.5/3.6 text variants. The HF config carries
the MTP block under `mtp_num_hidden_layers` and the tensors under
`mtp.*`; we extend block_count, emit the nextn metadata key, and remap
`mtp.*` to the standard layer-indexed nextn naming so the existing
tensor_map handles them."""
# Class-level annotations so the type checker understands the attributes
# available on the concrete subclasses in the MRO
hparams: dict[str, Any]
model_arch: gguf.MODEL_ARCH
gguf_writer: gguf.GGUFWriter
block_count: int
tensor_map: gguf.TensorNameMap
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.block_count = self.hparams["num_hidden_layers"] + self.hparams.get("mtp_num_hidden_layers", 0)
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
def set_gguf_parameters(self):
super().set_gguf_parameters() # ty: ignore[unresolved-attribute]
if (n := self.hparams.get("mtp_num_hidden_layers", 0)) > 0:
self.gguf_writer.add_nextn_predict_layers(n)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# Multimodal Qwen3.5/3.6 wrap the text model under `model.language_model.*`.
if name.startswith("model.language_model."):
name = "model." + name[len("model.language_model."):]
elif name.startswith("language_model."):
name = name[len("language_model."):]
# Remap MTP block tensors to llama.cpp's layer-indexed nextn naming.
# HF: mtp.layers.0.* (transformer block at MTP slot 0)
# mtp.fc / mtp.pre_fc_norm_embedding / mtp.pre_fc_norm_hidden / mtp.norm
if name.startswith("mtp."):
n_layer = self.hparams["num_hidden_layers"]
if name.find("layers.") != -1:
assert bid is not None
name = name.replace(f"mtp.layers.{bid}", f"model.layers.{bid + n_layer}")
else:
remapper = {
"mtp.fc": "model.layers.{bid}.eh_proj",
"mtp.pre_fc_norm_embedding": "model.layers.{bid}.enorm",
"mtp.pre_fc_norm_hidden": "model.layers.{bid}.hnorm",
"mtp.norm": "model.layers.{bid}.shared_head.norm",
}
stem = Path(name).stem
suffix = Path(name).suffix
tmpl = remapper[stem] + suffix
for b in range(n_layer, self.block_count):
yield from super().modify_tensors(data_torch, tmpl.format(bid=b), b) # ty: ignore[unresolved-attribute]
return
yield from super().modify_tensors(data_torch, name, bid) # ty: ignore[unresolved-attribute]
@ModelBase.register("Qwen3_5ForConditionalGeneration", "Qwen3_5ForCausalLM")
class Qwen3_5TextModel(_Qwen35MRopeMixin, _LinearAttentionVReorderBase):
class Qwen3_5TextModel(_Qwen35MtpMixin, _Qwen35MRopeMixin, _LinearAttentionVReorderBase):
model_arch = gguf.MODEL_ARCH.QWEN35
@ModelBase.register("Qwen3_5MoeForConditionalGeneration", "Qwen3_5MoeForCausalLM")
class Qwen3_5MoeTextModel(_Qwen35MRopeMixin, _LinearAttentionVReorderBase):
class Qwen3_5MoeTextModel(_Qwen35MtpMixin, _Qwen35MRopeMixin, _LinearAttentionVReorderBase):
model_arch = gguf.MODEL_ARCH.QWEN35MOE

View File

@ -2114,7 +2114,14 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.SSM_NORM,
MODEL_TENSOR.SSM_BETA,
MODEL_TENSOR.SSM_ALPHA,
MODEL_TENSOR.SSM_OUT
MODEL_TENSOR.SSM_OUT,
# NextN/MTP tensors - preserved but unused
MODEL_TENSOR.NEXTN_EH_PROJ,
MODEL_TENSOR.NEXTN_EMBED_TOKENS,
MODEL_TENSOR.NEXTN_ENORM,
MODEL_TENSOR.NEXTN_HNORM,
MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD,
MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM,
],
MODEL_ARCH.QWEN35MOE: [
MODEL_TENSOR.TOKEN_EMBD,
@ -2145,7 +2152,14 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.SSM_NORM,
MODEL_TENSOR.SSM_BETA,
MODEL_TENSOR.SSM_ALPHA,
MODEL_TENSOR.SSM_OUT
MODEL_TENSOR.SSM_OUT,
# NextN/MTP tensors - preserved but unused
MODEL_TENSOR.NEXTN_EH_PROJ,
MODEL_TENSOR.NEXTN_EMBED_TOKENS,
MODEL_TENSOR.NEXTN_ENORM,
MODEL_TENSOR.NEXTN_HNORM,
MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD,
MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM,
],
MODEL_ARCH.PLAMO: [
MODEL_TENSOR.TOKEN_EMBD,

View File

@ -310,6 +310,9 @@ extern "C" {
// override key-value pairs of the model meta data
const struct llama_model_kv_override * kv_overrides;
// override architecture from GGUF (e.g. load the MTP head of a Qwen3.5 GGUF as "qwen35_mtp")
const char * override_arch;
// Keep the booleans together to avoid misalignment during copy-by-value.
bool vocab_only; // only load the vocabulary, no weights
bool use_mmap; // use mmap if possible

View File

@ -41,6 +41,8 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_QWEN3VLMOE, "qwen3vlmoe" },
{ LLM_ARCH_QWEN35, "qwen35" },
{ LLM_ARCH_QWEN35MOE, "qwen35moe" },
{ LLM_ARCH_QWEN35_MTP, "qwen35_mtp" },
{ LLM_ARCH_QWEN35MOE_MTP, "qwen35moe_mtp" },
{ LLM_ARCH_PHI2, "phi2" },
{ LLM_ARCH_PHI3, "phi3" },
{ LLM_ARCH_PHIMOE, "phimoe" },
@ -757,14 +759,15 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_INDEXER_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_INDEXER_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_INDEXER_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
// NextN/MTP tensors are currently ignored (reserved for future MTP support)
// These tensors only exist in the last layer(s) and are treated as output tensors
{LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}},
{LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}},
{LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
{LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
// NextN/MTP tensors are stored per-block (blk.%d.nextn.*) even though only the
// last nextn_predict_layers blocks carry them. Classify as LAYER_REPEATING so
// the model loader doesn't fault on the block index.
{LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}},
{LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}},
{LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
// Nemotron 3 Super
{LLM_TENSOR_FFN_LATENT_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_FFN_LATENT_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},

View File

@ -45,6 +45,8 @@ enum llm_arch {
LLM_ARCH_QWEN3VLMOE,
LLM_ARCH_QWEN35,
LLM_ARCH_QWEN35MOE,
LLM_ARCH_QWEN35_MTP,
LLM_ARCH_QWEN35MOE_MTP,
LLM_ARCH_PHI2,
LLM_ARCH_PHI3,
LLM_ARCH_PHIMOE,

View File

@ -49,6 +49,7 @@ llama_context::llama_context(
cparams.yarn_beta_fast = params.yarn_beta_fast >= 0.0f ? params.yarn_beta_fast : hparams.yarn_beta_fast;
cparams.yarn_beta_slow = params.yarn_beta_slow >= 0.0f ? params.yarn_beta_slow : hparams.yarn_beta_slow;
cparams.embeddings = params.embeddings;
cparams.embeddings_pre_norm = false;
cparams.offload_kqv = params.offload_kqv;
cparams.no_perf = params.no_perf;
cparams.pooling_type = params.pooling_type;
@ -860,6 +861,33 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) {
return it->second.data();
}
float * llama_context::get_embeddings_pre_norm() {
output_reorder();
return embd_pre_norm.data;
}
float * llama_context::get_embeddings_pre_norm_ith(int32_t i) {
output_reorder();
try {
if (embd_pre_norm.data == nullptr) {
throw std::runtime_error("no pre-norm embeddings");
}
const int64_t j = output_resolve_row(i);
const uint32_t n_embd = model.hparams.n_embd;
return embd_pre_norm.data + j*n_embd;
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: invalid pre-norm embeddings id %d, reason: %s\n", __func__, i, err.what());
#ifndef NDEBUG
GGML_ABORT("fatal error");
#else
return nullptr;
#endif
}
}
llama_token llama_context::get_sampled_token_ith(int32_t idx) {
output_reorder();
@ -1040,6 +1068,12 @@ void llama_context::set_embeddings(bool value) {
//sched_need_reserve = true;
}
void llama_context::set_embeddings_pre_norm(bool value) {
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
cparams.embeddings_pre_norm = value;
}
void llama_context::set_causal_attn(bool value) {
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
@ -1241,7 +1275,9 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll
}
int llama_context::encode(const llama_batch & batch_inp) {
GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
// MTP hook batches carry both token (next-token id) and embd (h_pre_norm row),
// so accept either present rather than requiring exactly one.
GGML_ASSERT(batch_inp.token || batch_inp.embd);
if (batch_inp.n_tokens == 0) {
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
@ -1312,8 +1348,9 @@ int llama_context::encode(const llama_batch & batch_inp) {
}
}
auto * t_logits = res->get_logits();
auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
auto * t_logits = res->get_logits();
auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
auto * t_h_pre_norm = cparams.embeddings_pre_norm ? res->get_h_pre_norm() : nullptr;
// extract logits
if (logits.data && t_logits) {
@ -1379,6 +1416,16 @@ int llama_context::encode(const llama_batch & batch_inp) {
}
}
// extract pre-norm embeddings (hidden state before the final output norm)
if (embd_pre_norm.data && t_h_pre_norm && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_pre_norm);
GGML_ASSERT(backend_h != nullptr);
const uint32_t n_embd = hparams.n_embd;
GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_pre_norm.size);
ggml_backend_tensor_get_async(backend_h, t_h_pre_norm, embd_pre_norm.data, 0, n_tokens*n_embd*sizeof(float));
}
// TODO: hacky solution
if (model.arch == LLM_ARCH_T5 && t_embd) {
//cross.t_embd = t_embd;
@ -1531,7 +1578,9 @@ static bool needs_raw_logits(const llama_ubatch & ubatch, const std::map<llama_s
}
int llama_context::decode(const llama_batch & batch_inp) {
GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
// MTP hook batches carry both token (next-token id) and embd (h_pre_norm row),
// so accept either present rather than requiring exactly one.
GGML_ASSERT(batch_inp.token || batch_inp.embd);
if (!memory) {
LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__);
@ -1727,8 +1776,9 @@ int llama_context::decode(const llama_batch & batch_inp) {
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
//}
auto * t_logits = res->get_logits();
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
auto * t_logits = res->get_logits();
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
auto * t_h_pre_norm = cparams.embeddings_pre_norm ? res->get_h_pre_norm() : nullptr;
if (t_embd && res->get_embd_pooled()) {
t_embd = res->get_embd_pooled();
@ -1809,6 +1859,20 @@ int llama_context::decode(const llama_batch & batch_inp) {
}
}
// extract pre-norm embeddings (hidden state before the final output norm)
// only meaningful in LLAMA_POOLING_TYPE_NONE (per-token); other pooling modes are ignored.
if (embd_pre_norm.data && t_h_pre_norm && n_outputs > 0 && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_pre_norm);
GGML_ASSERT(backend_h != nullptr);
const uint32_t n_embd = hparams.n_embd;
float * embd_pre_norm_out = embd_pre_norm.data + n_outputs_prev*n_embd;
GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_pre_norm.size);
ggml_backend_tensor_get_async(backend_h, t_h_pre_norm, embd_pre_norm_out, 0, n_outputs*n_embd*sizeof(float));
}
// Copy backend sampling output if this ubatch produced any sampling tensors.
if (has_samplers && (!res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty())) {
const auto seq_to_output_row = build_seq_to_output_row(ubatch, n_outputs_prev);
@ -1893,10 +1957,12 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
const auto n_batch = cparams.n_batch;
const auto n_vocab = vocab.n_tokens();
const auto n_embd = hparams.n_embd;
const auto n_embd_out = hparams.n_embd_out();
bool has_logits = true;
bool has_embd = cparams.embeddings;
bool has_logits = true;
bool has_embd = cparams.embeddings;
bool has_embd_pre_norm = cparams.embeddings_pre_norm;
// TODO: hacky enc-dec support
if (model.arch == LLM_ARCH_T5) {
@ -1908,8 +1974,9 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
size_t backend_float_count = 0;
size_t backend_token_count = 0;
logits.size = has_logits ? n_vocab*n_outputs_max : 0;
embd.size = has_embd ? n_embd_out*n_outputs_max : 0;
logits.size = has_logits ? n_vocab*n_outputs_max : 0;
embd.size = has_embd ? n_embd_out*n_outputs_max : 0;
embd_pre_norm.size = has_embd_pre_norm ? n_embd*n_outputs_max : 0;
// Allocate backend sampling output buffers if there are backend samplers configured.
const bool has_sampling = !sampling.samplers.empty();
@ -1925,8 +1992,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0;
const size_t new_size =
(logits.size + embd.size + backend_float_count) * sizeof(float) +
( backend_token_count) * sizeof(llama_token);
(logits.size + embd.size + embd_pre_norm.size + backend_float_count) * sizeof(float) +
( backend_token_count) * sizeof(llama_token);
// alloc only when more than the current capacity is required
// TODO: also consider shrinking the buffer
@ -1942,6 +2009,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
buf_output = nullptr;
logits.data = nullptr;
embd.data = nullptr;
embd_pre_norm.data = nullptr;
}
auto * buft = ggml_backend_cpu_buffer_type();
@ -1970,6 +2038,9 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
embd = has_embd ? buffer_view<float>{(float *) (base + offset), embd.size} : buffer_view<float>{nullptr, 0};
offset += embd.size * sizeof(float);
embd_pre_norm = has_embd_pre_norm ? buffer_view<float>{(float *) (base + offset), embd_pre_norm.size} : buffer_view<float>{nullptr, 0};
offset += embd_pre_norm.size * sizeof(float);
if (has_sampling) {
sampling.logits = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)};
offset += sampling.logits.size * sizeof(float);
@ -2034,6 +2105,12 @@ void llama_context::output_reorder() {
}
}
if (embd_pre_norm.size > 0) {
for (uint64_t k = 0; k < n_embd; k++) {
std::swap(embd_pre_norm.data[i0*n_embd + k], embd_pre_norm.data[i1*n_embd + k]);
}
}
if (!sampling.samplers.empty()) {
assert(sampling.logits.size > 0);
assert(sampling.probs.size > 0);
@ -3436,6 +3513,22 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) {
return ctx->get_embeddings_seq(seq_id);
}
void llama_set_embeddings_pre_norm(llama_context * ctx, bool value) {
ctx->set_embeddings_pre_norm(value);
}
float * llama_get_embeddings_pre_norm(llama_context * ctx) {
ctx->synchronize();
return ctx->get_embeddings_pre_norm();
}
float * llama_get_embeddings_pre_norm_ith(llama_context * ctx, int32_t i) {
ctx->synchronize();
return ctx->get_embeddings_pre_norm_ith(i);
}
bool llama_set_sampler(llama_context * ctx, llama_seq_id seq_id, llama_sampler * smpl) {
return ctx->set_sampler(seq_id, smpl);
}

View File

@ -84,6 +84,9 @@ struct llama_context {
float * get_embeddings_ith(int32_t i);
float * get_embeddings_seq(llama_seq_id seq_id);
float * get_embeddings_pre_norm();
float * get_embeddings_pre_norm_ith(int32_t i);
llama_token * get_sampled_tokens() const;
llama_token get_sampled_token_ith(int32_t idx);
@ -107,6 +110,7 @@ struct llama_context {
void set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data);
void set_embeddings (bool value);
void set_embeddings_pre_norm(bool value);
void set_causal_attn(bool value);
void set_warmup(bool value);
@ -278,6 +282,11 @@ private:
// populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
buffer_view<float> embd = {nullptr, 0};
// hidden state before the final output norm (2-dimensional array: [n_outputs][n_embd])
// populated only when cparams.embeddings_pre_norm is enabled and the model graph
// sets llm_graph_result::t_h_pre_norm
buffer_view<float> embd_pre_norm = {nullptr, 0};
struct sampling_info {
// !samplers.empty() to check if any samplers are active
std::map<llama_seq_id, llama_sampler *> samplers;

View File

@ -27,6 +27,7 @@ struct llama_cparams {
float yarn_beta_slow;
bool embeddings;
bool embeddings_pre_norm; // also extract the hidden state before the final output norm
bool causal_attn;
bool offload_kqv;
bool flash_attn;

View File

@ -88,3 +88,19 @@ LLAMA_API int32_t llama_model_n_devices(const struct llama_model * model);
LLAMA_API ggml_backend_dev_t llama_model_get_device(const struct llama_model * model, int i);
LLAMA_API llama_memory_breakdown llama_get_memory_breakdown(const struct llama_context * ctx);
//
// pre-norm embeddings (hidden state before the final output norm)
//
// mirrors:
// LLAMA_API void llama_set_embeddings(struct llama_context * ctx, bool embeddings);
LLAMA_API void llama_set_embeddings_pre_norm(struct llama_context * ctx, bool value);
// mirrors:
// LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
LLAMA_API float * llama_get_embeddings_pre_norm(struct llama_context * ctx);
// mirrors:
// LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
LLAMA_API float * llama_get_embeddings_pre_norm_ith(struct llama_context * ctx, int32_t i);

View File

@ -644,6 +644,7 @@ public:
ggml_tensor * get_logits() const { return t_logits; }
ggml_tensor * get_embd() const { return t_embd; }
ggml_tensor * get_embd_pooled() const { return t_embd_pooled; }
ggml_tensor * get_h_pre_norm() const { return t_h_pre_norm; }
ggml_cgraph * get_gf() const { return gf; }
ggml_context * get_ctx() const { return ctx_compute.get(); }
@ -672,6 +673,7 @@ public:
ggml_tensor * t_logits = nullptr;
ggml_tensor * t_embd = nullptr;
ggml_tensor * t_embd_pooled = nullptr;
ggml_tensor * t_h_pre_norm = nullptr; // [n_embd, n_outputs] hidden state before final output norm
std::map<llama_seq_id, ggml_tensor*> t_sampled_logits;
std::map<llama_seq_id, ggml_tensor*> t_candidates;

View File

@ -229,6 +229,12 @@ uint32_t llama_hparams::n_embd_head_v_mla() const {
}
bool llama_hparams::has_kv(uint32_t il) const {
if (kv_only_nextn) {
// MTP head: only the trailing nextn_predict_layers blocks own a KV cache;
// the leading trunk blocks are not executed in this graph.
return nextn_predict_layers > 0 && il >= (n_layer - nextn_predict_layers);
}
if (n_layer_kv_from_start >= 0) {
if (il < (uint32_t) n_layer_kv_from_start) {
return true;

View File

@ -92,6 +92,8 @@ struct llama_hparams {
uint32_t moe_latent_size = 0;
uint32_t nextn_predict_layers = 0;
bool kv_only_nextn = false; // if true, only the last nextn_predict_layers blocks have a KV cache (MTP head arches)
float f_norm_eps;
float f_norm_rms_eps;
float f_norm_group_eps;

View File

@ -1312,9 +1312,16 @@ struct ggml_tensor * llama_model_loader::create_tensor_as_view(struct ggml_conte
return tensor;
}
void llama_model_loader::done_getting_tensors() const {
if (n_created != n_tensors) {
throw std::runtime_error(format("%s: wrong number of tensors; expected %d, got %d", __func__, n_tensors, n_created));
void llama_model_loader::done_getting_tensors(bool partial) const {
if (n_created > n_tensors) {
throw std::runtime_error(format("%s: too many tensors created; expected %d, got %d", __func__, n_tensors, n_created));
}
if (n_created < n_tensors) {
if (!partial) {
throw std::runtime_error(format("%s: wrong number of tensors; expected %d, got %d", __func__, n_tensors, n_created));
}
LLAMA_LOG_INFO("%s: partial load — used %d of %d tensors in the file (rest belong to a sibling model on the same .gguf)\n",
__func__, n_created, n_tensors);
}
if (n_tensors_moved > 0) {
LLAMA_LOG_DEBUG("%s: tensor '%s' (%s) (and %zu others) cannot be used with preferred buffer type %s, using %s instead\n",

View File

@ -184,7 +184,7 @@ struct llama_model_loader {
struct ggml_tensor * create_tensor_as_view(struct ggml_context * ctx, struct ggml_tensor * base, const std::string & name, const std::initializer_list<int64_t> & ne, size_t offset, bool required = true);
void done_getting_tensors() const;
void done_getting_tensors(bool partial = false) const;
void init_mappings(bool prefetch = true, llama_mlocks * mlock_mmaps = nullptr);

View File

@ -276,6 +276,10 @@ static llama_model * llama_model_mapping(llm_arch arch, const llama_model_params
return new llama_model_qwen35(params);
case LLM_ARCH_QWEN35MOE:
return new llama_model_qwen35moe(params);
case LLM_ARCH_QWEN35_MTP:
return new llama_model_qwen35_mtp(params);
case LLM_ARCH_QWEN35MOE_MTP:
return new llama_model_qwen35moe_mtp(params);
case LLM_ARCH_MISTRAL3:
return new llama_model_mistral3(params);
case LLM_ARCH_MIMO2:
@ -309,6 +313,15 @@ llama_model * llama_model_create(llama_model_loader & ml, const llama_model_para
if (arch == LLM_ARCH_UNKNOWN) {
throw std::runtime_error("unknown model architecture: '" + ml.get_arch_name() + "'");
}
if (params.override_arch != nullptr && params.override_arch[0] != '\0') {
const llm_arch override = llm_arch_from_string(params.override_arch);
if (override == LLM_ARCH_UNKNOWN) {
throw std::runtime_error(std::string("unknown override architecture: '") + params.override_arch + "'");
}
LLAMA_LOG_INFO("%s: overriding architecture %s -> %s\n",
__func__, llm_arch_name(arch), llm_arch_name(override));
arch = override;
}
return llama_model_create(arch, params);
}
@ -1396,7 +1409,8 @@ bool llama_model_base::load_tensors(llama_model_loader & ml) {
}
}
ml.done_getting_tensors();
const bool partial_load = (arch == LLM_ARCH_QWEN35_MTP || arch == LLM_ARCH_QWEN35MOE_MTP);
ml.done_getting_tensors(partial_load);
// populate tensors_by_name
for (auto & [_, ctx_ptr] : ml.ctx_map) {
@ -2089,6 +2103,7 @@ llama_model_params llama_model_default_params() {
/*.progress_callback =*/ nullptr,
/*.progress_callback_user_data =*/ nullptr,
/*.kv_overrides =*/ nullptr,
/*.override_arch =*/ nullptr,
/*.vocab_only =*/ false,
/*.use_mmap =*/ true,
/*.use_direct_io =*/ false,
@ -2313,6 +2328,8 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_QWEN3VLMOE:
case LLM_ARCH_QWEN35:
case LLM_ARCH_QWEN35MOE:
case LLM_ARCH_QWEN35_MTP:
case LLM_ARCH_QWEN35MOE_MTP:
return LLAMA_ROPE_TYPE_IMROPE;
case LLM_ARCH_GLM4:

View File

@ -1785,6 +1785,32 @@ struct llama_model_qwen35moe : public llama_model_base {
};
struct llama_model_qwen35_mtp : public llama_model_base {
llama_model_qwen35_mtp(const struct llama_model_params & params) : llama_model_base(params) {}
void load_arch_hparams(llama_model_loader & ml) override;
void load_arch_tensors(llama_model_loader & ml) override;
struct graph : public llm_graph_context {
graph(const llama_model & model, const llm_graph_params & params);
};
std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override;
};
struct llama_model_qwen35moe_mtp : public llama_model_base {
llama_model_qwen35moe_mtp(const struct llama_model_params & params) : llama_model_base(params) {}
void load_arch_hparams(llama_model_loader & ml) override;
void load_arch_tensors(llama_model_loader & ml) override;
struct graph : public llm_graph_context {
graph(const llama_model & model, const llm_graph_params & params);
};
std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override;
};
struct llama_model_mistral3 : public llama_model_base {
llama_model_mistral3(const struct llama_model_params & params) : llama_model_base(params) {}
void load_arch_hparams(llama_model_loader & ml) override;

View File

@ -12,16 +12,23 @@ void llama_model_qwen35::load_arch_hparams(llama_model_loader & ml) {
ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group);
// Mark recurrent layers (linear attention layers)
// NextN/MTP (Qwen3.5/3.6): extra decoder block appended beyond the main stack
ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false);
GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer");
hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers;
// Mark recurrent layers (linear attention layers). MTP layers are dense
// attention-only and must be flagged non-recurrent.
{
const uint32_t n_main = hparams.n_layer - hparams.nextn_predict_layers;
uint32_t full_attn_interval = 4;
ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false);
for (uint32_t i = 0; i < hparams.n_layer; ++i) {
hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0);
hparams.recurrent_layer_arr[i] = (i < n_main) && ((i + 1) % full_attn_interval != 0);
}
}
switch (hparams.n_layer) {
switch (hparams.n_layer - hparams.nextn_predict_layers) {
case 24: type = hparams.n_embd == 1024 ? LLM_TYPE_0_8B : LLM_TYPE_2B; break;
case 32: type = hparams.n_embd == 2560 ? LLM_TYPE_4B : LLM_TYPE_9B; break;
case 64: type = LLM_TYPE_27B; break;
@ -83,6 +90,16 @@ void llama_model_qwen35::load_arch_tensors(llama_model_loader &) {
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
// NextN/MTP tensors (preserved but unused) - only bound on MTP layers
if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) {
layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, TENSOR_NOT_REQUIRED);
layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED);
layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED);
layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED);
}
}
}
@ -111,7 +128,9 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para
ggml_tensor * inp_pos = build_inp_pos();
ggml_tensor * inp_out_ids = build_inp_out_ids();
for (int il = 0; il < n_layer; ++il) {
// MTP/NextN layers are loaded as extra decoder blocks but not executed in the main pass.
const int n_transformer_layers = n_layer - (int) hparams.nextn_predict_layers;
for (int il = 0; il < n_transformer_layers; ++il) {
ggml_tensor * inpSA = inpL;
cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il);
@ -128,7 +147,7 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para
cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il);
}
if (il == n_layer - 1 && inp_out_ids) {
if (il == n_transformer_layers - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
@ -160,6 +179,9 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para
}
cur = inpL;
cb(cur, "h_pre_norm", -1);
res->t_h_pre_norm = cur;
// Final norm
cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1);

207
src/models/qwen35_mtp.cpp Normal file
View File

@ -0,0 +1,207 @@
#include "models.h"
void llama_model_qwen35_mtp::load_arch_hparams(llama_model_loader & ml) {
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true);
ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false);
GGML_ASSERT(hparams.nextn_predict_layers > 0 && "QWEN35_MTP requires nextn_predict_layers > 0");
GGML_ASSERT(hparams.nextn_predict_layers <= hparams.n_layer);
// only the MTP layers get a KV cache, trunk layers are skipped.
hparams.kv_only_nextn = true;
hparams.n_layer_kv_from_start = -1;
for (uint32_t i = 0; i < hparams.n_layer; ++i) {
hparams.recurrent_layer_arr[i] = false;
}
type = LLM_TYPE_UNKNOWN;
}
void llama_model_qwen35_mtp::load_arch_tensors(llama_model_loader &) {
LLAMA_LOAD_LOCALS;
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0);
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, TENSOR_NOT_REQUIRED);
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
if (output == nullptr) {
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED);
}
const uint32_t n_main = n_layer - hparams.nextn_predict_layers;
for (int i = 0; i < n_layer; ++i) {
if (static_cast<uint32_t>(i) < n_main) {
continue; // trunk layer — owned by the sibling QWEN35 model
}
auto & layer = layers[i];
// MTP block looks like a full-attention Qwen3.5 decoder block.
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0);
create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0);
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0);
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0);
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
// NextN-specific tensors that define the MTP block.
layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, 0);
layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, 0);
layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, 0);
layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED);
}
}
std::unique_ptr<llm_graph_context> llama_model_qwen35_mtp::build_arch_graph(const llm_graph_params & params) const {
return std::make_unique<graph>(*this, params);
}
// LLM_ARCH_QWEN35_MTP draft head for Qwen3.5/3.6 dense series
llama_model_qwen35_mtp::graph::graph(const llama_model & model, const llm_graph_params & params)
: llm_graph_context(params) {
GGML_ASSERT(hparams.nextn_predict_layers > 0 && "QWEN35_MTP requires nextn_predict_layers > 0");
GGML_ASSERT(hparams.nextn_predict_layers == 1 && "QWEN35_MTP currently only supports a single MTP block");
const int64_t n_embd_head = hparams.n_embd_head_v();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
// The MTP block lives at the source file's original layer index.
const int il = (int) hparams.n_layer - (int) hparams.nextn_predict_layers;
const auto & layer = model.layers[il];
GGML_ASSERT(layer.nextn.eh_proj && "MTP block missing nextn.eh_proj");
GGML_ASSERT(layer.nextn.enorm && "MTP block missing nextn.enorm");
GGML_ASSERT(layer.nextn.hnorm && "MTP block missing nextn.hnorm");
int sections[4];
std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
auto inp = std::make_unique<llm_graph_input_embd>(hparams.n_embd);
inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
ggml_set_input(inp->tokens);
inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens);
ggml_set_input(inp->embd);
ggml_set_name(inp->embd, "mtp_h_input");
ggml_tensor * tok_embd_w = layer.nextn.embed_tokens ? layer.nextn.embed_tokens : model.tok_embd;
ggml_tensor * h_input = inp->embd;
ggml_tensor * tok_embd = ggml_get_rows(ctx0, tok_embd_w, inp->tokens);
cb(tok_embd, "mtp_tok_embd", il);
res->add_input(std::move(inp));
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv();
ggml_tensor * h_norm = build_norm(h_input, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il);
cb(h_norm, "mtp_hnorm", il);
ggml_tensor * e_norm = build_norm(tok_embd, layer.nextn.enorm, nullptr, LLM_NORM_RMS, il);
cb(e_norm, "mtp_enorm", il);
ggml_tensor * concat = ggml_concat(ctx0, e_norm, h_norm, /*dim=*/ 0);
cb(concat, "mtp_concat", il);
ggml_tensor * cur = build_lora_mm(layer.nextn.eh_proj, concat);
cb(cur, "mtp_eh_proj", il);
ggml_tensor * inpSA = cur;
cur = build_norm(cur, layer.attn_norm, nullptr, LLM_NORM_RMS, il);
cb(cur, "mtp_attn_norm", il);
ggml_tensor * Qcur_full = build_lora_mm(layer.wq, cur, layer.wq_s);
cb(Qcur_full, "mtp_Qcur_full", il);
ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full,
n_embd_head, n_head, n_tokens,
ggml_element_size(Qcur_full) * n_embd_head * 2,
ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head,
0);
Qcur = build_norm(Qcur, layer.attn_q_norm, nullptr, LLM_NORM_RMS, il);
cb(Qcur, "mtp_Qcur_normed", il);
ggml_tensor * gate = ggml_view_3d(ctx0, Qcur_full,
n_embd_head, n_head, n_tokens,
ggml_element_size(Qcur_full) * n_embd_head * 2,
ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head,
ggml_element_size(Qcur_full) * n_embd_head);
gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens);
cb(gate, "mtp_gate", il);
ggml_tensor * Kcur = build_lora_mm(layer.wk, cur, layer.wk_s);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Kcur = build_norm(Kcur, layer.attn_k_norm, nullptr, LLM_NORM_RMS, il);
cb(Kcur, "mtp_Kcur_normed", il);
ggml_tensor * Vcur = build_lora_mm(layer.wv, cur, layer.wv_s);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
cb(Vcur, "mtp_Vcur", il);
Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr,
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, nullptr,
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
const float kq_scale = hparams.f_attention_scale == 0.0f
? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
cur = build_attn(inp_attn,
nullptr, nullptr, nullptr,
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
cb(cur, "mtp_attn_pregate", il);
cur = ggml_mul(ctx0, cur, ggml_sigmoid(ctx0, gate));
cur = build_lora_mm(layer.wo, cur, layer.wo_s);
cb(cur, "mtp_attn_out", il);
cur = ggml_add(ctx0, cur, inpSA);
cb(cur, "mtp_attn_residual", il);
ggml_tensor * ffn_residual = cur;
cur = build_norm(cur, layer.attn_post_norm, nullptr, LLM_NORM_RMS, il);
cb(cur, "mtp_attn_post_norm", il);
cur = build_ffn(cur,
layer.ffn_up, nullptr, layer.ffn_up_s,
layer.ffn_gate, nullptr, layer.ffn_gate_s,
layer.ffn_down, nullptr, layer.ffn_down_s,
nullptr,
LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(cur, "mtp_ffn_out", il);
cur = ggml_add(ctx0, cur, ffn_residual);
cb(cur, "mtp_post_ffn", il);
// Pre-norm hidden state: used by the AR draft loop to seed the next MTP step.
// (In the trunk graph this is `t_h_pre_norm`; the MTP head reuses the same slot.)
cb(cur, "h_pre_norm", -1);
res->t_h_pre_norm = cur;
ggml_tensor * head_norm_w = layer.nextn.shared_head_norm
? layer.nextn.shared_head_norm
: model.output_norm;
GGML_ASSERT(head_norm_w && "QWEN35_MTP: missing both nextn.shared_head_norm and output_norm");
cur = build_norm(cur, head_norm_w, nullptr, LLM_NORM_RMS, -1);
cb(cur, "mtp_shared_head_norm", -1);
ggml_tensor * head_w = layer.nextn.shared_head_head ? layer.nextn.shared_head_head : model.output;
GGML_ASSERT(head_w && "QWEN35_MTP: missing LM head (nextn.shared_head_head or model.output)");
cur = build_lora_mm(head_w, cur);
cb(cur, "result_output", -1);
res->t_logits = cur;
ggml_build_forward_expand(gf, cur);
}

View File

@ -15,16 +15,23 @@ void llama_model_qwen35moe::load_arch_hparams(llama_model_loader & ml) {
ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group);
// Mark recurrent layers (linear attention layers)
// NextN/MTP (Qwen3.5/3.6): extra decoder block appended beyond the main stack
ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false);
GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer");
hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers;
// Mark recurrent layers (linear attention layers). MTP layers are dense
// attention-only and must be flagged non-recurrent.
{
const uint32_t n_main = hparams.n_layer - hparams.nextn_predict_layers;
uint32_t full_attn_interval = 4;
ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false);
for (uint32_t i = 0; i < hparams.n_layer; ++i) {
hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0);
hparams.recurrent_layer_arr[i] = (i < n_main) && ((i + 1) % full_attn_interval != 0);
}
}
switch (hparams.n_layer) {
switch (hparams.n_layer - hparams.nextn_predict_layers) {
case 40: type = LLM_TYPE_35B_A3B; break;
case 48: type = LLM_TYPE_122B_A10B; break;
case 60: type = LLM_TYPE_397B_A17B; break;
@ -96,6 +103,16 @@ void llama_model_qwen35moe::load_arch_tensors(llama_model_loader &) {
layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0);
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0);
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, 0);
// NextN/MTP tensors (preserved but unused) - only bound on MTP layers
if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) {
layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, TENSOR_NOT_REQUIRED);
layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED);
layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED);
layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED);
}
}
}
@ -124,7 +141,9 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p
ggml_tensor * inp_pos = build_inp_pos();
ggml_tensor * inp_out_ids = build_inp_out_ids();
for (int il = 0; il < n_layer; ++il) {
// MTP/NextN layers are loaded as extra decoder blocks but not executed in the main pass.
const int n_transformer_layers = n_layer - (int) hparams.nextn_predict_layers;
for (int il = 0; il < n_transformer_layers; ++il) {
ggml_tensor * inpSA = inpL;
cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il);
@ -141,7 +160,7 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p
cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il);
}
if (il == n_layer - 1 && inp_out_ids) {
if (il == n_transformer_layers - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
@ -173,6 +192,9 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p
}
cur = inpL;
cb(cur, "h_pre_norm", -1);
res->t_h_pre_norm = cur;
// Final norm
cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1);

View File

@ -0,0 +1,252 @@
#include "models.h"
void llama_model_qwen35moe_mtp::load_arch_hparams(llama_model_loader & ml) {
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false);
ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true);
ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false);
GGML_ASSERT(hparams.nextn_predict_layers > 0 && "QWEN35MOE_MTP requires nextn_predict_layers > 0");
GGML_ASSERT(hparams.nextn_predict_layers <= hparams.n_layer);
GGML_ASSERT(hparams.n_expert > 0 && "QWEN35MOE_MTP requires n_expert > 0");
// only the MTP layers get a KV cache, trunk layers are skipped.
hparams.kv_only_nextn = true;
hparams.n_layer_kv_from_start = -1;
for (uint32_t i = 0; i < hparams.n_layer; ++i) {
hparams.recurrent_layer_arr[i] = false;
}
type = LLM_TYPE_UNKNOWN;
}
void llama_model_qwen35moe_mtp::load_arch_tensors(llama_model_loader &) {
LLAMA_LOAD_LOCALS;
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0);
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, TENSOR_NOT_REQUIRED);
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
if (output == nullptr) {
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED);
}
const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used;
const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff;
const uint32_t n_main = n_layer - hparams.nextn_predict_layers;
for (int i = 0; i < n_layer; ++i) {
if (static_cast<uint32_t>(i) < n_main) {
continue; // trunk layer — owned by the sibling QWEN35MOE model
}
auto & layer = layers[i];
// MTP block looks like a full-attention Qwen3.5 decoder block with MoE FFN.
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0);
create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0);
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0);
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0);
// Routed experts
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0);
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0);
create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, 0);
// Shared experts
layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), { n_embd }, 0);
layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0);
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0);
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, 0);
// NextN-specific tensors that define the MTP block.
layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, 0);
layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, 0);
layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, 0);
layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED);
}
}
std::unique_ptr<llm_graph_context> llama_model_qwen35moe_mtp::build_arch_graph(const llm_graph_params & params) const {
return std::make_unique<graph>(*this, params);
}
// LLM_ARCH_QWEN35MOE_MTP draft head for Qwen3.5/3.6 MoE
llama_model_qwen35moe_mtp::graph::graph(const llama_model & model, const llm_graph_params & params)
: llm_graph_context(params) {
GGML_ASSERT(hparams.nextn_predict_layers > 0 && "QWEN35MOE_MTP requires nextn_predict_layers > 0");
GGML_ASSERT(hparams.nextn_predict_layers == 1 && "QWEN35MOE_MTP currently only supports a single MTP block");
const int64_t n_embd_head = hparams.n_embd_head_v();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
const int il = (int) hparams.n_layer - (int) hparams.nextn_predict_layers;
const auto & layer = model.layers[il];
GGML_ASSERT(layer.nextn.eh_proj && "MTP block missing nextn.eh_proj");
GGML_ASSERT(layer.nextn.enorm && "MTP block missing nextn.enorm");
GGML_ASSERT(layer.nextn.hnorm && "MTP block missing nextn.hnorm");
GGML_ASSERT(layer.ffn_gate_inp && "MTP block missing ffn_gate_inp");
int sections[4];
std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
auto inp = std::make_unique<llm_graph_input_embd>(hparams.n_embd);
inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
ggml_set_input(inp->tokens);
inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens);
ggml_set_input(inp->embd);
ggml_set_name(inp->embd, "mtp_h_input");
ggml_tensor * tok_embd_w = layer.nextn.embed_tokens ? layer.nextn.embed_tokens : model.tok_embd;
ggml_tensor * h_input = inp->embd;
ggml_tensor * tok_embd = ggml_get_rows(ctx0, tok_embd_w, inp->tokens);
cb(tok_embd, "mtp_tok_embd", il);
res->add_input(std::move(inp));
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv();
ggml_tensor * h_norm = build_norm(h_input, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il);
cb(h_norm, "mtp_hnorm", il);
ggml_tensor * e_norm = build_norm(tok_embd, layer.nextn.enorm, nullptr, LLM_NORM_RMS, il);
cb(e_norm, "mtp_enorm", il);
ggml_tensor * concat = ggml_concat(ctx0, e_norm, h_norm, /*dim=*/ 0);
cb(concat, "mtp_concat", il);
ggml_tensor * cur = build_lora_mm(layer.nextn.eh_proj, concat);
cb(cur, "mtp_eh_proj", il);
ggml_tensor * inpSA = cur;
cur = build_norm(cur, layer.attn_norm, nullptr, LLM_NORM_RMS, il);
cb(cur, "mtp_attn_norm", il);
ggml_tensor * Qcur_full = build_lora_mm(layer.wq, cur, layer.wq_s);
cb(Qcur_full, "mtp_Qcur_full", il);
ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full,
n_embd_head, n_head, n_tokens,
ggml_element_size(Qcur_full) * n_embd_head * 2,
ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head,
0);
Qcur = build_norm(Qcur, layer.attn_q_norm, nullptr, LLM_NORM_RMS, il);
cb(Qcur, "mtp_Qcur_normed", il);
ggml_tensor * gate = ggml_view_3d(ctx0, Qcur_full,
n_embd_head, n_head, n_tokens,
ggml_element_size(Qcur_full) * n_embd_head * 2,
ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head,
ggml_element_size(Qcur_full) * n_embd_head);
gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens);
cb(gate, "mtp_gate", il);
ggml_tensor * Kcur = build_lora_mm(layer.wk, cur, layer.wk_s);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Kcur = build_norm(Kcur, layer.attn_k_norm, nullptr, LLM_NORM_RMS, il);
cb(Kcur, "mtp_Kcur_normed", il);
ggml_tensor * Vcur = build_lora_mm(layer.wv, cur, layer.wv_s);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
cb(Vcur, "mtp_Vcur", il);
Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr,
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, nullptr,
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
const float kq_scale = hparams.f_attention_scale == 0.0f
? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
cur = build_attn(inp_attn,
nullptr, nullptr, nullptr,
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
cb(cur, "mtp_attn_pregate", il);
cur = ggml_mul(ctx0, cur, ggml_sigmoid(ctx0, gate));
cur = build_lora_mm(layer.wo, cur, layer.wo_s);
cb(cur, "mtp_attn_out", il);
cur = ggml_add(ctx0, cur, inpSA);
cb(cur, "mtp_attn_residual", il);
ggml_tensor * ffn_residual = cur;
cur = build_norm(cur, layer.attn_post_norm, nullptr, LLM_NORM_RMS, il);
cb(cur, "mtp_attn_post_norm", il);
// MoE FFN — routed experts plus gated shared expert (mirrors qwen35moe).
ggml_tensor * moe_out =
build_moe_ffn(cur,
layer.ffn_gate_inp,
layer.ffn_up_exps,
layer.ffn_gate_exps,
layer.ffn_down_exps,
nullptr,
n_expert, n_expert_used,
LLM_FFN_SILU, true,
hparams.expert_weights_scale,
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il,
nullptr, layer.ffn_gate_up_exps,
layer.ffn_up_exps_s,
layer.ffn_gate_exps_s,
layer.ffn_down_exps_s);
cb(moe_out, "mtp_ffn_moe_out", il);
if (layer.ffn_up_shexp != nullptr) {
ggml_tensor * ffn_shexp =
build_ffn(cur,
layer.ffn_up_shexp, nullptr, layer.ffn_up_shexp_s,
layer.ffn_gate_shexp, nullptr, layer.ffn_gate_shexp_s,
layer.ffn_down_shexp, nullptr, layer.ffn_down_shexp_s,
nullptr,
LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(ffn_shexp, "mtp_ffn_shexp", il);
ggml_tensor * shared_gate = build_lora_mm(layer.ffn_gate_inp_shexp, cur);
shared_gate = ggml_sigmoid(ctx0, shared_gate);
cb(shared_gate, "mtp_shared_expert_gate_sigmoid", il);
ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate);
cb(ffn_shexp, "mtp_ffn_shexp_gated", il);
cur = ggml_add(ctx0, moe_out, ffn_shexp);
} else {
cur = moe_out;
}
cb(cur, "mtp_ffn_out", il);
cur = ggml_add(ctx0, cur, ffn_residual);
cb(cur, "mtp_post_ffn", il);
// Pre-norm hidden state: used by the AR draft loop to seed the next MTP step.
cb(cur, "h_pre_norm", -1);
res->t_h_pre_norm = cur;
ggml_tensor * head_norm_w = layer.nextn.shared_head_norm
? layer.nextn.shared_head_norm
: model.output_norm;
GGML_ASSERT(head_norm_w && "QWEN35MOE_MTP: missing both nextn.shared_head_norm and output_norm");
cur = build_norm(cur, head_norm_w, nullptr, LLM_NORM_RMS, -1);
cb(cur, "mtp_shared_head_norm", -1);
ggml_tensor * head_w = layer.nextn.shared_head_head ? layer.nextn.shared_head_head : model.output;
GGML_ASSERT(head_w && "QWEN35MOE_MTP: missing LM head (nextn.shared_head_head or model.output)");
cur = build_lora_mm(head_w, cur);
cb(cur, "result_output", -1);
res->t_logits = cur;
ggml_build_forward_expand(gf, cur);
}

View File

@ -406,6 +406,9 @@ static bool arch_supported(const llm_arch arch) {
if (arch == LLM_ARCH_DEEPSEEK2OCR) {
return false;
}
if (arch == LLM_ARCH_QWEN35_MTP || arch == LLM_ARCH_QWEN35MOE_MTP) {
return false; // MTP-only arch; requires a sibling trunk model and cannot run standalone.
}
// FIXME some models are segfaulting with WebGPU:
#ifdef GGML_USE_WEBGPU

View File

@ -57,6 +57,11 @@ struct server_slot {
llama_context * ctx_tgt = nullptr;
llama_context * ctx_dft = nullptr;
// True when this slot's speculative impl is MTP (ctx_dft is the MTP head).
// MTP needs every prefill position to carry logits=1 so the streaming
// hook in common_speculative_state_mtp::process() can read t_h_pre_norm.
bool is_mtp_enabled = false;
// multimodal
mtmd_context * mctx = nullptr;
@ -237,8 +242,20 @@ struct server_slot {
(ggml_time_us() - t_start) / 1000.0, n_text, (int) prompt.tokens.size());
}
bool is_mtp() const { return is_mtp_enabled; }
// The trunk needs to emit logits at every prefill position when either:
// - the task asked for embeddings, or
// - MTP is enabled for this slot (the streaming hook in process() reads
// h_pre_norm at every prompt position).
bool need_embd() const {
GGML_ASSERT(task);
return task->need_embd() || is_mtp();
}
// if the context does not have a memory module then all embeddings have to be computed within a single ubatch
// also we cannot split if the pooling would require any past tokens
// (MTP supports splitting — uses task->need_embd() not need_embd())
bool can_split() const {
GGML_ASSERT(task);
@ -743,6 +760,53 @@ private:
ctx_dft_seq_rm_type = common_context_can_seq_rm(ctx_dft.get());
params_base.speculative.draft.ctx_tgt = ctx_tgt;
params_base.speculative.draft.ctx_dft = ctx_dft.get();
} else if (params_base.speculative.type == COMMON_SPECULATIVE_TYPE_MTP) {
// MTP head lives in the *target* GGUF — load it as a sibling model
// with override_arch and feed it through the existing ctx_dft slot.
char trunk_arch[64] = {0};
llama_model_meta_val_str(model_tgt, "general.architecture", trunk_arch, sizeof(trunk_arch));
const char * mtp_arch = nullptr;
if (std::string(trunk_arch) == "qwen35") {
mtp_arch = "qwen35_mtp";
} else if (std::string(trunk_arch) == "qwen35moe") {
mtp_arch = "qwen35moe_mtp";
} else {
SRV_ERR("MTP not supported for trunk architecture '%s'\n", trunk_arch);
return false;
}
if (params_base.n_parallel > 1) {
SRV_ERR("MTP currently supports only n_parallel=1; got %d\n", params_base.n_parallel);
return false;
}
SRV_INF("loading MTP head from '%s' (override_arch=%s)\n",
params_base.model.path.c_str(), mtp_arch);
auto mparams_mtp = common_model_params_to_llama(params_base);
mparams_mtp.override_arch = mtp_arch;
model_dft.reset(llama_model_load_from_file(params_base.model.path.c_str(), mparams_mtp));
if (model_dft == nullptr) {
SRV_ERR("failed to load MTP head from '%s'\n", params_base.model.path.c_str());
return false;
}
auto cparams_mtp = common_context_params_to_llama(params_base);
cparams_mtp.n_ctx = llama_n_ctx_seq(ctx_tgt);
cparams_mtp.n_seq_max = 1;
ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams_mtp));
if (ctx_dft == nullptr) {
SRV_ERR("%s", "failed to create MTP context\n");
return false;
}
ctx_dft_seq_rm_type = common_context_can_seq_rm(ctx_dft.get());
params_base.speculative.draft.ctx_tgt = ctx_tgt;
params_base.speculative.draft.ctx_dft = ctx_dft.get();
}
@ -855,6 +919,7 @@ private:
slot.ctx_tgt = ctx_tgt;
slot.ctx_dft = ctx_dft.get();
slot.spec = spec.get();
slot.is_mtp_enabled = (params_base.speculative.type == COMMON_SPECULATIVE_TYPE_MTP) && (ctx_dft != nullptr);
slot.n_ctx = n_ctx_slot;
slot.mctx = mctx;
@ -2716,12 +2781,14 @@ private:
break;
}
// embedding requires all tokens in the batch to be output
// embedding requires all tokens in the batch to be output;
// MTP also wants logits at every prompt position so the
// streaming hook can mirror t_h_pre_norm into ctx_dft.
common_batch_add(batch,
cur_tok,
slot.prompt.tokens.pos_next(),
{ slot.id },
slot.task->need_embd());
slot.need_embd());
slot.prompt.tokens.push_back(cur_tok);
slot.n_prompt_tokens_processed++;
@ -2838,7 +2905,7 @@ private:
slot_batched->lora[alora_disabled_id].scale = alora_scale;
}
llama_set_embeddings(ctx_tgt, slot_batched->task->need_embd());
llama_set_embeddings(ctx_tgt, slot_batched->need_embd());
}
if (batch.n_tokens == 0) {