Refactor: Move spec outside server (#1949)

* Refactor speculative decoding: move logic outside of server

* remove duplicated tokens in mtp kv cache

* narrow to only discard draft cells in MTP

* revert mtp_speculative_gen_draft
This commit is contained in:
Samuel Oliveira Alves 2026-06-12 13:12:39 -03:00 committed by GitHub
parent d1339249d7
commit 8a38025174
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 776 additions and 425 deletions

View File

@ -195,10 +195,38 @@ bool common_params_speculative::has_stage_type(common_speculative_type stage_typ
});
}
void common_params_speculative::remove_stage_type(common_speculative_type stage_type) {
stages.erase(std::remove_if(stages.begin(), stages.end(), [stage_type](const common_speculative_stage_params & stage) {
return stage.type == stage_type;
}), stages.end());
if (type == stage_type) {
const auto resolved = get_resolved_stages();
type = resolved.empty() ? COMMON_SPECULATIVE_TYPE_NONE : resolved.front().type;
}
}
bool common_params_speculative::has_composite_stage_chain() const {
return get_resolved_stages().size() > 1;
}
bool common_params_speculative::needs_dft_model() const {
return has_stage_type(COMMON_SPECULATIVE_TYPE_DRAFT) ||
(has_stage_type(COMMON_SPECULATIVE_TYPE_MTP) && has_dft());
}
void common_params_speculative::clear_dft() {
if (model_dft != nullptr) {
llama_free_model(model_dft);
model_dft = nullptr;
}
model.clear();
params.clear();
mparams_dft.path.clear();
cparams_dft = llama_context_default_params();
}
int32_t common_params_speculative::get_max_stage_n_max() const {
const auto resolved = get_resolved_stages();
if (resolved.empty()) {

View File

@ -252,7 +252,10 @@ struct common_params_speculative {
common_params_speculative with_stage_overrides(const common_speculative_stage_params & stage) const;
bool has_stage_chain() const;
bool has_stage_type(common_speculative_type stage_type) const;
void remove_stage_type(common_speculative_type stage_type);
bool has_composite_stage_chain() const;
bool needs_dft_model() const;
void clear_dft();
int32_t get_max_stage_n_max() const;
int32_t get_min_usable_stage_n_min() const;

View File

@ -47,6 +47,18 @@ const std::map<std::string, enum common_speculative_type> common_speculative_typ
{"suffix", COMMON_SPECULATIVE_TYPE_SUFFIX}
};
void common_speculative_checkpoint::clear() {
valid = false;
per_step_enabled = false;
n_past = 0;
sampled = LLAMA_TOKEN_NULL;
if (sampler != nullptr) {
common_sampler_free(sampler);
sampler = nullptr;
}
}
struct common_speculative_config {
common_speculative_stage_params stage;
common_speculative_type type;
@ -172,6 +184,17 @@ struct common_speculative_state_mtp;
static common_speculative_state_mtp * common_speculative_get_mtp_state(common_speculative * spec);
static const common_speculative_state_mtp * common_speculative_get_mtp_state(const common_speculative * spec);
static void mtp_invalidate_cached_drafts(common_speculative_state_mtp & state);
static bool common_speculative_checkpoint_save(
common_speculative_checkpoint & ckpt,
llama_model * model,
llama_context * ctx,
common_sampler * sampler_src,
const common_params_sampling & sparams,
llama_seq_id seq_id,
llama_pos n_past,
llama_token sampled,
int max_tokens,
int ckpt_mode);
static std::vector<llama_token> mtp_speculative_gen_draft(
common_speculative_state_mtp & state,
@ -1002,12 +1025,17 @@ struct common_speculative_state_suffix : public common_speculative_state {
};
struct common_speculative {
common_speculative_checkpoint checkpoint;
std::vector<common_speculative_config> configs; // resolved stage config for each implementation
std::vector<std::unique_ptr<common_speculative_state>> impls; // list of implementations to use and their states
common_speculative_state * curr_impl = nullptr; // current implementation in use (for stats)
std::unique_ptr<spec_tuner> tuner;
int last_n_drafted = 0;
int64_t t_step_start_us = 0;
~common_speculative() {
checkpoint.clear();
}
};
static bool common_speculative_stage_chain_matches(
@ -1315,6 +1343,7 @@ common_speculative * common_speculative_init(
}
auto * result = new common_speculative {
/* .checkpoint = */ {},
/* .configs = */ std::move(configs),
/* .impls = */ std::move(impls)
};
@ -1340,6 +1369,170 @@ common_speculative * common_speculative_init(
return result;
}
common_speculative_init_status common_speculative_try_init(
common_params_speculative & params,
llama_context * ctx_tgt,
common_speculative ** out_spec) {
if (out_spec != nullptr) {
*out_spec = nullptr;
}
if (!params.has_stage_chain()) {
return COMMON_SPECULATIVE_INIT_SKIPPED;
}
common_speculative * spec = common_speculative_init(params, ctx_tgt);
if (spec != nullptr) {
if (out_spec != nullptr) {
*out_spec = spec;
}
return COMMON_SPECULATIVE_INIT_READY;
}
const llama_model * model = ctx_tgt != nullptr ? llama_get_model(ctx_tgt) : nullptr;
if (model != nullptr && llama_model_has_recurrent(model)) {
return COMMON_SPECULATIVE_INIT_ERR_RECURRENT;
}
if (params.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP)) {
return COMMON_SPECULATIVE_INIT_ERR_MTP;
}
return COMMON_SPECULATIVE_INIT_ERR_GENERIC;
}
void common_speculative_prepare_startup(
gpt_params & params_base,
bool allow_parallel_mtp) {
auto & params = params_base.speculative;
if (!allow_parallel_mtp && params_base.n_parallel > 1 && params.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP)) {
LOG_WRN("%s: MTP is not supported with parallel slots yet, removing the MTP stage to avoid cross-slot corruption. n_parallel=%d, stage_chain=%s\n",
__func__, params_base.n_parallel, common_speculative_stage_chain_to_str(params).c_str());
params.remove_stage_type(COMMON_SPECULATIVE_TYPE_MTP);
}
if (!params.needs_dft_model()) {
params.clear_dft();
}
params_base.has_mtp = params.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP);
}
bool common_speculative_finalize_startup(
gpt_params & params_base,
const llama_model * model) {
auto & params = params_base.speculative;
if (!params.needs_dft_model()) {
params.clear_dft();
}
if (params.has_dft()) {
LLAMA_LOG_INFO("\n\n==================================loading DRAFT model==================================\n\n");
if (!common_speculative_load_draft_model(params, params_base)) {
return false;
}
}
params_base.has_mtp = params.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP);
const bool has_external_mtp = params_base.has_mtp &&
llama_model_is_gemma4_mtp_assistant(params.model_dft);
params_base.has_mtp = common_speculative_prepare_mtp_runtime(
params,
params_base,
model,
has_external_mtp);
if (params_base.has_mtp) {
params_base.pooling_type = LLAMA_POOLING_TYPE_NONE;
}
return true;
}
bool common_speculative_load_draft_model(
common_params_speculative & params,
const gpt_params & params_base) {
if (!params.has_dft()) {
return true;
}
gpt_params params_dft;
params_dft.devices = params.devices;
params_dft.model = params.model;
params_dft.main_gpu = params_base.main_gpu;
params_dft.n_gpu_layers = params.n_gpu_layers;
params_dft.rpc_servers = params_base.rpc_servers;
params_dft.cache_type_k = params.cache_type_k.empty() ? params_base.cache_type_k : params.cache_type_k;
params_dft.cache_type_v = params.cache_type_v.empty() ? params_base.cache_type_v : params.cache_type_v;
params_dft.flash_attn = params_base.flash_attn;
params_dft.k_cache_hadamard = params_base.k_cache_hadamard;
params_dft.v_cache_hadamard = params_base.v_cache_hadamard;
if (!params.params.empty()) {
auto [argc, argv] = parse_command_line("llama-server " + params.params);
if (!gpt_params_parse(argc, argv, params_dft)) {
gpt_params_print_usage(argc, argv, params_dft);
free_command_line(argc, argv);
return false;
}
free_command_line(argc, argv);
}
LOG_INF("%s: loading draft model '%s'\n", __func__, params_dft.model.c_str());
if (params_dft.n_ctx == 0) {
params_dft.n_ctx = params.n_ctx;
}
params_dft.n_ctx = params_dft.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_dft.n_ctx;
params_dft.n_parallel = 1;
params_dft.n_batch = params_dft.n_ctx;
params.mparams_dft.path = params_dft.model;
llama_model_params mparams_dft = common_model_params_to_llama(params_dft);
llama_model * loaded_model = llama_model_load_from_file(params_dft.model.c_str(), mparams_dft);
if (loaded_model == nullptr) {
LOG_ERR("%s: failed to load draft model '%s'\n", __func__, params.model.c_str());
return false;
}
params.model_dft = loaded_model;
params.cparams_dft = common_context_params_to_llama(params_dft);
return true;
}
bool common_speculative_prepare_mtp_runtime(
common_params_speculative & params,
const gpt_params & params_base,
const llama_model * model,
bool has_external_mtp) {
if (!params.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP)) {
return false;
}
if (llama_model_n_nextn_layer(model) == 0 && !has_external_mtp) {
LOG_WRN("%s: MTP speculative stage requested, but model has 0 NextN layers. Removing MTP from the configured stage chain.\n",
__func__);
params.remove_stage_type(COMMON_SPECULATIVE_TYPE_MTP);
if (!params.needs_dft_model()) {
params.clear_dft();
}
return false;
}
if (!has_external_mtp) {
gpt_params params_mtp = params_base;
params_mtp.pooling_type = LLAMA_POOLING_TYPE_NONE;
params.cparams_dft = common_context_params_to_llama(params_mtp);
}
params.cparams_dft.mtp = true;
params.cparams_dft.mtp_op_type = MTP_OP_WARMUP;
params.cparams_dft.embeddings = true;
return true;
}
void common_speculative_free(common_speculative * spec) {
if (spec == nullptr) {
return;
@ -1353,6 +1546,11 @@ void common_speculative_begin(common_speculative * spec, const llama_tokens & pr
return;
}
spec->checkpoint.clear();
spec->curr_impl = nullptr;
spec->last_n_drafted = 0;
spec->t_step_start_us = 0;
for (auto & impl : spec->impls) {
common_time_meas tm(impl->t_begin_us, !impl->gen_perf);
impl->begin(prompt);
@ -1456,6 +1654,34 @@ void common_speculative_accept(common_speculative * spec, uint16_t n_accepted) {
}
}
bool common_speculative_before_draft(
common_speculative * spec,
llama_model * model,
llama_context * ctx,
common_sampler * sampler_src,
const common_params_sampling & sparams,
llama_seq_id seq_id,
llama_pos n_past,
llama_token sampled,
int max_tokens,
int ckpt_mode) {
if (spec == nullptr) {
return false;
}
return common_speculative_checkpoint_save(
spec->checkpoint,
model,
ctx,
sampler_src,
sparams,
seq_id,
n_past,
sampled,
max_tokens,
ckpt_mode);
}
static bool common_speculative_has_type(const common_speculative * spec, common_speculative_type type) {
if (spec == nullptr) {
return false;
@ -1663,6 +1889,38 @@ bool common_speculative_ensure_sequence_hidden(
return common_speculative_capture_output_hidden(spec, ctx, -1, seq_id, pos);
}
common_speculative_draft_result common_speculative_draft_ex(
common_speculative * spec,
llama_context * ctx,
common_params_speculative & params,
const llama_tokens & prompt_tgt,
llama_token id_last,
llama_pos draft_base_pos,
llama_seq_id draft_seq_id) {
common_speculative_draft_result result = {};
if (common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_MTP)) {
if (!common_speculative_ensure_sequence_hidden(spec, ctx, draft_seq_id, draft_base_pos - 1)) {
LOG_ERR("%s: seq_id=%d MTP hidden state is empty during speculation\n",
__func__, (int) draft_seq_id);
return result;
}
}
result.tokens = common_speculative_draft(
spec,
params,
prompt_tgt,
id_last,
draft_base_pos,
draft_seq_id);
result.type = spec != nullptr && spec->curr_impl != nullptr
? spec->curr_impl->type
: COMMON_SPECULATIVE_TYPE_NONE;
return result;
}
int32_t common_speculative_on_target_seq_batch(
common_speculative * spec,
llama_context * ctx_tgt,
@ -1834,6 +2092,234 @@ bool common_speculative_commit_accepted_output(
hidden_rows);
}
static bool common_speculative_checkpoint_save(
common_speculative_checkpoint & ckpt,
llama_model * model,
llama_context * ctx,
common_sampler * sampler_src,
const common_params_sampling & sparams,
llama_seq_id seq_id,
llama_pos n_past,
llama_token sampled,
int max_tokens,
int ckpt_mode) {
ckpt.clear();
ckpt.n_past = n_past;
ckpt.sampled = sampled;
const int actual_mode = llama_spec_ckpt_init(ctx, ckpt_mode, max_tokens);
if (actual_mode == LLAMA_SPEC_CKPT_NONE) {
return false;
}
ckpt.per_step_enabled = (actual_mode == LLAMA_SPEC_CKPT_PER_STEP);
ckpt.valid = llama_spec_ckpt_save(ctx, seq_id);
if (!ckpt.valid) {
llama_spec_ckpt_discard(ctx);
return false;
}
ckpt.sampler = common_sampler_init(model, sparams);
if (ckpt.sampler == nullptr) {
common_speculative_checkpoint_discard(ckpt, ctx);
return false;
}
if (sampler_src != nullptr) {
common_sampler_clone(sampler_src, ckpt.sampler);
}
return true;
}
const common_speculative_checkpoint * common_speculative_get_checkpoint(const common_speculative * spec) {
return spec != nullptr ? &spec->checkpoint : nullptr;
}
void common_speculative_checkpoint_discard(
common_speculative_checkpoint & ckpt,
llama_context * ctx) {
ckpt.clear();
llama_spec_ckpt_discard(ctx);
}
void common_speculative_checkpoint_restore(
common_speculative_checkpoint & ckpt,
common_speculative * spec,
llama_context * ctx,
common_sampler * sampler_dst,
llama_seq_id seq_id,
common_speculative_type spec_type_used,
llama_token sampled_before,
const std::vector<llama_token> & ids,
int n_draft,
const std::vector<float> & mtp_hidden_state_pre,
int32_t mtp_n_past_base) {
if (!ckpt.valid) {
return;
}
if (ckpt.per_step_enabled) {
const int step = (int) ids.size() - 1;
llama_spec_ckpt_restore(ctx, seq_id, ckpt.n_past, step);
if (ckpt.sampler != nullptr && sampler_dst != nullptr) {
common_sampler_clone(ckpt.sampler, sampler_dst);
}
if (sampler_dst != nullptr) {
for (llama_token id : ids) {
common_sampler_accept(sampler_dst, ctx, id, true);
}
}
if (common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_MTP) && !mtp_hidden_state_pre.empty()) {
if (!common_speculative_commit_accepted_hidden_rows(
spec,
spec_type_used,
seq_id,
mtp_n_past_base,
sampled_before,
ids,
mtp_hidden_state_pre)) {
common_speculative_clear_sequence_hidden(spec, seq_id);
} else if (spec_type_used != COMMON_SPECULATIVE_TYPE_MTP) {
LOG_DBG("%s: seq_id=%d synced MTP target hidden state from accepted-prefix rows after per-step restore\n",
__func__, (int) seq_id);
}
}
LOG_DBG("%s: seq_id=%d per-step restore: step=%d (rejected %d drafts)\n",
__func__, (int) seq_id, step, (int) (n_draft - (ids.size() - 1)));
} else {
llama_spec_ckpt_restore(ctx, seq_id, ckpt.n_past, 0);
if (ckpt.sampler != nullptr && sampler_dst != nullptr) {
common_sampler_clone(ckpt.sampler, sampler_dst);
}
if (!ids.empty()) {
const int n_re = (int) ids.size();
llama_batch re_batch = llama_batch_init(n_re, 0, 1);
common_batch_add(re_batch, ckpt.sampled, ckpt.n_past, { seq_id }, n_re == 1);
for (int j = 0; j < n_re - 1; ++j) {
common_batch_add(re_batch, ids[j], ckpt.n_past + 1 + j, { seq_id }, j == n_re - 2);
}
if (common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_MTP)) {
for (int j = 0; j < re_batch.n_tokens; ++j) {
re_batch.logits[j] = true;
}
llama_set_embeddings(ctx, true);
}
const int ret = llama_decode(ctx, re_batch);
if (ret != 0) {
LOG_ERR("%s: seq_id=%d failed to re-decode accepted tokens after checkpoint restore: %d\n",
__func__, (int) seq_id, ret);
}
if (common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_MTP)) {
std::vector<int32_t> redecoded_indices(n_re);
for (int j = 0; j < n_re; ++j) {
redecoded_indices[j] = j;
}
if (!common_speculative_commit_accepted_output(
spec,
ctx,
spec_type_used,
seq_id,
ckpt.n_past,
sampled_before,
ids,
redecoded_indices)) {
common_speculative_clear_sequence_hidden(spec, seq_id);
}
}
if (sampler_dst != nullptr) {
for (llama_token id : ids) {
common_sampler_accept(sampler_dst, ctx, id, true);
}
}
llama_batch_free(re_batch);
LOG_DBG("%s: seq_id=%d spec checkpoint restored: re-decoded %d tokens (rejected %d drafts)\n",
__func__, (int) seq_id, n_re, (int) (n_draft - (ids.size() - 1)));
}
}
common_speculative_checkpoint_discard(ckpt, ctx);
}
void common_speculative_commit(
common_speculative * spec,
llama_context * ctx,
common_sampler * sampler_dst,
llama_seq_id seq_id,
llama_token sampled_before,
const std::vector<llama_token> & ids,
int n_draft,
llama_pos pos_base,
const std::vector<int32_t> & accepted_output_indices) {
GGML_ASSERT(spec != nullptr);
GGML_ASSERT(!ids.empty());
common_speculative_checkpoint & ckpt = spec->checkpoint;
const common_speculative_type spec_type_used = spec->curr_impl != nullptr
? spec->curr_impl->type
: COMMON_SPECULATIVE_TYPE_NONE;
const bool any_rejected = (int) ids.size() - 1 < n_draft;
std::vector<float> mtp_hidden_state_pre;
common_speculative_accept(spec, ids.size() - 1);
if (common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_MTP) &&
any_rejected &&
ckpt.valid &&
!accepted_output_indices.empty()) {
if (!common_speculative_copy_output_hidden_rows(spec, ctx, accepted_output_indices, mtp_hidden_state_pre)) {
mtp_hidden_state_pre.clear();
}
}
if (any_rejected && ckpt.valid) {
common_speculative_checkpoint_restore(
ckpt,
spec,
ctx,
sampler_dst,
seq_id,
spec_type_used,
sampled_before,
ids,
n_draft,
mtp_hidden_state_pre,
pos_base);
return;
}
if (common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_MTP) && !accepted_output_indices.empty()) {
if (!common_speculative_commit_accepted_output(
spec,
ctx,
spec_type_used,
seq_id,
pos_base,
sampled_before,
ids,
accepted_output_indices)) {
common_speculative_clear_sequence_hidden(spec, seq_id);
} else if (spec_type_used != COMMON_SPECULATIVE_TYPE_MTP) {
LOG_DBG("%s: seq_id=%d synced MTP target hidden state from accepted-prefix rows\n",
__func__, (int) seq_id);
}
}
llama_kv_cache_seq_rm(ctx, seq_id, pos_base + (llama_pos) (ids.size() - 1), -1);
common_speculative_checkpoint_discard(ckpt, ctx);
}
void common_speculative_print_stats(const common_speculative * spec, double slot_tps, int n_decoded, int n_past, common_params_speculative * active_params) {
if (spec == nullptr) {
return;
@ -1980,6 +2466,50 @@ void common_speculative_clear_sequence_hidden(common_speculative * spec, llama_s
mtp_clear_target_hidden(*mtp_state, seq_id);
}
void common_speculative_clear_sequence(
common_speculative * spec,
llama_seq_id seq_id,
bool clear_companion_ctx) {
if (spec != nullptr) {
spec->checkpoint.clear();
spec->curr_impl = nullptr;
spec->last_n_drafted = 0;
spec->t_step_start_us = 0;
}
common_speculative_clear_sequence_hidden(spec, seq_id);
if (clear_companion_ctx) {
if (auto * ctx_mtp = common_speculative_get_companion_ctx(spec); ctx_mtp != nullptr) {
llama_kv_cache_clear(ctx_mtp);
}
}
}
bool common_speculative_trim_sequence(
common_speculative * spec,
llama_context * ctx,
llama_seq_id seq_id,
llama_pos pos_begin) {
const bool target_trimmed = llama_kv_cache_seq_rm(ctx, seq_id, pos_begin, -1);
if (auto * ctx_mtp = common_speculative_get_companion_ctx(spec); ctx_mtp != nullptr) {
return target_trimmed && llama_kv_cache_seq_rm(ctx_mtp, seq_id, pos_begin, -1);
}
return target_trimmed;
}
void common_speculative_clear_sequence_kv(
common_speculative * spec,
llama_context * ctx,
llama_seq_id seq_id) {
common_speculative_clear_sequence(spec, seq_id);
llama_kv_cache_seq_rm(ctx, seq_id, -1, -1);
if (auto * ctx_mtp = common_speculative_get_companion_ctx(spec); ctx_mtp != nullptr) {
llama_kv_cache_seq_rm(ctx_mtp, seq_id, -1, -1);
}
}
llama_context * common_speculative_get_companion_ctx(common_speculative * spec) {
if (auto * mtp_state = common_speculative_get_mtp_state(spec); mtp_state != nullptr) {
return mtp_state->ctx_mtp;
@ -2184,13 +2714,10 @@ std::vector<llama_token> mtp_speculative_gen_draft(
// This prevents cache state corruption where two cells map to the same logical position.
// If the state contained in `last` had a valid token id and probability, it means that we
// have previously run an "accept" batch, where the token sampled from the main model was included.
// In that case, we need to discard all tokens that we ran here to get the KV cache to the correct state.
// => for i0 = 1 we discard from n_past
// But if we did not have a valid last token_id, it means the first token we run was sampled from the
// main model. Hence we want to keep this token in the KV cache and discard all other tokens.
// => for i0 = 0 we discard from n_past + 1
// Even in that case, the token at `n_past` is already committed and must remain in the KV cache,
// so we only discard the speculative tail starting at `n_past + 1`.
if (n_decode > 0) {
llama_kv_cache_seq_rm(ctx, seq_id, n_past + 1 - i0, n_past + n_decode + 2);
llama_kv_cache_seq_rm(ctx, seq_id, n_past + 1, n_past + n_decode + 2);
}
return drafts;

View File

@ -7,6 +7,14 @@
struct common_speculative;
enum common_speculative_init_status {
COMMON_SPECULATIVE_INIT_SKIPPED,
COMMON_SPECULATIVE_INIT_READY,
COMMON_SPECULATIVE_INIT_ERR_RECURRENT,
COMMON_SPECULATIVE_INIT_ERR_MTP,
COMMON_SPECULATIVE_INIT_ERR_GENERIC,
};
using common_speculative_feature_kind = llama_spec_feature_kind;
using common_speculative_feature_row_view = llama_spec_feature_row_view;
using common_speculative_feature_view = llama_spec_feature_view;
@ -14,6 +22,21 @@ using common_speculative_feature_view = llama_spec_feature_view;
static constexpr common_speculative_feature_kind COMMON_SPECULATIVE_FEATURE_NONE = LLAMA_SPEC_FEATURE_NONE;
static constexpr common_speculative_feature_kind COMMON_SPECULATIVE_FEATURE_HIDDEN_STATE = LLAMA_SPEC_FEATURE_HIDDEN_STATE;
struct common_speculative_checkpoint {
bool valid = false;
bool per_step_enabled = false;
llama_pos n_past = 0;
llama_token sampled = LLAMA_TOKEN_NULL;
common_sampler * sampler = nullptr;
void clear();
};
struct common_speculative_draft_result {
llama_tokens tokens;
common_speculative_type type = COMMON_SPECULATIVE_TYPE_NONE;
};
// comma separated list of all types
std::string common_speculative_type_name_str();
@ -31,6 +54,29 @@ common_speculative * common_speculative_init(
common_params_speculative & params,
llama_context * ctx_tgt);
common_speculative_init_status common_speculative_try_init(
common_params_speculative & params,
llama_context * ctx_tgt,
common_speculative ** out_spec);
void common_speculative_prepare_startup(
gpt_params & params_base,
bool allow_parallel_mtp = true);
bool common_speculative_finalize_startup(
gpt_params & params_base,
const llama_model * model);
bool common_speculative_load_draft_model(
common_params_speculative & params,
const gpt_params & params_base);
bool common_speculative_prepare_mtp_runtime(
common_params_speculative & params,
const gpt_params & params_base,
const llama_model * model,
bool has_external_mtp);
void common_speculative_free(common_speculative * spec);
// optionally call once at the beginning of a new generation
@ -46,9 +92,30 @@ llama_tokens common_speculative_draft(
llama_pos draft_base_pos = -1,
llama_seq_id draft_seq_id = 0);
common_speculative_draft_result common_speculative_draft_ex(
common_speculative * spec,
llama_context * ctx,
common_params_speculative & params,
const llama_tokens & prompt,
llama_token id_last,
llama_pos draft_base_pos = -1,
llama_seq_id draft_seq_id = 0);
// informs the speculative decoder that n_accepted tokens were accepted by the target model
void common_speculative_accept(common_speculative * spec, uint16_t n_accepted);
bool common_speculative_before_draft(
common_speculative * spec,
llama_model * model,
llama_context * ctx,
common_sampler * sampler_src,
const common_params_sampling & sparams,
llama_seq_id seq_id,
llama_pos n_past,
llama_token sampled,
int max_tokens,
int ckpt_mode);
bool common_speculative_ensure_sequence_hidden(
common_speculative * spec,
llama_context * ctx,
@ -87,10 +154,56 @@ bool common_speculative_commit_accepted_output(
const std::vector<llama_token> & ids,
const std::vector<int32_t> & output_indices);
const common_speculative_checkpoint * common_speculative_get_checkpoint(const common_speculative * spec);
void common_speculative_checkpoint_discard(
common_speculative_checkpoint & ckpt,
llama_context * ctx);
void common_speculative_checkpoint_restore(
common_speculative_checkpoint & ckpt,
common_speculative * spec,
llama_context * ctx,
common_sampler * sampler_dst,
llama_seq_id seq_id,
common_speculative_type spec_type_used,
llama_token sampled_before,
const std::vector<llama_token> & ids,
int n_draft,
const std::vector<float> & mtp_hidden_state_pre,
int32_t mtp_n_past_base);
void common_speculative_commit(
common_speculative * spec,
llama_context * ctx,
common_sampler * sampler_dst,
llama_seq_id seq_id,
llama_token sampled_before,
const std::vector<llama_token> & ids,
int n_draft,
llama_pos pos_base,
const std::vector<int32_t> & accepted_output_indices);
bool common_speculative_has_sequence_hidden(const common_speculative * spec, llama_seq_id seq_id);
void common_speculative_clear_sequence_hidden(common_speculative * spec, llama_seq_id seq_id);
void common_speculative_clear_sequence(
common_speculative * spec,
llama_seq_id seq_id,
bool clear_companion_ctx = false);
bool common_speculative_trim_sequence(
common_speculative * spec,
llama_context * ctx,
llama_seq_id seq_id,
llama_pos pos_begin);
void common_speculative_clear_sequence_kv(
common_speculative * spec,
llama_context * ctx,
llama_seq_id seq_id);
llama_context * common_speculative_get_companion_ctx(common_speculative * spec);
int32_t common_speculative_on_target_seq_batch(

View File

@ -45,11 +45,6 @@ static void log_text(const gpt_params & params_base, const std::string & text) {
}
}
static bool params_use_gemma4_external_mtp(const gpt_params & params_base) {
return params_base.has_mtp &&
llama_model_is_gemma4_mtp_assistant(params_base.speculative.model_dft);
}
struct server_mtp_warmup {
llama_context * ctx_tgt;
server_slot * slot;
@ -72,72 +67,6 @@ static bool server_response_needs_chat_parse(oaicompat_type oaicompat) {
oaicompat == OAICOMPAT_TYPE_RESP;
}
void server_speculative_checkpoint::clear() {
valid = false;
per_step_enabled = false;
n_past = 0;
sampled = LLAMA_TOKEN_NULL;
if (sampler != nullptr) {
common_sampler_free(sampler);
sampler = nullptr;
}
}
static void discard_speculative_checkpoint(server_slot & slot, llama_context * ctx) {
slot.spec_ckpt.clear();
llama_spec_ckpt_discard(ctx);
}
static bool save_speculative_checkpoint(server_slot & slot, llama_model * model, llama_context * ctx, int ckpt_mode) {
slot.spec_ckpt.clear();
const int32_t n_pre_spec_tokens = slot.cache_tokens.n_tokens() - (int32_t)(slot.drafted.size() + 1);
slot.spec_ckpt.n_past = slot.cache_tokens.pos_next(n_pre_spec_tokens);
slot.spec_ckpt.sampled = slot.sampled;
const int max_tokens = (int)slot.drafted.size() + 1;
const int actual_mode = llama_spec_ckpt_init(ctx, ckpt_mode, max_tokens);
if (actual_mode == LLAMA_SPEC_CKPT_NONE) {
return false;
}
slot.spec_ckpt.per_step_enabled = (actual_mode == LLAMA_SPEC_CKPT_PER_STEP);
slot.spec_ckpt.valid = llama_spec_ckpt_save(ctx, slot.id);
if (!slot.spec_ckpt.valid) {
llama_spec_ckpt_discard(ctx);
return false;
}
slot.spec_ckpt.sampler = common_sampler_init(model, slot.sparams);
if (slot.spec_ckpt.sampler == nullptr) {
discard_speculative_checkpoint(slot, ctx);
return false;
}
common_sampler_clone(slot.ctx_sampling, slot.spec_ckpt.sampler);
return true;
}
static void server_remove_speculative_stage(common_params_speculative & spec, common_speculative_type type) {
spec.stages.erase(std::remove_if(spec.stages.begin(), spec.stages.end(), [type](const common_speculative_stage_params & stage) {
return stage.type == type;
}), spec.stages.end());
if (spec.type == type) {
spec.type = COMMON_SPECULATIVE_TYPE_NONE;
const auto resolved = spec.get_resolved_stages();
spec.type = resolved.empty() ? COMMON_SPECULATIVE_TYPE_NONE : resolved.front().type;
}
}
static bool server_speculative_needs_draft_model(const common_params_speculative & spec) {
return spec.has_stage_type(COMMON_SPECULATIVE_TYPE_DRAFT);
}
static bool server_speculative_has_mtp(const common_params_speculative & spec) {
return spec.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP);
}
static bool server_speculative_same_stage_types(
const common_params_speculative & lhs,
const common_params_speculative & rhs) {
@ -233,29 +162,17 @@ server_context::~server_context() {
}
// Free multimodal
mtmd_free(mctx);
// Free draft model and context if they exist
if (ctx_draft) {
llama_free(ctx_draft);
ctx_draft = nullptr;
}
if (model_draft) {
llama_free_model(model_draft);
model_draft = nullptr;
}
// Clear any sampling context
for (server_slot& slot : slots) {
if (slot.ctx_sampling != nullptr) {
common_sampler_free(slot.ctx_sampling);
}
slot.spec_ckpt.clear();
if (slot.ctx_dft) {
llama_free(slot.ctx_dft);
}
common_speculative_free(slot.spec);
llama_batch_free(slot.batch_spec);
}
params_base.speculative.clear_dft();
llama_batch_free(batch);
}
@ -278,24 +195,10 @@ bool server_context::load_model(const gpt_params& params_) {
add_bos_token = llama_should_add_bos_token(model);
has_eos_token = llama_add_eos_token(model) != 1;
if (params_base.n_parallel > 1 && server_speculative_has_mtp(params_base.speculative)) {
LOG_WARNING("MTP is not supported with parallel slots yet, removing the MTP stage to avoid cross-slot corruption.\n", {
{"n_parallel", params_base.n_parallel},
{"stage_chain", common_speculative_stage_chain_to_str(params_base.speculative)},
});
common_speculative_prepare_startup(params_base, false);
params_base.has_mtp = false;
server_remove_speculative_stage(params_base.speculative, COMMON_SPECULATIVE_TYPE_MTP);
if (!server_speculative_needs_draft_model(params_base.speculative)) {
params_base.speculative.model.clear();
params_base.speculative.params.clear();
params_base.speculative.model_dft = nullptr;
}
}
bool has_draft_model = !params_base.speculative.model.empty() || !params_base.speculative.params.empty();
std::string& mmproj_path = params_base.mmproj.path;
const bool has_draft_model = params_base.speculative.has_dft();
std::string & mmproj_path = params_base.mmproj.path;
if (!mmproj_path.empty()) {
mtmd_context_params mparams = mtmd_context_params_default();
mparams.use_gpu = params_base.mmproj_use_gpu;
@ -309,10 +212,10 @@ bool server_context::load_model(const gpt_params& params_) {
mparams.image_max_tokens = params_base.image_max_tokens;
mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams);
if (mctx == nullptr) {
LOG_ERROR("failed to load multimodal model, '%s'\n", mmproj_path.c_str());
LOG_ERROR("failed to load multimodal model, %s\n", mmproj_path.c_str());
return false;
}
LOG_INFO("loaded multimodal model, '%s'\n", mmproj_path.c_str());
LOG_INFO("loaded multimodal model, %s\n", mmproj_path.c_str());
//if (params.n_cache_reuse) {
// params_base.n_cache_reuse = 0;
@ -323,71 +226,22 @@ bool server_context::load_model(const gpt_params& params_) {
LOG_ERROR("%s\n", "err: speculative decode is not supported by multimodal");
return false;
}
const auto spec_stages = params_base.speculative.get_resolved_stages();
const bool multimodal_spec_supported = spec_stages.empty() ||
(spec_stages.size() == 1 && spec_stages.front().type == COMMON_SPECULATIVE_TYPE_MTP);
if (!multimodal_spec_supported) {
const auto spec_stages = params_base.speculative.get_resolved_stages();
const bool multimodal_spec_supported = spec_stages.empty() ||
(spec_stages.size() == 1 && spec_stages.front().type == COMMON_SPECULATIVE_TYPE_MTP);
if (!multimodal_spec_supported) {
params_base.speculative.type = COMMON_SPECULATIVE_TYPE_NONE;
params_base.speculative.stages.clear();
params_base.has_mtp = false;
SRV_WRN("%s\n", "speculative decoding is not supported by multimodal, it will be disabled");
}
}
// Load draft model for speculative decoding if specified
if (has_draft_model) {
LLAMA_LOG_INFO("\n\n==================================loading DRAFT model==================================\n\n");
gpt_params params_dft;
params_dft.devices = params_base.speculative.devices;
params_dft.model = params_base.speculative.model;
params_dft.main_gpu = params_base.main_gpu;
params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
params_dft.rpc_servers = params_base.rpc_servers;
params_dft.cache_type_k = params_base.speculative.cache_type_k.empty() ? params_base.cache_type_k : params_base.speculative.cache_type_k;
params_dft.cache_type_v = params_base.speculative.cache_type_v.empty() ? params_base.cache_type_v : params_base.speculative.cache_type_v;
params_dft.flash_attn = params_base.flash_attn;
params_dft.k_cache_hadamard = params_base.k_cache_hadamard;
params_dft.v_cache_hadamard = params_base.v_cache_hadamard;
if (!params_base.speculative.params.empty()) {
auto [argc, argv] = parse_command_line("llama-server " + params_base.speculative.params);
if (!gpt_params_parse(argc, argv, params_dft)) {
gpt_params_print_usage(argc, argv, params_dft);
free_command_line(argc, argv);
return false;
};
free_command_line(argc, argv);
}
LOG_INFO("", { {"model", params_dft.model} });
if (params_dft.n_ctx == 0) {
params_dft.n_ctx = params_base.speculative.n_ctx;
}
params_dft.n_ctx = params_dft.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_dft.n_ctx;
params_dft.n_parallel = 1;
params_dft.n_batch = params_dft.n_ctx;
params_base.speculative.mparams_dft.path = params_dft.model; //
llama_model_params mparams_dft = common_model_params_to_llama(params_dft);
llama_model * model_dft = llama_model_load_from_file(params_dft.model.c_str(), mparams_dft);
if (model_dft == nullptr) {
LOG_ERROR("failed to load draft model", { {"model", params_base.speculative.model} });
return false;
}
cparams_dft = common_context_params_to_llama(params_dft);
params_base.speculative.model_dft = model_dft;
params_base.speculative.cparams_dft = cparams_dft;
if (!common_speculative_finalize_startup(params_base, model)) {
return false;
}
if (server_speculative_has_mtp(params_base.speculative) &&
llama_model_n_nextn_layer(model) == 0 &&
!params_use_gemma4_external_mtp(params_base)) {
LOG_WARNING("WARNING: MTP speculative stage requested, but model has 0 NextN layers. MTP will be disabled.\n", {});
params_base.has_mtp = false;
server_remove_speculative_stage(params_base.speculative, COMMON_SPECULATIVE_TYPE_MTP);
}
return true;
}
@ -396,6 +250,20 @@ void server_context::init() {
LOG_INFO("initializing slots", { {"n_slots", params_base.n_parallel} });
if (params_base.has_mtp) {
SRV_INF("%s\n", "MTP needs embeddings on decode, enabling");
llama_set_embeddings(ctx, true);
}
const bool requested_spec = params_base.speculative.has_stage_chain();
bool can_spec = true;
if (!params_base.dry_run) {
can_spec = common_speculative_is_compat(ctx);
}
if (!can_spec && requested_spec) {
SRV_WRN("%s", "speculative decoding not supported by this context\n");
}
for (int i = 0; i < params_base.n_parallel; i++) {
server_slot slot;
@ -440,68 +308,27 @@ void server_context::init() {
slot.params.speculative = params_base.speculative;
slot.sparams = params_base.sparams;
const bool wants_mtp_stage = server_speculative_has_mtp(params_base.speculative);
if (wants_mtp_stage) {
const bool has_external_mtp = params_use_gemma4_external_mtp(params_base);
if (llama_model_n_nextn_layer(model) > 0 || has_external_mtp) {
params_base.pooling_type = LLAMA_POOLING_TYPE_NONE;
if (!has_external_mtp) {
params_base.speculative.cparams_dft = common_context_params_to_llama(params_base);
}
params_base.speculative.cparams_dft.mtp = true;
params_base.speculative.cparams_dft.mtp_op_type = MTP_OP_WARMUP;
params_base.speculative.cparams_dft.embeddings = true;
slot.has_mtp = true;
slot.params.speculative.cparams_dft = params_base.speculative.cparams_dft;
slot.batch_spec = llama_batch_init(slot.params.speculative.get_max_stage_n_max() + 1, 0, 1);
SLT_DBG(slot, "batch_spec contains %d tokens\n", slot.batch_spec.n_tokens);
SRV_INF("%s\n", "MTP needs embeddings on decode, enabling");
llama_set_embeddings(ctx, true);
}
else {
SRV_WRN("%s\n", "MTP speculative stage requested, but model has 0 NextN layers. Removing MTP from the configured stage chain.");
params_base.has_mtp = false;
server_remove_speculative_stage(params_base.speculative, COMMON_SPECULATIVE_TYPE_MTP);
slot.params.speculative = params_base.speculative;
slot.has_mtp = false;
}
}
const bool requested_spec = !params_base.speculative.get_resolved_stages().empty();
bool can_spec = true;
if (!params_base.dry_run) {
can_spec = common_speculative_is_compat(ctx);
}
if (!can_spec) {
SRV_WRN("%s", "speculative decoding not supported by this context\n");
}
// try speculative decoding
if (can_spec && requested_spec) {
slot.spec = common_speculative_init(params_base.speculative, slot.ctx);
if (slot.spec) {
if (mctx && !slot.has_mtp) {
switch (common_speculative_try_init(params_base.speculative, slot.ctx, &slot.spec)) {
case COMMON_SPECULATIVE_INIT_READY:
if (mctx && !slot.uses_mtp()) {
SRV_ERR("%s\n", "speculative decoding is not supported with multimodal");
return;
}
SLT_INF(slot, "%s", "speculative decoding context initialized\n");
} else {
if (llama_model_has_recurrent(model)) {
SRV_ERR("%s", "failed to initialize recurrent speculative context\n");
throw std::runtime_error("recurrent speculative context initialization failed");
} else if (slot.has_mtp) {
SRV_ERR("%s", "failed to initialize MTP speculative context\n");
throw std::runtime_error("MTP speculative context initialization failed");
} else {
SRV_ERR("%s", "failed to initialize speculative decoding context\n");
throw std::runtime_error("speculative decoding context initialization failed");
}
break;
case COMMON_SPECULATIVE_INIT_ERR_RECURRENT:
SRV_ERR("%s", "failed to initialize recurrent speculative context\n");
throw std::runtime_error("recurrent speculative context initialization failed");
case COMMON_SPECULATIVE_INIT_ERR_MTP:
SRV_ERR("%s", "failed to initialize MTP speculative context\n");
throw std::runtime_error("MTP speculative context initialization failed");
case COMMON_SPECULATIVE_INIT_ERR_GENERIC:
SRV_ERR("%s", "failed to initialize speculative decoding context\n");
throw std::runtime_error("speculative decoding context initialization failed");
case COMMON_SPECULATIVE_INIT_SKIPPED:
break;
}
}
@ -620,9 +447,7 @@ void server_slot::reset() {
n_kept_prompt = 0;
n_sent_text = 0;
drafted.clear();
drafted_spec_type = COMMON_SPECULATIVE_TYPE_NONE;
i_batch_dft.clear();
spec_ckpt.clear();
n_sent_token_probs = 0;
infill = false;
ga_i = 0;
@ -640,7 +465,7 @@ void server_slot::reset() {
image_just_processed = false;
do_checkpoint = false;
if (spec != nullptr) {
common_speculative_clear_sequence_hidden(spec, id);
common_speculative_clear_sequence(spec, id);
}
positional_bans.clear();
@ -675,7 +500,11 @@ void server_slot::reset() {
}
bool server_slot::need_embd() const {
return embedding || has_mtp;
return embedding || uses_mtp();
}
bool server_slot::uses_mtp() const {
return params.speculative.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP);
}
bool server_slot::has_budget(gpt_params& global_params) {
@ -711,7 +540,7 @@ void server_slot::add_token_string(const completion_token_output& token) {
}
bool server_slot::can_speculate() const {
return (!!spec || has_mtp);
return (!!spec || uses_mtp());
}
int server_slot::get_n_draft_max() const {
@ -1327,7 +1156,7 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
throw std::runtime_error("Error: per-request speculative stages must match the server startup stage types; only stage parameter overrides are supported");
}
if (slot.params.speculative.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP) && !slot.has_mtp) {
if (slot.params.speculative.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP) && !params_base.has_mtp) {
throw std::runtime_error("Error: MTP speculative stage requested, but the server was not started with MTP support");
}
@ -2107,10 +1936,7 @@ void server_context::kv_cache_clear() {
continue;
}
common_speculative_clear_sequence_hidden(slot.spec, slot.id);
if (auto * ctx_companion = common_speculative_get_companion_ctx(slot.spec); ctx_companion != nullptr) {
llama_kv_cache_clear(ctx_companion);
}
common_speculative_clear_sequence(slot.spec, slot.id, true);
}
clean_kv_cache = false;
}
@ -3360,7 +3186,7 @@ void server_context::discard_n_kv_and_cache_tokens(llama_context* ctx, server_sl
const auto pos_max = llama_kv_cache_seq_pos_max(slot.ctx, slot.id);
llama_kv_cache_seq_rm(ctx, slot.id, slot.cache_tokens.pos_next(kv_keep), slot.cache_tokens.pos_next(kv_keep + kv_discard));
llama_kv_cache_seq_add(ctx, slot.id, kv_keep + kv_discard, kv_past, -kv_discard);
if (slot.has_mtp && slot.spec) {
if (slot.uses_mtp() && slot.spec) {
common_speculative_context_shift(slot.spec, slot.id, kv_keep, kv_discard, kv_past);
}
if (slot.params.cache_prompt) {
@ -3569,33 +3395,27 @@ void server_context::add_sampled_tokens() {
// perform the speculative drafting for all sequences at the same time in a single batch
const int n_draft_max_pre = slot.get_n_draft_max();
if (n_draft_max_pre > 0) {
if (mctx && !slot.has_mtp) {
if (mctx && !slot.uses_mtp()) {
// we should never reach this, as speculative is automatically disabled if mmproj is loaded
GGML_ABORT("not supported by multimodal");
}
static const llama_tokens empty_prompt;
const llama_tokens & cached_text_tokens = slot.has_mtp && !slot.params.speculative.has_composite_stage_chain()
const llama_tokens & cached_text_tokens = slot.uses_mtp() && !slot.params.speculative.has_composite_stage_chain()
? empty_prompt
: slot.cache_tokens.get_text_tokens();
auto & params_spec = slot.params.speculative;
const llama_pos draft_base_pos = slot.has_mtp ? slot.cache_tokens.pos_next() : -1;
if (slot.has_mtp) {
if (!common_speculative_ensure_sequence_hidden(slot.spec, ctx, slot.id, draft_base_pos - 1)) {
LOG_ERROR("MTP hidden state is empty during speculation", {});
}
}
llama_tokens draft = common_speculative_draft(
const llama_pos draft_base_pos = slot.uses_mtp() ? slot.cache_tokens.pos_next() : -1;
common_speculative_draft_result draft_result = common_speculative_draft_ex(
slot.spec,
ctx,
params_spec,
cached_text_tokens,
slot.sampled,
draft_base_pos,
slot.id);
slot.drafted_spec_type = common_speculative_current_type(slot.spec);
llama_tokens & draft = draft_result.tokens;
const int n_draft_max = slot.get_n_draft_max();
@ -3620,7 +3440,6 @@ void server_context::add_sampled_tokens() {
// fallback to normal decoding
slot.i_batch = slot.i_batch_dft[0];
slot.drafted.clear();
slot.drafted_spec_type = COMMON_SPECULATIVE_TYPE_NONE;
slot.i_batch_dft.clear();
} else {
// keep track of total number of drafted tokens tested
@ -3637,7 +3456,6 @@ void server_context::add_sampled_tokens() {
}
else {
// no speculative decoding
slot.drafted_spec_type = COMMON_SPECULATIVE_TYPE_NONE;
slot.i_batch = batch.n_tokens;
common_batch_add(batch, slot.sampled, slot.cache_tokens.pos_next(), { slot.id }, true);
@ -3977,15 +3795,10 @@ void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t
slot.cache_tokens.keep_first(slot.n_past);
int p0 = (int)system_tokens.size() + slot.n_past;
p0 = system_tokens.size() + slot.cache_tokens.pos_next();
auto * ctx_companion = slot.spec ? common_speculative_get_companion_ctx(slot.spec) : nullptr;
const bool target_trimmed = llama_kv_cache_seq_rm(ctx, slot.id, p0, -1);
const bool companion_trimmed = ctx_companion == nullptr || llama_kv_cache_seq_rm(ctx_companion, slot.id, p0, -1);
if (!target_trimmed || !companion_trimmed) {
const bool trimmed = common_speculative_trim_sequence(slot.spec, ctx, slot.id, p0);
if (!trimmed) {
// could not partially delete (likely using a non-Transformer model)
llama_kv_cache_seq_rm(ctx, slot.id, -1, -1);
if (ctx_companion != nullptr) {
llama_kv_cache_seq_rm(ctx_companion, slot.id, -1, -1);
}
common_speculative_clear_sequence_kv(slot.spec, ctx, slot.id);
p0 = (int)system_tokens.size();
if (p0 != 0) {
@ -4022,7 +3835,7 @@ void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t
llama_pos p1 = slot.cache_tokens.pos_next() + slot.n_past_prompt - slot.n_past; // add offset to prompt
server_mtp_warmup mtp_media_warmup {
ctx,
slot.has_mtp && slot.spec ? &slot : nullptr,
slot.uses_mtp() && slot.spec ? &slot : nullptr,
};
mtmd_helper_eval_batch_callback mtp_media_callback =
mtp_media_warmup.slot ? server_mtp_media_warmup_callback : nullptr;
@ -4164,103 +3977,6 @@ void server_context::extend_context(const int32_t n_tokens) {
}
}
// Restore recurrent state and re-decode accepted tokens after speculative-decode rejection.
static void restore_speculative_checkpoint(
server_slot & slot, llama_context * ctx, llama_model * model,
common_speculative_type spec_type_used,
llama_token sampled_before,
const std::vector<llama_token> & ids, int n_draft,
const std::vector<float> & mtp_hidden_state_pre, int32_t mtp_n_past_base) {
if (slot.spec_ckpt.per_step_enabled) {
const int step = (int)ids.size() - 1;
llama_spec_ckpt_restore(ctx, slot.id, slot.spec_ckpt.n_past, step);
if (slot.spec_ckpt.sampler) {
common_sampler_clone(slot.spec_ckpt.sampler, slot.ctx_sampling);
}
for (llama_token id : ids) {
common_sampler_accept(slot.ctx_sampling, ctx, id, true);
}
// Update MTP KV cache and hidden state using embeddings collected before checkpoint restore.
if (slot.has_mtp && !mtp_hidden_state_pre.empty()) {
if (!common_speculative_commit_accepted_hidden_rows(
slot.spec,
spec_type_used,
slot.id,
mtp_n_past_base,
sampled_before,
ids,
mtp_hidden_state_pre)) {
common_speculative_clear_sequence_hidden(slot.spec, slot.id);
} else if (spec_type_used != COMMON_SPECULATIVE_TYPE_MTP) {
SLT_DBG(slot, "%s", "synced MTP target hidden state from accepted-prefix rows after per-step restore");
}
}
SLT_DBG(slot, "per-step restore: step=%d (rejected %d drafts)\n",
step, (int)(n_draft - (ids.size() - 1)));
} else {
// Restore pre-speculation recurrent state then re-decode accepted tokens.
llama_spec_ckpt_restore(ctx, slot.id, slot.spec_ckpt.n_past, 0);
if (slot.spec_ckpt.sampler) {
common_sampler_clone(slot.spec_ckpt.sampler, slot.ctx_sampling);
}
if (!ids.empty()) {
// Re-decode to advance recurrent state to the accepted position.
const int n_re = (int)ids.size();
llama_batch re_batch = llama_batch_init(n_re, 0, 1);
common_batch_add(re_batch, slot.spec_ckpt.sampled, slot.spec_ckpt.n_past, { slot.id }, n_re == 1);
for (int j = 0; j < n_re - 1; j++) {
common_batch_add(re_batch, ids[j], slot.spec_ckpt.n_past + 1 + j, { slot.id }, j == n_re - 2);
}
if (slot.has_mtp) {
for (int j = 0; j < re_batch.n_tokens; j++) {
re_batch.logits[j] = true;
}
llama_set_embeddings(ctx, true);
}
const int ret = llama_decode(ctx, re_batch);
if (ret != 0) {
SLT_ERR(slot, "failed to re-decode accepted tokens after checkpoint restore: %d\n", ret);
}
if (slot.has_mtp) {
const int n_accepted = (int)ids.size();
std::vector<int32_t> redecoded_indices(n_accepted);
for (int j = 0; j < n_accepted; ++j) {
redecoded_indices[j] = j;
}
if (!common_speculative_commit_accepted_output(
slot.spec,
ctx,
spec_type_used,
slot.id,
slot.spec_ckpt.n_past,
sampled_before,
ids,
redecoded_indices)) {
common_speculative_clear_sequence_hidden(slot.spec, slot.id);
}
}
for (llama_token id : ids) {
common_sampler_accept(slot.ctx_sampling, ctx, id, true);
}
llama_batch_free(re_batch);
SLT_DBG(slot, "spec checkpoint restored: re-decoded %d tokens (rejected %d drafts)\n",
n_re, (int)(n_draft - (ids.size() - 1)));
}
}
discard_speculative_checkpoint(slot, ctx);
}
void server_context::speculative_decoding_accept() {
for (auto& slot : slots) {
if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch_dft.empty()) {
@ -4268,7 +3984,6 @@ void server_context::speculative_decoding_accept() {
}
const llama_token sampled_before = slot.sampled;
const common_speculative_type spec_type_used = slot.drafted_spec_type;
size_t n_draft = slot.drafted.size();
slot.ctx_sampling->to_generated_text = &slot.generated_text;
@ -4298,28 +4013,15 @@ void server_context::speculative_decoding_accept() {
continue;
}
const bool any_rejected = (ids.size() - 1) < n_draft;
int32_t mtp_n_past_base = 0;
std::vector<float> mtp_hidden_state_pre;
std::vector<int32_t> accepted_output_indices;
if (slot.has_mtp) {
const int32_t n_pre_spec_tokens = slot.cache_tokens.n_tokens() - (int32_t)(slot.drafted.size() + 1);
mtp_n_past_base = slot.cache_tokens.pos_next(n_pre_spec_tokens);
if (slot.uses_mtp()) {
if (!ids.empty()) {
accepted_output_indices.assign(slot.i_batch_dft.begin(), slot.i_batch_dft.begin() + ids.size());
}
if (any_rejected && slot.spec_ckpt.valid && !accepted_output_indices.empty()) {
if (!common_speculative_copy_output_hidden_rows(slot.spec, ctx, accepted_output_indices, mtp_hidden_state_pre)) {
mtp_hidden_state_pre.clear();
}
}
}
slot.i_batch_dft.clear();
slot.drafted.clear();
slot.drafted_spec_type = COMMON_SPECULATIVE_TYPE_NONE;
slot.n_past += ids.size();
slot.n_decoded += ids.size();
@ -4329,11 +4031,9 @@ void server_context::speculative_decoding_accept() {
// update how many tokens out of those tested were accepted
slot.n_draft_accepted += ids.size() - 1;
// inform the speculative decoding about the number of accepted tokens
common_speculative_accept(slot.spec, ids.size() - 1);
// rollback to the state before sampling the draft tokens
slot.cache_tokens.keep_first(slot.cache_tokens.n_tokens() - n_draft);
const llama_pos spec_pos_base = slot.cache_tokens.pos_next();
// add accepted tokens to the prompt
for (auto it = ids.begin(); it != ids.end() - 1; ++it) {
@ -4342,28 +4042,16 @@ void server_context::speculative_decoding_accept() {
slot.sampled = ids.back(); // last accepted token
slot.n_past = slot.cache_tokens.n_tokens();
// for recurrent/hybrid models: if any drafts were rejected, restore recurrent state
if (any_rejected && slot.spec_ckpt.valid) {
restore_speculative_checkpoint(slot, ctx, model, spec_type_used, sampled_before, ids, n_draft, mtp_hidden_state_pre, mtp_n_past_base);
} else {
if (slot.has_mtp && !accepted_output_indices.empty()) {
if (!common_speculative_commit_accepted_output(
slot.spec,
ctx,
spec_type_used,
slot.id,
mtp_n_past_base,
sampled_before,
ids,
accepted_output_indices)) {
common_speculative_clear_sequence_hidden(slot.spec, slot.id);
} else if (spec_type_used != COMMON_SPECULATIVE_TYPE_MTP) {
SLT_DBG(slot, "%s", "synced MTP target hidden state from accepted-prefix rows");
}
}
llama_kv_cache_seq_rm(ctx, slot.id, slot.cache_tokens.pos_next(slot.n_past), -1);
discard_speculative_checkpoint(slot, ctx);
}
common_speculative_commit(
slot.spec,
ctx,
slot.ctx_sampling,
slot.id,
sampled_before,
ids,
n_draft,
spec_pos_base,
accepted_output_indices);
for (size_t i = 0; i < ids.size(); ++i) {
completion_token_output result;
@ -4737,9 +4425,9 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
continue; // continue loop of n_batch
}
if (server_speculative_has_mtp(params_base.speculative)) {
if (params_base.speculative.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP)) {
for (auto & slot : slots) {
if (!slot.spec || !slot.has_mtp) {
if (!slot.spec || !slot.uses_mtp()) {
continue;
}
@ -4779,7 +4467,7 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
if (slot.n_decoded == 0 && slot.can_speculate()) {
static const llama_tokens empty_prompt;
const llama_tokens & spec_prompt = slot.has_mtp && !slot.params.speculative.has_composite_stage_chain()
const llama_tokens & spec_prompt = slot.uses_mtp() && !slot.params.speculative.has_composite_stage_chain()
? empty_prompt
: slot.cache_tokens.get_text_tokens();
common_speculative_begin(slot.spec, spec_prompt);
@ -4803,7 +4491,7 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
completion_token_output result;
const int tok_idx = slot.i_batch - i;
if (slot.has_mtp && slot.n_decoded == 0) {
if (slot.uses_mtp() && slot.n_decoded == 0) {
(void) common_speculative_capture_output_hidden(slot.spec, ctx, tok_idx, slot.id, slot.n_past);
}
@ -4935,10 +4623,25 @@ void server_context::update_slots() {
if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch_dft.empty()) {
continue;
}
if (save_speculative_checkpoint(slot, model, ctx, ckpt_mode)) {
const char * mode_name = slot.spec_ckpt.per_step_enabled ? "per-step" : "shadow/cpu";
const int32_t n_pre_spec_tokens = slot.cache_tokens.n_tokens() - (int32_t) (slot.drafted.size() + 1);
const llama_pos n_past_pre_spec = slot.cache_tokens.pos_next(n_pre_spec_tokens);
const int max_tokens = (int) slot.drafted.size() + 1;
if (common_speculative_before_draft(
slot.spec,
model,
ctx,
slot.ctx_sampling,
slot.sparams,
slot.id,
n_past_pre_spec,
slot.sampled,
max_tokens,
ckpt_mode)) {
const common_speculative_checkpoint * ckpt = common_speculative_get_checkpoint(slot.spec);
GGML_ASSERT(ckpt != nullptr);
const char * mode_name = ckpt->per_step_enabled ? "per-step" : "shadow/cpu";
SLT_DBG(slot, "spec checkpoint saved (mode=%s), n_past_pre_spec=%d\n",
mode_name, slot.spec_ckpt.n_past);
mode_name, ckpt->n_past);
} else {
SLT_WRN(slot, "%s", "failed to save spec checkpoint\n");
}

View File

@ -22,16 +22,6 @@ enum slot_command {
SLOT_COMMAND_RELEASE,
};
struct server_speculative_checkpoint {
bool valid = false;
bool per_step_enabled = false; // per-step SSM checkpoints active
llama_pos n_past = 0;
llama_token sampled = LLAMA_TOKEN_NULL;
common_sampler * sampler = nullptr; // saved sampler state
void clear();
};
struct server_slot {
int id;
int id_task = -1;
@ -39,9 +29,6 @@ struct server_slot {
struct slot_params params;
llama_batch batch_spec = {};
llama_context * ctx_dft = nullptr;
bool released = false;
slot_state state = SLOT_STATE_IDLE;
slot_command command = SLOT_COMMAND_NONE;
@ -136,7 +123,6 @@ struct server_slot {
// sampling
llama_token sampled; // in speculative mode, this is the last accepted token
llama_tokens drafted;
common_speculative_type drafted_spec_type = COMMON_SPECULATIVE_TYPE_NONE;
json json_schema;
@ -171,11 +157,6 @@ struct server_slot {
// expiring logit bias
std::vector<common_sampler::elb_state> prev_elb_states;
bool has_mtp = false;
// saves recurrent state before a speculative batch so it can be restored on rejection
server_speculative_checkpoint spec_ckpt;
// speculative decoding stats
int32_t n_draft_total = 0; // Total draft tokens generated
int32_t n_draft_accepted = 0; // Draft tokens actually accepted
@ -195,6 +176,7 @@ struct server_slot {
void reset();
bool need_embd() const;
bool uses_mtp() const;
bool has_budget(gpt_params& global_params);
@ -266,11 +248,6 @@ struct server_context {
// multimodal
mtmd_context* mctx = nullptr;
// For speculative decoding
llama_model* model_draft = nullptr;
llama_context* ctx_draft = nullptr;
llama_context_params cparams_dft;
int32_t n_ctx; // total context for all clients / slots
// system prompt