mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-06-28 04:30:15 -05:00
Fix MTP warmup for GLM models
This commit is contained in:
parent
3b81f63acd
commit
2c1dc8781b
@ -2347,8 +2347,6 @@ void common_speculative_checkpoint_restore(
|
|||||||
common_speculative_checkpoint_discard(ckpt, ctx);
|
common_speculative_checkpoint_discard(ckpt, ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool mtp_model_uses_recurrent_conditioning(const common_speculative_state_mtp & state);
|
|
||||||
|
|
||||||
void common_speculative_commit(
|
void common_speculative_commit(
|
||||||
common_speculative * spec,
|
common_speculative * spec,
|
||||||
llama_context * ctx,
|
llama_context * ctx,
|
||||||
@ -2559,6 +2557,7 @@ static bool mtp_model_uses_recurrent_conditioning(const common_speculative_state
|
|||||||
if (state.ctx_mtp == nullptr) {
|
if (state.ctx_mtp == nullptr) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
return true;
|
||||||
|
|
||||||
const llama_model * model = llama_get_model(state.ctx_mtp);
|
const llama_model * model = llama_get_model(state.ctx_mtp);
|
||||||
if (!llama_model_has_recurrent(model)) {
|
if (!llama_model_has_recurrent(model)) {
|
||||||
|
|||||||
@ -519,8 +519,8 @@ static ggml_cgraph * build_gemma4_graph_parallel(llm_build_context & llm, llama_
|
|||||||
}
|
}
|
||||||
|
|
||||||
cur = llm_build_context::build_output(lctx, ctx0, cur, model.output, model.output_norm, cb);
|
cur = llm_build_context::build_output(lctx, ctx0, cur, model.output, model.output_norm, cb);
|
||||||
cb(cur, "almost_result", -1);
|
|
||||||
if (hparams.f_final_logit_softcapping > 0) {
|
if (hparams.f_final_logit_softcapping > 0) {
|
||||||
|
cb(cur, "almost_result", -1);
|
||||||
cur = ggml_softcap(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping, hparams.f_final_logit_softcapping);
|
cur = ggml_softcap(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping, hparams.f_final_logit_softcapping);
|
||||||
}
|
}
|
||||||
cb(cur, "result_output", -1);
|
cb(cur, "result_output", -1);
|
||||||
@ -666,6 +666,7 @@ ggml_cgraph * llm_build_context::build_gemma4_mtp() {
|
|||||||
|
|
||||||
ggml_tensor * mtp_embd = llm_build_lora_mm(lctx, ctx0, model.mtp_post_proj, cur);
|
ggml_tensor * mtp_embd = llm_build_lora_mm(lctx, ctx0, model.mtp_post_proj, cur);
|
||||||
cb(mtp_embd, "result_mtp_embd", -1);
|
cb(mtp_embd, "result_mtp_embd", -1);
|
||||||
|
ggml_set_output(mtp_embd);
|
||||||
ggml_build_forward_expand(gf, mtp_embd);
|
ggml_build_forward_expand(gf, mtp_embd);
|
||||||
|
|
||||||
ggml_tensor * logits;
|
ggml_tensor * logits;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user