mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-06-28 04:30:15 -05:00
minor refactor in DFlash kv cache graph
This commit is contained in:
parent
6cae8c7ba2
commit
ad24046b51
@ -12,6 +12,7 @@
|
|||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
|
#include <type_traits>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
void llama_sync_dflash_workspace_if_pending(struct llama_context & lctx) {
|
void llama_sync_dflash_workspace_if_pending(struct llama_context & lctx) {
|
||||||
@ -70,12 +71,12 @@ bool llama_context::ensure_dflash_kv_cache_tensors(int32_t cross_ctx) {
|
|||||||
const int64_t n_embd_head_v = model.hparams.n_embd_head_v(0);
|
const int64_t n_embd_head_v = model.hparams.n_embd_head_v(0);
|
||||||
const int64_t n_head_kv = model.hparams.n_head_kv();
|
const int64_t n_head_kv = model.hparams.n_head_kv();
|
||||||
|
|
||||||
if (dflash.kv.cache_ctx != nullptr && !dflash.kv.k_ctx_cache.empty()) {
|
if (dflash.kv.cache_ctx != nullptr &&
|
||||||
const bool cache_matches = (int32_t) dflash.kv.k_ctx_cache.size() == n_layer &&
|
(int32_t) dflash.kv.k_ctx_cache.size() == n_layer &&
|
||||||
dflash.kv.k_ctx_cache.front() != nullptr &&
|
(int32_t) dflash.kv.k_ctx_workspace.size() == n_layer) {
|
||||||
|
const bool cache_matches =
|
||||||
(int32_t) dflash.kv.k_ctx_cache.front()->ne[2] == target_cross_ctx;
|
(int32_t) dflash.kv.k_ctx_cache.front()->ne[2] == target_cross_ctx;
|
||||||
const bool workspace_matches = (int32_t) dflash.kv.k_ctx_workspace.size() == n_layer &&
|
const bool workspace_matches =
|
||||||
dflash.kv.k_ctx_workspace.front() != nullptr &&
|
|
||||||
(int32_t) dflash.kv.k_ctx_workspace.front()->ne[1] == target_workspace_n_kv_total;
|
(int32_t) dflash.kv.k_ctx_workspace.front()->ne[1] == target_workspace_n_kv_total;
|
||||||
|
|
||||||
if (cache_matches && workspace_matches) {
|
if (cache_matches && workspace_matches) {
|
||||||
@ -98,8 +99,6 @@ bool llama_context::ensure_dflash_kv_cache_tensors(int32_t cross_ctx) {
|
|||||||
dflash.kv.workspace_graph_rows = 0;
|
dflash.kv.workspace_graph_rows = 0;
|
||||||
dflash.kv.workspace_graph_write_pos = 0;
|
dflash.kv.workspace_graph_write_pos = 0;
|
||||||
dflash.kv.workspace_reserved_rows = 0;
|
dflash.kv.workspace_reserved_rows = 0;
|
||||||
dflash.kv.cache_compute_meta.clear();
|
|
||||||
dflash.kv.workspace_compute_meta.clear();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_init_params params = {
|
ggml_init_params params = {
|
||||||
@ -110,6 +109,7 @@ bool llama_context::ensure_dflash_kv_cache_tensors(int32_t cross_ctx) {
|
|||||||
|
|
||||||
dflash.kv.cache_ctx = ggml_init(params);
|
dflash.kv.cache_ctx = ggml_init(params);
|
||||||
if (dflash.kv.cache_ctx == nullptr) {
|
if (dflash.kv.cache_ctx == nullptr) {
|
||||||
|
LLAMA_LOG_ERROR("%s: failed to allocate DFlash K/V cache context\n", __func__);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -123,74 +123,44 @@ bool llama_context::ensure_dflash_kv_cache_tensors(int32_t cross_ctx) {
|
|||||||
dflash.kv.cache_bufs.reserve((size_t) std::max(1, n_layer) * 4);
|
dflash.kv.cache_bufs.reserve((size_t) std::max(1, n_layer) * 4);
|
||||||
for (int32_t il = 0; il < n_layer; ++il) {
|
for (int32_t il = 0; il < n_layer; ++il) {
|
||||||
ggml_backend_buffer_type_t layer_buft = llama_dflash_kv_cache_layer_buft(*this, il);
|
ggml_backend_buffer_type_t layer_buft = llama_dflash_kv_cache_layer_buft(*this, il);
|
||||||
|
auto alloc_kv_input = [&](ggml_tensor *& tensor, const char * tensor_tag, const char * tensor_name,
|
||||||
dflash.kv.k_ctx_cache[(size_t) il] = ggml_new_tensor_3d(dflash.kv.cache_ctx, GGML_TYPE_F32, n_embd_head_k, n_head_kv, target_cross_ctx);
|
int64_t ne0, int64_t ne1, int64_t ne2) -> bool {
|
||||||
dflash.kv.v_ctx_cache[(size_t) il] = ggml_new_tensor_3d(dflash.kv.cache_ctx, GGML_TYPE_F32, n_embd_head_v, n_head_kv, target_cross_ctx);
|
tensor = ggml_new_tensor_3d(dflash.kv.cache_ctx, GGML_TYPE_F32, ne0, ne1, ne2);
|
||||||
if (dflash.kv.k_ctx_cache[(size_t) il] == nullptr || dflash.kv.v_ctx_cache[(size_t) il] == nullptr) {
|
if (tensor == nullptr) {
|
||||||
free_dflash_kv_cache_tensors();
|
LLAMA_LOG_ERROR("%s: failed to create %s for layer %d\n", __func__, tensor_tag, il);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_set_input(dflash.kv.k_ctx_cache[(size_t) il]);
|
ggml_set_input(tensor);
|
||||||
ggml_set_input(dflash.kv.v_ctx_cache[(size_t) il]);
|
ggml_format_name(tensor, tensor_name, il);
|
||||||
ggml_format_name(dflash.kv.k_ctx_cache[(size_t) il], "dflash_k_ctx_cache_%d", il);
|
|
||||||
ggml_format_name(dflash.kv.v_ctx_cache[(size_t) il], "dflash_v_ctx_cache_%d", il);
|
|
||||||
|
|
||||||
const size_t k_bytes = ggml_backend_buft_get_alloc_size(layer_buft, dflash.kv.k_ctx_cache[(size_t) il]);
|
const size_t tensor_bytes = ggml_backend_buft_get_alloc_size(layer_buft, tensor);
|
||||||
ggml_backend_buffer_t k_buf = ggml_backend_buft_alloc_buffer(layer_buft, k_bytes);
|
ggml_backend_buffer_t buf = ggml_backend_buft_alloc_buffer(layer_buft, tensor_bytes);
|
||||||
if (k_buf == nullptr) {
|
if (buf == nullptr) {
|
||||||
free_dflash_kv_cache_tensors();
|
LLAMA_LOG_ERROR("%s: failed to allocate %s buffer for layer %d (%zu bytes)\n",
|
||||||
return false;
|
__func__, tensor_tag, il, tensor_bytes);
|
||||||
}
|
|
||||||
ggml_backend_buffer_set_usage(k_buf, GGML_BACKEND_BUFFER_USAGE_COMPUTE);
|
|
||||||
ggml_backend_tensor_alloc(k_buf, dflash.kv.k_ctx_cache[(size_t) il], ggml_backend_buffer_get_base(k_buf));
|
|
||||||
ggml_backend_buffer_clear(k_buf, 0);
|
|
||||||
dflash.kv.cache_bufs.push_back(k_buf);
|
|
||||||
|
|
||||||
const size_t v_bytes = ggml_backend_buft_get_alloc_size(layer_buft, dflash.kv.v_ctx_cache[(size_t) il]);
|
|
||||||
ggml_backend_buffer_t v_buf = ggml_backend_buft_alloc_buffer(layer_buft, v_bytes);
|
|
||||||
if (v_buf == nullptr) {
|
|
||||||
free_dflash_kv_cache_tensors();
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
ggml_backend_buffer_set_usage(v_buf, GGML_BACKEND_BUFFER_USAGE_COMPUTE);
|
|
||||||
ggml_backend_tensor_alloc(v_buf, dflash.kv.v_ctx_cache[(size_t) il], ggml_backend_buffer_get_base(v_buf));
|
|
||||||
ggml_backend_buffer_clear(v_buf, 0);
|
|
||||||
dflash.kv.cache_bufs.push_back(v_buf);
|
|
||||||
|
|
||||||
dflash.kv.k_ctx_workspace[(size_t) il] = ggml_new_tensor_3d(dflash.kv.cache_ctx, GGML_TYPE_F32, n_embd_head_k, target_workspace_n_kv_total, n_head_kv);
|
|
||||||
dflash.kv.v_ctx_workspace[(size_t) il] = ggml_new_tensor_3d(dflash.kv.cache_ctx, GGML_TYPE_F32, n_embd_head_v, target_workspace_n_kv_total, n_head_kv);
|
|
||||||
if (dflash.kv.k_ctx_workspace[(size_t) il] == nullptr || dflash.kv.v_ctx_workspace[(size_t) il] == nullptr) {
|
|
||||||
free_dflash_kv_cache_tensors();
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_set_input(dflash.kv.k_ctx_workspace[(size_t) il]);
|
ggml_backend_buffer_set_usage(buf, GGML_BACKEND_BUFFER_USAGE_COMPUTE);
|
||||||
ggml_set_input(dflash.kv.v_ctx_workspace[(size_t) il]);
|
ggml_backend_tensor_alloc(buf, tensor, ggml_backend_buffer_get_base(buf));
|
||||||
ggml_format_name(dflash.kv.k_ctx_workspace[(size_t) il], "dflash_k_ctx_workspace_%d", il);
|
ggml_backend_buffer_clear(buf, 0);
|
||||||
ggml_format_name(dflash.kv.v_ctx_workspace[(size_t) il], "dflash_v_ctx_workspace_%d", il);
|
dflash.kv.cache_bufs.push_back(buf);
|
||||||
|
|
||||||
const size_t k_workspace_bytes = ggml_backend_buft_get_alloc_size(layer_buft, dflash.kv.k_ctx_workspace[(size_t) il]);
|
return true;
|
||||||
ggml_backend_buffer_t k_workspace_buf = ggml_backend_buft_alloc_buffer(layer_buft, k_workspace_bytes);
|
};
|
||||||
if (k_workspace_buf == nullptr) {
|
|
||||||
|
if (!alloc_kv_input(dflash.kv.k_ctx_cache[(size_t) il], "dflash_k_ctx_cache", "dflash_k_ctx_cache_%d",
|
||||||
|
n_embd_head_k, n_head_kv, target_cross_ctx) ||
|
||||||
|
!alloc_kv_input(dflash.kv.v_ctx_cache[(size_t) il], "dflash_v_ctx_cache", "dflash_v_ctx_cache_%d",
|
||||||
|
n_embd_head_v, n_head_kv, target_cross_ctx) ||
|
||||||
|
!alloc_kv_input(dflash.kv.k_ctx_workspace[(size_t) il], "dflash_k_ctx_workspace", "dflash_k_ctx_workspace_%d",
|
||||||
|
n_embd_head_k, target_workspace_n_kv_total, n_head_kv) ||
|
||||||
|
!alloc_kv_input(dflash.kv.v_ctx_workspace[(size_t) il], "dflash_v_ctx_workspace", "dflash_v_ctx_workspace_%d",
|
||||||
|
n_embd_head_v, target_workspace_n_kv_total, n_head_kv)) {
|
||||||
free_dflash_kv_cache_tensors();
|
free_dflash_kv_cache_tensors();
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
ggml_backend_buffer_set_usage(k_workspace_buf, GGML_BACKEND_BUFFER_USAGE_COMPUTE);
|
|
||||||
ggml_backend_tensor_alloc(k_workspace_buf, dflash.kv.k_ctx_workspace[(size_t) il], ggml_backend_buffer_get_base(k_workspace_buf));
|
|
||||||
ggml_backend_buffer_clear(k_workspace_buf, 0);
|
|
||||||
dflash.kv.cache_bufs.push_back(k_workspace_buf);
|
|
||||||
|
|
||||||
const size_t v_workspace_bytes = ggml_backend_buft_get_alloc_size(layer_buft, dflash.kv.v_ctx_workspace[(size_t) il]);
|
|
||||||
ggml_backend_buffer_t v_workspace_buf = ggml_backend_buft_alloc_buffer(layer_buft, v_workspace_bytes);
|
|
||||||
if (v_workspace_buf == nullptr) {
|
|
||||||
free_dflash_kv_cache_tensors();
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
ggml_backend_buffer_set_usage(v_workspace_buf, GGML_BACKEND_BUFFER_USAGE_COMPUTE);
|
|
||||||
ggml_backend_tensor_alloc(v_workspace_buf, dflash.kv.v_ctx_workspace[(size_t) il], ggml_backend_buffer_get_base(v_workspace_buf));
|
|
||||||
ggml_backend_buffer_clear(v_workspace_buf, 0);
|
|
||||||
dflash.kv.cache_bufs.push_back(v_workspace_buf);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
dflash.kv.workspace_token_capacity = target_token_capacity;
|
dflash.kv.workspace_token_capacity = target_token_capacity;
|
||||||
@ -201,10 +171,15 @@ bool llama_context::ensure_dflash_kv_cache_tensors(int32_t cross_ctx) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void llama_context::free_dflash_kv_cache_tensors() {
|
void llama_context::free_dflash_kv_cache_tensors() {
|
||||||
dflash.kv.k_ctx_cache.clear();
|
auto release_vector = [](auto & v) {
|
||||||
dflash.kv.v_ctx_cache.clear();
|
using vec_type = std::decay_t<decltype(v)>;
|
||||||
dflash.kv.k_ctx_workspace.clear();
|
vec_type().swap(v);
|
||||||
dflash.kv.v_ctx_workspace.clear();
|
};
|
||||||
|
|
||||||
|
release_vector(dflash.kv.k_ctx_cache);
|
||||||
|
release_vector(dflash.kv.v_ctx_cache);
|
||||||
|
release_vector(dflash.kv.k_ctx_workspace);
|
||||||
|
release_vector(dflash.kv.v_ctx_workspace);
|
||||||
dflash.kv.cache_write_pos = 0;
|
dflash.kv.cache_write_pos = 0;
|
||||||
dflash.kv.cache_n_filled = 0;
|
dflash.kv.cache_n_filled = 0;
|
||||||
dflash.kv.cache_update_rows = 0;
|
dflash.kv.cache_update_rows = 0;
|
||||||
@ -244,7 +219,9 @@ void llama_context::free_dflash_kv_cache_tensors() {
|
|||||||
ggml_backend_buffer_free(buf);
|
ggml_backend_buffer_free(buf);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
dflash.kv.cache_bufs.clear();
|
release_vector(dflash.kv.cache_bufs);
|
||||||
|
release_vector(dflash.kv.cache_compute_meta);
|
||||||
|
release_vector(dflash.kv.workspace_compute_meta);
|
||||||
if (dflash.kv.cache_ctx != nullptr) {
|
if (dflash.kv.cache_ctx != nullptr) {
|
||||||
ggml_free(dflash.kv.cache_ctx);
|
ggml_free(dflash.kv.cache_ctx);
|
||||||
dflash.kv.cache_ctx = nullptr;
|
dflash.kv.cache_ctx = nullptr;
|
||||||
|
|||||||
@ -2257,10 +2257,14 @@ bool create_tensors_helper::create_dflash_tensors(const LLM_TN & tn) {
|
|||||||
model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||||
model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
||||||
model.output = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
model.output = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||||
model.output_mtp = create_tensor(ctx_output, "output_extra.weight", {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
auto output_extra = create_tensor(ctx_output, "output_extra.weight", {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||||
|
if (output_extra != nullptr) {
|
||||||
|
model.output = output_extra;
|
||||||
|
}
|
||||||
if (model.output == nullptr && model.tok_embd != nullptr) {
|
if (model.output == nullptr && model.tok_embd != nullptr) {
|
||||||
model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
|
model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
|
||||||
}
|
}
|
||||||
|
model.output_mtp = model.output;
|
||||||
model.dflash_fc = create_tensor(ctx_output, tn(LLM_TENSOR_DFLASH_FC, "weight"), {(int64_t) hparams.dflash_n_target_features, n_embd}, 0);
|
model.dflash_fc = create_tensor(ctx_output, tn(LLM_TENSOR_DFLASH_FC, "weight"), {(int64_t) hparams.dflash_n_target_features, n_embd}, 0);
|
||||||
model.dflash_hidden_norm = create_tensor(ctx_output, tn(LLM_TENSOR_DFLASH_HIDDEN_NORM, "weight"), {n_embd}, 0);
|
model.dflash_hidden_norm = create_tensor(ctx_output, tn(LLM_TENSOR_DFLASH_HIDDEN_NORM, "weight"), {n_embd}, 0);
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user