Split mode graph for dense Qwen35 MTP (#2027)

* WIP: Split mode graph for Gemma4 assistant

Something is not right - acceptance drops to nearly zero.

* Per model CUDA contexts

Still not working!?

* This works

The issue was that I was not correctly calculating the number
of KV heads for the split KV cache.

* Compiler warnings

* It is better to use llama_context pointers as keys

* Split mode graph for dense Qwen35 MTP
This commit is contained in:
Kawrakow 2026-06-25 11:12:22 +02:00 committed by GitHub
parent d3e86a5431
commit b84902d2ad
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 10 additions and 19 deletions

View File

@ -291,14 +291,10 @@ struct ggml_tensor * llm_build_context::build_qwen35_mtp(
const float kq_scale = 1.0f / sqrtf(float(n_embd_head));
cur = build_std_attention(gf, mtp_layer.attn_norm, cur,
inp_pos, nullptr, nullptr,
inp_pos, inp_out_ids, nullptr,
KQ_mask, nullptr, nullptr,
kq_scale, 0.0f, 0, il, true, false, true, false, true, nullptr);
if (inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
}
// Dense FFN — optional (9B and 4B don't have FFN in MTP layer)
if (mtp_layer.ffn_gate != nullptr) {
cur = llm_build_ffn(ctx0, lctx, mtp_layer.ffn_norm, cur,
@ -312,11 +308,6 @@ struct ggml_tensor * llm_build_context::build_qwen35_mtp(
cur = lctx.cvec.apply_to(ctx0, cur, il);
cb(cur, "ffn_out", il);
// As far as I can tell this was wrong. We need the FFN output, and not the normalized result.
//cur = llm_build_norm(ctx0, cur, hparams, mtp_layer.nextn.shared_head_norm, NULL, LLM_NORM_RMS, cb, il);
cb(cur, "result_norm", -1);
//cur = build_output(lctx, ctx0, cur, model.output, nullptr, cb);
cur = build_output(lctx, ctx0, cur, model.output_mtp, mtp_layer.nextn.shared_head_norm, cb);
cb(cur, "result_output", -1);

View File

@ -179,10 +179,10 @@ struct create_tensors_helper : public create_tensors_helper_interface {
inline ggml_context * ctx_for_layer(int i) const {
return ctx_map.at(model.buft_layer[i].buft);
}
inline ggml_context * ctx_for_layer_split(int i) const {
inline ggml_context * ctx_for_layer_split(int i, bool force_split = false) const {
const bool is_mtp_layer = model.hparams.nextn_predict_layers > 0 &&
static_cast<uint32_t>(i) >= model.hparams.n_layer - model.hparams.nextn_predict_layers;
return is_mtp_layer ? ctx_map.at(model.buft_layer[i].buft) : ctx_map.at(model.buft_layer[i].buft_matrix);
return is_mtp_layer && !force_split ? ctx_map.at(model.buft_layer[i].buft) : ctx_map.at(model.buft_layer[i].buft_matrix);
}
std::map<ggml_backend_buffer_type_t, int> buft_layer_count;
@ -1733,8 +1733,7 @@ bool create_tensors_helper::create_qwen35_tensors(const LLM_TN & tn) {
const bool is_mtp_layer = hparams.nextn_predict_layers > 0 &&
static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers;
// For now only run MTP into the per-layer
ggml_context * ctx_split = is_mtp_layer ? ctx_for_layer(i) : ctx_for_layer_split(i);
ggml_context * ctx_split = ctx_for_layer_split(i, true);
int flags = 0;
// Skip loading MTP layers if the feature is disabled
@ -4609,6 +4608,7 @@ bool create_tensors_helper::create_tensors() {
for (int il = 0; il < n_layer; ++il) {
// For now only run MTP into the per-layer
if (model.mtp && hparams.nextn_predict_layers > 0 &&
model.arch != LLM_ARCH_QWEN35 &&
static_cast<uint32_t>(il) >= static_cast<uint32_t>(n_layer) - hparams.nextn_predict_layers) {
LLAMA_LOG_DEBUG("%s: not splitting MTP tail layer %d (forced non-split)\n", __func__, il);
continue;

View File

@ -880,7 +880,7 @@ static bool llama_kv_cache_init(
std::map<ggml_backend_buffer_type_t, int> buft_layer_count;
if (offload) {
const bool is_mtp = (model.arch == LLM_ARCH_GLM_DSA ||
model.arch == LLM_ARCH_QWEN35 ||
//model.arch == LLM_ARCH_QWEN35 ||
model.arch == LLM_ARCH_QWEN35MOE) && hparams.nextn_predict_layers > 0;
const int64_t n_mtp_first = hparams.n_layer - hparams.nextn_predict_layers;
for (int64_t i = 0; i < n_layer; ++i) {
@ -960,10 +960,10 @@ static bool llama_kv_cache_init(
int n_mla = 0;
int n_kv_active_layers = 0;
const int64_t n_mtp_first_layer = hparams.n_layer - hparams.nextn_predict_layers;
const int n_mtp_first_layer = hparams.n_layer - hparams.nextn_predict_layers;
for (int i = 0; i < (int) n_layer; i++) {
// For MTP-only context, skip KV allocation for non-MTP layers
if (cparams.mtp_op_type != MTP_OP_NONE && i < (int)n_mtp_first_layer) {
if (cparams.mtp_op_type != MTP_OP_NONE && i < n_mtp_first_layer) {
cache.k_l.push_back(nullptr);
if (!is_mla_attn || !cparams.mla_attn || (cparams.mla_attn == 1 && !cparams.flash_attn)) {
cache.v_l.push_back(nullptr);
@ -976,10 +976,10 @@ static bool llama_kv_cache_init(
const uint32_t n_head_kv = hparams.n_head_kv(i);
const uint32_t n_embd_head_k= hparams.n_embd_head_k(i);
const bool is_mtp_tail_layer = (model.arch == LLM_ARCH_QWEN35 ||
const bool is_mtp_tail_layer = (//model.arch == LLM_ARCH_QWEN35 ||
model.arch == LLM_ARCH_QWEN35MOE ||
model.arch == LLM_ARCH_GLM_DSA) &&
hparams.nextn_predict_layers > 0 && i >= (int)n_mtp_first_layer;
hparams.nextn_predict_layers > 0 && i >= n_mtp_first_layer;
//struct ggml_context * ctx = split_cache && !qnext_recurrent ? ctx_map.at(model.buft_layer[i].buft_matrix) : offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front();
struct ggml_context * ctx = ((split_cache || replicate_mla) && !is_mtp_tail_layer) ? ctx_map.at(model.buft_layer[i].buft_matrix) : offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front();
ggml_tensor * k = nullptr;