mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-27 23:50:20 -05:00
llama: Gemma 4 MTP
This commit is contained in:
parent
0066404085
commit
f268966d49
@ -418,6 +418,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
|
||||
int32_t n_embd = 0;
|
||||
|
||||
bool kv_shared_with_target = false;
|
||||
|
||||
// Per-sequence cross-batch carryover: pair (h_p, x_{p+1}) at MTP pos p+1.
|
||||
// The last h-row of one process() call needs the first token of the NEXT
|
||||
// call to pair with, so it's stashed here until that next call fires.
|
||||
@ -444,7 +446,9 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
auto * ctx_dft = this->params.ctx_dft;
|
||||
GGML_ASSERT(ctx_tgt && ctx_dft && "MTP requires ctx_tgt and ctx_dft to be set");
|
||||
|
||||
n_embd = llama_model_n_embd(llama_get_model(ctx_dft));
|
||||
n_embd = llama_model_n_embd_out(llama_get_model(ctx_dft));
|
||||
GGML_ASSERT(n_embd == llama_model_n_embd(llama_get_model(ctx_tgt)) &&
|
||||
"MTP input row width must match the target h_nextn width");
|
||||
|
||||
LOG_INF("%s: adding speculative implementation 'draft-mtp'\n", __func__);
|
||||
LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%.2f, n_embd=%d, backend_sampling=%d\n", __func__, this->params.n_max, this->params.n_min, this->params.p_min, n_embd, (int) this->params.backend_sampling);
|
||||
@ -489,6 +493,9 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
|
||||
llama_set_embeddings_nextn(ctx_tgt, true, /*masked*/ false);
|
||||
llama_set_embeddings_nextn(ctx_dft, true, /*masked*/ true);
|
||||
llama_set_mtp_source(ctx_dft, ctx_tgt);
|
||||
|
||||
kv_shared_with_target = llama_model_n_layer_kv(llama_get_model(ctx_dft)) == 0;
|
||||
|
||||
pending_h.assign(n_seq, std::vector<float>(n_embd, 0.0f));
|
||||
|
||||
@ -526,9 +533,10 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
if (N <= 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto * ctx_dft = this->params.ctx_dft;
|
||||
const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id);
|
||||
if (pos_max < N - 1) {
|
||||
if (pos_max < N - 1 && !kv_shared_with_target) {
|
||||
LOG_WRN("%s: ctx_dft pos_max=%d < N-1=%d - "
|
||||
"process() hook may not have run on every prefill ubatch "
|
||||
"(need_embd / logits=1 on every prompt position?). "
|
||||
@ -571,48 +579,42 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
|
||||
const size_t row_bytes = (size_t) n_embd * sizeof(float);
|
||||
|
||||
common_batch_clear(batch);
|
||||
// if kv is shared with target (e.g Gemma4), then we can skip this catch-up decode
|
||||
if (!kv_shared_with_target) {
|
||||
common_batch_clear(batch);
|
||||
|
||||
for (int k = 0; k < n_tokens; ++k) {
|
||||
common_batch_add(batch, batch_in.token[k], batch_in.pos[k], { batch_in.seq_id[k][0] }, 0);
|
||||
}
|
||||
|
||||
// shift the tgt embeddings to the right by one position
|
||||
// assumes that the tokens in the batch are sequential for each sequence
|
||||
// i.e. we cannot have seq_id like this: [0, 0, 0, 1, 1, 0, 1, 1]
|
||||
// ^--- this is a problem
|
||||
// TODO:this is generally true, but would be nice to assert it
|
||||
{
|
||||
const float * h_tgt = llama_get_embeddings_nextn(ctx_tgt);
|
||||
std::memcpy(batch.embd + (size_t) 1 * n_embd, h_tgt, row_bytes * (n_tokens-1));
|
||||
|
||||
//{
|
||||
// // string with seq_ids in the batch
|
||||
// std::stringstream ss;
|
||||
// for (int i = 0; i < n_tokens; ++i) {
|
||||
// ss << batch_in.seq_id[i][0] << ",";
|
||||
// }
|
||||
// LOG_WRN("%s: batch_in.seq_id = %s\n", __func__, ss.str().c_str());
|
||||
//}
|
||||
}
|
||||
|
||||
// fill the pending embeddings from a previous run
|
||||
auto set_h = [&](int idx, const float * h_row) {
|
||||
std::memcpy(batch.embd + (size_t) idx * n_embd, h_row, row_bytes);
|
||||
};
|
||||
|
||||
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
|
||||
if (i_batch_beg[seq_id] < 0) {
|
||||
continue;
|
||||
for (int k = 0; k < n_tokens; ++k) {
|
||||
common_batch_add(batch, batch_in.token[k], batch_in.pos[k], { batch_in.seq_id[k][0] }, 0);
|
||||
}
|
||||
|
||||
set_h(i_batch_beg[seq_id], pending_h[seq_id].data());
|
||||
}
|
||||
// shift the tgt embeddings to the right by one position
|
||||
// assumes that the tokens in the batch are sequential for each sequence
|
||||
// i.e. we cannot have seq_id like this: [0, 0, 0, 1, 1, 0, 1, 1]
|
||||
// ^--- this is a problem
|
||||
// TODO:this is generally true, but would be nice to assert it
|
||||
{
|
||||
const float * h_tgt = llama_get_embeddings_nextn(ctx_tgt);
|
||||
std::memcpy(batch.embd + (size_t) 1 * n_embd, h_tgt, row_bytes * (n_tokens-1));
|
||||
}
|
||||
|
||||
const int32_t rc = llama_decode(ctx_dft, batch);
|
||||
if (rc != 0) {
|
||||
LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (pos=%d)\n", __func__, (int) rc, (int) batch_in.pos[0]);
|
||||
return false;
|
||||
// fill the pending embeddings from a previous run
|
||||
auto set_h = [&](int idx, const float * h_row) {
|
||||
std::memcpy(batch.embd + (size_t) idx * n_embd, h_row, row_bytes);
|
||||
};
|
||||
|
||||
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
|
||||
if (i_batch_beg[seq_id] < 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
set_h(i_batch_beg[seq_id], pending_h[seq_id].data());
|
||||
}
|
||||
|
||||
const int32_t rc = llama_decode(ctx_dft, batch);
|
||||
if (rc != 0) {
|
||||
LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (pos=%d)\n", __func__, (int) rc, (int) batch_in.pos[0]);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
|
||||
|
||||
@ -75,6 +75,7 @@ TEXT_MODEL_MAP: dict[str, str] = {
|
||||
"Gemma3TextModel": "gemma",
|
||||
"Gemma3nForCausalLM": "gemma",
|
||||
"Gemma3nForConditionalGeneration": "gemma",
|
||||
"Gemma4AssistantForCausalLM": "gemma",
|
||||
"Gemma4ForConditionalGeneration": "gemma",
|
||||
"Gemma4ForCausalLM": "gemma",
|
||||
"Gemma4UnifiedForConditionalGeneration": "gemma",
|
||||
|
||||
@ -785,6 +785,16 @@ class Gemma4UnifiedModel(Gemma4Model):
|
||||
self.gguf_writer.add_suppress_tokens(suppress_tokens)
|
||||
|
||||
|
||||
@ModelBase.register("Gemma4AssistantForCausalLM")
|
||||
class Gemma4AssistantModel(Gemma4Model):
|
||||
model_arch = gguf.MODEL_ARCH.GEMMA4_ASSISTANT
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_embedding_length_out(self.hparams["backbone_hidden_size"])
|
||||
self.gguf_writer.add_nextn_predict_layers(self.block_count)
|
||||
|
||||
|
||||
@ModelBase.register("Gemma4ForConditionalGeneration")
|
||||
class Gemma4VisionAudioModel(MmprojModel):
|
||||
has_audio_encoder = True
|
||||
|
||||
@ -434,6 +434,7 @@ class MODEL_ARCH(IntEnum):
|
||||
GEMMA3 = auto()
|
||||
GEMMA3N = auto()
|
||||
GEMMA4 = auto()
|
||||
GEMMA4_ASSISTANT = auto()
|
||||
GEMMA_EMBEDDING = auto()
|
||||
STARCODER2 = auto()
|
||||
RWKV6 = auto()
|
||||
@ -866,6 +867,8 @@ class MODEL_TENSOR(IntEnum):
|
||||
A_PER_DIM_K_SCALE = auto() # gemma4
|
||||
A_PER_DIM_SCALE = auto() # gemma4
|
||||
# nextn/mtp
|
||||
NEXTN_PRE_PROJ = auto()
|
||||
NEXTN_POST_PROJ = auto()
|
||||
NEXTN_EH_PROJ = auto()
|
||||
NEXTN_EMBED_TOKENS = auto()
|
||||
NEXTN_ENORM = auto()
|
||||
@ -955,6 +958,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
MODEL_ARCH.GEMMA3: "gemma3",
|
||||
MODEL_ARCH.GEMMA3N: "gemma3n",
|
||||
MODEL_ARCH.GEMMA4: "gemma4",
|
||||
MODEL_ARCH.GEMMA4_ASSISTANT: "gemma4-assistant",
|
||||
MODEL_ARCH.GEMMA_EMBEDDING: "gemma-embedding",
|
||||
MODEL_ARCH.STARCODER2: "starcoder2",
|
||||
MODEL_ARCH.RWKV6: "rwkv6",
|
||||
@ -1417,6 +1421,8 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
MODEL_TENSOR.A_QF_FFN_DOWN: "a.proj_blk.{bid}.ffn_down",
|
||||
MODEL_TENSOR.A_QF_FFN_NORM: "a.proj_blk.{bid}.ffn_norm",
|
||||
# NextN/MTP
|
||||
MODEL_TENSOR.NEXTN_PRE_PROJ: "nextn.pre_projection",
|
||||
MODEL_TENSOR.NEXTN_POST_PROJ: "nextn.post_projection",
|
||||
MODEL_TENSOR.NEXTN_EH_PROJ: "blk.{bid}.nextn.eh_proj",
|
||||
MODEL_TENSOR.NEXTN_EMBED_TOKENS: "blk.{bid}.nextn.embed_tokens",
|
||||
MODEL_TENSOR.NEXTN_ENORM: "blk.{bid}.nextn.enorm",
|
||||
@ -2500,6 +2506,24 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.PER_LAYER_PROJ_NORM,
|
||||
MODEL_TENSOR.PER_LAYER_POST_NORM,
|
||||
],
|
||||
MODEL_ARCH.GEMMA4_ASSISTANT: [
|
||||
MODEL_TENSOR.ROPE_FREQS,
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.NEXTN_PRE_PROJ,
|
||||
MODEL_TENSOR.NEXTN_POST_PROJ,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_Q_NORM,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_POST_NORM,
|
||||
MODEL_TENSOR.FFN_PRE_NORM,
|
||||
MODEL_TENSOR.FFN_POST_NORM,
|
||||
MODEL_TENSOR.LAYER_OUT_SCALE,
|
||||
],
|
||||
MODEL_ARCH.GEMMA_EMBEDDING: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
|
||||
@ -2279,6 +2279,14 @@ class TensorNameMap:
|
||||
),
|
||||
|
||||
# NextN/MTP tensors
|
||||
MODEL_TENSOR.NEXTN_PRE_PROJ: (
|
||||
"pre_projection",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.NEXTN_POST_PROJ: (
|
||||
"post_projection",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.NEXTN_EH_PROJ: (
|
||||
"model.layers.{bid}.eh_proj",
|
||||
),
|
||||
|
||||
@ -57,6 +57,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_GEMMA3, "gemma3" },
|
||||
{ LLM_ARCH_GEMMA3N, "gemma3n" },
|
||||
{ LLM_ARCH_GEMMA4, "gemma4" },
|
||||
{ LLM_ARCH_GEMMA4_ASSISTANT, "gemma4-assistant" },
|
||||
{ LLM_ARCH_GEMMA_EMBEDDING, "gemma-embedding" },
|
||||
{ LLM_ARCH_STARCODER2, "starcoder2" },
|
||||
{ LLM_ARCH_MAMBA, "mamba" },
|
||||
@ -452,6 +453,8 @@ static const std::map<llm_tensor, const char *> LLM_TENSOR_NAMES = {
|
||||
{ LLM_TENSOR_FFN_NORM_EXPS, "blk.%d.ffn_norm_exps" },
|
||||
{ LLM_TENSOR_ATTN_K_B, "blk.%d.attn_k_b" },
|
||||
{ LLM_TENSOR_ATTN_V_B, "blk.%d.attn_v_b" },
|
||||
{ LLM_TENSOR_NEXTN_PRE_PROJ, "nextn.pre_projection" },
|
||||
{ LLM_TENSOR_NEXTN_POST_PROJ, "nextn.post_projection" },
|
||||
{ LLM_TENSOR_NEXTN_EH_PROJ, "blk.%d.nextn.eh_proj" },
|
||||
{ LLM_TENSOR_NEXTN_EMBED_TOKENS, "blk.%d.nextn.embed_tokens" },
|
||||
{ LLM_TENSOR_NEXTN_ENORM, "blk.%d.nextn.enorm" },
|
||||
@ -764,6 +767,8 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
||||
{LLM_TENSOR_INDEXER_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_INDEXER_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_INDEXER_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_NEXTN_PRE_PROJ, {LLM_TENSOR_LAYER_INPUT, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_NEXTN_POST_PROJ, {LLM_TENSOR_LAYER_INPUT, GGML_OP_MUL_MAT}},
|
||||
// NextN/MTP tensors are stored per-block (blk.%d.nextn.*) even though only the
|
||||
// last nextn_predict_layers blocks carry them. Classify as LAYER_REPEATING so
|
||||
// the model loader doesn't fault on the block index.
|
||||
|
||||
@ -61,6 +61,7 @@ enum llm_arch {
|
||||
LLM_ARCH_GEMMA3,
|
||||
LLM_ARCH_GEMMA3N,
|
||||
LLM_ARCH_GEMMA4,
|
||||
LLM_ARCH_GEMMA4_ASSISTANT,
|
||||
LLM_ARCH_GEMMA_EMBEDDING,
|
||||
LLM_ARCH_STARCODER2,
|
||||
LLM_ARCH_MAMBA,
|
||||
@ -556,6 +557,8 @@ enum llm_tensor {
|
||||
LLM_TENSOR_INDEXER_PROJ,
|
||||
LLM_TENSOR_INDEXER_ATTN_K,
|
||||
LLM_TENSOR_INDEXER_ATTN_Q_B,
|
||||
LLM_TENSOR_NEXTN_PRE_PROJ,
|
||||
LLM_TENSOR_NEXTN_POST_PROJ,
|
||||
LLM_TENSOR_NEXTN_EH_PROJ,
|
||||
LLM_TENSOR_NEXTN_EMBED_TOKENS,
|
||||
LLM_TENSOR_NEXTN_ENORM,
|
||||
|
||||
@ -30,6 +30,21 @@ static llm_graph_type ctx_type_to_graph_type(llama_context_type ctx_type) {
|
||||
throw std::runtime_error("Unsupported ctx type");
|
||||
}
|
||||
|
||||
static uint32_t ctx_type_to_embd_inp(const llama_hparams & hparams, llama_context_type ctx_type) {
|
||||
switch (ctx_type) {
|
||||
case LLAMA_CONTEXT_TYPE_DEFAULT: return hparams.n_embd_inp();
|
||||
case LLAMA_CONTEXT_TYPE_MTP : return hparams.n_embd_out();
|
||||
}
|
||||
throw std::runtime_error("Unsupported ctx type");
|
||||
}
|
||||
|
||||
namespace {
|
||||
struct src_mctx_reset_on_exit {
|
||||
llama_memory_context_ptr * slot;
|
||||
~src_mctx_reset_on_exit() { if (slot) slot->reset(); }
|
||||
};
|
||||
}
|
||||
|
||||
llama_context::llama_context(
|
||||
const llama_model & model,
|
||||
llama_context_params params) :
|
||||
@ -372,7 +387,11 @@ llama_context::llama_context(
|
||||
LLAMA_LOG_INFO("%s: pipeline parallelism enabled\n", __func__);
|
||||
}
|
||||
|
||||
sched_reserve();
|
||||
// MTP draft contexts can't reserve until the source context is wired
|
||||
// via llama_set_mtp_source — defer to the first decode.
|
||||
if (cparams.ctx_type != LLAMA_CONTEXT_TYPE_MTP) {
|
||||
sched_reserve();
|
||||
}
|
||||
|
||||
if (!cparams.flash_attn) {
|
||||
if (ggml_is_quantized(params.type_v)) {
|
||||
@ -446,6 +465,23 @@ void llama_context::sched_reserve() {
|
||||
}
|
||||
}
|
||||
|
||||
// When called from decode(), src_mctx_for_decode is already populated and
|
||||
// we must not drop it on exit (process_ubatch still needs it). Snapshot
|
||||
// only when sched_reserve runs standalone (e.g. lazy first-decode reserve
|
||||
// when set_mtp_source flipped sched_need_reserve).
|
||||
const bool owns_src_snapshot = src_ctx && !src_mctx_for_decode;
|
||||
if (owns_src_snapshot) {
|
||||
auto * src_memory = src_ctx->get_memory();
|
||||
if (!src_memory) {
|
||||
throw std::runtime_error("MTP source context has no memory module");
|
||||
}
|
||||
src_mctx_for_decode = src_memory->init_full();
|
||||
if (!src_mctx_for_decode) {
|
||||
throw std::runtime_error("failed to initialize MTP source memory snapshot");
|
||||
}
|
||||
}
|
||||
src_mctx_reset_on_exit reserve_src_drop{owns_src_snapshot ? &src_mctx_for_decode : nullptr};
|
||||
|
||||
// avoid reserving graphs with zero outputs - assume one output per sequence
|
||||
const int n_outputs = n_seqs;
|
||||
|
||||
@ -904,7 +940,7 @@ float * llama_context::get_embeddings_nextn_ith(int32_t i) {
|
||||
throw std::runtime_error("no nextn embeddings");
|
||||
}
|
||||
|
||||
const uint32_t n_embd = model.hparams.n_embd;
|
||||
const uint32_t n_embd = model.hparams.n_embd_out();
|
||||
|
||||
if (!cparams.embeddings_nextn_masked) {
|
||||
// unmasked: nextn rows are stored densely, indexed by raw token position.
|
||||
@ -1113,6 +1149,17 @@ void llama_context::set_embeddings_nextn(bool value, bool masked) {
|
||||
cparams.embeddings_nextn_masked = masked;
|
||||
}
|
||||
|
||||
void llama_context::set_mtp_source(llama_context * src) {
|
||||
if (src_ctx == src) {
|
||||
return;
|
||||
}
|
||||
src_ctx = src;
|
||||
src_mctx_for_decode.reset();
|
||||
// worst-case compute buffers were reserved without knowing about the source
|
||||
// memory; force a re-reserve so the next decode sees src views
|
||||
sched_need_reserve = true;
|
||||
}
|
||||
|
||||
void llama_context::set_causal_attn(bool value) {
|
||||
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
|
||||
|
||||
@ -1338,7 +1385,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
||||
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
const int64_t n_embd = hparams.n_embd_inp();
|
||||
const int64_t n_embd = ctx_type_to_embd_inp(hparams, cparams.ctx_type);
|
||||
const int64_t n_vocab = model.vocab.n_tokens();
|
||||
|
||||
// note: during encode, we always pass the full sequence starting from pos = 0
|
||||
@ -1473,7 +1520,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
||||
ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_nextn);
|
||||
GGML_ASSERT(backend_h != nullptr);
|
||||
|
||||
const uint32_t n_embd = hparams.n_embd;
|
||||
const uint32_t n_embd = hparams.n_embd_out();
|
||||
GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_nextn.size);
|
||||
ggml_backend_tensor_get_async(backend_h, t_h_nextn, embd_nextn.data, 0, n_tokens*n_embd*sizeof(float));
|
||||
}
|
||||
@ -1648,7 +1695,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
const int64_t n_vocab = vocab.n_tokens();
|
||||
const int64_t n_embd = hparams.n_embd_inp();
|
||||
const int64_t n_embd = ctx_type_to_embd_inp(hparams, cparams.ctx_type);
|
||||
|
||||
// when computing embeddings, all tokens are output
|
||||
const bool output_all = cparams.embeddings;
|
||||
@ -1710,6 +1757,20 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
||||
embd_seq.clear();
|
||||
output_swaps.clear();
|
||||
|
||||
src_mctx_reset_on_exit decode_src_drop{&src_mctx_for_decode};
|
||||
if (src_ctx) {
|
||||
auto * src_memory = src_ctx->get_memory();
|
||||
if (!src_memory) {
|
||||
LLAMA_LOG_ERROR("%s: MTP source context has no memory module\n", __func__);
|
||||
return -2;
|
||||
}
|
||||
src_mctx_for_decode = src_memory->init_full();
|
||||
if (!src_mctx_for_decode) {
|
||||
LLAMA_LOG_ERROR("%s: failed to snapshot MTP source memory\n", __func__);
|
||||
return -2;
|
||||
}
|
||||
}
|
||||
|
||||
sched_reserve();
|
||||
|
||||
bool did_optimize = false;
|
||||
@ -1924,7 +1985,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
||||
ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_nextn);
|
||||
GGML_ASSERT(backend_h != nullptr);
|
||||
|
||||
const uint32_t n_embd = hparams.n_embd;
|
||||
const uint32_t n_embd = hparams.n_embd_out();
|
||||
float * embd_nextn_out = embd_nextn.data + offset*n_embd;
|
||||
|
||||
GGML_ASSERT((offset + n_rows)*n_embd <= (int64_t) embd_nextn.size);
|
||||
@ -2017,7 +2078,6 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
||||
|
||||
const auto n_batch = cparams.n_batch;
|
||||
const auto n_vocab = vocab.n_tokens();
|
||||
const auto n_embd = hparams.n_embd;
|
||||
const auto n_embd_out = hparams.n_embd_out();
|
||||
|
||||
bool has_logits = true;
|
||||
@ -2036,12 +2096,12 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
||||
|
||||
logits.size = has_logits ? n_vocab*n_outputs_max : 0;
|
||||
embd.size = has_embd ? n_embd_out*n_outputs_max : 0;
|
||||
embd_nextn.size = has_embd_nextn ? n_embd*n_outputs_max : 0;
|
||||
embd_nextn.size = has_embd_nextn ? n_embd_out*n_outputs_max : 0;
|
||||
|
||||
if (has_embd_nextn && !cparams.embeddings_nextn_masked) {
|
||||
// unmasked: nextn row exists for every token in the batch, not just
|
||||
// those flagged via batch.logits[i] -> size by token count instead.
|
||||
embd_nextn.size = (size_t) n_embd * n_batch;
|
||||
embd_nextn.size = (size_t) n_embd_out * n_batch;
|
||||
}
|
||||
|
||||
// Allocate backend sampling output buffers if there are backend samplers configured.
|
||||
@ -2304,6 +2364,8 @@ llm_graph_params llama_context::graph_params(
|
||||
/*.cvec =*/ cvec.get(),
|
||||
/*.loras =*/ loras.get(),
|
||||
/*.mctx =*/ mctx,
|
||||
/*.src_mctx =*/ src_mctx_for_decode.get(),
|
||||
/*.src_model =*/ src_ctx ? &src_ctx->get_model() : nullptr,
|
||||
/*.cross =*/ &cross,
|
||||
/*.samplers =*/ sampling.samplers,
|
||||
/*.n_outputs =*/ n_outputs,
|
||||
@ -3593,6 +3655,10 @@ void llama_set_embeddings_nextn(llama_context * ctx, bool value, bool masked) {
|
||||
ctx->set_embeddings_nextn(value, masked);
|
||||
}
|
||||
|
||||
void llama_set_mtp_source(llama_context * ctx, llama_context * src) {
|
||||
ctx->set_mtp_source(src);
|
||||
}
|
||||
|
||||
float * llama_get_embeddings_nextn(llama_context * ctx) {
|
||||
ctx->synchronize();
|
||||
|
||||
|
||||
@ -6,6 +6,7 @@
|
||||
#include "llama-graph.h"
|
||||
#include "llama-adapter.h"
|
||||
#include "llama-impl.h"
|
||||
#include "llama-memory.h"
|
||||
|
||||
#include "ggml-cpp.h"
|
||||
#include "ggml-opt.h"
|
||||
@ -111,6 +112,7 @@ struct llama_context {
|
||||
|
||||
void set_embeddings (bool value);
|
||||
void set_embeddings_nextn(bool value, bool masked);
|
||||
void set_mtp_source(llama_context * src);
|
||||
void set_causal_attn(bool value);
|
||||
void set_warmup(bool value);
|
||||
|
||||
@ -275,6 +277,12 @@ private:
|
||||
|
||||
std::unique_ptr<llama_memory_i> memory;
|
||||
|
||||
// external KV source used by MTP draft contexts. src_ctx is the target
|
||||
// context whose memory we read; src_mctx_for_decode is a per-decode
|
||||
// snapshot held for the duration of one decode/sched_reserve call.
|
||||
llama_context * src_ctx = nullptr;
|
||||
llama_memory_context_ptr src_mctx_for_decode;
|
||||
|
||||
// decode output (2-dimensional array: [n_outputs][n_vocab])
|
||||
buffer_view<float> logits = {nullptr, 0};
|
||||
|
||||
|
||||
@ -85,6 +85,11 @@ using llama_memory_breakdown = std::map<ggml_backend_buffer_type_t, llama_memory
|
||||
LLAMA_API int32_t llama_model_n_expert (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_devices(const struct llama_model * model);
|
||||
|
||||
// number of layers that own KV (i.e. layers whose graph writes K/V).
|
||||
// 0 means the model owns no KV — e.g. a Gemma4-style MTP draft that reads
|
||||
// trunk KV via llama_set_mtp_source.
|
||||
LLAMA_API int32_t llama_model_n_layer_kv(const struct llama_model * model);
|
||||
|
||||
LLAMA_API ggml_backend_dev_t llama_model_get_device(const struct llama_model * model, int i);
|
||||
|
||||
LLAMA_API llama_memory_breakdown llama_get_memory_breakdown(const struct llama_context * ctx);
|
||||
@ -93,6 +98,7 @@ LLAMA_API llama_memory_breakdown llama_get_memory_breakdown(const struct llama_c
|
||||
// If masked == true, output the embeddings only for the tokens with batch.logits != 0
|
||||
// If masked == false, output the embeddings for all tokens in the batch regardless of batch.logits
|
||||
LLAMA_API void llama_set_embeddings_nextn(struct llama_context * ctx, bool value, bool masked);
|
||||
LLAMA_API void llama_set_mtp_source(struct llama_context * ctx, struct llama_context * src);
|
||||
|
||||
// mirrors:
|
||||
// LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
|
||||
|
||||
@ -620,6 +620,22 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
|
||||
return res;
|
||||
}
|
||||
|
||||
void llm_graph_input_attn_src_kv_iswa::set_input(const llama_ubatch * ubatch) {
|
||||
src_mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
||||
src_mctx->get_swa() ->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
|
||||
}
|
||||
|
||||
bool llm_graph_input_attn_src_kv_iswa::can_reuse(const llm_graph_params & params) {
|
||||
const auto * mctx = static_cast<const llama_kv_cache_iswa_context *>(params.src_mctx);
|
||||
|
||||
this->src_mctx = mctx;
|
||||
|
||||
bool res = true;
|
||||
res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams);
|
||||
res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams);
|
||||
return res;
|
||||
}
|
||||
|
||||
void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
||||
GGML_ASSERT(cross_kq_mask);
|
||||
|
||||
@ -1034,6 +1050,8 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
|
||||
cvec (params.cvec),
|
||||
loras (params.loras),
|
||||
mctx (params.mctx),
|
||||
src_mctx (params.src_mctx),
|
||||
src_model (params.src_model),
|
||||
cross (params.cross),
|
||||
samplers (params.samplers),
|
||||
cb_func (params.cb),
|
||||
@ -2604,6 +2622,59 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
|
||||
return (llm_graph_input_attn_cross *) res->add_input(std::move(inp));
|
||||
}
|
||||
|
||||
llm_graph_input_attn_src_kv_iswa * llm_graph_context::build_attn_inp_src_kv_iswa() const {
|
||||
GGML_ASSERT(src_mctx && "MTP draft graph requires src_mctx (set via llama_set_mtp_source)");
|
||||
|
||||
const auto * src_iswa = static_cast<const llama_kv_cache_iswa_context *>(src_mctx);
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_attn_src_kv_iswa>(hparams, cparams, src_iswa);
|
||||
|
||||
inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, src_iswa->get_base(), ubatch, cparams);
|
||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||
|
||||
inp->self_kq_mask_swa = build_attn_inp_kq_mask(ctx0, src_iswa->get_swa(), ubatch, cparams);
|
||||
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
||||
|
||||
return (llm_graph_input_attn_src_kv_iswa *) res->add_input(std::move(inp));
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_attn(
|
||||
llm_graph_input_attn_src_kv_iswa * inp,
|
||||
ggml_tensor * wo,
|
||||
ggml_tensor * wo_b,
|
||||
ggml_tensor * wo_s,
|
||||
ggml_tensor * q_cur,
|
||||
ggml_tensor * kq_b,
|
||||
ggml_tensor * sinks,
|
||||
ggml_tensor * v_mla,
|
||||
float kq_scale,
|
||||
int il_assist,
|
||||
int il_src) const {
|
||||
const bool is_swa = hparams.is_swa(il_assist);
|
||||
|
||||
const auto * src_iswa = inp->src_mctx;
|
||||
const auto * src_cur = is_swa ? src_iswa->get_swa() : src_iswa->get_base();
|
||||
|
||||
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
|
||||
|
||||
ggml_build_forward_expand(gf, q_cur);
|
||||
|
||||
ggml_tensor * q = q_cur;
|
||||
ggml_tensor * k = src_cur->get_k(ctx0, il_src);
|
||||
ggml_tensor * v = src_cur->get_v(ctx0, il_src);
|
||||
|
||||
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il_assist);
|
||||
cb(cur, "kqv_out", il_assist);
|
||||
|
||||
if (wo) {
|
||||
cur = build_lora_mm(wo, cur, wo_s);
|
||||
}
|
||||
if (wo_b) {
|
||||
cur = ggml_add(ctx0, cur, wo_b);
|
||||
}
|
||||
return cur;
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_attn(
|
||||
llm_graph_input_attn_cross * inp,
|
||||
ggml_tensor * wo,
|
||||
|
||||
@ -459,6 +459,37 @@ public:
|
||||
const llama_kv_cache_iswa_context * mctx;
|
||||
};
|
||||
|
||||
// mask-only input for attention against an external (read-only) ISWA KV cache.
|
||||
// used by MTP draft graphs that attend to the target's KV without owning any.
|
||||
class llm_graph_input_attn_src_kv_iswa : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_attn_src_kv_iswa(
|
||||
const llama_hparams & hparams,
|
||||
const llama_cparams & cparams,
|
||||
const llama_kv_cache_iswa_context * src_mctx) :
|
||||
hparams(hparams),
|
||||
cparams(cparams),
|
||||
src_mctx(src_mctx) {
|
||||
}
|
||||
~llm_graph_input_attn_src_kv_iswa() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
bool can_reuse(const llm_graph_params & params) override;
|
||||
|
||||
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
||||
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
|
||||
|
||||
ggml_tensor * self_kq_mask = nullptr;
|
||||
ggml_tensor * self_kq_mask_cnv = nullptr;
|
||||
ggml_tensor * self_kq_mask_swa = nullptr;
|
||||
ggml_tensor * self_kq_mask_swa_cnv = nullptr;
|
||||
|
||||
const llama_hparams hparams;
|
||||
const llama_cparams cparams;
|
||||
|
||||
const llama_kv_cache_iswa_context * src_mctx;
|
||||
};
|
||||
|
||||
class llm_graph_input_attn_cross : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_attn_cross(const llama_cross * cross) : cross(cross) {}
|
||||
@ -601,6 +632,11 @@ struct llm_graph_params {
|
||||
const llama_adapter_cvec * cvec;
|
||||
const llama_adapter_loras * loras;
|
||||
const llama_memory_context_i * mctx;
|
||||
// per-decode snapshot of an external memory module the graph reads from
|
||||
// (never writes) — e.g. ctx_dft reading target KV during MTP draft.
|
||||
// nullptr for a main decode. Rebound inside reuse-aware input classes.
|
||||
const llama_memory_context_i * src_mctx;
|
||||
const llama_model * src_model;
|
||||
const llama_cross * cross;
|
||||
|
||||
std::map<llama_seq_id, llama_sampler *> samplers;
|
||||
@ -818,6 +854,8 @@ struct llm_graph_context {
|
||||
const llama_adapter_cvec * cvec;
|
||||
const llama_adapter_loras * loras;
|
||||
const llama_memory_context_i * mctx;
|
||||
const llama_memory_context_i * src_mctx;
|
||||
const llama_model * src_model;
|
||||
const llama_cross * cross;
|
||||
|
||||
std::map<llama_seq_id, llama_sampler *> samplers;
|
||||
@ -1047,6 +1085,24 @@ struct llm_graph_context {
|
||||
float kq_scale,
|
||||
int il) const;
|
||||
|
||||
llm_graph_input_attn_src_kv_iswa * build_attn_inp_src_kv_iswa() const;
|
||||
|
||||
// Q-only attention against an external ISWA KV cache (no K/V projections,
|
||||
// no writes). il_assist labels the attention block in the local graph for
|
||||
// logging; il_src indexes the source K/V layer to attend to.
|
||||
ggml_tensor * build_attn(
|
||||
llm_graph_input_attn_src_kv_iswa * inp,
|
||||
ggml_tensor * wo,
|
||||
ggml_tensor * wo_b,
|
||||
ggml_tensor * wo_s,
|
||||
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
||||
ggml_tensor * kq_b,
|
||||
ggml_tensor * sinks, // [n_head_q]
|
||||
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
||||
float kq_scale,
|
||||
int il_assist,
|
||||
int il_src) const;
|
||||
|
||||
llm_graph_input_attn_cross * build_attn_inp_cross() const;
|
||||
|
||||
ggml_tensor * build_attn(
|
||||
|
||||
@ -2462,6 +2462,10 @@ uint32_t llama_kv_cache_context::get_n_kv() const {
|
||||
return n_kv;
|
||||
}
|
||||
|
||||
llama_pos llama_kv_cache_context::seq_pos_max(llama_seq_id seq_id) const {
|
||||
return kv->seq_pos_max(seq_id);
|
||||
}
|
||||
|
||||
ggml_type llama_kv_cache_context::type_k() const {
|
||||
return kv->type_k();
|
||||
}
|
||||
|
||||
@ -354,6 +354,11 @@ public:
|
||||
|
||||
uint32_t get_n_kv() const;
|
||||
|
||||
// last position recorded in the cache for this sequence; -1 if absent.
|
||||
// exposed for cross-context KV consumers (e.g. MTP draft) that need to
|
||||
// anchor the source position without owning a memory module of their own.
|
||||
llama_pos seq_pos_max(llama_seq_id seq_id) const;
|
||||
|
||||
ggml_type type_k() const;
|
||||
ggml_type type_v() const;
|
||||
|
||||
|
||||
@ -139,6 +139,8 @@ static llama_model * llama_model_mapping(llm_arch arch, const llama_model_params
|
||||
return new llama_model_gemma3n(params);
|
||||
case LLM_ARCH_GEMMA4:
|
||||
return new llama_model_gemma4(params);
|
||||
case LLM_ARCH_GEMMA4_ASSISTANT:
|
||||
return new llama_model_gemma4_assistant(params);
|
||||
case LLM_ARCH_GEMMA_EMBEDDING:
|
||||
return new llama_model_gemma_embedding(params);
|
||||
case LLM_ARCH_STARCODER2:
|
||||
@ -2378,6 +2380,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
||||
case LLM_ARCH_GEMMA3:
|
||||
case LLM_ARCH_GEMMA3N:
|
||||
case LLM_ARCH_GEMMA4:
|
||||
case LLM_ARCH_GEMMA4_ASSISTANT:
|
||||
case LLM_ARCH_GEMMA_EMBEDDING:
|
||||
case LLM_ARCH_STARCODER2:
|
||||
case LLM_ARCH_OPENELM:
|
||||
@ -2572,6 +2575,10 @@ int32_t llama_model_n_devices(const struct llama_model * model) {
|
||||
return (int32_t)model->devices.size();
|
||||
}
|
||||
|
||||
int32_t llama_model_n_layer_kv(const struct llama_model * model) {
|
||||
return (int32_t) model->hparams.n_layer_kv();
|
||||
}
|
||||
|
||||
ggml_backend_dev_t llama_model_get_device(const struct llama_model * model, int i) {
|
||||
if (i < 0 || i >= (int)model->devices.size()) {
|
||||
return nullptr;
|
||||
|
||||
@ -548,6 +548,10 @@ struct llama_model {
|
||||
struct ggml_tensor * output_s = nullptr;
|
||||
struct ggml_tensor * output_in_s = nullptr;
|
||||
|
||||
// NextN/MTP model-level projections
|
||||
struct ggml_tensor * nextn_pre_proj = nullptr;
|
||||
struct ggml_tensor * nextn_post_proj = nullptr;
|
||||
|
||||
// classifier
|
||||
struct ggml_tensor * cls = nullptr;
|
||||
struct ggml_tensor * cls_b = nullptr;
|
||||
|
||||
@ -135,6 +135,214 @@ std::unique_ptr<llm_graph_context> llama_model_gemma4::build_arch_graph(const ll
|
||||
return std::make_unique<graph>(*this, params);
|
||||
}
|
||||
|
||||
void llama_model_gemma4_assistant::load_arch_hparams(llama_model_loader & ml) {
|
||||
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
|
||||
ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.is_swa_impl, hparams.n_layer);
|
||||
|
||||
uint32_t n_kv_shared_layers = 0;
|
||||
ml.get_key(LLM_KV_ATTENTION_SHARED_KV_LAYERS, n_kv_shared_layers, false);
|
||||
|
||||
hparams.n_layer_kv_from_start = hparams.n_layer - (int32_t) n_kv_shared_layers;
|
||||
hparams.f_attention_scale = 1.0f;
|
||||
|
||||
ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false);
|
||||
ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false);
|
||||
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_SWA, hparams.n_embd_head_k_swa);
|
||||
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_SWA, hparams.n_embd_head_v_swa);
|
||||
|
||||
if (hparams.n_layer == 4) {
|
||||
type = LLM_TYPE_31B;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_model_gemma4_assistant::load_arch_tensors(llama_model_loader &) {
|
||||
LLAMA_LOAD_LOCALS;
|
||||
|
||||
if (n_embd_head_k != n_embd_head_v) {
|
||||
throw std::runtime_error("Gemma 4 assistant requires n_embd_head_k == n_embd_head_v");
|
||||
}
|
||||
if (hparams.n_embd_head_k_swa != hparams.n_embd_head_v_swa) {
|
||||
throw std::runtime_error("Gemma 4 assistant requires n_embd_head_k_swa == n_embd_head_v_swa");
|
||||
}
|
||||
if (hparams.n_embd_out() == n_embd) {
|
||||
throw std::runtime_error("Gemma 4 assistant requires embedding_length_out to carry the target hidden size");
|
||||
}
|
||||
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0);
|
||||
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED);
|
||||
|
||||
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0);
|
||||
|
||||
const int64_t n_embd_backbone = hparams.n_embd_out();
|
||||
nextn_pre_proj = create_tensor(tn(LLM_TENSOR_NEXTN_PRE_PROJ, "weight"), { 2*n_embd_backbone, n_embd }, 0);
|
||||
nextn_post_proj = create_tensor(tn(LLM_TENSOR_NEXTN_POST_PROJ, "weight"), { n_embd, n_embd_backbone }, 0);
|
||||
|
||||
int rope_freqs_flag = 0;
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = layers[i];
|
||||
|
||||
const int64_t n_head = hparams.n_head(i);
|
||||
const int64_t n_embd_head = hparams.n_embd_head_k(i);
|
||||
const int64_t n_ff = hparams.n_ff(i);
|
||||
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
|
||||
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head*n_head }, 0);
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head*n_head, n_embd }, 0);
|
||||
|
||||
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head }, 0);
|
||||
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0);
|
||||
|
||||
layer.out_scale = create_tensor(tn(LLM_TENSOR_LAYER_OUT_SCALE, "weight", i), { 1u }, 0);
|
||||
|
||||
if (!hparams.is_swa(i)) {
|
||||
layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), { n_embd_head/2 }, rope_freqs_flag);
|
||||
rope_freqs_flag = TENSOR_DUPLICATED;
|
||||
}
|
||||
|
||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0);
|
||||
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0);
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0);
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0);
|
||||
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), { n_embd }, 0);
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<llm_graph_context> llama_model_gemma4_assistant::build_arch_graph(const llm_graph_params & params) const {
|
||||
return std::make_unique<graph>(*this, params);
|
||||
}
|
||||
|
||||
llama_model_gemma4_assistant::graph::graph(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_graph_context(params) {
|
||||
GGML_ASSERT(src_mctx && "Gemma 4 assistant graph requires an MTP source (llama_set_mtp_source)");
|
||||
GGML_ASSERT(src_model && "Gemma 4 assistant graph requires a source model");
|
||||
GGML_ASSERT(src_model->tok_embd && "source model missing tok_embd");
|
||||
|
||||
const auto & src_hparams = src_model->hparams;
|
||||
|
||||
// By convention the MTP draft reads from the trunk's final SWA and full layers.
|
||||
const int32_t src_layer_full = (int32_t) src_hparams.n_layer - 1;
|
||||
const int32_t src_layer_swa = (int32_t) src_hparams.n_layer - 2;
|
||||
GGML_ASSERT(!src_hparams.is_swa(src_layer_full) && "trunk's last layer must be full attention");
|
||||
GGML_ASSERT( src_hparams.is_swa(src_layer_swa) && "trunk's penultimate layer must be SWA");
|
||||
|
||||
const int64_t n_embd_backbone = hparams.n_embd_out();
|
||||
|
||||
ggml_tensor * inp_tokens;
|
||||
ggml_tensor * inp_h;
|
||||
{
|
||||
auto inp = std::make_unique<llm_graph_input_embd>(n_embd_backbone);
|
||||
|
||||
inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
|
||||
cb(inp->tokens, "inp_tokens", -1);
|
||||
ggml_set_input(inp->tokens);
|
||||
inp_tokens = inp->tokens;
|
||||
res->t_inp_tokens = inp->tokens;
|
||||
|
||||
inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd_backbone, ubatch.n_tokens);
|
||||
cb(inp->embd, "inp_h", -1);
|
||||
ggml_set_input(inp->embd);
|
||||
inp_h = inp->embd;
|
||||
res->t_inp_embd = inp->embd;
|
||||
|
||||
res->add_input(std::move(inp));
|
||||
}
|
||||
|
||||
ggml_tensor * x = ggml_get_rows(ctx0, src_model->tok_embd, inp_tokens);
|
||||
x = ggml_scale(ctx0, x, sqrtf((float) n_embd_backbone));
|
||||
cb(x, "inp_embd_target", -1);
|
||||
|
||||
ggml_tensor * xh = ggml_concat(ctx0, x, inp_h, 0);
|
||||
cb(xh, "inp_xh", -1);
|
||||
|
||||
ggml_tensor * cur = ggml_mul_mat(ctx0, model.nextn_pre_proj, xh);
|
||||
cb(cur, "pre_proj", -1);
|
||||
|
||||
auto * inp_attn = build_attn_inp_src_kv_iswa();
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
ggml_tensor * inpL = cur;
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
const bool is_swa = hparams.is_swa(il);
|
||||
const int32_t il_src = is_swa ? src_layer_swa : src_layer_full;
|
||||
|
||||
const int64_t n_embd_head = hparams.n_embd_head_k(il);
|
||||
const int64_t n_head = hparams.n_head(il);
|
||||
|
||||
const float freq_base_l = model.get_rope_freq_base(cparams, il);
|
||||
const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
|
||||
const int n_rot_l = hparams.n_rot(il);
|
||||
|
||||
ggml_tensor * cur_norm = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il);
|
||||
cb(cur_norm, "attn_norm", il);
|
||||
|
||||
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur_norm);
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il);
|
||||
cb(Qcur, "Qcur_normed", il);
|
||||
|
||||
ggml_tensor * freq_factors = is_swa ? nullptr : model.layers[il].rope_freqs;
|
||||
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, freq_factors, n_rot_l, rope_type, n_ctx_orig,
|
||||
freq_base_l, freq_scale_l, ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
cb(Qcur, "Qcur_pos", il);
|
||||
|
||||
cur = build_attn(inp_attn, model.layers[il].wo, nullptr, nullptr,
|
||||
Qcur, nullptr, nullptr, nullptr, hparams.f_attention_scale, il, il_src);
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
||||
}
|
||||
|
||||
cur = build_norm(cur, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, il);
|
||||
cb(cur, "attn_post_norm", il);
|
||||
|
||||
ggml_tensor * attn_out = ggml_add(ctx0, cur, inpL);
|
||||
cb(attn_out, "attn_out", il);
|
||||
|
||||
cur = build_norm(attn_out, model.layers[il].ffn_norm, nullptr, LLM_NORM_RMS, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
cur = build_ffn(cur,
|
||||
model.layers[il].ffn_up, nullptr, nullptr,
|
||||
model.layers[il].ffn_gate, nullptr, nullptr,
|
||||
model.layers[il].ffn_down, nullptr, nullptr,
|
||||
nullptr,
|
||||
LLM_FFN_GELU, LLM_FFN_PAR, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
|
||||
cur = build_norm(cur, model.layers[il].ffn_post_norm, nullptr, LLM_NORM_RMS, -1);
|
||||
cb(cur, "ffn_post_norm", il);
|
||||
|
||||
cur = ggml_add(ctx0, cur, attn_out);
|
||||
|
||||
cur = ggml_mul(ctx0, cur, model.layers[il].out_scale);
|
||||
cb(cur, "out_scaled", il);
|
||||
|
||||
inpL = cur;
|
||||
}
|
||||
cur = inpL;
|
||||
|
||||
cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1);
|
||||
cb(cur, "result_norm", -1);
|
||||
res->t_embd = cur;
|
||||
|
||||
ggml_tensor * logits = build_lora_mm(model.output, cur);
|
||||
cb(logits, "result_output", -1);
|
||||
res->t_logits = logits;
|
||||
|
||||
ggml_tensor * h_next = ggml_mul_mat(ctx0, model.nextn_post_proj, cur);
|
||||
cb(h_next, "h_nextn", -1);
|
||||
res->t_h_nextn = h_next;
|
||||
|
||||
ggml_build_forward_expand(gf, logits);
|
||||
ggml_build_forward_expand(gf, h_next);
|
||||
}
|
||||
|
||||
// get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim
|
||||
static ggml_tensor * ggml_view_2d_slice(ggml_context * ctx0, ggml_tensor * x, int idx) {
|
||||
GGML_ASSERT(idx < (int) x->ne[2]);
|
||||
@ -270,7 +478,8 @@ llama_model_gemma4::graph::graph(const llama_model & model, const llm_graph_para
|
||||
}
|
||||
|
||||
// TODO @ngxson : strip unused token right after the last KV layer to speed up prompt processing
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
// keep all rows when extracting unmasked nextn embeddings (MTP target needs the hidden state for every token)
|
||||
if (il == n_layer - 1 && inp_out_ids && cparams.embeddings_nextn_masked) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
||||
}
|
||||
@ -370,7 +579,7 @@ llama_model_gemma4::graph::graph(const llama_model & model, const llm_graph_para
|
||||
ggml_tensor * inp_this_layer = ggml_view_2d_slice(ctx0, inp_per_layer, il); // [n_embd_per_layer, n_tokens]
|
||||
|
||||
// TODO @ngxson : improve this
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
if (il == n_layer - 1 && inp_out_ids && cparams.embeddings_nextn_masked) {
|
||||
inp_this_layer = ggml_get_rows(ctx0, inp_this_layer, inp_out_ids);
|
||||
}
|
||||
|
||||
@ -401,6 +610,17 @@ llama_model_gemma4::graph::graph(const llama_model & model, const llm_graph_para
|
||||
model.output_norm, nullptr,
|
||||
LLM_NORM_RMS, -1);
|
||||
|
||||
// Expose the post-output-norm hidden state (the LM-head input feature) so that
|
||||
// MTP draft contexts can read it via llama_get_embeddings_nextn_ith() as the
|
||||
// recurrent h input. This matches the reference (transformers/vLLM/SGLang),
|
||||
// which feeds the drafter the target's post-final-norm hidden state.
|
||||
cb(cur, "h_nextn", -1);
|
||||
res->t_h_nextn = cur;
|
||||
|
||||
if (!cparams.embeddings_nextn_masked && inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
}
|
||||
|
||||
cb(cur, "result_norm", -1);
|
||||
res->t_embd = cur;
|
||||
|
||||
|
||||
@ -822,6 +822,19 @@ struct llama_model_gemma4 : public llama_model_base {
|
||||
};
|
||||
|
||||
|
||||
struct llama_model_gemma4_assistant : public llama_model_base {
|
||||
llama_model_gemma4_assistant(const struct llama_model_params & params) : llama_model_base(params) {}
|
||||
void load_arch_hparams(llama_model_loader & ml) override;
|
||||
void load_arch_tensors(llama_model_loader & ml) override;
|
||||
|
||||
struct graph : public llm_graph_context {
|
||||
graph(const llama_model & model, const llm_graph_params & params);
|
||||
};
|
||||
|
||||
std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override;
|
||||
};
|
||||
|
||||
|
||||
struct llama_model_gemma_embedding : public llama_model_base {
|
||||
llama_model_gemma_embedding(const struct llama_model_params & params) : llama_model_base(params) {}
|
||||
void load_arch_hparams(llama_model_loader & ml) override;
|
||||
|
||||
@ -10,6 +10,7 @@
|
||||
#include "common.h"
|
||||
#include "fit.h"
|
||||
#include "llama.h"
|
||||
#include "../../src/llama-ext.h" // staging API: llama_set_mtp_source
|
||||
#include "log.h"
|
||||
#include "sampling.h"
|
||||
#include "speculative.h"
|
||||
@ -949,6 +950,11 @@ private:
|
||||
cparams.n_rs_seq = 0;
|
||||
ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams));
|
||||
|
||||
if (spec_mtp) {
|
||||
// MTP draft must know its target before the first decode
|
||||
llama_set_mtp_source(ctx_dft.get(), ctx_tgt);
|
||||
}
|
||||
|
||||
ctx_dft_seq_rm_type = common_context_can_seq_rm(ctx_dft.get());
|
||||
|
||||
params_base.speculative.draft.ctx_tgt = ctx_tgt;
|
||||
@ -971,6 +977,10 @@ private:
|
||||
return false;
|
||||
}
|
||||
|
||||
// wire the source before any decode (the seq-rm probe below
|
||||
// triggers sched_reserve which needs src for Gemma4-style MTP)
|
||||
llama_set_mtp_source(ctx_dft.get(), ctx_tgt);
|
||||
|
||||
ctx_dft_seq_rm_type = common_context_can_seq_rm(ctx_dft.get());
|
||||
|
||||
params_base.speculative.draft.ctx_tgt = ctx_tgt;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user