mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-27 23:50:20 -05:00
wip
This commit is contained in:
parent
1c4a91c0f3
commit
37c56c245e
@ -418,7 +418,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
|
||||
int32_t n_embd = 0;
|
||||
|
||||
bool kv_shared_with_target = false;
|
||||
bool is_mem_shared = false;
|
||||
|
||||
// Per-sequence cross-batch carryover: pair (h_p, x_{p+1}) at MTP pos p+1.
|
||||
// The last h-row of one process() call needs the first token of the NEXT
|
||||
@ -494,13 +494,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
llama_set_embeddings_nextn(ctx_tgt, true, /*masked*/ false);
|
||||
llama_set_embeddings_nextn(ctx_dft, true, /*masked*/ true);
|
||||
|
||||
kv_shared_with_target = llama_get_memory(ctx_dft) == nullptr;
|
||||
if (kv_shared_with_target) {
|
||||
llama_set_memory(ctx_dft, ctx_tgt);
|
||||
|
||||
// TODO: avoid the const cast
|
||||
llama_model_set_tok_embd(const_cast<llama_model *>(llama_get_model(ctx_dft)), llama_model_get_tok_embd(llama_get_model(ctx_tgt)));
|
||||
}
|
||||
// TODO: fix this
|
||||
is_mem_shared = true;
|
||||
|
||||
pending_h.assign(n_seq, std::vector<float>(n_embd, 0.0f));
|
||||
|
||||
@ -541,7 +536,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
|
||||
auto * ctx_dft = this->params.ctx_dft;
|
||||
const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id);
|
||||
if (pos_max < N - 1 && !kv_shared_with_target) {
|
||||
|
||||
if (pos_max < N - 1 && !is_mem_shared) {
|
||||
LOG_WRN("%s: ctx_dft pos_max=%d < N-1=%d - "
|
||||
"process() hook may not have run on every prefill ubatch "
|
||||
"(need_embd / logits=1 on every prompt position?). "
|
||||
@ -585,7 +581,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
const size_t row_bytes = (size_t) n_embd * sizeof(float);
|
||||
|
||||
// if kv is shared with target (e.g Gemma4), then we can skip this catch-up decode
|
||||
if (!kv_shared_with_target) {
|
||||
if (!is_mem_shared) {
|
||||
common_batch_clear(batch);
|
||||
|
||||
for (int k = 0; k < n_tokens; ++k) {
|
||||
@ -645,7 +641,6 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
|
||||
void draft(common_speculative_draft_params_vec & dparams) override {
|
||||
auto & ctx_dft = params.ctx_dft;
|
||||
auto & ctx_tgt = params.ctx_tgt;
|
||||
|
||||
common_batch_clear(batch);
|
||||
|
||||
@ -671,8 +666,6 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
|
||||
h_row = pending_h[seq_id].data();
|
||||
std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes);
|
||||
|
||||
llama_memory_seq_rm(llama_get_memory(ctx_dft), seq_id, dp.n_past, -1);
|
||||
}
|
||||
|
||||
int ret = llama_decode(ctx_dft, batch);
|
||||
@ -733,8 +726,6 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
|
||||
common_batch_add(batch, id, dp.n_past + i + 1, { seq_id }, true);
|
||||
std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes);
|
||||
|
||||
llama_memory_seq_rm(llama_get_memory(ctx_dft), seq_id, dp.n_past, -1);
|
||||
}
|
||||
|
||||
if (batch.n_tokens == 0) {
|
||||
@ -751,8 +742,6 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
++i;
|
||||
}
|
||||
|
||||
llama_synchronize(ctx_dft);
|
||||
|
||||
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
|
||||
auto & dp = dparams[seq_id];
|
||||
if (!dp.drafting) {
|
||||
|
||||
@ -388,6 +388,8 @@ extern "C" {
|
||||
// note: the samplers must be sampler chains (i.e. use llama_sampler_chain_init)
|
||||
struct llama_sampler_seq_config * samplers;
|
||||
size_t n_samplers;
|
||||
|
||||
struct llama_context * ctx_src;
|
||||
};
|
||||
|
||||
struct llama_model_tensor_override {
|
||||
|
||||
@ -126,9 +126,10 @@ llama_context::llama_context(
|
||||
cparams.embeddings_nextn_masked = false;
|
||||
cparams.offload_kqv = params.offload_kqv;
|
||||
cparams.no_perf = params.no_perf;
|
||||
cparams.pooling_type = params.pooling_type;
|
||||
cparams.warmup = false;
|
||||
|
||||
cparams.ctx_type = params.ctx_type;
|
||||
cparams.pooling_type = params.pooling_type;
|
||||
|
||||
cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
|
||||
cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base;
|
||||
@ -141,7 +142,7 @@ llama_context::llama_context(
|
||||
cparams.cb_eval = params.cb_eval;
|
||||
cparams.cb_eval_user_data = params.cb_eval_user_data;
|
||||
|
||||
cparams.ctx_type = params.ctx_type;
|
||||
cparams.ctx_src = params.ctx_src;
|
||||
|
||||
// Initialize backend samplers here so they are part of the sampling graph
|
||||
// before the reserve passes run later in this function. This avoids a later
|
||||
@ -360,7 +361,8 @@ llama_context::llama_context(
|
||||
/*.type_k =*/ params.type_k,
|
||||
/*.type_v =*/ params.type_v,
|
||||
/*.swa_full =*/ params.swa_full,
|
||||
/*.ctx_type= */ cparams.ctx_type,
|
||||
/*.ctx_type =*/ cparams.ctx_type,
|
||||
/*.mem_src =*/ params.ctx_src ? params.ctx_src->memory.get() : nullptr,
|
||||
};
|
||||
|
||||
memory.reset(model.create_memory(params_mem, cparams));
|
||||
@ -792,14 +794,8 @@ uint32_t llama_context::n_threads_batch() const {
|
||||
return cparams.n_threads_batch;
|
||||
}
|
||||
|
||||
llama_memory_ptr llama_context::get_memory() const {
|
||||
return memory;
|
||||
}
|
||||
|
||||
void llama_context::set_memory(llama_memory_ptr memory) {
|
||||
this->memory = std::move(memory);
|
||||
|
||||
sched_need_reserve = true;
|
||||
llama_memory_t llama_context::get_memory() const {
|
||||
return memory.get();
|
||||
}
|
||||
|
||||
bool llama_context::memory_update(bool optimize) {
|
||||
@ -3441,6 +3437,7 @@ llama_context_params llama_context_default_params() {
|
||||
/*.kv_unified =*/ false,
|
||||
/*.sampler =*/ nullptr,
|
||||
/*.n_sampler =*/ 0,
|
||||
/*.ctx_src =*/ nullptr,
|
||||
};
|
||||
|
||||
return result;
|
||||
@ -3658,12 +3655,8 @@ void llama_set_embeddings_nextn(llama_context * ctx, bool value, bool masked) {
|
||||
ctx->set_embeddings_nextn(value, masked);
|
||||
}
|
||||
|
||||
void llama_set_memory(llama_context * ctx, llama_context * src) {
|
||||
ctx->set_memory(src->get_memory());
|
||||
}
|
||||
|
||||
llama_memory_t llama_get_memory(const struct llama_context * ctx) {
|
||||
return ctx->get_memory().get();
|
||||
return ctx->get_memory();
|
||||
}
|
||||
|
||||
float * llama_get_embeddings_nextn(llama_context * ctx) {
|
||||
|
||||
@ -71,8 +71,7 @@ struct llama_context {
|
||||
uint32_t n_threads() const;
|
||||
uint32_t n_threads_batch() const;
|
||||
|
||||
llama_memory_ptr get_memory() const;
|
||||
void set_memory(llama_memory_ptr memory);
|
||||
llama_memory_t get_memory() const;
|
||||
|
||||
// return true if the memory was updated
|
||||
bool memory_update(bool optimize);
|
||||
|
||||
@ -49,4 +49,6 @@ struct llama_cparams {
|
||||
|
||||
ggml_backend_sched_eval_callback cb_eval;
|
||||
void * cb_eval_user_data;
|
||||
|
||||
llama_context * ctx_src;
|
||||
};
|
||||
|
||||
@ -93,7 +93,6 @@ LLAMA_API llama_memory_breakdown llama_get_memory_breakdown(const struct llama_c
|
||||
// If masked == true, output the embeddings only for the tokens with batch.logits != 0
|
||||
// If masked == false, output the embeddings for all tokens in the batch regardless of batch.logits
|
||||
LLAMA_API void llama_set_embeddings_nextn(struct llama_context * ctx, bool value, bool masked);
|
||||
LLAMA_API void llama_set_memory(struct llama_context * ctx, struct llama_context * src);
|
||||
|
||||
// mirrors:
|
||||
// LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
|
||||
@ -101,6 +100,3 @@ LLAMA_API float * llama_get_embeddings_nextn(struct llama_context * ctx);
|
||||
|
||||
// LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
|
||||
LLAMA_API float * llama_get_embeddings_nextn_ith(struct llama_context * ctx, int32_t i);
|
||||
|
||||
LLAMA_API ggml_tensor * llama_model_get_tok_embd(const struct llama_model * model);
|
||||
LLAMA_API void llama_model_set_tok_embd( struct llama_model * model, ggml_tensor * tensor);
|
||||
|
||||
@ -2543,16 +2543,6 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||
ggml_build_forward_expand(gf, v_cur);
|
||||
}
|
||||
|
||||
int il_save = il;
|
||||
|
||||
if (arch == LLM_ARCH_GEMMA4_ASSISTANT) {
|
||||
if (il == n_layer - 1) {
|
||||
il = 59;
|
||||
} else {
|
||||
il = 58;
|
||||
}
|
||||
}
|
||||
|
||||
const auto * mctx_iswa = inp->mctx;
|
||||
|
||||
const auto * mctx_cur = is_swa ? mctx_iswa->get_swa() : mctx_iswa->get_base();
|
||||
@ -2576,8 +2566,6 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
|
||||
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
|
||||
|
||||
il = il_save;
|
||||
|
||||
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
|
||||
cb(cur, "kqv_out", il);
|
||||
|
||||
|
||||
@ -32,7 +32,7 @@ llama_kv_cache_dsa::llama_kv_cache_dsa(
|
||||
kv_mla = std::make_unique<llama_kv_cache>(
|
||||
model, model.hparams, type_k, type_v,
|
||||
v_trans, offload, unified, kv_size, n_seq_max, n_pad,
|
||||
n_swa, swa_type, filter, reuse);
|
||||
n_swa, swa_type, nullptr, filter, reuse, nullptr);
|
||||
|
||||
// we use llama_kv_cache for caching indexer keys
|
||||
// by hand-tweaking some hparams we fool it to create
|
||||
@ -49,7 +49,7 @@ llama_kv_cache_dsa::llama_kv_cache_dsa(
|
||||
kv_lid = std::make_unique<llama_kv_cache>(
|
||||
model, hparams_lid, type_k, type_v,
|
||||
v_trans, offload, unified, kv_size, n_seq_max, n_pad,
|
||||
n_swa, swa_type, filter, reuse);
|
||||
n_swa, swa_type, nullptr, filter, reuse, nullptr);
|
||||
}
|
||||
|
||||
void llama_kv_cache_dsa::clear(bool data) {
|
||||
|
||||
@ -23,8 +23,10 @@ llama_kv_cache_iswa::llama_kv_cache_iswa(
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_ubatch,
|
||||
uint32_t n_pad,
|
||||
llama_memory_t mem_src,
|
||||
const layer_filter_cb & filter,
|
||||
const layer_reuse_cb & reuse) : hparams(model.hparams), unified(unified) {
|
||||
const layer_reuse_cb & reuse,
|
||||
const layer_share_cb & share) : hparams(model.hparams), unified(unified) {
|
||||
|
||||
// chain filters
|
||||
const layer_filter_cb filter_base = [&](int32_t il) {
|
||||
@ -59,17 +61,27 @@ llama_kv_cache_iswa::llama_kv_cache_iswa(
|
||||
|
||||
LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base);
|
||||
|
||||
llama_memory_t mem_src_base = nullptr;
|
||||
if (mem_src) {
|
||||
mem_src_base = static_cast<llama_kv_cache_iswa *>(mem_src)->get_base();
|
||||
}
|
||||
|
||||
llama_memory_t mem_src_swa = nullptr;
|
||||
if (mem_src) {
|
||||
mem_src_swa = static_cast<llama_kv_cache_iswa *>(mem_src)->get_swa();
|
||||
}
|
||||
|
||||
kv_base = std::make_unique<llama_kv_cache>(
|
||||
model, hparams, type_k, type_v,
|
||||
v_trans, offload, unified, size_base, n_seq_max, n_pad,
|
||||
0, LLAMA_SWA_TYPE_NONE, filter_base, reuse);
|
||||
0, LLAMA_SWA_TYPE_NONE, mem_src_base, filter_base, reuse, share);
|
||||
|
||||
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
|
||||
|
||||
kv_swa = std::make_unique<llama_kv_cache>(
|
||||
model, hparams, type_k, type_v,
|
||||
v_trans, offload, unified, size_swa, n_seq_max, n_pad,
|
||||
hparams.n_swa, hparams.swa_type, filter_swa, reuse);
|
||||
hparams.n_swa, hparams.swa_type, mem_src_swa, filter_swa, reuse, share);
|
||||
}
|
||||
|
||||
void llama_kv_cache_iswa::clear(bool data) {
|
||||
|
||||
@ -25,8 +25,10 @@ public:
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_ubatch,
|
||||
uint32_t n_pad,
|
||||
llama_memory_t mem_src,
|
||||
const layer_filter_cb & filter,
|
||||
const layer_reuse_cb & reuse);
|
||||
const layer_reuse_cb & reuse,
|
||||
const layer_share_cb & share);
|
||||
|
||||
~llama_kv_cache_iswa() = default;
|
||||
|
||||
|
||||
@ -90,8 +90,10 @@ llama_kv_cache::llama_kv_cache(
|
||||
uint32_t n_pad,
|
||||
uint32_t n_swa,
|
||||
llama_swa_type swa_type,
|
||||
llama_memory_t mem_src,
|
||||
const layer_filter_cb & filter,
|
||||
const layer_reuse_cb & reuse) :
|
||||
const layer_reuse_cb & reuse,
|
||||
const layer_share_cb & share) :
|
||||
model(model), hparams(hparams), v_trans(v_trans),
|
||||
n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
|
||||
|
||||
@ -160,6 +162,8 @@ llama_kv_cache::llama_kv_cache(
|
||||
|
||||
const bool is_mla = hparams.is_mla();
|
||||
|
||||
other = static_cast<llama_kv_cache *>(mem_src);
|
||||
|
||||
for (uint32_t il = 0; il < n_layer; il++) {
|
||||
if (!hparams.has_kv(il)) {
|
||||
LLAMA_LOG_DEBUG("%s: layer %3d: does not have KV cache\n", __func__, il);
|
||||
@ -171,6 +175,24 @@ llama_kv_cache::llama_kv_cache(
|
||||
continue;
|
||||
}
|
||||
|
||||
if (share && other) {
|
||||
const int32_t il_share = share(il);
|
||||
|
||||
if (il_share >= 0) {
|
||||
const auto & layer_share = other->layers[other->map_layer_ids[il_share]];
|
||||
|
||||
LLAMA_LOG_WARN("%s: layer %3d: sharing with layer %d. k = %p, v = %p\n", __func__, il, il_share,
|
||||
layer_share.k->data, layer_share.v->data);
|
||||
|
||||
map_layer_ids[il] = layers.size();
|
||||
|
||||
layers.push_back(layer_share);
|
||||
layers.back().il = il;
|
||||
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if (n_embd_head_k_all == 0) {
|
||||
n_embd_head_k_all = (int32_t) hparams.n_embd_head_k(il);
|
||||
} else if (n_embd_head_k_all > 0 && n_embd_head_k_all != (int32_t) hparams.n_embd_head_k(il)) {
|
||||
@ -347,6 +369,11 @@ void llama_kv_cache::clear(bool data) {
|
||||
}
|
||||
|
||||
bool llama_kv_cache::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
||||
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
|
||||
if (other) {
|
||||
return true;
|
||||
}
|
||||
|
||||
GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));
|
||||
|
||||
if (p0 < 0) {
|
||||
@ -410,6 +437,11 @@ bool llama_kv_cache::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
||||
}
|
||||
|
||||
void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
||||
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
|
||||
if (other) {
|
||||
return;
|
||||
}
|
||||
|
||||
GGML_ASSERT(seq_id_src >= 0 && (size_t) seq_id_src < seq_to_stream.size());
|
||||
GGML_ASSERT(seq_id_dst >= 0 && (size_t) seq_id_dst < seq_to_stream.size());
|
||||
|
||||
@ -497,6 +529,11 @@ void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, ll
|
||||
}
|
||||
|
||||
void llama_kv_cache::seq_keep(llama_seq_id seq_id) {
|
||||
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
|
||||
if (other) {
|
||||
return;
|
||||
}
|
||||
|
||||
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
|
||||
|
||||
auto & cells = v_cells[seq_to_stream[seq_id]];
|
||||
@ -519,6 +556,11 @@ void llama_kv_cache::seq_keep(llama_seq_id seq_id) {
|
||||
}
|
||||
|
||||
void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
||||
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
|
||||
if (other) {
|
||||
return;
|
||||
}
|
||||
|
||||
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
|
||||
GGML_ASSERT(hparams.n_pos_per_embd() == 1 && "seq_add() is only supported for n_pos_per_embd() == 1");
|
||||
|
||||
@ -564,6 +606,11 @@ void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, ll
|
||||
}
|
||||
|
||||
void llama_kv_cache::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
||||
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
|
||||
if (other) {
|
||||
return;
|
||||
}
|
||||
|
||||
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
|
||||
GGML_ASSERT(hparams.n_pos_per_embd() == 1 && "seq_div() is only supported for n_pos_per_embd() == 1");
|
||||
|
||||
@ -746,6 +793,11 @@ llama_kv_cache::slot_info_vec_t llama_kv_cache::prepare(const std::vector<llama_
|
||||
}
|
||||
|
||||
bool llama_kv_cache::update(llama_context * lctx, bool do_shift, const stream_copy_info & sc_info) {
|
||||
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
|
||||
if (other) {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool updated = false;
|
||||
|
||||
auto * sched = lctx->get_sched();
|
||||
@ -1021,6 +1073,12 @@ llama_kv_cache::slot_info llama_kv_cache::find_slot(const llama_ubatch & ubatch,
|
||||
}
|
||||
|
||||
void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch) {
|
||||
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
|
||||
if (other) {
|
||||
v_cells = other->v_cells;
|
||||
return;
|
||||
}
|
||||
|
||||
// keep track of the max sequence position that we would overwrite with this ubatch
|
||||
// for non-SWA cache, this would be always empty
|
||||
llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
|
||||
@ -1552,10 +1610,6 @@ static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, T * data
|
||||
goto skip;
|
||||
}
|
||||
|
||||
if (ubatch->n_tokens == 1 && p0 == p1) {
|
||||
goto skip;
|
||||
}
|
||||
|
||||
// M-RoPE causal mask
|
||||
if (is_2d) {
|
||||
if (p0 == p1) {
|
||||
@ -1819,6 +1873,9 @@ void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
|
||||
}
|
||||
|
||||
ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_context * lctx) const {
|
||||
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
|
||||
GGML_ASSERT(!other);
|
||||
|
||||
auto * ctx = res->get_ctx();
|
||||
auto * gf = res->get_gf();
|
||||
|
||||
@ -1864,6 +1921,11 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
|
||||
}
|
||||
|
||||
void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
|
||||
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
|
||||
if (other) {
|
||||
return;
|
||||
}
|
||||
|
||||
GGML_UNUSED(flags);
|
||||
|
||||
io.write(&n_stream, sizeof(n_stream));
|
||||
@ -1929,6 +1991,11 @@ void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, lla
|
||||
}
|
||||
|
||||
void llama_kv_cache::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
|
||||
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
|
||||
if (other) {
|
||||
return;
|
||||
}
|
||||
|
||||
GGML_UNUSED(flags);
|
||||
|
||||
GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));
|
||||
|
||||
@ -98,7 +98,7 @@ public:
|
||||
// likely through `struct llama_memory_params`
|
||||
llama_kv_cache(
|
||||
const llama_model & model,
|
||||
const llama_hparams & hparams,
|
||||
const llama_hparams & hparams,
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
bool v_trans,
|
||||
@ -109,8 +109,10 @@ public:
|
||||
uint32_t n_pad,
|
||||
uint32_t n_swa,
|
||||
llama_swa_type swa_type,
|
||||
llama_memory_t mem_src,
|
||||
const layer_filter_cb & filter,
|
||||
const layer_reuse_cb & reuse);
|
||||
const layer_reuse_cb & reuse,
|
||||
const layer_share_cb & share);
|
||||
|
||||
~llama_kv_cache() = default;
|
||||
|
||||
@ -264,6 +266,9 @@ private:
|
||||
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method
|
||||
std::vector<uint32_t> v_heads;
|
||||
|
||||
// TODO: temporary until we refactor to be able to share the same cells between 2 kv caches [TAG_KV_CACHE_SHARE_CELLS]
|
||||
llama_kv_cache * other;
|
||||
|
||||
std::vector<llama_kv_cells> v_cells;
|
||||
|
||||
// maps from a sequence id to a stream id
|
||||
|
||||
@ -43,9 +43,11 @@ llama_memory_hybrid_iswa::llama_memory_hybrid_iswa(
|
||||
n_seq_max,
|
||||
n_ubatch,
|
||||
n_pad,
|
||||
nullptr,
|
||||
filter_attn == nullptr ?
|
||||
[&](int32_t il) { return !hparams.is_recr(il); }
|
||||
: filter_attn,
|
||||
nullptr,
|
||||
nullptr
|
||||
)),
|
||||
mem_recr(new llama_memory_recurrent(
|
||||
|
||||
@ -44,9 +44,11 @@ llama_memory_hybrid::llama_memory_hybrid(
|
||||
n_pad,
|
||||
n_swa,
|
||||
swa_type,
|
||||
nullptr,
|
||||
filter_attn == nullptr ?
|
||||
[&](int32_t il) { return !hparams.is_recr(il); }
|
||||
: filter_attn,
|
||||
nullptr,
|
||||
nullptr
|
||||
)),
|
||||
mem_recr(new llama_memory_recurrent(
|
||||
|
||||
@ -23,6 +23,8 @@ struct llama_memory_params {
|
||||
bool swa_full;
|
||||
|
||||
llama_context_type ctx_type;
|
||||
|
||||
llama_memory_t mem_src;
|
||||
};
|
||||
|
||||
enum llama_memory_status {
|
||||
@ -76,6 +78,8 @@ struct llama_memory_i {
|
||||
// return negative value to indicate that the layer il should not reuse memory
|
||||
using layer_reuse_cb = std::function<int32_t(int32_t il)>;
|
||||
|
||||
using layer_share_cb = std::function<int32_t(int32_t il)>;
|
||||
|
||||
virtual ~llama_memory_i() = default;
|
||||
|
||||
// split the input batch into a set of ubatches and verify that they can fit into the cache
|
||||
@ -122,4 +126,4 @@ struct llama_memory_i {
|
||||
virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) = 0;
|
||||
};
|
||||
|
||||
using llama_memory_ptr = std::shared_ptr<llama_memory_i>;
|
||||
using llama_memory_ptr = std::unique_ptr<llama_memory_i>;
|
||||
|
||||
@ -1987,7 +1987,6 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||
case LLM_ARCH_WAVTOKENIZER_DEC:
|
||||
case LLM_ARCH_MODERN_BERT:
|
||||
case LLM_ARCH_GEMMA_EMBEDDING:
|
||||
case LLM_ARCH_GEMMA4_ASSISTANT:
|
||||
case LLM_ARCH_DREAM:
|
||||
case LLM_ARCH_LLADA:
|
||||
case LLM_ARCH_LLADA_MOE:
|
||||
@ -2097,8 +2096,9 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||
/* filter_recr */ std::move(filter_recr));
|
||||
}
|
||||
} else {
|
||||
llama_memory_i::layer_reuse_cb reuse = nullptr;
|
||||
llama_kv_cache::layer_filter_cb filter = nullptr;
|
||||
llama_memory_i::layer_reuse_cb reuse = nullptr;
|
||||
llama_kv_cache::layer_share_cb share = nullptr;
|
||||
|
||||
if (arch == LLM_ARCH_GEMMA3N || arch == LLM_ARCH_GEMMA4) {
|
||||
reuse = [&](uint32_t il) {
|
||||
@ -2124,23 +2124,56 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||
}
|
||||
}
|
||||
|
||||
llama_memory_t mem_src = cparams.ctx_src ? llama_get_memory(cparams.ctx_src) : nullptr;
|
||||
|
||||
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
|
||||
GGML_ASSERT(hparams.is_swa_any());
|
||||
|
||||
res = new llama_kv_cache_iswa(
|
||||
*this,
|
||||
params.type_k,
|
||||
params.type_v,
|
||||
!cparams.flash_attn,
|
||||
cparams.offload_kqv,
|
||||
params.swa_full,
|
||||
cparams.kv_unified,
|
||||
cparams.n_ctx_seq,
|
||||
cparams.n_seq_max,
|
||||
cparams.n_ubatch,
|
||||
1,
|
||||
filter,
|
||||
reuse);
|
||||
if (arch == LLM_ARCH_GEMMA4_ASSISTANT) {
|
||||
share = [&](int32_t il) {
|
||||
const llama_model * model_src = llama_get_model(cparams.ctx_src);
|
||||
|
||||
if (hparams.is_swa(il)) {
|
||||
return llama_model_n_layer(model_src) - 2;
|
||||
}
|
||||
|
||||
return llama_model_n_layer(model_src) - 1;
|
||||
};
|
||||
|
||||
res = new llama_kv_cache_iswa(
|
||||
*this,
|
||||
params.type_k,
|
||||
params.type_v,
|
||||
!cparams.flash_attn,
|
||||
cparams.offload_kqv,
|
||||
params.swa_full,
|
||||
cparams.kv_unified,
|
||||
cparams.n_ctx_seq,
|
||||
cparams.n_seq_max,
|
||||
cparams.n_ubatch,
|
||||
1,
|
||||
mem_src,
|
||||
filter,
|
||||
reuse,
|
||||
share);
|
||||
} else {
|
||||
res = new llama_kv_cache_iswa(
|
||||
*this,
|
||||
params.type_k,
|
||||
params.type_v,
|
||||
!cparams.flash_attn,
|
||||
cparams.offload_kqv,
|
||||
params.swa_full,
|
||||
cparams.kv_unified,
|
||||
cparams.n_ctx_seq,
|
||||
cparams.n_seq_max,
|
||||
cparams.n_ubatch,
|
||||
1,
|
||||
mem_src,
|
||||
filter,
|
||||
reuse,
|
||||
share);
|
||||
}
|
||||
} else {
|
||||
GGML_ASSERT(!hparams.is_swa_any());
|
||||
|
||||
@ -2157,7 +2190,9 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||
1,
|
||||
hparams.n_swa,
|
||||
hparams.swa_type,
|
||||
mem_src,
|
||||
filter,
|
||||
nullptr,
|
||||
nullptr);
|
||||
}
|
||||
}
|
||||
@ -2631,11 +2666,3 @@ void llama_model_base::create_tensor_qkv(llama_layer & layer, int bid,
|
||||
layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", bid), {n_embd_v_}, TENSOR_NOT_REQUIRED);
|
||||
}
|
||||
}
|
||||
|
||||
ggml_tensor * llama_model_get_tok_embd(const struct llama_model * model) {
|
||||
return model->tok_embd;
|
||||
}
|
||||
|
||||
void llama_model_set_tok_embd(struct llama_model * model, ggml_tensor * tensor) {
|
||||
model->tok_embd = tensor;
|
||||
}
|
||||
|
||||
@ -104,7 +104,10 @@ llama_model_gemma4_assistant::graph::graph(const llama_model & model, const llm_
|
||||
res->add_input(std::move(inp));
|
||||
}
|
||||
|
||||
ggml_tensor * x = ggml_get_rows(ctx0, model.tok_embd, inp_tokens);
|
||||
GGML_ASSERT(cparams.ctx_src != nullptr);
|
||||
const auto * model_src = llama_get_model(cparams.ctx_src);
|
||||
|
||||
ggml_tensor * x = ggml_get_rows(ctx0, model_src->tok_embd, inp_tokens);
|
||||
x = ggml_scale(ctx0, x, sqrtf((float) n_embd_backbone));
|
||||
cb(x, "inp_embd_target", -1);
|
||||
|
||||
|
||||
@ -10,8 +10,6 @@
|
||||
#include "common.h"
|
||||
#include "fit.h"
|
||||
#include "llama.h"
|
||||
#include "../../src/llama-ext.h" // staging API: llama_set_mtp_source
|
||||
#include "ggml-cpp.h"
|
||||
#include "log.h"
|
||||
#include "sampling.h"
|
||||
#include "speculative.h"
|
||||
@ -847,30 +845,14 @@ private:
|
||||
}
|
||||
cparams_dft.n_rs_seq = 0;
|
||||
|
||||
bool skip_measure = false;
|
||||
//TODO: remove this
|
||||
if (spec_mtp && has_draft) {
|
||||
struct gguf_init_params meta_params = {
|
||||
/* .no_alloc = */ true,
|
||||
/* .ctx = */ nullptr,
|
||||
};
|
||||
gguf_context_ptr meta(gguf_init_from_file(params_dft.model.path.c_str(), meta_params));
|
||||
|
||||
if (std::string(gguf_get_val_str(meta.get(), gguf_find_key(meta.get(), "general.architecture"))) == "gemma4-assistant") {
|
||||
skip_measure = true;
|
||||
SRV_WRN("[spec] skipping --fit memory measurement for Gemma 4 assistant draft model '%s'\n",
|
||||
params_dft.model.path.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<ggml_backend_dev_t> devs;
|
||||
uint32_t hp_ngl = 0;
|
||||
uint32_t hp_nct = 0;
|
||||
uint32_t hp_nex = 0;
|
||||
if (!skip_measure) try {
|
||||
try {
|
||||
auto dmd = common_get_device_memory_data(
|
||||
params_dft.model.path.c_str(), &mparams_dft, &cparams_dft,
|
||||
devs, hp_ngl, hp_nct, hp_nex, GGML_LOG_LEVEL_ERROR);
|
||||
devs, hp_ngl, hp_nct, hp_nex, GGML_LOG_LEVEL_DEBUG);
|
||||
|
||||
GGML_ASSERT(!params_base.fit_params_target.empty());
|
||||
size_t total = 0;
|
||||
@ -972,6 +954,8 @@ private:
|
||||
// note: for small models maybe we can set this to the maximum possible draft from all speculative types
|
||||
// the extra memory for small models is likely negligible?
|
||||
cparams.n_rs_seq = 0;
|
||||
cparams.ctx_src = ctx_tgt;
|
||||
|
||||
ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams));
|
||||
|
||||
params_base.speculative.draft.ctx_tgt = ctx_tgt;
|
||||
@ -987,6 +971,7 @@ private:
|
||||
cparams_mtp.type_v = params_base.speculative.draft.cache_type_v;
|
||||
cparams_mtp.n_rs_seq = 0;
|
||||
cparams_mtp.n_outputs_max = params_base.n_parallel;
|
||||
cparams_mtp.ctx_src = ctx_tgt;
|
||||
|
||||
ctx_dft.reset(llama_init_from_model(model_tgt, cparams_mtp));
|
||||
if (ctx_dft == nullptr) {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user