mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-27 23:50:20 -05:00
wip
This commit is contained in:
parent
65eef9549c
commit
1c4a91c0f3
@ -491,11 +491,16 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
}
|
||||
}
|
||||
|
||||
kv_shared_with_target = llama_get_memory(ctx_dft) == nullptr;
|
||||
|
||||
llama_set_embeddings_nextn(ctx_tgt, true, /*masked*/ false);
|
||||
llama_set_embeddings_nextn(ctx_dft, true, /*masked*/ true);
|
||||
llama_set_mtp_source(ctx_dft, ctx_tgt);
|
||||
|
||||
kv_shared_with_target = llama_get_memory(ctx_dft) == nullptr;
|
||||
if (kv_shared_with_target) {
|
||||
llama_set_memory(ctx_dft, ctx_tgt);
|
||||
|
||||
// TODO: avoid the const cast
|
||||
llama_model_set_tok_embd(const_cast<llama_model *>(llama_get_model(ctx_dft)), llama_model_get_tok_embd(llama_get_model(ctx_tgt)));
|
||||
}
|
||||
|
||||
pending_h.assign(n_seq, std::vector<float>(n_embd, 0.0f));
|
||||
|
||||
@ -640,6 +645,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
|
||||
void draft(common_speculative_draft_params_vec & dparams) override {
|
||||
auto & ctx_dft = params.ctx_dft;
|
||||
auto & ctx_tgt = params.ctx_tgt;
|
||||
|
||||
common_batch_clear(batch);
|
||||
|
||||
@ -665,6 +671,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
|
||||
h_row = pending_h[seq_id].data();
|
||||
std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes);
|
||||
|
||||
llama_memory_seq_rm(llama_get_memory(ctx_dft), seq_id, dp.n_past, -1);
|
||||
}
|
||||
|
||||
int ret = llama_decode(ctx_dft, batch);
|
||||
@ -725,6 +733,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
|
||||
common_batch_add(batch, id, dp.n_past + i + 1, { seq_id }, true);
|
||||
std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes);
|
||||
|
||||
llama_memory_seq_rm(llama_get_memory(ctx_dft), seq_id, dp.n_past, -1);
|
||||
}
|
||||
|
||||
if (batch.n_tokens == 0) {
|
||||
@ -741,6 +751,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
++i;
|
||||
}
|
||||
|
||||
llama_synchronize(ctx_dft);
|
||||
|
||||
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
|
||||
auto & dp = dparams[seq_id];
|
||||
if (!dp.drafting) {
|
||||
|
||||
@ -301,16 +301,16 @@ bool llama_batch_allocr::init(
|
||||
ok = false;
|
||||
}
|
||||
|
||||
if (!ok) {
|
||||
LLAMA_LOG_ERROR(
|
||||
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
|
||||
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
|
||||
" - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
|
||||
" it is required that the sequence positions remain consecutive: Y = X + 1\n",
|
||||
__func__, s, s, p0, s, seq_pos_min(s));
|
||||
//if (!ok) {
|
||||
// LLAMA_LOG_ERROR(
|
||||
// "%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
|
||||
// " - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
|
||||
// " - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
|
||||
// " it is required that the sequence positions remain consecutive: Y = X + 1\n",
|
||||
// __func__, s, s, p0, s, seq_pos_min(s));
|
||||
|
||||
return false;
|
||||
}
|
||||
// return false;
|
||||
//}
|
||||
}
|
||||
|
||||
if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
|
||||
|
||||
@ -38,12 +38,6 @@ static uint32_t ctx_type_to_embd_inp(const llama_hparams & hparams, llama_contex
|
||||
throw std::runtime_error("Unsupported ctx type");
|
||||
}
|
||||
|
||||
namespace {
|
||||
struct src_mctx_reset_on_exit {
|
||||
llama_memory_context_ptr * slot;
|
||||
~src_mctx_reset_on_exit() { if (slot) slot->reset(); }
|
||||
};
|
||||
|
||||
static void llama_assert_gemma4_mtp_source_placement(
|
||||
const llama_context * ctx,
|
||||
const llama_context * src) {
|
||||
@ -92,7 +86,6 @@ static void llama_assert_gemma4_mtp_source_placement(
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
llama_context::llama_context(
|
||||
const llama_model & model,
|
||||
@ -514,23 +507,6 @@ void llama_context::sched_reserve() {
|
||||
}
|
||||
}
|
||||
|
||||
// When called from decode(), src_mctx_for_decode is already populated and
|
||||
// we must not drop it on exit (process_ubatch still needs it). Snapshot
|
||||
// only when sched_reserve runs standalone (e.g. lazy first-decode reserve
|
||||
// when set_mtp_source flipped sched_need_reserve).
|
||||
const bool owns_src_snapshot = src_ctx && !src_mctx_for_decode;
|
||||
if (owns_src_snapshot) {
|
||||
auto * src_memory = src_ctx->get_memory();
|
||||
if (!src_memory) {
|
||||
throw std::runtime_error("MTP source context has no memory module");
|
||||
}
|
||||
src_mctx_for_decode = src_memory->init_full();
|
||||
if (!src_mctx_for_decode) {
|
||||
throw std::runtime_error("failed to initialize MTP source memory snapshot");
|
||||
}
|
||||
}
|
||||
src_mctx_reset_on_exit reserve_src_drop{owns_src_snapshot ? &src_mctx_for_decode : nullptr};
|
||||
|
||||
// avoid reserving graphs with zero outputs - assume one output per sequence
|
||||
const int n_outputs = n_seqs;
|
||||
|
||||
@ -816,8 +792,14 @@ uint32_t llama_context::n_threads_batch() const {
|
||||
return cparams.n_threads_batch;
|
||||
}
|
||||
|
||||
llama_memory_t llama_context::get_memory() const {
|
||||
return memory.get();
|
||||
llama_memory_ptr llama_context::get_memory() const {
|
||||
return memory;
|
||||
}
|
||||
|
||||
void llama_context::set_memory(llama_memory_ptr memory) {
|
||||
this->memory = std::move(memory);
|
||||
|
||||
sched_need_reserve = true;
|
||||
}
|
||||
|
||||
bool llama_context::memory_update(bool optimize) {
|
||||
@ -1198,18 +1180,6 @@ void llama_context::set_embeddings_nextn(bool value, bool masked) {
|
||||
cparams.embeddings_nextn_masked = masked;
|
||||
}
|
||||
|
||||
void llama_context::set_mtp_source(llama_context * src) {
|
||||
if (src_ctx == src) {
|
||||
return;
|
||||
}
|
||||
llama_assert_gemma4_mtp_source_placement(this, src);
|
||||
src_ctx = src;
|
||||
src_mctx_for_decode.reset();
|
||||
// worst-case compute buffers were reserved without knowing about the source
|
||||
// memory; force a re-reserve so the next decode sees src views
|
||||
sched_need_reserve = true;
|
||||
}
|
||||
|
||||
void llama_context::set_causal_attn(bool value) {
|
||||
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
|
||||
|
||||
@ -1807,20 +1777,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
||||
embd_seq.clear();
|
||||
output_swaps.clear();
|
||||
|
||||
src_mctx_reset_on_exit decode_src_drop{&src_mctx_for_decode};
|
||||
if (src_ctx) {
|
||||
auto * src_memory = src_ctx->get_memory();
|
||||
if (!src_memory) {
|
||||
LLAMA_LOG_ERROR("%s: MTP source context has no memory module\n", __func__);
|
||||
return -2;
|
||||
}
|
||||
src_mctx_for_decode = src_memory->init_full();
|
||||
if (!src_mctx_for_decode) {
|
||||
LLAMA_LOG_ERROR("%s: failed to snapshot MTP source memory\n", __func__);
|
||||
return -2;
|
||||
}
|
||||
}
|
||||
|
||||
sched_reserve();
|
||||
|
||||
bool did_optimize = false;
|
||||
@ -2414,8 +2370,6 @@ llm_graph_params llama_context::graph_params(
|
||||
/*.cvec =*/ cvec.get(),
|
||||
/*.loras =*/ loras.get(),
|
||||
/*.mctx =*/ mctx,
|
||||
/*.src_mctx =*/ src_mctx_for_decode.get(),
|
||||
/*.src_model =*/ src_ctx ? &src_ctx->get_model() : nullptr,
|
||||
/*.cross =*/ &cross,
|
||||
/*.samplers =*/ sampling.samplers,
|
||||
/*.n_outputs =*/ n_outputs,
|
||||
@ -3704,8 +3658,12 @@ void llama_set_embeddings_nextn(llama_context * ctx, bool value, bool masked) {
|
||||
ctx->set_embeddings_nextn(value, masked);
|
||||
}
|
||||
|
||||
void llama_set_mtp_source(llama_context * ctx, llama_context * src) {
|
||||
ctx->set_mtp_source(src);
|
||||
void llama_set_memory(llama_context * ctx, llama_context * src) {
|
||||
ctx->set_memory(src->get_memory());
|
||||
}
|
||||
|
||||
llama_memory_t llama_get_memory(const struct llama_context * ctx) {
|
||||
return ctx->get_memory().get();
|
||||
}
|
||||
|
||||
float * llama_get_embeddings_nextn(llama_context * ctx) {
|
||||
@ -3771,7 +3729,7 @@ struct ggml_cgraph * llama_graph_reserve(
|
||||
uint32_t n_tokens,
|
||||
uint32_t n_seqs,
|
||||
uint32_t n_outputs) {
|
||||
auto * memory = ctx->get_memory();
|
||||
auto memory = ctx->get_memory();
|
||||
llama_memory_context_ptr mctx;
|
||||
if (memory) {
|
||||
mctx = memory->init_full();
|
||||
@ -3811,10 +3769,6 @@ int32_t llama_set_adapter_cvec(
|
||||
// memory
|
||||
//
|
||||
|
||||
llama_memory_t llama_get_memory(const struct llama_context * ctx) {
|
||||
return ctx->get_memory();
|
||||
}
|
||||
|
||||
void llama_memory_clear(llama_memory_t mem, bool data) {
|
||||
if (!mem) {
|
||||
return;
|
||||
|
||||
@ -71,7 +71,8 @@ struct llama_context {
|
||||
uint32_t n_threads() const;
|
||||
uint32_t n_threads_batch() const;
|
||||
|
||||
llama_memory_t get_memory() const;
|
||||
llama_memory_ptr get_memory() const;
|
||||
void set_memory(llama_memory_ptr memory);
|
||||
|
||||
// return true if the memory was updated
|
||||
bool memory_update(bool optimize);
|
||||
@ -112,7 +113,6 @@ struct llama_context {
|
||||
|
||||
void set_embeddings (bool value);
|
||||
void set_embeddings_nextn(bool value, bool masked);
|
||||
void set_mtp_source(llama_context * src);
|
||||
void set_causal_attn(bool value);
|
||||
void set_warmup(bool value);
|
||||
|
||||
@ -275,13 +275,7 @@ private:
|
||||
|
||||
llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably
|
||||
|
||||
std::unique_ptr<llama_memory_i> memory;
|
||||
|
||||
// external KV source used by MTP draft contexts. src_ctx is the target
|
||||
// context whose memory we read; src_mctx_for_decode is a per-decode
|
||||
// snapshot held for the duration of one decode/sched_reserve call.
|
||||
llama_context * src_ctx = nullptr;
|
||||
llama_memory_context_ptr src_mctx_for_decode;
|
||||
llama_memory_ptr memory;
|
||||
|
||||
// decode output (2-dimensional array: [n_outputs][n_vocab])
|
||||
buffer_view<float> logits = {nullptr, 0};
|
||||
|
||||
@ -93,7 +93,7 @@ LLAMA_API llama_memory_breakdown llama_get_memory_breakdown(const struct llama_c
|
||||
// If masked == true, output the embeddings only for the tokens with batch.logits != 0
|
||||
// If masked == false, output the embeddings for all tokens in the batch regardless of batch.logits
|
||||
LLAMA_API void llama_set_embeddings_nextn(struct llama_context * ctx, bool value, bool masked);
|
||||
LLAMA_API void llama_set_mtp_source(struct llama_context * ctx, struct llama_context * src);
|
||||
LLAMA_API void llama_set_memory(struct llama_context * ctx, struct llama_context * src);
|
||||
|
||||
// mirrors:
|
||||
// LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
|
||||
@ -101,3 +101,6 @@ LLAMA_API float * llama_get_embeddings_nextn(struct llama_context * ctx);
|
||||
|
||||
// LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
|
||||
LLAMA_API float * llama_get_embeddings_nextn_ith(struct llama_context * ctx, int32_t i);
|
||||
|
||||
LLAMA_API ggml_tensor * llama_model_get_tok_embd(const struct llama_model * model);
|
||||
LLAMA_API void llama_model_set_tok_embd( struct llama_model * model, ggml_tensor * tensor);
|
||||
|
||||
@ -397,7 +397,7 @@ static void print_mask(const T * data, int64_t n_tokens, int64_t n_kv, int64_t n
|
||||
case LLAMA_SWA_TYPE_SYMMETRIC: swa_type_str = "LLAMA_SWA_TYPE_SYMMETRIC"; break;
|
||||
};
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swq_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str);
|
||||
LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swa_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str);
|
||||
LLAMA_LOG_DEBUG("%s: '0' = can attend, '∞' = masked\n", __func__);
|
||||
LLAMA_LOG_DEBUG("%s: Rows = query tokens, Columns = key/value tokens\n\n", __func__);
|
||||
|
||||
@ -565,18 +565,18 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
|
||||
if (self_k_idxs && self_k_idxs->buffer) {
|
||||
mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
|
||||
mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
|
||||
|
||||
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
||||
}
|
||||
|
||||
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
||||
|
||||
// swa tensors may not be allocated if there are no SWA attention layers
|
||||
if (self_k_idxs_swa && self_k_idxs_swa->buffer) {
|
||||
mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch);
|
||||
mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
|
||||
|
||||
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
|
||||
}
|
||||
|
||||
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
|
||||
|
||||
if (self_k_rot) {
|
||||
mctx->get_base()->set_input_k_rot(self_k_rot);
|
||||
}
|
||||
@ -605,47 +605,18 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
|
||||
if (self_k_idxs && self_k_idxs->buffer) {
|
||||
res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
|
||||
//res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
|
||||
|
||||
res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams);
|
||||
}
|
||||
|
||||
res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams);
|
||||
|
||||
// swa tensors may not be allocated if there are no SWA attention layers
|
||||
if (self_k_idxs_swa && self_k_idxs_swa->buffer) {
|
||||
res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
|
||||
//res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
|
||||
|
||||
res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams);
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams);
|
||||
|
||||
void llm_graph_input_attn_src_kv_iswa::set_input(const llama_ubatch * ubatch) {
|
||||
src_mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
||||
src_mctx->get_swa() ->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
|
||||
|
||||
if (self_k_rot) {
|
||||
src_mctx->get_base()->set_input_k_rot(self_k_rot);
|
||||
}
|
||||
if (self_v_rot) {
|
||||
src_mctx->get_base()->set_input_v_rot(self_v_rot);
|
||||
}
|
||||
if (self_k_rot_swa) {
|
||||
src_mctx->get_swa()->set_input_k_rot(self_k_rot_swa);
|
||||
}
|
||||
if (self_v_rot_swa) {
|
||||
src_mctx->get_swa()->set_input_v_rot(self_v_rot_swa);
|
||||
}
|
||||
}
|
||||
|
||||
bool llm_graph_input_attn_src_kv_iswa::can_reuse(const llm_graph_params & params) {
|
||||
const auto * mctx = static_cast<const llama_kv_cache_iswa_context *>(params.src_mctx);
|
||||
|
||||
this->src_mctx = mctx;
|
||||
|
||||
bool res = true;
|
||||
res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams);
|
||||
res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams);
|
||||
return res;
|
||||
}
|
||||
|
||||
@ -785,18 +756,18 @@ void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) {
|
||||
if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) {
|
||||
attn_ctx->get_base()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch);
|
||||
attn_ctx->get_base()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch);
|
||||
|
||||
attn_ctx->get_base()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
|
||||
}
|
||||
|
||||
attn_ctx->get_base()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
|
||||
|
||||
// swa tensors may not be allocated if there are no SWA attention layers
|
||||
if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) {
|
||||
attn_ctx->get_swa()->set_input_k_idxs(inp_attn->self_k_idxs_swa, ubatch);
|
||||
attn_ctx->get_swa()->set_input_v_idxs(inp_attn->self_v_idxs_swa, ubatch);
|
||||
|
||||
attn_ctx->get_swa()->set_input_kq_mask(inp_attn->self_kq_mask_swa, ubatch, cparams.causal_attn);
|
||||
}
|
||||
|
||||
attn_ctx->get_swa()->set_input_kq_mask(inp_attn->self_kq_mask_swa, ubatch, cparams.causal_attn);
|
||||
|
||||
if (inp_attn->self_k_rot) {
|
||||
attn_ctx->get_base()->set_input_k_rot(inp_attn->self_k_rot);
|
||||
}
|
||||
@ -839,18 +810,18 @@ bool llm_graph_input_mem_hybrid_iswa::can_reuse(const llm_graph_params & params)
|
||||
if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) {
|
||||
res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
|
||||
//res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
|
||||
|
||||
res &= can_reuse_kq_mask(inp_attn->self_kq_mask, attn_ctx->get_base(), params.ubatch, params.cparams);
|
||||
}
|
||||
|
||||
res &= can_reuse_kq_mask(inp_attn->self_kq_mask, attn_ctx->get_base(), params.ubatch, params.cparams);
|
||||
|
||||
// swa tensors may not be allocated if there are no SWA attention layers
|
||||
if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) {
|
||||
res &= inp_attn->self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
|
||||
//res &= inp_attn->self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
|
||||
|
||||
res &= can_reuse_kq_mask(inp_attn->self_kq_mask_swa, attn_ctx->get_swa(), params.ubatch, params.cparams);
|
||||
}
|
||||
|
||||
res &= can_reuse_kq_mask(inp_attn->self_kq_mask_swa, attn_ctx->get_swa(), params.ubatch, params.cparams);
|
||||
|
||||
res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
|
||||
|
||||
res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs;
|
||||
@ -1063,8 +1034,6 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
|
||||
cvec (params.cvec),
|
||||
loras (params.loras),
|
||||
mctx (params.mctx),
|
||||
src_mctx (params.src_mctx),
|
||||
src_model (params.src_model),
|
||||
cross (params.cross),
|
||||
samplers (params.samplers),
|
||||
cb_func (params.cb),
|
||||
@ -2574,6 +2543,16 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||
ggml_build_forward_expand(gf, v_cur);
|
||||
}
|
||||
|
||||
int il_save = il;
|
||||
|
||||
if (arch == LLM_ARCH_GEMMA4_ASSISTANT) {
|
||||
if (il == n_layer - 1) {
|
||||
il = 59;
|
||||
} else {
|
||||
il = 58;
|
||||
}
|
||||
}
|
||||
|
||||
const auto * mctx_iswa = inp->mctx;
|
||||
|
||||
const auto * mctx_cur = is_swa ? mctx_iswa->get_swa() : mctx_iswa->get_base();
|
||||
@ -2597,6 +2576,8 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
|
||||
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
|
||||
|
||||
il = il_save;
|
||||
|
||||
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
|
||||
cb(cur, "kqv_out", il);
|
||||
|
||||
@ -2635,98 +2616,6 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
|
||||
return (llm_graph_input_attn_cross *) res->add_input(std::move(inp));
|
||||
}
|
||||
|
||||
llm_graph_input_attn_src_kv_iswa * llm_graph_context::build_attn_inp_src_kv_iswa() const {
|
||||
GGML_ASSERT(src_mctx && "MTP draft graph requires src_mctx (set via llama_set_mtp_source)");
|
||||
|
||||
const auto * src_iswa = static_cast<const llama_kv_cache_iswa_context *>(src_mctx);
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_attn_src_kv_iswa>(hparams, cparams, src_iswa);
|
||||
|
||||
inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, src_iswa->get_base(), ubatch, cparams);
|
||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||
|
||||
inp->self_kq_mask_swa = build_attn_inp_kq_mask(ctx0, src_iswa->get_swa(), ubatch, cparams);
|
||||
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
||||
|
||||
inp->self_k_rot = src_iswa->get_base()->build_input_k_rot(ctx0);
|
||||
inp->self_v_rot = src_iswa->get_base()->build_input_v_rot(ctx0);
|
||||
inp->self_k_rot_swa = src_iswa->get_swa()->build_input_k_rot(ctx0);
|
||||
inp->self_v_rot_swa = src_iswa->get_swa()->build_input_v_rot(ctx0);
|
||||
|
||||
return (llm_graph_input_attn_src_kv_iswa *) res->add_input(std::move(inp));
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_attn(
|
||||
llm_graph_input_attn_src_kv_iswa * inp,
|
||||
ggml_tensor * wo,
|
||||
ggml_tensor * wo_b,
|
||||
ggml_tensor * wo_s,
|
||||
ggml_tensor * q_cur,
|
||||
ggml_tensor * kq_b,
|
||||
ggml_tensor * sinks,
|
||||
ggml_tensor * v_mla,
|
||||
float kq_scale,
|
||||
int il_assist,
|
||||
int il_src) const {
|
||||
const bool is_swa = hparams.is_swa(il_assist);
|
||||
|
||||
const auto * src_iswa = inp->src_mctx;
|
||||
const auto * src_cur = is_swa ? src_iswa->get_swa() : src_iswa->get_base();
|
||||
|
||||
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
|
||||
|
||||
auto * k_rot = is_swa ? inp->self_k_rot_swa : inp->self_k_rot;
|
||||
auto * v_rot = is_swa ? inp->self_v_rot_swa : inp->self_v_rot;
|
||||
|
||||
if (k_rot) {
|
||||
q_cur = ggml_mul_mat_aux(ctx0, q_cur, k_rot);
|
||||
}
|
||||
|
||||
ggml_build_forward_expand(gf, q_cur);
|
||||
|
||||
ggml_tensor * q = q_cur;
|
||||
ggml_tensor * k = src_cur->get_k(ctx0, il_src);
|
||||
ggml_tensor * v = src_cur->get_v(ctx0, il_src);
|
||||
|
||||
// build_attn_mha splits q across k->ne[3] (the trunk's stream count). When the
|
||||
// trunk runs kv_unified=false the assistant's ubatch only references a subset
|
||||
// of streams (one per active draft seq); q->ne[2] is not divisible by the full
|
||||
// n_stream and the view collapses tokens. Slice k/v down to exactly the streams
|
||||
// referenced by this ubatch. Requires those streams to form a contiguous range.
|
||||
if (k->ne[3] > 1 && (uint32_t) k->ne[3] != ubatch.n_seqs_unq) {
|
||||
GGML_ASSERT(ubatch.n_seqs_unq > 0 && ubatch.seq_id_unq);
|
||||
llama_seq_id min_s = ubatch.seq_id_unq[0];
|
||||
llama_seq_id max_s = ubatch.seq_id_unq[0];
|
||||
for (uint32_t s = 1; s < ubatch.n_seqs_unq; ++s) {
|
||||
min_s = std::min(min_s, ubatch.seq_id_unq[s]);
|
||||
max_s = std::max(max_s, ubatch.seq_id_unq[s]);
|
||||
}
|
||||
GGML_ASSERT((uint32_t)(max_s - min_s + 1) == ubatch.n_seqs_unq &&
|
||||
"MTP src-kv attn requires the active draft seq_ids to be contiguous");
|
||||
GGML_ASSERT((int64_t) max_s < k->ne[3] && "MTP assistant seq_id beyond trunk stream count");
|
||||
|
||||
k = ggml_view_4d(ctx0, k, k->ne[0], k->ne[1], k->ne[2], (int64_t) ubatch.n_seqs_unq,
|
||||
k->nb[1], k->nb[2], k->nb[3], (size_t) min_s * k->nb[3]);
|
||||
v = ggml_view_4d(ctx0, v, v->ne[0], v->ne[1], v->ne[2], (int64_t) ubatch.n_seqs_unq,
|
||||
v->nb[1], v->nb[2], v->nb[3], (size_t) min_s * v->nb[3]);
|
||||
}
|
||||
|
||||
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il_assist);
|
||||
cb(cur, "kqv_out", il_assist);
|
||||
|
||||
if (v_rot) {
|
||||
cur = ggml_mul_mat_aux(ctx0, cur, v_rot);
|
||||
}
|
||||
|
||||
if (wo) {
|
||||
cur = build_lora_mm(wo, cur, wo_s);
|
||||
}
|
||||
if (wo_b) {
|
||||
cur = ggml_add(ctx0, cur, wo_b);
|
||||
}
|
||||
return cur;
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_attn(
|
||||
llm_graph_input_attn_cross * inp,
|
||||
ggml_tensor * wo,
|
||||
|
||||
@ -459,42 +459,6 @@ public:
|
||||
const llama_kv_cache_iswa_context * mctx;
|
||||
};
|
||||
|
||||
// mask-only input for attention against an external (read-only) ISWA KV cache.
|
||||
// used by MTP draft graphs that attend to the target's KV without owning any.
|
||||
class llm_graph_input_attn_src_kv_iswa : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_attn_src_kv_iswa(
|
||||
const llama_hparams & hparams,
|
||||
const llama_cparams & cparams,
|
||||
const llama_kv_cache_iswa_context * src_mctx) :
|
||||
hparams(hparams),
|
||||
cparams(cparams),
|
||||
src_mctx(src_mctx) {
|
||||
}
|
||||
~llm_graph_input_attn_src_kv_iswa() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
bool can_reuse(const llm_graph_params & params) override;
|
||||
|
||||
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
||||
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
|
||||
|
||||
ggml_tensor * self_kq_mask = nullptr;
|
||||
ggml_tensor * self_kq_mask_cnv = nullptr;
|
||||
ggml_tensor * self_kq_mask_swa = nullptr;
|
||||
ggml_tensor * self_kq_mask_swa_cnv = nullptr;
|
||||
|
||||
ggml_tensor * self_k_rot = nullptr;
|
||||
ggml_tensor * self_v_rot = nullptr;
|
||||
ggml_tensor * self_k_rot_swa = nullptr;
|
||||
ggml_tensor * self_v_rot_swa = nullptr;
|
||||
|
||||
const llama_hparams hparams;
|
||||
const llama_cparams cparams;
|
||||
|
||||
const llama_kv_cache_iswa_context * src_mctx;
|
||||
};
|
||||
|
||||
class llm_graph_input_attn_cross : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_attn_cross(const llama_cross * cross) : cross(cross) {}
|
||||
@ -637,11 +601,6 @@ struct llm_graph_params {
|
||||
const llama_adapter_cvec * cvec;
|
||||
const llama_adapter_loras * loras;
|
||||
const llama_memory_context_i * mctx;
|
||||
// per-decode snapshot of an external memory module the graph reads from
|
||||
// (never writes) — e.g. ctx_dft reading target KV during MTP draft.
|
||||
// nullptr for a main decode. Rebound inside reuse-aware input classes.
|
||||
const llama_memory_context_i * src_mctx;
|
||||
const llama_model * src_model;
|
||||
const llama_cross * cross;
|
||||
|
||||
std::map<llama_seq_id, llama_sampler *> samplers;
|
||||
@ -859,8 +818,6 @@ struct llm_graph_context {
|
||||
const llama_adapter_cvec * cvec;
|
||||
const llama_adapter_loras * loras;
|
||||
const llama_memory_context_i * mctx;
|
||||
const llama_memory_context_i * src_mctx;
|
||||
const llama_model * src_model;
|
||||
const llama_cross * cross;
|
||||
|
||||
std::map<llama_seq_id, llama_sampler *> samplers;
|
||||
@ -1090,24 +1047,6 @@ struct llm_graph_context {
|
||||
float kq_scale,
|
||||
int il) const;
|
||||
|
||||
llm_graph_input_attn_src_kv_iswa * build_attn_inp_src_kv_iswa() const;
|
||||
|
||||
// Q-only attention against an external ISWA KV cache (no K/V projections,
|
||||
// no writes). il_assist labels the attention block in the local graph for
|
||||
// logging; il_src indexes the source K/V layer to attend to.
|
||||
ggml_tensor * build_attn(
|
||||
llm_graph_input_attn_src_kv_iswa * inp,
|
||||
ggml_tensor * wo,
|
||||
ggml_tensor * wo_b,
|
||||
ggml_tensor * wo_s,
|
||||
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
||||
ggml_tensor * kq_b,
|
||||
ggml_tensor * sinks, // [n_head_q]
|
||||
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
||||
float kq_scale,
|
||||
int il_assist,
|
||||
int il_src) const;
|
||||
|
||||
llm_graph_input_attn_cross * build_attn_inp_cross() const;
|
||||
|
||||
ggml_tensor * build_attn(
|
||||
|
||||
@ -1552,6 +1552,10 @@ static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, T * data
|
||||
goto skip;
|
||||
}
|
||||
|
||||
if (ubatch->n_tokens == 1 && p0 == p1) {
|
||||
goto skip;
|
||||
}
|
||||
|
||||
// M-RoPE causal mask
|
||||
if (is_2d) {
|
||||
if (p0 == p1) {
|
||||
|
||||
@ -122,4 +122,4 @@ struct llama_memory_i {
|
||||
virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) = 0;
|
||||
};
|
||||
|
||||
using llama_memory_ptr = std::unique_ptr<llama_memory_i>;
|
||||
using llama_memory_ptr = std::shared_ptr<llama_memory_i>;
|
||||
|
||||
@ -1987,6 +1987,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||
case LLM_ARCH_WAVTOKENIZER_DEC:
|
||||
case LLM_ARCH_MODERN_BERT:
|
||||
case LLM_ARCH_GEMMA_EMBEDDING:
|
||||
case LLM_ARCH_GEMMA4_ASSISTANT:
|
||||
case LLM_ARCH_DREAM:
|
||||
case LLM_ARCH_LLADA:
|
||||
case LLM_ARCH_LLADA_MOE:
|
||||
@ -2630,3 +2631,11 @@ void llama_model_base::create_tensor_qkv(llama_layer & layer, int bid,
|
||||
layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", bid), {n_embd_v_}, TENSOR_NOT_REQUIRED);
|
||||
}
|
||||
}
|
||||
|
||||
ggml_tensor * llama_model_get_tok_embd(const struct llama_model * model) {
|
||||
return model->tok_embd;
|
||||
}
|
||||
|
||||
void llama_model_set_tok_embd(struct llama_model * model, ggml_tensor * tensor) {
|
||||
model->tok_embd = tensor;
|
||||
}
|
||||
|
||||
@ -706,6 +706,7 @@ const char * llm_type_name(llm_type type);
|
||||
#define LLAMA_LOAD_LOCALS \
|
||||
const int n_layer = hparams.n_layer(); GGML_UNUSED(n_layer); \
|
||||
const int n_layer_all = hparams.n_layer_all; GGML_UNUSED(n_layer_all); \
|
||||
const int n_layer_nextn = hparams.n_layer_nextn; GGML_UNUSED(n_layer_nextn); \
|
||||
const int64_t n_head = hparams.n_head(); GGML_UNUSED(n_head); \
|
||||
const int64_t n_head_kv = hparams.n_head_kv(); GGML_UNUSED(n_head_kv); \
|
||||
const int64_t n_embd = hparams.n_embd; GGML_UNUSED(n_embd); \
|
||||
|
||||
@ -82,18 +82,6 @@ std::unique_ptr<llm_graph_context> llama_model_gemma4_assistant::build_arch_grap
|
||||
|
||||
llama_model_gemma4_assistant::graph::graph(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_graph_context(params) {
|
||||
GGML_ASSERT(src_mctx && "Gemma 4 assistant graph requires an MTP source (llama_set_mtp_source)");
|
||||
GGML_ASSERT(src_model && "Gemma 4 assistant graph requires a source model");
|
||||
GGML_ASSERT(src_model->tok_embd && "source model missing tok_embd");
|
||||
|
||||
const auto & src_hparams = src_model->hparams;
|
||||
|
||||
// By convention the MTP draft reads from the trunk's final SWA and full layers.
|
||||
const int32_t src_layer_full = (int32_t) src_hparams.n_layer() - 1;
|
||||
const int32_t src_layer_swa = (int32_t) src_hparams.n_layer() - 2;
|
||||
GGML_ASSERT(!src_hparams.is_swa(src_layer_full) && "trunk's last layer must be full attention");
|
||||
GGML_ASSERT( src_hparams.is_swa(src_layer_swa) && "trunk's penultimate layer must be SWA");
|
||||
|
||||
const int64_t n_embd_backbone = hparams.n_embd_out();
|
||||
|
||||
ggml_tensor * inp_tokens;
|
||||
@ -116,7 +104,7 @@ llama_model_gemma4_assistant::graph::graph(const llama_model & model, const llm_
|
||||
res->add_input(std::move(inp));
|
||||
}
|
||||
|
||||
ggml_tensor * x = ggml_get_rows(ctx0, src_model->tok_embd, inp_tokens);
|
||||
ggml_tensor * x = ggml_get_rows(ctx0, model.tok_embd, inp_tokens);
|
||||
x = ggml_scale(ctx0, x, sqrtf((float) n_embd_backbone));
|
||||
cb(x, "inp_embd_target", -1);
|
||||
|
||||
@ -126,15 +114,14 @@ llama_model_gemma4_assistant::graph::graph(const llama_model & model, const llm_
|
||||
ggml_tensor * cur = ggml_mul_mat(ctx0, model.nextn_pre_proj, xh);
|
||||
cb(cur, "pre_proj", -1);
|
||||
|
||||
auto * inp_attn = build_attn_inp_src_kv_iswa();
|
||||
auto * inp_attn = build_attn_inp_kv_iswa();
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
ggml_tensor * inpL = cur;
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
const bool is_swa = hparams.is_swa(il);
|
||||
const int32_t il_src = is_swa ? src_layer_swa : src_layer_full;
|
||||
const bool is_swa = hparams.is_swa(il);
|
||||
|
||||
const int64_t n_embd_head = hparams.n_embd_head_k(il);
|
||||
const int64_t n_head = hparams.n_head(il);
|
||||
@ -157,7 +144,7 @@ llama_model_gemma4_assistant::graph::graph(const llama_model & model, const llm_
|
||||
cb(Qcur, "Qcur_pos", il);
|
||||
|
||||
cur = build_attn(inp_attn, model.layers[il].wo, nullptr, nullptr,
|
||||
Qcur, nullptr, nullptr, nullptr, hparams.f_attention_scale, il, il_src);
|
||||
Qcur, nullptr, nullptr, nullptr, nullptr, nullptr, hparams.f_attention_scale, il);
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
|
||||
@ -974,13 +974,6 @@ private:
|
||||
cparams.n_rs_seq = 0;
|
||||
ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams));
|
||||
|
||||
if (spec_mtp) {
|
||||
// MTP draft must know its target before the first decode
|
||||
llama_set_mtp_source(ctx_dft.get(), ctx_tgt);
|
||||
}
|
||||
|
||||
ctx_dft_seq_rm_type = common_context_can_seq_rm(ctx_dft.get());
|
||||
|
||||
params_base.speculative.draft.ctx_tgt = ctx_tgt;
|
||||
params_base.speculative.draft.ctx_dft = ctx_dft.get();
|
||||
} else if (std::find(params_base.speculative.types.begin(), params_base.speculative.types.end(),
|
||||
@ -1001,12 +994,6 @@ private:
|
||||
return false;
|
||||
}
|
||||
|
||||
// wire the source before any decode (the seq-rm probe below
|
||||
// triggers sched_reserve which needs src for Gemma4-style MTP)
|
||||
llama_set_mtp_source(ctx_dft.get(), ctx_tgt);
|
||||
|
||||
ctx_dft_seq_rm_type = common_context_can_seq_rm(ctx_dft.get());
|
||||
|
||||
params_base.speculative.draft.ctx_tgt = ctx_tgt;
|
||||
params_base.speculative.draft.ctx_dft = ctx_dft.get();
|
||||
}
|
||||
@ -1094,6 +1081,10 @@ private:
|
||||
}
|
||||
}
|
||||
|
||||
if (ctx_dft) {
|
||||
ctx_dft_seq_rm_type = common_context_can_seq_rm(ctx_dft.get());
|
||||
}
|
||||
|
||||
if (spec) {
|
||||
SRV_INF("%s", "speculative decoding context initialized\n");
|
||||
} else {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user