MTP: Reuse graphs (again)

This commit is contained in:
Kawrakow 2026-05-11 15:31:15 +00:00
parent eb570eb966
commit 16369dbf0f
5 changed files with 74 additions and 31 deletions

View File

@ -1602,6 +1602,12 @@ extern "C" {
struct ggml_tensor * a,
struct ggml_tensor * b);
GGML_API struct ggml_tensor * ggml_mul_mat_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * result);
// change the precision of a matrix multiplication
// set to GGML_PREC_F32 for higher precision (useful for phi-2)
GGML_API void ggml_mul_mat_set_prec(

View File

@ -7799,6 +7799,32 @@ struct ggml_tensor * ggml_mul_mat(
return result;
}
struct ggml_tensor * ggml_mul_mat_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * result_in) {
if (!result_in) {
return ggml_mul_mat(ctx, a, b);
}
GGML_ASSERT(ggml_can_mul_mat(a, b));
GGML_ASSERT(!ggml_is_transposed(a));
GGML_ASSERT(b->ne[2] == 1 && b->ne[3] == 1);
const int64_t ne[4] = { a->ne[1], b->ne[1], b->ne[2], b->ne[3] };
GGML_ASSERT(result_in->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_nelements(result_in) >= ne[0]*ne[1]);
struct ggml_tensor * result = ggml_view_2d(ctx, result_in, ne[0], ne[1], ne[0]*sizeof(float), 0);
result->op = GGML_OP_MUL_MAT;
result->grad = NULL;
result->src[0] = a;
result->src[1] = b;
return result;
}
void ggml_mul_mat_set_prec(
struct ggml_tensor * a,
enum ggml_prec prec) {

View File

@ -146,11 +146,13 @@ std::pair<ggml_tensor *, ggml_tensor *> delta_net::build_fused_delta_net(ggml_co
return {output_tokens, new_state};
}
std::pair<ggml_tensor *, ggml_tensor *> delta_net::build_qkvz(llama_context & lctx, ggml_context * ctx0, ggml_tensor * wqkv, ggml_tensor * wqkv_gate,
ggml_tensor * input, int il, const llm_build_cb & cb, ggml_cgraph * gf) {
std::pair<ggml_tensor *, ggml_tensor *> delta_net::build_qkvz(llama_context & lctx, ggml_context * ctx0,
ggml_tensor * wqkv, ggml_tensor * wqkv_gate,
ggml_tensor * input, int il, const llm_build_cb & cb, ggml_cgraph * gf, ggml_tensor * qkv_cpy) {
const int64_t n_tok = input->ne[1];
ggml_tensor * qkv_mixed = llm_build_context::llm_build_lora_mm(lctx, ctx0, wqkv, input);
auto qkv_mixed = qkv_cpy ? ggml_mul_mat_inplace(ctx0, wqkv, input, qkv_cpy)
: llm_build_context::llm_build_lora_mm(lctx, ctx0, wqkv, input);
cb(qkv_mixed, "qkv_mixed", il);
ggml_tensor * z = llm_build_context::llm_build_lora_mm(lctx, ctx0, wqkv_gate, input);
cb(z, "z", il);
@ -163,7 +165,7 @@ std::pair<ggml_tensor *, ggml_tensor *> delta_net::build_qkvz(llama_context & lc
std::pair<ggml_tensor *, ggml_tensor *> delta_net::build_qkvz(llama_context & lctx, ggml_context * ctx0, ggml_tensor * ssm_in,
int64_t head_k_dim, int64_t num_k_heads, int64_t head_v_dim, int64_t num_v_heads,
ggml_tensor * input, int il, const llm_build_cb & cb) {
ggml_tensor * input, int il, const llm_build_cb & cb, ggml_tensor * qkv_cpy) {
const int64_t n_tok = input->ne[1];
@ -210,17 +212,17 @@ std::pair<ggml_tensor *, ggml_tensor *> delta_net::build_qkvz(llama_context & lc
cb(value_flat, "value_flat", il);
ggml_tensor * qkv_mixed = ggml_concat(ctx0, query_flat, key_flat, 0);
qkv_mixed = ggml_concat(ctx0, qkv_mixed, value_flat, 0);
qkv_mixed = qkv_cpy ? ggml_concat_inplace(ctx0, qkv_mixed, value_flat, qkv_cpy, 0) : ggml_concat(ctx0, qkv_mixed, value_flat, 0);
cb(qkv_mixed, "qkv_mixed", il);
return { qkv_mixed, z };
}
std::pair<ggml_tensor *, ggml_tensor *> delta_net::build_qkvz(llama_context & lctx, ggml_context * ctx0, ggml_tensor * wqkv, ggml_tensor * wqkv_gate, ggml_tensor * ssm_in,
int64_t head_k_dim, int64_t num_k_heads, int64_t head_v_dim, int64_t num_v_heads, ggml_tensor * input, int il, const llm_build_cb & cb, ggml_cgraph * gf) {
int64_t head_k_dim, int64_t num_k_heads, int64_t head_v_dim, int64_t num_v_heads, ggml_tensor * input, int il, const llm_build_cb & cb, ggml_cgraph * gf, ggml_tensor * qkv_cpy) {
GGML_ASSERT((wqkv && wqkv_gate) || ssm_in);
return wqkv && wqkv_gate ? build_qkvz(lctx, ctx0, wqkv, wqkv_gate, input, il, cb, gf)
: build_qkvz(lctx, ctx0, ssm_in, head_k_dim, num_k_heads, head_v_dim, num_v_heads, input, il, cb);
return wqkv && wqkv_gate ? build_qkvz(lctx, ctx0, wqkv, wqkv_gate, input, il, cb, gf, qkv_cpy)
: build_qkvz(lctx, ctx0, ssm_in, head_k_dim, num_k_heads, head_v_dim, num_v_heads, input, il, cb, qkv_cpy);
}
std::pair<ggml_tensor *, ggml_tensor *> delta_net::build_beta_gate(llama_context & lctx, ggml_context * ctx0,
@ -488,17 +490,21 @@ ggml_tensor * delta_net::build_layer_attn_linear_core(ggml_context * ctx0, ggml_
int il_cb = 1000*il + id;
int64_t num_k_heads_id, num_v_heads_id;
ggml_tensor *qkv_mixed, *z;
auto qkv_cpy = save_per_step_states &&
il < (int)kv_self.ckpt.per_step_qkv.size() &&
id < (int)kv_self.ckpt.per_step_qkv[il].size() ?
kv_self.ckpt.per_step_qkv[il][id] : nullptr;
if (split_wqkv && split_wqkv_gate) {
num_k_heads_id = split_wqkv->splits[id]->ne[1]/(head_k_dim*(2 + gqa_ratio));
num_v_heads_id = num_k_heads_id * gqa_ratio;
auto p = build_qkvz(lctx, ctx0, split_wqkv->splits[id], split_wqkv_gate->splits[id], cur, il_cb, cb, gf);
auto p = build_qkvz(lctx, ctx0, split_wqkv->splits[id], split_wqkv_gate->splits[id], cur, il_cb, cb, gf, qkv_cpy);
qkv_mixed = p.first;
z = p.second;
} else {
num_k_heads_id = split_smm_in->splits[id]->ne[1]/(2*head_k_dim*(1 + gqa_ratio));
num_v_heads_id = num_k_heads_id * gqa_ratio;
auto p = build_qkvz(lctx, ctx0, nullptr, nullptr, split_smm_in->splits[id], head_k_dim, num_k_heads_id, head_v_dim, num_v_heads_id, cur, il, cb, gf);
//auto p = build_qkvz(lctx, ctx0, split_smm_in->splits[id], head_k_dim, num_k_heads_id, head_v_dim, num_v_heads_id, cur, il_cb, cb);
auto p = build_qkvz(lctx, ctx0, nullptr, nullptr, split_smm_in->splits[id],
head_k_dim, num_k_heads_id, head_v_dim, num_v_heads_id, cur, il, cb, gf, qkv_cpy);
qkv_mixed = p.first;
z = p.second;
}
@ -529,9 +535,9 @@ ggml_tensor * delta_net::build_layer_attn_linear_core(ggml_context * ctx0, ggml_
per_step_ckpt = kv_self.ckpt.per_step_ssm[il][id];
//GGML_ASSERT(per_step_ckpt);
}
if (save_per_step_states && il < (int)kv_self.ckpt.per_step_qkv.size() && id < (int)kv_self.ckpt.per_step_qkv[il].size()) {
build_qkv_cpy(ctx0, gf, qkv_mixed, kv_self.ckpt.per_step_qkv[il][id]);
}
//if (save_per_step_states && il < (int)kv_self.ckpt.per_step_qkv.size() && id < (int)kv_self.ckpt.per_step_qkv[il].size()) {
// build_qkv_cpy(ctx0, gf, qkv_mixed, kv_self.ckpt.per_step_qkv[il][id]);
//}
auto output = build_qkv(ctx0, split_s_l->splits[id], split_ssm_conv1d->splits[id], qkv_mixed, inp_s_seq_qnext, beta, gate,
head_k_dim, num_k_heads_id, head_v_dim, num_v_heads_id, hparams.ssm_d_conv,
state_seq_id_local, qnext_state_slots, reset_state_local, hparams.f_norm_rms_eps,
@ -579,8 +585,10 @@ ggml_tensor * delta_net::build_layer_attn_linear_core(ggml_context * ctx0, ggml_
auto norm = model.layers[il].attn_norm->extra ? ((ggml_split_tensor_t *)model.layers[il].attn_norm->extra)->splits[idx] : model.layers[il].attn_norm;
auto cur = llm_build_context::llm_build_norm(ctx0, input, hparams, norm, nullptr, LLM_NORM_RMS, cb, il);
auto qkv_cpy = save_per_step_states && il < (int)kv_self.ckpt.per_step_qkv.size() && !kv_self.ckpt.per_step_qkv[il].empty()
? kv_self.ckpt.per_step_qkv[il].front() : nullptr;
auto [qkv_mixed, z] = build_qkvz(lctx, ctx0, model.layers[il].wqkv, model.layers[il].wqkv_gate, model.layers[il].ssm_in,
head_k_dim, num_k_heads, head_v_dim, num_v_heads, cur, il, cb, gf);
head_k_dim, num_k_heads, head_v_dim, num_v_heads, cur, il, cb, gf, qkv_cpy);
auto [beta, gate] = build_beta_gate(lctx, ctx0, model.layers[il].ssm_beta_alpha, model.layers[il].ssm_beta, model.layers[il].ssm_alpha,
model.layers[il].ssm_dt, model.layers[il].ssm_a, num_k_heads, num_v_heads, n_seqs, cur, il, cb, gf);
@ -592,16 +600,9 @@ ggml_tensor * delta_net::build_layer_attn_linear_core(ggml_context * ctx0, ggml_
}
// Save qkv_mixed features for per-step conv state reconstruction
if (save_per_step_states && il < (int)kv_self.ckpt.per_step_qkv.size() && !kv_self.ckpt.per_step_qkv[il].empty()) {
build_qkv_cpy(ctx0, gf, qkv_mixed, kv_self.ckpt.per_step_qkv[il].front());
//const int64_t conv_dim = qkv_mixed->ne[0];
//const int64_t n_tok_qkv = qkv_mixed->ne[1] * qkv_mixed->ne[2];
//ggml_tensor * qkv_flat = ggml_reshape_2d(ctx0, qkv_mixed, conv_dim, n_tok_qkv);
//ggml_tensor * qkv_dst = ggml_view_2d(ctx0, kv_self.ckpt.per_step_qkv[il].front(),
// conv_dim, n_tok_qkv, conv_dim * sizeof(float), 0);
//auto qkv_cpy = ggml_cpy(ctx0, qkv_flat, qkv_dst);
//ggml_build_forward_expand(gf, qkv_cpy);
}
//if (save_per_step_states && il < (int)kv_self.ckpt.per_step_qkv.size() && !kv_self.ckpt.per_step_qkv[il].empty()) {
// build_qkv_cpy(ctx0, gf, qkv_mixed, kv_self.ckpt.per_step_qkv[il].front());
//}
auto output = build_qkv(ctx0, kv_self.s_l[il], model.layers[il].ssm_conv1d,
qkv_mixed, inp_s_seq_qnext, beta, gate,

View File

@ -33,13 +33,17 @@ private:
bool has_unique_seq_ids;
static std::pair<ggml_tensor *, ggml_tensor *> build_qkvz(llama_context & lctx, ggml_context * ctx0,
ggml_tensor * wqkv, ggml_tensor * wqkv_gate, ggml_tensor * input, int il, const llm_build_cb & cb, ggml_cgraph * gf);
ggml_tensor * wqkv, ggml_tensor * wqkv_gate, ggml_tensor * input, int il, const llm_build_cb & cb,
ggml_cgraph * gf, ggml_tensor * qkv_copy);
static std::pair<ggml_tensor *, ggml_tensor *> build_qkvz(llama_context & lctx, ggml_context * ctx0, ggml_tensor * ssm_in,
int64_t head_k_dim, int64_t num_k_heads, int64_t head_v_dim, int64_t num_v_heads, ggml_tensor * input, int il, const llm_build_cb & cb);
int64_t head_k_dim, int64_t num_k_heads, int64_t head_v_dim, int64_t num_v_heads, ggml_tensor * input, int il,
const llm_build_cb & cb, ggml_tensor * qkv_copy);
static std::pair<ggml_tensor *, ggml_tensor *> build_qkvz(llama_context & lctx, ggml_context * ctx0, ggml_tensor * wqkv, ggml_tensor * wqkv_gate, ggml_tensor * ssm_in,
int64_t head_k_dim, int64_t num_k_heads, int64_t head_v_dim, int64_t num_v_heads, ggml_tensor * input, int il, const llm_build_cb & cb, ggml_cgraph * gf);
static std::pair<ggml_tensor *, ggml_tensor *> build_qkvz(llama_context & lctx, ggml_context * ctx0,
ggml_tensor * wqkv, ggml_tensor * wqkv_gate, ggml_tensor * ssm_in,
int64_t head_k_dim, int64_t num_k_heads, int64_t head_v_dim, int64_t num_v_heads, ggml_tensor * input,
int il, const llm_build_cb & cb, ggml_cgraph * gf, ggml_tensor * qkv_copy);
static std::pair<ggml_tensor *, ggml_tensor *> build_beta_gate(llama_context & lctx, ggml_context * ctx0,
ggml_tensor * ssm_beta_alpha, ggml_tensor * ssm_beta, ggml_tensor * ssm_alpha,

View File

@ -553,6 +553,8 @@ struct llama_context::Prev {
int n_outputs;
int n_kv;
int n_tokens;
int save_per_step_ssm;
int per_step_max_allocated;
llama_mtp_op_type mtp_op_type;
ggml_cgraph * graph;
};
@ -565,12 +567,14 @@ void llama_context::reset_scheduler() {
bool llama_context::can_reuse_graph(const llama_batch & u_batch) {
if (!cparams.graph_reuse) return false;
if (kv_self.save_per_step_ssm) return false;
//if (kv_self.save_per_step_ssm) return false;
if (model.arch == LLM_ARCH_GEMMA4_MTP && mtp_target_ctx != nullptr) return false;
auto the_prev = cparams.mtp_op_type == MTP_OP_NONE ? prev.get() : prev_mtp.get();
if (!the_prev || !the_prev->graph) return false;
//if (u_batch.n_tokens > 1) return false;
if (u_batch.embd) return false;
if (the_prev->save_per_step_ssm != kv_self.save_per_step_ssm ||
the_prev->per_step_max_allocated != kv_self.ckpt.per_step_max_allocated) return false;
return u_batch.all_seq_id == the_prev->all_seq_id &&
kv_self.head > 0 &&
kv_self.n == the_prev->n_kv &&
@ -4669,7 +4673,9 @@ static int llama_decode_internal(
!(lctx.model.arch == LLM_ARCH_GEMMA4_MTP && lctx.mtp_target_ctx != nullptr)) {
prev = std::make_unique<llama_context::Prev>(llama_context::Prev{
(int)u_batch.all_seq_id, (int)lctx.n_outputs, (int)lctx.kv_self.n,
(int)u_batch.n_tokens, cparams.mtp_op_type, gf});
(int)u_batch.n_tokens,
lctx.kv_self.save_per_step_ssm, lctx.kv_self.ckpt.per_step_max_allocated,
cparams.mtp_op_type, gf});
}
} else {
//printf("Reusing graph with n_kv = %d, n_tokens = %d\n", (int)prev->n_kv, (int)prev->n_tokens);