mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-06-28 04:30:15 -05:00
DFlash: use persistent FA-ready K/V cache (#1997)
* Prototype physical-order DFlash KV cache (cherry picked from commit f9093d9ee57cf66f6ce44c42524158bb1449d1c9) * Use persistent FA-ready DFlash KV cache (cherry picked from commit cfed6ae456b5448ac0053fbd5994037af845a69a) * Address DFlash review cleanup --------- Co-authored-by: Joel Farthing <262452229+joelfarthing@users.noreply.github.com>
This commit is contained in:
parent
72440a19fc
commit
64fceb70bc
@ -4,108 +4,6 @@
|
||||
|
||||
#include <cmath>
|
||||
|
||||
ggml_cgraph * llm_build_context::build_dflash_kv_workspace() {
|
||||
const int64_t n_embd_head_k = hparams.n_embd_head_k(0);
|
||||
const int64_t n_embd_head_v = hparams.n_embd_head_v(0);
|
||||
const int64_t ctx_len = lctx.dflash.visible_cross_ctx > 0
|
||||
? (int64_t) lctx.dflash.visible_cross_ctx
|
||||
: std::max<int64_t>(1, (int64_t) cparams.n_ctx - (int64_t) hparams.dflash_block_size);
|
||||
const int32_t cache_rows = std::clamp(lctx.dflash.kv.cache_view_n_filled, 0, (int32_t) ctx_len);
|
||||
const int32_t cache_write_pos = ctx_len > 0
|
||||
? ((lctx.dflash.kv.cache_view_write_pos % (int32_t) ctx_len) + (int32_t) ctx_len) % (int32_t) ctx_len
|
||||
: 0;
|
||||
|
||||
GGML_ASSERT(n_embd_head_k == n_embd_head_v);
|
||||
GGML_ASSERT(lctx.ensure_dflash_kv_cache_tensors((int32_t) ctx_len));
|
||||
GGML_ASSERT((int32_t) lctx.dflash.kv.k_ctx_workspace.size() == n_layer);
|
||||
GGML_ASSERT((int32_t) lctx.dflash.kv.v_ctx_workspace.size() == n_layer);
|
||||
|
||||
ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes((int) std::max<int64_t>(1, ctx_len)) + 16 * n_layer, false);
|
||||
|
||||
auto build_ordered_cache_view = [&](ggml_tensor * cache) -> ggml_tensor * {
|
||||
if (!lctx.dflash.kv.cache_view_valid || cache_rows <= 0) {
|
||||
return cache;
|
||||
}
|
||||
|
||||
if (cache_rows < ctx_len) {
|
||||
ggml_tensor * zero_pad = ggml_view_3d(ctx0, cache,
|
||||
cache->ne[0],
|
||||
cache->ne[1],
|
||||
ctx_len - cache_rows,
|
||||
cache->nb[1],
|
||||
cache->nb[2],
|
||||
(size_t) cache_rows * cache->nb[2]);
|
||||
ggml_tensor * valid = ggml_view_3d(ctx0, cache,
|
||||
cache->ne[0],
|
||||
cache->ne[1],
|
||||
cache_rows,
|
||||
cache->nb[1],
|
||||
cache->nb[2],
|
||||
0);
|
||||
return ggml_concat(ctx0, zero_pad, valid, 2);
|
||||
}
|
||||
|
||||
if (cache_write_pos == 0) {
|
||||
return cache;
|
||||
}
|
||||
|
||||
ggml_tensor * tail = ggml_view_3d(ctx0, cache,
|
||||
cache->ne[0],
|
||||
cache->ne[1],
|
||||
ctx_len - cache_write_pos,
|
||||
cache->nb[1],
|
||||
cache->nb[2],
|
||||
(size_t) cache_write_pos * cache->nb[2]);
|
||||
ggml_tensor * head = ggml_view_3d(ctx0, cache,
|
||||
cache->ne[0],
|
||||
cache->ne[1],
|
||||
cache_write_pos,
|
||||
cache->nb[1],
|
||||
cache->nb[2],
|
||||
0);
|
||||
return ggml_concat(ctx0, tail, head, 2);
|
||||
};
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
GGML_ASSERT(il < (int32_t) lctx.dflash.kv.k_ctx_cache.size());
|
||||
GGML_ASSERT(il < (int32_t) lctx.dflash.kv.v_ctx_cache.size());
|
||||
|
||||
ggml_tensor * Kordered = build_ordered_cache_view(lctx.dflash.kv.k_ctx_cache[il]);
|
||||
ggml_tensor * Vordered = build_ordered_cache_view(lctx.dflash.kv.v_ctx_cache[il]);
|
||||
cb(Kordered, "dflash_workspace_k_ctx_view", il);
|
||||
cb(Vordered, "dflash_workspace_v_ctx_view", il);
|
||||
|
||||
ggml_tensor * Kworkspace = ggml_cont(ctx0, ggml_permute(ctx0, Kordered, 0, 2, 1, 3));
|
||||
ggml_tensor * Vworkspace = ggml_cont(ctx0, ggml_permute(ctx0, Vordered, 0, 2, 1, 3));
|
||||
cb(Kworkspace, "dflash_workspace_k_perm_cont", il);
|
||||
cb(Vworkspace, "dflash_workspace_v_perm_cont", il);
|
||||
|
||||
ggml_tensor * Kdst = ggml_view_3d(ctx0, lctx.dflash.kv.k_ctx_workspace[il],
|
||||
lctx.dflash.kv.k_ctx_workspace[il]->ne[0],
|
||||
ctx_len,
|
||||
lctx.dflash.kv.k_ctx_workspace[il]->ne[2],
|
||||
lctx.dflash.kv.k_ctx_workspace[il]->nb[1],
|
||||
lctx.dflash.kv.k_ctx_workspace[il]->nb[2],
|
||||
0);
|
||||
ggml_tensor * Vdst = ggml_view_3d(ctx0, lctx.dflash.kv.v_ctx_workspace[il],
|
||||
lctx.dflash.kv.v_ctx_workspace[il]->ne[0],
|
||||
ctx_len,
|
||||
lctx.dflash.kv.v_ctx_workspace[il]->ne[2],
|
||||
lctx.dflash.kv.v_ctx_workspace[il]->nb[1],
|
||||
lctx.dflash.kv.v_ctx_workspace[il]->nb[2],
|
||||
0);
|
||||
|
||||
ggml_tensor * Kstore = ggml_cpy(ctx0, Kworkspace, Kdst);
|
||||
ggml_tensor * Vstore = ggml_cpy(ctx0, Vworkspace, Vdst);
|
||||
cb(Kstore, "dflash_workspace_k_store", il);
|
||||
cb(Vstore, "dflash_workspace_v_store", il);
|
||||
ggml_build_forward_expand(gf, Kstore);
|
||||
ggml_build_forward_expand(gf, Vstore);
|
||||
}
|
||||
|
||||
return gf;
|
||||
}
|
||||
|
||||
ggml_cgraph * llm_build_context::build_dflash_kv_cache() {
|
||||
const int64_t n_embd_head_k = hparams.n_embd_head_k(0);
|
||||
const int64_t n_embd_head_v = hparams.n_embd_head_v(0);
|
||||
@ -150,10 +48,14 @@ ggml_cgraph * llm_build_context::build_dflash_kv_cache() {
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
cb(Kcur_ctx, "dflash_kv_k_rope", il);
|
||||
Kcur_ctx = ggml_cont(ctx0, ggml_permute(ctx0, Kcur_ctx, 0, 2, 1, 3));
|
||||
cb(Kcur_ctx, "dflash_kv_k_physical", il);
|
||||
|
||||
ggml_tensor * Vcur_ctx = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, fused_target);
|
||||
cb(Vcur_ctx, "dflash_kv_v_proj", il);
|
||||
Vcur_ctx = ggml_reshape_3d(ctx0, Vcur_ctx, n_embd_head_v, n_head_kv, update_rows);
|
||||
Vcur_ctx = ggml_cont(ctx0, ggml_permute(ctx0, Vcur_ctx, 0, 2, 1, 3));
|
||||
cb(Vcur_ctx, "dflash_kv_v_physical", il);
|
||||
|
||||
const int32_t first_rows = std::min<int32_t>((int32_t) update_rows, (int32_t) ctx_len - write_pos);
|
||||
const int32_t second_rows = (int32_t) update_rows - first_rows;
|
||||
@ -163,8 +65,8 @@ ggml_cgraph * llm_build_context::build_dflash_kv_cache() {
|
||||
? Kcur_ctx
|
||||
: ggml_view_3d(ctx0, Kcur_ctx,
|
||||
Kcur_ctx->ne[0],
|
||||
Kcur_ctx->ne[1],
|
||||
first_rows,
|
||||
Kcur_ctx->ne[2],
|
||||
Kcur_ctx->nb[1],
|
||||
Kcur_ctx->nb[2],
|
||||
0);
|
||||
@ -172,25 +74,25 @@ ggml_cgraph * llm_build_context::build_dflash_kv_cache() {
|
||||
? Vcur_ctx
|
||||
: ggml_view_3d(ctx0, Vcur_ctx,
|
||||
Vcur_ctx->ne[0],
|
||||
Vcur_ctx->ne[1],
|
||||
first_rows,
|
||||
Vcur_ctx->ne[2],
|
||||
Vcur_ctx->nb[1],
|
||||
Vcur_ctx->nb[2],
|
||||
0);
|
||||
ggml_tensor * Kdst_first = ggml_view_3d(ctx0, lctx.dflash.kv.k_ctx_cache[il],
|
||||
lctx.dflash.kv.k_ctx_cache[il]->ne[0],
|
||||
lctx.dflash.kv.k_ctx_cache[il]->ne[1],
|
||||
first_rows,
|
||||
lctx.dflash.kv.k_ctx_cache[il]->ne[2],
|
||||
lctx.dflash.kv.k_ctx_cache[il]->nb[1],
|
||||
lctx.dflash.kv.k_ctx_cache[il]->nb[2],
|
||||
(size_t) write_pos * lctx.dflash.kv.k_ctx_cache[il]->nb[2]);
|
||||
(size_t) write_pos * lctx.dflash.kv.k_ctx_cache[il]->nb[1]);
|
||||
ggml_tensor * Vdst_first = ggml_view_3d(ctx0, lctx.dflash.kv.v_ctx_cache[il],
|
||||
lctx.dflash.kv.v_ctx_cache[il]->ne[0],
|
||||
lctx.dflash.kv.v_ctx_cache[il]->ne[1],
|
||||
first_rows,
|
||||
lctx.dflash.kv.v_ctx_cache[il]->ne[2],
|
||||
lctx.dflash.kv.v_ctx_cache[il]->nb[1],
|
||||
lctx.dflash.kv.v_ctx_cache[il]->nb[2],
|
||||
(size_t) write_pos * lctx.dflash.kv.v_ctx_cache[il]->nb[2]);
|
||||
(size_t) write_pos * lctx.dflash.kv.v_ctx_cache[il]->nb[1]);
|
||||
|
||||
ggml_tensor * Kstore_first = ggml_cpy(ctx0, Ksrc_first, Kdst_first);
|
||||
cb(Kstore_first, "dflash_kv_k_store", il);
|
||||
@ -204,29 +106,29 @@ ggml_cgraph * llm_build_context::build_dflash_kv_cache() {
|
||||
if (second_rows > 0) {
|
||||
ggml_tensor * Ksrc_second = ggml_view_3d(ctx0, Kcur_ctx,
|
||||
Kcur_ctx->ne[0],
|
||||
Kcur_ctx->ne[1],
|
||||
second_rows,
|
||||
Kcur_ctx->ne[2],
|
||||
Kcur_ctx->nb[1],
|
||||
Kcur_ctx->nb[2],
|
||||
(size_t) first_rows * Kcur_ctx->nb[2]);
|
||||
(size_t) first_rows * Kcur_ctx->nb[1]);
|
||||
ggml_tensor * Vsrc_second = ggml_view_3d(ctx0, Vcur_ctx,
|
||||
Vcur_ctx->ne[0],
|
||||
Vcur_ctx->ne[1],
|
||||
second_rows,
|
||||
Vcur_ctx->ne[2],
|
||||
Vcur_ctx->nb[1],
|
||||
Vcur_ctx->nb[2],
|
||||
(size_t) first_rows * Vcur_ctx->nb[2]);
|
||||
(size_t) first_rows * Vcur_ctx->nb[1]);
|
||||
ggml_tensor * Kdst_second = ggml_view_3d(ctx0, lctx.dflash.kv.k_ctx_cache[il],
|
||||
lctx.dflash.kv.k_ctx_cache[il]->ne[0],
|
||||
lctx.dflash.kv.k_ctx_cache[il]->ne[1],
|
||||
second_rows,
|
||||
lctx.dflash.kv.k_ctx_cache[il]->ne[2],
|
||||
lctx.dflash.kv.k_ctx_cache[il]->nb[1],
|
||||
lctx.dflash.kv.k_ctx_cache[il]->nb[2],
|
||||
0);
|
||||
ggml_tensor * Vdst_second = ggml_view_3d(ctx0, lctx.dflash.kv.v_ctx_cache[il],
|
||||
lctx.dflash.kv.v_ctx_cache[il]->ne[0],
|
||||
lctx.dflash.kv.v_ctx_cache[il]->ne[1],
|
||||
second_rows,
|
||||
lctx.dflash.kv.v_ctx_cache[il]->ne[2],
|
||||
lctx.dflash.kv.v_ctx_cache[il]->nb[1],
|
||||
lctx.dflash.kv.v_ctx_cache[il]->nb[2],
|
||||
0);
|
||||
@ -251,16 +153,11 @@ ggml_cgraph * llm_build_context::build_dflash() {
|
||||
const int64_t ctx_len = lctx.dflash.visible_cross_ctx > 0
|
||||
? (int64_t) lctx.dflash.visible_cross_ctx
|
||||
: std::max<int64_t>(1, (int64_t) cparams.n_ctx - (int64_t) hparams.dflash_block_size);
|
||||
const int32_t cache_write_pos = ctx_len > 0
|
||||
? ((lctx.dflash.kv.cache_view_write_pos % (int32_t) ctx_len) + (int32_t) ctx_len) % (int32_t) ctx_len
|
||||
: 0;
|
||||
const int64_t n_kv_total = GGML_PAD(ctx_len + n_tokens, flash_attn ? 256 : 32);
|
||||
const int64_t n_kv_pad = n_kv_total - (ctx_len + n_tokens);
|
||||
|
||||
GGML_ASSERT(n_embd_head_k == n_embd_head_v);
|
||||
GGML_ASSERT(n_target_features > 0);
|
||||
GGML_ASSERT(lctx.ensure_dflash_kv_cache_tensors((int32_t) ctx_len));
|
||||
GGML_ASSERT(cache_write_pos >= 0 && cache_write_pos < ctx_len);
|
||||
|
||||
ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes((int) std::max<int64_t>(n_tokens, ctx_len)) + 32 * n_layer, false);
|
||||
|
||||
@ -279,6 +176,10 @@ ggml_cgraph * llm_build_context::build_dflash() {
|
||||
ggml_set_input(lctx.dflash.inputs.kq_mask);
|
||||
cb(lctx.dflash.inputs.kq_mask, "dflash_kq_mask", -1);
|
||||
|
||||
lctx.dflash.kv.draft_tail_rows_tensor = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
|
||||
ggml_set_input(lctx.dflash.kv.draft_tail_rows_tensor);
|
||||
cb(lctx.dflash.kv.draft_tail_rows_tensor, "dflash_draft_tail_rows", -1);
|
||||
|
||||
ggml_tensor * dflash_kq_mask_full = lctx.dflash.inputs.kq_mask;
|
||||
ggml_tensor * dflash_kq_mask_swa = nullptr;
|
||||
lctx.dflash.inputs.kq_mask_swa = nullptr;
|
||||
@ -326,43 +227,43 @@ ggml_cgraph * llm_build_context::build_dflash() {
|
||||
Vcur_noise = ggml_reshape_3d(ctx0, Vcur_noise, n_embd_head_v, n_head_kv, n_tokens);
|
||||
cb(Vcur_noise, "Vcur_noise", il);
|
||||
|
||||
GGML_ASSERT(il < (int32_t) lctx.dflash.kv.k_ctx_workspace.size());
|
||||
GGML_ASSERT(il < (int32_t) lctx.dflash.kv.v_ctx_workspace.size());
|
||||
GGML_ASSERT(lctx.dflash.kv.k_ctx_workspace[il] != nullptr);
|
||||
GGML_ASSERT(lctx.dflash.kv.v_ctx_workspace[il] != nullptr);
|
||||
|
||||
ggml_tensor * Kcur_ctx = ggml_view_3d(ctx0, lctx.dflash.kv.k_ctx_workspace[il],
|
||||
lctx.dflash.kv.k_ctx_workspace[il]->ne[0],
|
||||
ctx_len,
|
||||
lctx.dflash.kv.k_ctx_workspace[il]->ne[2],
|
||||
lctx.dflash.kv.k_ctx_workspace[il]->nb[1],
|
||||
lctx.dflash.kv.k_ctx_workspace[il]->nb[2],
|
||||
0);
|
||||
ggml_tensor * Vcur_ctx = ggml_view_3d(ctx0, lctx.dflash.kv.v_ctx_workspace[il],
|
||||
lctx.dflash.kv.v_ctx_workspace[il]->ne[0],
|
||||
ctx_len,
|
||||
lctx.dflash.kv.v_ctx_workspace[il]->ne[2],
|
||||
lctx.dflash.kv.v_ctx_workspace[il]->nb[1],
|
||||
lctx.dflash.kv.v_ctx_workspace[il]->nb[2],
|
||||
0);
|
||||
cb(Kcur_ctx, "Kcur_ctx_workspace", il);
|
||||
cb(Vcur_ctx, "Vcur_ctx_workspace", il);
|
||||
GGML_ASSERT(il < (int32_t) lctx.dflash.kv.k_ctx_cache.size());
|
||||
GGML_ASSERT(il < (int32_t) lctx.dflash.kv.v_ctx_cache.size());
|
||||
GGML_ASSERT(lctx.dflash.kv.k_ctx_cache[il] != nullptr);
|
||||
GGML_ASSERT(lctx.dflash.kv.v_ctx_cache[il] != nullptr);
|
||||
GGML_ASSERT(lctx.dflash.kv.k_ctx_cache[il]->type == lctx.dflash.kv.v_ctx_cache[il]->type);
|
||||
GGML_ASSERT(lctx.dflash.kv.k_ctx_cache[il]->ne[1] >= n_kv_total);
|
||||
GGML_ASSERT(lctx.dflash.kv.v_ctx_cache[il]->ne[1] >= n_kv_total);
|
||||
|
||||
ggml_tensor * Kcur_draft = ggml_cont(ctx0, ggml_permute(ctx0, Kcur_noise, 0, 2, 1, 3));
|
||||
ggml_tensor * Vcur_draft = ggml_cont(ctx0, ggml_permute(ctx0, Vcur_noise, 0, 2, 1, 3));
|
||||
cb(Kcur_draft, "dflash_main_k_perm_cont", il);
|
||||
cb(Vcur_draft, "dflash_main_v_perm_cont", il);
|
||||
|
||||
ggml_tensor * Kcur = ggml_concat(ctx0, Kcur_ctx, Kcur_draft, 1);
|
||||
ggml_tensor * Vcur = ggml_concat(ctx0, Vcur_ctx, Vcur_draft, 1);
|
||||
cb(Kcur, "dflash_main_k_concat", il);
|
||||
cb(Vcur, "dflash_main_v_concat", il);
|
||||
ggml_tensor * Kcur = ggml_set_rows(ctx0, lctx.dflash.kv.k_ctx_cache[il], Kcur_draft, lctx.dflash.kv.draft_tail_rows_tensor);
|
||||
ggml_tensor * Vcur = ggml_set_rows(ctx0, lctx.dflash.kv.v_ctx_cache[il], Vcur_draft, lctx.dflash.kv.draft_tail_rows_tensor);
|
||||
cb(Kcur, "dflash_main_k_set_tail", il);
|
||||
cb(Vcur, "dflash_main_v_set_tail", il);
|
||||
|
||||
if (n_kv_pad > 0) {
|
||||
Kcur = ggml_pad(ctx0, Kcur, 0, (int) n_kv_pad, 0, 0);
|
||||
Vcur = ggml_pad(ctx0, Vcur, 0, (int) n_kv_pad, 0, 0);
|
||||
cb(Kcur, "dflash_main_k_pad", il);
|
||||
cb(Vcur, "dflash_main_v_pad", il);
|
||||
if (Kcur->ne[1] != n_kv_total) {
|
||||
Kcur = ggml_view_3d(ctx0, Kcur,
|
||||
Kcur->ne[0],
|
||||
n_kv_total,
|
||||
Kcur->ne[2],
|
||||
Kcur->nb[1],
|
||||
Kcur->nb[2],
|
||||
0);
|
||||
cb(Kcur, "dflash_main_k_active_view", il);
|
||||
}
|
||||
if (Vcur->ne[1] != n_kv_total) {
|
||||
Vcur = ggml_view_3d(ctx0, Vcur,
|
||||
Vcur->ne[0],
|
||||
n_kv_total,
|
||||
Vcur->ne[2],
|
||||
Vcur->nb[1],
|
||||
Vcur->nb[2],
|
||||
0);
|
||||
cb(Vcur, "dflash_main_v_active_view", il);
|
||||
}
|
||||
|
||||
if (Kcur->type == GGML_TYPE_F32) {
|
||||
|
||||
@ -2244,43 +2244,6 @@ struct ggml_cgraph * llm_build_context::llama_build_graph_dflash_kv_cache(llama_
|
||||
return result;
|
||||
}
|
||||
|
||||
struct ggml_cgraph * llm_build_context::llama_build_graph_dflash_kv_workspace(llama_context & lctx) {
|
||||
llama_batch dummy;
|
||||
dummy.n_tokens = 0;
|
||||
|
||||
llm_build_cb cb = [&](struct ggml_tensor * cur, const char * name, int il) {
|
||||
if (il >= 0) {
|
||||
int j = 0;
|
||||
for (; j < GGML_MAX_NAME - 1; ++j) {
|
||||
cur->name[j] = name[j];
|
||||
if (!name[j]) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (j < GGML_MAX_NAME - 3) {
|
||||
cur->name[j++] = '-';
|
||||
auto sil = std::to_string(il);
|
||||
for (int k = 0; k < (int) sil.size() && j < GGML_MAX_NAME - 1; ++k) {
|
||||
cur->name[j++] = sil[k];
|
||||
}
|
||||
}
|
||||
cur->name[j] = 0;
|
||||
} else {
|
||||
ggml_set_name(cur, name);
|
||||
}
|
||||
};
|
||||
|
||||
struct llm_build_context llm(lctx, dummy, cb, false, false, 0, false, &lctx.dflash.kv.workspace_compute_meta);
|
||||
|
||||
llm.init();
|
||||
|
||||
struct ggml_cgraph * result = llm.build_dflash_kv_workspace();
|
||||
|
||||
llm.free();
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
ggml_cgraph * llm_build_context::llama_build_graph(
|
||||
llama_context & lctx,
|
||||
const llama_batch & batch,
|
||||
|
||||
@ -251,8 +251,6 @@ struct llm_build_context {
|
||||
|
||||
ggml_cgraph * build_dflash_kv_cache();
|
||||
|
||||
ggml_cgraph * build_dflash_kv_workspace();
|
||||
|
||||
ggml_cgraph * build_starcoder2();
|
||||
|
||||
ggml_cgraph * build_mamba();
|
||||
@ -474,8 +472,6 @@ llm_expert_gating_func_type gating_op,
|
||||
|
||||
static ggml_cgraph * llama_build_graph_dflash_kv_cache(llama_context & lctx);
|
||||
|
||||
static ggml_cgraph * llama_build_graph_dflash_kv_workspace(llama_context & lctx);
|
||||
|
||||
static ggml_cgraph * llama_build_graph(llama_context & lctx, const llama_batch & batch, bool worst_case, int n_outputs = 0);
|
||||
|
||||
ggml_tensor * build_std_attention(ggml_cgraph * gf, ggml_tensor * attn_norm, ggml_tensor * cur,
|
||||
|
||||
@ -301,10 +301,10 @@ struct llama_context {
|
||||
struct kv_runtime_state {
|
||||
std::vector<struct ggml_tensor *> k_ctx_cache;
|
||||
std::vector<struct ggml_tensor *> v_ctx_cache;
|
||||
std::vector<struct ggml_tensor *> k_ctx_workspace;
|
||||
std::vector<struct ggml_tensor *> v_ctx_workspace;
|
||||
struct ggml_context * cache_ctx = nullptr;
|
||||
std::vector<ggml_backend_buffer_t> cache_bufs;
|
||||
std::vector<llama_pos> cache_pos;
|
||||
std::vector<uint8_t> cache_slot_valid;
|
||||
int32_t cache_write_pos = 0;
|
||||
int32_t cache_n_filled = 0;
|
||||
int32_t cache_update_rows = 0;
|
||||
@ -314,28 +314,16 @@ struct llama_context {
|
||||
uint64_t cache_applied_window_version = 0;
|
||||
bool cache_valid = false;
|
||||
bool cache_view_valid = false;
|
||||
int32_t workspace_write_pos = 0;
|
||||
int32_t workspace_n_filled = 0;
|
||||
int32_t workspace_reserved_rows = 0;
|
||||
int32_t workspace_token_capacity = 0;
|
||||
int32_t workspace_n_kv_total = 0;
|
||||
uint64_t workspace_applied_window_version = 0;
|
||||
bool workspace_valid = false;
|
||||
bool workspace_sync_pending = false;
|
||||
std::vector<uint8_t> cache_compute_meta;
|
||||
std::vector<uint8_t> workspace_compute_meta;
|
||||
ggml_backend_sched_t cache_sched = nullptr;
|
||||
ggml_backend_sched_t workspace_sched = nullptr;
|
||||
ggml_cgraph * cache_graph = nullptr;
|
||||
ggml_cgraph * workspace_graph = nullptr;
|
||||
int32_t cache_graph_rows = 0;
|
||||
int32_t cache_graph_write_pos = 0;
|
||||
int32_t workspace_graph_rows = 0;
|
||||
int32_t workspace_graph_write_pos = 0;
|
||||
struct ggml_tensor * cache_input_target_features = nullptr;
|
||||
struct ggml_tensor * cache_input_pos_ctx = nullptr;
|
||||
struct ggml_tensor * kq_mask_tensor = nullptr;
|
||||
struct ggml_tensor * kq_mask_swa_tensor = nullptr;
|
||||
struct ggml_tensor * draft_tail_rows_tensor = nullptr;
|
||||
};
|
||||
|
||||
struct capture_state {
|
||||
|
||||
@ -15,15 +15,6 @@
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
void llama_sync_dflash_workspace_if_pending(struct llama_context & lctx) {
|
||||
if (!lctx.dflash.kv.workspace_sync_pending || lctx.dflash.kv.workspace_sched == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
ggml_backend_sched_synchronize(lctx.dflash.kv.workspace_sched);
|
||||
lctx.dflash.kv.workspace_sync_pending = false;
|
||||
}
|
||||
|
||||
static ggml_backend_buffer_type_t llama_dflash_kv_cache_layer_buft(const llama_context & lctx, int32_t il) {
|
||||
if (il >= 0 && il < (int32_t) lctx.model.buft_layer.size() && lctx.model.buft_layer[il].buft != nullptr) {
|
||||
return lctx.model.buft_layer[il].buft;
|
||||
@ -64,8 +55,11 @@ static ggml_backend_t llama_backend_for_tensor(const llama_context & lctx, const
|
||||
|
||||
bool llama_context::ensure_dflash_kv_cache_tensors(int32_t cross_ctx) {
|
||||
const int32_t target_cross_ctx = std::max<int32_t>(1, cross_ctx);
|
||||
const int32_t target_token_capacity = std::max<int32_t>(1, (int32_t) model.hparams.dflash_block_size);
|
||||
const int32_t target_workspace_n_kv_total = GGML_PAD(target_cross_ctx + target_token_capacity, cparams.flash_attn ? 256 : 32);
|
||||
const int32_t target_token_capacity = std::max<int32_t>(
|
||||
1,
|
||||
std::max<int32_t>((int32_t) model.hparams.dflash_block_size, (int32_t) cparams.n_ubatch));
|
||||
const int32_t target_cache_n_kv_total = GGML_PAD(target_cross_ctx + target_token_capacity, cparams.flash_attn ? 256 : 32);
|
||||
const ggml_type target_cache_type = cparams.flash_attn ? GGML_TYPE_F16 : GGML_TYPE_F32;
|
||||
const int32_t n_layer = model.hparams.n_layer;
|
||||
const int64_t n_embd_head_k = model.hparams.n_embd_head_k(0);
|
||||
const int64_t n_embd_head_v = model.hparams.n_embd_head_v(0);
|
||||
@ -73,13 +67,16 @@ bool llama_context::ensure_dflash_kv_cache_tensors(int32_t cross_ctx) {
|
||||
|
||||
if (dflash.kv.cache_ctx != nullptr &&
|
||||
(int32_t) dflash.kv.k_ctx_cache.size() == n_layer &&
|
||||
(int32_t) dflash.kv.k_ctx_workspace.size() == n_layer) {
|
||||
(int32_t) dflash.kv.cache_pos.size() == target_cross_ctx &&
|
||||
(int32_t) dflash.kv.cache_slot_valid.size() == target_cross_ctx) {
|
||||
const bool cache_matches =
|
||||
(int32_t) dflash.kv.k_ctx_cache.front()->ne[2] == target_cross_ctx;
|
||||
const bool workspace_matches =
|
||||
(int32_t) dflash.kv.k_ctx_workspace.front()->ne[1] == target_workspace_n_kv_total;
|
||||
|
||||
if (cache_matches && workspace_matches) {
|
||||
dflash.kv.k_ctx_cache.front() != nullptr &&
|
||||
dflash.kv.v_ctx_cache.front() != nullptr &&
|
||||
dflash.kv.k_ctx_cache.front()->type == target_cache_type &&
|
||||
dflash.kv.v_ctx_cache.front()->type == target_cache_type &&
|
||||
(int32_t) dflash.kv.k_ctx_cache.front()->ne[1] == target_cache_n_kv_total &&
|
||||
(int32_t) dflash.kv.v_ctx_cache.front()->ne[1] == target_cache_n_kv_total;
|
||||
if (cache_matches) {
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -88,17 +85,9 @@ bool llama_context::ensure_dflash_kv_cache_tensors(int32_t cross_ctx) {
|
||||
ggml_backend_sched_free(dflash.kv.cache_sched);
|
||||
dflash.kv.cache_sched = nullptr;
|
||||
}
|
||||
if (dflash.kv.workspace_sched != nullptr) {
|
||||
ggml_backend_sched_free(dflash.kv.workspace_sched);
|
||||
dflash.kv.workspace_sched = nullptr;
|
||||
}
|
||||
dflash.kv.cache_graph = nullptr;
|
||||
dflash.kv.workspace_graph = nullptr;
|
||||
dflash.kv.cache_graph_rows = 0;
|
||||
dflash.kv.cache_graph_write_pos = 0;
|
||||
dflash.kv.workspace_graph_rows = 0;
|
||||
dflash.kv.workspace_graph_write_pos = 0;
|
||||
dflash.kv.workspace_reserved_rows = 0;
|
||||
}
|
||||
|
||||
ggml_init_params params = {
|
||||
@ -115,22 +104,18 @@ bool llama_context::ensure_dflash_kv_cache_tensors(int32_t cross_ctx) {
|
||||
|
||||
dflash.kv.k_ctx_cache.resize((size_t) n_layer);
|
||||
dflash.kv.v_ctx_cache.resize((size_t) n_layer);
|
||||
dflash.kv.k_ctx_workspace.clear();
|
||||
dflash.kv.v_ctx_workspace.clear();
|
||||
dflash.kv.k_ctx_workspace.resize((size_t) n_layer);
|
||||
dflash.kv.v_ctx_workspace.resize((size_t) n_layer);
|
||||
dflash.kv.cache_pos.assign((size_t) target_cross_ctx, 0);
|
||||
dflash.kv.cache_slot_valid.assign((size_t) target_cross_ctx, 0);
|
||||
dflash.kv.cache_bufs.clear();
|
||||
dflash.kv.cache_bufs.reserve((size_t) std::max(1, n_layer) * 4);
|
||||
dflash.kv.cache_bufs.reserve((size_t) std::max(1, n_layer) * 2);
|
||||
for (int32_t il = 0; il < n_layer; ++il) {
|
||||
ggml_backend_buffer_type_t layer_buft = llama_dflash_kv_cache_layer_buft(*this, il);
|
||||
ggml_tensor *& k_ctx_cache = dflash.kv.k_ctx_cache[il];
|
||||
ggml_tensor *& v_ctx_cache = dflash.kv.v_ctx_cache[il];
|
||||
ggml_tensor *& k_ctx_workspace = dflash.kv.k_ctx_workspace[il];
|
||||
ggml_tensor *& v_ctx_workspace = dflash.kv.v_ctx_workspace[il];
|
||||
|
||||
auto alloc_kv_input = [&](ggml_tensor *& tensor, const char * tensor_tag, const char * tensor_name,
|
||||
int64_t ne0, int64_t ne1, int64_t ne2) -> bool {
|
||||
tensor = ggml_new_tensor_3d(dflash.kv.cache_ctx, GGML_TYPE_F32, ne0, ne1, ne2);
|
||||
ggml_type type, int64_t ne0, int64_t ne1, int64_t ne2) -> bool {
|
||||
tensor = ggml_new_tensor_3d(dflash.kv.cache_ctx, type, ne0, ne1, ne2);
|
||||
if (tensor == nullptr) {
|
||||
LLAMA_LOG_ERROR("%s: failed to create %s for layer %d\n", __func__, tensor_tag, il);
|
||||
return false;
|
||||
@ -156,20 +141,14 @@ bool llama_context::ensure_dflash_kv_cache_tensors(int32_t cross_ctx) {
|
||||
};
|
||||
|
||||
if (!alloc_kv_input(k_ctx_cache, "dflash_k_ctx_cache", "dflash_k_ctx_cache_%d",
|
||||
n_embd_head_k, n_head_kv, target_cross_ctx) ||
|
||||
target_cache_type, n_embd_head_k, target_cache_n_kv_total, n_head_kv) ||
|
||||
!alloc_kv_input(v_ctx_cache, "dflash_v_ctx_cache", "dflash_v_ctx_cache_%d",
|
||||
n_embd_head_v, n_head_kv, target_cross_ctx) ||
|
||||
!alloc_kv_input(k_ctx_workspace, "dflash_k_ctx_workspace", "dflash_k_ctx_workspace_%d",
|
||||
n_embd_head_k, target_workspace_n_kv_total, n_head_kv) ||
|
||||
!alloc_kv_input(v_ctx_workspace, "dflash_v_ctx_workspace", "dflash_v_ctx_workspace_%d",
|
||||
n_embd_head_v, target_workspace_n_kv_total, n_head_kv)) {
|
||||
target_cache_type, n_embd_head_v, target_cache_n_kv_total, n_head_kv)) {
|
||||
free_dflash_kv_cache_tensors();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
dflash.kv.workspace_token_capacity = target_token_capacity;
|
||||
dflash.kv.workspace_n_kv_total = target_workspace_n_kv_total;
|
||||
llama_reset_dflash_kv_cache_state(this);
|
||||
|
||||
return true;
|
||||
@ -183,8 +162,8 @@ void llama_context::free_dflash_kv_cache_tensors() {
|
||||
|
||||
release_vector(dflash.kv.k_ctx_cache);
|
||||
release_vector(dflash.kv.v_ctx_cache);
|
||||
release_vector(dflash.kv.k_ctx_workspace);
|
||||
release_vector(dflash.kv.v_ctx_workspace);
|
||||
release_vector(dflash.kv.cache_pos);
|
||||
release_vector(dflash.kv.cache_slot_valid);
|
||||
dflash.kv.cache_write_pos = 0;
|
||||
dflash.kv.cache_n_filled = 0;
|
||||
dflash.kv.cache_update_rows = 0;
|
||||
@ -194,30 +173,14 @@ void llama_context::free_dflash_kv_cache_tensors() {
|
||||
dflash.kv.cache_applied_window_version = 0;
|
||||
dflash.kv.cache_valid = false;
|
||||
dflash.kv.cache_view_valid = false;
|
||||
dflash.kv.workspace_write_pos = 0;
|
||||
dflash.kv.workspace_n_filled = 0;
|
||||
dflash.kv.workspace_reserved_rows = 0;
|
||||
dflash.kv.workspace_token_capacity = 0;
|
||||
dflash.kv.workspace_n_kv_total = 0;
|
||||
dflash.kv.workspace_applied_window_version = 0;
|
||||
dflash.kv.workspace_valid = false;
|
||||
dflash.kv.workspace_sync_pending = false;
|
||||
dflash.kv.cache_graph = nullptr;
|
||||
dflash.kv.workspace_graph = nullptr;
|
||||
dflash.kv.cache_graph_rows = 0;
|
||||
dflash.kv.cache_graph_write_pos = 0;
|
||||
dflash.kv.workspace_graph_rows = 0;
|
||||
dflash.kv.workspace_graph_write_pos = 0;
|
||||
dflash.kv.cache_input_target_features = nullptr;
|
||||
dflash.kv.cache_input_pos_ctx = nullptr;
|
||||
dflash.kv.kq_mask_tensor = nullptr;
|
||||
dflash.kv.kq_mask_swa_tensor = nullptr;
|
||||
|
||||
if (dflash.kv.workspace_sched != nullptr) {
|
||||
ggml_backend_sched_synchronize(dflash.kv.workspace_sched);
|
||||
ggml_backend_sched_free(dflash.kv.workspace_sched);
|
||||
dflash.kv.workspace_sched = nullptr;
|
||||
}
|
||||
dflash.kv.draft_tail_rows_tensor = nullptr;
|
||||
|
||||
for (ggml_backend_buffer_t buf : dflash.kv.cache_bufs) {
|
||||
if (buf != nullptr) {
|
||||
@ -226,7 +189,6 @@ void llama_context::free_dflash_kv_cache_tensors() {
|
||||
}
|
||||
release_vector(dflash.kv.cache_bufs);
|
||||
release_vector(dflash.kv.cache_compute_meta);
|
||||
release_vector(dflash.kv.workspace_compute_meta);
|
||||
if (dflash.kv.cache_ctx != nullptr) {
|
||||
ggml_free(dflash.kv.cache_ctx);
|
||||
dflash.kv.cache_ctx = nullptr;
|
||||
@ -388,13 +350,10 @@ bool llama_prepare_dflash_graph_inputs(
|
||||
const int32_t n_rows = lctx.dflash.target.features_n_rows;
|
||||
const int32_t append_rows_available = lctx.dflash.target.append_features_n_rows;
|
||||
const int32_t width = (int32_t) lctx.model.hparams.dflash_n_target_features;
|
||||
const int32_t graph_cross_ctx = lctx.dflash.kv.k_ctx_cache.front() != nullptr
|
||||
? (int32_t) lctx.dflash.kv.k_ctx_cache.front()->ne[2]
|
||||
: 0;
|
||||
const int32_t graph_cross_ctx = (int32_t) lctx.dflash.kv.cache_pos.size();
|
||||
const int32_t n_mask_tokens = (int32_t) kq_mask->ne[1];
|
||||
const int32_t n_kv_total = (int32_t) kq_mask->ne[0];
|
||||
|
||||
llama_sync_dflash_workspace_if_pending(lctx);
|
||||
ggml_tensor * draft_tail_rows = lctx.dflash.kv.draft_tail_rows_tensor;
|
||||
|
||||
if (graph_cross_ctx != cross_ctx) {
|
||||
LLAMA_LOG_ERROR("%s: DFlash graph cross_ctx drift (graph=%d configured=%d)\n",
|
||||
@ -418,8 +377,10 @@ bool llama_prepare_dflash_graph_inputs(
|
||||
__func__, n_kv_total, cross_ctx + (int32_t) n_tokens);
|
||||
return false;
|
||||
}
|
||||
|
||||
const int32_t left_pad = cross_ctx - n_rows;
|
||||
if (draft_tail_rows == nullptr || draft_tail_rows->type != GGML_TYPE_I32 || draft_tail_rows->ne[0] != (int64_t) n_tokens) {
|
||||
LLAMA_LOG_ERROR("%s: DFlash draft tail row input is not initialized for n_tokens=%u\n", __func__, n_tokens);
|
||||
return false;
|
||||
}
|
||||
|
||||
lctx.dflash.target.pos_ctx_data.resize((size_t) cross_ctx);
|
||||
std::fill(lctx.dflash.target.pos_ctx_data.begin(), lctx.dflash.target.pos_ctx_data.end(), 0);
|
||||
@ -437,7 +398,6 @@ bool llama_prepare_dflash_graph_inputs(
|
||||
return false;
|
||||
}
|
||||
}
|
||||
std::copy(src_pos, src_pos + n_rows, lctx.dflash.target.pos_ctx_data.begin() + (ptrdiff_t) left_pad);
|
||||
|
||||
const llama_dflash_kv_cache_transition cache_plan = llama_plan_dflash_kv_cache_transition(
|
||||
cross_ctx,
|
||||
@ -520,6 +480,7 @@ bool llama_prepare_dflash_graph_inputs(
|
||||
llama_reset_dflash_kv_cache_state(&lctx);
|
||||
}
|
||||
|
||||
const int32_t cache_write_start = lctx.dflash.kv.cache_write_pos;
|
||||
lctx.dflash.kv.cache_update_rows = update_rows;
|
||||
ggml_cgraph * gf_kv = nullptr;
|
||||
const bool can_reuse_kv_graph = lctx.dflash.kv.cache_graph != nullptr &&
|
||||
@ -558,6 +519,18 @@ bool llama_prepare_dflash_graph_inputs(
|
||||
llama_graph_compute_sched(lctx, lctx.dflash.kv.cache_sched, gf_kv, lctx.cparams.n_threads);
|
||||
ggml_backend_sched_synchronize(lctx.dflash.kv.cache_sched);
|
||||
|
||||
if ((int32_t) lctx.dflash.kv.cache_pos.size() != cross_ctx) {
|
||||
lctx.dflash.kv.cache_pos.assign((size_t) cross_ctx, 0);
|
||||
}
|
||||
if ((int32_t) lctx.dflash.kv.cache_slot_valid.size() != cross_ctx) {
|
||||
lctx.dflash.kv.cache_slot_valid.assign((size_t) cross_ctx, 0);
|
||||
}
|
||||
for (int32_t i = 0; i < update_rows; ++i) {
|
||||
const int32_t slot = (cache_write_start + i) % cross_ctx;
|
||||
lctx.dflash.kv.cache_pos[(size_t) slot] = update_pos[i];
|
||||
lctx.dflash.kv.cache_slot_valid[(size_t) slot] = 1;
|
||||
}
|
||||
|
||||
lctx.dflash.kv.cache_n_filled = std::min(cross_ctx, lctx.dflash.kv.cache_n_filled + update_rows);
|
||||
lctx.dflash.kv.cache_write_pos = (lctx.dflash.kv.cache_write_pos + update_rows) % cross_ctx;
|
||||
lctx.dflash.kv.cache_applied_window_version = lctx.dflash.target.version;
|
||||
@ -567,101 +540,36 @@ bool llama_prepare_dflash_graph_inputs(
|
||||
lctx.dflash.kv.cache_view_valid = true;
|
||||
}
|
||||
|
||||
if (lctx.dflash.kv.cache_view_valid &&
|
||||
!lctx.dflash.kv.k_ctx_workspace.empty() && !lctx.dflash.kv.v_ctx_workspace.empty()) {
|
||||
const bool need_workspace_refresh = !lctx.dflash.kv.workspace_valid ||
|
||||
lctx.dflash.kv.workspace_n_filled != lctx.dflash.kv.cache_view_n_filled ||
|
||||
lctx.dflash.kv.workspace_write_pos != lctx.dflash.kv.cache_view_write_pos ||
|
||||
lctx.dflash.kv.workspace_applied_window_version != lctx.dflash.kv.cache_applied_window_version;
|
||||
|
||||
if (need_workspace_refresh) {
|
||||
const size_t max_nodes = lctx.model.max_nodes((int) std::max<int32_t>(1, cross_ctx)) + 16 * lctx.model.hparams.n_layer;
|
||||
const size_t meta_size = ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false);
|
||||
if (lctx.dflash.kv.workspace_compute_meta.size() != meta_size) {
|
||||
lctx.dflash.kv.workspace_compute_meta.resize(meta_size);
|
||||
}
|
||||
|
||||
ggml_cgraph * gf_workspace = nullptr;
|
||||
const bool can_reuse_workspace_graph = lctx.dflash.kv.workspace_graph != nullptr &&
|
||||
lctx.dflash.kv.workspace_graph_rows == lctx.dflash.kv.cache_view_n_filled &&
|
||||
lctx.dflash.kv.workspace_graph_write_pos == lctx.dflash.kv.cache_view_write_pos;
|
||||
|
||||
if (can_reuse_workspace_graph) {
|
||||
gf_workspace = lctx.dflash.kv.workspace_graph;
|
||||
} else {
|
||||
gf_workspace = llm_build_context::llama_build_graph_dflash_kv_workspace(lctx);
|
||||
if (gf_workspace == nullptr) {
|
||||
LLAMA_LOG_ERROR("%s: failed to build DFlash K/V workspace graph\n", __func__);
|
||||
if ((int32_t) lctx.dflash.kv.cache_pos.size() != cross_ctx ||
|
||||
(int32_t) lctx.dflash.kv.cache_slot_valid.size() != cross_ctx) {
|
||||
LLAMA_LOG_ERROR("%s: DFlash physical cache slot map is not initialized\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<ggml_backend_buffer_type_t> backend_buft;
|
||||
backend_buft.reserve(lctx.backends.size());
|
||||
for (auto * backend : lctx.backends) {
|
||||
if (ggml_backend_is_cpu(backend)) {
|
||||
backend_buft.push_back(llama_default_buffer_type_cpu(true));
|
||||
} else {
|
||||
backend_buft.push_back(ggml_backend_get_default_buffer_type(backend));
|
||||
for (int32_t i = 0; i < cross_ctx; ++i) {
|
||||
if (lctx.dflash.kv.cache_slot_valid[(size_t) i]) {
|
||||
lctx.dflash.target.pos_ctx_data[(size_t) i] = lctx.dflash.kv.cache_pos[(size_t) i];
|
||||
}
|
||||
}
|
||||
|
||||
if (lctx.dflash.kv.workspace_sched == nullptr) {
|
||||
lctx.dflash.kv.workspace_sched = ggml_backend_sched_new(lctx.backends.data(), backend_buft.data(), lctx.backends.size(), max_nodes, false);
|
||||
std::vector<int32_t> draft_tail_rows_data((size_t) n_tokens);
|
||||
for (uint32_t i = 0; i < n_tokens; ++i) {
|
||||
draft_tail_rows_data[(size_t) i] = cross_ctx + (int32_t) i;
|
||||
}
|
||||
ggml_backend_tensor_set(draft_tail_rows, draft_tail_rows_data.data(), 0, ggml_nbytes(draft_tail_rows));
|
||||
|
||||
if (lctx.dflash.kv.workspace_reserved_rows != cross_ctx) {
|
||||
const bool saved_view_valid = lctx.dflash.kv.cache_view_valid;
|
||||
const int32_t saved_view_rows = lctx.dflash.kv.cache_view_n_filled;
|
||||
const int32_t saved_view_write_pos = lctx.dflash.kv.cache_view_write_pos;
|
||||
|
||||
lctx.dflash.kv.cache_view_valid = true;
|
||||
lctx.dflash.kv.cache_view_n_filled = cross_ctx;
|
||||
lctx.dflash.kv.cache_view_write_pos = cross_ctx > 1 ? 1 : 0;
|
||||
|
||||
ggml_cgraph * gf_workspace_reserve = llm_build_context::llama_build_graph_dflash_kv_workspace(lctx);
|
||||
|
||||
lctx.dflash.kv.cache_view_valid = saved_view_valid;
|
||||
lctx.dflash.kv.cache_view_n_filled = saved_view_rows;
|
||||
lctx.dflash.kv.cache_view_write_pos = saved_view_write_pos;
|
||||
|
||||
const bool reserved = lctx.dflash.kv.workspace_sched != nullptr &&
|
||||
gf_workspace_reserve != nullptr &&
|
||||
ggml_backend_sched_reserve(lctx.dflash.kv.workspace_sched, gf_workspace_reserve);
|
||||
if (!reserved) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize DFlash K/V workspace scheduler\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
lctx.dflash.kv.workspace_reserved_rows = cross_ctx;
|
||||
}
|
||||
|
||||
ggml_backend_sched_reset(lctx.dflash.kv.workspace_sched);
|
||||
ggml_backend_sched_alloc_graph(lctx.dflash.kv.workspace_sched, gf_workspace);
|
||||
|
||||
lctx.dflash.kv.workspace_graph = gf_workspace;
|
||||
lctx.dflash.kv.workspace_graph_rows = lctx.dflash.kv.cache_view_n_filled;
|
||||
lctx.dflash.kv.workspace_graph_write_pos = lctx.dflash.kv.cache_view_write_pos;
|
||||
}
|
||||
|
||||
llama_graph_compute_sched(lctx, lctx.dflash.kv.workspace_sched, gf_workspace, lctx.cparams.n_threads);
|
||||
lctx.dflash.kv.workspace_sync_pending = true;
|
||||
|
||||
lctx.dflash.kv.workspace_n_filled = lctx.dflash.kv.cache_view_n_filled;
|
||||
lctx.dflash.kv.workspace_write_pos = lctx.dflash.kv.cache_view_write_pos;
|
||||
lctx.dflash.kv.workspace_applied_window_version = lctx.dflash.kv.cache_applied_window_version;
|
||||
lctx.dflash.kv.workspace_valid = true;
|
||||
}
|
||||
}
|
||||
|
||||
const int32_t full_visible_first = left_pad;
|
||||
const int32_t full_visible_last = cross_ctx + (int32_t) n_tokens - 1;
|
||||
const size_t mask_elems = (size_t) n_kv_total * (size_t) n_mask_tokens;
|
||||
if (kq_mask->type == GGML_TYPE_F16) {
|
||||
const ggml_fp16_t h_inf = ggml_fp32_to_fp16(-INFINITY);
|
||||
const ggml_fp16_t h_zero = ggml_fp32_to_fp16(0.0f);
|
||||
std::vector<ggml_fp16_t> mask_f16(mask_elems, h_inf);
|
||||
std::vector<ggml_fp16_t> row_f16((size_t) n_kv_total, h_inf);
|
||||
std::fill(row_f16.begin() + full_visible_first, row_f16.begin() + full_visible_last + 1, h_zero);
|
||||
for (int32_t i = 0; i < cross_ctx; ++i) {
|
||||
if (lctx.dflash.kv.cache_slot_valid[(size_t) i]) {
|
||||
row_f16[(size_t) i] = h_zero;
|
||||
}
|
||||
}
|
||||
std::fill(row_f16.begin() + cross_ctx, row_f16.begin() + cross_ctx + n_tokens, h_zero);
|
||||
for (uint32_t j = 0; j < n_tokens; ++j) {
|
||||
std::memcpy(mask_f16.data() + (size_t) j * (size_t) n_kv_total, row_f16.data(), (size_t) n_kv_total * sizeof(ggml_fp16_t));
|
||||
}
|
||||
@ -669,7 +577,12 @@ bool llama_prepare_dflash_graph_inputs(
|
||||
} else {
|
||||
lctx.dflash.target.kq_mask_data.assign(mask_elems, -INFINITY);
|
||||
std::vector<float> row_f32((size_t) n_kv_total, -INFINITY);
|
||||
std::fill(row_f32.begin() + full_visible_first, row_f32.begin() + full_visible_last + 1, 0.0f);
|
||||
for (int32_t i = 0; i < cross_ctx; ++i) {
|
||||
if (lctx.dflash.kv.cache_slot_valid[(size_t) i]) {
|
||||
row_f32[(size_t) i] = 0.0f;
|
||||
}
|
||||
}
|
||||
std::fill(row_f32.begin() + cross_ctx, row_f32.begin() + cross_ctx + n_tokens, 0.0f);
|
||||
for (uint32_t j = 0; j < n_tokens; ++j) {
|
||||
std::memcpy(lctx.dflash.target.kq_mask_data.data() + (size_t) j * (size_t) n_kv_total, row_f32.data(), (size_t) n_kv_total * sizeof(float));
|
||||
}
|
||||
@ -688,7 +601,10 @@ bool llama_prepare_dflash_graph_inputs(
|
||||
ggml_fp16_t * row = mask_swa_f16.data() + (size_t) j * (size_t) n_kv_total;
|
||||
const int32_t q_pos = draft_pos_base + (int32_t) j;
|
||||
|
||||
for (int32_t k = left_pad; k < cross_ctx; ++k) {
|
||||
for (int32_t k = 0; k < cross_ctx; ++k) {
|
||||
if (!lctx.dflash.kv.cache_slot_valid[(size_t) k]) {
|
||||
continue;
|
||||
}
|
||||
const int32_t k_pos = (int32_t) lctx.dflash.target.pos_ctx_data[(size_t) k];
|
||||
if (q_pos - k_pos < swa_window) {
|
||||
row[k] = h_zero;
|
||||
@ -709,7 +625,10 @@ bool llama_prepare_dflash_graph_inputs(
|
||||
float * row = lctx.dflash.target.kq_mask_swa_data.data() + (size_t) j * (size_t) n_kv_total;
|
||||
const int32_t q_pos = draft_pos_base + (int32_t) j;
|
||||
|
||||
for (int32_t k = left_pad; k < cross_ctx; ++k) {
|
||||
for (int32_t k = 0; k < cross_ctx; ++k) {
|
||||
if (!lctx.dflash.kv.cache_slot_valid[(size_t) k]) {
|
||||
continue;
|
||||
}
|
||||
const int32_t k_pos = (int32_t) lctx.dflash.target.pos_ctx_data[(size_t) k];
|
||||
if (q_pos - k_pos < swa_window) {
|
||||
row[k] = 0.0f;
|
||||
|
||||
@ -5,4 +5,3 @@
|
||||
struct llama_context;
|
||||
|
||||
bool llama_prepare_dflash_graph_inputs(llama_context & lctx, uint32_t n_tokens);
|
||||
void llama_sync_dflash_workspace_if_pending(llama_context & lctx);
|
||||
|
||||
@ -21,11 +21,8 @@ void llama_reset_dflash_kv_cache_state(struct llama_context * ctx) {
|
||||
ctx->dflash.kv.cache_applied_window_version = 0;
|
||||
ctx->dflash.kv.cache_valid = false;
|
||||
ctx->dflash.kv.cache_view_valid = false;
|
||||
ctx->dflash.kv.workspace_write_pos = 0;
|
||||
ctx->dflash.kv.workspace_n_filled = 0;
|
||||
ctx->dflash.kv.workspace_applied_window_version = 0;
|
||||
ctx->dflash.kv.workspace_valid = false;
|
||||
ctx->dflash.kv.workspace_sync_pending = false;
|
||||
std::fill(ctx->dflash.kv.cache_pos.begin(), ctx->dflash.kv.cache_pos.end(), 0);
|
||||
std::fill(ctx->dflash.kv.cache_slot_valid.begin(), ctx->dflash.kv.cache_slot_valid.end(), 0);
|
||||
|
||||
for (ggml_backend_buffer_t buf : ctx->dflash.kv.cache_bufs) {
|
||||
if (buf != nullptr) {
|
||||
|
||||
@ -5493,9 +5493,6 @@ static int llama_decode_internal(
|
||||
#if IK_PRINT_TIMING
|
||||
tim1 = ggml_time_us();
|
||||
#endif
|
||||
if (lctx.dflash.kv.workspace_sync_pending) {
|
||||
llama_sync_dflash_workspace_if_pending(lctx);
|
||||
}
|
||||
llama_graph_compute(lctx, gf, n_threads);
|
||||
#if IK_PRINT_TIMING
|
||||
llama_synchronize(&lctx);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user