diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 704b8d4290..2802103bdd 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -95,13 +95,16 @@ llama_kv_cache::llama_kv_cache( const layer_reuse_cb & reuse, const layer_share_cb & share) : model(model), hparams(hparams), v_trans(v_trans), - n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) { + n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type), + other(static_cast(mem_other)), + v_cells_impl(other ? other->v_cells_impl : std::make_shared()), + v_cells(*v_cells_impl) { // shared cells view the source cache's K/V tensors, so the cell count // follows the source allocation: a fitted target can be smaller than the // draft default and oversized views would overflow the source tensors - if (mem_other) { - const uint32_t size_other = static_cast(mem_other)->get_size(); + if (other) { + const uint32_t size_other = other->get_size(); if (kv_size != size_other) { LLAMA_LOG_WARN("%s: kv_size = %u overridden to %u to match the shared source cache\n", __func__, kv_size, size_other); kv_size = size_other; @@ -173,8 +176,6 @@ llama_kv_cache::llama_kv_cache( const bool is_mla = hparams.is_mla(); - other = static_cast(mem_other); - for (uint32_t il = 0; il < n_layer; il++) { if (!hparams.has_kv(il)) { LLAMA_LOG_DEBUG("%s: layer %3d: does not have KV cache\n", __func__, il); @@ -1105,7 +1106,6 @@ llama_kv_cache::slot_info llama_kv_cache::find_slot(const llama_ubatch & ubatch, void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch) { // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] if (other) { - v_cells = other->v_cells; return; } diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index f5ace6ae35..3d68f98c14 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -269,7 +269,9 @@ private: // TODO: temporary until we refactor to be able to share the same cells between 2 kv caches [TAG_KV_CACHE_SHARE_CELLS] llama_kv_cache * other; - std::vector v_cells; + std::shared_ptr v_cells_impl; + + llama_kv_cells_vec & v_cells; // maps from a sequence id to a stream id std::vector seq_to_stream; diff --git a/src/llama-kv-cells.h b/src/llama-kv-cells.h index 10063bf427..fddd31a0b2 100644 --- a/src/llama-kv-cells.h +++ b/src/llama-kv-cells.h @@ -531,3 +531,5 @@ private: } } }; + +using llama_kv_cells_vec = std::vector;