From d5c04c15fd11230a02198ee5381f3a9a7a5196cc Mon Sep 17 00:00:00 2001 From: Samuel Oliveira Alves <107287165+SamuelOliveirads@users.noreply.github.com> Date: Fri, 19 Jun 2026 04:04:54 -0300 Subject: [PATCH] clean redudance in dflash graph and small logics (#1994) --- src/graphs/build_dflash.cpp | 148 +++++++++++++---------------- src/llama-dflash.cpp | 128 ++++++++++++++++--------- src/llama-hparams.cpp | 32 ++++--- src/llama-model-loader.cpp | 4 +- src/llama-spec-features-dflash.cpp | 20 ++-- src/llama-spec-features-dflash.h | 1 - 6 files changed, 188 insertions(+), 145 deletions(-) diff --git a/src/graphs/build_dflash.cpp b/src/graphs/build_dflash.cpp index 4fcc43fb..3adacce6 100644 --- a/src/graphs/build_dflash.cpp +++ b/src/graphs/build_dflash.cpp @@ -67,11 +67,11 @@ ggml_cgraph * llm_build_context::build_dflash_kv_workspace() { }; for (int il = 0; il < n_layer; ++il) { - GGML_ASSERT((size_t) il < lctx.dflash.kv.k_ctx_cache.size()); - GGML_ASSERT((size_t) il < lctx.dflash.kv.v_ctx_cache.size()); + GGML_ASSERT(il < (int32_t) lctx.dflash.kv.k_ctx_cache.size()); + GGML_ASSERT(il < (int32_t) lctx.dflash.kv.v_ctx_cache.size()); - ggml_tensor * Kordered = build_ordered_cache_view(lctx.dflash.kv.k_ctx_cache[(size_t) il]); - ggml_tensor * Vordered = build_ordered_cache_view(lctx.dflash.kv.v_ctx_cache[(size_t) il]); + ggml_tensor * Kordered = build_ordered_cache_view(lctx.dflash.kv.k_ctx_cache[il]); + ggml_tensor * Vordered = build_ordered_cache_view(lctx.dflash.kv.v_ctx_cache[il]); cb(Kordered, "dflash_workspace_k_ctx_view", il); cb(Vordered, "dflash_workspace_v_ctx_view", il); @@ -80,19 +80,19 @@ ggml_cgraph * llm_build_context::build_dflash_kv_workspace() { cb(Kworkspace, "dflash_workspace_k_perm_cont", il); cb(Vworkspace, "dflash_workspace_v_perm_cont", il); - ggml_tensor * Kdst = ggml_view_3d(ctx0, lctx.dflash.kv.k_ctx_workspace[(size_t) il], - lctx.dflash.kv.k_ctx_workspace[(size_t) il]->ne[0], + ggml_tensor * Kdst = ggml_view_3d(ctx0, lctx.dflash.kv.k_ctx_workspace[il], + lctx.dflash.kv.k_ctx_workspace[il]->ne[0], ctx_len, - lctx.dflash.kv.k_ctx_workspace[(size_t) il]->ne[2], - lctx.dflash.kv.k_ctx_workspace[(size_t) il]->nb[1], - lctx.dflash.kv.k_ctx_workspace[(size_t) il]->nb[2], + lctx.dflash.kv.k_ctx_workspace[il]->ne[2], + lctx.dflash.kv.k_ctx_workspace[il]->nb[1], + lctx.dflash.kv.k_ctx_workspace[il]->nb[2], 0); - ggml_tensor * Vdst = ggml_view_3d(ctx0, lctx.dflash.kv.v_ctx_workspace[(size_t) il], - lctx.dflash.kv.v_ctx_workspace[(size_t) il]->ne[0], + ggml_tensor * Vdst = ggml_view_3d(ctx0, lctx.dflash.kv.v_ctx_workspace[il], + lctx.dflash.kv.v_ctx_workspace[il]->ne[0], ctx_len, - lctx.dflash.kv.v_ctx_workspace[(size_t) il]->ne[2], - lctx.dflash.kv.v_ctx_workspace[(size_t) il]->nb[1], - lctx.dflash.kv.v_ctx_workspace[(size_t) il]->nb[2], + lctx.dflash.kv.v_ctx_workspace[il]->ne[2], + lctx.dflash.kv.v_ctx_workspace[il]->nb[1], + lctx.dflash.kv.v_ctx_workspace[il]->nb[2], 0); ggml_tensor * Kstore = ggml_cpy(ctx0, Kworkspace, Kdst); @@ -137,8 +137,8 @@ ggml_cgraph * llm_build_context::build_dflash_kv_cache() { cb(fused_target, "dflash_kv_fused_target", -1); for (int il = 0; il < n_layer; ++il) { - GGML_ASSERT((size_t) il < lctx.dflash.kv.k_ctx_cache.size()); - GGML_ASSERT((size_t) il < lctx.dflash.kv.v_ctx_cache.size()); + GGML_ASSERT(il < (int32_t) lctx.dflash.kv.k_ctx_cache.size()); + GGML_ASSERT(il < (int32_t) lctx.dflash.kv.v_ctx_cache.size()); ggml_tensor * Kcur_ctx_proj = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, fused_target); cb(Kcur_ctx_proj, "dflash_kv_k_proj", il); @@ -177,20 +177,20 @@ ggml_cgraph * llm_build_context::build_dflash_kv_cache() { Vcur_ctx->nb[1], Vcur_ctx->nb[2], 0); - ggml_tensor * Kdst_first = ggml_view_3d(ctx0, lctx.dflash.kv.k_ctx_cache[(size_t) il], - lctx.dflash.kv.k_ctx_cache[(size_t) il]->ne[0], - lctx.dflash.kv.k_ctx_cache[(size_t) il]->ne[1], + ggml_tensor * Kdst_first = ggml_view_3d(ctx0, lctx.dflash.kv.k_ctx_cache[il], + lctx.dflash.kv.k_ctx_cache[il]->ne[0], + lctx.dflash.kv.k_ctx_cache[il]->ne[1], first_rows, - lctx.dflash.kv.k_ctx_cache[(size_t) il]->nb[1], - lctx.dflash.kv.k_ctx_cache[(size_t) il]->nb[2], - (size_t) write_pos * lctx.dflash.kv.k_ctx_cache[(size_t) il]->nb[2]); - ggml_tensor * Vdst_first = ggml_view_3d(ctx0, lctx.dflash.kv.v_ctx_cache[(size_t) il], - lctx.dflash.kv.v_ctx_cache[(size_t) il]->ne[0], - lctx.dflash.kv.v_ctx_cache[(size_t) il]->ne[1], + lctx.dflash.kv.k_ctx_cache[il]->nb[1], + lctx.dflash.kv.k_ctx_cache[il]->nb[2], + (size_t) write_pos * lctx.dflash.kv.k_ctx_cache[il]->nb[2]); + ggml_tensor * Vdst_first = ggml_view_3d(ctx0, lctx.dflash.kv.v_ctx_cache[il], + lctx.dflash.kv.v_ctx_cache[il]->ne[0], + lctx.dflash.kv.v_ctx_cache[il]->ne[1], first_rows, - lctx.dflash.kv.v_ctx_cache[(size_t) il]->nb[1], - lctx.dflash.kv.v_ctx_cache[(size_t) il]->nb[2], - (size_t) write_pos * lctx.dflash.kv.v_ctx_cache[(size_t) il]->nb[2]); + lctx.dflash.kv.v_ctx_cache[il]->nb[1], + lctx.dflash.kv.v_ctx_cache[il]->nb[2], + (size_t) write_pos * lctx.dflash.kv.v_ctx_cache[il]->nb[2]); ggml_tensor * Kstore_first = ggml_cpy(ctx0, Ksrc_first, Kdst_first); cb(Kstore_first, "dflash_kv_k_store", il); @@ -216,19 +216,19 @@ ggml_cgraph * llm_build_context::build_dflash_kv_cache() { Vcur_ctx->nb[1], Vcur_ctx->nb[2], (size_t) first_rows * Vcur_ctx->nb[2]); - ggml_tensor * Kdst_second = ggml_view_3d(ctx0, lctx.dflash.kv.k_ctx_cache[(size_t) il], - lctx.dflash.kv.k_ctx_cache[(size_t) il]->ne[0], - lctx.dflash.kv.k_ctx_cache[(size_t) il]->ne[1], + ggml_tensor * Kdst_second = ggml_view_3d(ctx0, lctx.dflash.kv.k_ctx_cache[il], + lctx.dflash.kv.k_ctx_cache[il]->ne[0], + lctx.dflash.kv.k_ctx_cache[il]->ne[1], second_rows, - lctx.dflash.kv.k_ctx_cache[(size_t) il]->nb[1], - lctx.dflash.kv.k_ctx_cache[(size_t) il]->nb[2], + lctx.dflash.kv.k_ctx_cache[il]->nb[1], + lctx.dflash.kv.k_ctx_cache[il]->nb[2], 0); - ggml_tensor * Vdst_second = ggml_view_3d(ctx0, lctx.dflash.kv.v_ctx_cache[(size_t) il], - lctx.dflash.kv.v_ctx_cache[(size_t) il]->ne[0], - lctx.dflash.kv.v_ctx_cache[(size_t) il]->ne[1], + ggml_tensor * Vdst_second = ggml_view_3d(ctx0, lctx.dflash.kv.v_ctx_cache[il], + lctx.dflash.kv.v_ctx_cache[il]->ne[0], + lctx.dflash.kv.v_ctx_cache[il]->ne[1], second_rows, - lctx.dflash.kv.v_ctx_cache[(size_t) il]->nb[1], - lctx.dflash.kv.v_ctx_cache[(size_t) il]->nb[2], + lctx.dflash.kv.v_ctx_cache[il]->nb[1], + lctx.dflash.kv.v_ctx_cache[il]->nb[2], 0); ggml_tensor * Kstore_second = ggml_cpy(ctx0, Ksrc_second, Kdst_second); @@ -264,40 +264,39 @@ ggml_cgraph * llm_build_context::build_dflash() { ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes((int) std::max(n_tokens, ctx_len)) + 32 * n_layer, false); - bool have_swa_layers = false; - for (int il = 0; il < n_layer; ++il) { - if (hparams.swa_layers[il]) { - have_swa_layers = true; - break; + const bool needs_swa_mask = hparams.n_swa > 0 && [&]() { + for (int il = 0; il < n_layer; ++il) { + if (hparams.swa_layers[il]) { + return true; + } } - } + return false; + }(); + const ggml_type mask_type = flash_attn ? GGML_TYPE_F16 : GGML_TYPE_F32; - lctx.dflash.inputs.kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv_total, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + lctx.dflash.inputs.kq_mask = ggml_new_tensor_2d(ctx0, mask_type, n_kv_total, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); lctx.dflash.kv.kq_mask_tensor = lctx.dflash.inputs.kq_mask; ggml_set_input(lctx.dflash.inputs.kq_mask); cb(lctx.dflash.inputs.kq_mask, "dflash_kq_mask", -1); - ggml_tensor * dflash_kq_mask_full = flash_attn ? ggml_cast(ctx0, lctx.dflash.inputs.kq_mask, GGML_TYPE_F16) : lctx.dflash.inputs.kq_mask; + ggml_tensor * dflash_kq_mask_full = lctx.dflash.inputs.kq_mask; ggml_tensor * dflash_kq_mask_swa = nullptr; lctx.dflash.inputs.kq_mask_swa = nullptr; lctx.dflash.kv.kq_mask_swa_tensor = nullptr; - if (have_swa_layers && hparams.n_swa > 0) { - lctx.dflash.inputs.kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv_total, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + if (needs_swa_mask) { + lctx.dflash.inputs.kq_mask_swa = ggml_new_tensor_2d(ctx0, mask_type, n_kv_total, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); lctx.dflash.kv.kq_mask_swa_tensor = lctx.dflash.inputs.kq_mask_swa; ggml_set_input(lctx.dflash.inputs.kq_mask_swa); cb(lctx.dflash.inputs.kq_mask_swa, "dflash_kq_mask_swa", -1); - dflash_kq_mask_swa = flash_attn ? ggml_cast(ctx0, lctx.dflash.inputs.kq_mask_swa, GGML_TYPE_F16) : lctx.dflash.inputs.kq_mask_swa; + dflash_kq_mask_swa = lctx.dflash.inputs.kq_mask_swa; } ggml_tensor * tok_embd = model.tok_embd; - if (tok_embd == nullptr) { - tok_embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_Q4_0, n_embd, hparams.n_vocab); - } + GGML_ASSERT(tok_embd != nullptr); ggml_tensor * inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, tok_embd, cb); ggml_tensor * inp_pos = build_inp_pos(); ggml_tensor * inp_out_ids = (n_tokens > 1 && n_outputs < n_tokens) ? build_inp_out_ids() : nullptr; - bool result_rows_selected = false; const float kq_scale = 1.0f / std::sqrt((float) n_embd_head_k); @@ -327,24 +326,24 @@ ggml_cgraph * llm_build_context::build_dflash() { Vcur_noise = ggml_reshape_3d(ctx0, Vcur_noise, n_embd_head_v, n_head_kv, n_tokens); cb(Vcur_noise, "Vcur_noise", il); - GGML_ASSERT((size_t) il < lctx.dflash.kv.k_ctx_workspace.size()); - GGML_ASSERT((size_t) il < lctx.dflash.kv.v_ctx_workspace.size()); - GGML_ASSERT(lctx.dflash.kv.k_ctx_workspace[(size_t) il] != nullptr); - GGML_ASSERT(lctx.dflash.kv.v_ctx_workspace[(size_t) il] != nullptr); + GGML_ASSERT(il < (int32_t) lctx.dflash.kv.k_ctx_workspace.size()); + GGML_ASSERT(il < (int32_t) lctx.dflash.kv.v_ctx_workspace.size()); + GGML_ASSERT(lctx.dflash.kv.k_ctx_workspace[il] != nullptr); + GGML_ASSERT(lctx.dflash.kv.v_ctx_workspace[il] != nullptr); - ggml_tensor * Kcur_ctx = ggml_view_3d(ctx0, lctx.dflash.kv.k_ctx_workspace[(size_t) il], - lctx.dflash.kv.k_ctx_workspace[(size_t) il]->ne[0], + ggml_tensor * Kcur_ctx = ggml_view_3d(ctx0, lctx.dflash.kv.k_ctx_workspace[il], + lctx.dflash.kv.k_ctx_workspace[il]->ne[0], ctx_len, - lctx.dflash.kv.k_ctx_workspace[(size_t) il]->ne[2], - lctx.dflash.kv.k_ctx_workspace[(size_t) il]->nb[1], - lctx.dflash.kv.k_ctx_workspace[(size_t) il]->nb[2], + lctx.dflash.kv.k_ctx_workspace[il]->ne[2], + lctx.dflash.kv.k_ctx_workspace[il]->nb[1], + lctx.dflash.kv.k_ctx_workspace[il]->nb[2], 0); - ggml_tensor * Vcur_ctx = ggml_view_3d(ctx0, lctx.dflash.kv.v_ctx_workspace[(size_t) il], - lctx.dflash.kv.v_ctx_workspace[(size_t) il]->ne[0], + ggml_tensor * Vcur_ctx = ggml_view_3d(ctx0, lctx.dflash.kv.v_ctx_workspace[il], + lctx.dflash.kv.v_ctx_workspace[il]->ne[0], ctx_len, - lctx.dflash.kv.v_ctx_workspace[(size_t) il]->ne[2], - lctx.dflash.kv.v_ctx_workspace[(size_t) il]->nb[1], - lctx.dflash.kv.v_ctx_workspace[(size_t) il]->nb[2], + lctx.dflash.kv.v_ctx_workspace[il]->ne[2], + lctx.dflash.kv.v_ctx_workspace[il]->nb[1], + lctx.dflash.kv.v_ctx_workspace[il]->nb[2], 0); cb(Kcur_ctx, "Kcur_ctx_workspace", il); cb(Vcur_ctx, "Vcur_ctx_workspace", il); @@ -400,7 +399,6 @@ ggml_cgraph * llm_build_context::build_dflash() { if (inp_out_ids != nullptr && il == n_layer - 1) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); cb(cur, "result_output_rows", -1); - result_rows_selected = true; } ggml_tensor * ffn_residual = cur; @@ -421,18 +419,8 @@ ggml_cgraph * llm_build_context::build_dflash() { inpL = cur; } - ggml_tensor * output = const_cast(llama_model_dflash_output_tensor(&model)); - if (output == nullptr) { - output = ggml_new_tensor_2d(ctx0, GGML_TYPE_Q4_0, n_embd, hparams.n_vocab); - } - - ggml_tensor * result_input = inpL; - if (inp_out_ids && !result_rows_selected) { - result_input = ggml_get_rows(ctx0, result_input, inp_out_ids); - cb(result_input, "result_output_rows", -1); - } - - ggml_tensor * result = build_output(lctx, ctx0, result_input, output, model.output_norm, cb); + GGML_ASSERT(model.output_mtp != nullptr); + ggml_tensor * result = build_output(lctx, ctx0, inpL, model.output_mtp, model.output_norm, cb); cb(result, "result_output", -1); ggml_build_forward_expand(gf, result); diff --git a/src/llama-dflash.cpp b/src/llama-dflash.cpp index 277a6ffd..aa912a96 100644 --- a/src/llama-dflash.cpp +++ b/src/llama-dflash.cpp @@ -25,12 +25,12 @@ void llama_sync_dflash_workspace_if_pending(struct llama_context & lctx) { } static ggml_backend_buffer_type_t llama_dflash_kv_cache_layer_buft(const llama_context & lctx, int32_t il) { - if (il >= 0 && (size_t) il < lctx.model.buft_layer.size() && lctx.model.buft_layer[(size_t) il].buft != nullptr) { - return lctx.model.buft_layer[(size_t) il].buft; + if (il >= 0 && il < (int32_t) lctx.model.buft_layer.size() && lctx.model.buft_layer[il].buft != nullptr) { + return lctx.model.buft_layer[il].buft; } - if (il >= 0 && (size_t) il < lctx.model.layers.size()) { - const ggml_tensor * wk = lctx.model.layers[(size_t) il].wk; + if (il >= 0 && il < (int32_t) lctx.model.layers.size()) { + const ggml_tensor * wk = lctx.model.layers[il].wk; if (wk != nullptr && wk->buffer != nullptr) { return ggml_backend_buffer_get_type(wk->buffer); } @@ -123,6 +123,11 @@ bool llama_context::ensure_dflash_kv_cache_tensors(int32_t cross_ctx) { dflash.kv.cache_bufs.reserve((size_t) std::max(1, n_layer) * 4); for (int32_t il = 0; il < n_layer; ++il) { ggml_backend_buffer_type_t layer_buft = llama_dflash_kv_cache_layer_buft(*this, il); + ggml_tensor *& k_ctx_cache = dflash.kv.k_ctx_cache[il]; + ggml_tensor *& v_ctx_cache = dflash.kv.v_ctx_cache[il]; + ggml_tensor *& k_ctx_workspace = dflash.kv.k_ctx_workspace[il]; + ggml_tensor *& v_ctx_workspace = dflash.kv.v_ctx_workspace[il]; + auto alloc_kv_input = [&](ggml_tensor *& tensor, const char * tensor_tag, const char * tensor_name, int64_t ne0, int64_t ne1, int64_t ne2) -> bool { tensor = ggml_new_tensor_3d(dflash.kv.cache_ctx, GGML_TYPE_F32, ne0, ne1, ne2); @@ -150,13 +155,13 @@ bool llama_context::ensure_dflash_kv_cache_tensors(int32_t cross_ctx) { return true; }; - if (!alloc_kv_input(dflash.kv.k_ctx_cache[(size_t) il], "dflash_k_ctx_cache", "dflash_k_ctx_cache_%d", + if (!alloc_kv_input(k_ctx_cache, "dflash_k_ctx_cache", "dflash_k_ctx_cache_%d", n_embd_head_k, n_head_kv, target_cross_ctx) || - !alloc_kv_input(dflash.kv.v_ctx_cache[(size_t) il], "dflash_v_ctx_cache", "dflash_v_ctx_cache_%d", + !alloc_kv_input(v_ctx_cache, "dflash_v_ctx_cache", "dflash_v_ctx_cache_%d", n_embd_head_v, n_head_kv, target_cross_ctx) || - !alloc_kv_input(dflash.kv.k_ctx_workspace[(size_t) il], "dflash_k_ctx_workspace", "dflash_k_ctx_workspace_%d", + !alloc_kv_input(k_ctx_workspace, "dflash_k_ctx_workspace", "dflash_k_ctx_workspace_%d", n_embd_head_k, target_workspace_n_kv_total, n_head_kv) || - !alloc_kv_input(dflash.kv.v_ctx_workspace[(size_t) il], "dflash_v_ctx_workspace", "dflash_v_ctx_workspace_%d", + !alloc_kv_input(v_ctx_workspace, "dflash_v_ctx_workspace", "dflash_v_ctx_workspace_%d", n_embd_head_v, target_workspace_n_kv_total, n_head_kv)) { free_dflash_kv_cache_tensors(); return false; @@ -267,23 +272,23 @@ static bool validate_dflash_graph_contract(const llama_context & lctx) { const auto & hparams = model.hparams; auto rope_dim_for_layer = [&hparams](int32_t il) -> uint32_t { - if (hparams.rope_dim_per_layer[(size_t) il] != 0) { - return hparams.rope_dim_per_layer[(size_t) il]; + if (hparams.rope_dim_per_layer[il] != 0) { + return hparams.rope_dim_per_layer[il]; } - return hparams.swa_layers[(size_t) il] ? hparams.n_rot_swa : hparams.n_rot; + return hparams.swa_layers[il] ? hparams.n_rot_swa : hparams.n_rot; }; auto rope_base_for_layer = [&hparams](int32_t il) -> float { if (hparams.has_rope_freq_base_per_layer) { - return hparams.rope_freq_base_per_layer[(size_t) il]; + return hparams.rope_freq_base_per_layer[il]; } - return hparams.swa_layers[(size_t) il] ? hparams.rope_freq_base_train_swa : hparams.rope_freq_base_train; + return hparams.swa_layers[il] ? hparams.rope_freq_base_train_swa : hparams.rope_freq_base_train; }; auto rope_scale_for_layer = [&hparams](int32_t il) -> float { - return hparams.swa_layers[(size_t) il] ? hparams.rope_freq_scale_train_swa : hparams.rope_freq_scale_train; + return hparams.swa_layers[il] ? hparams.rope_freq_scale_train_swa : hparams.rope_freq_scale_train; }; const uint32_t ref_n_head = hparams.n_head(0); @@ -322,31 +327,31 @@ static bool validate_dflash_graph_contract(const llama_context & lctx) { return false; } - if (model.layers[(size_t) il].attn_norm == nullptr || - model.layers[(size_t) il].attn_q_norm == nullptr || - model.layers[(size_t) il].attn_k_norm == nullptr) { + if (model.layers[il].attn_norm == nullptr || + model.layers[il].attn_q_norm == nullptr || + model.layers[il].attn_k_norm == nullptr) { LLAMA_LOG_ERROR("%s: DFlash graph requires attn_norm, attn_q_norm, and attn_k_norm weights, but layer %d is missing one or more of them\n", __func__, il); return false; } - const bool has_q_norm = model.layers[(size_t) il].attn_q_norm != nullptr; - const bool has_k_norm = model.layers[(size_t) il].attn_k_norm != nullptr; + const bool has_q_norm = model.layers[il].attn_q_norm != nullptr; + const bool has_k_norm = model.layers[il].attn_k_norm != nullptr; if (has_q_norm != has_k_norm) { LLAMA_LOG_ERROR("%s: DFlash graph requires symmetric Q/K norm presence, but layer %d has q_norm=%d k_norm=%d\n", __func__, il, (int) has_q_norm, (int) has_k_norm); return false; } - if (model.layers[(size_t) il].attn_norm_b != nullptr || - model.layers[(size_t) il].attn_q_norm_b != nullptr || - model.layers[(size_t) il].attn_k_norm_b != nullptr) { + if (model.layers[il].attn_norm_b != nullptr || + model.layers[il].attn_q_norm_b != nullptr || + model.layers[il].attn_k_norm_b != nullptr) { LLAMA_LOG_ERROR("%s: DFlash graph does not implement norm-bias tensors, but layer %d requires attn_norm_b/q_norm_b/k_norm_b\n", __func__, il); return false; } - if (dflash_layer_has_attention_bias(model.layers[(size_t) il])) { + if (dflash_layer_has_attention_bias(model.layers[il])) { LLAMA_LOG_ERROR("%s: DFlash graph does not implement attention bias tensors, but layer %d requires them\n", __func__, il); return false; @@ -655,39 +660,76 @@ bool llama_prepare_dflash_graph_inputs( const int32_t full_visible_first = left_pad; const int32_t full_visible_last = cross_ctx + (int32_t) n_tokens - 1; - lctx.dflash.target.kq_mask_data.assign((size_t) n_kv_total * (size_t) n_mask_tokens, -INFINITY); - for (uint32_t j = 0; j < n_tokens; ++j) { - float * row = lctx.dflash.target.kq_mask_data.data() + (size_t) j * (size_t) n_kv_total; - for (int32_t i = full_visible_first; i <= full_visible_last; ++i) { - row[i] = 0.0f; + const size_t mask_elems = (size_t) n_kv_total * (size_t) n_mask_tokens; + if (kq_mask->type == GGML_TYPE_F16) { + const ggml_fp16_t h_inf = ggml_fp32_to_fp16(-INFINITY); + const ggml_fp16_t h_zero = ggml_fp32_to_fp16(0.0f); + std::vector mask_f16(mask_elems, h_inf); + std::vector row_f16((size_t) n_kv_total, h_inf); + std::fill(row_f16.begin() + full_visible_first, row_f16.begin() + full_visible_last + 1, h_zero); + for (uint32_t j = 0; j < n_tokens; ++j) { + std::memcpy(mask_f16.data() + (size_t) j * (size_t) n_kv_total, row_f16.data(), (size_t) n_kv_total * sizeof(ggml_fp16_t)); } + ggml_backend_tensor_set(kq_mask, mask_f16.data(), 0, ggml_nbytes(kq_mask)); + } else { + lctx.dflash.target.kq_mask_data.assign(mask_elems, -INFINITY); + std::vector row_f32((size_t) n_kv_total, -INFINITY); + std::fill(row_f32.begin() + full_visible_first, row_f32.begin() + full_visible_last + 1, 0.0f); + for (uint32_t j = 0; j < n_tokens; ++j) { + std::memcpy(lctx.dflash.target.kq_mask_data.data() + (size_t) j * (size_t) n_kv_total, row_f32.data(), (size_t) n_kv_total * sizeof(float)); + } + ggml_backend_tensor_set(kq_mask, lctx.dflash.target.kq_mask_data.data(), 0, ggml_nbytes(kq_mask)); } - ggml_backend_tensor_set(kq_mask, lctx.dflash.target.kq_mask_data.data(), 0, ggml_nbytes(kq_mask)); if (kq_mask_swa != nullptr) { - lctx.dflash.target.kq_mask_swa_data.assign((size_t) n_kv_total * (size_t) n_mask_tokens, -INFINITY); const int32_t swa_window = (int32_t) lctx.model.hparams.n_swa; const int32_t draft_pos_base = (int32_t) last_target_pos; - for (uint32_t j = 0; j < n_tokens; ++j) { - float * row = lctx.dflash.target.kq_mask_swa_data.data() + (size_t) j * (size_t) n_kv_total; - const int32_t q_pos = draft_pos_base + (int32_t) j; - for (int32_t k = left_pad; k < cross_ctx; ++k) { - const int32_t k_pos = (int32_t) lctx.dflash.target.pos_ctx_data[(size_t) k]; - if (q_pos - k_pos < swa_window) { - row[k] = 0.0f; + if (kq_mask_swa->type == GGML_TYPE_F16) { + const ggml_fp16_t h_inf = ggml_fp32_to_fp16(-INFINITY); + const ggml_fp16_t h_zero = ggml_fp32_to_fp16(0.0f); + std::vector mask_swa_f16(mask_elems, h_inf); + for (uint32_t j = 0; j < n_tokens; ++j) { + ggml_fp16_t * row = mask_swa_f16.data() + (size_t) j * (size_t) n_kv_total; + const int32_t q_pos = draft_pos_base + (int32_t) j; + + for (int32_t k = left_pad; k < cross_ctx; ++k) { + const int32_t k_pos = (int32_t) lctx.dflash.target.pos_ctx_data[(size_t) k]; + if (q_pos - k_pos < swa_window) { + row[k] = h_zero; + } + } + + for (int32_t k = cross_ctx; k < cross_ctx + (int32_t) n_tokens; ++k) { + const int32_t block_k = k - cross_ctx; + if (block_k <= (int32_t) j) { + row[k] = h_zero; + } } } + ggml_backend_tensor_set(kq_mask_swa, mask_swa_f16.data(), 0, ggml_nbytes(kq_mask_swa)); + } else { + lctx.dflash.target.kq_mask_swa_data.assign(mask_elems, -INFINITY); + for (uint32_t j = 0; j < n_tokens; ++j) { + float * row = lctx.dflash.target.kq_mask_swa_data.data() + (size_t) j * (size_t) n_kv_total; + const int32_t q_pos = draft_pos_base + (int32_t) j; - for (int32_t k = cross_ctx; k < cross_ctx + (int32_t) n_tokens; ++k) { - const int32_t block_k = k - cross_ctx; - if (block_k <= (int32_t) j) { - row[k] = 0.0f; + for (int32_t k = left_pad; k < cross_ctx; ++k) { + const int32_t k_pos = (int32_t) lctx.dflash.target.pos_ctx_data[(size_t) k]; + if (q_pos - k_pos < swa_window) { + row[k] = 0.0f; + } + } + + for (int32_t k = cross_ctx; k < cross_ctx + (int32_t) n_tokens; ++k) { + const int32_t block_k = k - cross_ctx; + if (block_k <= (int32_t) j) { + row[k] = 0.0f; + } } } + ggml_backend_tensor_set(kq_mask_swa, lctx.dflash.target.kq_mask_swa_data.data(), 0, ggml_nbytes(kq_mask_swa)); } - - ggml_backend_tensor_set(kq_mask_swa, lctx.dflash.target.kq_mask_swa_data.data(), 0, ggml_nbytes(kq_mask_swa)); } return true; diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 80f852f1..b041844c 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -55,33 +55,39 @@ static bool load_dflash_target_layer_ids( throw std::runtime_error(format("dflash: %s must be a uint32/int32 array", key.c_str())); } - const size_t n = gguf_get_arr_n(ml.meta, kid); + uint32_t n = 0; + ml.get_arr_n(key, n, true); if (n == 0) { throw std::runtime_error(format("dflash: %s must not be empty", key.c_str())); } if (n > 8) { - throw std::runtime_error(format("dflash: %s has %zu entries, max is 8", key.c_str(), n)); + throw std::runtime_error(format("dflash: %s has %u entries, max is 8", key.c_str(), n)); } - hparams.dflash_n_target_layers = (uint32_t) n; + hparams.dflash_n_target_layers = n; for (uint32_t & id : hparams.dflash_target_layer_ids) { id = 0; } - const void * data = gguf_get_arr_data(ml.meta, kid); - for (uint32_t i = 0; i < hparams.dflash_n_target_layers; ++i) { - if (type == GGUF_TYPE_INT32) { - const int32_t id = ((const int32_t *) data)[i]; - if (id < 0) { - throw std::runtime_error(format("dflash: %s contains negative layer id %d", key.c_str(), id)); + if (type == GGUF_TYPE_INT32) { + std::array layer_ids = {}; + ml.get_arr(key, layer_ids, true); + for (uint32_t i = 0; i < hparams.dflash_n_target_layers; ++i) { + if (layer_ids[i] < 0) { + throw std::runtime_error(format("dflash: %s contains negative layer id %d", key.c_str(), layer_ids[i])); } - hparams.dflash_target_layer_ids[i] = (uint32_t) id; - } else { - hparams.dflash_target_layer_ids[i] = ((const uint32_t *) data)[i]; + hparams.dflash_target_layer_ids[i] = (uint32_t) layer_ids[i]; } + } else { + std::array layer_ids = {}; + ml.get_arr(key, layer_ids, true); + for (uint32_t i = 0; i < hparams.dflash_n_target_layers; ++i) { + hparams.dflash_target_layer_ids[i] = layer_ids[i]; + } + } + for (uint32_t i = 0; i < hparams.dflash_n_target_layers; ++i) { const uint32_t id = hparams.dflash_target_layer_ids[i]; - for (uint32_t j = 0; j < i; ++j) { if (hparams.dflash_target_layer_ids[j] == id) { throw std::runtime_error(format( diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp index a871a035..b372a678 100644 --- a/src/llama-model-loader.cpp +++ b/src/llama-model-loader.cpp @@ -1260,5 +1260,7 @@ template bool llama_model_loader::get_key_or_arr>(enum llm_kv template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); +template std::enable_if::value, bool>::type llama_model_loader::get_arr_n(const std::string &, unsigned int &, bool); template std::enable_if::value, bool>::type llama_model_loader::get_arr_n(enum llm_kv, unsigned int&, bool); - +template bool llama_model_loader::get_arr(const std::string &, std::array &, bool); +template bool llama_model_loader::get_arr(const std::string &, std::array &, bool); diff --git a/src/llama-spec-features-dflash.cpp b/src/llama-spec-features-dflash.cpp index 4be45727..92ae6426 100644 --- a/src/llama-spec-features-dflash.cpp +++ b/src/llama-spec-features-dflash.cpp @@ -118,7 +118,7 @@ int32_t llama_model_dflash_target_mask_token_id(const struct llama_model * model return (int32_t) model->vocab.token_mask(); } -const struct ggml_tensor * llama_model_dflash_output_tensor( +static const ggml_tensor * llama_dflash_output_tensor( const struct llama_model * model) { if (model == nullptr) { return nullptr; @@ -142,8 +142,8 @@ int32_t llama_model_dflash_io_mode( return LLAMA_DFLASH_IO_MODE_INVALID; } - const ggml_tensor * draft_output = llama_model_dflash_output_tensor(draft_model); - const ggml_tensor * target_output = llama_model_dflash_output_tensor(target_model); + const ggml_tensor * draft_output = llama_dflash_output_tensor(draft_model); + const ggml_tensor * target_output = llama_dflash_output_tensor(target_model); if (draft_model->tok_embd == nullptr || draft_output == nullptr || target_model->tok_embd == nullptr || target_output == nullptr) { return LLAMA_DFLASH_IO_MODE_INVALID; } @@ -165,7 +165,7 @@ bool llama_model_dflash_io_tensors_match( const struct llama_model * draft_model, int32_t n_embd, int32_t n_vocab) { - const ggml_tensor * output = llama_model_dflash_output_tensor(draft_model); + const ggml_tensor * output = llama_dflash_output_tensor(draft_model); if (draft_model == nullptr || draft_model->tok_embd == nullptr || output == nullptr || n_embd <= 0 || n_vocab <= 0) { return false; } @@ -202,11 +202,17 @@ bool llama_model_share_dflash_io_tensors( const bool uses_shared_output = draft_model->output == target_model->output || draft_model->output == target_model->tok_embd; - if (draft_model->output_mtp == nullptr && target_model->output_mtp != nullptr && uses_shared_tok && uses_shared_output) { - draft_model->output_mtp = target_model->output_mtp; + if (draft_model->output_mtp == nullptr) { + if (target_model->output_mtp != nullptr && uses_shared_tok && uses_shared_output) { + draft_model->output_mtp = target_model->output_mtp; + } else if (draft_model->output != nullptr) { + draft_model->output_mtp = draft_model->output; + } else { + draft_model->output_mtp = draft_model->tok_embd; + } } - const struct ggml_tensor * output = llama_model_dflash_output_tensor(draft_model); + const struct ggml_tensor * output = llama_dflash_output_tensor(draft_model); return draft_model->tok_embd != nullptr && output != nullptr; } diff --git a/src/llama-spec-features-dflash.h b/src/llama-spec-features-dflash.h index c893db7a..cec99c0b 100644 --- a/src/llama-spec-features-dflash.h +++ b/src/llama-spec-features-dflash.h @@ -85,7 +85,6 @@ int32_t llama_model_dflash_n_target_layers(const struct llama_model * model); int32_t llama_model_dflash_n_target_features(const struct llama_model * model); int32_t llama_model_dflash_target_layer_ids(const struct llama_model * model, int32_t * layer_ids, int32_t capacity); int32_t llama_model_dflash_target_mask_token_id(const struct llama_model * model); -const struct ggml_tensor * llama_model_dflash_output_tensor(const struct llama_model * model); enum llama_dflash_io_mode { LLAMA_DFLASH_IO_MODE_INVALID = 0,