fix multi-seq

This commit is contained in:
Aman Gupta 2026-05-19 22:17:09 +08:00
parent f268966d49
commit 9af0434d8c
2 changed files with 23 additions and 1 deletions

View File

@ -2663,6 +2663,29 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_tensor * k = src_cur->get_k(ctx0, il_src);
ggml_tensor * v = src_cur->get_v(ctx0, il_src);
// build_attn_mha splits q across k->ne[3] (the trunk's stream count). When the
// trunk runs kv_unified=false the assistant's ubatch only references a subset
// of streams (one per active draft seq); q->ne[2] is not divisible by the full
// n_stream and the view collapses tokens. Slice k/v down to exactly the streams
// referenced by this ubatch. Requires those streams to form a contiguous range.
if (k->ne[3] > 1 && (uint32_t) k->ne[3] != ubatch.n_seqs_unq) {
GGML_ASSERT(ubatch.n_seqs_unq > 0 && ubatch.seq_id_unq);
llama_seq_id min_s = ubatch.seq_id_unq[0];
llama_seq_id max_s = ubatch.seq_id_unq[0];
for (uint32_t s = 1; s < ubatch.n_seqs_unq; ++s) {
min_s = std::min(min_s, ubatch.seq_id_unq[s]);
max_s = std::max(max_s, ubatch.seq_id_unq[s]);
}
GGML_ASSERT((uint32_t)(max_s - min_s + 1) == ubatch.n_seqs_unq &&
"MTP src-kv attn requires the active draft seq_ids to be contiguous");
GGML_ASSERT((int64_t) max_s < k->ne[3] && "MTP assistant seq_id beyond trunk stream count");
k = ggml_view_4d(ctx0, k, k->ne[0], k->ne[1], k->ne[2], (int64_t) ubatch.n_seqs_unq,
k->nb[1], k->nb[2], k->nb[3], (size_t) min_s * k->nb[3]);
v = ggml_view_4d(ctx0, v, v->ne[0], v->ne[1], v->ne[2], (int64_t) ubatch.n_seqs_unq,
v->nb[1], v->nb[2], v->nb[3], (size_t) min_s * v->nb[3]);
}
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il_assist);
cb(cur, "kqv_out", il_assist);

View File

@ -329,7 +329,6 @@ llama_model_gemma4_assistant::graph::graph(const llama_model & model, const llm_
cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1);
cb(cur, "result_norm", -1);
res->t_embd = cur;
ggml_tensor * logits = build_lora_mm(model.output, cur);
cb(logits, "result_output", -1);