minor refactor in DFlash kv cache graph

This commit is contained in:
SamuelOliveirads 2026-06-15 18:22:56 -03:00
parent 6cae8c7ba2
commit ad24046b51
3 changed files with 143 additions and 162 deletions

View File

@ -12,6 +12,7 @@
#include <algorithm>
#include <cmath>
#include <cstring>
#include <type_traits>
#include <vector>
void llama_sync_dflash_workspace_if_pending(struct llama_context & lctx) {
@ -70,12 +71,12 @@ bool llama_context::ensure_dflash_kv_cache_tensors(int32_t cross_ctx) {
const int64_t n_embd_head_v = model.hparams.n_embd_head_v(0);
const int64_t n_head_kv = model.hparams.n_head_kv();
if (dflash.kv.cache_ctx != nullptr && !dflash.kv.k_ctx_cache.empty()) {
const bool cache_matches = (int32_t) dflash.kv.k_ctx_cache.size() == n_layer &&
dflash.kv.k_ctx_cache.front() != nullptr &&
if (dflash.kv.cache_ctx != nullptr &&
(int32_t) dflash.kv.k_ctx_cache.size() == n_layer &&
(int32_t) dflash.kv.k_ctx_workspace.size() == n_layer) {
const bool cache_matches =
(int32_t) dflash.kv.k_ctx_cache.front()->ne[2] == target_cross_ctx;
const bool workspace_matches = (int32_t) dflash.kv.k_ctx_workspace.size() == n_layer &&
dflash.kv.k_ctx_workspace.front() != nullptr &&
const bool workspace_matches =
(int32_t) dflash.kv.k_ctx_workspace.front()->ne[1] == target_workspace_n_kv_total;
if (cache_matches && workspace_matches) {
@ -98,8 +99,6 @@ bool llama_context::ensure_dflash_kv_cache_tensors(int32_t cross_ctx) {
dflash.kv.workspace_graph_rows = 0;
dflash.kv.workspace_graph_write_pos = 0;
dflash.kv.workspace_reserved_rows = 0;
dflash.kv.cache_compute_meta.clear();
dflash.kv.workspace_compute_meta.clear();
}
ggml_init_params params = {
@ -110,6 +109,7 @@ bool llama_context::ensure_dflash_kv_cache_tensors(int32_t cross_ctx) {
dflash.kv.cache_ctx = ggml_init(params);
if (dflash.kv.cache_ctx == nullptr) {
LLAMA_LOG_ERROR("%s: failed to allocate DFlash K/V cache context\n", __func__);
return false;
}
@ -123,74 +123,44 @@ bool llama_context::ensure_dflash_kv_cache_tensors(int32_t cross_ctx) {
dflash.kv.cache_bufs.reserve((size_t) std::max(1, n_layer) * 4);
for (int32_t il = 0; il < n_layer; ++il) {
ggml_backend_buffer_type_t layer_buft = llama_dflash_kv_cache_layer_buft(*this, il);
dflash.kv.k_ctx_cache[(size_t) il] = ggml_new_tensor_3d(dflash.kv.cache_ctx, GGML_TYPE_F32, n_embd_head_k, n_head_kv, target_cross_ctx);
dflash.kv.v_ctx_cache[(size_t) il] = ggml_new_tensor_3d(dflash.kv.cache_ctx, GGML_TYPE_F32, n_embd_head_v, n_head_kv, target_cross_ctx);
if (dflash.kv.k_ctx_cache[(size_t) il] == nullptr || dflash.kv.v_ctx_cache[(size_t) il] == nullptr) {
free_dflash_kv_cache_tensors();
auto alloc_kv_input = [&](ggml_tensor *& tensor, const char * tensor_tag, const char * tensor_name,
int64_t ne0, int64_t ne1, int64_t ne2) -> bool {
tensor = ggml_new_tensor_3d(dflash.kv.cache_ctx, GGML_TYPE_F32, ne0, ne1, ne2);
if (tensor == nullptr) {
LLAMA_LOG_ERROR("%s: failed to create %s for layer %d\n", __func__, tensor_tag, il);
return false;
}
ggml_set_input(dflash.kv.k_ctx_cache[(size_t) il]);
ggml_set_input(dflash.kv.v_ctx_cache[(size_t) il]);
ggml_format_name(dflash.kv.k_ctx_cache[(size_t) il], "dflash_k_ctx_cache_%d", il);
ggml_format_name(dflash.kv.v_ctx_cache[(size_t) il], "dflash_v_ctx_cache_%d", il);
ggml_set_input(tensor);
ggml_format_name(tensor, tensor_name, il);
const size_t k_bytes = ggml_backend_buft_get_alloc_size(layer_buft, dflash.kv.k_ctx_cache[(size_t) il]);
ggml_backend_buffer_t k_buf = ggml_backend_buft_alloc_buffer(layer_buft, k_bytes);
if (k_buf == nullptr) {
free_dflash_kv_cache_tensors();
return false;
}
ggml_backend_buffer_set_usage(k_buf, GGML_BACKEND_BUFFER_USAGE_COMPUTE);
ggml_backend_tensor_alloc(k_buf, dflash.kv.k_ctx_cache[(size_t) il], ggml_backend_buffer_get_base(k_buf));
ggml_backend_buffer_clear(k_buf, 0);
dflash.kv.cache_bufs.push_back(k_buf);
const size_t v_bytes = ggml_backend_buft_get_alloc_size(layer_buft, dflash.kv.v_ctx_cache[(size_t) il]);
ggml_backend_buffer_t v_buf = ggml_backend_buft_alloc_buffer(layer_buft, v_bytes);
if (v_buf == nullptr) {
free_dflash_kv_cache_tensors();
return false;
}
ggml_backend_buffer_set_usage(v_buf, GGML_BACKEND_BUFFER_USAGE_COMPUTE);
ggml_backend_tensor_alloc(v_buf, dflash.kv.v_ctx_cache[(size_t) il], ggml_backend_buffer_get_base(v_buf));
ggml_backend_buffer_clear(v_buf, 0);
dflash.kv.cache_bufs.push_back(v_buf);
dflash.kv.k_ctx_workspace[(size_t) il] = ggml_new_tensor_3d(dflash.kv.cache_ctx, GGML_TYPE_F32, n_embd_head_k, target_workspace_n_kv_total, n_head_kv);
dflash.kv.v_ctx_workspace[(size_t) il] = ggml_new_tensor_3d(dflash.kv.cache_ctx, GGML_TYPE_F32, n_embd_head_v, target_workspace_n_kv_total, n_head_kv);
if (dflash.kv.k_ctx_workspace[(size_t) il] == nullptr || dflash.kv.v_ctx_workspace[(size_t) il] == nullptr) {
free_dflash_kv_cache_tensors();
const size_t tensor_bytes = ggml_backend_buft_get_alloc_size(layer_buft, tensor);
ggml_backend_buffer_t buf = ggml_backend_buft_alloc_buffer(layer_buft, tensor_bytes);
if (buf == nullptr) {
LLAMA_LOG_ERROR("%s: failed to allocate %s buffer for layer %d (%zu bytes)\n",
__func__, tensor_tag, il, tensor_bytes);
return false;
}
ggml_set_input(dflash.kv.k_ctx_workspace[(size_t) il]);
ggml_set_input(dflash.kv.v_ctx_workspace[(size_t) il]);
ggml_format_name(dflash.kv.k_ctx_workspace[(size_t) il], "dflash_k_ctx_workspace_%d", il);
ggml_format_name(dflash.kv.v_ctx_workspace[(size_t) il], "dflash_v_ctx_workspace_%d", il);
ggml_backend_buffer_set_usage(buf, GGML_BACKEND_BUFFER_USAGE_COMPUTE);
ggml_backend_tensor_alloc(buf, tensor, ggml_backend_buffer_get_base(buf));
ggml_backend_buffer_clear(buf, 0);
dflash.kv.cache_bufs.push_back(buf);
const size_t k_workspace_bytes = ggml_backend_buft_get_alloc_size(layer_buft, dflash.kv.k_ctx_workspace[(size_t) il]);
ggml_backend_buffer_t k_workspace_buf = ggml_backend_buft_alloc_buffer(layer_buft, k_workspace_bytes);
if (k_workspace_buf == nullptr) {
return true;
};
if (!alloc_kv_input(dflash.kv.k_ctx_cache[(size_t) il], "dflash_k_ctx_cache", "dflash_k_ctx_cache_%d",
n_embd_head_k, n_head_kv, target_cross_ctx) ||
!alloc_kv_input(dflash.kv.v_ctx_cache[(size_t) il], "dflash_v_ctx_cache", "dflash_v_ctx_cache_%d",
n_embd_head_v, n_head_kv, target_cross_ctx) ||
!alloc_kv_input(dflash.kv.k_ctx_workspace[(size_t) il], "dflash_k_ctx_workspace", "dflash_k_ctx_workspace_%d",
n_embd_head_k, target_workspace_n_kv_total, n_head_kv) ||
!alloc_kv_input(dflash.kv.v_ctx_workspace[(size_t) il], "dflash_v_ctx_workspace", "dflash_v_ctx_workspace_%d",
n_embd_head_v, target_workspace_n_kv_total, n_head_kv)) {
free_dflash_kv_cache_tensors();
return false;
}
ggml_backend_buffer_set_usage(k_workspace_buf, GGML_BACKEND_BUFFER_USAGE_COMPUTE);
ggml_backend_tensor_alloc(k_workspace_buf, dflash.kv.k_ctx_workspace[(size_t) il], ggml_backend_buffer_get_base(k_workspace_buf));
ggml_backend_buffer_clear(k_workspace_buf, 0);
dflash.kv.cache_bufs.push_back(k_workspace_buf);
const size_t v_workspace_bytes = ggml_backend_buft_get_alloc_size(layer_buft, dflash.kv.v_ctx_workspace[(size_t) il]);
ggml_backend_buffer_t v_workspace_buf = ggml_backend_buft_alloc_buffer(layer_buft, v_workspace_bytes);
if (v_workspace_buf == nullptr) {
free_dflash_kv_cache_tensors();
return false;
}
ggml_backend_buffer_set_usage(v_workspace_buf, GGML_BACKEND_BUFFER_USAGE_COMPUTE);
ggml_backend_tensor_alloc(v_workspace_buf, dflash.kv.v_ctx_workspace[(size_t) il], ggml_backend_buffer_get_base(v_workspace_buf));
ggml_backend_buffer_clear(v_workspace_buf, 0);
dflash.kv.cache_bufs.push_back(v_workspace_buf);
}
dflash.kv.workspace_token_capacity = target_token_capacity;
@ -201,10 +171,15 @@ bool llama_context::ensure_dflash_kv_cache_tensors(int32_t cross_ctx) {
}
void llama_context::free_dflash_kv_cache_tensors() {
dflash.kv.k_ctx_cache.clear();
dflash.kv.v_ctx_cache.clear();
dflash.kv.k_ctx_workspace.clear();
dflash.kv.v_ctx_workspace.clear();
auto release_vector = [](auto & v) {
using vec_type = std::decay_t<decltype(v)>;
vec_type().swap(v);
};
release_vector(dflash.kv.k_ctx_cache);
release_vector(dflash.kv.v_ctx_cache);
release_vector(dflash.kv.k_ctx_workspace);
release_vector(dflash.kv.v_ctx_workspace);
dflash.kv.cache_write_pos = 0;
dflash.kv.cache_n_filled = 0;
dflash.kv.cache_update_rows = 0;
@ -244,7 +219,9 @@ void llama_context::free_dflash_kv_cache_tensors() {
ggml_backend_buffer_free(buf);
}
}
dflash.kv.cache_bufs.clear();
release_vector(dflash.kv.cache_bufs);
release_vector(dflash.kv.cache_compute_meta);
release_vector(dflash.kv.workspace_compute_meta);
if (dflash.kv.cache_ctx != nullptr) {
ggml_free(dflash.kv.cache_ctx);
dflash.kv.cache_ctx = nullptr;

View File

@ -2257,10 +2257,14 @@ bool create_tensors_helper::create_dflash_tensors(const LLM_TN & tn) {
model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
model.output = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
model.output_mtp = create_tensor(ctx_output, "output_extra.weight", {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
auto output_extra = create_tensor(ctx_output, "output_extra.weight", {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
if (output_extra != nullptr) {
model.output = output_extra;
}
if (model.output == nullptr && model.tok_embd != nullptr) {
model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
}
model.output_mtp = model.output;
model.dflash_fc = create_tensor(ctx_output, tn(LLM_TENSOR_DFLASH_FC, "weight"), {(int64_t) hparams.dflash_n_target_features, n_embd}, 0);
model.dflash_hidden_norm = create_tensor(ctx_output, tn(LLM_TENSOR_DFLASH_HIDDEN_NORM, "weight"), {n_embd}, 0);