diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 948cc016a9..76a7735351 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -2663,6 +2663,29 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * k = src_cur->get_k(ctx0, il_src); ggml_tensor * v = src_cur->get_v(ctx0, il_src); + // build_attn_mha splits q across k->ne[3] (the trunk's stream count). When the + // trunk runs kv_unified=false the assistant's ubatch only references a subset + // of streams (one per active draft seq); q->ne[2] is not divisible by the full + // n_stream and the view collapses tokens. Slice k/v down to exactly the streams + // referenced by this ubatch. Requires those streams to form a contiguous range. + if (k->ne[3] > 1 && (uint32_t) k->ne[3] != ubatch.n_seqs_unq) { + GGML_ASSERT(ubatch.n_seqs_unq > 0 && ubatch.seq_id_unq); + llama_seq_id min_s = ubatch.seq_id_unq[0]; + llama_seq_id max_s = ubatch.seq_id_unq[0]; + for (uint32_t s = 1; s < ubatch.n_seqs_unq; ++s) { + min_s = std::min(min_s, ubatch.seq_id_unq[s]); + max_s = std::max(max_s, ubatch.seq_id_unq[s]); + } + GGML_ASSERT((uint32_t)(max_s - min_s + 1) == ubatch.n_seqs_unq && + "MTP src-kv attn requires the active draft seq_ids to be contiguous"); + GGML_ASSERT((int64_t) max_s < k->ne[3] && "MTP assistant seq_id beyond trunk stream count"); + + k = ggml_view_4d(ctx0, k, k->ne[0], k->ne[1], k->ne[2], (int64_t) ubatch.n_seqs_unq, + k->nb[1], k->nb[2], k->nb[3], (size_t) min_s * k->nb[3]); + v = ggml_view_4d(ctx0, v, v->ne[0], v->ne[1], v->ne[2], (int64_t) ubatch.n_seqs_unq, + v->nb[1], v->nb[2], v->nb[3], (size_t) min_s * v->nb[3]); + } + ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il_assist); cb(cur, "kqv_out", il_assist); diff --git a/src/models/gemma4.cpp b/src/models/gemma4.cpp index 822db10cb1..f837b23686 100644 --- a/src/models/gemma4.cpp +++ b/src/models/gemma4.cpp @@ -329,7 +329,6 @@ llama_model_gemma4_assistant::graph::graph(const llama_model & model, const llm_ cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); cb(cur, "result_norm", -1); - res->t_embd = cur; ggml_tensor * logits = build_lora_mm(model.output, cur); cb(logits, "result_output", -1);