clean redudance in dflash graph and small logics (#1994)

This commit is contained in:
Samuel Oliveira Alves 2026-06-19 04:04:54 -03:00 committed by GitHub
parent 7321648844
commit d5c04c15fd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 188 additions and 145 deletions

View File

@ -67,11 +67,11 @@ ggml_cgraph * llm_build_context::build_dflash_kv_workspace() {
};
for (int il = 0; il < n_layer; ++il) {
GGML_ASSERT((size_t) il < lctx.dflash.kv.k_ctx_cache.size());
GGML_ASSERT((size_t) il < lctx.dflash.kv.v_ctx_cache.size());
GGML_ASSERT(il < (int32_t) lctx.dflash.kv.k_ctx_cache.size());
GGML_ASSERT(il < (int32_t) lctx.dflash.kv.v_ctx_cache.size());
ggml_tensor * Kordered = build_ordered_cache_view(lctx.dflash.kv.k_ctx_cache[(size_t) il]);
ggml_tensor * Vordered = build_ordered_cache_view(lctx.dflash.kv.v_ctx_cache[(size_t) il]);
ggml_tensor * Kordered = build_ordered_cache_view(lctx.dflash.kv.k_ctx_cache[il]);
ggml_tensor * Vordered = build_ordered_cache_view(lctx.dflash.kv.v_ctx_cache[il]);
cb(Kordered, "dflash_workspace_k_ctx_view", il);
cb(Vordered, "dflash_workspace_v_ctx_view", il);
@ -80,19 +80,19 @@ ggml_cgraph * llm_build_context::build_dflash_kv_workspace() {
cb(Kworkspace, "dflash_workspace_k_perm_cont", il);
cb(Vworkspace, "dflash_workspace_v_perm_cont", il);
ggml_tensor * Kdst = ggml_view_3d(ctx0, lctx.dflash.kv.k_ctx_workspace[(size_t) il],
lctx.dflash.kv.k_ctx_workspace[(size_t) il]->ne[0],
ggml_tensor * Kdst = ggml_view_3d(ctx0, lctx.dflash.kv.k_ctx_workspace[il],
lctx.dflash.kv.k_ctx_workspace[il]->ne[0],
ctx_len,
lctx.dflash.kv.k_ctx_workspace[(size_t) il]->ne[2],
lctx.dflash.kv.k_ctx_workspace[(size_t) il]->nb[1],
lctx.dflash.kv.k_ctx_workspace[(size_t) il]->nb[2],
lctx.dflash.kv.k_ctx_workspace[il]->ne[2],
lctx.dflash.kv.k_ctx_workspace[il]->nb[1],
lctx.dflash.kv.k_ctx_workspace[il]->nb[2],
0);
ggml_tensor * Vdst = ggml_view_3d(ctx0, lctx.dflash.kv.v_ctx_workspace[(size_t) il],
lctx.dflash.kv.v_ctx_workspace[(size_t) il]->ne[0],
ggml_tensor * Vdst = ggml_view_3d(ctx0, lctx.dflash.kv.v_ctx_workspace[il],
lctx.dflash.kv.v_ctx_workspace[il]->ne[0],
ctx_len,
lctx.dflash.kv.v_ctx_workspace[(size_t) il]->ne[2],
lctx.dflash.kv.v_ctx_workspace[(size_t) il]->nb[1],
lctx.dflash.kv.v_ctx_workspace[(size_t) il]->nb[2],
lctx.dflash.kv.v_ctx_workspace[il]->ne[2],
lctx.dflash.kv.v_ctx_workspace[il]->nb[1],
lctx.dflash.kv.v_ctx_workspace[il]->nb[2],
0);
ggml_tensor * Kstore = ggml_cpy(ctx0, Kworkspace, Kdst);
@ -137,8 +137,8 @@ ggml_cgraph * llm_build_context::build_dflash_kv_cache() {
cb(fused_target, "dflash_kv_fused_target", -1);
for (int il = 0; il < n_layer; ++il) {
GGML_ASSERT((size_t) il < lctx.dflash.kv.k_ctx_cache.size());
GGML_ASSERT((size_t) il < lctx.dflash.kv.v_ctx_cache.size());
GGML_ASSERT(il < (int32_t) lctx.dflash.kv.k_ctx_cache.size());
GGML_ASSERT(il < (int32_t) lctx.dflash.kv.v_ctx_cache.size());
ggml_tensor * Kcur_ctx_proj = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, fused_target);
cb(Kcur_ctx_proj, "dflash_kv_k_proj", il);
@ -177,20 +177,20 @@ ggml_cgraph * llm_build_context::build_dflash_kv_cache() {
Vcur_ctx->nb[1],
Vcur_ctx->nb[2],
0);
ggml_tensor * Kdst_first = ggml_view_3d(ctx0, lctx.dflash.kv.k_ctx_cache[(size_t) il],
lctx.dflash.kv.k_ctx_cache[(size_t) il]->ne[0],
lctx.dflash.kv.k_ctx_cache[(size_t) il]->ne[1],
ggml_tensor * Kdst_first = ggml_view_3d(ctx0, lctx.dflash.kv.k_ctx_cache[il],
lctx.dflash.kv.k_ctx_cache[il]->ne[0],
lctx.dflash.kv.k_ctx_cache[il]->ne[1],
first_rows,
lctx.dflash.kv.k_ctx_cache[(size_t) il]->nb[1],
lctx.dflash.kv.k_ctx_cache[(size_t) il]->nb[2],
(size_t) write_pos * lctx.dflash.kv.k_ctx_cache[(size_t) il]->nb[2]);
ggml_tensor * Vdst_first = ggml_view_3d(ctx0, lctx.dflash.kv.v_ctx_cache[(size_t) il],
lctx.dflash.kv.v_ctx_cache[(size_t) il]->ne[0],
lctx.dflash.kv.v_ctx_cache[(size_t) il]->ne[1],
lctx.dflash.kv.k_ctx_cache[il]->nb[1],
lctx.dflash.kv.k_ctx_cache[il]->nb[2],
(size_t) write_pos * lctx.dflash.kv.k_ctx_cache[il]->nb[2]);
ggml_tensor * Vdst_first = ggml_view_3d(ctx0, lctx.dflash.kv.v_ctx_cache[il],
lctx.dflash.kv.v_ctx_cache[il]->ne[0],
lctx.dflash.kv.v_ctx_cache[il]->ne[1],
first_rows,
lctx.dflash.kv.v_ctx_cache[(size_t) il]->nb[1],
lctx.dflash.kv.v_ctx_cache[(size_t) il]->nb[2],
(size_t) write_pos * lctx.dflash.kv.v_ctx_cache[(size_t) il]->nb[2]);
lctx.dflash.kv.v_ctx_cache[il]->nb[1],
lctx.dflash.kv.v_ctx_cache[il]->nb[2],
(size_t) write_pos * lctx.dflash.kv.v_ctx_cache[il]->nb[2]);
ggml_tensor * Kstore_first = ggml_cpy(ctx0, Ksrc_first, Kdst_first);
cb(Kstore_first, "dflash_kv_k_store", il);
@ -216,19 +216,19 @@ ggml_cgraph * llm_build_context::build_dflash_kv_cache() {
Vcur_ctx->nb[1],
Vcur_ctx->nb[2],
(size_t) first_rows * Vcur_ctx->nb[2]);
ggml_tensor * Kdst_second = ggml_view_3d(ctx0, lctx.dflash.kv.k_ctx_cache[(size_t) il],
lctx.dflash.kv.k_ctx_cache[(size_t) il]->ne[0],
lctx.dflash.kv.k_ctx_cache[(size_t) il]->ne[1],
ggml_tensor * Kdst_second = ggml_view_3d(ctx0, lctx.dflash.kv.k_ctx_cache[il],
lctx.dflash.kv.k_ctx_cache[il]->ne[0],
lctx.dflash.kv.k_ctx_cache[il]->ne[1],
second_rows,
lctx.dflash.kv.k_ctx_cache[(size_t) il]->nb[1],
lctx.dflash.kv.k_ctx_cache[(size_t) il]->nb[2],
lctx.dflash.kv.k_ctx_cache[il]->nb[1],
lctx.dflash.kv.k_ctx_cache[il]->nb[2],
0);
ggml_tensor * Vdst_second = ggml_view_3d(ctx0, lctx.dflash.kv.v_ctx_cache[(size_t) il],
lctx.dflash.kv.v_ctx_cache[(size_t) il]->ne[0],
lctx.dflash.kv.v_ctx_cache[(size_t) il]->ne[1],
ggml_tensor * Vdst_second = ggml_view_3d(ctx0, lctx.dflash.kv.v_ctx_cache[il],
lctx.dflash.kv.v_ctx_cache[il]->ne[0],
lctx.dflash.kv.v_ctx_cache[il]->ne[1],
second_rows,
lctx.dflash.kv.v_ctx_cache[(size_t) il]->nb[1],
lctx.dflash.kv.v_ctx_cache[(size_t) il]->nb[2],
lctx.dflash.kv.v_ctx_cache[il]->nb[1],
lctx.dflash.kv.v_ctx_cache[il]->nb[2],
0);
ggml_tensor * Kstore_second = ggml_cpy(ctx0, Ksrc_second, Kdst_second);
@ -264,40 +264,39 @@ ggml_cgraph * llm_build_context::build_dflash() {
ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes((int) std::max<int64_t>(n_tokens, ctx_len)) + 32 * n_layer, false);
bool have_swa_layers = false;
const bool needs_swa_mask = hparams.n_swa > 0 && [&]() {
for (int il = 0; il < n_layer; ++il) {
if (hparams.swa_layers[il]) {
have_swa_layers = true;
break;
return true;
}
}
return false;
}();
const ggml_type mask_type = flash_attn ? GGML_TYPE_F16 : GGML_TYPE_F32;
lctx.dflash.inputs.kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv_total, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
lctx.dflash.inputs.kq_mask = ggml_new_tensor_2d(ctx0, mask_type, n_kv_total, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
lctx.dflash.kv.kq_mask_tensor = lctx.dflash.inputs.kq_mask;
ggml_set_input(lctx.dflash.inputs.kq_mask);
cb(lctx.dflash.inputs.kq_mask, "dflash_kq_mask", -1);
ggml_tensor * dflash_kq_mask_full = flash_attn ? ggml_cast(ctx0, lctx.dflash.inputs.kq_mask, GGML_TYPE_F16) : lctx.dflash.inputs.kq_mask;
ggml_tensor * dflash_kq_mask_full = lctx.dflash.inputs.kq_mask;
ggml_tensor * dflash_kq_mask_swa = nullptr;
lctx.dflash.inputs.kq_mask_swa = nullptr;
lctx.dflash.kv.kq_mask_swa_tensor = nullptr;
if (have_swa_layers && hparams.n_swa > 0) {
lctx.dflash.inputs.kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv_total, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
if (needs_swa_mask) {
lctx.dflash.inputs.kq_mask_swa = ggml_new_tensor_2d(ctx0, mask_type, n_kv_total, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
lctx.dflash.kv.kq_mask_swa_tensor = lctx.dflash.inputs.kq_mask_swa;
ggml_set_input(lctx.dflash.inputs.kq_mask_swa);
cb(lctx.dflash.inputs.kq_mask_swa, "dflash_kq_mask_swa", -1);
dflash_kq_mask_swa = flash_attn ? ggml_cast(ctx0, lctx.dflash.inputs.kq_mask_swa, GGML_TYPE_F16) : lctx.dflash.inputs.kq_mask_swa;
dflash_kq_mask_swa = lctx.dflash.inputs.kq_mask_swa;
}
ggml_tensor * tok_embd = model.tok_embd;
if (tok_embd == nullptr) {
tok_embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_Q4_0, n_embd, hparams.n_vocab);
}
GGML_ASSERT(tok_embd != nullptr);
ggml_tensor * inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, tok_embd, cb);
ggml_tensor * inp_pos = build_inp_pos();
ggml_tensor * inp_out_ids = (n_tokens > 1 && n_outputs < n_tokens) ? build_inp_out_ids() : nullptr;
bool result_rows_selected = false;
const float kq_scale = 1.0f / std::sqrt((float) n_embd_head_k);
@ -327,24 +326,24 @@ ggml_cgraph * llm_build_context::build_dflash() {
Vcur_noise = ggml_reshape_3d(ctx0, Vcur_noise, n_embd_head_v, n_head_kv, n_tokens);
cb(Vcur_noise, "Vcur_noise", il);
GGML_ASSERT((size_t) il < lctx.dflash.kv.k_ctx_workspace.size());
GGML_ASSERT((size_t) il < lctx.dflash.kv.v_ctx_workspace.size());
GGML_ASSERT(lctx.dflash.kv.k_ctx_workspace[(size_t) il] != nullptr);
GGML_ASSERT(lctx.dflash.kv.v_ctx_workspace[(size_t) il] != nullptr);
GGML_ASSERT(il < (int32_t) lctx.dflash.kv.k_ctx_workspace.size());
GGML_ASSERT(il < (int32_t) lctx.dflash.kv.v_ctx_workspace.size());
GGML_ASSERT(lctx.dflash.kv.k_ctx_workspace[il] != nullptr);
GGML_ASSERT(lctx.dflash.kv.v_ctx_workspace[il] != nullptr);
ggml_tensor * Kcur_ctx = ggml_view_3d(ctx0, lctx.dflash.kv.k_ctx_workspace[(size_t) il],
lctx.dflash.kv.k_ctx_workspace[(size_t) il]->ne[0],
ggml_tensor * Kcur_ctx = ggml_view_3d(ctx0, lctx.dflash.kv.k_ctx_workspace[il],
lctx.dflash.kv.k_ctx_workspace[il]->ne[0],
ctx_len,
lctx.dflash.kv.k_ctx_workspace[(size_t) il]->ne[2],
lctx.dflash.kv.k_ctx_workspace[(size_t) il]->nb[1],
lctx.dflash.kv.k_ctx_workspace[(size_t) il]->nb[2],
lctx.dflash.kv.k_ctx_workspace[il]->ne[2],
lctx.dflash.kv.k_ctx_workspace[il]->nb[1],
lctx.dflash.kv.k_ctx_workspace[il]->nb[2],
0);
ggml_tensor * Vcur_ctx = ggml_view_3d(ctx0, lctx.dflash.kv.v_ctx_workspace[(size_t) il],
lctx.dflash.kv.v_ctx_workspace[(size_t) il]->ne[0],
ggml_tensor * Vcur_ctx = ggml_view_3d(ctx0, lctx.dflash.kv.v_ctx_workspace[il],
lctx.dflash.kv.v_ctx_workspace[il]->ne[0],
ctx_len,
lctx.dflash.kv.v_ctx_workspace[(size_t) il]->ne[2],
lctx.dflash.kv.v_ctx_workspace[(size_t) il]->nb[1],
lctx.dflash.kv.v_ctx_workspace[(size_t) il]->nb[2],
lctx.dflash.kv.v_ctx_workspace[il]->ne[2],
lctx.dflash.kv.v_ctx_workspace[il]->nb[1],
lctx.dflash.kv.v_ctx_workspace[il]->nb[2],
0);
cb(Kcur_ctx, "Kcur_ctx_workspace", il);
cb(Vcur_ctx, "Vcur_ctx_workspace", il);
@ -400,7 +399,6 @@ ggml_cgraph * llm_build_context::build_dflash() {
if (inp_out_ids != nullptr && il == n_layer - 1) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
cb(cur, "result_output_rows", -1);
result_rows_selected = true;
}
ggml_tensor * ffn_residual = cur;
@ -421,18 +419,8 @@ ggml_cgraph * llm_build_context::build_dflash() {
inpL = cur;
}
ggml_tensor * output = const_cast<ggml_tensor *>(llama_model_dflash_output_tensor(&model));
if (output == nullptr) {
output = ggml_new_tensor_2d(ctx0, GGML_TYPE_Q4_0, n_embd, hparams.n_vocab);
}
ggml_tensor * result_input = inpL;
if (inp_out_ids && !result_rows_selected) {
result_input = ggml_get_rows(ctx0, result_input, inp_out_ids);
cb(result_input, "result_output_rows", -1);
}
ggml_tensor * result = build_output(lctx, ctx0, result_input, output, model.output_norm, cb);
GGML_ASSERT(model.output_mtp != nullptr);
ggml_tensor * result = build_output(lctx, ctx0, inpL, model.output_mtp, model.output_norm, cb);
cb(result, "result_output", -1);
ggml_build_forward_expand(gf, result);

View File

@ -25,12 +25,12 @@ void llama_sync_dflash_workspace_if_pending(struct llama_context & lctx) {
}
static ggml_backend_buffer_type_t llama_dflash_kv_cache_layer_buft(const llama_context & lctx, int32_t il) {
if (il >= 0 && (size_t) il < lctx.model.buft_layer.size() && lctx.model.buft_layer[(size_t) il].buft != nullptr) {
return lctx.model.buft_layer[(size_t) il].buft;
if (il >= 0 && il < (int32_t) lctx.model.buft_layer.size() && lctx.model.buft_layer[il].buft != nullptr) {
return lctx.model.buft_layer[il].buft;
}
if (il >= 0 && (size_t) il < lctx.model.layers.size()) {
const ggml_tensor * wk = lctx.model.layers[(size_t) il].wk;
if (il >= 0 && il < (int32_t) lctx.model.layers.size()) {
const ggml_tensor * wk = lctx.model.layers[il].wk;
if (wk != nullptr && wk->buffer != nullptr) {
return ggml_backend_buffer_get_type(wk->buffer);
}
@ -123,6 +123,11 @@ bool llama_context::ensure_dflash_kv_cache_tensors(int32_t cross_ctx) {
dflash.kv.cache_bufs.reserve((size_t) std::max(1, n_layer) * 4);
for (int32_t il = 0; il < n_layer; ++il) {
ggml_backend_buffer_type_t layer_buft = llama_dflash_kv_cache_layer_buft(*this, il);
ggml_tensor *& k_ctx_cache = dflash.kv.k_ctx_cache[il];
ggml_tensor *& v_ctx_cache = dflash.kv.v_ctx_cache[il];
ggml_tensor *& k_ctx_workspace = dflash.kv.k_ctx_workspace[il];
ggml_tensor *& v_ctx_workspace = dflash.kv.v_ctx_workspace[il];
auto alloc_kv_input = [&](ggml_tensor *& tensor, const char * tensor_tag, const char * tensor_name,
int64_t ne0, int64_t ne1, int64_t ne2) -> bool {
tensor = ggml_new_tensor_3d(dflash.kv.cache_ctx, GGML_TYPE_F32, ne0, ne1, ne2);
@ -150,13 +155,13 @@ bool llama_context::ensure_dflash_kv_cache_tensors(int32_t cross_ctx) {
return true;
};
if (!alloc_kv_input(dflash.kv.k_ctx_cache[(size_t) il], "dflash_k_ctx_cache", "dflash_k_ctx_cache_%d",
if (!alloc_kv_input(k_ctx_cache, "dflash_k_ctx_cache", "dflash_k_ctx_cache_%d",
n_embd_head_k, n_head_kv, target_cross_ctx) ||
!alloc_kv_input(dflash.kv.v_ctx_cache[(size_t) il], "dflash_v_ctx_cache", "dflash_v_ctx_cache_%d",
!alloc_kv_input(v_ctx_cache, "dflash_v_ctx_cache", "dflash_v_ctx_cache_%d",
n_embd_head_v, n_head_kv, target_cross_ctx) ||
!alloc_kv_input(dflash.kv.k_ctx_workspace[(size_t) il], "dflash_k_ctx_workspace", "dflash_k_ctx_workspace_%d",
!alloc_kv_input(k_ctx_workspace, "dflash_k_ctx_workspace", "dflash_k_ctx_workspace_%d",
n_embd_head_k, target_workspace_n_kv_total, n_head_kv) ||
!alloc_kv_input(dflash.kv.v_ctx_workspace[(size_t) il], "dflash_v_ctx_workspace", "dflash_v_ctx_workspace_%d",
!alloc_kv_input(v_ctx_workspace, "dflash_v_ctx_workspace", "dflash_v_ctx_workspace_%d",
n_embd_head_v, target_workspace_n_kv_total, n_head_kv)) {
free_dflash_kv_cache_tensors();
return false;
@ -267,23 +272,23 @@ static bool validate_dflash_graph_contract(const llama_context & lctx) {
const auto & hparams = model.hparams;
auto rope_dim_for_layer = [&hparams](int32_t il) -> uint32_t {
if (hparams.rope_dim_per_layer[(size_t) il] != 0) {
return hparams.rope_dim_per_layer[(size_t) il];
if (hparams.rope_dim_per_layer[il] != 0) {
return hparams.rope_dim_per_layer[il];
}
return hparams.swa_layers[(size_t) il] ? hparams.n_rot_swa : hparams.n_rot;
return hparams.swa_layers[il] ? hparams.n_rot_swa : hparams.n_rot;
};
auto rope_base_for_layer = [&hparams](int32_t il) -> float {
if (hparams.has_rope_freq_base_per_layer) {
return hparams.rope_freq_base_per_layer[(size_t) il];
return hparams.rope_freq_base_per_layer[il];
}
return hparams.swa_layers[(size_t) il] ? hparams.rope_freq_base_train_swa : hparams.rope_freq_base_train;
return hparams.swa_layers[il] ? hparams.rope_freq_base_train_swa : hparams.rope_freq_base_train;
};
auto rope_scale_for_layer = [&hparams](int32_t il) -> float {
return hparams.swa_layers[(size_t) il] ? hparams.rope_freq_scale_train_swa : hparams.rope_freq_scale_train;
return hparams.swa_layers[il] ? hparams.rope_freq_scale_train_swa : hparams.rope_freq_scale_train;
};
const uint32_t ref_n_head = hparams.n_head(0);
@ -322,31 +327,31 @@ static bool validate_dflash_graph_contract(const llama_context & lctx) {
return false;
}
if (model.layers[(size_t) il].attn_norm == nullptr ||
model.layers[(size_t) il].attn_q_norm == nullptr ||
model.layers[(size_t) il].attn_k_norm == nullptr) {
if (model.layers[il].attn_norm == nullptr ||
model.layers[il].attn_q_norm == nullptr ||
model.layers[il].attn_k_norm == nullptr) {
LLAMA_LOG_ERROR("%s: DFlash graph requires attn_norm, attn_q_norm, and attn_k_norm weights, but layer %d is missing one or more of them\n",
__func__, il);
return false;
}
const bool has_q_norm = model.layers[(size_t) il].attn_q_norm != nullptr;
const bool has_k_norm = model.layers[(size_t) il].attn_k_norm != nullptr;
const bool has_q_norm = model.layers[il].attn_q_norm != nullptr;
const bool has_k_norm = model.layers[il].attn_k_norm != nullptr;
if (has_q_norm != has_k_norm) {
LLAMA_LOG_ERROR("%s: DFlash graph requires symmetric Q/K norm presence, but layer %d has q_norm=%d k_norm=%d\n",
__func__, il, (int) has_q_norm, (int) has_k_norm);
return false;
}
if (model.layers[(size_t) il].attn_norm_b != nullptr ||
model.layers[(size_t) il].attn_q_norm_b != nullptr ||
model.layers[(size_t) il].attn_k_norm_b != nullptr) {
if (model.layers[il].attn_norm_b != nullptr ||
model.layers[il].attn_q_norm_b != nullptr ||
model.layers[il].attn_k_norm_b != nullptr) {
LLAMA_LOG_ERROR("%s: DFlash graph does not implement norm-bias tensors, but layer %d requires attn_norm_b/q_norm_b/k_norm_b\n",
__func__, il);
return false;
}
if (dflash_layer_has_attention_bias(model.layers[(size_t) il])) {
if (dflash_layer_has_attention_bias(model.layers[il])) {
LLAMA_LOG_ERROR("%s: DFlash graph does not implement attention bias tensors, but layer %d requires them\n",
__func__, il);
return false;
@ -655,19 +660,56 @@ bool llama_prepare_dflash_graph_inputs(
const int32_t full_visible_first = left_pad;
const int32_t full_visible_last = cross_ctx + (int32_t) n_tokens - 1;
lctx.dflash.target.kq_mask_data.assign((size_t) n_kv_total * (size_t) n_mask_tokens, -INFINITY);
const size_t mask_elems = (size_t) n_kv_total * (size_t) n_mask_tokens;
if (kq_mask->type == GGML_TYPE_F16) {
const ggml_fp16_t h_inf = ggml_fp32_to_fp16(-INFINITY);
const ggml_fp16_t h_zero = ggml_fp32_to_fp16(0.0f);
std::vector<ggml_fp16_t> mask_f16(mask_elems, h_inf);
std::vector<ggml_fp16_t> row_f16((size_t) n_kv_total, h_inf);
std::fill(row_f16.begin() + full_visible_first, row_f16.begin() + full_visible_last + 1, h_zero);
for (uint32_t j = 0; j < n_tokens; ++j) {
float * row = lctx.dflash.target.kq_mask_data.data() + (size_t) j * (size_t) n_kv_total;
for (int32_t i = full_visible_first; i <= full_visible_last; ++i) {
row[i] = 0.0f;
std::memcpy(mask_f16.data() + (size_t) j * (size_t) n_kv_total, row_f16.data(), (size_t) n_kv_total * sizeof(ggml_fp16_t));
}
ggml_backend_tensor_set(kq_mask, mask_f16.data(), 0, ggml_nbytes(kq_mask));
} else {
lctx.dflash.target.kq_mask_data.assign(mask_elems, -INFINITY);
std::vector<float> row_f32((size_t) n_kv_total, -INFINITY);
std::fill(row_f32.begin() + full_visible_first, row_f32.begin() + full_visible_last + 1, 0.0f);
for (uint32_t j = 0; j < n_tokens; ++j) {
std::memcpy(lctx.dflash.target.kq_mask_data.data() + (size_t) j * (size_t) n_kv_total, row_f32.data(), (size_t) n_kv_total * sizeof(float));
}
ggml_backend_tensor_set(kq_mask, lctx.dflash.target.kq_mask_data.data(), 0, ggml_nbytes(kq_mask));
}
if (kq_mask_swa != nullptr) {
lctx.dflash.target.kq_mask_swa_data.assign((size_t) n_kv_total * (size_t) n_mask_tokens, -INFINITY);
const int32_t swa_window = (int32_t) lctx.model.hparams.n_swa;
const int32_t draft_pos_base = (int32_t) last_target_pos;
if (kq_mask_swa->type == GGML_TYPE_F16) {
const ggml_fp16_t h_inf = ggml_fp32_to_fp16(-INFINITY);
const ggml_fp16_t h_zero = ggml_fp32_to_fp16(0.0f);
std::vector<ggml_fp16_t> mask_swa_f16(mask_elems, h_inf);
for (uint32_t j = 0; j < n_tokens; ++j) {
ggml_fp16_t * row = mask_swa_f16.data() + (size_t) j * (size_t) n_kv_total;
const int32_t q_pos = draft_pos_base + (int32_t) j;
for (int32_t k = left_pad; k < cross_ctx; ++k) {
const int32_t k_pos = (int32_t) lctx.dflash.target.pos_ctx_data[(size_t) k];
if (q_pos - k_pos < swa_window) {
row[k] = h_zero;
}
}
for (int32_t k = cross_ctx; k < cross_ctx + (int32_t) n_tokens; ++k) {
const int32_t block_k = k - cross_ctx;
if (block_k <= (int32_t) j) {
row[k] = h_zero;
}
}
}
ggml_backend_tensor_set(kq_mask_swa, mask_swa_f16.data(), 0, ggml_nbytes(kq_mask_swa));
} else {
lctx.dflash.target.kq_mask_swa_data.assign(mask_elems, -INFINITY);
for (uint32_t j = 0; j < n_tokens; ++j) {
float * row = lctx.dflash.target.kq_mask_swa_data.data() + (size_t) j * (size_t) n_kv_total;
const int32_t q_pos = draft_pos_base + (int32_t) j;
@ -686,9 +728,9 @@ bool llama_prepare_dflash_graph_inputs(
}
}
}
ggml_backend_tensor_set(kq_mask_swa, lctx.dflash.target.kq_mask_swa_data.data(), 0, ggml_nbytes(kq_mask_swa));
}
}
return true;
}

View File

@ -55,33 +55,39 @@ static bool load_dflash_target_layer_ids(
throw std::runtime_error(format("dflash: %s must be a uint32/int32 array", key.c_str()));
}
const size_t n = gguf_get_arr_n(ml.meta, kid);
uint32_t n = 0;
ml.get_arr_n(key, n, true);
if (n == 0) {
throw std::runtime_error(format("dflash: %s must not be empty", key.c_str()));
}
if (n > 8) {
throw std::runtime_error(format("dflash: %s has %zu entries, max is 8", key.c_str(), n));
throw std::runtime_error(format("dflash: %s has %u entries, max is 8", key.c_str(), n));
}
hparams.dflash_n_target_layers = (uint32_t) n;
hparams.dflash_n_target_layers = n;
for (uint32_t & id : hparams.dflash_target_layer_ids) {
id = 0;
}
const void * data = gguf_get_arr_data(ml.meta, kid);
for (uint32_t i = 0; i < hparams.dflash_n_target_layers; ++i) {
if (type == GGUF_TYPE_INT32) {
const int32_t id = ((const int32_t *) data)[i];
if (id < 0) {
throw std::runtime_error(format("dflash: %s contains negative layer id %d", key.c_str(), id));
std::array<int32_t, 8> layer_ids = {};
ml.get_arr(key, layer_ids, true);
for (uint32_t i = 0; i < hparams.dflash_n_target_layers; ++i) {
if (layer_ids[i] < 0) {
throw std::runtime_error(format("dflash: %s contains negative layer id %d", key.c_str(), layer_ids[i]));
}
hparams.dflash_target_layer_ids[i] = (uint32_t) layer_ids[i];
}
hparams.dflash_target_layer_ids[i] = (uint32_t) id;
} else {
hparams.dflash_target_layer_ids[i] = ((const uint32_t *) data)[i];
std::array<uint32_t, 8> layer_ids = {};
ml.get_arr(key, layer_ids, true);
for (uint32_t i = 0; i < hparams.dflash_n_target_layers; ++i) {
hparams.dflash_target_layer_ids[i] = layer_ids[i];
}
}
for (uint32_t i = 0; i < hparams.dflash_n_target_layers; ++i) {
const uint32_t id = hparams.dflash_target_layer_ids[i];
for (uint32_t j = 0; j < i; ++j) {
if (hparams.dflash_target_layer_ids[j] == id) {
throw std::runtime_error(format(

View File

@ -1260,5 +1260,7 @@ template bool llama_model_loader::get_key_or_arr<std::array<int, 4>>(enum llm_kv
template bool llama_model_loader::get_key_or_arr<std::array<uint32_t, 512>>(enum llm_kv kid, std::array<uint32_t, 512> & result, uint32_t n, bool required);
template bool llama_model_loader::get_key_or_arr<std::array<float, 512>>(enum llm_kv kid, std::array<float, 512> & result, uint32_t n, bool required);
template std::enable_if<std::is_integral<unsigned int>::value, bool>::type llama_model_loader::get_arr_n<unsigned int>(const std::string &, unsigned int &, bool);
template std::enable_if<std::is_integral<unsigned int>::value, bool>::type llama_model_loader::get_arr_n<unsigned int>(enum llm_kv, unsigned int&, bool);
template bool llama_model_loader::get_arr<int32_t, 8>(const std::string &, std::array<int32_t, 8> &, bool);
template bool llama_model_loader::get_arr<uint32_t, 8>(const std::string &, std::array<uint32_t, 8> &, bool);

View File

@ -118,7 +118,7 @@ int32_t llama_model_dflash_target_mask_token_id(const struct llama_model * model
return (int32_t) model->vocab.token_mask();
}
const struct ggml_tensor * llama_model_dflash_output_tensor(
static const ggml_tensor * llama_dflash_output_tensor(
const struct llama_model * model) {
if (model == nullptr) {
return nullptr;
@ -142,8 +142,8 @@ int32_t llama_model_dflash_io_mode(
return LLAMA_DFLASH_IO_MODE_INVALID;
}
const ggml_tensor * draft_output = llama_model_dflash_output_tensor(draft_model);
const ggml_tensor * target_output = llama_model_dflash_output_tensor(target_model);
const ggml_tensor * draft_output = llama_dflash_output_tensor(draft_model);
const ggml_tensor * target_output = llama_dflash_output_tensor(target_model);
if (draft_model->tok_embd == nullptr || draft_output == nullptr || target_model->tok_embd == nullptr || target_output == nullptr) {
return LLAMA_DFLASH_IO_MODE_INVALID;
}
@ -165,7 +165,7 @@ bool llama_model_dflash_io_tensors_match(
const struct llama_model * draft_model,
int32_t n_embd,
int32_t n_vocab) {
const ggml_tensor * output = llama_model_dflash_output_tensor(draft_model);
const ggml_tensor * output = llama_dflash_output_tensor(draft_model);
if (draft_model == nullptr || draft_model->tok_embd == nullptr || output == nullptr || n_embd <= 0 || n_vocab <= 0) {
return false;
}
@ -202,11 +202,17 @@ bool llama_model_share_dflash_io_tensors(
const bool uses_shared_output = draft_model->output == target_model->output ||
draft_model->output == target_model->tok_embd;
if (draft_model->output_mtp == nullptr && target_model->output_mtp != nullptr && uses_shared_tok && uses_shared_output) {
if (draft_model->output_mtp == nullptr) {
if (target_model->output_mtp != nullptr && uses_shared_tok && uses_shared_output) {
draft_model->output_mtp = target_model->output_mtp;
} else if (draft_model->output != nullptr) {
draft_model->output_mtp = draft_model->output;
} else {
draft_model->output_mtp = draft_model->tok_embd;
}
}
const struct ggml_tensor * output = llama_model_dflash_output_tensor(draft_model);
const struct ggml_tensor * output = llama_dflash_output_tensor(draft_model);
return draft_model->tok_embd != nullptr && output != nullptr;
}

View File

@ -85,7 +85,6 @@ int32_t llama_model_dflash_n_target_layers(const struct llama_model * model);
int32_t llama_model_dflash_n_target_features(const struct llama_model * model);
int32_t llama_model_dflash_target_layer_ids(const struct llama_model * model, int32_t * layer_ids, int32_t capacity);
int32_t llama_model_dflash_target_mask_token_id(const struct llama_model * model);
const struct ggml_tensor * llama_model_dflash_output_tensor(const struct llama_model * model);
enum llama_dflash_io_mode {
LLAMA_DFLASH_IO_MODE_INVALID = 0,