mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-06-28 04:30:15 -05:00
Add llama_context to MTP (#1601)
* wip: separate llama_context for MTP with graph reuse * wip: fix KV cache desync with separate MTP context * refactor: remove dead mtp logic code, encapsulate KV mirroring * mtp-context: derive args directly from the main model's context * mtp: fix kv cache positions * clean small comments * minor refactor for context shift
This commit is contained in:
parent
9b5785ad6b
commit
557b674f63
@ -148,11 +148,13 @@ struct common_speculative_state {
|
||||
|
||||
struct common_speculative_state_mtp : public common_speculative_state {
|
||||
llama_context * ctx_tgt;
|
||||
llama_context * ctx_mtp = nullptr;
|
||||
common_sampler * smpl;
|
||||
|
||||
common_speculative_state_mtp(
|
||||
enum common_speculative_type type,
|
||||
llama_context * ctx_tgt)
|
||||
llama_context * ctx_tgt,
|
||||
const llama_context_params & mtp_cparams)
|
||||
: common_speculative_state(type)
|
||||
, ctx_tgt(ctx_tgt)
|
||||
{
|
||||
@ -161,10 +163,21 @@ struct common_speculative_state_mtp : public common_speculative_state {
|
||||
llama_sampler_type::DIST,
|
||||
};
|
||||
smpl = common_sampler_init(llama_get_model(ctx_tgt), params);
|
||||
|
||||
const llama_model * model = llama_get_model(ctx_tgt);
|
||||
ctx_mtp = llama_init_from_model(const_cast<llama_model *>(model), mtp_cparams);
|
||||
if (ctx_mtp) {
|
||||
LOG_INF("%s: created MTP context (n_ctx=%d)\n", __func__, llama_n_ctx(ctx_mtp));
|
||||
} else {
|
||||
LOG_ERR("%s: failed to create MTP context, falling back to shared context\n", __func__);
|
||||
}
|
||||
}
|
||||
|
||||
~common_speculative_state_mtp() override {
|
||||
common_sampler_free(smpl);
|
||||
if (ctx_mtp) {
|
||||
llama_free(ctx_mtp);
|
||||
}
|
||||
}
|
||||
|
||||
void begin(const llama_tokens & prompt) override {
|
||||
@ -178,12 +191,20 @@ struct common_speculative_state_mtp : public common_speculative_state {
|
||||
llama_tokens & result) override {
|
||||
|
||||
int32_t n_past = (int32_t)prompt_tgt.size();
|
||||
|
||||
llama_seq_id seq_id = 0;
|
||||
|
||||
if (ctx_mtp) {
|
||||
llama_pos mtp_pos_max = llama_kv_cache_seq_pos_max(ctx_mtp, seq_id);
|
||||
if (mtp_pos_max >= n_past) {
|
||||
llama_kv_cache_seq_rm(ctx_mtp, seq_id, n_past, -1);
|
||||
}
|
||||
}
|
||||
|
||||
llama_context * ctx = ctx_mtp ? ctx_mtp : ctx_tgt;
|
||||
|
||||
result = mtp_speculative_gen_draft(
|
||||
smpl,
|
||||
ctx_tgt,
|
||||
ctx,
|
||||
params.n_max,
|
||||
params.p_min,
|
||||
id_last,
|
||||
@ -954,7 +975,8 @@ common_speculative * common_speculative_init(
|
||||
}
|
||||
case COMMON_SPECULATIVE_TYPE_MTP: {
|
||||
impls.push_back(std::make_unique<common_speculative_state_mtp>(config.type,
|
||||
/* .ctx_tgt = */ ctx_tgt
|
||||
/* .ctx_tgt = */ ctx_tgt,
|
||||
/* .mtp_cparams = */ params.cparams_dft
|
||||
));
|
||||
break;
|
||||
}
|
||||
@ -1166,6 +1188,33 @@ void common_speculative_print_stats(const common_speculative * spec, double slot
|
||||
// ----------------------------------------------------------------------------
|
||||
// MTP
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
llama_context * common_speculative_get_mtp_ctx(common_speculative * spec) {
|
||||
if (!spec) return nullptr;
|
||||
|
||||
for (auto & impl : spec->impls) {
|
||||
if (impl->type == COMMON_SPECULATIVE_TYPE_MTP) {
|
||||
auto * mtp_state = dynamic_cast<common_speculative_state_mtp *>(impl.get());
|
||||
if (mtp_state) {
|
||||
return mtp_state->ctx_mtp;
|
||||
}
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void common_speculative_context_shift(
|
||||
common_speculative * spec,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos kv_keep,
|
||||
llama_pos kv_discard,
|
||||
llama_pos kv_past) {
|
||||
if (auto * ctx_mtp = common_speculative_get_mtp_ctx(spec); ctx_mtp != nullptr) {
|
||||
llama_kv_cache_seq_rm (ctx_mtp, seq_id, kv_keep, kv_keep + kv_discard);
|
||||
llama_kv_cache_seq_add(ctx_mtp, seq_id, kv_keep + kv_discard, kv_past, -kv_discard);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<llama_token> mtp_speculative_gen_draft(
|
||||
struct common_sampler * smpl,
|
||||
struct llama_context * ctx,
|
||||
@ -1231,7 +1280,15 @@ void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, b
|
||||
return;
|
||||
}
|
||||
|
||||
LOG_DBG("[MTP-UPDATE|%s] Updating %d tokens...\n", is_prompt_warmup ? "PROMPT_WARMUP" : "GEN_ACCEPTED", batch.n_tokens);
|
||||
llama_seq_id seq_id = batch.seq_id[0][0];
|
||||
llama_pos start_pos = batch.pos[0];
|
||||
|
||||
if (llama_kv_cache_seq_pos_max(ctx, seq_id) >= start_pos) {
|
||||
llama_kv_cache_seq_rm(ctx, seq_id, start_pos, -1);
|
||||
}
|
||||
|
||||
LOG_DBG("[MTP-UPDATE|%s] Updating %d tokens from pos %d...\n",
|
||||
is_prompt_warmup ? "PROMPT_WARMUP" : "GEN_ACCEPTED", batch.n_tokens, (int)start_pos);
|
||||
|
||||
llama_batch mtp_batch = batch;
|
||||
if (is_prompt_warmup) {
|
||||
|
||||
@ -41,6 +41,17 @@ void common_speculative_accept(common_speculative * spec, uint16_t n_accepted);
|
||||
// print statistics about the speculative decoding
|
||||
void common_speculative_print_stats(const common_speculative * spec, double slot_tps = 0.0, int n_decoded = 0, int n_past = 0, common_params_speculative * active_params = nullptr);
|
||||
|
||||
// get the MTP context from the speculative object (nullptr if not MTP type)
|
||||
llama_context * common_speculative_get_mtp_ctx(common_speculative * spec);
|
||||
|
||||
// Context shift for MTP to match how server handle main model
|
||||
void common_speculative_context_shift(
|
||||
common_speculative * spec,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos kv_keep,
|
||||
llama_pos kv_discard,
|
||||
llama_pos kv_past);
|
||||
|
||||
// Generates speculative draft tokens using the Multi-Token Prediction (MTP) architecture.
|
||||
std::vector<llama_token> mtp_speculative_gen_draft(
|
||||
struct common_sampler * smpl,
|
||||
|
||||
@ -214,9 +214,15 @@ void server_context::init() {
|
||||
params_base.speculative.type = COMMON_SPECULATIVE_TYPE_MTP;
|
||||
params_base.pooling_type = LLAMA_POOLING_TYPE_NONE;
|
||||
|
||||
params_base.speculative.cparams_dft = common_context_params_to_llama(params_base);
|
||||
params_base.speculative.cparams_dft.mtp = true;
|
||||
params_base.speculative.cparams_dft.mtp_op_type = MTP_OP_WARMUP;
|
||||
params_base.speculative.cparams_dft.embeddings = true;
|
||||
|
||||
slot.has_mtp = true;
|
||||
slot.params.speculative.type = COMMON_SPECULATIVE_TYPE_MTP;
|
||||
slot.params.speculative.n_min = 0;
|
||||
slot.params.speculative.cparams_dft = params_base.speculative.cparams_dft;
|
||||
|
||||
slot.batch_spec = llama_batch_init(slot.params.speculative.n_max + 1, 0, 1);
|
||||
SLT_DBG(slot, "batch_spec contains %d tokens\n", slot.batch_spec.n_tokens);
|
||||
@ -2622,6 +2628,9 @@ void server_context::discard_n_kv_and_cache_tokens(llama_context* ctx, server_sl
|
||||
const auto pos_max = llama_kv_cache_seq_pos_max(slot.ctx, slot.id);
|
||||
llama_kv_cache_seq_rm(ctx, slot.id, kv_keep, kv_keep + kv_discard);
|
||||
llama_kv_cache_seq_add(ctx, slot.id, kv_keep + kv_discard, kv_past, -kv_discard);
|
||||
if (slot.has_mtp && slot.spec) {
|
||||
common_speculative_context_shift(slot.spec, slot.id, kv_keep, kv_discard, kv_past);
|
||||
}
|
||||
if (slot.params.cache_prompt) {
|
||||
slot.cache_tokens.discard_n_tokens(n_keep, n_discard);
|
||||
}
|
||||
@ -2838,10 +2847,12 @@ void server_context::add_sampled_tokens() {
|
||||
auto & params_spec = slot.params.speculative;
|
||||
|
||||
if (slot.has_mtp) {
|
||||
llama_context * mtp_ctx = common_speculative_get_mtp_ctx(slot.spec);
|
||||
llama_context * hs_ctx = mtp_ctx ? mtp_ctx : ctx;
|
||||
if (!slot.mtp_hidden_state.empty()) {
|
||||
const int n_embd = llama_model_n_embd(llama_get_model(ctx));
|
||||
const int n_hidden = slot.mtp_hidden_state.size() / n_embd;
|
||||
llama_set_draft_input_hidden_state(ctx, slot.mtp_hidden_state.data() + (n_hidden - 1) * n_embd);
|
||||
llama_set_draft_input_hidden_state(hs_ctx, slot.mtp_hidden_state.data() + (n_hidden - 1) * n_embd);
|
||||
} else {
|
||||
LOG_ERROR("MTP hidden state is empty during speculation", {});
|
||||
const float* emb_neg1 = llama_get_embeddings_ith(ctx, -1);
|
||||
@ -2849,7 +2860,7 @@ void server_context::add_sampled_tokens() {
|
||||
const int n_embd = llama_model_n_embd(llama_get_model(ctx));
|
||||
slot.mtp_hidden_state.resize(n_embd);
|
||||
memcpy(slot.mtp_hidden_state.data(), emb_neg1, n_embd * sizeof(float));
|
||||
llama_set_draft_input_hidden_state(ctx, slot.mtp_hidden_state.data());
|
||||
llama_set_draft_input_hidden_state(hs_ctx, slot.mtp_hidden_state.data());
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -3415,6 +3426,9 @@ void server_context::speculative_decoding_accept() {
|
||||
const auto ids = common_sampler_sample_and_accept_n(slot.ctx_sampling, ctx, slot.i_batch_dft, slot.drafted);
|
||||
|
||||
if (slot.has_mtp) {
|
||||
llama_context * mtp_ctx = common_speculative_get_mtp_ctx(slot.spec);
|
||||
llama_context * mtp_target = mtp_ctx ? mtp_ctx : ctx;
|
||||
|
||||
const int n_embd = llama_model_n_embd(llama_get_model(ctx));
|
||||
if (!ids.empty()) {
|
||||
const float* emb = llama_get_embeddings(ctx);
|
||||
@ -3430,10 +3444,10 @@ void server_context::speculative_decoding_accept() {
|
||||
}
|
||||
}
|
||||
|
||||
llama_set_draft_input_hidden_state(ctx, slot.mtp_hidden_state.data());
|
||||
llama_set_draft_input_hidden_state(mtp_target, slot.mtp_hidden_state.data());
|
||||
|
||||
int32_t n_past_base = slot.n_past - (slot.drafted.size() + 1);
|
||||
mtp_accept_tokens(ctx, ids, n_past_base, slot.id);
|
||||
mtp_accept_tokens(mtp_target, ids, n_past_base, slot.id);
|
||||
}
|
||||
|
||||
slot.i_batch_dft.clear();
|
||||
@ -3933,8 +3947,16 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
|
||||
slot.i_batch = -1;
|
||||
}
|
||||
if (mtp_warmup_needed && !batch_mtp_hidden_state.empty()) {
|
||||
llama_set_draft_input_hidden_state(ctx, batch_mtp_hidden_state.data());
|
||||
mtp_update_kv_cache(ctx, batch_view, true);
|
||||
llama_context * mtp_ctx = nullptr;
|
||||
for (auto & slot : slots) {
|
||||
if (slot.spec && slot.has_mtp) {
|
||||
llama_context * mc = common_speculative_get_mtp_ctx(slot.spec);
|
||||
if (mc) { mtp_ctx = mc; break; }
|
||||
}
|
||||
}
|
||||
llama_context * mtp_target = mtp_ctx ? mtp_ctx : ctx;
|
||||
llama_set_draft_input_hidden_state(mtp_target, batch_mtp_hidden_state.data());
|
||||
mtp_update_kv_cache(mtp_target, batch_view, true);
|
||||
}
|
||||
|
||||
// speculative decoding - main model sample and accept
|
||||
|
||||
@ -47,9 +47,6 @@ struct llama_kv_cache {
|
||||
uint32_t size = 0;
|
||||
uint32_t used = 0; // used cells (i.e. at least one seq_id)
|
||||
|
||||
// Track's main model's head position for MTP KV cache operations
|
||||
uint32_t mtp_kv_head_hint = 0;
|
||||
|
||||
// computed before each graph build
|
||||
uint32_t n = 0;
|
||||
|
||||
|
||||
@ -850,7 +850,16 @@ static bool llama_kv_cache_init(
|
||||
}
|
||||
|
||||
int n_mla = 0;
|
||||
const int64_t n_mtp_first_layer = n_layer - hparams.nextn_predict_layers;
|
||||
for (int i = 0; i < (int) n_layer; i++) {
|
||||
// For MTP-only context, skip KV allocation for non-MTP layers
|
||||
if (cparams.mtp_op_type != MTP_OP_NONE && i < (int)n_mtp_first_layer) {
|
||||
cache.k_l.push_back(nullptr);
|
||||
if (!is_mla_attn || !cparams.mla_attn || (cparams.mla_attn == 1 && !cparams.flash_attn)) {
|
||||
cache.v_l.push_back(nullptr);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
const bool qnext_recurrent = llama_is_recurrent_layer(hparams, i);
|
||||
const uint32_t n_embd_v_row = llama_kv_v_row_embd(model, hparams, i);
|
||||
const uint32_t n_head_kv = hparams.n_head_kv(i);
|
||||
@ -1066,8 +1075,7 @@ static bool llama_kv_cache_init(
|
||||
// to the first cell of the slot.
|
||||
static bool llama_kv_cache_find_slot(
|
||||
struct llama_kv_cache & cache,
|
||||
const struct llama_batch & batch,
|
||||
enum llama_mtp_op_type op_type) {
|
||||
const struct llama_batch & batch) {
|
||||
const uint32_t n_tokens = batch.n_tokens;
|
||||
|
||||
if (cache.recurrent) {
|
||||
@ -1118,51 +1126,6 @@ static bool llama_kv_cache_find_slot(
|
||||
}
|
||||
// otherwise, one cell per token.
|
||||
|
||||
bool is_mtp_special_op = (op_type == MTP_OP_WARMUP ||
|
||||
op_type == MTP_OP_UPDATE_ACCEPTED);
|
||||
if (is_mtp_special_op) {
|
||||
const llama_pos target_pos = batch.pos[0];
|
||||
const llama_seq_id target_seq = batch.seq_id[0][0];
|
||||
|
||||
bool found = false;
|
||||
|
||||
if (cache.mtp_kv_head_hint < cache.size &&
|
||||
cache.cells[cache.mtp_kv_head_hint].pos == target_pos &&
|
||||
cache.cells[cache.mtp_kv_head_hint].has_seq_id(target_seq)) {
|
||||
cache.head = cache.mtp_kv_head_hint;
|
||||
found = true;
|
||||
}
|
||||
else if (cache.head < cache.size &&
|
||||
cache.cells[cache.head].pos == target_pos &&
|
||||
cache.cells[cache.head].has_seq_id(target_seq)) {
|
||||
found = true;
|
||||
}
|
||||
else {
|
||||
for (uint32_t i = 0; i < cache.size; ++i) {
|
||||
if (cache.cells[i].pos == target_pos &&
|
||||
cache.cells[i].has_seq_id(target_seq)) {
|
||||
|
||||
cache.head = i;
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!found) {
|
||||
LLAMA_LOG_ERROR("%s: MTP Update failed - slot for seq %d pos %d not found\n",
|
||||
__func__, target_seq, target_pos);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (cache.head + n_tokens > cache.size) {
|
||||
LLAMA_LOG_ERROR("%s: MTP Update out of bounds\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
if (n_tokens > cache.size) {
|
||||
LLAMA_LOG_ERROR("%s: n_tokens=%d > cache.size=%d\n", __func__, n_tokens, cache.size);
|
||||
return false;
|
||||
@ -3922,14 +3885,10 @@ static int llama_decode_internal(
|
||||
kv_self.head = 0;
|
||||
}
|
||||
|
||||
if (!llama_kv_cache_find_slot(kv_self, u_batch, cparams.mtp_op_type)) {
|
||||
if (!llama_kv_cache_find_slot(kv_self, u_batch)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (cparams.mtp_op_type == MTP_OP_NONE) {
|
||||
kv_self.mtp_kv_head_hint = kv_self.head;
|
||||
}
|
||||
|
||||
if (!kv_self.recurrent) {
|
||||
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
||||
// after enough generations, the benefit from this heuristic disappears
|
||||
@ -6842,7 +6801,7 @@ struct llama_data_read {
|
||||
batch.n_seq_id[i] = 1;
|
||||
batch.seq_id[i][0] = dest_seq_id;
|
||||
}
|
||||
if (!llama_kv_cache_find_slot(kv_self, batch, ctx->cparams.mtp_op_type)) {
|
||||
if (!llama_kv_cache_find_slot(kv_self, batch)) {
|
||||
llama_batch_free(batch);
|
||||
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
|
||||
return false;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user