ik_llama.cpp/common/speculative.cpp
Kawrakow f5e5753c32
Fix Qwen35 mtp warmup (#1987)
* Use hidden state from prev token from qwen mtp

* Fix Qwen35 MTP warmup

* Cleanup + remove unnecessary crippling performance by not using accept to sample draft token

* Provide API to gtet the model arch string

---------

Co-authored-by: SamuelOliveirads <samueloliveira32df@gmail.com>
2026-06-18 09:03:40 +02:00

2975 lines
103 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#include "speculative.h"
#include "common.h"
#include "ggml.h"
#include "llama.h"
#include "log.h"
#include "ngram-cache.h"
#include "ngram-map.h"
#include "ngram-mod.h"
#include "sampling.h"
#include "suffix-tree.h"
#include <algorithm>
#include <atomic>
#include <cstdlib>
#include <cstring>
#include <iomanip>
#include <limits>
#include <map>
#include <sstream>
#include <unordered_map>
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
void llama_set_mtp_target_context(struct llama_context * ctx, struct llama_context * target_ctx);
const std::vector<enum common_speculative_type> common_speculative_types = {
COMMON_SPECULATIVE_TYPE_NONE,
COMMON_SPECULATIVE_TYPE_DRAFT,
COMMON_SPECULATIVE_TYPE_DFLASH,
COMMON_SPECULATIVE_TYPE_MTP,
COMMON_SPECULATIVE_TYPE_EAGLE3,
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE,
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K,
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V,
COMMON_SPECULATIVE_TYPE_NGRAM_MOD,
COMMON_SPECULATIVE_TYPE_NGRAM_CACHE,
COMMON_SPECULATIVE_TYPE_SUFFIX
};
const std::map<std::string, enum common_speculative_type> common_speculative_type_from_name_map = {
{"none", COMMON_SPECULATIVE_TYPE_NONE},
{"draft", COMMON_SPECULATIVE_TYPE_DRAFT},
{"dflash", COMMON_SPECULATIVE_TYPE_DFLASH},
{"mtp", COMMON_SPECULATIVE_TYPE_MTP},
{"eagle3", COMMON_SPECULATIVE_TYPE_EAGLE3},
{"ngram_simple", COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE},
{"ngram_map_k", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K},
{"ngram_map_k4v", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V},
{"ngram_mod", COMMON_SPECULATIVE_TYPE_NGRAM_MOD},
{"ngram_cache", COMMON_SPECULATIVE_TYPE_NGRAM_CACHE},
{"suffix", COMMON_SPECULATIVE_TYPE_SUFFIX}
};
void common_speculative_checkpoint::clear() {
valid = false;
per_step_enabled = false;
n_past = 0;
sampled = LLAMA_TOKEN_NULL;
if (sampler != nullptr) {
common_sampler_free(sampler);
sampler = nullptr;
}
}
struct common_speculative_config {
common_speculative_stage_params stage;
common_speculative_type type;
common_params_speculative params;
common_speculative_config(
const common_speculative_stage_params & s,
const common_params_speculative & p = common_params_speculative{})
: stage(s), type(s.type), params(p) {}
};
static bool common_speculative_are_compatible(
const llama_model * model_tgt,
const llama_model * model_dft) {
const llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt);
const llama_vocab * vocab_dft = llama_model_get_vocab(model_dft);
const auto vocab_type_tgt = llama_vocab_type(vocab_tgt);
LOG_DBG("%s: vocab_type tgt: %d\n", __func__, vocab_type_tgt);
const auto vocab_type_dft = llama_vocab_type(vocab_dft);
LOG_DBG("%s: vocab_type dft: %d\n", __func__, vocab_type_dft);
if (vocab_type_tgt != vocab_type_dft) {
LOG_DBG("%s: draft model vocab type must match target model to use speculation but ", __func__);
LOG_DBG("vocab_type_dft = %d while vocab_type_tgt = %d\n", vocab_type_dft, vocab_type_tgt);
return false;
}
if (
llama_vocab_get_add_bos(vocab_tgt) != llama_vocab_get_add_bos(vocab_dft) ||
llama_vocab_get_add_eos(vocab_tgt) != llama_vocab_get_add_eos(vocab_dft) ||
llama_vocab_bos(vocab_tgt) != llama_vocab_bos(vocab_dft) ||
llama_vocab_eos(vocab_tgt) != llama_vocab_eos(vocab_dft)
) {
LOG_DBG("%s: draft model special tokens must match target model to use speculation\n", __func__);
return false;
}
{
const int n_vocab_tgt = llama_vocab_n_tokens(vocab_tgt);
const int n_vocab_dft = llama_vocab_n_tokens(vocab_dft);
const int vocab_diff = n_vocab_tgt > n_vocab_dft
? n_vocab_tgt - n_vocab_dft
: n_vocab_dft - n_vocab_tgt;
if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
LOG_DBG("%s: draft model vocab must closely match target model to use speculation but ", __func__);
LOG_DBG("target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
n_vocab_tgt, llama_vocab_n_tokens(vocab_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
return false;
}
for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) {
const char * token_text_tgt = llama_vocab_get_text(vocab_tgt, i);
const char * token_text_dft = llama_vocab_get_text(vocab_dft, i);
if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
LOG_DBG("%s: draft model vocab must match target model to use speculation but ", __func__);
LOG_DBG("token %d content differs - target '%s', draft '%s'\n", i,
common_token_to_piece(vocab_tgt, i).c_str(),
common_token_to_piece(vocab_dft, i).c_str());
return false;
}
}
}
return true;
}
// state of an implementation of speculative decoding
//
// each implementation has a unique type and a state that is implementation-specific
// in a subclass of common_speculative_state
struct common_speculative_state {
const enum common_speculative_type type;
size_t n_call_begin = 0; // number of times this implementation was called for refresh.
size_t n_call_draft = 0; // number of times this implementation was called for generation.
size_t n_call_accept = 0; // number of times this implementation was called for accumulation.
size_t n_gen_drafts = 0; // number of times a draft or part was generated by this implementation.
size_t n_acc_drafts = 0; // number of times a draft or part was accepted by the target model.
size_t n_gen_tokens = 0; // number of tokens generated by this implementation.
size_t n_acc_tokens = 0; // number of tokens accepted by the target model.
// TODO: track performance of most recent calls
const bool gen_perf = true; // whether to generate performance stats.
int64_t t_begin_us = 0; // total time spent in refresh of this implementation in microseconds.
int64_t t_draft_us = 0; // total time spent in generating drafts in this implementation in microseconds.
int64_t t_accept_us = 0; // total time spent in accumulation of this implementation in microseconds.
common_speculative_state(enum common_speculative_type type) : type(type) {}
virtual ~common_speculative_state() = default;
virtual void begin(const llama_tokens & prompt) = 0;
virtual void draft(
const common_params_speculative & params,
const llama_tokens & prompt_tgt,
llama_token id_last,
llama_tokens & result) = 0;
virtual void draft(
const common_params_speculative & params,
const llama_tokens & prompt_tgt,
llama_token id_last,
llama_pos draft_base_pos,
llama_seq_id draft_seq_id,
llama_tokens & result) {
GGML_UNUSED(draft_base_pos);
GGML_UNUSED(draft_seq_id);
draft(params, prompt_tgt, id_last, result);
}
virtual void accept(uint16_t n_accepted) = 0;
};
struct common_speculative_state_mtp;
struct common_speculative_state_dflash;
static common_speculative_state_mtp * common_speculative_get_mtp_state(common_speculative * spec);
static const common_speculative_state_mtp * common_speculative_get_mtp_state(const common_speculative * spec);
static common_speculative_state_dflash * common_speculative_get_dflash_state(common_speculative * spec);
static const common_speculative_state_dflash * common_speculative_get_dflash_state(const common_speculative * spec);
static int32_t common_speculative_feature_width(const common_speculative * spec);
static void mtp_invalidate_cached_drafts(common_speculative_state_mtp & state);
static bool common_speculative_checkpoint_save(
common_speculative_checkpoint & ckpt,
llama_model * model,
llama_context * ctx,
common_sampler * sampler_src,
const common_params_sampling & sparams,
llama_seq_id seq_id,
llama_pos n_past,
llama_token sampled,
int max_tokens,
int ckpt_mode);
static std::vector<llama_token> mtp_speculative_gen_draft(
common_speculative_state_mtp & state,
struct common_sampler * smpl,
struct llama_context * ctx,
int n_draft,
float p_min,
llama_token id_last,
llama_pos n_past,
llama_seq_id seq_id,
bool constant_draft_positions = false);
static int32_t mtp_update_kv_cache(struct llama_context * ctx, const llama_batch & batch, bool is_prompt_warmup);
struct mtp_last_embd {
std::vector<float> embd;
float prob = 0.0f;
int last_id = -1;
};
struct common_speculative_state_mtp : public common_speculative_state {
llama_context * ctx_tgt;
llama_context * ctx_mtp = nullptr;
common_sampler * smpl;
// For Gemma 4 external MTP assistant: draft positions are held constant
bool constant_draft_positions = false;
int n_embd = 0;
std::unordered_map<llama_seq_id, std::vector<float>> target_hidden_by_seq;
std::unordered_map<llama_seq_id, mtp_last_embd> draft_cache_by_seq;
common_speculative_state_mtp(
enum common_speculative_type type,
llama_context * ctx_tgt,
llama_context * ctx_mtp,
bool constant_draft_positions = false)
: common_speculative_state(type)
, ctx_tgt(ctx_tgt)
, ctx_mtp(ctx_mtp)
, constant_draft_positions(constant_draft_positions)
{
struct common_params_sampling sparams;
sparams.samplers_sequence = {
llama_sampler_type::DIST,
};
smpl = common_sampler_init(llama_get_model(ctx_mtp), sparams);
llama_set_mtp_target_context(ctx_mtp, ctx_tgt);
n_embd = llama_mtp_state_n_embd(ctx_mtp);
LOG_INF("%s: MTP context ready (n_ctx=%d, constant_draft_positions=%s)\n", __func__,
llama_n_ctx(ctx_mtp), constant_draft_positions ? "true" : "false");
}
~common_speculative_state_mtp() override {
common_sampler_free(smpl);
if (ctx_mtp) {
llama_free(ctx_mtp);
}
}
void begin(const llama_tokens & prompt) override {
GGML_UNUSED(prompt);
target_hidden_by_seq.clear();
draft_cache_by_seq.clear();
}
void draft(
const common_params_speculative & params,
const llama_tokens & prompt_tgt,
llama_token id_last,
llama_tokens & result) override {
draft(params, prompt_tgt, id_last, -1, 0, result);
}
void draft(
const common_params_speculative & params,
const llama_tokens & prompt_tgt,
llama_token id_last,
llama_pos draft_base_pos,
llama_seq_id seq_id,
llama_tokens & result) override {
const llama_pos mtp_pos_max = llama_kv_cache_seq_pos_max(ctx_mtp, seq_id);
const bool has_draft_base_pos = draft_base_pos >= 0;
// Prefer the target slot position when the caller has it. Gemma4 external MTP reads
// the target KV cache directly, so ctx_mtp's own KV position is not authoritative.
const llama_pos n_past = has_draft_base_pos
? draft_base_pos
: (mtp_pos_max >= 0 ? mtp_pos_max + 1 : (llama_pos) prompt_tgt.size());
if (!has_draft_base_pos && !prompt_tgt.empty() && mtp_pos_max < (llama_pos)prompt_tgt.size() - 1) {
LOG_WRN("%s: MTP context not fully warmed up: pos_max = %d, expected = %d\n",
__func__, (int)mtp_pos_max, (int)prompt_tgt.size() - 1);
}
if (has_draft_base_pos && !constant_draft_positions && mtp_pos_max < n_past - 1) {
LOG_WRN("%s: MTP context not fully warmed up: pos_max = %d, expected >= %d\n",
__func__, (int)mtp_pos_max, (int)n_past - 1);
}
llama_context * ctx = ctx_mtp;
const auto hidden_it = target_hidden_by_seq.find(seq_id);
if (hidden_it == target_hidden_by_seq.end() || (int) hidden_it->second.size() != n_embd) {
LOG_WRN("%s: missing target hidden state for seq_id %d\n", __func__, (int) seq_id);
result.clear();
return;
}
if (!llama_set_draft_input_hidden_state_copy(ctx, hidden_it->second.data(), hidden_it->second.size())) {
result.clear();
return;
}
result = mtp_speculative_gen_draft(
*this,
smpl,
ctx,
params.n_max,
params.p_min,
id_last,
n_past,
seq_id,
constant_draft_positions
);
}
void accept(uint16_t n_accepted) override {
GGML_UNUSED(n_accepted);
}
};
#include "speculative-dflash-impl.h"
struct common_speculative_state_draft : public common_speculative_state {
llama_context * ctx_tgt; // only used for retokenizing from ctx_dft
llama_context * ctx_dft;
common_sampler * smpl;
llama_batch batch;
llama_tokens prompt_dft;
bool vocab_cmpt = true; // whether retokenization is needed
std::unordered_map<std::string, std::string> vocab_map;
common_speculative_state_draft(
enum common_speculative_type type,
llama_context * ctx_tgt,
llama_context * ctx_dft,
const std::vector<std::pair<std::string, std::string>> & replacements)
: common_speculative_state(type)
, ctx_tgt(ctx_tgt)
, ctx_dft(ctx_dft)
{
batch = llama_batch_init(llama_n_batch(ctx_dft), 0, 1);
smpl = nullptr;
{
struct common_params_sampling params;
params.top_k = 10;
params.samplers_sequence = {
llama_sampler_type::TOP_K,
llama_sampler_type::DIST, // needed to get probabilities
};
smpl = common_sampler_init(llama_get_model(ctx_dft), params);
}
vocab_cmpt = common_speculative_are_compatible(llama_get_model(ctx_tgt), llama_get_model(ctx_dft));
LOG_DBG("vocab_cmpt = %d\n", vocab_cmpt);
if (!vocab_cmpt) {
LOG_WRN("the target and draft vocabs are not compatible - tokens will be translated between the two\n");
for (const auto & pair : replacements) {
vocab_map[pair.first] = pair.second;
}
}
}
~common_speculative_state_draft() override {
llama_free(ctx_dft);
common_sampler_free(smpl);
llama_batch_free(batch);
}
void begin(const llama_tokens & prompt) override {
GGML_UNUSED(prompt);
}
void draft(
const common_params_speculative & params,
const llama_tokens & prompt_tgt,
llama_token id_last,
llama_tokens & result) override {
auto * spec = this;
auto & batch = spec->batch;
auto & ctx_tgt = spec->ctx_tgt;
auto & ctx_dft = spec->ctx_dft;
auto & smpl = spec->smpl;
auto & prompt_dft = spec->prompt_dft;
int reuse_i = 0;
int reuse_n = 0;
const int n_ctx = llama_n_ctx(ctx_dft) - params.n_max;
llama_tokens prompt_cnv;
if (!spec->vocab_cmpt) {
// convert id_last to draft vocab. llama_detokenize is called directly to avoid an allocation
const auto * model_tgt = llama_get_model(ctx_tgt);
const auto * vocab_tgt = llama_model_get_vocab(model_tgt);
std::string text;
text = common_detokenize(ctx_tgt, prompt_tgt, true);
text = replace_to_dft(text);
LOG_DBG("%s: main->draft detokenized string: '%s'\n", __func__, text.c_str());
prompt_cnv = common_tokenize(ctx_dft, text, false, true);
int32_t n_chars = llama_detokenize(vocab_tgt, &id_last, 1, nullptr, 0, false, false);
GGML_ASSERT(n_chars < 0 && "failed to detokenize id_last");
text.resize(-n_chars);
llama_detokenize(vocab_tgt, &id_last, 1, text.data(), text.size(), false, false);
text = replace_to_dft(text);
LOG_DBG("main->draft detokenized id_last(%d): '%s'\n", id_last, text.c_str());
id_last = common_tokenize(ctx_dft, text, false, true)[0];
}
const llama_tokens & prompt_cur = spec->vocab_cmpt ? prompt_tgt : prompt_cnv;
const int i_start = std::max<int>(0, (int) prompt_cur.size() - n_ctx);
// reuse as much as possible from the old draft context
// ideally, the draft context should be as big as the target context and we will always reuse the entire prompt
for (int i = 0; i < (int) prompt_dft.size(); ++i) {
int cur = 0;
while (i_start + cur < (int) prompt_cur.size() &&
i + cur < (int) prompt_dft.size() &&
prompt_cur[i_start + cur] == prompt_dft[i + cur]) {
cur++;
}
if ((cur >= 256 || n_ctx >= (int) prompt_cur.size()) && cur > reuse_n) {
reuse_i = i;
reuse_n = cur;
}
}
LOG_DBG("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt_dft.size());
result.clear();
result.reserve(params.n_max);
if (reuse_n == 0) {
llama_kv_cache_clear(ctx_dft);
prompt_dft.clear();
} else {
// this happens when a previous draft has been discarded (for example, due to being too small), but the
// target model agreed with it. in this case, we simply pass back the previous results to save compute
if (reuse_i + reuse_n < (int) prompt_dft.size() && prompt_dft[reuse_i + reuse_n] == id_last) {
for (int i = reuse_i + reuse_n + 1; i < (int) prompt_dft.size(); ++i) {
result.push_back(prompt_dft[i]);
if (params.n_max <= (int) result.size()) {
break;
}
}
return;
}
if (reuse_i > 0) {
llama_kv_cache_seq_rm (ctx_dft, 0, 0, reuse_i);
llama_kv_cache_seq_add(ctx_dft, 0, reuse_i, -1, -reuse_i);
prompt_dft.erase(prompt_dft.begin(), prompt_dft.begin() + reuse_i);
}
if (reuse_n < (int) prompt_dft.size()) {
llama_kv_cache_seq_rm (ctx_dft, 0, reuse_n, -1);
prompt_dft.erase(prompt_dft.begin() + reuse_n, prompt_dft.end());
}
}
// prepare a batch to evaluate any new tokens in the prompt
common_batch_clear(batch);
for (size_t i = i_start + reuse_n; i < prompt_cur.size(); ++i) {
//LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_cur[i]);
common_batch_add(batch, prompt_cur[i], i - i_start, { 0 }, false);
prompt_dft.push_back(prompt_cur[i]);
}
// we should rarely end-up here during normal decoding
if (batch.n_tokens > 0) {
//LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());
llama_decode(ctx_dft, batch);
}
const llama_pos n_past = prompt_dft.size();
LOG_DBG("%s: n_past = %d\n", __func__, n_past);
common_batch_clear(batch);
common_batch_add (batch, id_last, n_past, { 0 }, true);
prompt_dft.push_back(id_last);
//LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx_dft, prompt_dft).c_str());
llama_decode(ctx_dft, batch);
common_sampler_reset(smpl);
// sample n_draft tokens from the draft model
for (int i = 0; i < params.n_max; ++i) {
common_batch_clear(batch);
common_sampler_sample(smpl, ctx_dft, 0, true);
const auto * cur_p = common_sampler_get_candidates(smpl, true);
for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str());
}
// add drafted token for each sequence
const llama_token id = cur_p->data[0].id;
common_sampler_accept(smpl, nullptr, id, true);
// only collect very high-confidence draft tokens
if (cur_p->data[0].p < params.p_min) {
if (i == 0) {
result.push_back(id);
}
break;
}
result.push_back(id);
if (params.n_max <= (int) result.size()) {
break;
}
common_batch_add(batch, id, n_past + i + 1, { 0 }, true);
// evaluate the drafted tokens on the draft model
llama_decode(ctx_dft, batch);
prompt_dft.push_back(id);
}
if (!spec->vocab_cmpt) {
std::string detokenized = common_detokenize(ctx_dft, result, true);
detokenized = replace_to_tgt(detokenized);
LOG_DBG("draft->main detokenized string: '%s'\n", detokenized.c_str());
result = common_tokenize(ctx_tgt, detokenized, false, true);
if (result.size() > (size_t)params.n_max) {
result.resize(params.n_max);
}
}
}
void accept(uint16_t n_accepted) override {
// noop
GGML_UNUSED(n_accepted);
}
std::string replace_to_dft(const std::string & input) const {
std::string result = input;
for (const auto & pair : this->vocab_map) {
size_t pos = result.find(pair.first);
while (pos != std::string::npos) {
result.replace(pos, pair.first.length(), pair.second);
pos = result.find(pair.first, pos + pair.second.length());
}
}
return result;
}
std::string replace_to_tgt(const std::string & input) const {
std::string result = input;
for (const auto & pair : this->vocab_map) {
size_t pos = result.find(pair.second);
while (pos != std::string::npos) {
result.replace(pos, pair.second.length(), pair.first);
pos = result.find(pair.second, pos + pair.first.length());
}
}
return result;
}
};
struct common_speculative_state_eagle3 : public common_speculative_state {
common_speculative_state_eagle3(enum common_speculative_type type) : common_speculative_state(type) {}
void begin(const llama_tokens & prompt) override {
GGML_UNUSED(prompt);
}
void draft(
const common_params_speculative & params,
const llama_tokens & prompt_tgt,
llama_token id_last,
llama_tokens & draft_tokens) override {
// TODO: implement
GGML_UNUSED(params);
GGML_UNUSED(prompt_tgt);
GGML_UNUSED(id_last);
GGML_UNUSED(draft_tokens);
}
void accept(uint16_t n_accepted) override {
// noop
GGML_UNUSED(n_accepted);
}
};
// state of self-speculation (simple implementation, not ngram-map)
struct common_speculative_state_ngram_simple : public common_speculative_state {
common_ngram_simple_config config;
common_speculative_state_ngram_simple(
enum common_speculative_type type,
common_ngram_simple_config config)
: common_speculative_state(type), config(config) {}
void begin(const llama_tokens & prompt) override {
GGML_UNUSED(prompt);
}
void draft(
const common_params_speculative & params,
const llama_tokens & prompt_tgt,
llama_token id_last,
llama_tokens & result) override {
result = common_ngram_simple_draft(config, prompt_tgt, id_last);
GGML_UNUSED(params);
}
void accept(uint16_t n_accepted) override {
// noop
GGML_UNUSED(n_accepted);
}
};
struct common_speculative_state_ngram_map_k : public common_speculative_state {
// draft ngram map for speculative decoding without draft model
common_ngram_map map;
common_speculative_state_ngram_map_k(
enum common_speculative_type type,
common_ngram_map map)
: common_speculative_state(type), map(std::move(map)) {}
void begin(const llama_tokens & prompt) override {
common_ngram_map_begin(map, prompt);
}
void draft(
const common_params_speculative & params,
const llama_tokens & prompt_tgt,
llama_token id_last,
llama_tokens & result) override {
common_ngram_map_draft(map, prompt_tgt, id_last, result);
GGML_UNUSED(params);
}
void accept(uint16_t n_accepted) override {
common_ngram_map_accept(map, n_accepted);
}
};
struct common_speculative_state_ngram_mod : public common_speculative_state {
common_ngram_mod & mod;
// the last position in the prompt that was added to the ngram container
size_t i_last = 0;
// length of the last drafted ngram (number of tokens returned by draft)
size_t n_draft_last = 0;
// consecutive accept rounds with low acceptance fraction (< 0.5)
int n_low = 0;
// enable trace logging if LLAMA_TRACE is set
const bool verbose;
common_speculative_state_ngram_mod(enum common_speculative_type type, common_ngram_mod & mod)
: common_speculative_state(type), mod(mod), verbose(std::getenv("LLAMA_TRACE") != nullptr) {
static_assert(sizeof(llama_token) == sizeof(common_ngram_mod::entry_t));
}
void begin(const llama_tokens & prompt) override {
i_last = 0;
n_draft_last = 0;
n_low = 0;
const size_t n = mod.get_n();
if (prompt.size() < n) {
return;
}
for (size_t i = 0; i < prompt.size() - n; ++i) {
mod.add(prompt.data() + i);
}
i_last = prompt.size() - n;
const double f = (double)mod.get_used() / (double)mod.size();
LOG_INF("%s: ngram_mod occupancy = %zu/%zu (%.2f)\n", __func__, mod.get_used(), mod.size(), f);
constexpr double f_thold = 0.25;
if (f > f_thold) {
LOG_WRN("%s: ngram_mod occupancy %.2f exceeds threshold (%.2f) - resetting\n", __func__, f, f_thold);
mod.reset();
}
}
void draft(
const common_params_speculative & params,
const llama_tokens & prompt_tgt,
llama_token id_last,
llama_tokens & result) override {
GGML_UNUSED(params);
n_draft_last = 0;
const size_t cur_len = prompt_tgt.size();
if (cur_len < mod.get_n()) {
return;
}
const size_t n = mod.get_n();
// add new ngrams in chunks
if (i_last + 32 < cur_len) {
for (size_t i = i_last; i < cur_len - n; ++i) {
mod.add(prompt_tgt.data() + i);
}
i_last = cur_len - n;
}
result.resize(n + params.n_max);
for (size_t i = 0; i < n - 1; ++i) {
result[i] = prompt_tgt[cur_len - n + 1 + i];
}
result[n - 1] = id_last;
for (int i = 0; i < params.n_max; ++i) {
const llama_token token = mod.get(result.data() + i);
if (token == common_ngram_mod::EMPTY) {
if (i < params.n_min) {
result.clear();
return;
}
result.resize(n + i);
break;
}
result[n + i] = token;
}
// only return the m tokens that were drafted
for (size_t i = 0; n + i < result.size(); ++i) {
result[i] = result[n + i];
}
result.resize(result.size() - n);
// store length of drafted ngram for later acceptance analysis
n_draft_last = result.size();
}
void accept(uint16_t n_accepted) override {
if (verbose) {
LOG_INF("%s: accepted %d tokens from %zu drafted tokens\n", __func__, n_accepted, n_draft_last);
}
// compute acceptance fraction if we have a recorded draft length
if (n_draft_last > 0) {
const double f_acc = (double)n_accepted / (double)n_draft_last;
if (f_acc < 0.5) {
n_low++;
if (n_low >= 3) {
LOG_WRN("%s: low acceptance streak (%d) resetting ngram_mod\n", __func__, n_low);
mod.reset();
n_low = 0;
i_last = 0;
}
} else {
n_low = 0;
}
}
}
};
struct common_speculative_state_ngram_cache : public common_speculative_state {
uint16_t n_draft;
bool save_dynamic;
bool save_static;
common_ngram_cache ngram_cache_context;
common_ngram_cache ngram_cache_dynamic;
common_ngram_cache ngram_cache_static;
size_t cache_size = 0; // number of tokens in n-gram cache
common_speculative_state_ngram_cache(
const enum common_speculative_type type,
const std::string & path_static,
const std::string & path_dynamic,
uint16_t n_draft,
bool save_dynamic,
bool save_static)
: common_speculative_state(type)
, n_draft(n_draft)
, save_dynamic(save_dynamic)
, save_static(save_static)
{
if (!path_static.empty()) {
try {
ngram_cache_static = common_ngram_cache_load(path_static);
} catch (...) {
LOG_ERR("failed to open static lookup cache: %s", path_static.c_str());
GGML_ABORT("Couldn't read static lookup cache");
}
}
if (!path_dynamic.empty()) {
try {
ngram_cache_dynamic = common_ngram_cache_load(path_dynamic);
} catch (...) {
LOG_ERR("failed to open dynamic lookup cache: %s", path_dynamic.c_str());
GGML_ABORT("Couldn't read dynamic lookup cache");
}
}
}
void begin(const llama_tokens & prompt) override {
GGML_UNUSED(prompt);
}
void draft(
const common_params_speculative & params,
const llama_tokens & prompt_tgt,
llama_token id_last,
llama_tokens & result) override {
GGML_UNUSED(params);
if (cache_size < prompt_tgt.size() + 1) {
llama_tokens tokens_new;
tokens_new.reserve(prompt_tgt.size() + 1 - cache_size);
for (size_t j = cache_size; j < prompt_tgt.size(); ++j) {
tokens_new.push_back(prompt_tgt[j]);
}
tokens_new.push_back(id_last); // add the last token
// Update context ngram cache with new prompt_tgt:
common_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX,
tokens_new, tokens_new.size(), false);
cache_size = prompt_tgt.size() + 1;
}
llama_tokens inp;
inp.reserve(prompt_tgt.size() + 1);
for (size_t j = 0; j < prompt_tgt.size(); ++j) {
inp.push_back(prompt_tgt[j]);
}
inp.push_back(id_last);
result.push_back(id_last);
common_ngram_cache_draft(inp, result, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX,
ngram_cache_context,
ngram_cache_dynamic,
ngram_cache_static);
if (result.size() > 0) {
// delete first token in result (which is the id_last token)
result.erase(result.begin());
}
}
void accept(uint16_t n_accepted) override {
// TODO: noop
GGML_UNUSED(n_accepted);
}
};
struct common_speculative_state_suffix : public common_speculative_state {
common_suffix_tree tree;
common_suffix_tree corpus_tree;
bool has_corpus = false;
size_t cache_size = 0;
// Acceptance feedback
size_t n_draft_last = 0;
bool had_accept = false;
int n_low = 0;
float base_p_min = 0.1f;
float eff_p_min = 0.1f;
common_speculative_state_suffix(
enum common_speculative_type type,
int max_depth,
const std::string & corpus_path,
const llama_model * model)
: common_speculative_state(type)
, tree(max_depth)
, corpus_tree(max_depth)
{
if (!corpus_path.empty()) {
std::function<std::vector<llama_token>(const std::string &)> tokenize_fn;
if (model) {
tokenize_fn = [model](const std::string & text) -> std::vector<llama_token> {
return common_tokenize(model, text, false, true);
};
}
has_corpus = corpus_tree.load_corpus(corpus_path, tokenize_fn);
}
}
void begin(const llama_tokens & prompt) override {
cache_size = 0;
n_draft_last = 0;
had_accept = false;
n_low = 0;
GGML_UNUSED(prompt);
}
void draft(
const common_params_speculative & params,
const llama_tokens & prompt_tgt,
llama_token id_last,
llama_tokens & result) override {
base_p_min = params.p_min;
if (n_draft_last > 0 && !had_accept) {
if (++n_low >= 3) {
eff_p_min = std::min(eff_p_min + 0.1f, 0.5f);
n_low = 0;
}
}
had_accept = false;
if (cache_size < prompt_tgt.size() + 1) {
llama_tokens tokens_new;
tokens_new.reserve(prompt_tgt.size() + 1 - cache_size);
for (size_t j = cache_size; j < prompt_tgt.size(); ++j) {
tokens_new.push_back(prompt_tgt[j]);
}
tokens_new.push_back(id_last);
tree.extend(tokens_new.data(), (int)tokens_new.size());
cache_size = prompt_tgt.size() + 1;
}
const int ctx_len = std::min((int)(prompt_tgt.size() + 1), tree.max_depth());
llama_tokens context;
context.reserve(ctx_len);
const int ctx_start = (int)prompt_tgt.size() + 1 - ctx_len;
for (int j = ctx_start; j < (int)prompt_tgt.size(); ++j) {
context.push_back(prompt_tgt[j]);
}
context.push_back(id_last);
const int min_match_len = std::max(1, params.suffix_min_match_len);
result = tree.speculate(
context.data(), (int)context.size(),
params.n_max,
eff_p_min,
1,
min_match_len);
if (has_corpus) {
auto corpus_result = corpus_tree.speculate(
context.data(), (int)context.size(),
params.n_max,
eff_p_min,
1,
min_match_len);
if (corpus_result.size() > result.size()) {
result = std::move(corpus_result);
}
}
n_draft_last = result.size();
}
void accept(uint16_t n_accepted) override {
if (n_draft_last == 0) {
return;
}
had_accept = true;
const double f_acc = (double)n_accepted / (double)n_draft_last;
if (f_acc < 0.5) {
if (++n_low >= 3) {
eff_p_min = std::min(eff_p_min + 0.1f, 0.5f);
n_low = 0;
}
} else {
n_low = 0;
if (eff_p_min > base_p_min) {
eff_p_min = std::max(eff_p_min - 0.05f, base_p_min);
}
}
}
};
struct common_speculative {
std::vector<common_speculative_config> configs; // resolved stage config for each implementation
std::vector<std::unique_ptr<common_speculative_state>> impls; // list of implementations to use and their states
common_speculative_checkpoint checkpoint;
common_speculative_state * curr_impl = nullptr; // current implementation in use (for stats)
std::unique_ptr<spec_tuner> tuner;
int last_n_drafted = 0;
int64_t t_step_start_us = 0;
};
static bool common_speculative_stage_chain_matches(
const std::vector<common_speculative_stage_params> & stages,
const std::vector<common_speculative_config> & configs) {
if (stages.size() != configs.size()) {
return false;
}
for (size_t i = 0; i < stages.size(); ++i) {
if (stages[i].type != configs[i].type) {
return false;
}
}
return true;
}
static common_params_speculative common_speculative_get_runtime_params(
const common_speculative_config & config,
const common_params_speculative & params,
const common_speculative_stage_params & stage) {
common_params_speculative result = config.params;
result.type = config.type;
result.n_max = stage.has_n_max_override() ? stage.n_max : params.n_max;
result.n_min = stage.has_n_min_override() ? stage.n_min : params.n_min;
result.p_min = stage.has_p_min_override() ? stage.p_min : params.p_min;
if (config.type == COMMON_SPECULATIVE_TYPE_SUFFIX) {
result.suffix_min_match_len = stage.has_suffix_min_match_len_override()
? stage.suffix_min_match_len
: params.suffix_min_match_len;
}
result.n_max = std::max(result.n_max, 0);
result.n_min = std::max(0, std::min(result.n_min, result.n_max));
result.stages.clear();
return result;
}
static common_ngram_map get_common_ngram_map(const common_speculative_config & config) {
uint16_t size_key = config.params.ngram_size_n;
uint16_t size_value = config.params.ngram_size_m;
bool key_only = (config.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K);
uint16_t min_hits = config.params.ngram_min_hits;
return common_ngram_map(size_key, size_value, key_only, min_hits);
}
static common_speculative_state_ngram_cache create_state_ngram_cache(
const std::string & path_static, const std::string & path_dynamic,
const common_speculative_config & config) {
uint16_t n_draft = 8; // TODO get from config?
// TODO bool param in common/common.h to set save_static/save_dynamic?
bool save_static = false;
bool save_dynamic = false;
common_speculative_state_ngram_cache state(config.type, path_static, path_dynamic, n_draft, save_static, save_dynamic);
return state;
}
std::string common_speculative_type_name_str() {
std::string result;
for (size_t i = 0; i < common_speculative_types.size(); i++) {
if (i > 0) {
result += ", ";
}
result += common_speculative_type_to_str(common_speculative_types[i]);
}
return result;
}
std::string common_speculative_type_to_str(enum common_speculative_type type) {
switch (type) {
case COMMON_SPECULATIVE_TYPE_NONE: return "none";
case COMMON_SPECULATIVE_TYPE_DRAFT: return "draft";
case COMMON_SPECULATIVE_TYPE_DFLASH: return "dflash";
case COMMON_SPECULATIVE_TYPE_MTP: return "mtp";
case COMMON_SPECULATIVE_TYPE_EAGLE3: return "eagle3";
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return "ngram_simple";
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: return "ngram_map_k";
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: return "ngram_map_k4v";
case COMMON_SPECULATIVE_TYPE_NGRAM_MOD: return "ngram_mod";
case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: return "ngram_cache";
case COMMON_SPECULATIVE_TYPE_SUFFIX: return "suffix";
default: return "unknown";
}
}
enum common_speculative_type common_speculative_type_from_name(const std::string & name) {
std::string normalized = name;
std::replace(normalized.begin(), normalized.end(), '-', '_');
const auto it = common_speculative_type_from_name_map.find(normalized);
if (it == common_speculative_type_from_name_map.end()) {
return COMMON_SPECULATIVE_TYPE_COUNT;
}
return it->second;
}
bool common_speculative_is_compat(llama_context * ctx_tgt) {
bool res = true;
llama_kv_cache_clear(ctx_tgt);
// eval 2 tokens to check if the context is compatible
std::vector<llama_token> tmp;
tmp.push_back(0);
tmp.push_back(0);
int ret = llama_decode(ctx_tgt, llama_batch_get_one(tmp.data(), tmp.size(), 0, 0));
if (ret != 0) {
LOG_ERR("%s: llama_decode() failed: %d\n", __func__, ret);
res = false;
goto done;
}
// try to remove the last tokens
if (!llama_kv_cache_seq_rm(ctx_tgt, 0, 1, -1)) {
LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__);
res = false;
goto done;
}
done:
llama_kv_cache_clear(ctx_tgt);
llama_synchronize(ctx_tgt);
return res;
}
// initialization of the speculative decoding system
//
common_speculative * common_speculative_init(
common_params_speculative & params,
llama_context * ctx_tgt) {
std::string chain_error;
if (!common_speculative_validate_chain(params, &chain_error)) {
LOG_ERR("%s: invalid speculative stage chain: %s\n", __func__, chain_error.c_str());
return nullptr;
}
const auto stages = params.get_resolved_stages();
if (params.model_dft && llama_model_is_gemma4_mtp_assistant(params.model_dft)) {
const bool has_draft_stage = std::any_of(stages.begin(), stages.end(), [](const common_speculative_stage_params & stage) {
return stage.type == COMMON_SPECULATIVE_TYPE_DRAFT;
});
if (has_draft_stage) {
LOG_ERR("%s: Gemma4 assistant models only support MTP stages; omit -md for self-spec-only runs or use -mtp/--spec-stage mtp for assistant-backed MTP\n", __func__);
return nullptr;
}
}
const bool has_dflash_stage = std::any_of(stages.begin(), stages.end(), [](const common_speculative_stage_params & stage) {
return stage.type == COMMON_SPECULATIVE_TYPE_DFLASH;
});
const bool needs_draft_ctx = std::any_of(stages.begin(), stages.end(), [&params](const common_speculative_stage_params & stage) {
return stage.type == COMMON_SPECULATIVE_TYPE_DRAFT ||
stage.type == COMMON_SPECULATIVE_TYPE_DFLASH ||
(stage.type == COMMON_SPECULATIVE_TYPE_MTP && params.model_dft != nullptr);
});
llama_context * ctx_dft = nullptr;
if (needs_draft_ctx) {
if (!params.model_dft) {
LOG_ERR("%s: draft speculative stage requires a loaded draft model\n", __func__);
return nullptr;
}
llama_context_params cparams_dft = params.cparams_dft;
if (has_dflash_stage) {
if (!llama_model_share_dflash_io_tensors(params.model_dft, llama_get_model(ctx_tgt))) {
LOG_ERR("%s: failed to share target IO tensors with DFlash draft model\n", __func__);
return nullptr;
}
int32_t max_cross_ctx = 0;
for (const auto & stage : stages) {
if (stage.type != COMMON_SPECULATIVE_TYPE_DFLASH) {
continue;
}
max_cross_ctx = std::max(max_cross_ctx, params.with_stage_overrides(stage).dflash_cross_ctx);
}
const int32_t block_size = llama_model_dflash_block_size(params.model_dft);
if (block_size <= 0) {
LOG_ERR("%s: invalid DFlash draft block size\n", __func__);
return nullptr;
}
const int64_t required_n_ctx = (int64_t) max_cross_ctx + (int64_t) block_size;
if (required_n_ctx > std::numeric_limits<int32_t>::max()) {
LOG_ERR("%s: invalid DFlash draft context size cross_ctx=%d block_size=%d required_n_ctx=%lld\n",
__func__, max_cross_ctx, block_size, (long long) required_n_ctx);
return nullptr;
}
cparams_dft.n_ctx = (uint32_t) required_n_ctx;
}
ctx_dft = llama_init_from_model(params.model_dft, cparams_dft);
if (ctx_dft == nullptr) {
LOG_ERR("%s", "failed to create draft context\n");
return nullptr;
}
}
// Compute the implementations to use based on the resolved stage chain.
std::vector<common_speculative_config> configs = {};
configs.reserve(stages.size());
for (const auto & stage : stages) {
common_params_speculative stage_params = params.with_stage_overrides(stage);
if (stage.type == COMMON_SPECULATIVE_TYPE_NGRAM_MOD && !stage_params.ngram_mod) {
stage_params.ngram_mod = std::make_shared<common_ngram_mod>(stage_params.ngram_size_n, 4*1024*1024);
LOG_INF("%s: initialized ngram_mod with n=%d, size=%zu (%.3f MB)\n", __func__,
stage_params.ngram_size_n, stage_params.ngram_mod->size(),
(float)(stage_params.ngram_mod->size_bytes())/1024/1024);
if (stage_params.ngram_size_n < 16) {
LOG_WRN("%s: ngram_mod n=%d is too small - poor quality is possible, see: https://github.com/ggml-org/llama.cpp/pull/19164\n", __func__, stage_params.ngram_size_n);
}
}
configs.push_back(common_speculative_config(stage, stage_params));
}
if (!configs.empty() && llama_model_has_recurrent(llama_get_model(ctx_tgt))) {
const int ckpt_tokens = std::max(1, params.get_max_stage_n_max() + 1);
const int actual_mode = llama_spec_ckpt_init(ctx_tgt, params.recurrent_ckpt_mode, ckpt_tokens);
if (actual_mode == LLAMA_SPEC_CKPT_NONE) {
LOG_ERR("%s: failed to prepare recurrent checkpoint mode '%s' during speculative init (max_tokens=%d)\n",
__func__,
params.recurrent_ckpt_mode == LLAMA_SPEC_CKPT_PER_STEP ? "per-step" :
params.recurrent_ckpt_mode == LLAMA_SPEC_CKPT_GPU_FALLBACK ? "gpu-fallback" :
params.recurrent_ckpt_mode == LLAMA_SPEC_CKPT_CPU ? "cpu" : "auto",
ckpt_tokens);
if (ctx_dft != nullptr) {
llama_free(ctx_dft);
}
return nullptr;
}
llama_spec_ckpt_discard(ctx_tgt);
params.recurrent_ckpt_mode = actual_mode;
}
std::vector<std::unique_ptr<common_speculative_state>> impls = {};
for (const common_speculative_config & config : configs) {
LOG_DBG("%s: adding implementation %s\n", __func__, common_speculative_type_to_str(config.type).c_str());
switch (config.type) {
case COMMON_SPECULATIVE_TYPE_NONE:
break;
case COMMON_SPECULATIVE_TYPE_DRAFT: {
impls.push_back(std::make_unique<common_speculative_state_draft>(config.type,
/* .ctx_tgt = */ ctx_tgt,
/* .ctx_dft = */ ctx_dft,
/* .replacements = */ config.params.replacements
));
break;
}
case COMMON_SPECULATIVE_TYPE_DFLASH: {
auto state = std::make_unique<common_speculative_state_dflash>(
config.type,
ctx_tgt,
ctx_dft,
config.params.dflash_cross_ctx);
if (!state->ready) {
LOG_ERR("%s: failed to initialize DFlash speculative state\n", __func__);
return nullptr;
}
impls.push_back(std::move(state));
ctx_dft = nullptr;
break;
}
case COMMON_SPECULATIVE_TYPE_MTP: {
llama_context * ctx_mtp = ctx_dft;
if (!ctx_mtp) {
const llama_model * model = llama_get_model(ctx_tgt);
ctx_mtp = llama_init_from_model(const_cast<llama_model *>(model), config.params.cparams_dft);
if (!ctx_mtp) {
LOG_ERR("%s: failed to create MTP context\n", __func__);
return nullptr;
}
}
ctx_dft = nullptr;
const bool use_constant_draft_positions = llama_model_is_gemma4_mtp_assistant(llama_get_model(ctx_mtp));
impls.push_back(std::make_unique<common_speculative_state_mtp>(
config.type, ctx_tgt, ctx_mtp, use_constant_draft_positions));
break;
}
case COMMON_SPECULATIVE_TYPE_EAGLE3: {
impls.push_back(std::make_unique<common_speculative_state_eagle3>(config.type));
break;
}
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: {
common_ngram_map ngram_map = get_common_ngram_map(config);
uint16_t ngram_size_key = ngram_map.size_key;
uint16_t mgram_size_value = ngram_map.size_value;
auto config_simple = common_ngram_simple_config {
/* .size_ngram = */ ngram_size_key,
/* .size_mgram = */ mgram_size_value
};
auto state = std::make_unique<common_speculative_state_ngram_simple>(
/* .type = */ config.type,
/* .state = */ config_simple
);
impls.push_back(std::move(state));
break;
}
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K:
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: {
impls.push_back(std::make_unique<common_speculative_state_ngram_map_k>(
(config.type),
get_common_ngram_map(config)
));
break;
}
case COMMON_SPECULATIVE_TYPE_NGRAM_MOD: {
GGML_ASSERT(config.params.ngram_mod);
impls.push_back(std::make_unique<common_speculative_state_ngram_mod>(config.type, *config.params.ngram_mod));
break;
}
case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: {
auto state = create_state_ngram_cache(
config.params.lookup_cache_static, config.params.lookup_cache_dynamic, config);
impls.push_back(std::make_unique<common_speculative_state_ngram_cache>(state));
break;
}
case COMMON_SPECULATIVE_TYPE_SUFFIX: {
int depth = config.params.suffix_max_depth > 0 ? config.params.suffix_max_depth : 64;
const llama_model * model = llama_get_model(ctx_tgt);
impls.push_back(std::make_unique<common_speculative_state_suffix>(
config.type, depth, config.params.suffix_corpus, model));
break;
}
default:
break;
}
}
if (impls.empty()) {
LOG_WRN("%s", "no implementations specified for speculative decoding\n");
return nullptr;
}
auto * result = new common_speculative {
/* .configs = */ std::move(configs),
/* .impls = */ std::move(impls)
};
// initialize autotune if requested
if (params.autotune && params.has_composite_stage_chain()) {
LOG_WRN("Autotune disabled — explicit speculative stage chains are not supported yet\n");
} else if (params.autotune && !result->impls.empty()) {
auto actual_type = result->impls[0]->type;
if (actual_type != COMMON_SPECULATIVE_TYPE_NONE &&
actual_type != COMMON_SPECULATIVE_TYPE_EAGLE3) {
result->tuner = std::make_unique<spec_tuner>();
result->tuner->init(actual_type, params, llama_get_model(ctx_tgt));
LOG_DBG("Autotune initialized for %s, tuning %zu parameters\n",
common_speculative_type_to_str(actual_type).c_str(),
result->tuner->coords.size());
} else {
LOG_WRN("Autotune disabled — speculative type %s is not supported for autotuning\n",
common_speculative_type_to_str(actual_type).c_str());
}
}
return result;
}
void common_speculative_free(common_speculative * spec) {
if (spec == nullptr) {
return;
}
spec->checkpoint.clear();
delete spec;
}
void common_speculative_begin(common_speculative * spec, const llama_tokens & prompt) {
if (spec == nullptr) {
return;
}
for (auto & impl : spec->impls) {
common_time_meas tm(impl->t_begin_us, !impl->gen_perf);
impl->begin(prompt);
impl->n_call_begin++;
}
}
llama_tokens common_speculative_draft(
common_speculative * spec,
common_params_speculative & params,
const llama_tokens & prompt_tgt, // specified in target model vocab
llama_token id_last,
llama_pos draft_base_pos,
llama_seq_id draft_seq_id) {
llama_tokens result;
spec->t_step_start_us = ggml_time_us();
// apply autotune proposal if enabled
if (spec->tuner && spec->tuner->enabled) {
spec->tuner->propose(params);
}
const auto runtime_stages = params.get_resolved_stages();
const bool use_runtime_stage_overrides = common_speculative_stage_chain_matches(runtime_stages, spec->configs);
spec->curr_impl = nullptr; // reset current implementation
for (size_t i = 0; i < spec->impls.size(); ++i) {
auto & impl = spec->impls[i];
const auto & runtime_stage = use_runtime_stage_overrides ? runtime_stages[i] : spec->configs[i].stage;
common_params_speculative impl_params = common_speculative_get_runtime_params(spec->configs[i], params, runtime_stage);
result.clear();
{
common_time_meas tm(impl->t_draft_us, !impl->gen_perf);
impl->draft(impl_params, prompt_tgt, id_last, draft_base_pos, draft_seq_id, result);
impl->n_call_draft++;
}
if (result.empty()) {
continue;
}
if (common_speculative_type_is_self_spec(impl->type) && impl_params.n_min > 0 && (int)result.size() < impl_params.n_min) {
LOG_DBG("%s: impl %s drafted %zu tokens, below fallback threshold %d - trying next implementation\n",
__func__, common_speculative_type_to_str(impl->type).c_str(), result.size(), impl_params.n_min);
result.clear();
continue;
}
LOG_DBG("%s: called impl %s, hist size = %zu, call_count = %zu, gen = %zu\n", __func__,
common_speculative_type_to_str(impl.get()->type).c_str(), prompt_tgt.size(),
impl.get()->n_call_draft, result.size());
spec->curr_impl = impl.get();
impl->n_gen_drafts++;
impl->n_gen_tokens += result.size();
break; // We have a draft, so break out of the loop and return it.
}
// store draft count for tuner feedback
if (spec->tuner && spec->tuner->enabled) {
spec->last_n_drafted = (int)result.size();
}
return result;
}
void common_speculative_accept(common_speculative * spec, uint16_t n_accepted) {
if (spec->tuner && spec->tuner->enabled && spec->t_step_start_us > 0) {
int64_t step_time_us = ggml_time_us() - spec->t_step_start_us;
double step_tps = (step_time_us > 100)
? (n_accepted + 1.0) * 1e6 / (double)step_time_us
: 0.0;
spec->tuner->accept_feedback(n_accepted, spec->last_n_drafted, step_tps);
spec->t_step_start_us = 0;
}
common_speculative_state * impl = spec->curr_impl;
if (!impl) {
return;
}
{
common_time_meas tm(impl->t_accept_us, !impl->gen_perf);
if (n_accepted > 0) {
impl->n_acc_drafts++;
impl->n_acc_tokens += n_accepted;
}
impl->accept(n_accepted);
impl->n_call_accept++;
}
if (impl->type != COMMON_SPECULATIVE_TYPE_MTP) {
if (auto * mtp_state = common_speculative_get_mtp_state(spec); mtp_state != nullptr) {
mtp_invalidate_cached_drafts(*mtp_state);
}
}
}
static bool common_speculative_has_type(const common_speculative * spec, common_speculative_type type) {
if (spec == nullptr) {
return false;
}
return std::any_of(spec->configs.begin(), spec->configs.end(), [type](const common_speculative_config & config) {
return config.type == type;
});
}
static int common_speculative_ctx_mtp_n_embd(llama_context * ctx) {
return ctx ? (int) llama_mtp_state_n_embd(ctx) : 0;
}
static bool common_speculative_batch_token_has_seq_id(
const llama_batch & batch,
int token_index,
llama_seq_id seq_id) {
if (batch.n_seq_id == nullptr || batch.seq_id == nullptr || batch.n_seq_id[token_index] <= 0 || batch.seq_id[token_index] == nullptr) {
return false;
}
for (int i = 0; i < batch.n_seq_id[token_index]; ++i) {
if (batch.seq_id[token_index][i] == seq_id) {
return true;
}
}
return false;
}
static bool common_speculative_batch_is_exact_single_seq(
const llama_batch & batch,
llama_seq_id seq_id) {
if (batch.n_tokens <= 0 || batch.n_seq_id == nullptr || batch.seq_id == nullptr) {
return false;
}
for (int i = 0; i < batch.n_tokens; ++i) {
if (batch.n_seq_id[i] != 1 || batch.seq_id[i] == nullptr || batch.seq_id[i][0] != seq_id) {
return false;
}
}
return true;
}
static int common_speculative_copy_seq_batch(
const llama_batch & batch,
llama_seq_id seq_id,
llama_batch & seq_batch) {
if (batch.token == nullptr || batch.pos == nullptr) {
return -1;
}
if (batch.n_tokens < 1) {
return 0;
}
std::vector<int> token_indices;
token_indices.reserve(batch.n_tokens);
for (int i = 0; i < batch.n_tokens; ++i) {
if (common_speculative_batch_token_has_seq_id(batch, i, seq_id)) {
token_indices.push_back(i);
}
}
if (token_indices.empty()) {
return 0;
}
seq_batch = llama_batch_init((int) token_indices.size(), 0, 1);
for (const int i : token_indices) {
common_batch_add(seq_batch, batch.token[i], batch.pos[i], { seq_id }, batch.logits != nullptr && batch.logits[i]);
}
return (int) token_indices.size();
}
static bool common_speculative_feature_view_copy_batch_rows(
const common_speculative_feature_view & view,
const llama_batch & batch,
llama_seq_id seq_id,
std::vector<float> * hidden_rows) {
if (hidden_rows == nullptr || view.kind != COMMON_SPECULATIVE_FEATURE_HIDDEN_STATE || view.width <= 0 || batch.n_tokens <= 0 || batch.pos == nullptr) {
return false;
}
std::unordered_map<llama_pos, const float *> rows_by_pos;
rows_by_pos.reserve(view.rows.size());
for (const auto & row : view.rows) {
if (row.seq_id == seq_id && row.data != nullptr) {
rows_by_pos[row.pos] = row.data;
}
}
hidden_rows->clear();
hidden_rows->reserve((size_t) batch.n_tokens * view.width);
for (int i = 0; i < batch.n_tokens; ++i) {
auto it = rows_by_pos.find(batch.pos[i]);
if (it == rows_by_pos.end()) {
hidden_rows->clear();
return false;
}
hidden_rows->insert(hidden_rows->end(), it->second, it->second + view.width);
}
return hidden_rows->size() == (size_t) batch.n_tokens * view.width;
}
static bool common_speculative_capture_target_features(
common_speculative * spec,
const common_speculative_feature_view & features);
static bool common_speculative_feature_view_from_hidden_rows(
const std::vector<float> & hidden_rows,
int32_t width,
llama_seq_id seq_id,
llama_pos pos_base,
common_speculative_feature_view & view) {
view = {};
view.kind = COMMON_SPECULATIVE_FEATURE_HIDDEN_STATE;
view.width = width;
if (width <= 0 || hidden_rows.empty() || hidden_rows.size() % (size_t) width != 0) {
return false;
}
const size_t n_rows = hidden_rows.size() / (size_t) width;
view.rows.reserve(n_rows);
for (size_t i = 0; i < n_rows; ++i) {
view.rows.push_back({
/* .seq_id = */ seq_id,
/* .pos = */ pos_base + (llama_pos) i,
/* .data = */ hidden_rows.data() + i * (size_t) width,
});
}
return true;
}
static bool common_speculative_collect_target_batch_features(
const common_speculative * spec,
llama_context * ctx,
const llama_batch & batch,
common_speculative_feature_view & features) {
features = {};
if (common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_DFLASH)) {
return llama_spec_get_dflash_feature_view(ctx, batch, features);
}
if (!common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_MTP)) {
return true;
}
if (!llama_spec_get_hidden_feature_view(ctx, batch, features)) {
return false;
}
return true;
}
static bool common_speculative_collect_target_seq_batch_features(
const common_speculative * spec,
llama_context * ctx,
const llama_batch & batch,
llama_seq_id seq_id,
common_speculative_feature_view & features) {
features = {};
if (common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_DFLASH)) {
return llama_spec_get_dflash_feature_view_for_seq(ctx, batch, seq_id, features);
}
if (!common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_MTP)) {
return true;
}
if (!llama_spec_get_hidden_feature_view_for_seq(ctx, batch, seq_id, features)) {
return false;
}
return true;
}
bool common_speculative_capture_output_hidden(
common_speculative * spec,
llama_context * ctx,
int32_t output_index,
llama_seq_id seq_id,
llama_pos pos) {
if (!common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_MTP)) {
return true;
}
common_speculative_feature_view features;
if (!llama_spec_get_hidden_feature_view_from_output_index(ctx, output_index, seq_id, pos, features)) {
return false;
}
return common_speculative_capture_target_features(spec, features);
}
bool common_speculative_ensure_sequence_hidden(
common_speculative * spec,
llama_context * ctx,
llama_seq_id seq_id,
llama_pos pos) {
if (!common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_MTP) || common_speculative_has_sequence_hidden(spec, seq_id)) {
return true;
}
return common_speculative_capture_output_hidden(spec, ctx, -1, seq_id, pos);
}
common_speculative_draft_result common_speculative_draft_ex(
common_speculative * spec,
llama_context * ctx,
common_params_speculative & params,
const llama_tokens & prompt_tgt,
llama_token id_last,
llama_pos draft_base_pos,
llama_seq_id draft_seq_id) {
common_speculative_draft_result result = {};
if (common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_MTP)) {
if (!common_speculative_ensure_sequence_hidden(spec, ctx, draft_seq_id, draft_base_pos - 1)) {
LOG_ERR("%s: seq_id=%d MTP hidden state is empty during speculation\n",
__func__, (int) draft_seq_id);
return result;
}
}
result.tokens = common_speculative_draft(
spec,
params,
prompt_tgt,
id_last,
draft_base_pos,
draft_seq_id);
result.type = spec != nullptr && spec->curr_impl != nullptr
? spec->curr_impl->type
: COMMON_SPECULATIVE_TYPE_NONE;
return result;
}
static bool common_speculative_has_target_features(const common_speculative * spec) {
return common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_MTP) ||
common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_DFLASH);
}
bool common_speculative_load_draft_model(
common_params_speculative & params,
const gpt_params & params_base) {
if (!params.has_dft()) {
return true;
}
gpt_params params_dft;
params_dft.devices = params.devices;
params_dft.model = params.model;
params_dft.main_gpu = params_base.main_gpu;
params_dft.n_gpu_layers = params.n_gpu_layers;
params_dft.rpc_servers = params_base.rpc_servers;
params_dft.cache_type_k = params.cache_type_k.empty() ? params_base.cache_type_k : params.cache_type_k;
params_dft.cache_type_v = params.cache_type_v.empty() ? params_base.cache_type_v : params.cache_type_v;
params_dft.flash_attn = params_base.flash_attn;
params_dft.k_cache_hadamard = params_base.k_cache_hadamard;
params_dft.v_cache_hadamard = params_base.v_cache_hadamard;
if (params.has_stage_type(COMMON_SPECULATIVE_TYPE_DFLASH)) {
params_dft.split_mode = params_base.split_mode;
for (size_t i = 0; i < std::size(params_dft.tensor_split); ++i) {
params_dft.tensor_split[i] = params_base.tensor_split[i];
}
params_dft.attn_max_batch = params_base.attn_max_batch;
params_dft.graph_reuse = params_base.graph_reuse;
params_dft.split_mode_graph_scheduling = params_base.split_mode_graph_scheduling;
params_dft.scheduler_async = params_base.scheduler_async;
params_dft.max_extra_alloc_MiB = params_base.max_extra_alloc_MiB;
params_dft.reduce_type = params_base.reduce_type;
}
if (!params.params.empty()) {
auto [argc, argv] = parse_command_line("llama-server " + params.params);
if (!gpt_params_parse(argc, argv, params_dft)) {
gpt_params_print_usage(argc, argv, params_dft);
free_command_line(argc, argv);
return false;
}
free_command_line(argc, argv);
}
LOG_INF("%s: loading draft model '%s'\n", __func__, params_dft.model.c_str());
if (params_dft.n_ctx == 0) {
params_dft.n_ctx = params.n_ctx;
}
if (params.has_stage_type(COMMON_SPECULATIVE_TYPE_DFLASH) && params_dft.n_gpu_layers < 0) {
params_dft.n_gpu_layers = params_base.n_gpu_layers;
}
params_dft.n_ctx = params_dft.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_dft.n_ctx;
params_dft.n_parallel = 1;
params_dft.n_batch = params_dft.n_ctx;
params.mparams_dft.path = params_dft.model;
llama_model_params mparams_dft = common_model_params_to_llama(params_dft);
llama_model * loaded_model = llama_model_load_from_file(params_dft.model.c_str(), mparams_dft);
if (loaded_model == nullptr) {
LOG_ERR("%s: failed to load draft model '%s'\n", __func__, params.model.c_str());
return false;
}
params.model_dft = loaded_model;
params.cparams_dft = common_context_params_to_llama(params_dft);
return true;
}
bool common_speculative_prepare_mtp_runtime(
common_params_speculative & params,
const gpt_params & params_base,
const llama_model * model,
bool has_external_mtp) {
if (!params.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP)) {
return false;
}
if (llama_model_n_nextn_layer(model) == 0 && !has_external_mtp) {
LOG_WRN("%s: MTP speculative stage requested, but model has 0 NextN layers. Removing MTP from the configured stage chain.\n",
__func__);
params.remove_stage_type(COMMON_SPECULATIVE_TYPE_MTP);
if (!params.needs_dft_model()) {
params.clear_dft();
}
return false;
}
if (!has_external_mtp) {
gpt_params params_mtp = params_base;
params_mtp.pooling_type = LLAMA_POOLING_TYPE_NONE;
params.cparams_dft = common_context_params_to_llama(params_mtp);
}
params.cparams_dft.mtp = true;
params.cparams_dft.mtp_op_type = MTP_OP_WARMUP;
params.cparams_dft.embeddings = true;
return true;
}
common_speculative_init_status common_speculative_try_init(
common_params_speculative & params,
llama_context * ctx_tgt,
common_speculative ** out_spec) {
if (out_spec != nullptr) {
*out_spec = nullptr;
}
if (!params.has_stage_chain()) {
return COMMON_SPECULATIVE_INIT_SKIPPED;
}
common_speculative * spec = common_speculative_init(params, ctx_tgt);
if (spec != nullptr) {
if (out_spec != nullptr) {
*out_spec = spec;
}
return COMMON_SPECULATIVE_INIT_READY;
}
const llama_model * model = ctx_tgt != nullptr ? llama_get_model(ctx_tgt) : nullptr;
if (model != nullptr && llama_model_has_recurrent(model)) {
return COMMON_SPECULATIVE_INIT_ERR_RECURRENT;
}
if (params.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP)) {
return COMMON_SPECULATIVE_INIT_ERR_MTP;
}
return COMMON_SPECULATIVE_INIT_ERR_GENERIC;
}
void common_speculative_prepare_startup(
gpt_params & params_base,
bool allow_parallel_mtp) {
auto & params = params_base.speculative;
if (!allow_parallel_mtp && params_base.n_parallel > 1 && params.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP)) {
LOG_WRN("%s: MTP is not supported with parallel slots yet, removing the MTP stage to avoid cross-slot corruption. n_parallel=%d, stage_chain=%s\n",
__func__, params_base.n_parallel, common_speculative_stage_chain_to_str(params).c_str());
params.remove_stage_type(COMMON_SPECULATIVE_TYPE_MTP);
}
if (!params.needs_dft_model()) {
params.clear_dft();
}
params_base.has_mtp = params.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP);
}
bool common_speculative_finalize_startup(
gpt_params & params_base,
const llama_model * model) {
auto & params = params_base.speculative;
if (!params.needs_dft_model()) {
params.clear_dft();
}
if (params.has_dft()) {
LLAMA_LOG_INFO("\n\n==================================loading DRAFT model==================================\n\n");
if (!common_speculative_load_draft_model(params, params_base)) {
return false;
}
}
params_base.has_mtp = params.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP);
const bool has_external_mtp = params_base.has_mtp &&
llama_model_is_gemma4_mtp_assistant(params.model_dft);
params_base.has_mtp = common_speculative_prepare_mtp_runtime(
params,
params_base,
model,
has_external_mtp);
if (params_base.has_mtp) {
params_base.pooling_type = LLAMA_POOLING_TYPE_NONE;
}
return true;
}
bool common_speculative_before_draft(
common_speculative * spec,
llama_model * model,
llama_context * ctx,
common_sampler * sampler_src,
const common_params_sampling & sparams,
llama_seq_id seq_id,
llama_pos n_past,
llama_token sampled,
int max_tokens,
int ckpt_mode) {
if (spec == nullptr) {
return false;
}
return common_speculative_checkpoint_save(
spec->checkpoint,
model,
ctx,
sampler_src,
sparams,
seq_id,
n_past,
sampled,
max_tokens,
ckpt_mode);
}
int32_t common_speculative_on_target_seq_batch(
common_speculative * spec,
llama_context * ctx_tgt,
const llama_batch & batch,
llama_seq_id seq_id,
bool is_prompt_warmup) {
if (ctx_tgt == nullptr || batch.n_tokens <= 0) {
return 0;
}
if (!common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_DFLASH)) {
llama_context * ctx_mtp = common_speculative_get_companion_ctx(spec);
ctx_mtp = ctx_mtp ? ctx_mtp : ctx_tgt;
if (ctx_mtp == nullptr) {
return 0;
}
const int n_embd_src = common_speculative_ctx_mtp_n_embd(ctx_tgt);
const int n_embd_dst = common_speculative_ctx_mtp_n_embd(ctx_mtp);
if (n_embd_src <= 0 || n_embd_dst <= 0) {
return -1;
}
if (n_embd_src != n_embd_dst) {
LOG_ERR("MTP warmup hidden state width mismatch: n_embd_src = %d, n_embd_dst = %d\n", n_embd_src, n_embd_dst);
return -1;
}
}
common_speculative_feature_view feature_view;
const llama_batch * batch_for_spec = &batch;
llama_batch seq_batch = {};
const bool needs_seq_split = is_prompt_warmup && !common_speculative_batch_is_exact_single_seq(batch, seq_id);
if (needs_seq_split) {
const int n_seq_tokens = common_speculative_copy_seq_batch(batch, seq_id, seq_batch);
if (n_seq_tokens <= 0) {
return n_seq_tokens < 0 ? -1 : 0;
}
if (!common_speculative_collect_target_seq_batch_features(spec, ctx_tgt, batch, seq_id, feature_view)) {
llama_batch_free(seq_batch);
return -1;
}
batch_for_spec = &seq_batch;
} else {
if (!common_speculative_collect_target_batch_features(spec, ctx_tgt, batch, feature_view)) {
return -1;
}
}
const int32_t ret = common_speculative_on_target_batch(spec, *batch_for_spec, feature_view, is_prompt_warmup);
if (needs_seq_split) {
llama_batch_free(seq_batch);
}
return ret;
}
bool common_speculative_copy_output_hidden_rows(
const common_speculative * spec,
llama_context * ctx,
const std::vector<int32_t> & output_indices,
std::vector<float> & hidden_rows) {
hidden_rows.clear();
if (common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_DFLASH)) {
return llama_spec_copy_dflash_rows_from_output_indices(ctx, output_indices, hidden_rows);
}
if (!common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_MTP)) {
return true;
}
return llama_spec_copy_hidden_rows_from_output_indices(ctx, output_indices, hidden_rows);
}
static bool common_speculative_build_commit_tokens(
common_speculative_type spec_type_used,
llama_token sampled_before,
const std::vector<llama_token> & ids,
std::vector<llama_token> & commit_tokens) {
commit_tokens.clear();
if (ids.empty()) {
return true;
}
if (spec_type_used == COMMON_SPECULATIVE_TYPE_MTP) {
commit_tokens = ids;
return true;
}
commit_tokens.reserve(ids.size());
commit_tokens.push_back(sampled_before);
if (ids.size() > 1) {
commit_tokens.insert(commit_tokens.end(), ids.begin(), ids.end() - 1);
}
return commit_tokens.size() == ids.size();
}
static bool common_speculative_apply_hidden_rows(
common_speculative * spec,
llama_seq_id seq_id,
llama_pos pos_base,
const std::vector<llama_token> & ids,
const std::vector<float> & hidden_rows) {
const int32_t feature_width = common_speculative_feature_width(spec);
if (feature_width <= 0 || ids.empty()) {
return true;
}
const size_t expected_floats = ids.size() * (size_t) feature_width;
if (hidden_rows.size() != expected_floats) {
return false;
}
llama_batch accepted_batch = llama_batch_init(ids.size(), 0, 1);
for (size_t i = 0; i < ids.size(); ++i) {
common_batch_add(accepted_batch, ids[i], pos_base + (llama_pos) i, { seq_id }, true);
}
common_speculative_feature_view feature_view;
const bool have_feature_view = common_speculative_feature_view_from_hidden_rows(
hidden_rows, feature_width, seq_id, pos_base, feature_view);
const int32_t ret = have_feature_view
? common_speculative_on_target_batch(spec, accepted_batch, feature_view, false)
: -1;
llama_batch_free(accepted_batch);
return ret == 0;
}
bool common_speculative_commit_accepted_hidden_rows(
common_speculative * spec,
common_speculative_type spec_type_used,
llama_seq_id seq_id,
llama_pos pos_base,
llama_token sampled_before,
const std::vector<llama_token> & ids,
const std::vector<float> & hidden_rows) {
if (common_speculative_feature_width(spec) <= 0 || ids.empty()) {
return true;
}
std::vector<llama_token> commit_tokens;
if (!common_speculative_build_commit_tokens(spec_type_used, sampled_before, ids, commit_tokens)) {
return false;
}
return common_speculative_apply_hidden_rows(spec, seq_id, pos_base, commit_tokens, hidden_rows);
}
bool common_speculative_commit_accepted_output(
common_speculative * spec,
llama_context * ctx,
common_speculative_type spec_type_used,
llama_seq_id seq_id,
llama_pos pos_base,
llama_token sampled_before,
const std::vector<llama_token> & ids,
const std::vector<int32_t> & output_indices) {
if (common_speculative_feature_width(spec) <= 0 || ids.empty()) {
return true;
}
std::vector<float> hidden_rows;
if (!common_speculative_copy_output_hidden_rows(spec, ctx, output_indices, hidden_rows)) {
return false;
}
return common_speculative_commit_accepted_hidden_rows(
spec,
spec_type_used,
seq_id,
pos_base,
sampled_before,
ids,
hidden_rows);
}
static bool common_speculative_checkpoint_save(
common_speculative_checkpoint & ckpt,
llama_model * model,
llama_context * ctx,
common_sampler * sampler_src,
const common_params_sampling & sparams,
llama_seq_id seq_id,
llama_pos n_past,
llama_token sampled,
int max_tokens,
int ckpt_mode) {
ckpt.clear();
ckpt.n_past = n_past;
ckpt.sampled = sampled;
const int actual_mode = llama_spec_ckpt_init(ctx, ckpt_mode, max_tokens);
if (actual_mode == LLAMA_SPEC_CKPT_NONE) {
return false;
}
ckpt.per_step_enabled = (actual_mode == LLAMA_SPEC_CKPT_PER_STEP);
ckpt.valid = llama_spec_ckpt_save(ctx, seq_id);
if (!ckpt.valid) {
llama_spec_ckpt_discard(ctx);
return false;
}
ckpt.sampler = common_sampler_init(model, sparams);
if (ckpt.sampler == nullptr) {
common_speculative_checkpoint_discard(ckpt, ctx);
return false;
}
if (sampler_src != nullptr) {
common_sampler_clone(sampler_src, ckpt.sampler);
}
return true;
}
const common_speculative_checkpoint * common_speculative_get_checkpoint(const common_speculative * spec) {
return spec != nullptr ? &spec->checkpoint : nullptr;
}
void common_speculative_checkpoint_discard(
common_speculative_checkpoint & ckpt,
llama_context * ctx) {
ckpt.clear();
llama_spec_ckpt_discard(ctx);
}
void common_speculative_checkpoint_restore(
common_speculative_checkpoint & ckpt,
common_speculative * spec,
llama_context * ctx,
common_sampler * sampler_dst,
llama_seq_id seq_id,
common_speculative_type spec_type_used,
llama_token sampled_before,
const std::vector<llama_token> & ids,
int n_draft,
const std::vector<float> & mtp_hidden_state_pre,
int32_t mtp_n_past_base) {
if (!ckpt.valid) {
return;
}
if (ckpt.per_step_enabled) {
const int step = (int) ids.size() - 1;
llama_spec_ckpt_restore(ctx, seq_id, ckpt.n_past, step);
if (ckpt.sampler != nullptr && sampler_dst != nullptr) {
common_sampler_clone(ckpt.sampler, sampler_dst);
}
if (sampler_dst != nullptr) {
for (llama_token id : ids) {
common_sampler_accept(sampler_dst, ctx, id, true);
}
}
if (common_speculative_has_target_features(spec) && !mtp_hidden_state_pre.empty()) {
if (!common_speculative_commit_accepted_hidden_rows(
spec,
spec_type_used,
seq_id,
mtp_n_past_base,
sampled_before,
ids,
mtp_hidden_state_pre)) {
common_speculative_clear_sequence_hidden(spec, seq_id);
} else if (spec_type_used != COMMON_SPECULATIVE_TYPE_MTP) {
LOG_DBG("%s: seq_id=%d synced MTP target hidden state from accepted-prefix rows after per-step restore\n",
__func__, (int) seq_id);
}
}
LOG_DBG("%s: seq_id=%d per-step restore: step=%d (rejected %d drafts)\n",
__func__, (int) seq_id, step, (int) (n_draft - (ids.size() - 1)));
} else {
llama_spec_ckpt_restore(ctx, seq_id, ckpt.n_past, 0);
if (ckpt.sampler != nullptr && sampler_dst != nullptr) {
common_sampler_clone(ckpt.sampler, sampler_dst);
}
if (!ids.empty()) {
const int n_re = (int) ids.size();
llama_batch re_batch = llama_batch_init(n_re, 0, 1);
common_batch_add(re_batch, ckpt.sampled, ckpt.n_past, { seq_id }, n_re == 1);
for (int j = 0; j < n_re - 1; ++j) {
common_batch_add(re_batch, ids[j], ckpt.n_past + 1 + j, { seq_id }, j == n_re - 2);
}
if (common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_MTP)) {
for (int j = 0; j < re_batch.n_tokens; ++j) {
re_batch.logits[j] = true;
}
llama_set_embeddings(ctx, true);
}
const int ret = llama_decode(ctx, re_batch);
if (ret != 0) {
LOG_ERR("%s: seq_id=%d failed to re-decode accepted tokens after checkpoint restore: %d\n",
__func__, (int) seq_id, ret);
}
if (common_speculative_has_target_features(spec)) {
std::vector<int32_t> redecoded_indices(n_re);
for (int j = 0; j < n_re; ++j) {
redecoded_indices[j] = j;
}
if (!common_speculative_commit_accepted_output(
spec,
ctx,
spec_type_used,
seq_id,
ckpt.n_past,
sampled_before,
ids,
redecoded_indices)) {
common_speculative_clear_sequence_hidden(spec, seq_id);
}
}
if (sampler_dst != nullptr) {
for (llama_token id : ids) {
common_sampler_accept(sampler_dst, ctx, id, true);
}
}
llama_batch_free(re_batch);
LOG_DBG("%s: seq_id=%d spec checkpoint restored: re-decoded %d tokens (rejected %d drafts)\n",
__func__, (int) seq_id, n_re, (int) (n_draft - (ids.size() - 1)));
}
}
common_speculative_checkpoint_discard(ckpt, ctx);
}
static bool mtp_model_uses_recurrent_conditioning(const common_speculative_state_mtp & state);
void common_speculative_commit(
common_speculative * spec,
llama_context * ctx,
common_sampler * sampler_dst,
llama_seq_id seq_id,
llama_token sampled_before,
const std::vector<llama_token> & ids,
int n_draft,
llama_pos pos_base,
const std::vector<int32_t> & accepted_output_indices) {
GGML_ASSERT(spec != nullptr);
GGML_ASSERT(!ids.empty());
common_speculative_checkpoint & ckpt = spec->checkpoint;
const common_speculative_type spec_type_used = spec->curr_impl != nullptr
? spec->curr_impl->type
: COMMON_SPECULATIVE_TYPE_NONE;
const bool any_rejected = (int) ids.size() - 1 < n_draft;
std::vector<float> mtp_hidden_state_pre;
common_speculative_accept(spec, ids.size() - 1);
if (common_speculative_has_target_features(spec) &&
any_rejected &&
ckpt.valid &&
!accepted_output_indices.empty()) {
if (!common_speculative_copy_output_hidden_rows(spec, ctx, accepted_output_indices, mtp_hidden_state_pre)) {
mtp_hidden_state_pre.clear();
}
}
if (any_rejected && ckpt.valid) {
common_speculative_checkpoint_restore(
ckpt,
spec,
ctx,
sampler_dst,
seq_id,
spec_type_used,
sampled_before,
ids,
n_draft,
mtp_hidden_state_pre,
pos_base);
return;
}
if (common_speculative_has_target_features(spec) && !accepted_output_indices.empty()) {
if (!common_speculative_commit_accepted_output(
spec,
ctx,
spec_type_used,
seq_id,
pos_base,
sampled_before,
ids,
accepted_output_indices)) {
common_speculative_clear_sequence_hidden(spec, seq_id);
} else if (spec_type_used != COMMON_SPECULATIVE_TYPE_MTP) {
LOG_DBG("%s: seq_id=%d synced MTP target hidden state from accepted-prefix rows\n",
__func__, (int) seq_id);
}
}
llama_kv_cache_seq_rm(ctx, seq_id, pos_base + (llama_pos) (ids.size() - 1), -1);
common_speculative_checkpoint_discard(ckpt, ctx);
}
void common_speculative_print_stats(const common_speculative * spec, double slot_tps, int n_decoded, int n_past, common_params_speculative * active_params) {
if (spec == nullptr) {
return;
}
for (const auto & impl : spec->impls) {
std::string str_perf;
if (impl->gen_perf) {
std::ostringstream oss;
oss << std::fixed << std::setprecision(3) << impl->t_begin_us / 1000.0 << ", ";
oss << std::fixed << std::setprecision(3) << impl->t_draft_us / 1000.0 << ", ";
oss << std::fixed << std::setprecision(3) << impl->t_accept_us / 1000.0;
str_perf = ", dur(b,g,a) = " + oss.str() + " ms";
} else {
str_perf = "";
}
LOG_INF("statistics %s: #calls(b,g,a) = %zu %zu %zu, #gen drafts = %zu, #acc drafts = %zu, #gen tokens = %zu, #acc tokens = %zu%s\n",
common_speculative_type_to_str(impl->type).c_str(),
impl->n_call_begin, impl->n_call_draft, impl->n_call_accept,
impl->n_gen_drafts,
impl->n_acc_drafts,
impl->n_gen_tokens,
impl->n_acc_tokens,
str_perf.c_str());
}
if (spec->tuner && spec->tuner->enabled && slot_tps > 0.0 && n_decoded > 0) {
auto * mutable_spec = const_cast<common_speculative *>(spec);
if (active_params) {
mutable_spec->tuner->end_of_request(slot_tps, n_past, *active_params);
} else {
common_params_speculative tmp_params;
mutable_spec->tuner->end_of_request(slot_tps, n_past, tmp_params);
}
}
}
// ----------------------------------------------------------------------------
// MTP
// ----------------------------------------------------------------------------
static common_speculative_state_mtp * common_speculative_get_mtp_state(common_speculative * spec) {
if (!spec) {
return nullptr;
}
for (auto & impl : spec->impls) {
if (impl->type != COMMON_SPECULATIVE_TYPE_MTP) {
continue;
}
if (auto * mtp_state = dynamic_cast<common_speculative_state_mtp *>(impl.get())) {
return mtp_state;
}
}
return nullptr;
}
static const common_speculative_state_mtp * common_speculative_get_mtp_state(const common_speculative * spec) {
return common_speculative_get_mtp_state(const_cast<common_speculative *>(spec));
}
static common_speculative_state_dflash * common_speculative_get_dflash_state(common_speculative * spec) {
if (!spec) {
return nullptr;
}
for (auto & impl : spec->impls) {
if (impl->type != COMMON_SPECULATIVE_TYPE_DFLASH) {
continue;
}
if (auto * dflash_state = dynamic_cast<common_speculative_state_dflash *>(impl.get())) {
return dflash_state;
}
}
return nullptr;
}
static const common_speculative_state_dflash * common_speculative_get_dflash_state(const common_speculative * spec) {
return common_speculative_get_dflash_state(const_cast<common_speculative *>(spec));
}
static int32_t common_speculative_feature_width(const common_speculative * spec) {
if (const auto * dflash_state = common_speculative_get_dflash_state(spec); dflash_state != nullptr) {
return dflash_state->n_target_features;
}
if (const auto * mtp_state = common_speculative_get_mtp_state(spec); mtp_state != nullptr) {
return mtp_state->n_embd;
}
return 0;
}
static mtp_last_embd & mtp_get_last_embd(common_speculative_state_mtp & state, llama_seq_id seq_id) {
auto & last = state.draft_cache_by_seq[seq_id];
if ((int) last.embd.size() != state.n_embd) {
last.embd.resize(state.n_embd);
}
return last;
}
static void mtp_invalidate_cached_draft(common_speculative_state_mtp & state, llama_seq_id seq_id) {
auto it = state.draft_cache_by_seq.find(seq_id);
if (it == state.draft_cache_by_seq.end()) {
return;
}
it->second.last_id = -1;
it->second.prob = 0.0f;
}
static void mtp_invalidate_cached_drafts(common_speculative_state_mtp & state) {
for (auto & entry : state.draft_cache_by_seq) {
entry.second.last_id = -1;
entry.second.prob = 0.0f;
}
}
static void mtp_store_target_hidden(
common_speculative_state_mtp & state,
llama_seq_id seq_id,
const float * hidden,
int32_t width) {
if (hidden == nullptr || width <= 0) {
return;
}
auto & stored = state.target_hidden_by_seq[seq_id];
stored.assign(hidden, hidden + width);
}
static bool mtp_model_uses_recurrent_conditioning(const common_speculative_state_mtp & state) {
if (state.ctx_mtp == nullptr) {
return false;
}
const llama_model * model = llama_get_model(state.ctx_mtp);
if (!llama_model_has_recurrent(model)) {
return false;
}
std::string arch{llama_model_arch_string(model)};
return arch == "qwen35" || arch == "qwen35moe";
}
static void mtp_clear_target_hidden(common_speculative_state_mtp & state, llama_seq_id seq_id) {
state.target_hidden_by_seq.erase(seq_id);
state.draft_cache_by_seq.erase(seq_id);
}
static bool common_speculative_capture_target_features(common_speculative * spec, const common_speculative_feature_view & features) {
auto * mtp_state = common_speculative_get_mtp_state(spec);
if (mtp_state == nullptr || features.kind != COMMON_SPECULATIVE_FEATURE_HIDDEN_STATE || features.width <= 0) {
return false;
}
bool captured = false;
for (const auto & row : features.rows) {
if (row.data == nullptr) {
continue;
}
mtp_store_target_hidden(*mtp_state, row.seq_id, row.data, features.width);
mtp_invalidate_cached_draft(*mtp_state, row.seq_id);
captured = true;
}
return captured;
}
bool common_speculative_has_sequence_hidden(const common_speculative * spec, llama_seq_id seq_id) {
const auto * mtp_state = common_speculative_get_mtp_state(spec);
if (mtp_state == nullptr) {
return false;
}
auto it = mtp_state->target_hidden_by_seq.find(seq_id);
return it != mtp_state->target_hidden_by_seq.end() && !it->second.empty();
}
void common_speculative_clear_sequence_hidden(common_speculative * spec, llama_seq_id seq_id) {
auto * mtp_state = common_speculative_get_mtp_state(spec);
if (mtp_state != nullptr) {
mtp_clear_target_hidden(*mtp_state, seq_id);
}
if (auto * dflash_state = common_speculative_get_dflash_state(spec); dflash_state != nullptr) {
dflash_clear_target_features(*dflash_state);
}
}
void common_speculative_clear_sequence(
common_speculative * spec,
llama_seq_id seq_id,
bool clear_companion_ctx) {
if (spec != nullptr) {
spec->checkpoint.clear();
spec->curr_impl = nullptr;
spec->last_n_drafted = 0;
spec->t_step_start_us = 0;
}
common_speculative_clear_sequence_hidden(spec, seq_id);
if (clear_companion_ctx) {
if (auto * ctx_mtp = common_speculative_get_companion_ctx(spec); ctx_mtp != nullptr) {
llama_kv_cache_clear(ctx_mtp);
}
}
}
bool common_speculative_trim_sequence(
common_speculative * spec,
llama_context * ctx,
llama_seq_id seq_id,
llama_pos pos_begin) {
const bool target_trimmed = llama_kv_cache_seq_rm(ctx, seq_id, pos_begin, -1);
if (auto * ctx_mtp = common_speculative_get_companion_ctx(spec); ctx_mtp != nullptr) {
return target_trimmed && llama_kv_cache_seq_rm(ctx_mtp, seq_id, pos_begin, -1);
}
return target_trimmed;
}
void common_speculative_clear_sequence_kv(
common_speculative * spec,
llama_context * ctx,
llama_seq_id seq_id) {
common_speculative_clear_sequence(spec, seq_id);
llama_kv_cache_seq_rm(ctx, seq_id, -1, -1);
if (auto * ctx_mtp = common_speculative_get_companion_ctx(spec); ctx_mtp != nullptr) {
llama_kv_cache_seq_rm(ctx_mtp, seq_id, -1, -1);
}
}
llama_context * common_speculative_get_companion_ctx(common_speculative * spec) {
if (auto * mtp_state = common_speculative_get_mtp_state(spec); mtp_state != nullptr) {
return mtp_state->ctx_mtp;
}
if (auto * dflash_state = common_speculative_get_dflash_state(spec); dflash_state != nullptr) {
return dflash_state->ctx_dft;
}
return nullptr;
}
static int32_t mtp_accept_batch(
common_speculative_state_mtp & state,
const llama_batch & accepted_batch,
llama_seq_id seq_id,
const float * hidden_rows) {
if (accepted_batch.n_tokens == 0 || hidden_rows == nullptr) {
return 0;
}
const size_t hidden_rows_floats = (size_t) accepted_batch.n_tokens * state.n_embd;
if (!llama_set_draft_input_hidden_state_copy(state.ctx_mtp, hidden_rows, hidden_rows_floats)) {
return -1;
}
if (mtp_update_kv_cache(state.ctx_mtp, accepted_batch, false) != 0) {
return -1;
}
auto & last = mtp_get_last_embd(state, seq_id);
const float * embd = llama_get_embeddings_ith(state.ctx_mtp, accepted_batch.n_tokens - 1);
if (embd != nullptr) {
std::memcpy(last.embd.data(), embd, last.embd.size() * sizeof(float));
if (!llama_set_draft_input_hidden_state_copy(state.ctx_mtp, last.embd.data(), last.embd.size())) {
return -1;
}
last.last_id = common_sampler_sample_speculative(nullptr, state.ctx_mtp, accepted_batch.n_tokens - 1, &last.prob);
}
return 0;
}
int32_t common_speculative_on_target_batch(
common_speculative * spec,
const llama_batch & batch,
const common_speculative_feature_view & features,
bool is_prompt_warmup) {
if (auto * dflash_state = common_speculative_get_dflash_state(spec); dflash_state != nullptr) {
if (features.kind != COMMON_SPECULATIVE_FEATURE_HIDDEN_STATE || batch.n_tokens <= 0) {
return 0;
}
if (features.width != dflash_state->n_target_features) {
LOG_ERR("%s: DFlash feature width mismatch: got %d expected %d\n",
__func__, features.width, dflash_state->n_target_features);
return -1;
}
if (batch.n_seq_id == nullptr || batch.seq_id == nullptr || batch.n_seq_id[0] <= 0 || batch.seq_id[0] == nullptr) {
return -1;
}
const llama_seq_id seq_id = batch.seq_id[0][0];
for (int i = 0; i < batch.n_tokens; ++i) {
if (batch.n_seq_id[i] != 1 || batch.seq_id[i] == nullptr || batch.seq_id[i][0] != seq_id) {
return -1;
}
}
if (!dflash_append_target_features(*dflash_state, features, seq_id)) {
return -1;
}
return 0;
}
auto * mtp_state = common_speculative_get_mtp_state(spec);
if (mtp_state == nullptr) {
return 0;
}
if (features.kind != COMMON_SPECULATIVE_FEATURE_HIDDEN_STATE || features.width <= 0 || batch.n_tokens <= 0) {
return 0;
}
if (features.width != mtp_state->n_embd) {
LOG_ERR("%s: MTP feature width mismatch: got %d expected %d\n",
__func__, features.width, mtp_state->n_embd);
return -1;
}
if (batch.n_seq_id == nullptr || batch.seq_id == nullptr || batch.n_seq_id[0] <= 0 || batch.seq_id[0] == nullptr) {
return -1;
}
const llama_seq_id seq_id = batch.seq_id[0][0];
for (int i = 0; i < batch.n_tokens; ++i) {
if (batch.n_seq_id[i] != 1 || batch.seq_id[i] == nullptr || batch.seq_id[i][0] != seq_id) {
return -1;
}
}
std::vector<float> hidden_rows_storage;
if (!common_speculative_feature_view_copy_batch_rows(features, batch, seq_id, &hidden_rows_storage)) {
return -1;
}
const float * last_hidden = hidden_rows_storage.data() + (size_t) (batch.n_tokens - 1) * features.width;
mtp_store_target_hidden(*mtp_state, seq_id, last_hidden, features.width);
if (mtp_state->constant_draft_positions) {
mtp_invalidate_cached_draft(*mtp_state, seq_id);
return 0;
}
if (!is_prompt_warmup) {
return mtp_accept_batch(*mtp_state, batch, seq_id, hidden_rows_storage.data());
}
const bool uses_shifted_hidden_rows = mtp_model_uses_recurrent_conditioning(*mtp_state);
std::vector<float> previous_hidden_storage;
if (uses_shifted_hidden_rows) {
const auto hidden_it = mtp_state->target_hidden_by_seq.find(seq_id);
if (hidden_it != mtp_state->target_hidden_by_seq.end() && (int32_t) hidden_it->second.size() == features.width) {
previous_hidden_storage = hidden_it->second;
} else {
previous_hidden_storage.assign(features.width, 0.0f);
}
}
const float * conditioned_hidden_rows = hidden_rows_storage.data();
std::vector<float> conditioned_hidden_storage;
if (uses_shifted_hidden_rows) {
conditioned_hidden_storage.resize(hidden_rows_storage.size());
std::copy(previous_hidden_storage.begin(), previous_hidden_storage.end(), conditioned_hidden_storage.begin());
if (batch.n_tokens > 1) {
std::copy(
hidden_rows_storage.begin(),
hidden_rows_storage.begin() + (size_t) (batch.n_tokens - 1) * features.width,
conditioned_hidden_storage.begin() + features.width);
}
conditioned_hidden_rows = conditioned_hidden_storage.data();
}
if (!llama_set_draft_input_hidden_state_copy(mtp_state->ctx_mtp, conditioned_hidden_rows, hidden_rows_storage.size())) {
return -1;
}
const int32_t ret = mtp_update_kv_cache(mtp_state->ctx_mtp, batch, true);
mtp_invalidate_cached_draft(*mtp_state, seq_id);
return ret;
}
common_speculative_type common_speculative_current_type(const common_speculative * spec) {
if (spec == nullptr || spec->curr_impl == nullptr) {
return COMMON_SPECULATIVE_TYPE_NONE;
}
return spec->curr_impl->type;
}
void common_speculative_context_shift(
common_speculative * spec,
llama_seq_id seq_id,
llama_pos kv_keep,
llama_pos kv_discard,
llama_pos kv_past) {
if (auto * ctx_mtp = common_speculative_get_companion_ctx(spec); ctx_mtp != nullptr) {
llama_kv_cache_seq_rm (ctx_mtp, seq_id, kv_keep, kv_keep + kv_discard);
llama_kv_cache_seq_add(ctx_mtp, seq_id, kv_keep + kv_discard, kv_past, -kv_discard);
}
if (auto * dflash_state = common_speculative_get_dflash_state(spec); dflash_state != nullptr) {
dflash_context_shift(*dflash_state, kv_keep, kv_discard, kv_past);
}
}
std::vector<llama_token> mtp_speculative_gen_draft(
common_speculative_state_mtp & state,
struct common_sampler * smpl,
struct llama_context * ctx,
int n_draft,
float p_min,
llama_token id_last,
llama_pos n_past,
llama_seq_id seq_id,
bool constant_draft_positions) {
llama_tokens drafts;
drafts.reserve(n_draft);
if (!smpl) return drafts;
if (n_draft <= 0) {
mtp_invalidate_cached_draft(state, seq_id);
return drafts;
}
common_sampler_reset(smpl);
llama_batch mtp_batch = llama_batch_init(1, 0, 1);
llama_set_mtp_op_type(ctx, MTP_OP_DRAFT_GEN);
float prob;
auto prob_ptr = p_min > 0 ? &prob : nullptr;
llama_token current_input_id = id_last;
llama_pos current_n_past = n_past;
const int n_embd = llama_mtp_state_n_embd(ctx);
auto & last = mtp_get_last_embd(state, seq_id);
int i0 = 0;
if (last.last_id >= 0) {
if (last.prob < p_min) {
n_draft = 1;
}
current_input_id = last.last_id;
last.last_id = -1;
drafts.push_back(current_input_id);
current_n_past++;
if (!llama_set_draft_input_hidden_state_copy(ctx, last.embd.data(), last.embd.size())) {
llama_batch_free(mtp_batch);
llama_set_mtp_op_type(ctx, MTP_OP_NONE);
return drafts;
}
i0 = 1;
}
int n_decode = 0;
for (int i = i0; i < n_draft; ++i) {
mtp_batch.n_tokens = 0;
const llama_pos draft_pos = constant_draft_positions ? n_past : current_n_past;
common_batch_add(mtp_batch, current_input_id, draft_pos, {seq_id}, true);
++n_decode;
if (llama_decode(ctx, mtp_batch) != 0) {
break;
}
llama_token id_next = common_sampler_sample_speculative(smpl, ctx, 0, prob_ptr);
if (i > 0 && prob_ptr && prob < p_min) {
break;
}
drafts.push_back(id_next);
const float * emb = llama_get_embeddings_ith(ctx, 0);
if (!emb) {
break;
}
// Keep a stable copy because later decode steps reuse ctx->embd storage.
memcpy(last.embd.data(), emb, n_embd * sizeof(float));
if (!llama_set_draft_input_hidden_state_copy(ctx, last.embd.data(), last.embd.size())) {
break;
}
current_input_id = id_next;
current_n_past++;
if (prob_ptr && prob < p_min) {
break;
}
}
llama_batch_free(mtp_batch);
llama_set_mtp_op_type(ctx, MTP_OP_NONE);
// Purge the metadata for the draft tokens.
// This prevents cache state corruption where two cells map to the same logical position.
// If the state contained in `last` had a valid token id and probability, it means that we
// have previously run an "accept" batch, where the token sampled from the main model was included.
// Even in that case, the token at `n_past` is already committed and must remain in the KV cache,
// so we only discard the speculative tail starting at `n_past + 1`.
if (n_decode > 0) {
llama_kv_cache_seq_rm(ctx, seq_id, n_past + 1, n_past + n_decode + 2);
}
return drafts;
}
int32_t mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, bool is_prompt_warmup) {
if (batch.n_tokens == 0) {
return 0;
}
llama_seq_id seq_id = batch.seq_id[0][0];
llama_pos start_pos = batch.pos[0];
if (llama_kv_cache_seq_pos_max(ctx, seq_id) >= start_pos) {
llama_kv_cache_seq_rm(ctx, seq_id, start_pos, -1);
}
LOG_DBG("[MTP-UPDATE|%s] Updating %d tokens for seq_id %d from pos %d...\n",
is_prompt_warmup ? "PROMPT_WARMUP" : "GEN_ACCEPTED", batch.n_tokens, seq_id, (int)start_pos);
// We never need all logits. We only need the logits of the last token so we can sample
// the next draft token. In the MTP_OP_WARMUP case we do not need logits at all, but just
// in case we also get the logits of the last token.
llama_batch mtp_batch = batch;
for (int i = 0; i < mtp_batch.n_tokens; ++i) {
mtp_batch.logits[i] = false;
}
mtp_batch.logits[mtp_batch.n_tokens-1] = true;
if (is_prompt_warmup) {
llama_set_mtp_op_type(ctx, MTP_OP_WARMUP);
} else {
llama_set_mtp_op_type(ctx, MTP_OP_UPDATE_ACCEPTED);
}
const int32_t ret = llama_decode(ctx, mtp_batch);
llama_set_mtp_op_type(ctx, MTP_OP_NONE);
return ret;
}