mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-06-28 04:30:15 -05:00
MTP: Reuse graphs (again)
This commit is contained in:
parent
eb570eb966
commit
16369dbf0f
@ -1602,6 +1602,12 @@ extern "C" {
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_mul_mat_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * result);
|
||||
|
||||
// change the precision of a matrix multiplication
|
||||
// set to GGML_PREC_F32 for higher precision (useful for phi-2)
|
||||
GGML_API void ggml_mul_mat_set_prec(
|
||||
|
||||
@ -7799,6 +7799,32 @@ struct ggml_tensor * ggml_mul_mat(
|
||||
return result;
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_mul_mat_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * result_in) {
|
||||
if (!result_in) {
|
||||
return ggml_mul_mat(ctx, a, b);
|
||||
}
|
||||
GGML_ASSERT(ggml_can_mul_mat(a, b));
|
||||
GGML_ASSERT(!ggml_is_transposed(a));
|
||||
GGML_ASSERT(b->ne[2] == 1 && b->ne[3] == 1);
|
||||
|
||||
const int64_t ne[4] = { a->ne[1], b->ne[1], b->ne[2], b->ne[3] };
|
||||
GGML_ASSERT(result_in->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ggml_nelements(result_in) >= ne[0]*ne[1]);
|
||||
|
||||
struct ggml_tensor * result = ggml_view_2d(ctx, result_in, ne[0], ne[1], ne[0]*sizeof(float), 0);
|
||||
|
||||
result->op = GGML_OP_MUL_MAT;
|
||||
result->grad = NULL;
|
||||
result->src[0] = a;
|
||||
result->src[1] = b;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
void ggml_mul_mat_set_prec(
|
||||
struct ggml_tensor * a,
|
||||
enum ggml_prec prec) {
|
||||
|
||||
@ -146,11 +146,13 @@ std::pair<ggml_tensor *, ggml_tensor *> delta_net::build_fused_delta_net(ggml_co
|
||||
return {output_tokens, new_state};
|
||||
}
|
||||
|
||||
std::pair<ggml_tensor *, ggml_tensor *> delta_net::build_qkvz(llama_context & lctx, ggml_context * ctx0, ggml_tensor * wqkv, ggml_tensor * wqkv_gate,
|
||||
ggml_tensor * input, int il, const llm_build_cb & cb, ggml_cgraph * gf) {
|
||||
std::pair<ggml_tensor *, ggml_tensor *> delta_net::build_qkvz(llama_context & lctx, ggml_context * ctx0,
|
||||
ggml_tensor * wqkv, ggml_tensor * wqkv_gate,
|
||||
ggml_tensor * input, int il, const llm_build_cb & cb, ggml_cgraph * gf, ggml_tensor * qkv_cpy) {
|
||||
|
||||
const int64_t n_tok = input->ne[1];
|
||||
ggml_tensor * qkv_mixed = llm_build_context::llm_build_lora_mm(lctx, ctx0, wqkv, input);
|
||||
auto qkv_mixed = qkv_cpy ? ggml_mul_mat_inplace(ctx0, wqkv, input, qkv_cpy)
|
||||
: llm_build_context::llm_build_lora_mm(lctx, ctx0, wqkv, input);
|
||||
cb(qkv_mixed, "qkv_mixed", il);
|
||||
ggml_tensor * z = llm_build_context::llm_build_lora_mm(lctx, ctx0, wqkv_gate, input);
|
||||
cb(z, "z", il);
|
||||
@ -163,7 +165,7 @@ std::pair<ggml_tensor *, ggml_tensor *> delta_net::build_qkvz(llama_context & lc
|
||||
|
||||
std::pair<ggml_tensor *, ggml_tensor *> delta_net::build_qkvz(llama_context & lctx, ggml_context * ctx0, ggml_tensor * ssm_in,
|
||||
int64_t head_k_dim, int64_t num_k_heads, int64_t head_v_dim, int64_t num_v_heads,
|
||||
ggml_tensor * input, int il, const llm_build_cb & cb) {
|
||||
ggml_tensor * input, int il, const llm_build_cb & cb, ggml_tensor * qkv_cpy) {
|
||||
|
||||
const int64_t n_tok = input->ne[1];
|
||||
|
||||
@ -210,17 +212,17 @@ std::pair<ggml_tensor *, ggml_tensor *> delta_net::build_qkvz(llama_context & lc
|
||||
cb(value_flat, "value_flat", il);
|
||||
|
||||
ggml_tensor * qkv_mixed = ggml_concat(ctx0, query_flat, key_flat, 0);
|
||||
qkv_mixed = ggml_concat(ctx0, qkv_mixed, value_flat, 0);
|
||||
qkv_mixed = qkv_cpy ? ggml_concat_inplace(ctx0, qkv_mixed, value_flat, qkv_cpy, 0) : ggml_concat(ctx0, qkv_mixed, value_flat, 0);
|
||||
cb(qkv_mixed, "qkv_mixed", il);
|
||||
|
||||
return { qkv_mixed, z };
|
||||
}
|
||||
|
||||
std::pair<ggml_tensor *, ggml_tensor *> delta_net::build_qkvz(llama_context & lctx, ggml_context * ctx0, ggml_tensor * wqkv, ggml_tensor * wqkv_gate, ggml_tensor * ssm_in,
|
||||
int64_t head_k_dim, int64_t num_k_heads, int64_t head_v_dim, int64_t num_v_heads, ggml_tensor * input, int il, const llm_build_cb & cb, ggml_cgraph * gf) {
|
||||
int64_t head_k_dim, int64_t num_k_heads, int64_t head_v_dim, int64_t num_v_heads, ggml_tensor * input, int il, const llm_build_cb & cb, ggml_cgraph * gf, ggml_tensor * qkv_cpy) {
|
||||
GGML_ASSERT((wqkv && wqkv_gate) || ssm_in);
|
||||
return wqkv && wqkv_gate ? build_qkvz(lctx, ctx0, wqkv, wqkv_gate, input, il, cb, gf)
|
||||
: build_qkvz(lctx, ctx0, ssm_in, head_k_dim, num_k_heads, head_v_dim, num_v_heads, input, il, cb);
|
||||
return wqkv && wqkv_gate ? build_qkvz(lctx, ctx0, wqkv, wqkv_gate, input, il, cb, gf, qkv_cpy)
|
||||
: build_qkvz(lctx, ctx0, ssm_in, head_k_dim, num_k_heads, head_v_dim, num_v_heads, input, il, cb, qkv_cpy);
|
||||
}
|
||||
|
||||
std::pair<ggml_tensor *, ggml_tensor *> delta_net::build_beta_gate(llama_context & lctx, ggml_context * ctx0,
|
||||
@ -488,17 +490,21 @@ ggml_tensor * delta_net::build_layer_attn_linear_core(ggml_context * ctx0, ggml_
|
||||
int il_cb = 1000*il + id;
|
||||
int64_t num_k_heads_id, num_v_heads_id;
|
||||
ggml_tensor *qkv_mixed, *z;
|
||||
auto qkv_cpy = save_per_step_states &&
|
||||
il < (int)kv_self.ckpt.per_step_qkv.size() &&
|
||||
id < (int)kv_self.ckpt.per_step_qkv[il].size() ?
|
||||
kv_self.ckpt.per_step_qkv[il][id] : nullptr;
|
||||
if (split_wqkv && split_wqkv_gate) {
|
||||
num_k_heads_id = split_wqkv->splits[id]->ne[1]/(head_k_dim*(2 + gqa_ratio));
|
||||
num_v_heads_id = num_k_heads_id * gqa_ratio;
|
||||
auto p = build_qkvz(lctx, ctx0, split_wqkv->splits[id], split_wqkv_gate->splits[id], cur, il_cb, cb, gf);
|
||||
auto p = build_qkvz(lctx, ctx0, split_wqkv->splits[id], split_wqkv_gate->splits[id], cur, il_cb, cb, gf, qkv_cpy);
|
||||
qkv_mixed = p.first;
|
||||
z = p.second;
|
||||
} else {
|
||||
num_k_heads_id = split_smm_in->splits[id]->ne[1]/(2*head_k_dim*(1 + gqa_ratio));
|
||||
num_v_heads_id = num_k_heads_id * gqa_ratio;
|
||||
auto p = build_qkvz(lctx, ctx0, nullptr, nullptr, split_smm_in->splits[id], head_k_dim, num_k_heads_id, head_v_dim, num_v_heads_id, cur, il, cb, gf);
|
||||
//auto p = build_qkvz(lctx, ctx0, split_smm_in->splits[id], head_k_dim, num_k_heads_id, head_v_dim, num_v_heads_id, cur, il_cb, cb);
|
||||
auto p = build_qkvz(lctx, ctx0, nullptr, nullptr, split_smm_in->splits[id],
|
||||
head_k_dim, num_k_heads_id, head_v_dim, num_v_heads_id, cur, il, cb, gf, qkv_cpy);
|
||||
qkv_mixed = p.first;
|
||||
z = p.second;
|
||||
}
|
||||
@ -529,9 +535,9 @@ ggml_tensor * delta_net::build_layer_attn_linear_core(ggml_context * ctx0, ggml_
|
||||
per_step_ckpt = kv_self.ckpt.per_step_ssm[il][id];
|
||||
//GGML_ASSERT(per_step_ckpt);
|
||||
}
|
||||
if (save_per_step_states && il < (int)kv_self.ckpt.per_step_qkv.size() && id < (int)kv_self.ckpt.per_step_qkv[il].size()) {
|
||||
build_qkv_cpy(ctx0, gf, qkv_mixed, kv_self.ckpt.per_step_qkv[il][id]);
|
||||
}
|
||||
//if (save_per_step_states && il < (int)kv_self.ckpt.per_step_qkv.size() && id < (int)kv_self.ckpt.per_step_qkv[il].size()) {
|
||||
// build_qkv_cpy(ctx0, gf, qkv_mixed, kv_self.ckpt.per_step_qkv[il][id]);
|
||||
//}
|
||||
auto output = build_qkv(ctx0, split_s_l->splits[id], split_ssm_conv1d->splits[id], qkv_mixed, inp_s_seq_qnext, beta, gate,
|
||||
head_k_dim, num_k_heads_id, head_v_dim, num_v_heads_id, hparams.ssm_d_conv,
|
||||
state_seq_id_local, qnext_state_slots, reset_state_local, hparams.f_norm_rms_eps,
|
||||
@ -579,8 +585,10 @@ ggml_tensor * delta_net::build_layer_attn_linear_core(ggml_context * ctx0, ggml_
|
||||
auto norm = model.layers[il].attn_norm->extra ? ((ggml_split_tensor_t *)model.layers[il].attn_norm->extra)->splits[idx] : model.layers[il].attn_norm;
|
||||
auto cur = llm_build_context::llm_build_norm(ctx0, input, hparams, norm, nullptr, LLM_NORM_RMS, cb, il);
|
||||
|
||||
auto qkv_cpy = save_per_step_states && il < (int)kv_self.ckpt.per_step_qkv.size() && !kv_self.ckpt.per_step_qkv[il].empty()
|
||||
? kv_self.ckpt.per_step_qkv[il].front() : nullptr;
|
||||
auto [qkv_mixed, z] = build_qkvz(lctx, ctx0, model.layers[il].wqkv, model.layers[il].wqkv_gate, model.layers[il].ssm_in,
|
||||
head_k_dim, num_k_heads, head_v_dim, num_v_heads, cur, il, cb, gf);
|
||||
head_k_dim, num_k_heads, head_v_dim, num_v_heads, cur, il, cb, gf, qkv_cpy);
|
||||
|
||||
auto [beta, gate] = build_beta_gate(lctx, ctx0, model.layers[il].ssm_beta_alpha, model.layers[il].ssm_beta, model.layers[il].ssm_alpha,
|
||||
model.layers[il].ssm_dt, model.layers[il].ssm_a, num_k_heads, num_v_heads, n_seqs, cur, il, cb, gf);
|
||||
@ -592,16 +600,9 @@ ggml_tensor * delta_net::build_layer_attn_linear_core(ggml_context * ctx0, ggml_
|
||||
}
|
||||
|
||||
// Save qkv_mixed features for per-step conv state reconstruction
|
||||
if (save_per_step_states && il < (int)kv_self.ckpt.per_step_qkv.size() && !kv_self.ckpt.per_step_qkv[il].empty()) {
|
||||
build_qkv_cpy(ctx0, gf, qkv_mixed, kv_self.ckpt.per_step_qkv[il].front());
|
||||
//const int64_t conv_dim = qkv_mixed->ne[0];
|
||||
//const int64_t n_tok_qkv = qkv_mixed->ne[1] * qkv_mixed->ne[2];
|
||||
//ggml_tensor * qkv_flat = ggml_reshape_2d(ctx0, qkv_mixed, conv_dim, n_tok_qkv);
|
||||
//ggml_tensor * qkv_dst = ggml_view_2d(ctx0, kv_self.ckpt.per_step_qkv[il].front(),
|
||||
// conv_dim, n_tok_qkv, conv_dim * sizeof(float), 0);
|
||||
//auto qkv_cpy = ggml_cpy(ctx0, qkv_flat, qkv_dst);
|
||||
//ggml_build_forward_expand(gf, qkv_cpy);
|
||||
}
|
||||
//if (save_per_step_states && il < (int)kv_self.ckpt.per_step_qkv.size() && !kv_self.ckpt.per_step_qkv[il].empty()) {
|
||||
// build_qkv_cpy(ctx0, gf, qkv_mixed, kv_self.ckpt.per_step_qkv[il].front());
|
||||
//}
|
||||
|
||||
auto output = build_qkv(ctx0, kv_self.s_l[il], model.layers[il].ssm_conv1d,
|
||||
qkv_mixed, inp_s_seq_qnext, beta, gate,
|
||||
|
||||
@ -33,13 +33,17 @@ private:
|
||||
bool has_unique_seq_ids;
|
||||
|
||||
static std::pair<ggml_tensor *, ggml_tensor *> build_qkvz(llama_context & lctx, ggml_context * ctx0,
|
||||
ggml_tensor * wqkv, ggml_tensor * wqkv_gate, ggml_tensor * input, int il, const llm_build_cb & cb, ggml_cgraph * gf);
|
||||
ggml_tensor * wqkv, ggml_tensor * wqkv_gate, ggml_tensor * input, int il, const llm_build_cb & cb,
|
||||
ggml_cgraph * gf, ggml_tensor * qkv_copy);
|
||||
|
||||
static std::pair<ggml_tensor *, ggml_tensor *> build_qkvz(llama_context & lctx, ggml_context * ctx0, ggml_tensor * ssm_in,
|
||||
int64_t head_k_dim, int64_t num_k_heads, int64_t head_v_dim, int64_t num_v_heads, ggml_tensor * input, int il, const llm_build_cb & cb);
|
||||
int64_t head_k_dim, int64_t num_k_heads, int64_t head_v_dim, int64_t num_v_heads, ggml_tensor * input, int il,
|
||||
const llm_build_cb & cb, ggml_tensor * qkv_copy);
|
||||
|
||||
static std::pair<ggml_tensor *, ggml_tensor *> build_qkvz(llama_context & lctx, ggml_context * ctx0, ggml_tensor * wqkv, ggml_tensor * wqkv_gate, ggml_tensor * ssm_in,
|
||||
int64_t head_k_dim, int64_t num_k_heads, int64_t head_v_dim, int64_t num_v_heads, ggml_tensor * input, int il, const llm_build_cb & cb, ggml_cgraph * gf);
|
||||
static std::pair<ggml_tensor *, ggml_tensor *> build_qkvz(llama_context & lctx, ggml_context * ctx0,
|
||||
ggml_tensor * wqkv, ggml_tensor * wqkv_gate, ggml_tensor * ssm_in,
|
||||
int64_t head_k_dim, int64_t num_k_heads, int64_t head_v_dim, int64_t num_v_heads, ggml_tensor * input,
|
||||
int il, const llm_build_cb & cb, ggml_cgraph * gf, ggml_tensor * qkv_copy);
|
||||
|
||||
static std::pair<ggml_tensor *, ggml_tensor *> build_beta_gate(llama_context & lctx, ggml_context * ctx0,
|
||||
ggml_tensor * ssm_beta_alpha, ggml_tensor * ssm_beta, ggml_tensor * ssm_alpha,
|
||||
|
||||
@ -553,6 +553,8 @@ struct llama_context::Prev {
|
||||
int n_outputs;
|
||||
int n_kv;
|
||||
int n_tokens;
|
||||
int save_per_step_ssm;
|
||||
int per_step_max_allocated;
|
||||
llama_mtp_op_type mtp_op_type;
|
||||
ggml_cgraph * graph;
|
||||
};
|
||||
@ -565,12 +567,14 @@ void llama_context::reset_scheduler() {
|
||||
|
||||
bool llama_context::can_reuse_graph(const llama_batch & u_batch) {
|
||||
if (!cparams.graph_reuse) return false;
|
||||
if (kv_self.save_per_step_ssm) return false;
|
||||
//if (kv_self.save_per_step_ssm) return false;
|
||||
if (model.arch == LLM_ARCH_GEMMA4_MTP && mtp_target_ctx != nullptr) return false;
|
||||
auto the_prev = cparams.mtp_op_type == MTP_OP_NONE ? prev.get() : prev_mtp.get();
|
||||
if (!the_prev || !the_prev->graph) return false;
|
||||
//if (u_batch.n_tokens > 1) return false;
|
||||
if (u_batch.embd) return false;
|
||||
if (the_prev->save_per_step_ssm != kv_self.save_per_step_ssm ||
|
||||
the_prev->per_step_max_allocated != kv_self.ckpt.per_step_max_allocated) return false;
|
||||
return u_batch.all_seq_id == the_prev->all_seq_id &&
|
||||
kv_self.head > 0 &&
|
||||
kv_self.n == the_prev->n_kv &&
|
||||
@ -4669,7 +4673,9 @@ static int llama_decode_internal(
|
||||
!(lctx.model.arch == LLM_ARCH_GEMMA4_MTP && lctx.mtp_target_ctx != nullptr)) {
|
||||
prev = std::make_unique<llama_context::Prev>(llama_context::Prev{
|
||||
(int)u_batch.all_seq_id, (int)lctx.n_outputs, (int)lctx.kv_self.n,
|
||||
(int)u_batch.n_tokens, cparams.mtp_op_type, gf});
|
||||
(int)u_batch.n_tokens,
|
||||
lctx.kv_self.save_per_step_ssm, lctx.kv_self.ckpt.per_step_max_allocated,
|
||||
cparams.mtp_op_type, gf});
|
||||
}
|
||||
} else {
|
||||
//printf("Reusing graph with n_kv = %d, n_tokens = %d\n", (int)prev->n_kv, (int)prev->n_tokens);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user