mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-06-28 04:30:15 -05:00
Add MTP Support for Gemma 4 (#1744)
* gemma-mtp: build the arch to load the MTP model * gemma-mtp: fix mtp kv state * gemma-mtp: refactor some functions and create gguf * gemma-mtp: make usable for embeddings models variant * gemma-mtp: fix qwen mtp load in graph split * gemma-mtp: refactor tensor creation and adjust output tensor handling * Gemma 4 MTP: improve tensor handling, and adjust split mode logic
This commit is contained in:
parent
ab0f22b819
commit
c2b8bca807
@ -1101,8 +1101,11 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
||||
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MOD;
|
||||
} else if (value == "suffix") {
|
||||
params.speculative.type = COMMON_SPECULATIVE_TYPE_SUFFIX;
|
||||
} else if (value == "mtp") {
|
||||
params.speculative.type = COMMON_SPECULATIVE_TYPE_MTP;
|
||||
params.has_mtp = true;
|
||||
} else {
|
||||
throw std::invalid_argument("unknown speculative decoding type without draft model");
|
||||
throw std::invalid_argument("unknown speculative decoding type");
|
||||
}
|
||||
return true;
|
||||
}
|
||||
@ -2760,7 +2763,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
|
||||
" per-step save SSM state per draft step in VRAM; no re-decode on rejection\n"
|
||||
" gpu-fallback copy state to GPU buffer; re-decode on rejection\n"
|
||||
" cpu serialise state via llama_state_seq; re-decode on rejection" });
|
||||
options.push_back({ "*", "--spec-type Name [none | ngram - cache | ngram - simple | ngram - map - k | ngram - map - k4v | ngram - mod | suffix]", "type of speculative decoding to use when no draft model is provided (default: %d)\n", (int)params.speculative.type});
|
||||
options.push_back({ "*", "--spec-type Name [none | mtp | ngram - cache | ngram - simple | ngram - map - k | ngram - map - k4v | ngram - mod | suffix]", "type of speculative decoding to use (default: %d)\n", (int)params.speculative.type});
|
||||
options.push_back({ "*", "--spec-ngram-size-n N", "ngram size N for ngram-simple/ngram-map speculative decoding, length of lookup n-gram (default: %d)\n",params.speculative.ngram_size_n });
|
||||
|
||||
options.push_back({ "*", "--spec-ngram-size-m N", "ngram size M for ngram-simple/ngram-map speculative decoding, length of draft m-gram (default: %d)\n", params.speculative.ngram_size_m });
|
||||
@ -3355,11 +3358,9 @@ std::string fs_get_cache_file(const std::string & filename) {
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// Model utils
|
||||
//
|
||||
struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
|
||||
llama_init_result iparams;
|
||||
|
||||
auto mparams = common_model_params_to_llama(params);
|
||||
|
||||
llama_model * model = nullptr;
|
||||
|
||||
@ -19,6 +19,9 @@
|
||||
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
|
||||
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
|
||||
|
||||
void llama_set_mtp_target_context(struct llama_context * ctx, struct llama_context * target_ctx);
|
||||
uint32_t llama_mtp_state_n_embd(const struct llama_context * ctx);
|
||||
|
||||
const std::vector<enum common_speculative_type> common_speculative_types = {
|
||||
COMMON_SPECULATIVE_TYPE_NONE,
|
||||
COMMON_SPECULATIVE_TYPE_DRAFT,
|
||||
@ -154,27 +157,28 @@ struct common_speculative_state_mtp : public common_speculative_state {
|
||||
llama_context * ctx_tgt;
|
||||
llama_context * ctx_mtp = nullptr;
|
||||
common_sampler * smpl;
|
||||
// For Gemma 4 external MTP assistant: draft positions are held constant
|
||||
bool constant_draft_positions = false;
|
||||
|
||||
common_speculative_state_mtp(
|
||||
enum common_speculative_type type,
|
||||
llama_context * ctx_tgt,
|
||||
const llama_context_params & mtp_cparams)
|
||||
llama_context * ctx_mtp,
|
||||
bool constant_draft_positions = false)
|
||||
: common_speculative_state(type)
|
||||
, ctx_tgt(ctx_tgt)
|
||||
, ctx_mtp(ctx_mtp)
|
||||
, constant_draft_positions(constant_draft_positions)
|
||||
{
|
||||
struct common_params_sampling params;
|
||||
params.samplers_sequence = {
|
||||
struct common_params_sampling sparams;
|
||||
sparams.samplers_sequence = {
|
||||
llama_sampler_type::DIST,
|
||||
};
|
||||
smpl = common_sampler_init(llama_get_model(ctx_tgt), params);
|
||||
smpl = common_sampler_init(llama_get_model(ctx_mtp), sparams);
|
||||
llama_set_mtp_target_context(ctx_mtp, ctx_tgt);
|
||||
|
||||
const llama_model * model = llama_get_model(ctx_tgt);
|
||||
ctx_mtp = llama_init_from_model(const_cast<llama_model *>(model), mtp_cparams);
|
||||
if (ctx_mtp) {
|
||||
LOG_INF("%s: created MTP context (n_ctx=%d)\n", __func__, llama_n_ctx(ctx_mtp));
|
||||
} else {
|
||||
LOG_ERR("%s: failed to create MTP context\n", __func__);
|
||||
}
|
||||
LOG_INF("%s: MTP context ready (n_ctx=%d, constant_draft_positions=%s)\n", __func__,
|
||||
llama_n_ctx(ctx_mtp), constant_draft_positions ? "true" : "false");
|
||||
}
|
||||
|
||||
~common_speculative_state_mtp() override {
|
||||
@ -211,7 +215,8 @@ struct common_speculative_state_mtp : public common_speculative_state {
|
||||
params.p_min,
|
||||
id_last,
|
||||
n_past,
|
||||
seq_id
|
||||
seq_id,
|
||||
constant_draft_positions
|
||||
);
|
||||
}
|
||||
|
||||
@ -1029,9 +1034,9 @@ common_speculative * common_speculative_init(
|
||||
// Compute the implementations to use based on the config and their order of preference
|
||||
std::vector<common_speculative_config> configs = {}; // list of speculative configs to try
|
||||
{
|
||||
bool has_draft = !params.mparams_dft.path.empty();
|
||||
bool has_draft_eagle3 = false; // TODO PR-18039: if params.speculative.eagle3
|
||||
bool has_mtp = (params.type == COMMON_SPECULATIVE_TYPE_MTP);
|
||||
bool has_draft = !params.mparams_dft.path.empty() && !has_mtp;
|
||||
|
||||
bool has_ngram_cache = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_CACHE);
|
||||
bool has_ngram_simple = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE);
|
||||
@ -1102,15 +1107,20 @@ common_speculative * common_speculative_init(
|
||||
break;
|
||||
}
|
||||
case COMMON_SPECULATIVE_TYPE_MTP: {
|
||||
auto mtp_state = std::make_unique<common_speculative_state_mtp>(config.type,
|
||||
/* .ctx_tgt = */ ctx_tgt,
|
||||
/* .mtp_cparams = */ params.cparams_dft
|
||||
);
|
||||
if (!mtp_state->ctx_mtp) {
|
||||
LOG_ERR("%s: failed to create MTP context\n", __func__);
|
||||
return nullptr;
|
||||
llama_context * ctx_mtp = ctx_dft;
|
||||
if (!ctx_mtp) {
|
||||
const llama_model * model = llama_get_model(ctx_tgt);
|
||||
ctx_mtp = llama_init_from_model(const_cast<llama_model *>(model), params.cparams_dft);
|
||||
if (!ctx_mtp) {
|
||||
LOG_ERR("%s: failed to create MTP context\n", __func__);
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
impls.push_back(std::move(mtp_state));
|
||||
ctx_dft = nullptr;
|
||||
|
||||
const bool use_constant_draft_positions = llama_model_is_gemma4_mtp_assistant(llama_get_model(ctx_mtp));
|
||||
impls.push_back(std::make_unique<common_speculative_state_mtp>(
|
||||
config.type, ctx_tgt, ctx_mtp, use_constant_draft_positions));
|
||||
break;
|
||||
}
|
||||
case COMMON_SPECULATIVE_TYPE_EAGLE3: {
|
||||
@ -1224,7 +1234,7 @@ static mtp_last_embd & mtp_get_last_embd(const llama_context * ctx) {
|
||||
static std::unordered_map<const llama_context *, mtp_last_embd> map;
|
||||
auto & last = map[ctx];
|
||||
if (last.embd.empty()) {
|
||||
auto n_embd = llama_model_n_embd(llama_get_model(ctx));
|
||||
auto n_embd = llama_mtp_state_n_embd(ctx);
|
||||
last.embd.resize(n_embd);
|
||||
}
|
||||
return last;
|
||||
@ -1377,7 +1387,8 @@ std::vector<llama_token> mtp_speculative_gen_draft(
|
||||
float p_min,
|
||||
llama_token id_last,
|
||||
int32_t n_past,
|
||||
llama_seq_id seq_id) {
|
||||
llama_seq_id seq_id,
|
||||
bool constant_draft_positions) {
|
||||
|
||||
llama_tokens drafts;
|
||||
drafts.reserve(n_draft);
|
||||
@ -1394,7 +1405,7 @@ std::vector<llama_token> mtp_speculative_gen_draft(
|
||||
|
||||
llama_token current_input_id = id_last;
|
||||
int32_t current_n_past = n_past;
|
||||
const int n_embd = llama_model_n_embd(llama_get_model(ctx));
|
||||
const int n_embd = llama_mtp_state_n_embd(ctx);
|
||||
|
||||
auto & last = mtp_get_last_embd(ctx);
|
||||
int i0 = 0;
|
||||
@ -1415,7 +1426,8 @@ std::vector<llama_token> mtp_speculative_gen_draft(
|
||||
int n_decode = 0;
|
||||
for (int i = i0; i < n_draft; ++i) {
|
||||
mtp_batch.n_tokens = 0;
|
||||
common_batch_add(mtp_batch, current_input_id, current_n_past, {seq_id}, true);
|
||||
const int32_t draft_pos = constant_draft_positions ? n_past : current_n_past;
|
||||
common_batch_add(mtp_batch, current_input_id, draft_pos, {seq_id}, true);
|
||||
|
||||
++n_decode;
|
||||
if (llama_decode(ctx, mtp_batch) != 0) {
|
||||
|
||||
@ -60,7 +60,8 @@ std::vector<llama_token> mtp_speculative_gen_draft(
|
||||
float p_min,
|
||||
llama_token id_last,
|
||||
int32_t n_past,
|
||||
llama_seq_id seq_id);
|
||||
llama_seq_id seq_id,
|
||||
bool constant_draft_positions = false);
|
||||
|
||||
void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, bool is_prompt_warmup);
|
||||
|
||||
|
||||
@ -3175,6 +3175,267 @@ class Gemma2Model(Model):
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
|
||||
class Gemma4BaseModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.GEMMA4
|
||||
|
||||
def _text_hparams(self) -> dict[str, Any]:
|
||||
text_hparams = self.hparams.get("text_config")
|
||||
if isinstance(text_hparams, dict):
|
||||
return text_hparams
|
||||
return self.hparams
|
||||
|
||||
def _arch_name(self) -> str:
|
||||
return gguf.MODEL_ARCH_NAMES[self.model_arch]
|
||||
|
||||
def find_hparam(self, keys: Iterable[str], optional: bool = False) -> Any:
|
||||
text_hparams = self.hparams.get("text_config")
|
||||
if isinstance(text_hparams, dict):
|
||||
for key in keys:
|
||||
if key in text_hparams:
|
||||
return text_hparams[key]
|
||||
return super().find_hparam(keys, optional)
|
||||
|
||||
def set_vocab(self):
|
||||
vocab = gguf.LlamaHfVocab(self.dir_model)
|
||||
tokens = []
|
||||
scores = []
|
||||
toktypes = []
|
||||
visible_tokens = {
|
||||
"<|channel>",
|
||||
"<channel|>",
|
||||
"<|tool_call>",
|
||||
"<tool_call|>",
|
||||
"<|tool_response>",
|
||||
"<tool_response|>",
|
||||
"<|\"|>",
|
||||
}
|
||||
|
||||
for text, score, toktype in vocab.all_tokens():
|
||||
tokens.append(text)
|
||||
scores.append(score)
|
||||
text_str = text.decode()
|
||||
if text_str in visible_tokens:
|
||||
toktypes.append(gguf.TokenType.USER_DEFINED)
|
||||
logger.info(f"Token {text_str!r} is set to USER_DEFINED")
|
||||
else:
|
||||
toktypes.append(toktype)
|
||||
|
||||
assert len(tokens) == vocab.vocab_size
|
||||
|
||||
self.gguf_writer.add_tokenizer_model("gemma4")
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_scores(scores)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
|
||||
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
self.gguf_writer.add_add_space_prefix(False)
|
||||
self.gguf_writer.add_add_bos_token(True)
|
||||
|
||||
|
||||
@Model.register("Gemma4ForConditionalGeneration")
|
||||
class Gemma4Model(Gemma4BaseModel):
|
||||
model_arch = gguf.MODEL_ARCH.GEMMA4
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
hparams = self._text_hparams()
|
||||
block_count = hparams["num_hidden_layers"]
|
||||
arch = self._arch_name()
|
||||
|
||||
self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
|
||||
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_head_count(hparams["num_attention_heads"])
|
||||
self.gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"])
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
|
||||
swa_layers = [layer_type == "sliding_attention" for layer_type in hparams["layer_types"]]
|
||||
self.gguf_writer.add_sliding_window(hparams["sliding_window"])
|
||||
self.gguf_writer.add_sliding_window_pattern(swa_layers)
|
||||
|
||||
num_kv_shared_layers = hparams.get("num_kv_shared_layers", 0)
|
||||
self.gguf_writer.add_shared_kv_layers(num_kv_shared_layers)
|
||||
|
||||
n_ff = hparams["intermediate_size"]
|
||||
if hparams.get("use_double_wide_mlp", False):
|
||||
first_kv_shared_layer_idx = block_count - num_kv_shared_layers
|
||||
n_ff_arr = [n_ff if il < first_kv_shared_layer_idx else n_ff * 2 for il in range(block_count)]
|
||||
self.gguf_writer.add_feed_forward_length(n_ff_arr)
|
||||
else:
|
||||
self.gguf_writer.add_feed_forward_length(n_ff)
|
||||
|
||||
expert_intermediate_size = hparams.get("expert_intermediate_size") or hparams.get("moe_intermediate_size")
|
||||
if expert_intermediate_size is not None:
|
||||
self.gguf_writer.add_expert_feed_forward_length(expert_intermediate_size)
|
||||
|
||||
n_pl_embd = hparams.get("hidden_size_per_layer_input") or 0
|
||||
self.gguf_writer.add_embedding_length_per_layer_input(n_pl_embd)
|
||||
|
||||
head_dim_full = int(hparams["global_head_dim"])
|
||||
head_dim_swa = int(hparams["head_dim"])
|
||||
self.gguf_writer.add_key_length(head_dim_full)
|
||||
self.gguf_writer.add_value_length(head_dim_full)
|
||||
self.gguf_writer.add_uint32(f"{arch}.attention.key_length_swa", head_dim_swa)
|
||||
self.gguf_writer.add_uint32(f"{arch}.attention.value_length_swa", head_dim_swa)
|
||||
|
||||
num_kv_full = hparams.get("num_global_key_value_heads")
|
||||
num_kv_swa = hparams.get("num_key_value_heads")
|
||||
if num_kv_full is not None and num_kv_swa is not None:
|
||||
kv_heads = [num_kv_swa if is_swa else num_kv_full for is_swa in swa_layers]
|
||||
self.gguf_writer.add_head_count_kv(kv_heads)
|
||||
elif num_kv_swa is not None:
|
||||
self.gguf_writer.add_head_count_kv(num_kv_swa)
|
||||
|
||||
rope_parameters = hparams.get("rope_parameters", {})
|
||||
rope_full = rope_parameters.get("full_attention", {})
|
||||
rope_swa = rope_parameters.get("sliding_attention", {})
|
||||
self.gguf_writer.add_rope_dimension_count(head_dim_full)
|
||||
partial_rotary_factor_swa = float(rope_swa.get("partial_rotary_factor", hparams.get("partial_rotary_factor", 1.0)))
|
||||
self.gguf_writer.add_uint32(f"{arch}.rope.dimension_count_swa", int(head_dim_swa * partial_rotary_factor_swa))
|
||||
self.gguf_writer.add_rope_freq_base(float(rope_full.get("rope_theta", 1000000.0)))
|
||||
self.gguf_writer.add_float32(f"{arch}.rope.freq_base_swa", float(rope_swa.get("rope_theta", 10000.0)))
|
||||
|
||||
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
|
||||
hparams = self._text_hparams()
|
||||
rope_params_full = hparams["rope_parameters"]["full_attention"]
|
||||
assert rope_params_full["rope_type"] == "proportional"
|
||||
|
||||
head_dim_full = int(hparams["global_head_dim"])
|
||||
partial_rotary_factor_full = rope_params_full["partial_rotary_factor"]
|
||||
n_rot_full = int(head_dim_full * partial_rotary_factor_full / 2)
|
||||
n_unrot_full = int(head_dim_full / 2) - n_rot_full
|
||||
values = [1.0] * n_rot_full + [1e30] * n_unrot_full
|
||||
rope_freqs_full = torch.tensor(values, dtype=torch.float32)
|
||||
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), rope_freqs_full)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
if name.endswith("per_dim_scale") or name.endswith("layer_scalar"):
|
||||
name = name + ".weight"
|
||||
|
||||
if "language_model." not in name and "rope_freqs" not in name:
|
||||
return []
|
||||
|
||||
name = name.replace("language_model.", "")
|
||||
|
||||
if name == "lm_head.weight":
|
||||
logger.debug(f"Skipping get tensor {name!r} in safetensors so that convert can end normally.")
|
||||
return []
|
||||
|
||||
if name.endswith("router.scale"):
|
||||
return [(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_INP, bid, ".scale"), data_torch)]
|
||||
|
||||
if ".per_expert_scale" in name:
|
||||
return [(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_DOWN_EXP, bid, ".scale"), data_torch)]
|
||||
|
||||
if ".experts." in name and not name.endswith(".weight"):
|
||||
name += ".weight"
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
|
||||
@Model.register("Gemma4AssistantForCausalLM")
|
||||
class Gemma4AssistantModel(Gemma4BaseModel):
|
||||
model_arch = gguf.MODEL_ARCH.GEMMA4_MTP
|
||||
|
||||
_root_tensor_map = {
|
||||
"model.embed_tokens.weight": "token_embd.weight",
|
||||
"model.norm.weight": "output_norm.weight",
|
||||
"pre_projection.weight": "mtp_pre_proj.weight",
|
||||
"post_projection.weight": "mtp_post_proj.weight",
|
||||
"masked_embedding.centroids.weight": "mtp_centroids.weight",
|
||||
"masked_embedding.token_ordering": "mtp_token_ordering.weight",
|
||||
"token_ordering": "mtp_token_ordering.weight",
|
||||
"token_ordering.weight": "mtp_token_ordering.weight",
|
||||
"model.token_ordering": "mtp_token_ordering.weight",
|
||||
"model.token_ordering.weight": "mtp_token_ordering.weight",
|
||||
"centroids": "mtp_centroids.weight",
|
||||
"centroids.weight": "mtp_centroids.weight",
|
||||
"model.centroids": "mtp_centroids.weight",
|
||||
"model.centroids.weight": "mtp_centroids.weight",
|
||||
}
|
||||
|
||||
_layer_tensor_map = {
|
||||
"input_layernorm.weight": "attn_norm.weight",
|
||||
"self_attn.q_proj.weight": "attn_q.weight",
|
||||
"self_attn.q_norm.weight": "attn_q_norm.weight",
|
||||
"self_attn.o_proj.weight": "attn_output.weight",
|
||||
"post_attention_layernorm.weight": "post_attention_norm.weight",
|
||||
"pre_feedforward_layernorm.weight": "ffn_norm.weight",
|
||||
"mlp.gate_proj.weight": "ffn_gate.weight",
|
||||
"mlp.up_proj.weight": "ffn_up.weight",
|
||||
"mlp.down_proj.weight": "ffn_down.weight",
|
||||
"post_feedforward_layernorm.weight": "post_ffw_norm.weight",
|
||||
"layer_scalar": "layer_output_scale.weight",
|
||||
"layer_scalar.weight": "layer_output_scale.weight",
|
||||
}
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
hparams = self._text_hparams()
|
||||
arch = self._arch_name()
|
||||
sliding_pattern = [layer_type == "sliding_attention" for layer_type in hparams["layer_types"]]
|
||||
|
||||
head_dim_swa = int(hparams["head_dim"])
|
||||
head_dim_full = int(hparams.get("global_head_dim") or head_dim_swa)
|
||||
n_kv_swa = int(hparams["num_key_value_heads"])
|
||||
n_kv_full = int(hparams.get("num_global_key_value_heads") or n_kv_swa)
|
||||
n_kv = [n_kv_swa if is_sliding else n_kv_full for is_sliding in sliding_pattern]
|
||||
|
||||
self.gguf_writer.add_context_length(int(hparams["max_position_embeddings"]))
|
||||
self.gguf_writer.add_embedding_length(int(hparams["hidden_size"]))
|
||||
self.gguf_writer.add_block_count(int(hparams["num_hidden_layers"]))
|
||||
self.gguf_writer.add_feed_forward_length(int(hparams["intermediate_size"]))
|
||||
self.gguf_writer.add_head_count(int(hparams["num_attention_heads"]))
|
||||
self.gguf_writer.add_head_count_kv(n_kv)
|
||||
self.gguf_writer.add_key_length(head_dim_full)
|
||||
self.gguf_writer.add_value_length(head_dim_full)
|
||||
self.gguf_writer.add_uint32(f"{arch}.attention.key_length_swa", head_dim_swa)
|
||||
self.gguf_writer.add_uint32(f"{arch}.attention.value_length_swa", head_dim_swa)
|
||||
self.gguf_writer.add_layer_norm_rms_eps(float(hparams["rms_norm_eps"]))
|
||||
self.gguf_writer.add_sliding_window(int(hparams["sliding_window"]))
|
||||
self.gguf_writer.add_array(f"{arch}.attention.sliding_window_pattern", sliding_pattern)
|
||||
self.gguf_writer.add_rope_dimension_count(head_dim_full)
|
||||
self.gguf_writer.add_uint32(f"{arch}.rope.dimension_count_swa", head_dim_swa)
|
||||
|
||||
rope_parameters = hparams.get("rope_parameters", {})
|
||||
rope_full = rope_parameters.get("full_attention", {})
|
||||
rope_swa = rope_parameters.get("sliding_attention", {})
|
||||
self.gguf_writer.add_rope_freq_base(float(rope_full.get("rope_theta", 1000000.0)))
|
||||
self.gguf_writer.add_float32(f"{arch}.rope.freq_base_swa", float(rope_swa.get("rope_theta", 10000.0)))
|
||||
|
||||
self.gguf_writer.add_uint32(f"{arch}.backbone_embedding_length", int(self.hparams["backbone_hidden_size"]))
|
||||
self.gguf_writer.add_bool(f"{arch}.use_ordered_embeddings", bool(self.hparams.get("use_ordered_embeddings", False)))
|
||||
self.gguf_writer.add_uint32(f"{arch}.centroid_count", int(self.hparams.get("num_centroids", 0)))
|
||||
self.gguf_writer.add_uint32(f"{arch}.centroid_top_k", int(self.hparams.get("centroid_intermediate_top_k", 0)))
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
del bid # unused
|
||||
|
||||
mapped_name = self._root_tensor_map.get(name)
|
||||
if mapped_name is not None:
|
||||
if mapped_name == "mtp_token_ordering.weight":
|
||||
n_vocab = int(data_torch.shape[0])
|
||||
n_centroids = int(self.hparams.get("num_centroids", 2048))
|
||||
tokens_per_centroid = n_vocab // n_centroids
|
||||
inv_ordering = torch.zeros(n_vocab, dtype=torch.int32)
|
||||
tok_ord_i32 = data_torch.to(dtype=torch.int64)
|
||||
inv_ordering[tok_ord_i32] = torch.arange(n_vocab, dtype=torch.int32)
|
||||
token_to_centroid = (inv_ordering // tokens_per_centroid).to(dtype=torch.int32)
|
||||
return [(mapped_name, token_to_centroid)]
|
||||
return [(mapped_name, data_torch)]
|
||||
|
||||
prefix = "model.layers."
|
||||
if not name.startswith(prefix):
|
||||
raise ValueError(f"Unsupported Gemma 4 assistant tensor: {name}")
|
||||
|
||||
layer_id, suffix = name[len(prefix):].split(".", 1)
|
||||
mapped_suffix = self._layer_tensor_map.get(suffix)
|
||||
if mapped_suffix is None:
|
||||
raise ValueError(f"Unsupported Gemma 4 assistant tensor: {name}")
|
||||
|
||||
return [(f"blk.{layer_id}.{mapped_suffix}", data_torch)]
|
||||
|
||||
|
||||
@Model.register("Starcoder2ForCausalLM")
|
||||
class StarCoder2Model(Model):
|
||||
model_arch = gguf.MODEL_ARCH.STARCODER2
|
||||
|
||||
@ -16,6 +16,8 @@
|
||||
#include <regex>
|
||||
#include <exception>
|
||||
|
||||
uint32_t llama_mtp_state_n_embd(const struct llama_context * ctx);
|
||||
|
||||
static void server_prompt_checkpoint_update(server_prompt_checkpoint & ckpt, llama_context * ctx, int id, int64_t n_tokens, llama_pos pos_min = -1, llama_pos pos_max = -1, int32_t offset = 0) {
|
||||
if (pos_min == -1) {
|
||||
pos_min = llama_kv_cache_seq_pos_min(ctx, id);
|
||||
@ -44,6 +46,97 @@ static void log_text(const gpt_params & params_base, const std::string & text) {
|
||||
}
|
||||
}
|
||||
|
||||
static bool params_use_gemma4_external_mtp(const gpt_params & params_base) {
|
||||
return params_base.has_mtp &&
|
||||
llama_model_is_gemma4_mtp_assistant(params_base.speculative.model_dft);
|
||||
}
|
||||
|
||||
static llama_context * get_slot_mtp_ctx(server_slot & slot, llama_context * ctx) {
|
||||
llama_context * mtp_ctx = common_speculative_get_mtp_ctx(slot.spec);
|
||||
return mtp_ctx ? mtp_ctx : ctx;
|
||||
}
|
||||
|
||||
static int get_ctx_mtp_n_embd(llama_context * ctx) {
|
||||
return ctx ? (int) llama_mtp_state_n_embd(ctx) : 0;
|
||||
}
|
||||
|
||||
static int get_slot_mtp_n_embd(server_slot & slot, llama_context * ctx) {
|
||||
return get_ctx_mtp_n_embd(get_slot_mtp_ctx(slot, ctx));
|
||||
}
|
||||
|
||||
static void cache_slot_mtp_hidden(server_slot & slot, const float * hidden, int n_embd) {
|
||||
if (hidden == nullptr || n_embd <= 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
slot.mtp_hidden_state.assign(hidden, hidden + n_embd);
|
||||
}
|
||||
|
||||
static void sync_slot_mtp_hidden(server_slot & slot, llama_context * ctx) {
|
||||
if (!slot.has_mtp || !slot.spec || slot.mtp_hidden_state.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int n_embd = get_slot_mtp_n_embd(slot, ctx);
|
||||
if (n_embd <= 0 || slot.mtp_hidden_state.size() < (size_t) n_embd) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int n_hidden = slot.mtp_hidden_state.size() / n_embd;
|
||||
llama_set_draft_input_hidden_state(get_slot_mtp_ctx(slot, ctx), slot.mtp_hidden_state.data() + (n_hidden - 1) * n_embd);
|
||||
}
|
||||
|
||||
static void cache_and_sync_slot_mtp_hidden(server_slot & slot, llama_context * ctx, const float * hidden, int n_embd) {
|
||||
cache_slot_mtp_hidden(slot, hidden, n_embd);
|
||||
sync_slot_mtp_hidden(slot, ctx);
|
||||
}
|
||||
|
||||
static void cache_and_sync_slot_mtp_hidden_from_rows(server_slot & slot, llama_context * ctx, const std::vector<float> & rows, int n_embd) {
|
||||
if (rows.empty() || n_embd <= 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const size_t n_rows = rows.size() / n_embd;
|
||||
if (n_rows == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
cache_and_sync_slot_mtp_hidden(slot, ctx, rows.data() + (n_rows - 1) * n_embd, n_embd);
|
||||
}
|
||||
|
||||
static void apply_slot_mtp_accept(
|
||||
server_slot & slot,
|
||||
llama_context * ctx,
|
||||
const std::vector<float> & mtp_hidden_state,
|
||||
const std::vector<llama_token> & ids,
|
||||
int32_t mtp_n_past_base,
|
||||
int n_embd) {
|
||||
if (!slot.has_mtp || mtp_hidden_state.empty() || n_embd <= 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (slot.use_gemma4_external_mtp) {
|
||||
cache_and_sync_slot_mtp_hidden_from_rows(slot, ctx, mtp_hidden_state, n_embd);
|
||||
return;
|
||||
}
|
||||
|
||||
slot.mtp_hidden_state = mtp_hidden_state;
|
||||
llama_set_draft_input_hidden_state(get_slot_mtp_ctx(slot, ctx), slot.mtp_hidden_state.data());
|
||||
mtp_accept_tokens(get_slot_mtp_ctx(slot, ctx), ids, mtp_n_past_base, slot.id);
|
||||
}
|
||||
|
||||
static void set_external_mtp_hidden(server_slot & slot, llama_context * ctx, const float * hidden, int n_embd) {
|
||||
if (!slot.has_mtp || !slot.spec || hidden == nullptr || n_embd <= 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
cache_and_sync_slot_mtp_hidden(slot, ctx, hidden, n_embd);
|
||||
}
|
||||
|
||||
static void set_external_mtp_hidden_from_rows(server_slot & slot, llama_context * ctx, const std::vector<float> & rows, int n_embd) {
|
||||
cache_and_sync_slot_mtp_hidden_from_rows(slot, ctx, rows, n_embd);
|
||||
}
|
||||
|
||||
void server_speculative_checkpoint::clear() {
|
||||
valid = false;
|
||||
per_step_enabled = false;
|
||||
@ -185,6 +278,7 @@ bool server_context::load_model(const gpt_params& params_) {
|
||||
gpt_params params_dft;
|
||||
params_dft.devices = params_base.speculative.devices;
|
||||
params_dft.model = params_base.speculative.model;
|
||||
params_dft.main_gpu = params_base.main_gpu;
|
||||
params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
|
||||
params_dft.rpc_servers = params_base.rpc_servers;
|
||||
params_dft.cache_type_k = params_base.speculative.cache_type_k.empty() ? params_base.cache_type_k : params_base.speculative.cache_type_k;
|
||||
@ -279,16 +373,22 @@ void server_context::init() {
|
||||
slot.sparams = params_base.sparams;
|
||||
|
||||
if (params_base.has_mtp) {
|
||||
if (llama_model_n_nextn_layer(model) > 0) {
|
||||
const bool has_external_mtp = params_use_gemma4_external_mtp(params_base);
|
||||
|
||||
if (llama_model_n_nextn_layer(model) > 0 || has_external_mtp) {
|
||||
params_base.speculative.type = COMMON_SPECULATIVE_TYPE_MTP;
|
||||
params_base.pooling_type = LLAMA_POOLING_TYPE_NONE;
|
||||
|
||||
params_base.speculative.cparams_dft = common_context_params_to_llama(params_base);
|
||||
if (!has_external_mtp) {
|
||||
params_base.speculative.cparams_dft = common_context_params_to_llama(params_base);
|
||||
}
|
||||
|
||||
params_base.speculative.cparams_dft.mtp = true;
|
||||
params_base.speculative.cparams_dft.mtp_op_type = MTP_OP_WARMUP;
|
||||
params_base.speculative.cparams_dft.embeddings = true;
|
||||
|
||||
slot.has_mtp = true;
|
||||
slot.use_gemma4_external_mtp = has_external_mtp;
|
||||
slot.params.speculative.type = COMMON_SPECULATIVE_TYPE_MTP;
|
||||
slot.params.speculative.n_min = 0;
|
||||
slot.params.speculative.cparams_dft = params_base.speculative.cparams_dft;
|
||||
@ -3276,20 +3376,14 @@ void server_context::add_sampled_tokens() {
|
||||
auto & params_spec = slot.params.speculative;
|
||||
|
||||
if (slot.has_mtp) {
|
||||
llama_context * mtp_ctx = common_speculative_get_mtp_ctx(slot.spec);
|
||||
llama_context * hs_ctx = mtp_ctx ? mtp_ctx : ctx;
|
||||
if (!slot.mtp_hidden_state.empty()) {
|
||||
const int n_embd = llama_model_n_embd(llama_get_model(ctx));
|
||||
const int n_hidden = slot.mtp_hidden_state.size() / n_embd;
|
||||
llama_set_draft_input_hidden_state(hs_ctx, slot.mtp_hidden_state.data() + (n_hidden - 1) * n_embd);
|
||||
sync_slot_mtp_hidden(slot, ctx);
|
||||
} else {
|
||||
LOG_ERROR("MTP hidden state is empty during speculation", {});
|
||||
const float* emb_neg1 = llama_get_embeddings_ith(ctx, -1);
|
||||
if (emb_neg1) {
|
||||
const int n_embd = llama_model_n_embd(llama_get_model(ctx));
|
||||
slot.mtp_hidden_state.resize(n_embd);
|
||||
memcpy(slot.mtp_hidden_state.data(), emb_neg1, n_embd * sizeof(float));
|
||||
llama_set_draft_input_hidden_state(hs_ctx, slot.mtp_hidden_state.data());
|
||||
const int n_embd = get_ctx_mtp_n_embd(ctx);
|
||||
cache_and_sync_slot_mtp_hidden(slot, ctx, emb_neg1, n_embd);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -3857,11 +3951,8 @@ static void restore_speculative_checkpoint(
|
||||
|
||||
// Update MTP KV cache and hidden state using embeddings collected before checkpoint restore.
|
||||
if (slot.has_mtp && !mtp_hidden_state_pre.empty()) {
|
||||
slot.mtp_hidden_state = mtp_hidden_state_pre;
|
||||
llama_context * mtp_ctx = common_speculative_get_mtp_ctx(slot.spec);
|
||||
llama_context * mtp_target = mtp_ctx ? mtp_ctx : ctx;
|
||||
llama_set_draft_input_hidden_state(mtp_target, slot.mtp_hidden_state.data());
|
||||
mtp_accept_tokens(mtp_target, ids, mtp_n_past_base, slot.id);
|
||||
const int n_embd = get_ctx_mtp_n_embd(ctx);
|
||||
apply_slot_mtp_accept(slot, ctx, mtp_hidden_state_pre, ids, mtp_n_past_base, n_embd);
|
||||
}
|
||||
|
||||
SLT_DBG(slot, "per-step restore: step=%d (rejected %d drafts)\n",
|
||||
@ -3895,7 +3986,7 @@ static void restore_speculative_checkpoint(
|
||||
SLT_ERR(slot, "failed to re-decode accepted tokens after checkpoint restore: %d\n", ret);
|
||||
}
|
||||
if (slot.has_mtp) {
|
||||
const int n_embd = llama_model_n_embd(llama_get_model(ctx));
|
||||
const int n_embd = get_ctx_mtp_n_embd(ctx);
|
||||
|
||||
const int n_accepted = (int)ids.size();
|
||||
slot.mtp_hidden_state.resize(n_accepted * n_embd);
|
||||
@ -3906,15 +3997,17 @@ static void restore_speculative_checkpoint(
|
||||
}
|
||||
}
|
||||
|
||||
llama_context * mtp_ctx_rej = common_speculative_get_mtp_ctx(slot.spec);
|
||||
llama_context * mtp_target_rej = mtp_ctx_rej ? mtp_ctx_rej : ctx;
|
||||
llama_set_draft_input_hidden_state(mtp_target_rej, slot.mtp_hidden_state.data());
|
||||
mtp_accept_tokens(mtp_target_rej, ids, slot.spec_ckpt.n_past, slot.id);
|
||||
if (slot.use_gemma4_external_mtp) {
|
||||
cache_and_sync_slot_mtp_hidden_from_rows(slot, ctx, slot.mtp_hidden_state, n_embd);
|
||||
} else {
|
||||
llama_set_draft_input_hidden_state(get_slot_mtp_ctx(slot, ctx), slot.mtp_hidden_state.data());
|
||||
mtp_accept_tokens(get_slot_mtp_ctx(slot, ctx), ids, slot.spec_ckpt.n_past, slot.id);
|
||||
|
||||
if (n_accepted > 1) {
|
||||
memmove(slot.mtp_hidden_state.data(),
|
||||
slot.mtp_hidden_state.data() + (n_accepted - 1) * n_embd,
|
||||
n_embd * sizeof(float));
|
||||
if (n_accepted > 1) {
|
||||
memmove(slot.mtp_hidden_state.data(),
|
||||
slot.mtp_hidden_state.data() + (n_accepted - 1) * n_embd,
|
||||
n_embd * sizeof(float));
|
||||
}
|
||||
}
|
||||
slot.mtp_hidden_state.resize(n_embd);
|
||||
}
|
||||
@ -3972,7 +4065,7 @@ void server_context::speculative_decoding_accept() {
|
||||
if (slot.has_mtp) {
|
||||
mtp_n_past_base = slot.n_past - (slot.drafted.size() + 1);
|
||||
|
||||
const int n_embd = llama_model_n_embd(llama_get_model(ctx));
|
||||
const int n_embd = get_ctx_mtp_n_embd(ctx);
|
||||
if (!ids.empty()) {
|
||||
mtp_hidden_state_pre.resize(ids.size() * n_embd);
|
||||
for (size_t i = 0; i < ids.size(); i++) {
|
||||
@ -4018,13 +4111,9 @@ void server_context::speculative_decoding_accept() {
|
||||
restore_speculative_checkpoint(slot, ctx, model, ids, n_draft, mtp_hidden_state_pre, mtp_n_past_base);
|
||||
} else {
|
||||
if (slot.has_mtp && !mtp_hidden_state_pre.empty()) {
|
||||
llama_context * mtp_ctx = common_speculative_get_mtp_ctx(slot.spec);
|
||||
llama_context * mtp_target = mtp_ctx ? mtp_ctx : ctx;
|
||||
|
||||
slot.mtp_hidden_state = std::move(mtp_hidden_state_pre);
|
||||
llama_set_draft_input_hidden_state(mtp_target, slot.mtp_hidden_state.data());
|
||||
mtp_accept_tokens(mtp_target, ids, mtp_n_past_base, slot.id);
|
||||
}
|
||||
const int n_embd = get_ctx_mtp_n_embd(ctx);
|
||||
apply_slot_mtp_accept(slot, ctx, mtp_hidden_state_pre, ids, mtp_n_past_base, n_embd);
|
||||
}
|
||||
llama_kv_cache_seq_rm(ctx, slot.id, slot.cache_tokens.pos_next(slot.n_past), -1);
|
||||
discard_speculative_checkpoint(slot, ctx);
|
||||
}
|
||||
@ -4396,8 +4485,18 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
|
||||
}
|
||||
|
||||
bool mtp_warmup_needed = false;
|
||||
llama_context * batch_mtp_target = nullptr;
|
||||
std::vector<float> batch_mtp_hidden_state;
|
||||
if (params_base.has_mtp) {
|
||||
for (auto & slot : slots) {
|
||||
if (slot.spec && slot.has_mtp) {
|
||||
llama_context * mc = common_speculative_get_mtp_ctx(slot.spec);
|
||||
if (mc) {
|
||||
batch_mtp_target = mc;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (auto& slot : slots) {
|
||||
if ((slot.state == SLOT_STATE_PROCESSING && slot.n_decoded == 0) ||
|
||||
(slot.state == SLOT_STATE_IDLE && slot.command == SLOT_COMMAND_LOAD_PROMPT)) {
|
||||
@ -4409,13 +4508,16 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
|
||||
}
|
||||
}
|
||||
if (mtp_warmup_needed) {
|
||||
const int n_embd = llama_model_n_embd(llama_get_model(ctx));
|
||||
llama_context * mtp_target = batch_mtp_target ? batch_mtp_target : ctx;
|
||||
const int n_embd_src = get_ctx_mtp_n_embd(ctx);
|
||||
const int n_embd_dst = get_ctx_mtp_n_embd(mtp_target);
|
||||
const int n_toks = batch_view.n_tokens;
|
||||
batch_mtp_hidden_state.resize(n_toks * n_embd);
|
||||
batch_mtp_hidden_state.assign(n_toks * n_embd_dst, 0.0f);
|
||||
for (int t = 0; t < n_toks; t++) {
|
||||
const float* emb_t = llama_get_embeddings_ith(ctx, t);
|
||||
if (emb_t) {
|
||||
memcpy(batch_mtp_hidden_state.data() + t * n_embd, emb_t, n_embd * sizeof(float));
|
||||
const int n_copy = std::min(n_embd_src, n_embd_dst);
|
||||
memcpy(batch_mtp_hidden_state.data() + t * n_embd_dst, emb_t, n_copy * sizeof(float));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -4469,9 +4571,12 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
|
||||
if (params_base.has_mtp && slot.n_decoded == 0) {
|
||||
const float* emb_i = llama_get_embeddings_ith(ctx, tok_idx);
|
||||
if (emb_i) {
|
||||
const int n_embd = llama_model_n_embd(llama_get_model(ctx));
|
||||
slot.mtp_hidden_state.resize(n_embd);
|
||||
memcpy(slot.mtp_hidden_state.data(), emb_i, n_embd * sizeof(float));
|
||||
const int n_embd = get_ctx_mtp_n_embd(ctx);
|
||||
if (slot.use_gemma4_external_mtp) {
|
||||
set_external_mtp_hidden(slot, ctx, emb_i, n_embd);
|
||||
} else {
|
||||
cache_slot_mtp_hidden(slot, emb_i, n_embd);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -4537,16 +4642,17 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
|
||||
slot.i_batch = -1;
|
||||
}
|
||||
if (mtp_warmup_needed && !batch_mtp_hidden_state.empty()) {
|
||||
llama_context * mtp_ctx = nullptr;
|
||||
for (auto & slot : slots) {
|
||||
if (slot.spec && slot.has_mtp) {
|
||||
llama_context * mc = common_speculative_get_mtp_ctx(slot.spec);
|
||||
if (mc) { mtp_ctx = mc; break; }
|
||||
if (params_use_gemma4_external_mtp(params_base)) {
|
||||
for (auto & slot : slots) {
|
||||
if (slot.spec && slot.has_mtp && !slot.mtp_hidden_state.empty()) {
|
||||
sync_slot_mtp_hidden(slot, ctx);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
llama_context * mtp_target = batch_mtp_target ? batch_mtp_target : ctx;
|
||||
llama_set_draft_input_hidden_state(mtp_target, batch_mtp_hidden_state.data());
|
||||
mtp_update_kv_cache(mtp_target, batch_view, true);
|
||||
}
|
||||
llama_context * mtp_target = mtp_ctx ? mtp_ctx : ctx;
|
||||
llama_set_draft_input_hidden_state(mtp_target, batch_mtp_hidden_state.data());
|
||||
mtp_update_kv_cache(mtp_target, batch_view, true);
|
||||
}
|
||||
|
||||
// speculative decoding - main model sample and accept
|
||||
|
||||
@ -171,6 +171,7 @@ struct server_slot {
|
||||
decltype(ctx_sampling->elb_states) elb_prev_states;
|
||||
|
||||
bool has_mtp = false;
|
||||
bool use_gemma4_external_mtp = false;
|
||||
std::vector<float> mtp_hidden_state;
|
||||
|
||||
// saves recurrent state before a speculative batch so it can be restored on rejection
|
||||
|
||||
@ -1107,7 +1107,8 @@ static bool ggml_is_view_op(enum ggml_op op) {
|
||||
#endif
|
||||
|
||||
#ifndef GGML_SCHED_MAX_SPLIT_INPUTS
|
||||
#define GGML_SCHED_MAX_SPLIT_INPUTS GGML_MAX_SRC
|
||||
// Gemma4 with per-layer embeddings and uses up to 32 inputs
|
||||
#define GGML_SCHED_MAX_SPLIT_INPUTS 32
|
||||
#endif
|
||||
|
||||
#ifndef GGML_SCHED_MAX_COPIES
|
||||
|
||||
@ -234,6 +234,8 @@ class MODEL_ARCH(IntEnum):
|
||||
GEMMA = auto()
|
||||
GEMMA2 = auto()
|
||||
GEMMA3 = auto()
|
||||
GEMMA4 = auto()
|
||||
GEMMA4_MTP = auto()
|
||||
STARCODER2 = auto()
|
||||
MAMBA = auto()
|
||||
XVERSE = auto()
|
||||
@ -282,7 +284,10 @@ class MODEL_TENSOR(IntEnum):
|
||||
FFN_GATE_INP_SHEXP = auto()
|
||||
FFN_NORM = auto()
|
||||
FFN_PRE_NORM = auto()
|
||||
FFN_PRE_NORM_2 = auto()
|
||||
FFN_POST_NORM = auto()
|
||||
FFN_POST_NORM_1 = auto()
|
||||
FFN_POST_NORM_2 = auto()
|
||||
FFN_GATE = auto()
|
||||
FFN_DOWN = auto()
|
||||
FFN_UP = auto()
|
||||
@ -291,6 +296,7 @@ class MODEL_TENSOR(IntEnum):
|
||||
FFN_GATE_EXP = auto()
|
||||
FFN_DOWN_EXP = auto()
|
||||
FFN_UP_EXP = auto()
|
||||
FFN_GATE_UP_EXP = auto()
|
||||
FFN_GATE_SHEXP = auto()
|
||||
FFN_DOWN_SHEXP = auto()
|
||||
FFN_UP_SHEXP = auto()
|
||||
@ -298,6 +304,13 @@ class MODEL_TENSOR(IntEnum):
|
||||
ATTN_Q_NORM = auto()
|
||||
ATTN_K_NORM = auto()
|
||||
LAYER_OUT_NORM = auto()
|
||||
LAYER_OUT_SCALE = auto()
|
||||
PER_LAYER_TOKEN_EMBD = auto()
|
||||
PER_LAYER_MODEL_PROJ = auto()
|
||||
PER_LAYER_INP_GATE = auto()
|
||||
PER_LAYER_PROJ = auto()
|
||||
PER_LAYER_PROJ_NORM = auto()
|
||||
PER_LAYER_POST_NORM = auto()
|
||||
SSM_IN = auto()
|
||||
SSM_CONV1D = auto()
|
||||
SSM_X = auto()
|
||||
@ -349,6 +362,10 @@ class MODEL_TENSOR(IntEnum):
|
||||
NEXTN_HNORM = auto() # nextn tensors (glm4moe)
|
||||
NEXTN_SHARED_HEAD_HEAD = auto() # nextn tensors (glm4moe)
|
||||
NEXTN_SHARED_HEAD_NORM = auto() # nextn tensors (glm4moe)
|
||||
MTP_PRE_PROJ = auto()
|
||||
MTP_POST_PROJ = auto()
|
||||
MTP_TOKEN_ORDERING = auto()
|
||||
MTP_CENTROIDS = auto()
|
||||
|
||||
|
||||
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
@ -383,6 +400,8 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
MODEL_ARCH.GEMMA: "gemma",
|
||||
MODEL_ARCH.GEMMA2: "gemma2",
|
||||
MODEL_ARCH.GEMMA3: "gemma3",
|
||||
MODEL_ARCH.GEMMA4: "gemma4",
|
||||
MODEL_ARCH.GEMMA4_MTP: "gemma4_mtp",
|
||||
MODEL_ARCH.STARCODER2: "starcoder2",
|
||||
MODEL_ARCH.MAMBA: "mamba",
|
||||
MODEL_ARCH.XVERSE: "xverse",
|
||||
@ -434,7 +453,10 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
MODEL_TENSOR.FFN_GATE_INP_SHEXP: "blk.{bid}.ffn_gate_inp_shexp",
|
||||
MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm",
|
||||
MODEL_TENSOR.FFN_PRE_NORM: "blk.{bid}.ffn_norm",
|
||||
MODEL_TENSOR.FFN_PRE_NORM_2: "blk.{bid}.pre_ffw_norm_2",
|
||||
MODEL_TENSOR.FFN_POST_NORM: "blk.{bid}.post_ffw_norm",
|
||||
MODEL_TENSOR.FFN_POST_NORM_1: "blk.{bid}.post_ffw_norm_1",
|
||||
MODEL_TENSOR.FFN_POST_NORM_2: "blk.{bid}.post_ffw_norm_2",
|
||||
MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate",
|
||||
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
|
||||
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
|
||||
@ -446,8 +468,16 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate_exps",
|
||||
MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down_exps",
|
||||
MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps",
|
||||
MODEL_TENSOR.FFN_GATE_UP_EXP: "blk.{bid}.ffn_gate_up_exps",
|
||||
MODEL_TENSOR.FFN_EXP_PROBS_B: "blk.{bid}.exp_probs_b",
|
||||
MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm",
|
||||
MODEL_TENSOR.LAYER_OUT_SCALE: "blk.{bid}.layer_output_scale",
|
||||
MODEL_TENSOR.PER_LAYER_TOKEN_EMBD: "per_layer_token_embd",
|
||||
MODEL_TENSOR.PER_LAYER_MODEL_PROJ: "per_layer_model_proj",
|
||||
MODEL_TENSOR.PER_LAYER_INP_GATE: "blk.{bid}.inp_gate",
|
||||
MODEL_TENSOR.PER_LAYER_PROJ: "blk.{bid}.proj",
|
||||
MODEL_TENSOR.PER_LAYER_PROJ_NORM: "per_layer_proj_norm",
|
||||
MODEL_TENSOR.PER_LAYER_POST_NORM: "blk.{bid}.post_norm",
|
||||
MODEL_TENSOR.SSM_IN: "blk.{bid}.ssm_in",
|
||||
MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d",
|
||||
MODEL_TENSOR.SSM_X: "blk.{bid}.ssm_x",
|
||||
@ -500,6 +530,10 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
MODEL_TENSOR.NEXTN_HNORM: "blk.{bid}.nextn.hnorm",
|
||||
MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD: "blk.{bid}.nextn.shared_head_head",
|
||||
MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM: "blk.{bid}.nextn.shared_head_norm",
|
||||
MODEL_TENSOR.MTP_PRE_PROJ: "mtp_pre_proj",
|
||||
MODEL_TENSOR.MTP_POST_PROJ: "mtp_post_proj",
|
||||
MODEL_TENSOR.MTP_TOKEN_ORDERING: "mtp_token_ordering",
|
||||
MODEL_TENSOR.MTP_CENTROIDS: "mtp_centroids",
|
||||
}
|
||||
|
||||
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
@ -962,6 +996,56 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.FFN_PRE_NORM,
|
||||
MODEL_TENSOR.FFN_POST_NORM,
|
||||
],
|
||||
MODEL_ARCH.GEMMA4: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.ROPE_FREQS,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.ATTN_Q_NORM,
|
||||
MODEL_TENSOR.ATTN_K_NORM,
|
||||
MODEL_TENSOR.ATTN_POST_NORM,
|
||||
MODEL_TENSOR.FFN_GATE_INP,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_PRE_NORM_2,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.FFN_POST_NORM,
|
||||
MODEL_TENSOR.FFN_POST_NORM_1,
|
||||
MODEL_TENSOR.FFN_POST_NORM_2,
|
||||
MODEL_TENSOR.FFN_GATE_UP_EXP,
|
||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||
MODEL_TENSOR.LAYER_OUT_SCALE,
|
||||
MODEL_TENSOR.PER_LAYER_TOKEN_EMBD,
|
||||
MODEL_TENSOR.PER_LAYER_MODEL_PROJ,
|
||||
MODEL_TENSOR.PER_LAYER_INP_GATE,
|
||||
MODEL_TENSOR.PER_LAYER_PROJ,
|
||||
MODEL_TENSOR.PER_LAYER_PROJ_NORM,
|
||||
MODEL_TENSOR.PER_LAYER_POST_NORM,
|
||||
],
|
||||
MODEL_ARCH.GEMMA4_MTP: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.ATTN_Q_NORM,
|
||||
MODEL_TENSOR.ATTN_POST_NORM,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.FFN_POST_NORM,
|
||||
MODEL_TENSOR.LAYER_OUT_SCALE,
|
||||
MODEL_TENSOR.MTP_PRE_PROJ,
|
||||
MODEL_TENSOR.MTP_POST_PROJ,
|
||||
MODEL_TENSOR.MTP_TOKEN_ORDERING,
|
||||
MODEL_TENSOR.MTP_CENTROIDS,
|
||||
],
|
||||
MODEL_ARCH.STARCODER2: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
|
||||
@ -245,12 +245,24 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.pre_feedforward_layernorm", # gemma2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_PRE_NORM_2: (
|
||||
"model.layers.{bid}.pre_feedforward_layernorm_2", # gemma4
|
||||
),
|
||||
|
||||
# Post feed-forward norm
|
||||
MODEL_TENSOR.FFN_POST_NORM: (
|
||||
"model.layers.{bid}.post_feedforward_layernorm", # gemma2
|
||||
"model.layers.{bid}.post_moe_norm", # grok-2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_POST_NORM_1: (
|
||||
"model.layers.{bid}.post_feedforward_layernorm_1", # gemma4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_POST_NORM_2: (
|
||||
"model.layers.{bid}.post_feedforward_layernorm_2", # gemma4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_GATE_INP: (
|
||||
"layers.{bid}.feed_forward.gate", # mixtral
|
||||
"model.layers.{bid}.block_sparse_moe.gate", # mixtral
|
||||
@ -305,6 +317,11 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.mlp.experts.up_proj", # qwen2moe (merged)
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_GATE_UP_EXP: (
|
||||
"model.layers.{bid}.mlp.experts.gate_up_proj", # gemma4
|
||||
"model.layers.{bid}.experts.gate_up_proj", # gemma4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_UP_SHEXP: (
|
||||
"model.layers.{bid}.mlp.shared_expert.up_proj", # qwen2moe
|
||||
"model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek2
|
||||
@ -413,6 +430,34 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.final_layernorm", # bailingmoe2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.LAYER_OUT_SCALE: (
|
||||
"model.layers.{bid}.layer_scalar", # gemma4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.PER_LAYER_TOKEN_EMBD: (
|
||||
"model.embed_tokens_per_layer", # gemma4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.PER_LAYER_MODEL_PROJ: (
|
||||
"model.per_layer_model_projection", # gemma4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.PER_LAYER_PROJ_NORM: (
|
||||
"model.per_layer_projection_norm", # gemma4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.PER_LAYER_INP_GATE: (
|
||||
"model.layers.{bid}.per_layer_input_gate", # gemma4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.PER_LAYER_PROJ: (
|
||||
"model.layers.{bid}.per_layer_projection", # gemma4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.PER_LAYER_POST_NORM: (
|
||||
"model.layers.{bid}.post_per_layer_input_norm", # gemma4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.SSM_IN: (
|
||||
"model.layers.{bid}.in_proj",
|
||||
"backbone.layers.{bid}.mixer.in_proj",
|
||||
|
||||
@ -685,6 +685,11 @@ extern "C" {
|
||||
|
||||
LLAMA_API bool llama_model_has_recurrent(const struct llama_model * model);
|
||||
|
||||
// Returns true if the model is a Gemma 4 MTP assistant (external frozen-KV speculative drafter)
|
||||
LLAMA_API bool llama_model_is_gemma4_mtp_assistant(const struct llama_model * model);
|
||||
|
||||
LLAMA_API bool llama_is_gemma4_mtp_file(const char * path);
|
||||
|
||||
LLAMA_API bool llama_model_is_split_mode_graph(const struct llama_model * model);
|
||||
|
||||
// Returns 0 on success
|
||||
|
||||
@ -2,6 +2,128 @@
|
||||
#include "../llama-model.h"
|
||||
#include "../llama-context.h"
|
||||
|
||||
static int gemma4_mtp_target_kv_layer(const llama_hparams & mtp_hparams, const llama_hparams & target_hparams, int mtp_il) {
|
||||
GGML_ASSERT(mtp_il >= 0 && mtp_il < (int) mtp_hparams.n_layer);
|
||||
|
||||
const bool is_sliding = mtp_hparams.swa_layers[mtp_il] != 0;
|
||||
const int target_n_kv_layer = target_hparams.n_layer_kv_from_start > 0
|
||||
? std::min<int>((int) target_hparams.n_layer, target_hparams.n_layer_kv_from_start)
|
||||
: (int) target_hparams.n_layer;
|
||||
|
||||
int target_il = target_n_kv_layer - 1;
|
||||
for (; target_il >= 0; --target_il) {
|
||||
if ((target_hparams.swa_layers[target_il] != 0) == is_sliding) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
GGML_ASSERT(target_il >= 0 && "Gemma4 MTP could not find a matching target KV layer");
|
||||
return target_il;
|
||||
}
|
||||
|
||||
static void gemma4_mtp_prepare_frozen_kv_views(
|
||||
ggml_context * ctx0,
|
||||
llama_context & lctx,
|
||||
const llama_kv_cache & target_kv,
|
||||
int assistant_il,
|
||||
int target_il,
|
||||
int32_t target_n_kv,
|
||||
ggml_tensor ** frozen_k,
|
||||
ggml_tensor ** frozen_v,
|
||||
const llm_build_cb & cb) {
|
||||
if (*frozen_k || *frozen_v) {
|
||||
GGML_ASSERT(*frozen_k && *frozen_v);
|
||||
return;
|
||||
}
|
||||
|
||||
if (!lctx.cparams.flash_attn) {
|
||||
return;
|
||||
}
|
||||
|
||||
GGML_ASSERT(target_il >= 0 && target_il < (int) target_kv.k_l.size() && target_il < (int) target_kv.v_l.size());
|
||||
|
||||
ggml_tensor * k_cache = target_kv.k_l[target_il];
|
||||
ggml_tensor * v_cache = target_kv.v_l[target_il];
|
||||
if (!k_cache || !v_cache || !k_cache->extra || !v_cache->extra) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto * split_k = (ggml_split_tensor_t *) k_cache->extra;
|
||||
auto * split_v = (ggml_split_tensor_t *) v_cache->extra;
|
||||
|
||||
GGML_ASSERT(split_k && split_v);
|
||||
GGML_ASSERT(split_k->n_device == split_v->n_device);
|
||||
|
||||
const llama_hparams & assistant_hparams = lctx.model.hparams;
|
||||
const int64_t n_embd_head_k = assistant_hparams.n_embd_head_k(assistant_il);
|
||||
const int64_t n_embd_head_v = assistant_hparams.n_embd_head_v(assistant_il);
|
||||
|
||||
std::vector<ggml_tensor *> k_parts;
|
||||
std::vector<ggml_tensor *> v_parts;
|
||||
k_parts.reserve(split_k->n_device);
|
||||
v_parts.reserve(split_v->n_device);
|
||||
|
||||
for (int id = 0; id < split_k->n_device; ++id) {
|
||||
ggml_tensor * split_kl = split_k->splits[id];
|
||||
ggml_tensor * split_vl = split_v->splits[id];
|
||||
|
||||
GGML_ASSERT((split_kl && split_vl) || (!split_kl && !split_vl));
|
||||
if (!split_kl) {
|
||||
continue;
|
||||
}
|
||||
|
||||
GGML_ASSERT(target_kv.size > 0);
|
||||
GGML_ASSERT(split_kl->ne[1] % target_kv.size == 0);
|
||||
|
||||
const int64_t split_n_head_kv = split_kl->ne[1] / target_kv.size;
|
||||
|
||||
ggml_tensor * k_part = ggml_view_3d(ctx0, split_kl,
|
||||
n_embd_head_k, target_n_kv, split_n_head_kv,
|
||||
ggml_row_size(split_kl->type, n_embd_head_k) * split_n_head_kv,
|
||||
ggml_row_size(split_kl->type, n_embd_head_k),
|
||||
0);
|
||||
if (k_part->type != GGML_TYPE_F32) {
|
||||
k_part = ggml_cast(ctx0, k_part, GGML_TYPE_F32);
|
||||
}
|
||||
cb(k_part, "mtp_frozen_k_split", 1000 * (assistant_il + 1) + id);
|
||||
|
||||
ggml_tensor * v_part = ggml_view_3d(ctx0, split_vl,
|
||||
n_embd_head_v, target_n_kv, split_n_head_kv,
|
||||
ggml_row_size(split_vl->type, split_n_head_kv * n_embd_head_v),
|
||||
ggml_row_size(split_vl->type, n_embd_head_v),
|
||||
0);
|
||||
if (v_part->type != GGML_TYPE_F32) {
|
||||
v_part = ggml_cast(ctx0, v_part, GGML_TYPE_F32);
|
||||
}
|
||||
cb(v_part, "mtp_frozen_v_split", 1000 * (assistant_il + 1) + id);
|
||||
|
||||
k_parts.push_back(k_part);
|
||||
v_parts.push_back(v_part);
|
||||
}
|
||||
|
||||
GGML_ASSERT(!k_parts.empty() && k_parts.size() == v_parts.size());
|
||||
|
||||
ggml_tensor * k_full = k_parts[0];
|
||||
ggml_tensor * v_full = v_parts[0];
|
||||
for (size_t i = 1; i < k_parts.size(); ++i) {
|
||||
k_full = ggml_concat(ctx0, k_full, k_parts[i], 2);
|
||||
v_full = ggml_concat(ctx0, v_full, v_parts[i], 2);
|
||||
}
|
||||
|
||||
if (k_full->type != GGML_TYPE_F16) {
|
||||
k_full = ggml_cast(ctx0, k_full, GGML_TYPE_F16);
|
||||
}
|
||||
if (v_full->type != GGML_TYPE_F16) {
|
||||
v_full = ggml_cast(ctx0, v_full, GGML_TYPE_F16);
|
||||
}
|
||||
|
||||
cb(k_full, "mtp_frozen_k", assistant_il);
|
||||
cb(v_full, "mtp_frozen_v", assistant_il);
|
||||
|
||||
*frozen_k = k_full;
|
||||
*frozen_v = v_full;
|
||||
}
|
||||
|
||||
static ggml_cgraph * build_gemma4_graph_parallel(llm_build_context & llm, llama_context & lctx, ggml_context * ctx0,
|
||||
ggml_tensor * inpL, ggml_tensor * inp_pos, ggml_tensor * inp_out_ids,
|
||||
ggml_tensor * KQ_mask, ggml_tensor * KQ_mask_swa, int n_tokens, const llm_build_cb & cb) {
|
||||
@ -363,6 +485,160 @@ static ggml_cgraph * build_gemma4_graph_parallel(llm_build_context & llm, llama_
|
||||
return gf;
|
||||
}
|
||||
|
||||
ggml_cgraph * llm_build_context::build_gemma4_mtp() {
|
||||
ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(n_tokens), false);
|
||||
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
const int64_t n_vocab = hparams.n_vocab;
|
||||
const int64_t n_backbone = hparams.mtp_backbone_n_embd;
|
||||
const int32_t n_layer = hparams.n_layer;
|
||||
const bool has_target_ctx = lctx.mtp_target_ctx != nullptr;
|
||||
|
||||
GGML_ASSERT(n_backbone > 0);
|
||||
|
||||
lctx.inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, batch.n_tokens);
|
||||
cb(lctx.inp_tokens, "inp_tokens", -1);
|
||||
ggml_set_input(lctx.inp_tokens);
|
||||
|
||||
ggml_tensor * hidden_state = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_backbone, n_tokens);
|
||||
ggml_set_name(hidden_state, "inp_mtp_states");
|
||||
ggml_set_input(hidden_state);
|
||||
lctx.inp_mtp_states = hidden_state;
|
||||
|
||||
if (!has_target_ctx || !batch.token) {
|
||||
ggml_tensor * cur = ggml_view_2d(ctx0, hidden_state, n_embd, n_tokens,
|
||||
ggml_row_size(hidden_state->type, n_backbone), 0);
|
||||
cb(cur, "mtp_init_hidden_view", -1);
|
||||
|
||||
ggml_tensor * mtp_embd = ggml_dup(ctx0, hidden_state);
|
||||
cb(mtp_embd, "result_mtp_embd", -1);
|
||||
ggml_build_forward_expand(gf, mtp_embd);
|
||||
|
||||
ggml_tensor * logits = build_output(lctx, ctx0, cur, model.output, model.output_norm, cb);
|
||||
cb(logits, "result_output", -1);
|
||||
ggml_build_forward_expand(gf, logits);
|
||||
|
||||
GGML_UNUSED(n_vocab);
|
||||
return gf;
|
||||
}
|
||||
|
||||
const llama_model & target_model = lctx.mtp_target_ctx->model;
|
||||
const llama_hparams & target_hparams = target_model.hparams;
|
||||
const llama_cparams & target_cparams = lctx.mtp_target_ctx->cparams;
|
||||
const llama_kv_cache & target_kv = lctx.mtp_target_ctx->kv_self;
|
||||
|
||||
GGML_ASSERT(n_tokens <= target_kv.n);
|
||||
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
ggml_tensor * token_embd = ggml_get_rows(ctx0, target_model.tok_embd, lctx.inp_tokens);
|
||||
cb(token_embd, "inp_embd_target", -1);
|
||||
token_embd = ggml_scale(ctx0, token_embd, std::sqrt(float(n_backbone)));
|
||||
cb(token_embd, "inp_embd_scaled", -1);
|
||||
|
||||
ggml_tensor * cur = ggml_concat(ctx0, token_embd, hidden_state, 0);
|
||||
cb(cur, "inp_mtp_combined", -1);
|
||||
cur = llm_build_lora_mm(lctx, ctx0, model.mtp_pre_proj, cur);
|
||||
cb(cur, "mtp_pre_proj", -1);
|
||||
|
||||
const int32_t target_n_kv = target_kv.n;
|
||||
const int32_t target_kv_head = target_kv.head;
|
||||
|
||||
ggml_tensor * KQ_mask = nullptr;
|
||||
ggml_tensor * KQ_mask_swa = nullptr;
|
||||
ggml_tensor * frozen_k_swa = nullptr;
|
||||
ggml_tensor * frozen_v_swa = nullptr;
|
||||
ggml_tensor * frozen_k_full = nullptr;
|
||||
ggml_tensor * frozen_v_full = nullptr;
|
||||
{
|
||||
const int64_t n_mask_tokens = GGML_PAD(n_tokens, GGML_KQ_MASK_PAD);
|
||||
lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, flash_attn ? GGML_TYPE_F16 : GGML_TYPE_F32, target_n_kv, n_mask_tokens);
|
||||
cb(lctx.inp_KQ_mask, "KQ_mask", -1);
|
||||
ggml_set_input(lctx.inp_KQ_mask);
|
||||
KQ_mask = lctx.inp_KQ_mask;
|
||||
|
||||
if (target_hparams.n_swa > 0) {
|
||||
lctx.inp_KQ_mask_swa = ggml_new_tensor_2d(ctx0, flash_attn ? GGML_TYPE_F16 : GGML_TYPE_F32, target_n_kv, n_mask_tokens);
|
||||
cb(lctx.inp_KQ_mask_swa, "KQ_mask_swa", -1);
|
||||
ggml_set_input(lctx.inp_KQ_mask_swa);
|
||||
KQ_mask_swa = lctx.inp_KQ_mask_swa;
|
||||
}
|
||||
}
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpL = cur;
|
||||
|
||||
const bool is_sliding = hparams.swa_layers[il] ? true : false;
|
||||
const float freq_base_l = is_sliding ? target_hparams.rope_freq_base_train_swa : target_cparams.rope_freq_base;
|
||||
const float freq_scale_l = is_sliding ? target_hparams.rope_freq_scale_train_swa : target_cparams.rope_freq_scale;
|
||||
const int n_rot_l = is_sliding ? target_hparams.n_rot_swa : target_hparams.n_rot;
|
||||
const int n_swa = is_sliding ? target_hparams.n_swa : 0;
|
||||
const int n_embd_head = hparams.n_embd_head_k(il);
|
||||
const int n_head = hparams.n_head(il);
|
||||
ggml_tensor * KQ_mask_l = is_sliding ? KQ_mask_swa : KQ_mask;
|
||||
|
||||
cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
|
||||
cb(Qcur, "Qcur", il);
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, cb, il);
|
||||
cb(Qcur, "Qcur_normed", il);
|
||||
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
cb(Qcur, "Qcur_rope", il);
|
||||
|
||||
const int target_il = gemma4_mtp_target_kv_layer(hparams, target_hparams, il);
|
||||
ggml_tensor *& frozen_k = is_sliding ? frozen_k_swa : frozen_k_full;
|
||||
ggml_tensor *& frozen_v = is_sliding ? frozen_v_swa : frozen_v_full;
|
||||
gemma4_mtp_prepare_frozen_kv_views(ctx0, lctx, target_kv, il, target_il, target_n_kv, &frozen_k, &frozen_v, cb);
|
||||
cur = llm_build_kv(ctx0, lctx, target_kv, gf, model.layers[il].wo, model.layers[il].bo,
|
||||
nullptr, nullptr, Qcur, KQ_mask_l, n_tokens, target_kv_head, target_n_kv, hparams.f_attention_scale, cb, il, nullptr, n_swa, target_il,
|
||||
&frozen_k, &frozen_v);
|
||||
|
||||
cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "attn_post_norm", il);
|
||||
cur = ggml_add(ctx0, cur, inpL);
|
||||
cb(cur, "attn_out", il);
|
||||
|
||||
ggml_tensor * ffn = llm_build_ffn(ctx0, lctx, model.layers[il].ffn_norm, cur,
|
||||
model.layers[il].ffn_up, nullptr, nullptr,
|
||||
model.layers[il].ffn_gate, nullptr, nullptr,
|
||||
model.layers[il].ffn_down, nullptr, nullptr,
|
||||
nullptr,
|
||||
LLM_FFN_GELU, LLM_FFN_PAR, cb, il, gf, true, false, nullptr, model.layers[il].ffn_post_norm);
|
||||
cb(ffn, "ffn_out", il);
|
||||
|
||||
cur = ffn;
|
||||
if (model.layers[il].out_scale) {
|
||||
cur = ggml_mul(ctx0, cur, model.layers[il].out_scale);
|
||||
cb(cur, "out_scaled", il);
|
||||
}
|
||||
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
}
|
||||
|
||||
ggml_tensor * mtp_embd = llm_build_lora_mm(lctx, ctx0, model.mtp_post_proj, cur);
|
||||
cb(mtp_embd, "result_mtp_embd", -1);
|
||||
ggml_build_forward_expand(gf, mtp_embd);
|
||||
|
||||
ggml_tensor * logits;
|
||||
// E2B/E4B: The centroid/token-ordering tensors are kept in the GGUF for future use but
|
||||
// not required for correct inference — the full-vocab matmul against the tied output
|
||||
// weight still yields valid per-token logits.
|
||||
{
|
||||
logits = build_output(lctx, ctx0, cur, model.output, model.output_norm, cb);
|
||||
cb(logits, "result_output", -1);
|
||||
}
|
||||
ggml_build_forward_expand(gf, logits);
|
||||
|
||||
GGML_UNUSED(n_embd);
|
||||
GGML_UNUSED(n_vocab);
|
||||
|
||||
return gf;
|
||||
}
|
||||
|
||||
static ggml_tensor * gemma4_project_per_layer_inputs(ggml_context * ctx0, const llama_model & model, const llm_build_cb & cb,
|
||||
int n_embd, int n_embd_per_layer, int n_layer, int n_tokens,
|
||||
ggml_tensor * inputs_embeds, ggml_tensor * inp_per_layer) {
|
||||
@ -614,6 +890,12 @@ ggml_cgraph * llm_build_context::build_gemma4() {
|
||||
|
||||
cur = inpL;
|
||||
|
||||
if (cparams.mtp) {
|
||||
ggml_tensor * mtp_embd = ggml_dup(ctx0, cur);
|
||||
cb(mtp_embd, "result_mtp_embd", -1);
|
||||
ggml_build_forward_expand(gf, mtp_embd);
|
||||
}
|
||||
|
||||
cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, NULL, LLM_NORM_RMS, cb, -1);
|
||||
cb(cur, "result_norm", -1);
|
||||
|
||||
|
||||
@ -78,6 +78,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_GLM_DSA, "glm-dsa" },
|
||||
{ LLM_ARCH_MISTRAL4, "mistral4" },
|
||||
{ LLM_ARCH_GEMMA4, "gemma4" },
|
||||
{ LLM_ARCH_GEMMA4_MTP, "gemma4_mtp" },
|
||||
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
||||
};
|
||||
|
||||
@ -140,6 +141,10 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||
{ LLM_KV_SWIGLU_LIMITS_SHARED, "%s.swiglu_limits_shared" },
|
||||
{ LLM_KV_SWIGLU_CLAMP_EXP, "%s.swiglu_clamp_exp" },
|
||||
{ LLM_KV_SWIGLU_CLAMP_SHEXP, "%s.swiglu_clamp_shexp" },
|
||||
{ LLM_KV_MTP_BACKBONE_EMBEDDING_LENGTH, "%s.backbone_embedding_length" },
|
||||
{ LLM_KV_MTP_USE_ORDERED_EMBEDDINGS, "%s.use_ordered_embeddings" },
|
||||
{ LLM_KV_MTP_CENTROID_COUNT, "%s.centroid_count" },
|
||||
{ LLM_KV_MTP_CENTROID_TOP_K, "%s.centroid_top_k" },
|
||||
|
||||
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
|
||||
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
|
||||
|
||||
@ -77,6 +77,7 @@ enum llm_arch {
|
||||
LLM_ARCH_GLM_DSA,
|
||||
LLM_ARCH_MISTRAL4,
|
||||
LLM_ARCH_GEMMA4,
|
||||
LLM_ARCH_GEMMA4_MTP,
|
||||
LLM_ARCH_UNKNOWN,
|
||||
};
|
||||
|
||||
@ -133,6 +134,10 @@ enum llm_kv {
|
||||
LLM_KV_SWIGLU_CLAMP_EXP,
|
||||
LLM_KV_SWIGLU_CLAMP_SHEXP,
|
||||
LLM_KV_EMBEDDING_LENGTH_PER_LAYER,
|
||||
LLM_KV_MTP_BACKBONE_EMBEDDING_LENGTH,
|
||||
LLM_KV_MTP_USE_ORDERED_EMBEDDINGS,
|
||||
LLM_KV_MTP_CENTROID_COUNT,
|
||||
LLM_KV_MTP_CENTROID_TOP_K,
|
||||
|
||||
LLM_KV_ATTENTION_HEAD_COUNT,
|
||||
LLM_KV_ATTENTION_HEAD_COUNT_KV,
|
||||
@ -358,6 +363,10 @@ enum llm_tensor {
|
||||
LLM_TENSOR_FFN_PRE_NORM_2, // 105
|
||||
LLM_TENSOR_FFN_POST_NORM_1,
|
||||
LLM_TENSOR_FFN_POST_NORM_2,
|
||||
LLM_TENSOR_MTP_PRE_PROJ,
|
||||
LLM_TENSOR_MTP_POST_PROJ,
|
||||
LLM_TENSOR_MTP_TOKEN_ORDERING,
|
||||
LLM_TENSOR_MTP_CENTROIDS,
|
||||
|
||||
LLM_TENSOR_UNKNOWN,
|
||||
};
|
||||
|
||||
@ -1537,37 +1537,46 @@ static ggml_tensor * llm_build_kqv(
|
||||
float kq_scale,
|
||||
const llm_build_cb & cb,
|
||||
int il,
|
||||
ggml_tensor * sinks = nullptr, int n_swa = 0) {
|
||||
ggml_tensor * sinks = nullptr, int n_swa = 0, int kv_il = -1,
|
||||
ggml_tensor ** k_cache_view = nullptr, ggml_tensor ** v_cache_view = nullptr) {
|
||||
const llama_model & model = lctx.model;
|
||||
const llama_hparams & hparams = lctx.model.hparams;
|
||||
const llama_cparams & cparams = lctx.cparams;
|
||||
|
||||
const int64_t n_ctx = cparams.n_ctx;
|
||||
const int64_t n_ctx = kv.size;
|
||||
const int64_t n_head = hparams.n_head(il);
|
||||
const int64_t n_head_kv = hparams.n_head_kv(il);
|
||||
const int64_t n_embd_head_k = hparams.n_embd_head_k(il);
|
||||
//const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
||||
const int64_t n_embd_head_v = hparams.n_embd_head_v(il);
|
||||
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
||||
const int kv_layer = kv_il >= 0 ? kv_il : il;
|
||||
|
||||
struct ggml_tensor * q = ggml_permute(ctx, q_cur, 0, 2, 1, 3);
|
||||
cb(q, "q", il);
|
||||
|
||||
auto k_cache = lctx.model.hparams.has_kv(il) ? kv.k_l[il]
|
||||
auto k_cache = kv_il >= 0 ? kv.k_l[kv_layer]
|
||||
: lctx.model.hparams.has_kv(il) ? kv.k_l[il]
|
||||
: lctx.model.hparams.swa_layers[il] ? kv.k_l[hparams.n_layer_kv_from_start-2] : kv.k_l[hparams.n_layer_kv_from_start-1];
|
||||
auto v_cache = lctx.model.hparams.has_kv(il) ? kv.v_l[il]
|
||||
auto v_cache = kv_il >= 0 ? kv.v_l[kv_layer]
|
||||
: lctx.model.hparams.has_kv(il) ? kv.v_l[il]
|
||||
: lctx.model.hparams.swa_layers[il] ? kv.v_l[hparams.n_layer_kv_from_start-2] : kv.v_l[hparams.n_layer_kv_from_start-1];
|
||||
|
||||
GGML_ASSERT(k_cache != nullptr && "k_cache is null in llm_build_kqv");
|
||||
GGML_ASSERT(v_cache != nullptr && "v_cache is null in llm_build_kqv");
|
||||
|
||||
struct ggml_tensor * k =
|
||||
ggml_view_3d(ctx, k_cache,
|
||||
n_embd_head_k, n_kv, n_head_kv,
|
||||
ggml_row_size(k_cache->type, n_embd_head_k)*n_head_kv, //n_embd_k_gqa),
|
||||
ggml_row_size(k_cache->type, n_embd_head_k),
|
||||
0);
|
||||
cb(k, "k", il);
|
||||
struct ggml_tensor * k = k_cache_view ? *k_cache_view : nullptr;
|
||||
if (!k) {
|
||||
k = ggml_view_3d(ctx, k_cache,
|
||||
n_embd_head_k, n_kv, n_head_kv,
|
||||
ggml_row_size(k_cache->type, n_embd_head_k)*n_head_kv, //n_embd_k_gqa),
|
||||
ggml_row_size(k_cache->type, n_embd_head_k),
|
||||
0);
|
||||
if (k_cache_view) {
|
||||
*k_cache_view = k;
|
||||
}
|
||||
cb(k, "k", il);
|
||||
}
|
||||
|
||||
#ifdef GGML_USE_VULKAN
|
||||
constexpr bool use_f32_precision = true;
|
||||
@ -1594,13 +1603,18 @@ static ggml_tensor * llm_build_kqv(
|
||||
GGML_UNUSED(n_ctx);
|
||||
|
||||
// split cached v into n_head heads (not transposed)
|
||||
struct ggml_tensor * v =
|
||||
ggml_view_3d(ctx, v_cache,
|
||||
n_embd_head_v, n_kv, n_head_kv,
|
||||
ggml_row_size(v_cache->type, n_embd_v_gqa),
|
||||
ggml_row_size(v_cache->type, n_embd_head_v),
|
||||
0);
|
||||
cb(v, "v", il);
|
||||
struct ggml_tensor * v = v_cache_view ? *v_cache_view : nullptr;
|
||||
if (!v) {
|
||||
v = ggml_view_3d(ctx, v_cache,
|
||||
n_embd_head_v, n_kv, n_head_kv,
|
||||
ggml_row_size(v_cache->type, n_embd_v_gqa),
|
||||
ggml_row_size(v_cache->type, n_embd_head_v),
|
||||
0);
|
||||
if (v_cache_view) {
|
||||
*v_cache_view = v;
|
||||
}
|
||||
cb(v, "v", il);
|
||||
}
|
||||
|
||||
cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
|
||||
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
|
||||
@ -1626,22 +1640,27 @@ static ggml_tensor * llm_build_kqv(
|
||||
} else {
|
||||
|
||||
// split cached v into n_head heads
|
||||
struct ggml_tensor * v;
|
||||
if (kv.v_trans) {
|
||||
v = ggml_view_3d(ctx, v_cache,
|
||||
n_kv, n_embd_head_v, n_head_kv,
|
||||
ggml_element_size(v_cache)*n_ctx,
|
||||
ggml_element_size(v_cache)*n_ctx*n_embd_head_v,
|
||||
0);
|
||||
} else {
|
||||
v = ggml_view_3d(ctx, v_cache,
|
||||
n_embd_head_v, n_kv, n_head_kv,
|
||||
ggml_row_size(v_cache->type, n_embd_v_gqa),
|
||||
ggml_row_size(v_cache->type, n_embd_head_v),
|
||||
0);
|
||||
v = ggml_cont(ctx, ggml_transpose(ctx, v));
|
||||
struct ggml_tensor * v = v_cache_view ? *v_cache_view : nullptr;
|
||||
if (!v) {
|
||||
if (kv.v_trans) {
|
||||
v = ggml_view_3d(ctx, v_cache,
|
||||
n_kv, n_embd_head_v, n_head_kv,
|
||||
ggml_element_size(v_cache)*n_ctx,
|
||||
ggml_element_size(v_cache)*n_ctx*n_embd_head_v,
|
||||
0);
|
||||
} else {
|
||||
v = ggml_view_3d(ctx, v_cache,
|
||||
n_embd_head_v, n_kv, n_head_kv,
|
||||
ggml_row_size(v_cache->type, n_embd_v_gqa),
|
||||
ggml_row_size(v_cache->type, n_embd_head_v),
|
||||
0);
|
||||
v = ggml_cont(ctx, ggml_transpose(ctx, v));
|
||||
}
|
||||
if (v_cache_view) {
|
||||
*v_cache_view = v;
|
||||
}
|
||||
cb(v, "v", il);
|
||||
}
|
||||
cb(v, "v", il);
|
||||
|
||||
auto kq_size = k->ne[1]*q->ne[1]*q->ne[2]*sizeof(float)/(1024*1024);
|
||||
if (cparams.attn_max_batch == 0 || cparams.attn_max_batch >= kq_size || k->ne[2] != q->ne[2] || v->ne[2] != q->ne[2] || sinks) {
|
||||
@ -1775,7 +1794,8 @@ ggml_tensor * llm_build_context::llm_build_kv(
|
||||
int32_t kv_head,
|
||||
int32_t n_kv,
|
||||
float kq_scale,
|
||||
const llm_build_cb & cb, int il, ggml_tensor * sinks, int n_swa) {
|
||||
const llm_build_cb & cb, int il, ggml_tensor * sinks, int n_swa, int kv_il,
|
||||
ggml_tensor ** k_cache_view, ggml_tensor ** v_cache_view) {
|
||||
const llama_hparams & hparams = lctx.model.hparams;
|
||||
const llama_cparams & cparams = lctx.cparams;
|
||||
|
||||
@ -1805,7 +1825,8 @@ ggml_tensor * llm_build_context::llm_build_kv(
|
||||
llm_build_kv_store(lctx, ctx, hparams, cparams, kv, graph, k_cur, v_cur, n_tokens, kv_head, cb, il);
|
||||
}
|
||||
|
||||
auto cur = llm_build_kqv(ctx, lctx, kv, graph, wo, wo_b, q_cur, kq_mask, n_tokens, n_kv, kq_scale, cb, il, sinks, n_swa);
|
||||
auto cur = llm_build_kqv(ctx, lctx, kv, graph, wo, wo_b, q_cur, kq_mask, n_tokens, n_kv, kq_scale, cb, il, sinks, n_swa, kv_il,
|
||||
k_cache_view, v_cache_view);
|
||||
cb(cur, "kqv_out", il);
|
||||
|
||||
return cur;
|
||||
@ -2332,6 +2353,10 @@ ggml_cgraph * llm_build_context::llama_build_graph(
|
||||
{
|
||||
result = llm.build_gemma4();
|
||||
} break;
|
||||
case LLM_ARCH_GEMMA4_MTP:
|
||||
{
|
||||
result = llm.build_gemma4_mtp();
|
||||
} break;
|
||||
case LLM_ARCH_STARCODER2:
|
||||
{
|
||||
result = llm.build_starcoder2();
|
||||
|
||||
@ -240,6 +240,8 @@ struct llm_build_context {
|
||||
|
||||
ggml_cgraph * build_gemma4();
|
||||
|
||||
ggml_cgraph * build_gemma4_mtp();
|
||||
|
||||
ggml_cgraph * build_starcoder2();
|
||||
|
||||
ggml_cgraph * build_mamba();
|
||||
@ -339,7 +341,8 @@ struct llm_build_context {
|
||||
int32_t kv_head,
|
||||
int32_t n_kv,
|
||||
float kq_scale,
|
||||
const llm_build_cb & cb, int il, ggml_tensor * sinks = nullptr, int n_swa = 0);
|
||||
const llm_build_cb & cb, int il, ggml_tensor * sinks = nullptr, int n_swa = 0, int kv_il = -1,
|
||||
ggml_tensor ** k_cache_view = nullptr, ggml_tensor ** v_cache_view = nullptr);
|
||||
|
||||
static ggml_tensor * llm_build_ffn(ggml_context * ctx, llama_context & lctx, ggml_tensor * ffn_norm,
|
||||
ggml_tensor * cur,
|
||||
|
||||
@ -199,6 +199,7 @@ struct llama_context {
|
||||
struct llama_cparams cparams;
|
||||
struct llama_sampling sampling;
|
||||
struct llama_kv_cache kv_self;
|
||||
struct llama_context * mtp_target_ctx = nullptr;
|
||||
struct llama_control_vector cvec;
|
||||
|
||||
std::vector<float> scale_data;
|
||||
|
||||
@ -754,6 +754,26 @@ void llm_load_hparams(
|
||||
default: model.type = e_model::MODEL_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_GEMMA4_MTP:
|
||||
{
|
||||
ml.get_key(LLM_KV_MTP_BACKBONE_EMBEDDING_LENGTH, hparams.mtp_backbone_n_embd);
|
||||
ml.get_key(LLM_KV_MTP_USE_ORDERED_EMBEDDINGS, hparams.mtp_use_ordered_embeddings, false);
|
||||
ml.get_key(LLM_KV_MTP_CENTROID_COUNT, hparams.mtp_num_centroids, false);
|
||||
ml.get_key(LLM_KV_MTP_CENTROID_TOP_K, hparams.mtp_centroid_top_k, false);
|
||||
|
||||
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer);
|
||||
ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false);
|
||||
|
||||
hparams.n_layer_kv_from_start = hparams.n_layer;
|
||||
hparams.f_attention_scale = 1.0f;
|
||||
|
||||
switch (hparams.mtp_backbone_n_embd) {
|
||||
case 5376: model.type = e_model::MODEL_32B; break;
|
||||
default: model.type = e_model::MODEL_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
|
||||
case LLM_ARCH_STARCODER2:
|
||||
{
|
||||
|
||||
@ -134,6 +134,12 @@ struct llama_hparams {
|
||||
// gemma4 per-layer embedding
|
||||
uint32_t n_embd_per_layer = 0;
|
||||
|
||||
// gemma4 separate assistant MTP
|
||||
uint32_t mtp_backbone_n_embd = 0;
|
||||
bool mtp_use_ordered_embeddings = false;
|
||||
uint32_t mtp_num_centroids = 0;
|
||||
uint32_t mtp_centroid_top_k = 0;
|
||||
|
||||
// needed by encoder-decoder models (e.g. T5, FLAN-T5)
|
||||
// ref: https://github.com/ggerganov/llama.cpp/pull/8141
|
||||
llama_token dec_start_token_id = -1;
|
||||
@ -152,6 +158,7 @@ struct llama_hparams {
|
||||
if (this->n_vocab != other.n_vocab) return true;
|
||||
if (this->n_ctx_train != other.n_ctx_train) return true;
|
||||
if (this->n_embd != other.n_embd) return true;
|
||||
if (this->mtp_backbone_n_embd != other.mtp_backbone_n_embd) return true;
|
||||
if (this->n_layer != other.n_layer) return true;
|
||||
if (this->n_rot != other.n_rot) return true;
|
||||
if (this->n_swa != other.n_swa) return true;
|
||||
|
||||
@ -95,6 +95,8 @@ struct create_tensors_helper : public create_tensors_helper_interface {
|
||||
|
||||
bool create_gemma4_tensors(const LLM_TN & tn);
|
||||
|
||||
bool create_gemma4_mtp_tensors(const LLM_TN & tn);
|
||||
|
||||
bool create_starcoder2_tensors(const LLM_TN & tn);
|
||||
|
||||
bool create_mamba_tensors(const LLM_TN & tn);
|
||||
@ -2016,6 +2018,7 @@ bool create_tensors_helper::create_gemma4_tensors(const LLM_TN & tn) {
|
||||
|
||||
const uint32_t n_embd_per_layer = hparams.n_embd_per_layer;
|
||||
const int64_t n_ff_exp = hparams.n_ff_exp;
|
||||
const bool use_split_ctx = model.split_mode == LLAMA_SPLIT_MODE_GRAPH || model.split_mode == LLAMA_SPLIT_MODE_ATTN;
|
||||
|
||||
if (n_embd_head_k != n_embd_head_v) {
|
||||
throw std::runtime_error("Gemma 4 requires n_embd_head_k == n_embd_head_v");
|
||||
@ -2043,7 +2046,8 @@ bool create_tensors_helper::create_gemma4_tensors(const LLM_TN & tn) {
|
||||
int rope_freqs_flag = 0;
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
ggml_context * ctx_split = ctx_for_layer_split(i);
|
||||
ggml_context * ctx_layer = ctx_for_layer(i);
|
||||
ggml_context * ctx_split = use_split_ctx ? ctx_for_layer_split(i) : ctx_layer;
|
||||
auto & layer = model.layers[i];
|
||||
const int64_t n_head = hparams.n_head(i);
|
||||
const int64_t n_embd_head = hparams.n_embd_head_k(i);
|
||||
@ -2110,6 +2114,53 @@ bool create_tensors_helper::create_gemma4_tensors(const LLM_TN & tn) {
|
||||
return use_mmap_buffer;
|
||||
}
|
||||
|
||||
bool create_tensors_helper::create_gemma4_mtp_tensors(const LLM_TN & tn) {
|
||||
LOADING_PRELUDE
|
||||
|
||||
const int64_t n_backbone = hparams.mtp_backbone_n_embd;
|
||||
const bool use_split_ctx = model.split_mode == LLAMA_SPLIT_MODE_GRAPH || model.split_mode == LLAMA_SPLIT_MODE_ATTN;
|
||||
if (n_backbone <= 0) {
|
||||
throw std::runtime_error("Gemma 4 MTP assistant requires backbone_embedding_length metadata");
|
||||
}
|
||||
|
||||
model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
||||
model.output = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
if (model.output == NULL) {
|
||||
model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
|
||||
}
|
||||
model.mtp_pre_proj = create_tensor(ctx_output, tn(LLM_TENSOR_MTP_PRE_PROJ, "weight"), {2*n_backbone, n_embd}, 0);
|
||||
model.mtp_post_proj = create_tensor(ctx_output, tn(LLM_TENSOR_MTP_POST_PROJ, "weight"), {n_embd, n_backbone}, 0);
|
||||
|
||||
model.mtp_token_ordering = create_tensor(ctx_output, tn(LLM_TENSOR_MTP_TOKEN_ORDERING, "weight"), {n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
model.mtp_centroids = create_tensor(ctx_output, tn(LLM_TENSOR_MTP_CENTROIDS, "weight"), {n_embd, hparams.mtp_num_centroids}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
ggml_context * ctx_layer = ctx_for_layer(i);
|
||||
ggml_context * ctx_split = use_split_ctx ? ctx_for_layer_split(i) : ctx_layer;
|
||||
auto & layer = model.layers[i];
|
||||
const int64_t n_head = hparams.n_head(i);
|
||||
const int64_t n_embd_head = hparams.n_embd_head_k(i);
|
||||
const int64_t n_ff_cur = hparams.n_ff(i);
|
||||
|
||||
layer.attn_norm = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
||||
layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head*n_head}, 0);
|
||||
layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head*n_head, n_embd}, 0);
|
||||
|
||||
layer.attn_q_norm = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head}, 0);
|
||||
layer.attn_post_norm = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
|
||||
layer.out_scale = create_tensor(ctx_split, tn(LLM_TENSOR_LAYER_OUT_SCALE, "weight", i), {1u}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
|
||||
layer.ffn_norm = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
||||
layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff_cur}, 0);
|
||||
layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff_cur}, 0);
|
||||
layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff_cur, n_embd}, 0);
|
||||
layer.ffn_post_norm = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
|
||||
}
|
||||
|
||||
return use_mmap_buffer;
|
||||
}
|
||||
|
||||
bool create_tensors_helper::create_starcoder2_tensors(const LLM_TN & tn) {
|
||||
LOADING_PRELUDE
|
||||
|
||||
@ -4071,6 +4122,8 @@ bool create_tensors_helper::create_tensors() {
|
||||
use_mmap_buffer = create_gemma_tensors(tn, 3); break;
|
||||
case LLM_ARCH_GEMMA4:
|
||||
use_mmap_buffer = create_gemma4_tensors(tn); break;
|
||||
case LLM_ARCH_GEMMA4_MTP:
|
||||
use_mmap_buffer = create_gemma4_mtp_tensors(tn); break;
|
||||
case LLM_ARCH_STARCODER2:
|
||||
use_mmap_buffer = create_starcoder2_tensors(tn); break;
|
||||
case LLM_ARCH_MAMBA:
|
||||
@ -4140,14 +4193,15 @@ bool create_tensors_helper::create_tensors() {
|
||||
use_mmap_buffer &= !has_buft_overrides;
|
||||
}
|
||||
|
||||
if (model.arch == LLM_ARCH_GEMMA4 && (model.split_mode == LLAMA_SPLIT_MODE_GRAPH || model.split_mode == LLAMA_SPLIT_MODE_ATTN)) {
|
||||
bool supported = true;
|
||||
if (model.tok_embd_per_layer) {
|
||||
supported = false;
|
||||
}
|
||||
if (!supported) {
|
||||
{
|
||||
const bool unsupported =
|
||||
(model.arch == LLM_ARCH_GEMMA4_MTP) ||
|
||||
(model.arch == LLM_ARCH_GEMMA4 && model.tok_embd_per_layer);
|
||||
if (unsupported && (model.split_mode == LLAMA_SPLIT_MODE_GRAPH || model.split_mode == LLAMA_SPLIT_MODE_ATTN)) {
|
||||
LLAMA_LOG_WARN("\n=========================================================\n");
|
||||
LLAMA_LOG_WARN("Split mode 'graph' is not supported for this Gemma4 variant\n");
|
||||
LLAMA_LOG_WARN("Split mode 'graph' is not supported for %s\n",
|
||||
model.arch == LLM_ARCH_GEMMA4_MTP ? "Gemma 4 MTP assistant"
|
||||
: "this Gemma4 variant");
|
||||
LLAMA_LOG_WARN(" => changing split mode to 'layer'\n");
|
||||
LLAMA_LOG_WARN("===========================================================\n\n");
|
||||
model.split_mode = LLAMA_SPLIT_MODE_LAYER;
|
||||
|
||||
@ -803,6 +803,28 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
|
||||
{ LLM_TENSOR_PER_LAYER_POST_NORM, "blk.%d.post_norm" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_GEMMA4_MTP,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
|
||||
{ LLM_TENSOR_LAYER_OUT_SCALE, "blk.%d.layer_output_scale" },
|
||||
{ LLM_TENSOR_MTP_PRE_PROJ, "mtp_pre_proj" },
|
||||
{ LLM_TENSOR_MTP_POST_PROJ, "mtp_post_proj" },
|
||||
{ LLM_TENSOR_MTP_TOKEN_ORDERING, "mtp_token_ordering" },
|
||||
{ LLM_TENSOR_MTP_CENTROIDS, "mtp_centroids" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_STARCODER2,
|
||||
{
|
||||
@ -1881,6 +1903,27 @@ bool llama_model_has_recurrent(const llama_model * model) {
|
||||
return llm_arch_is_hybrid(model->arch) || llm_arch_is_recurrent(model->arch);
|
||||
}
|
||||
|
||||
bool llama_model_is_gemma4_mtp_assistant(const llama_model * model) {
|
||||
return model && model->arch == LLM_ARCH_GEMMA4_MTP;
|
||||
}
|
||||
|
||||
bool llama_is_gemma4_mtp_file(const char * path) {
|
||||
if (!path || !*path) return false;
|
||||
struct gguf_init_params params = { /*.no_alloc =*/ true, /*.ctx =*/ nullptr };
|
||||
struct gguf_context * ctx = gguf_init_from_file(path, params);
|
||||
if (!ctx) return false;
|
||||
bool result = false;
|
||||
const int key_id = gguf_find_key(ctx, "general.architecture");
|
||||
if (key_id >= 0) {
|
||||
const char * arch = gguf_get_val_str(ctx, key_id);
|
||||
if (arch && strcmp(arch, "gemma4_mtp") == 0) {
|
||||
result = true;
|
||||
}
|
||||
}
|
||||
gguf_free(ctx);
|
||||
return result;
|
||||
}
|
||||
|
||||
bool llama_model_is_split_mode_graph(const struct llama_model * model) {
|
||||
return model && (model->split_mode == LLAMA_SPLIT_MODE_GRAPH || model->split_mode == LLAMA_SPLIT_MODE_ATTN);
|
||||
}
|
||||
|
||||
@ -405,6 +405,11 @@ struct llama_model {
|
||||
struct ggml_tensor * per_layer_model_proj = nullptr;
|
||||
struct ggml_tensor * per_layer_proj_norm = nullptr;
|
||||
|
||||
struct ggml_tensor * mtp_pre_proj = nullptr;
|
||||
struct ggml_tensor * mtp_post_proj = nullptr;
|
||||
struct ggml_tensor * mtp_token_ordering = nullptr;
|
||||
struct ggml_tensor * mtp_centroids = nullptr;
|
||||
|
||||
struct ggml_tensor * output_norm;
|
||||
struct ggml_tensor * output_norm_b;
|
||||
struct ggml_tensor * output;
|
||||
|
||||
130
src/llama.cpp
130
src/llama.cpp
@ -25,6 +25,9 @@
|
||||
#include "ggml-alloc.h"
|
||||
#include "ggml-backend.h"
|
||||
|
||||
uint32_t llama_mtp_state_n_embd(const struct llama_context * ctx);
|
||||
void llama_set_mtp_target_context(struct llama_context * ctx, struct llama_context * target_ctx);
|
||||
|
||||
// TODO: fix these includes
|
||||
#include "iqk/iqk_quantize.h"
|
||||
#include "iqk/iqk_cpu_ops.h"
|
||||
@ -562,6 +565,7 @@ void llama_context::reset_scheduler() {
|
||||
bool llama_context::can_reuse_graph(const llama_batch & u_batch) {
|
||||
if (!cparams.graph_reuse) return false;
|
||||
if (kv_self.save_per_step_ssm) return false;
|
||||
if (model.arch == LLM_ARCH_GEMMA4_MTP && mtp_target_ctx != nullptr) return false;
|
||||
auto the_prev = cparams.mtp_op_type == MTP_OP_NONE ? prev.get() : prev_mtp.get();
|
||||
if (!the_prev || !the_prev->graph) return false;
|
||||
//if (u_batch.n_tokens > 1) return false;
|
||||
@ -810,6 +814,9 @@ static bool llama_kv_cache_init(
|
||||
const bool is_mtp_tail = qwen_mtp && i >= n_mtp_first;
|
||||
if (split_cache && !is_mtp_tail) {
|
||||
buft_layer_count[model.buft_layer[i].buft_matrix]++;
|
||||
if (model.buft_layer[i].buft != model.buft_layer[i].buft_matrix) {
|
||||
buft_layer_count[model.buft_layer[i].buft]++;
|
||||
}
|
||||
} else {
|
||||
buft_layer_count[model.buft_layer[i].buft]++;
|
||||
}
|
||||
@ -2519,6 +2526,10 @@ static std::pair<std::vector<double>, double> get_layer_sizes(const llama_model_
|
||||
if (name == "output_norm.weight") {
|
||||
continue;
|
||||
}
|
||||
if (name == "mtp_pre_proj.weight" || name == "mtp_post_proj.weight" ||
|
||||
name == "mtp_centroids.weight" || name == "mtp_token_ordering.weight") {
|
||||
continue;
|
||||
}
|
||||
auto pos = name.find("blk.");
|
||||
if (pos != 0) {
|
||||
LLAMA_LOG_WARN("Oops: tensor with strange name %s\n", name.c_str());
|
||||
@ -2706,7 +2717,19 @@ static bool llm_load_tensors(
|
||||
auto & hparams = model.hparams;
|
||||
|
||||
if (split_mode == LLAMA_SPLIT_MODE_GRAPH || split_mode == LLAMA_SPLIT_MODE_ATTN) {
|
||||
if (!is_model_split_supported(model)) {
|
||||
const bool unsupported_gemma_split =
|
||||
model.arch == LLM_ARCH_GEMMA4_MTP ||
|
||||
(model.arch == LLM_ARCH_GEMMA4 && hparams.n_embd_per_layer > 0);
|
||||
|
||||
if (unsupported_gemma_split) {
|
||||
LLAMA_LOG_WARN("\n=========================================================\n");
|
||||
LLAMA_LOG_WARN("Split mode 'graph' is not supported for %s\n",
|
||||
model.arch == LLM_ARCH_GEMMA4_MTP ? "Gemma 4 MTP assistant"
|
||||
: "this Gemma4 variant");
|
||||
LLAMA_LOG_WARN(" => changing split mode to 'layer'\n");
|
||||
LLAMA_LOG_WARN("===========================================================\n\n");
|
||||
split_mode = LLAMA_SPLIT_MODE_LAYER;
|
||||
} else if (!is_model_split_supported(model)) {
|
||||
LLAMA_LOG_WARN("\n=======================================================\n");
|
||||
LLAMA_LOG_WARN("Split mode 'graph' is not supported for this model\n");
|
||||
LLAMA_LOG_WARN(" => changing split mode to 'layer'\n");
|
||||
@ -3028,6 +3051,20 @@ static bool llm_load_tensors(
|
||||
}
|
||||
}
|
||||
}
|
||||
if (model.arch == LLM_ARCH_GEMMA4_MTP && split_mode == LLAMA_SPLIT_MODE_LAYER && device_count > 0 && n_gpu_layers > 0) {
|
||||
const int mtp_device = std::clamp(main_gpu, 0, device_count - 1);
|
||||
|
||||
LLAMA_LOG_INFO("%s: Gemma 4 MTP assistant forcing layer placement to GPU %d under layer split\n",
|
||||
__func__, mtp_device);
|
||||
|
||||
for (int i = i_gpu_start; i < n_layer; ++i) {
|
||||
model.default_layer_device[i] = mtp_device;
|
||||
}
|
||||
if (n_gpu_layers > n_layer) {
|
||||
model.default_layer_device[n_layer] = mtp_device;
|
||||
}
|
||||
}
|
||||
|
||||
// assign the repeating layers to the devices according to the splits
|
||||
if (split_mode == LLAMA_SPLIT_MODE_LAYER) {
|
||||
for (int i = i_gpu_start; i < n_layer; ++i) {
|
||||
@ -3609,7 +3646,11 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||
#endif
|
||||
// NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
|
||||
if (cparams.causal_attn && !lctx.is_encoding) {
|
||||
const int64_t n_kv = kv_self.n;
|
||||
const llama_kv_cache & mask_kv_self =
|
||||
(lctx.model.arch == LLM_ARCH_GEMMA4_MTP && lctx.mtp_target_ctx != nullptr)
|
||||
? lctx.mtp_target_ctx->kv_self
|
||||
: kv_self;
|
||||
const int64_t n_kv = mask_kv_self.n;
|
||||
const int64_t n_tokens = batch.n_tokens;
|
||||
|
||||
|
||||
@ -3636,21 +3677,21 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||
}
|
||||
}
|
||||
|
||||
auto noalibi_f16 = [&lctx, &hparams, n_kv, data_f16, data_swa_f16] (int j, llama_pos pos, llama_seq_id seq_id, int first, int last) {
|
||||
auto noalibi_f16 = [&mask_kv_self, &hparams, n_kv, data_f16, data_swa_f16] (int j, llama_pos pos, llama_seq_id seq_id, int first, int last) {
|
||||
ggml_half h_inf = ggml_fp32_to_fp16(-INFINITY);
|
||||
ggml_half h_zero = ggml_fp32_to_fp16(0.f);
|
||||
for (int i = first; i < last; ++i) {
|
||||
ggml_half h = !lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos ? h_inf : h_zero;
|
||||
ggml_half h = !mask_kv_self.cells[i].has_seq_id(seq_id) || mask_kv_self.cells[i].pos > pos ? h_inf : h_zero;
|
||||
if (data_f16) data_f16[j*n_kv + i] = h;
|
||||
if (data_swa_f16) {
|
||||
if (h != h_inf) {
|
||||
if (hparams.n_attn_chunk) {
|
||||
llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk;
|
||||
if (lctx.kv_self.cells[i].pos < pos_chunk_start || pos < pos_chunk_start) {
|
||||
if (mask_kv_self.cells[i].pos < pos_chunk_start || pos < pos_chunk_start) {
|
||||
h = h_inf;
|
||||
}
|
||||
} else {
|
||||
if (pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa) {
|
||||
if (pos - mask_kv_self.cells[i].pos >= (int32_t)hparams.n_swa) {
|
||||
h = h_inf;
|
||||
}
|
||||
}
|
||||
@ -3663,7 +3704,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||
if (n_kv >= 1024 && n_tokens >= 32) {
|
||||
int n_thread = std::max(1, int(std::thread::hardware_concurrency()/2));
|
||||
int npt = (n_kv + n_thread - 1)/n_thread;
|
||||
auto compute = [&batch, &lctx, &hparams, &cparams, &noalibi_f16, n_tokens, n_kv, npt, data, data_swa, data_f16, data_swa_f16] (int ith) {
|
||||
auto compute = [&batch, &mask_kv_self, &hparams, &cparams, &noalibi_f16, n_tokens, n_kv, npt, data, data_swa, data_f16, data_swa_f16] (int ith) {
|
||||
int first = ith * npt;
|
||||
int last = std::min(int(n_kv), first + npt);
|
||||
if (last <= first) return;
|
||||
@ -3678,11 +3719,11 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||
|
||||
for (int i = first; i < last; ++i) {
|
||||
float f;
|
||||
if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) {
|
||||
if (!mask_kv_self.cells[i].has_seq_id(seq_id) || mask_kv_self.cells[i].pos > pos) {
|
||||
f = -INFINITY;
|
||||
} else {
|
||||
if (hparams.use_alibi) {
|
||||
f = -std::abs(lctx.kv_self.cells[i].pos - pos);
|
||||
f = -std::abs(mask_kv_self.cells[i].pos - pos);
|
||||
} else {
|
||||
f = 0.0f;
|
||||
}
|
||||
@ -3700,11 +3741,11 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||
if (f > -INFINITY) {
|
||||
if (hparams.n_attn_chunk) {
|
||||
llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk;
|
||||
if (lctx.kv_self.cells[i].pos < pos_chunk_start || pos < pos_chunk_start) {
|
||||
if (mask_kv_self.cells[i].pos < pos_chunk_start || pos < pos_chunk_start) {
|
||||
f = -INFINITY;
|
||||
}
|
||||
} else {
|
||||
if (pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa) {
|
||||
if (pos - mask_kv_self.cells[i].pos >= (int32_t)hparams.n_swa) {
|
||||
f = -INFINITY;
|
||||
}
|
||||
}
|
||||
@ -3759,11 +3800,11 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||
|
||||
for (int i = 0; i < n_kv; ++i) {
|
||||
float f;
|
||||
if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) {
|
||||
if (!mask_kv_self.cells[i].has_seq_id(seq_id) || mask_kv_self.cells[i].pos > pos) {
|
||||
f = -INFINITY;
|
||||
} else {
|
||||
if (hparams.use_alibi) {
|
||||
f = -std::abs(lctx.kv_self.cells[i].pos - pos);
|
||||
f = -std::abs(mask_kv_self.cells[i].pos - pos);
|
||||
} else {
|
||||
f = 0.0f;
|
||||
}
|
||||
@ -3780,11 +3821,11 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||
if (data_swa || data_swa_f16) {
|
||||
if (hparams.n_attn_chunk) {
|
||||
llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk;
|
||||
if (lctx.kv_self.cells[i].pos < pos_chunk_start || pos < pos_chunk_start) {
|
||||
if (mask_kv_self.cells[i].pos < pos_chunk_start || pos < pos_chunk_start) {
|
||||
f = -INFINITY;
|
||||
}
|
||||
} else {
|
||||
if (pos - kv_self.cells[i].pos >= (int32_t)hparams.n_swa) {
|
||||
if (pos - mask_kv_self.cells[i].pos >= (int32_t)hparams.n_swa) {
|
||||
f = -INFINITY;
|
||||
}
|
||||
}
|
||||
@ -4125,6 +4166,21 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||
|
||||
// Make sure enough space is available for outputs.
|
||||
// Returns max number of outputs for which space was reserved.
|
||||
static uint32_t llama_output_embd_width(const llama_context & lctx) {
|
||||
const auto & hparams = lctx.model.hparams;
|
||||
if (lctx.cparams.mtp && lctx.model.arch == LLM_ARCH_GEMMA4_MTP && hparams.mtp_backbone_n_embd > 0) {
|
||||
return hparams.mtp_backbone_n_embd;
|
||||
}
|
||||
return hparams.n_embd;
|
||||
}
|
||||
|
||||
static bool llama_context_has_mtp_outputs(const llama_context & lctx) {
|
||||
return lctx.cparams.mtp && (
|
||||
lctx.model.hparams.nextn_predict_layers > 0 ||
|
||||
lctx.model.arch == LLM_ARCH_GEMMA4 ||
|
||||
lctx.model.arch == LLM_ARCH_GEMMA4_MTP);
|
||||
}
|
||||
|
||||
static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) {
|
||||
const auto & cparams = lctx.cparams;
|
||||
const auto & hparams = lctx.model.hparams;
|
||||
@ -4133,10 +4189,10 @@ static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) {
|
||||
|
||||
const auto n_batch = cparams.n_batch;
|
||||
const auto n_vocab = hparams.n_vocab;
|
||||
const auto n_embd = hparams.n_embd;
|
||||
const auto n_embd = llama_output_embd_width(lctx);
|
||||
|
||||
// TODO: use a per-batch flag for logits presence instead
|
||||
const bool has_mtp = lctx.model.hparams.nextn_predict_layers > 0 && lctx.cparams.mtp;
|
||||
const bool has_mtp = llama_context_has_mtp_outputs(lctx);
|
||||
const bool has_logits = !cparams.embeddings || has_mtp;
|
||||
const bool has_embd = lctx.is_encoding || (cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE)) || has_mtp;
|
||||
|
||||
@ -4305,7 +4361,8 @@ static int llama_decode_internal(
|
||||
|
||||
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
|
||||
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
||||
const bool has_mtp = cparams.mtp && hparams.nextn_predict_layers > 0;
|
||||
const bool has_mtp = llama_context_has_mtp_outputs(lctx);
|
||||
const uint32_t n_embd_output = llama_output_embd_width(lctx);
|
||||
|
||||
// count outputs
|
||||
if (batch_all.logits && !embd_pooled) {
|
||||
@ -4521,7 +4578,8 @@ static int llama_decode_internal(
|
||||
printf("sched_alloc_graph(...): %d us\n", int(tim2-tim1));
|
||||
#endif
|
||||
//if (u_batch.n_tokens == 1 && u_batch.embd == nullptr && lctx.cparams.graph_reuse) {
|
||||
if (u_batch.embd == nullptr && lctx.cparams.graph_reuse) {
|
||||
if (u_batch.embd == nullptr && lctx.cparams.graph_reuse &&
|
||||
!(lctx.model.arch == LLM_ARCH_GEMMA4_MTP && lctx.mtp_target_ctx != nullptr)) {
|
||||
prev = std::make_unique<llama_context::Prev>(llama_context::Prev{
|
||||
(int)u_batch.all_seq_id, (int)lctx.n_outputs, (int)lctx.kv_self.n,
|
||||
(int)u_batch.n_tokens, cparams.mtp_op_type, gf});
|
||||
@ -4546,13 +4604,13 @@ static int llama_decode_internal(
|
||||
res = nullptr;
|
||||
}
|
||||
else {
|
||||
const bool has_mtp = lctx.model.hparams.nextn_predict_layers > 0 && lctx.model.mtp;
|
||||
const bool use_qwen_mtp_embd = has_mtp && (lctx.model.arch == LLM_ARCH_QWEN35 ||
|
||||
lctx.model.arch == LLM_ARCH_QWEN35MOE);
|
||||
const bool has_mtp = llama_context_has_mtp_outputs(lctx);
|
||||
const bool use_raw_mtp_embd = has_mtp && (lctx.model.arch == LLM_ARCH_QWEN35 ||
|
||||
lctx.model.arch == LLM_ARCH_QWEN35MOE || lctx.model.arch == LLM_ARCH_GEMMA4 || lctx.model.arch == LLM_ARCH_GEMMA4_MTP);
|
||||
if (cparams.embeddings || has_mtp) {
|
||||
for (int i = gf->n_nodes - 1; i >= 0; --i) {
|
||||
if (use_qwen_mtp_embd && strcmp(gf->nodes[i]->name, "result_mtp_embd") == 0) {
|
||||
// Qwen 3.5 uses raw hidden state before the final shared-head normalization.
|
||||
if (use_raw_mtp_embd && strcmp(gf->nodes[i]->name, "result_mtp_embd") == 0) {
|
||||
// MTP recurrent state can be wider/different than the logits head hidden state.
|
||||
embd = gf->nodes[i];
|
||||
break;
|
||||
}
|
||||
@ -4565,7 +4623,7 @@ static int llama_decode_internal(
|
||||
}
|
||||
}
|
||||
}
|
||||
if (cparams.embeddings && lctx.model.hparams.nextn_predict_layers == 0) {
|
||||
if (cparams.embeddings && lctx.model.hparams.nextn_predict_layers == 0 && !has_mtp) {
|
||||
res = nullptr; // do not extract logits for embedding case
|
||||
} else {
|
||||
if (!embd) { // do not extract embeddings when not needed
|
||||
@ -4667,13 +4725,13 @@ static int llama_decode_internal(
|
||||
{
|
||||
// extract token embeddings
|
||||
GGML_ASSERT(lctx.embd != nullptr);
|
||||
float * embd_out = lctx.embd + n_outputs_prev_embd*n_embd;
|
||||
const int32_t n_outputs_new_embd = has_mtp ? n_tokens : lctx.n_outputs;
|
||||
float * embd_out = lctx.embd + n_outputs_prev_embd*n_embd_output;
|
||||
const int32_t n_outputs_new_embd = has_mtp ? embd->ne[1] : lctx.n_outputs;
|
||||
|
||||
if (n_outputs_new_embd) {
|
||||
GGML_ASSERT( n_outputs_prev_embd + n_outputs_new_embd <= n_outputs_embd);
|
||||
GGML_ASSERT((n_outputs_prev_embd + n_outputs_new_embd)*n_embd <= (int64_t) lctx.embd_size);
|
||||
ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_outputs_new_embd*n_embd*sizeof(float));
|
||||
GGML_ASSERT((n_outputs_prev_embd + n_outputs_new_embd)*n_embd_output <= (int64_t) lctx.embd_size);
|
||||
ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_outputs_new_embd*n_embd_output*sizeof(float));
|
||||
}
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_MEAN:
|
||||
@ -4704,7 +4762,7 @@ static int llama_decode_internal(
|
||||
#endif
|
||||
}
|
||||
n_outputs_prev += lctx.n_outputs;
|
||||
n_outputs_prev_embd += has_mtp ? n_tokens : lctx.n_outputs;
|
||||
n_outputs_prev_embd += (has_mtp && embd) ? embd->ne[1] : lctx.n_outputs;
|
||||
cur_token += n_tokens;
|
||||
if (reset_previous) {
|
||||
// We need to discard this graph. Otherwise, iwith CUDA graphs enabled, the graph will get resused and this will reset the
|
||||
@ -6033,7 +6091,8 @@ struct llama_context * llama_init_from_model(
|
||||
}
|
||||
|
||||
if (model->arch != LLM_ARCH_GLM4_MOE && model->arch != LLM_ARCH_QWEN35 &&
|
||||
model->arch != LLM_ARCH_QWEN35MOE && cparams.mtp != 0) {
|
||||
model->arch != LLM_ARCH_QWEN35MOE && model->arch != LLM_ARCH_GEMMA4 &&
|
||||
model->arch != LLM_ARCH_GEMMA4_MTP && cparams.mtp != 0) {
|
||||
cparams.mtp = 0;
|
||||
}
|
||||
|
||||
@ -6572,6 +6631,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
|
||||
case LLM_ARCH_SEED_OSS:
|
||||
case LLM_ARCH_STEP35:
|
||||
case LLM_ARCH_GEMMA4:
|
||||
case LLM_ARCH_GEMMA4_MTP:
|
||||
return LLAMA_ROPE_TYPE_NEOX;
|
||||
|
||||
case LLM_ARCH_QWEN2VL:
|
||||
@ -9903,6 +9963,14 @@ void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float
|
||||
ctx->draft_input_hidden_state = hidden_state;
|
||||
}
|
||||
|
||||
uint32_t llama_mtp_state_n_embd(const struct llama_context * ctx) {
|
||||
return llama_output_embd_width(*ctx);
|
||||
}
|
||||
|
||||
void llama_set_mtp_target_context(struct llama_context * ctx, struct llama_context * target_ctx) {
|
||||
ctx->mtp_target_ctx = target_ctx;
|
||||
}
|
||||
|
||||
size_t llama_fill_from_utf8(void* utf8, void* cpts, void* scripts) {
|
||||
return unicode_fill_from_utf8((std::string*)utf8, (std::vector<uint32_t>*)cpts, (std::vector<std::string>*)scripts);
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user