mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-06-28 04:30:15 -05:00
Refactor: Move spec outside server (#1949)
* Refactor speculative decoding: move logic outside of server * remove duplicated tokens in mtp kv cache * narrow to only discard draft cells in MTP * revert mtp_speculative_gen_draft
This commit is contained in:
parent
d1339249d7
commit
8a38025174
@ -195,10 +195,38 @@ bool common_params_speculative::has_stage_type(common_speculative_type stage_typ
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void common_params_speculative::remove_stage_type(common_speculative_type stage_type) {
|
||||||
|
stages.erase(std::remove_if(stages.begin(), stages.end(), [stage_type](const common_speculative_stage_params & stage) {
|
||||||
|
return stage.type == stage_type;
|
||||||
|
}), stages.end());
|
||||||
|
|
||||||
|
if (type == stage_type) {
|
||||||
|
const auto resolved = get_resolved_stages();
|
||||||
|
type = resolved.empty() ? COMMON_SPECULATIVE_TYPE_NONE : resolved.front().type;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
bool common_params_speculative::has_composite_stage_chain() const {
|
bool common_params_speculative::has_composite_stage_chain() const {
|
||||||
return get_resolved_stages().size() > 1;
|
return get_resolved_stages().size() > 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool common_params_speculative::needs_dft_model() const {
|
||||||
|
return has_stage_type(COMMON_SPECULATIVE_TYPE_DRAFT) ||
|
||||||
|
(has_stage_type(COMMON_SPECULATIVE_TYPE_MTP) && has_dft());
|
||||||
|
}
|
||||||
|
|
||||||
|
void common_params_speculative::clear_dft() {
|
||||||
|
if (model_dft != nullptr) {
|
||||||
|
llama_free_model(model_dft);
|
||||||
|
model_dft = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
model.clear();
|
||||||
|
params.clear();
|
||||||
|
mparams_dft.path.clear();
|
||||||
|
cparams_dft = llama_context_default_params();
|
||||||
|
}
|
||||||
|
|
||||||
int32_t common_params_speculative::get_max_stage_n_max() const {
|
int32_t common_params_speculative::get_max_stage_n_max() const {
|
||||||
const auto resolved = get_resolved_stages();
|
const auto resolved = get_resolved_stages();
|
||||||
if (resolved.empty()) {
|
if (resolved.empty()) {
|
||||||
|
|||||||
@ -252,7 +252,10 @@ struct common_params_speculative {
|
|||||||
common_params_speculative with_stage_overrides(const common_speculative_stage_params & stage) const;
|
common_params_speculative with_stage_overrides(const common_speculative_stage_params & stage) const;
|
||||||
bool has_stage_chain() const;
|
bool has_stage_chain() const;
|
||||||
bool has_stage_type(common_speculative_type stage_type) const;
|
bool has_stage_type(common_speculative_type stage_type) const;
|
||||||
|
void remove_stage_type(common_speculative_type stage_type);
|
||||||
bool has_composite_stage_chain() const;
|
bool has_composite_stage_chain() const;
|
||||||
|
bool needs_dft_model() const;
|
||||||
|
void clear_dft();
|
||||||
int32_t get_max_stage_n_max() const;
|
int32_t get_max_stage_n_max() const;
|
||||||
int32_t get_min_usable_stage_n_min() const;
|
int32_t get_min_usable_stage_n_min() const;
|
||||||
|
|
||||||
|
|||||||
@ -47,6 +47,18 @@ const std::map<std::string, enum common_speculative_type> common_speculative_typ
|
|||||||
{"suffix", COMMON_SPECULATIVE_TYPE_SUFFIX}
|
{"suffix", COMMON_SPECULATIVE_TYPE_SUFFIX}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
void common_speculative_checkpoint::clear() {
|
||||||
|
valid = false;
|
||||||
|
per_step_enabled = false;
|
||||||
|
n_past = 0;
|
||||||
|
sampled = LLAMA_TOKEN_NULL;
|
||||||
|
|
||||||
|
if (sampler != nullptr) {
|
||||||
|
common_sampler_free(sampler);
|
||||||
|
sampler = nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct common_speculative_config {
|
struct common_speculative_config {
|
||||||
common_speculative_stage_params stage;
|
common_speculative_stage_params stage;
|
||||||
common_speculative_type type;
|
common_speculative_type type;
|
||||||
@ -172,6 +184,17 @@ struct common_speculative_state_mtp;
|
|||||||
static common_speculative_state_mtp * common_speculative_get_mtp_state(common_speculative * spec);
|
static common_speculative_state_mtp * common_speculative_get_mtp_state(common_speculative * spec);
|
||||||
static const common_speculative_state_mtp * common_speculative_get_mtp_state(const common_speculative * spec);
|
static const common_speculative_state_mtp * common_speculative_get_mtp_state(const common_speculative * spec);
|
||||||
static void mtp_invalidate_cached_drafts(common_speculative_state_mtp & state);
|
static void mtp_invalidate_cached_drafts(common_speculative_state_mtp & state);
|
||||||
|
static bool common_speculative_checkpoint_save(
|
||||||
|
common_speculative_checkpoint & ckpt,
|
||||||
|
llama_model * model,
|
||||||
|
llama_context * ctx,
|
||||||
|
common_sampler * sampler_src,
|
||||||
|
const common_params_sampling & sparams,
|
||||||
|
llama_seq_id seq_id,
|
||||||
|
llama_pos n_past,
|
||||||
|
llama_token sampled,
|
||||||
|
int max_tokens,
|
||||||
|
int ckpt_mode);
|
||||||
|
|
||||||
static std::vector<llama_token> mtp_speculative_gen_draft(
|
static std::vector<llama_token> mtp_speculative_gen_draft(
|
||||||
common_speculative_state_mtp & state,
|
common_speculative_state_mtp & state,
|
||||||
@ -1002,12 +1025,17 @@ struct common_speculative_state_suffix : public common_speculative_state {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct common_speculative {
|
struct common_speculative {
|
||||||
|
common_speculative_checkpoint checkpoint;
|
||||||
std::vector<common_speculative_config> configs; // resolved stage config for each implementation
|
std::vector<common_speculative_config> configs; // resolved stage config for each implementation
|
||||||
std::vector<std::unique_ptr<common_speculative_state>> impls; // list of implementations to use and their states
|
std::vector<std::unique_ptr<common_speculative_state>> impls; // list of implementations to use and their states
|
||||||
common_speculative_state * curr_impl = nullptr; // current implementation in use (for stats)
|
common_speculative_state * curr_impl = nullptr; // current implementation in use (for stats)
|
||||||
std::unique_ptr<spec_tuner> tuner;
|
std::unique_ptr<spec_tuner> tuner;
|
||||||
int last_n_drafted = 0;
|
int last_n_drafted = 0;
|
||||||
int64_t t_step_start_us = 0;
|
int64_t t_step_start_us = 0;
|
||||||
|
|
||||||
|
~common_speculative() {
|
||||||
|
checkpoint.clear();
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
static bool common_speculative_stage_chain_matches(
|
static bool common_speculative_stage_chain_matches(
|
||||||
@ -1315,6 +1343,7 @@ common_speculative * common_speculative_init(
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto * result = new common_speculative {
|
auto * result = new common_speculative {
|
||||||
|
/* .checkpoint = */ {},
|
||||||
/* .configs = */ std::move(configs),
|
/* .configs = */ std::move(configs),
|
||||||
/* .impls = */ std::move(impls)
|
/* .impls = */ std::move(impls)
|
||||||
};
|
};
|
||||||
@ -1340,6 +1369,170 @@ common_speculative * common_speculative_init(
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
common_speculative_init_status common_speculative_try_init(
|
||||||
|
common_params_speculative & params,
|
||||||
|
llama_context * ctx_tgt,
|
||||||
|
common_speculative ** out_spec) {
|
||||||
|
if (out_spec != nullptr) {
|
||||||
|
*out_spec = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!params.has_stage_chain()) {
|
||||||
|
return COMMON_SPECULATIVE_INIT_SKIPPED;
|
||||||
|
}
|
||||||
|
|
||||||
|
common_speculative * spec = common_speculative_init(params, ctx_tgt);
|
||||||
|
if (spec != nullptr) {
|
||||||
|
if (out_spec != nullptr) {
|
||||||
|
*out_spec = spec;
|
||||||
|
}
|
||||||
|
return COMMON_SPECULATIVE_INIT_READY;
|
||||||
|
}
|
||||||
|
|
||||||
|
const llama_model * model = ctx_tgt != nullptr ? llama_get_model(ctx_tgt) : nullptr;
|
||||||
|
if (model != nullptr && llama_model_has_recurrent(model)) {
|
||||||
|
return COMMON_SPECULATIVE_INIT_ERR_RECURRENT;
|
||||||
|
}
|
||||||
|
if (params.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP)) {
|
||||||
|
return COMMON_SPECULATIVE_INIT_ERR_MTP;
|
||||||
|
}
|
||||||
|
return COMMON_SPECULATIVE_INIT_ERR_GENERIC;
|
||||||
|
}
|
||||||
|
|
||||||
|
void common_speculative_prepare_startup(
|
||||||
|
gpt_params & params_base,
|
||||||
|
bool allow_parallel_mtp) {
|
||||||
|
auto & params = params_base.speculative;
|
||||||
|
|
||||||
|
if (!allow_parallel_mtp && params_base.n_parallel > 1 && params.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP)) {
|
||||||
|
LOG_WRN("%s: MTP is not supported with parallel slots yet, removing the MTP stage to avoid cross-slot corruption. n_parallel=%d, stage_chain=%s\n",
|
||||||
|
__func__, params_base.n_parallel, common_speculative_stage_chain_to_str(params).c_str());
|
||||||
|
params.remove_stage_type(COMMON_SPECULATIVE_TYPE_MTP);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!params.needs_dft_model()) {
|
||||||
|
params.clear_dft();
|
||||||
|
}
|
||||||
|
|
||||||
|
params_base.has_mtp = params.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool common_speculative_finalize_startup(
|
||||||
|
gpt_params & params_base,
|
||||||
|
const llama_model * model) {
|
||||||
|
auto & params = params_base.speculative;
|
||||||
|
|
||||||
|
if (!params.needs_dft_model()) {
|
||||||
|
params.clear_dft();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (params.has_dft()) {
|
||||||
|
LLAMA_LOG_INFO("\n\n==================================loading DRAFT model==================================\n\n");
|
||||||
|
if (!common_speculative_load_draft_model(params, params_base)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
params_base.has_mtp = params.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP);
|
||||||
|
const bool has_external_mtp = params_base.has_mtp &&
|
||||||
|
llama_model_is_gemma4_mtp_assistant(params.model_dft);
|
||||||
|
|
||||||
|
params_base.has_mtp = common_speculative_prepare_mtp_runtime(
|
||||||
|
params,
|
||||||
|
params_base,
|
||||||
|
model,
|
||||||
|
has_external_mtp);
|
||||||
|
if (params_base.has_mtp) {
|
||||||
|
params_base.pooling_type = LLAMA_POOLING_TYPE_NONE;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool common_speculative_load_draft_model(
|
||||||
|
common_params_speculative & params,
|
||||||
|
const gpt_params & params_base) {
|
||||||
|
if (!params.has_dft()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
gpt_params params_dft;
|
||||||
|
params_dft.devices = params.devices;
|
||||||
|
params_dft.model = params.model;
|
||||||
|
params_dft.main_gpu = params_base.main_gpu;
|
||||||
|
params_dft.n_gpu_layers = params.n_gpu_layers;
|
||||||
|
params_dft.rpc_servers = params_base.rpc_servers;
|
||||||
|
params_dft.cache_type_k = params.cache_type_k.empty() ? params_base.cache_type_k : params.cache_type_k;
|
||||||
|
params_dft.cache_type_v = params.cache_type_v.empty() ? params_base.cache_type_v : params.cache_type_v;
|
||||||
|
params_dft.flash_attn = params_base.flash_attn;
|
||||||
|
params_dft.k_cache_hadamard = params_base.k_cache_hadamard;
|
||||||
|
params_dft.v_cache_hadamard = params_base.v_cache_hadamard;
|
||||||
|
|
||||||
|
if (!params.params.empty()) {
|
||||||
|
auto [argc, argv] = parse_command_line("llama-server " + params.params);
|
||||||
|
if (!gpt_params_parse(argc, argv, params_dft)) {
|
||||||
|
gpt_params_print_usage(argc, argv, params_dft);
|
||||||
|
free_command_line(argc, argv);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
free_command_line(argc, argv);
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG_INF("%s: loading draft model '%s'\n", __func__, params_dft.model.c_str());
|
||||||
|
|
||||||
|
if (params_dft.n_ctx == 0) {
|
||||||
|
params_dft.n_ctx = params.n_ctx;
|
||||||
|
}
|
||||||
|
params_dft.n_ctx = params_dft.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_dft.n_ctx;
|
||||||
|
params_dft.n_parallel = 1;
|
||||||
|
params_dft.n_batch = params_dft.n_ctx;
|
||||||
|
|
||||||
|
params.mparams_dft.path = params_dft.model;
|
||||||
|
|
||||||
|
llama_model_params mparams_dft = common_model_params_to_llama(params_dft);
|
||||||
|
llama_model * loaded_model = llama_model_load_from_file(params_dft.model.c_str(), mparams_dft);
|
||||||
|
if (loaded_model == nullptr) {
|
||||||
|
LOG_ERR("%s: failed to load draft model '%s'\n", __func__, params.model.c_str());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
params.model_dft = loaded_model;
|
||||||
|
params.cparams_dft = common_context_params_to_llama(params_dft);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool common_speculative_prepare_mtp_runtime(
|
||||||
|
common_params_speculative & params,
|
||||||
|
const gpt_params & params_base,
|
||||||
|
const llama_model * model,
|
||||||
|
bool has_external_mtp) {
|
||||||
|
if (!params.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (llama_model_n_nextn_layer(model) == 0 && !has_external_mtp) {
|
||||||
|
LOG_WRN("%s: MTP speculative stage requested, but model has 0 NextN layers. Removing MTP from the configured stage chain.\n",
|
||||||
|
__func__);
|
||||||
|
params.remove_stage_type(COMMON_SPECULATIVE_TYPE_MTP);
|
||||||
|
if (!params.needs_dft_model()) {
|
||||||
|
params.clear_dft();
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!has_external_mtp) {
|
||||||
|
gpt_params params_mtp = params_base;
|
||||||
|
params_mtp.pooling_type = LLAMA_POOLING_TYPE_NONE;
|
||||||
|
params.cparams_dft = common_context_params_to_llama(params_mtp);
|
||||||
|
}
|
||||||
|
|
||||||
|
params.cparams_dft.mtp = true;
|
||||||
|
params.cparams_dft.mtp_op_type = MTP_OP_WARMUP;
|
||||||
|
params.cparams_dft.embeddings = true;
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
void common_speculative_free(common_speculative * spec) {
|
void common_speculative_free(common_speculative * spec) {
|
||||||
if (spec == nullptr) {
|
if (spec == nullptr) {
|
||||||
return;
|
return;
|
||||||
@ -1353,6 +1546,11 @@ void common_speculative_begin(common_speculative * spec, const llama_tokens & pr
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
spec->checkpoint.clear();
|
||||||
|
spec->curr_impl = nullptr;
|
||||||
|
spec->last_n_drafted = 0;
|
||||||
|
spec->t_step_start_us = 0;
|
||||||
|
|
||||||
for (auto & impl : spec->impls) {
|
for (auto & impl : spec->impls) {
|
||||||
common_time_meas tm(impl->t_begin_us, !impl->gen_perf);
|
common_time_meas tm(impl->t_begin_us, !impl->gen_perf);
|
||||||
impl->begin(prompt);
|
impl->begin(prompt);
|
||||||
@ -1456,6 +1654,34 @@ void common_speculative_accept(common_speculative * spec, uint16_t n_accepted) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool common_speculative_before_draft(
|
||||||
|
common_speculative * spec,
|
||||||
|
llama_model * model,
|
||||||
|
llama_context * ctx,
|
||||||
|
common_sampler * sampler_src,
|
||||||
|
const common_params_sampling & sparams,
|
||||||
|
llama_seq_id seq_id,
|
||||||
|
llama_pos n_past,
|
||||||
|
llama_token sampled,
|
||||||
|
int max_tokens,
|
||||||
|
int ckpt_mode) {
|
||||||
|
if (spec == nullptr) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return common_speculative_checkpoint_save(
|
||||||
|
spec->checkpoint,
|
||||||
|
model,
|
||||||
|
ctx,
|
||||||
|
sampler_src,
|
||||||
|
sparams,
|
||||||
|
seq_id,
|
||||||
|
n_past,
|
||||||
|
sampled,
|
||||||
|
max_tokens,
|
||||||
|
ckpt_mode);
|
||||||
|
}
|
||||||
|
|
||||||
static bool common_speculative_has_type(const common_speculative * spec, common_speculative_type type) {
|
static bool common_speculative_has_type(const common_speculative * spec, common_speculative_type type) {
|
||||||
if (spec == nullptr) {
|
if (spec == nullptr) {
|
||||||
return false;
|
return false;
|
||||||
@ -1663,6 +1889,38 @@ bool common_speculative_ensure_sequence_hidden(
|
|||||||
return common_speculative_capture_output_hidden(spec, ctx, -1, seq_id, pos);
|
return common_speculative_capture_output_hidden(spec, ctx, -1, seq_id, pos);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
common_speculative_draft_result common_speculative_draft_ex(
|
||||||
|
common_speculative * spec,
|
||||||
|
llama_context * ctx,
|
||||||
|
common_params_speculative & params,
|
||||||
|
const llama_tokens & prompt_tgt,
|
||||||
|
llama_token id_last,
|
||||||
|
llama_pos draft_base_pos,
|
||||||
|
llama_seq_id draft_seq_id) {
|
||||||
|
common_speculative_draft_result result = {};
|
||||||
|
|
||||||
|
if (common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_MTP)) {
|
||||||
|
if (!common_speculative_ensure_sequence_hidden(spec, ctx, draft_seq_id, draft_base_pos - 1)) {
|
||||||
|
LOG_ERR("%s: seq_id=%d MTP hidden state is empty during speculation\n",
|
||||||
|
__func__, (int) draft_seq_id);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result.tokens = common_speculative_draft(
|
||||||
|
spec,
|
||||||
|
params,
|
||||||
|
prompt_tgt,
|
||||||
|
id_last,
|
||||||
|
draft_base_pos,
|
||||||
|
draft_seq_id);
|
||||||
|
result.type = spec != nullptr && spec->curr_impl != nullptr
|
||||||
|
? spec->curr_impl->type
|
||||||
|
: COMMON_SPECULATIVE_TYPE_NONE;
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
int32_t common_speculative_on_target_seq_batch(
|
int32_t common_speculative_on_target_seq_batch(
|
||||||
common_speculative * spec,
|
common_speculative * spec,
|
||||||
llama_context * ctx_tgt,
|
llama_context * ctx_tgt,
|
||||||
@ -1834,6 +2092,234 @@ bool common_speculative_commit_accepted_output(
|
|||||||
hidden_rows);
|
hidden_rows);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool common_speculative_checkpoint_save(
|
||||||
|
common_speculative_checkpoint & ckpt,
|
||||||
|
llama_model * model,
|
||||||
|
llama_context * ctx,
|
||||||
|
common_sampler * sampler_src,
|
||||||
|
const common_params_sampling & sparams,
|
||||||
|
llama_seq_id seq_id,
|
||||||
|
llama_pos n_past,
|
||||||
|
llama_token sampled,
|
||||||
|
int max_tokens,
|
||||||
|
int ckpt_mode) {
|
||||||
|
ckpt.clear();
|
||||||
|
ckpt.n_past = n_past;
|
||||||
|
ckpt.sampled = sampled;
|
||||||
|
|
||||||
|
const int actual_mode = llama_spec_ckpt_init(ctx, ckpt_mode, max_tokens);
|
||||||
|
if (actual_mode == LLAMA_SPEC_CKPT_NONE) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
ckpt.per_step_enabled = (actual_mode == LLAMA_SPEC_CKPT_PER_STEP);
|
||||||
|
|
||||||
|
ckpt.valid = llama_spec_ckpt_save(ctx, seq_id);
|
||||||
|
if (!ckpt.valid) {
|
||||||
|
llama_spec_ckpt_discard(ctx);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
ckpt.sampler = common_sampler_init(model, sparams);
|
||||||
|
if (ckpt.sampler == nullptr) {
|
||||||
|
common_speculative_checkpoint_discard(ckpt, ctx);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (sampler_src != nullptr) {
|
||||||
|
common_sampler_clone(sampler_src, ckpt.sampler);
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
const common_speculative_checkpoint * common_speculative_get_checkpoint(const common_speculative * spec) {
|
||||||
|
return spec != nullptr ? &spec->checkpoint : nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
void common_speculative_checkpoint_discard(
|
||||||
|
common_speculative_checkpoint & ckpt,
|
||||||
|
llama_context * ctx) {
|
||||||
|
ckpt.clear();
|
||||||
|
llama_spec_ckpt_discard(ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
void common_speculative_checkpoint_restore(
|
||||||
|
common_speculative_checkpoint & ckpt,
|
||||||
|
common_speculative * spec,
|
||||||
|
llama_context * ctx,
|
||||||
|
common_sampler * sampler_dst,
|
||||||
|
llama_seq_id seq_id,
|
||||||
|
common_speculative_type spec_type_used,
|
||||||
|
llama_token sampled_before,
|
||||||
|
const std::vector<llama_token> & ids,
|
||||||
|
int n_draft,
|
||||||
|
const std::vector<float> & mtp_hidden_state_pre,
|
||||||
|
int32_t mtp_n_past_base) {
|
||||||
|
if (!ckpt.valid) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ckpt.per_step_enabled) {
|
||||||
|
const int step = (int) ids.size() - 1;
|
||||||
|
llama_spec_ckpt_restore(ctx, seq_id, ckpt.n_past, step);
|
||||||
|
|
||||||
|
if (ckpt.sampler != nullptr && sampler_dst != nullptr) {
|
||||||
|
common_sampler_clone(ckpt.sampler, sampler_dst);
|
||||||
|
}
|
||||||
|
if (sampler_dst != nullptr) {
|
||||||
|
for (llama_token id : ids) {
|
||||||
|
common_sampler_accept(sampler_dst, ctx, id, true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_MTP) && !mtp_hidden_state_pre.empty()) {
|
||||||
|
if (!common_speculative_commit_accepted_hidden_rows(
|
||||||
|
spec,
|
||||||
|
spec_type_used,
|
||||||
|
seq_id,
|
||||||
|
mtp_n_past_base,
|
||||||
|
sampled_before,
|
||||||
|
ids,
|
||||||
|
mtp_hidden_state_pre)) {
|
||||||
|
common_speculative_clear_sequence_hidden(spec, seq_id);
|
||||||
|
} else if (spec_type_used != COMMON_SPECULATIVE_TYPE_MTP) {
|
||||||
|
LOG_DBG("%s: seq_id=%d synced MTP target hidden state from accepted-prefix rows after per-step restore\n",
|
||||||
|
__func__, (int) seq_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG_DBG("%s: seq_id=%d per-step restore: step=%d (rejected %d drafts)\n",
|
||||||
|
__func__, (int) seq_id, step, (int) (n_draft - (ids.size() - 1)));
|
||||||
|
} else {
|
||||||
|
llama_spec_ckpt_restore(ctx, seq_id, ckpt.n_past, 0);
|
||||||
|
|
||||||
|
if (ckpt.sampler != nullptr && sampler_dst != nullptr) {
|
||||||
|
common_sampler_clone(ckpt.sampler, sampler_dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!ids.empty()) {
|
||||||
|
const int n_re = (int) ids.size();
|
||||||
|
llama_batch re_batch = llama_batch_init(n_re, 0, 1);
|
||||||
|
common_batch_add(re_batch, ckpt.sampled, ckpt.n_past, { seq_id }, n_re == 1);
|
||||||
|
for (int j = 0; j < n_re - 1; ++j) {
|
||||||
|
common_batch_add(re_batch, ids[j], ckpt.n_past + 1 + j, { seq_id }, j == n_re - 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_MTP)) {
|
||||||
|
for (int j = 0; j < re_batch.n_tokens; ++j) {
|
||||||
|
re_batch.logits[j] = true;
|
||||||
|
}
|
||||||
|
llama_set_embeddings(ctx, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
const int ret = llama_decode(ctx, re_batch);
|
||||||
|
if (ret != 0) {
|
||||||
|
LOG_ERR("%s: seq_id=%d failed to re-decode accepted tokens after checkpoint restore: %d\n",
|
||||||
|
__func__, (int) seq_id, ret);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_MTP)) {
|
||||||
|
std::vector<int32_t> redecoded_indices(n_re);
|
||||||
|
for (int j = 0; j < n_re; ++j) {
|
||||||
|
redecoded_indices[j] = j;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!common_speculative_commit_accepted_output(
|
||||||
|
spec,
|
||||||
|
ctx,
|
||||||
|
spec_type_used,
|
||||||
|
seq_id,
|
||||||
|
ckpt.n_past,
|
||||||
|
sampled_before,
|
||||||
|
ids,
|
||||||
|
redecoded_indices)) {
|
||||||
|
common_speculative_clear_sequence_hidden(spec, seq_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (sampler_dst != nullptr) {
|
||||||
|
for (llama_token id : ids) {
|
||||||
|
common_sampler_accept(sampler_dst, ctx, id, true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_batch_free(re_batch);
|
||||||
|
LOG_DBG("%s: seq_id=%d spec checkpoint restored: re-decoded %d tokens (rejected %d drafts)\n",
|
||||||
|
__func__, (int) seq_id, n_re, (int) (n_draft - (ids.size() - 1)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
common_speculative_checkpoint_discard(ckpt, ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
void common_speculative_commit(
|
||||||
|
common_speculative * spec,
|
||||||
|
llama_context * ctx,
|
||||||
|
common_sampler * sampler_dst,
|
||||||
|
llama_seq_id seq_id,
|
||||||
|
llama_token sampled_before,
|
||||||
|
const std::vector<llama_token> & ids,
|
||||||
|
int n_draft,
|
||||||
|
llama_pos pos_base,
|
||||||
|
const std::vector<int32_t> & accepted_output_indices) {
|
||||||
|
GGML_ASSERT(spec != nullptr);
|
||||||
|
GGML_ASSERT(!ids.empty());
|
||||||
|
|
||||||
|
common_speculative_checkpoint & ckpt = spec->checkpoint;
|
||||||
|
const common_speculative_type spec_type_used = spec->curr_impl != nullptr
|
||||||
|
? spec->curr_impl->type
|
||||||
|
: COMMON_SPECULATIVE_TYPE_NONE;
|
||||||
|
const bool any_rejected = (int) ids.size() - 1 < n_draft;
|
||||||
|
std::vector<float> mtp_hidden_state_pre;
|
||||||
|
|
||||||
|
common_speculative_accept(spec, ids.size() - 1);
|
||||||
|
|
||||||
|
if (common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_MTP) &&
|
||||||
|
any_rejected &&
|
||||||
|
ckpt.valid &&
|
||||||
|
!accepted_output_indices.empty()) {
|
||||||
|
if (!common_speculative_copy_output_hidden_rows(spec, ctx, accepted_output_indices, mtp_hidden_state_pre)) {
|
||||||
|
mtp_hidden_state_pre.clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (any_rejected && ckpt.valid) {
|
||||||
|
common_speculative_checkpoint_restore(
|
||||||
|
ckpt,
|
||||||
|
spec,
|
||||||
|
ctx,
|
||||||
|
sampler_dst,
|
||||||
|
seq_id,
|
||||||
|
spec_type_used,
|
||||||
|
sampled_before,
|
||||||
|
ids,
|
||||||
|
n_draft,
|
||||||
|
mtp_hidden_state_pre,
|
||||||
|
pos_base);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_MTP) && !accepted_output_indices.empty()) {
|
||||||
|
if (!common_speculative_commit_accepted_output(
|
||||||
|
spec,
|
||||||
|
ctx,
|
||||||
|
spec_type_used,
|
||||||
|
seq_id,
|
||||||
|
pos_base,
|
||||||
|
sampled_before,
|
||||||
|
ids,
|
||||||
|
accepted_output_indices)) {
|
||||||
|
common_speculative_clear_sequence_hidden(spec, seq_id);
|
||||||
|
} else if (spec_type_used != COMMON_SPECULATIVE_TYPE_MTP) {
|
||||||
|
LOG_DBG("%s: seq_id=%d synced MTP target hidden state from accepted-prefix rows\n",
|
||||||
|
__func__, (int) seq_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_kv_cache_seq_rm(ctx, seq_id, pos_base + (llama_pos) (ids.size() - 1), -1);
|
||||||
|
common_speculative_checkpoint_discard(ckpt, ctx);
|
||||||
|
}
|
||||||
|
|
||||||
void common_speculative_print_stats(const common_speculative * spec, double slot_tps, int n_decoded, int n_past, common_params_speculative * active_params) {
|
void common_speculative_print_stats(const common_speculative * spec, double slot_tps, int n_decoded, int n_past, common_params_speculative * active_params) {
|
||||||
if (spec == nullptr) {
|
if (spec == nullptr) {
|
||||||
return;
|
return;
|
||||||
@ -1980,6 +2466,50 @@ void common_speculative_clear_sequence_hidden(common_speculative * spec, llama_s
|
|||||||
mtp_clear_target_hidden(*mtp_state, seq_id);
|
mtp_clear_target_hidden(*mtp_state, seq_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void common_speculative_clear_sequence(
|
||||||
|
common_speculative * spec,
|
||||||
|
llama_seq_id seq_id,
|
||||||
|
bool clear_companion_ctx) {
|
||||||
|
if (spec != nullptr) {
|
||||||
|
spec->checkpoint.clear();
|
||||||
|
spec->curr_impl = nullptr;
|
||||||
|
spec->last_n_drafted = 0;
|
||||||
|
spec->t_step_start_us = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
common_speculative_clear_sequence_hidden(spec, seq_id);
|
||||||
|
|
||||||
|
if (clear_companion_ctx) {
|
||||||
|
if (auto * ctx_mtp = common_speculative_get_companion_ctx(spec); ctx_mtp != nullptr) {
|
||||||
|
llama_kv_cache_clear(ctx_mtp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool common_speculative_trim_sequence(
|
||||||
|
common_speculative * spec,
|
||||||
|
llama_context * ctx,
|
||||||
|
llama_seq_id seq_id,
|
||||||
|
llama_pos pos_begin) {
|
||||||
|
const bool target_trimmed = llama_kv_cache_seq_rm(ctx, seq_id, pos_begin, -1);
|
||||||
|
if (auto * ctx_mtp = common_speculative_get_companion_ctx(spec); ctx_mtp != nullptr) {
|
||||||
|
return target_trimmed && llama_kv_cache_seq_rm(ctx_mtp, seq_id, pos_begin, -1);
|
||||||
|
}
|
||||||
|
|
||||||
|
return target_trimmed;
|
||||||
|
}
|
||||||
|
|
||||||
|
void common_speculative_clear_sequence_kv(
|
||||||
|
common_speculative * spec,
|
||||||
|
llama_context * ctx,
|
||||||
|
llama_seq_id seq_id) {
|
||||||
|
common_speculative_clear_sequence(spec, seq_id);
|
||||||
|
llama_kv_cache_seq_rm(ctx, seq_id, -1, -1);
|
||||||
|
if (auto * ctx_mtp = common_speculative_get_companion_ctx(spec); ctx_mtp != nullptr) {
|
||||||
|
llama_kv_cache_seq_rm(ctx_mtp, seq_id, -1, -1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
llama_context * common_speculative_get_companion_ctx(common_speculative * spec) {
|
llama_context * common_speculative_get_companion_ctx(common_speculative * spec) {
|
||||||
if (auto * mtp_state = common_speculative_get_mtp_state(spec); mtp_state != nullptr) {
|
if (auto * mtp_state = common_speculative_get_mtp_state(spec); mtp_state != nullptr) {
|
||||||
return mtp_state->ctx_mtp;
|
return mtp_state->ctx_mtp;
|
||||||
@ -2184,13 +2714,10 @@ std::vector<llama_token> mtp_speculative_gen_draft(
|
|||||||
// This prevents cache state corruption where two cells map to the same logical position.
|
// This prevents cache state corruption where two cells map to the same logical position.
|
||||||
// If the state contained in `last` had a valid token id and probability, it means that we
|
// If the state contained in `last` had a valid token id and probability, it means that we
|
||||||
// have previously run an "accept" batch, where the token sampled from the main model was included.
|
// have previously run an "accept" batch, where the token sampled from the main model was included.
|
||||||
// In that case, we need to discard all tokens that we ran here to get the KV cache to the correct state.
|
// Even in that case, the token at `n_past` is already committed and must remain in the KV cache,
|
||||||
// => for i0 = 1 we discard from n_past
|
// so we only discard the speculative tail starting at `n_past + 1`.
|
||||||
// But if we did not have a valid last token_id, it means the first token we run was sampled from the
|
|
||||||
// main model. Hence we want to keep this token in the KV cache and discard all other tokens.
|
|
||||||
// => for i0 = 0 we discard from n_past + 1
|
|
||||||
if (n_decode > 0) {
|
if (n_decode > 0) {
|
||||||
llama_kv_cache_seq_rm(ctx, seq_id, n_past + 1 - i0, n_past + n_decode + 2);
|
llama_kv_cache_seq_rm(ctx, seq_id, n_past + 1, n_past + n_decode + 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
return drafts;
|
return drafts;
|
||||||
|
|||||||
@ -7,6 +7,14 @@
|
|||||||
|
|
||||||
struct common_speculative;
|
struct common_speculative;
|
||||||
|
|
||||||
|
enum common_speculative_init_status {
|
||||||
|
COMMON_SPECULATIVE_INIT_SKIPPED,
|
||||||
|
COMMON_SPECULATIVE_INIT_READY,
|
||||||
|
COMMON_SPECULATIVE_INIT_ERR_RECURRENT,
|
||||||
|
COMMON_SPECULATIVE_INIT_ERR_MTP,
|
||||||
|
COMMON_SPECULATIVE_INIT_ERR_GENERIC,
|
||||||
|
};
|
||||||
|
|
||||||
using common_speculative_feature_kind = llama_spec_feature_kind;
|
using common_speculative_feature_kind = llama_spec_feature_kind;
|
||||||
using common_speculative_feature_row_view = llama_spec_feature_row_view;
|
using common_speculative_feature_row_view = llama_spec_feature_row_view;
|
||||||
using common_speculative_feature_view = llama_spec_feature_view;
|
using common_speculative_feature_view = llama_spec_feature_view;
|
||||||
@ -14,6 +22,21 @@ using common_speculative_feature_view = llama_spec_feature_view;
|
|||||||
static constexpr common_speculative_feature_kind COMMON_SPECULATIVE_FEATURE_NONE = LLAMA_SPEC_FEATURE_NONE;
|
static constexpr common_speculative_feature_kind COMMON_SPECULATIVE_FEATURE_NONE = LLAMA_SPEC_FEATURE_NONE;
|
||||||
static constexpr common_speculative_feature_kind COMMON_SPECULATIVE_FEATURE_HIDDEN_STATE = LLAMA_SPEC_FEATURE_HIDDEN_STATE;
|
static constexpr common_speculative_feature_kind COMMON_SPECULATIVE_FEATURE_HIDDEN_STATE = LLAMA_SPEC_FEATURE_HIDDEN_STATE;
|
||||||
|
|
||||||
|
struct common_speculative_checkpoint {
|
||||||
|
bool valid = false;
|
||||||
|
bool per_step_enabled = false;
|
||||||
|
llama_pos n_past = 0;
|
||||||
|
llama_token sampled = LLAMA_TOKEN_NULL;
|
||||||
|
common_sampler * sampler = nullptr;
|
||||||
|
|
||||||
|
void clear();
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_speculative_draft_result {
|
||||||
|
llama_tokens tokens;
|
||||||
|
common_speculative_type type = COMMON_SPECULATIVE_TYPE_NONE;
|
||||||
|
};
|
||||||
|
|
||||||
// comma separated list of all types
|
// comma separated list of all types
|
||||||
std::string common_speculative_type_name_str();
|
std::string common_speculative_type_name_str();
|
||||||
|
|
||||||
@ -31,6 +54,29 @@ common_speculative * common_speculative_init(
|
|||||||
common_params_speculative & params,
|
common_params_speculative & params,
|
||||||
llama_context * ctx_tgt);
|
llama_context * ctx_tgt);
|
||||||
|
|
||||||
|
common_speculative_init_status common_speculative_try_init(
|
||||||
|
common_params_speculative & params,
|
||||||
|
llama_context * ctx_tgt,
|
||||||
|
common_speculative ** out_spec);
|
||||||
|
|
||||||
|
void common_speculative_prepare_startup(
|
||||||
|
gpt_params & params_base,
|
||||||
|
bool allow_parallel_mtp = true);
|
||||||
|
|
||||||
|
bool common_speculative_finalize_startup(
|
||||||
|
gpt_params & params_base,
|
||||||
|
const llama_model * model);
|
||||||
|
|
||||||
|
bool common_speculative_load_draft_model(
|
||||||
|
common_params_speculative & params,
|
||||||
|
const gpt_params & params_base);
|
||||||
|
|
||||||
|
bool common_speculative_prepare_mtp_runtime(
|
||||||
|
common_params_speculative & params,
|
||||||
|
const gpt_params & params_base,
|
||||||
|
const llama_model * model,
|
||||||
|
bool has_external_mtp);
|
||||||
|
|
||||||
void common_speculative_free(common_speculative * spec);
|
void common_speculative_free(common_speculative * spec);
|
||||||
|
|
||||||
// optionally call once at the beginning of a new generation
|
// optionally call once at the beginning of a new generation
|
||||||
@ -46,9 +92,30 @@ llama_tokens common_speculative_draft(
|
|||||||
llama_pos draft_base_pos = -1,
|
llama_pos draft_base_pos = -1,
|
||||||
llama_seq_id draft_seq_id = 0);
|
llama_seq_id draft_seq_id = 0);
|
||||||
|
|
||||||
|
common_speculative_draft_result common_speculative_draft_ex(
|
||||||
|
common_speculative * spec,
|
||||||
|
llama_context * ctx,
|
||||||
|
common_params_speculative & params,
|
||||||
|
const llama_tokens & prompt,
|
||||||
|
llama_token id_last,
|
||||||
|
llama_pos draft_base_pos = -1,
|
||||||
|
llama_seq_id draft_seq_id = 0);
|
||||||
|
|
||||||
// informs the speculative decoder that n_accepted tokens were accepted by the target model
|
// informs the speculative decoder that n_accepted tokens were accepted by the target model
|
||||||
void common_speculative_accept(common_speculative * spec, uint16_t n_accepted);
|
void common_speculative_accept(common_speculative * spec, uint16_t n_accepted);
|
||||||
|
|
||||||
|
bool common_speculative_before_draft(
|
||||||
|
common_speculative * spec,
|
||||||
|
llama_model * model,
|
||||||
|
llama_context * ctx,
|
||||||
|
common_sampler * sampler_src,
|
||||||
|
const common_params_sampling & sparams,
|
||||||
|
llama_seq_id seq_id,
|
||||||
|
llama_pos n_past,
|
||||||
|
llama_token sampled,
|
||||||
|
int max_tokens,
|
||||||
|
int ckpt_mode);
|
||||||
|
|
||||||
bool common_speculative_ensure_sequence_hidden(
|
bool common_speculative_ensure_sequence_hidden(
|
||||||
common_speculative * spec,
|
common_speculative * spec,
|
||||||
llama_context * ctx,
|
llama_context * ctx,
|
||||||
@ -87,10 +154,56 @@ bool common_speculative_commit_accepted_output(
|
|||||||
const std::vector<llama_token> & ids,
|
const std::vector<llama_token> & ids,
|
||||||
const std::vector<int32_t> & output_indices);
|
const std::vector<int32_t> & output_indices);
|
||||||
|
|
||||||
|
const common_speculative_checkpoint * common_speculative_get_checkpoint(const common_speculative * spec);
|
||||||
|
|
||||||
|
void common_speculative_checkpoint_discard(
|
||||||
|
common_speculative_checkpoint & ckpt,
|
||||||
|
llama_context * ctx);
|
||||||
|
|
||||||
|
void common_speculative_checkpoint_restore(
|
||||||
|
common_speculative_checkpoint & ckpt,
|
||||||
|
common_speculative * spec,
|
||||||
|
llama_context * ctx,
|
||||||
|
common_sampler * sampler_dst,
|
||||||
|
llama_seq_id seq_id,
|
||||||
|
common_speculative_type spec_type_used,
|
||||||
|
llama_token sampled_before,
|
||||||
|
const std::vector<llama_token> & ids,
|
||||||
|
int n_draft,
|
||||||
|
const std::vector<float> & mtp_hidden_state_pre,
|
||||||
|
int32_t mtp_n_past_base);
|
||||||
|
|
||||||
|
void common_speculative_commit(
|
||||||
|
common_speculative * spec,
|
||||||
|
llama_context * ctx,
|
||||||
|
common_sampler * sampler_dst,
|
||||||
|
llama_seq_id seq_id,
|
||||||
|
llama_token sampled_before,
|
||||||
|
const std::vector<llama_token> & ids,
|
||||||
|
int n_draft,
|
||||||
|
llama_pos pos_base,
|
||||||
|
const std::vector<int32_t> & accepted_output_indices);
|
||||||
|
|
||||||
bool common_speculative_has_sequence_hidden(const common_speculative * spec, llama_seq_id seq_id);
|
bool common_speculative_has_sequence_hidden(const common_speculative * spec, llama_seq_id seq_id);
|
||||||
|
|
||||||
void common_speculative_clear_sequence_hidden(common_speculative * spec, llama_seq_id seq_id);
|
void common_speculative_clear_sequence_hidden(common_speculative * spec, llama_seq_id seq_id);
|
||||||
|
|
||||||
|
void common_speculative_clear_sequence(
|
||||||
|
common_speculative * spec,
|
||||||
|
llama_seq_id seq_id,
|
||||||
|
bool clear_companion_ctx = false);
|
||||||
|
|
||||||
|
bool common_speculative_trim_sequence(
|
||||||
|
common_speculative * spec,
|
||||||
|
llama_context * ctx,
|
||||||
|
llama_seq_id seq_id,
|
||||||
|
llama_pos pos_begin);
|
||||||
|
|
||||||
|
void common_speculative_clear_sequence_kv(
|
||||||
|
common_speculative * spec,
|
||||||
|
llama_context * ctx,
|
||||||
|
llama_seq_id seq_id);
|
||||||
|
|
||||||
llama_context * common_speculative_get_companion_ctx(common_speculative * spec);
|
llama_context * common_speculative_get_companion_ctx(common_speculative * spec);
|
||||||
|
|
||||||
int32_t common_speculative_on_target_seq_batch(
|
int32_t common_speculative_on_target_seq_batch(
|
||||||
|
|||||||
@ -45,11 +45,6 @@ static void log_text(const gpt_params & params_base, const std::string & text) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool params_use_gemma4_external_mtp(const gpt_params & params_base) {
|
|
||||||
return params_base.has_mtp &&
|
|
||||||
llama_model_is_gemma4_mtp_assistant(params_base.speculative.model_dft);
|
|
||||||
}
|
|
||||||
|
|
||||||
struct server_mtp_warmup {
|
struct server_mtp_warmup {
|
||||||
llama_context * ctx_tgt;
|
llama_context * ctx_tgt;
|
||||||
server_slot * slot;
|
server_slot * slot;
|
||||||
@ -72,72 +67,6 @@ static bool server_response_needs_chat_parse(oaicompat_type oaicompat) {
|
|||||||
oaicompat == OAICOMPAT_TYPE_RESP;
|
oaicompat == OAICOMPAT_TYPE_RESP;
|
||||||
}
|
}
|
||||||
|
|
||||||
void server_speculative_checkpoint::clear() {
|
|
||||||
valid = false;
|
|
||||||
per_step_enabled = false;
|
|
||||||
n_past = 0;
|
|
||||||
sampled = LLAMA_TOKEN_NULL;
|
|
||||||
|
|
||||||
if (sampler != nullptr) {
|
|
||||||
common_sampler_free(sampler);
|
|
||||||
sampler = nullptr;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static void discard_speculative_checkpoint(server_slot & slot, llama_context * ctx) {
|
|
||||||
slot.spec_ckpt.clear();
|
|
||||||
llama_spec_ckpt_discard(ctx);
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool save_speculative_checkpoint(server_slot & slot, llama_model * model, llama_context * ctx, int ckpt_mode) {
|
|
||||||
slot.spec_ckpt.clear();
|
|
||||||
const int32_t n_pre_spec_tokens = slot.cache_tokens.n_tokens() - (int32_t)(slot.drafted.size() + 1);
|
|
||||||
slot.spec_ckpt.n_past = slot.cache_tokens.pos_next(n_pre_spec_tokens);
|
|
||||||
slot.spec_ckpt.sampled = slot.sampled;
|
|
||||||
|
|
||||||
const int max_tokens = (int)slot.drafted.size() + 1;
|
|
||||||
const int actual_mode = llama_spec_ckpt_init(ctx, ckpt_mode, max_tokens);
|
|
||||||
if (actual_mode == LLAMA_SPEC_CKPT_NONE) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
slot.spec_ckpt.per_step_enabled = (actual_mode == LLAMA_SPEC_CKPT_PER_STEP);
|
|
||||||
|
|
||||||
slot.spec_ckpt.valid = llama_spec_ckpt_save(ctx, slot.id);
|
|
||||||
if (!slot.spec_ckpt.valid) {
|
|
||||||
llama_spec_ckpt_discard(ctx);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
slot.spec_ckpt.sampler = common_sampler_init(model, slot.sparams);
|
|
||||||
if (slot.spec_ckpt.sampler == nullptr) {
|
|
||||||
discard_speculative_checkpoint(slot, ctx);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
common_sampler_clone(slot.ctx_sampling, slot.spec_ckpt.sampler);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
static void server_remove_speculative_stage(common_params_speculative & spec, common_speculative_type type) {
|
|
||||||
spec.stages.erase(std::remove_if(spec.stages.begin(), spec.stages.end(), [type](const common_speculative_stage_params & stage) {
|
|
||||||
return stage.type == type;
|
|
||||||
}), spec.stages.end());
|
|
||||||
|
|
||||||
if (spec.type == type) {
|
|
||||||
spec.type = COMMON_SPECULATIVE_TYPE_NONE;
|
|
||||||
const auto resolved = spec.get_resolved_stages();
|
|
||||||
spec.type = resolved.empty() ? COMMON_SPECULATIVE_TYPE_NONE : resolved.front().type;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool server_speculative_needs_draft_model(const common_params_speculative & spec) {
|
|
||||||
return spec.has_stage_type(COMMON_SPECULATIVE_TYPE_DRAFT);
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool server_speculative_has_mtp(const common_params_speculative & spec) {
|
|
||||||
return spec.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP);
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool server_speculative_same_stage_types(
|
static bool server_speculative_same_stage_types(
|
||||||
const common_params_speculative & lhs,
|
const common_params_speculative & lhs,
|
||||||
const common_params_speculative & rhs) {
|
const common_params_speculative & rhs) {
|
||||||
@ -233,29 +162,17 @@ server_context::~server_context() {
|
|||||||
}
|
}
|
||||||
// Free multimodal
|
// Free multimodal
|
||||||
mtmd_free(mctx);
|
mtmd_free(mctx);
|
||||||
// Free draft model and context if they exist
|
|
||||||
if (ctx_draft) {
|
|
||||||
llama_free(ctx_draft);
|
|
||||||
ctx_draft = nullptr;
|
|
||||||
}
|
|
||||||
if (model_draft) {
|
|
||||||
llama_free_model(model_draft);
|
|
||||||
model_draft = nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clear any sampling context
|
// Clear any sampling context
|
||||||
for (server_slot& slot : slots) {
|
for (server_slot& slot : slots) {
|
||||||
if (slot.ctx_sampling != nullptr) {
|
if (slot.ctx_sampling != nullptr) {
|
||||||
common_sampler_free(slot.ctx_sampling);
|
common_sampler_free(slot.ctx_sampling);
|
||||||
}
|
}
|
||||||
slot.spec_ckpt.clear();
|
|
||||||
if (slot.ctx_dft) {
|
|
||||||
llama_free(slot.ctx_dft);
|
|
||||||
}
|
|
||||||
common_speculative_free(slot.spec);
|
common_speculative_free(slot.spec);
|
||||||
llama_batch_free(slot.batch_spec);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
params_base.speculative.clear_dft();
|
||||||
|
|
||||||
llama_batch_free(batch);
|
llama_batch_free(batch);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -278,24 +195,10 @@ bool server_context::load_model(const gpt_params& params_) {
|
|||||||
add_bos_token = llama_should_add_bos_token(model);
|
add_bos_token = llama_should_add_bos_token(model);
|
||||||
has_eos_token = llama_add_eos_token(model) != 1;
|
has_eos_token = llama_add_eos_token(model) != 1;
|
||||||
|
|
||||||
if (params_base.n_parallel > 1 && server_speculative_has_mtp(params_base.speculative)) {
|
common_speculative_prepare_startup(params_base, false);
|
||||||
LOG_WARNING("MTP is not supported with parallel slots yet, removing the MTP stage to avoid cross-slot corruption.\n", {
|
|
||||||
{"n_parallel", params_base.n_parallel},
|
|
||||||
{"stage_chain", common_speculative_stage_chain_to_str(params_base.speculative)},
|
|
||||||
});
|
|
||||||
|
|
||||||
params_base.has_mtp = false;
|
const bool has_draft_model = params_base.speculative.has_dft();
|
||||||
server_remove_speculative_stage(params_base.speculative, COMMON_SPECULATIVE_TYPE_MTP);
|
std::string & mmproj_path = params_base.mmproj.path;
|
||||||
|
|
||||||
if (!server_speculative_needs_draft_model(params_base.speculative)) {
|
|
||||||
params_base.speculative.model.clear();
|
|
||||||
params_base.speculative.params.clear();
|
|
||||||
params_base.speculative.model_dft = nullptr;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bool has_draft_model = !params_base.speculative.model.empty() || !params_base.speculative.params.empty();
|
|
||||||
std::string& mmproj_path = params_base.mmproj.path;
|
|
||||||
if (!mmproj_path.empty()) {
|
if (!mmproj_path.empty()) {
|
||||||
mtmd_context_params mparams = mtmd_context_params_default();
|
mtmd_context_params mparams = mtmd_context_params_default();
|
||||||
mparams.use_gpu = params_base.mmproj_use_gpu;
|
mparams.use_gpu = params_base.mmproj_use_gpu;
|
||||||
@ -309,10 +212,10 @@ bool server_context::load_model(const gpt_params& params_) {
|
|||||||
mparams.image_max_tokens = params_base.image_max_tokens;
|
mparams.image_max_tokens = params_base.image_max_tokens;
|
||||||
mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams);
|
mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams);
|
||||||
if (mctx == nullptr) {
|
if (mctx == nullptr) {
|
||||||
LOG_ERROR("failed to load multimodal model, '%s'\n", mmproj_path.c_str());
|
LOG_ERROR("failed to load multimodal model, %s\n", mmproj_path.c_str());
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
LOG_INFO("loaded multimodal model, '%s'\n", mmproj_path.c_str());
|
LOG_INFO("loaded multimodal model, %s\n", mmproj_path.c_str());
|
||||||
|
|
||||||
//if (params.n_cache_reuse) {
|
//if (params.n_cache_reuse) {
|
||||||
// params_base.n_cache_reuse = 0;
|
// params_base.n_cache_reuse = 0;
|
||||||
@ -323,71 +226,22 @@ bool server_context::load_model(const gpt_params& params_) {
|
|||||||
LOG_ERROR("%s\n", "err: speculative decode is not supported by multimodal");
|
LOG_ERROR("%s\n", "err: speculative decode is not supported by multimodal");
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
const auto spec_stages = params_base.speculative.get_resolved_stages();
|
|
||||||
const bool multimodal_spec_supported = spec_stages.empty() ||
|
const auto spec_stages = params_base.speculative.get_resolved_stages();
|
||||||
(spec_stages.size() == 1 && spec_stages.front().type == COMMON_SPECULATIVE_TYPE_MTP);
|
const bool multimodal_spec_supported = spec_stages.empty() ||
|
||||||
if (!multimodal_spec_supported) {
|
(spec_stages.size() == 1 && spec_stages.front().type == COMMON_SPECULATIVE_TYPE_MTP);
|
||||||
|
if (!multimodal_spec_supported) {
|
||||||
params_base.speculative.type = COMMON_SPECULATIVE_TYPE_NONE;
|
params_base.speculative.type = COMMON_SPECULATIVE_TYPE_NONE;
|
||||||
params_base.speculative.stages.clear();
|
params_base.speculative.stages.clear();
|
||||||
params_base.has_mtp = false;
|
params_base.has_mtp = false;
|
||||||
SRV_WRN("%s\n", "speculative decoding is not supported by multimodal, it will be disabled");
|
SRV_WRN("%s\n", "speculative decoding is not supported by multimodal, it will be disabled");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Load draft model for speculative decoding if specified
|
|
||||||
if (has_draft_model) {
|
|
||||||
LLAMA_LOG_INFO("\n\n==================================loading DRAFT model==================================\n\n");
|
|
||||||
|
|
||||||
gpt_params params_dft;
|
|
||||||
params_dft.devices = params_base.speculative.devices;
|
|
||||||
params_dft.model = params_base.speculative.model;
|
|
||||||
params_dft.main_gpu = params_base.main_gpu;
|
|
||||||
params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
|
|
||||||
params_dft.rpc_servers = params_base.rpc_servers;
|
|
||||||
params_dft.cache_type_k = params_base.speculative.cache_type_k.empty() ? params_base.cache_type_k : params_base.speculative.cache_type_k;
|
|
||||||
params_dft.cache_type_v = params_base.speculative.cache_type_v.empty() ? params_base.cache_type_v : params_base.speculative.cache_type_v;
|
|
||||||
params_dft.flash_attn = params_base.flash_attn;
|
|
||||||
params_dft.k_cache_hadamard = params_base.k_cache_hadamard;
|
|
||||||
params_dft.v_cache_hadamard = params_base.v_cache_hadamard;
|
|
||||||
if (!params_base.speculative.params.empty()) {
|
|
||||||
auto [argc, argv] = parse_command_line("llama-server " + params_base.speculative.params);
|
|
||||||
if (!gpt_params_parse(argc, argv, params_dft)) {
|
|
||||||
gpt_params_print_usage(argc, argv, params_dft);
|
|
||||||
free_command_line(argc, argv);
|
|
||||||
return false;
|
|
||||||
};
|
|
||||||
free_command_line(argc, argv);
|
|
||||||
}
|
|
||||||
LOG_INFO("", { {"model", params_dft.model} });
|
|
||||||
if (params_dft.n_ctx == 0) {
|
|
||||||
params_dft.n_ctx = params_base.speculative.n_ctx;
|
|
||||||
}
|
|
||||||
params_dft.n_ctx = params_dft.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_dft.n_ctx;
|
|
||||||
params_dft.n_parallel = 1;
|
|
||||||
params_dft.n_batch = params_dft.n_ctx;
|
|
||||||
|
|
||||||
params_base.speculative.mparams_dft.path = params_dft.model; //
|
|
||||||
|
|
||||||
llama_model_params mparams_dft = common_model_params_to_llama(params_dft);
|
|
||||||
|
|
||||||
llama_model * model_dft = llama_model_load_from_file(params_dft.model.c_str(), mparams_dft);
|
|
||||||
if (model_dft == nullptr) {
|
|
||||||
LOG_ERROR("failed to load draft model", { {"model", params_base.speculative.model} });
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
cparams_dft = common_context_params_to_llama(params_dft);
|
|
||||||
|
|
||||||
params_base.speculative.model_dft = model_dft;
|
|
||||||
params_base.speculative.cparams_dft = cparams_dft;
|
|
||||||
|
|
||||||
|
if (!common_speculative_finalize_startup(params_base, model)) {
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
if (server_speculative_has_mtp(params_base.speculative) &&
|
|
||||||
llama_model_n_nextn_layer(model) == 0 &&
|
|
||||||
!params_use_gemma4_external_mtp(params_base)) {
|
|
||||||
LOG_WARNING("WARNING: MTP speculative stage requested, but model has 0 NextN layers. MTP will be disabled.\n", {});
|
|
||||||
params_base.has_mtp = false;
|
|
||||||
server_remove_speculative_stage(params_base.speculative, COMMON_SPECULATIVE_TYPE_MTP);
|
|
||||||
}
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -396,6 +250,20 @@ void server_context::init() {
|
|||||||
|
|
||||||
LOG_INFO("initializing slots", { {"n_slots", params_base.n_parallel} });
|
LOG_INFO("initializing slots", { {"n_slots", params_base.n_parallel} });
|
||||||
|
|
||||||
|
if (params_base.has_mtp) {
|
||||||
|
SRV_INF("%s\n", "MTP needs embeddings on decode, enabling");
|
||||||
|
llama_set_embeddings(ctx, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
const bool requested_spec = params_base.speculative.has_stage_chain();
|
||||||
|
bool can_spec = true;
|
||||||
|
if (!params_base.dry_run) {
|
||||||
|
can_spec = common_speculative_is_compat(ctx);
|
||||||
|
}
|
||||||
|
if (!can_spec && requested_spec) {
|
||||||
|
SRV_WRN("%s", "speculative decoding not supported by this context\n");
|
||||||
|
}
|
||||||
|
|
||||||
for (int i = 0; i < params_base.n_parallel; i++) {
|
for (int i = 0; i < params_base.n_parallel; i++) {
|
||||||
server_slot slot;
|
server_slot slot;
|
||||||
|
|
||||||
@ -440,68 +308,27 @@ void server_context::init() {
|
|||||||
slot.params.speculative = params_base.speculative;
|
slot.params.speculative = params_base.speculative;
|
||||||
slot.sparams = params_base.sparams;
|
slot.sparams = params_base.sparams;
|
||||||
|
|
||||||
const bool wants_mtp_stage = server_speculative_has_mtp(params_base.speculative);
|
|
||||||
if (wants_mtp_stage) {
|
|
||||||
const bool has_external_mtp = params_use_gemma4_external_mtp(params_base);
|
|
||||||
|
|
||||||
if (llama_model_n_nextn_layer(model) > 0 || has_external_mtp) {
|
|
||||||
params_base.pooling_type = LLAMA_POOLING_TYPE_NONE;
|
|
||||||
|
|
||||||
if (!has_external_mtp) {
|
|
||||||
params_base.speculative.cparams_dft = common_context_params_to_llama(params_base);
|
|
||||||
}
|
|
||||||
|
|
||||||
params_base.speculative.cparams_dft.mtp = true;
|
|
||||||
params_base.speculative.cparams_dft.mtp_op_type = MTP_OP_WARMUP;
|
|
||||||
params_base.speculative.cparams_dft.embeddings = true;
|
|
||||||
|
|
||||||
slot.has_mtp = true;
|
|
||||||
slot.params.speculative.cparams_dft = params_base.speculative.cparams_dft;
|
|
||||||
|
|
||||||
slot.batch_spec = llama_batch_init(slot.params.speculative.get_max_stage_n_max() + 1, 0, 1);
|
|
||||||
SLT_DBG(slot, "batch_spec contains %d tokens\n", slot.batch_spec.n_tokens);
|
|
||||||
|
|
||||||
SRV_INF("%s\n", "MTP needs embeddings on decode, enabling");
|
|
||||||
llama_set_embeddings(ctx, true);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
SRV_WRN("%s\n", "MTP speculative stage requested, but model has 0 NextN layers. Removing MTP from the configured stage chain.");
|
|
||||||
params_base.has_mtp = false;
|
|
||||||
server_remove_speculative_stage(params_base.speculative, COMMON_SPECULATIVE_TYPE_MTP);
|
|
||||||
slot.params.speculative = params_base.speculative;
|
|
||||||
slot.has_mtp = false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const bool requested_spec = !params_base.speculative.get_resolved_stages().empty();
|
|
||||||
|
|
||||||
bool can_spec = true;
|
|
||||||
if (!params_base.dry_run) {
|
|
||||||
can_spec = common_speculative_is_compat(ctx);
|
|
||||||
}
|
|
||||||
if (!can_spec) {
|
|
||||||
SRV_WRN("%s", "speculative decoding not supported by this context\n");
|
|
||||||
}
|
|
||||||
// try speculative decoding
|
// try speculative decoding
|
||||||
if (can_spec && requested_spec) {
|
if (can_spec && requested_spec) {
|
||||||
slot.spec = common_speculative_init(params_base.speculative, slot.ctx);
|
switch (common_speculative_try_init(params_base.speculative, slot.ctx, &slot.spec)) {
|
||||||
if (slot.spec) {
|
case COMMON_SPECULATIVE_INIT_READY:
|
||||||
if (mctx && !slot.has_mtp) {
|
if (mctx && !slot.uses_mtp()) {
|
||||||
SRV_ERR("%s\n", "speculative decoding is not supported with multimodal");
|
SRV_ERR("%s\n", "speculative decoding is not supported with multimodal");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
SLT_INF(slot, "%s", "speculative decoding context initialized\n");
|
SLT_INF(slot, "%s", "speculative decoding context initialized\n");
|
||||||
} else {
|
break;
|
||||||
if (llama_model_has_recurrent(model)) {
|
case COMMON_SPECULATIVE_INIT_ERR_RECURRENT:
|
||||||
SRV_ERR("%s", "failed to initialize recurrent speculative context\n");
|
SRV_ERR("%s", "failed to initialize recurrent speculative context\n");
|
||||||
throw std::runtime_error("recurrent speculative context initialization failed");
|
throw std::runtime_error("recurrent speculative context initialization failed");
|
||||||
} else if (slot.has_mtp) {
|
case COMMON_SPECULATIVE_INIT_ERR_MTP:
|
||||||
SRV_ERR("%s", "failed to initialize MTP speculative context\n");
|
SRV_ERR("%s", "failed to initialize MTP speculative context\n");
|
||||||
throw std::runtime_error("MTP speculative context initialization failed");
|
throw std::runtime_error("MTP speculative context initialization failed");
|
||||||
} else {
|
case COMMON_SPECULATIVE_INIT_ERR_GENERIC:
|
||||||
SRV_ERR("%s", "failed to initialize speculative decoding context\n");
|
SRV_ERR("%s", "failed to initialize speculative decoding context\n");
|
||||||
throw std::runtime_error("speculative decoding context initialization failed");
|
throw std::runtime_error("speculative decoding context initialization failed");
|
||||||
}
|
case COMMON_SPECULATIVE_INIT_SKIPPED:
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -620,9 +447,7 @@ void server_slot::reset() {
|
|||||||
n_kept_prompt = 0;
|
n_kept_prompt = 0;
|
||||||
n_sent_text = 0;
|
n_sent_text = 0;
|
||||||
drafted.clear();
|
drafted.clear();
|
||||||
drafted_spec_type = COMMON_SPECULATIVE_TYPE_NONE;
|
|
||||||
i_batch_dft.clear();
|
i_batch_dft.clear();
|
||||||
spec_ckpt.clear();
|
|
||||||
n_sent_token_probs = 0;
|
n_sent_token_probs = 0;
|
||||||
infill = false;
|
infill = false;
|
||||||
ga_i = 0;
|
ga_i = 0;
|
||||||
@ -640,7 +465,7 @@ void server_slot::reset() {
|
|||||||
image_just_processed = false;
|
image_just_processed = false;
|
||||||
do_checkpoint = false;
|
do_checkpoint = false;
|
||||||
if (spec != nullptr) {
|
if (spec != nullptr) {
|
||||||
common_speculative_clear_sequence_hidden(spec, id);
|
common_speculative_clear_sequence(spec, id);
|
||||||
}
|
}
|
||||||
|
|
||||||
positional_bans.clear();
|
positional_bans.clear();
|
||||||
@ -675,7 +500,11 @@ void server_slot::reset() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool server_slot::need_embd() const {
|
bool server_slot::need_embd() const {
|
||||||
return embedding || has_mtp;
|
return embedding || uses_mtp();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool server_slot::uses_mtp() const {
|
||||||
|
return params.speculative.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool server_slot::has_budget(gpt_params& global_params) {
|
bool server_slot::has_budget(gpt_params& global_params) {
|
||||||
@ -711,7 +540,7 @@ void server_slot::add_token_string(const completion_token_output& token) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool server_slot::can_speculate() const {
|
bool server_slot::can_speculate() const {
|
||||||
return (!!spec || has_mtp);
|
return (!!spec || uses_mtp());
|
||||||
}
|
}
|
||||||
|
|
||||||
int server_slot::get_n_draft_max() const {
|
int server_slot::get_n_draft_max() const {
|
||||||
@ -1327,7 +1156,7 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
|
|||||||
throw std::runtime_error("Error: per-request speculative stages must match the server startup stage types; only stage parameter overrides are supported");
|
throw std::runtime_error("Error: per-request speculative stages must match the server startup stage types; only stage parameter overrides are supported");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (slot.params.speculative.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP) && !slot.has_mtp) {
|
if (slot.params.speculative.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP) && !params_base.has_mtp) {
|
||||||
throw std::runtime_error("Error: MTP speculative stage requested, but the server was not started with MTP support");
|
throw std::runtime_error("Error: MTP speculative stage requested, but the server was not started with MTP support");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2107,10 +1936,7 @@ void server_context::kv_cache_clear() {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
common_speculative_clear_sequence_hidden(slot.spec, slot.id);
|
common_speculative_clear_sequence(slot.spec, slot.id, true);
|
||||||
if (auto * ctx_companion = common_speculative_get_companion_ctx(slot.spec); ctx_companion != nullptr) {
|
|
||||||
llama_kv_cache_clear(ctx_companion);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
clean_kv_cache = false;
|
clean_kv_cache = false;
|
||||||
}
|
}
|
||||||
@ -3360,7 +3186,7 @@ void server_context::discard_n_kv_and_cache_tokens(llama_context* ctx, server_sl
|
|||||||
const auto pos_max = llama_kv_cache_seq_pos_max(slot.ctx, slot.id);
|
const auto pos_max = llama_kv_cache_seq_pos_max(slot.ctx, slot.id);
|
||||||
llama_kv_cache_seq_rm(ctx, slot.id, slot.cache_tokens.pos_next(kv_keep), slot.cache_tokens.pos_next(kv_keep + kv_discard));
|
llama_kv_cache_seq_rm(ctx, slot.id, slot.cache_tokens.pos_next(kv_keep), slot.cache_tokens.pos_next(kv_keep + kv_discard));
|
||||||
llama_kv_cache_seq_add(ctx, slot.id, kv_keep + kv_discard, kv_past, -kv_discard);
|
llama_kv_cache_seq_add(ctx, slot.id, kv_keep + kv_discard, kv_past, -kv_discard);
|
||||||
if (slot.has_mtp && slot.spec) {
|
if (slot.uses_mtp() && slot.spec) {
|
||||||
common_speculative_context_shift(slot.spec, slot.id, kv_keep, kv_discard, kv_past);
|
common_speculative_context_shift(slot.spec, slot.id, kv_keep, kv_discard, kv_past);
|
||||||
}
|
}
|
||||||
if (slot.params.cache_prompt) {
|
if (slot.params.cache_prompt) {
|
||||||
@ -3569,33 +3395,27 @@ void server_context::add_sampled_tokens() {
|
|||||||
// perform the speculative drafting for all sequences at the same time in a single batch
|
// perform the speculative drafting for all sequences at the same time in a single batch
|
||||||
const int n_draft_max_pre = slot.get_n_draft_max();
|
const int n_draft_max_pre = slot.get_n_draft_max();
|
||||||
if (n_draft_max_pre > 0) {
|
if (n_draft_max_pre > 0) {
|
||||||
if (mctx && !slot.has_mtp) {
|
if (mctx && !slot.uses_mtp()) {
|
||||||
// we should never reach this, as speculative is automatically disabled if mmproj is loaded
|
// we should never reach this, as speculative is automatically disabled if mmproj is loaded
|
||||||
GGML_ABORT("not supported by multimodal");
|
GGML_ABORT("not supported by multimodal");
|
||||||
}
|
}
|
||||||
|
|
||||||
static const llama_tokens empty_prompt;
|
static const llama_tokens empty_prompt;
|
||||||
const llama_tokens & cached_text_tokens = slot.has_mtp && !slot.params.speculative.has_composite_stage_chain()
|
const llama_tokens & cached_text_tokens = slot.uses_mtp() && !slot.params.speculative.has_composite_stage_chain()
|
||||||
? empty_prompt
|
? empty_prompt
|
||||||
: slot.cache_tokens.get_text_tokens();
|
: slot.cache_tokens.get_text_tokens();
|
||||||
|
|
||||||
auto & params_spec = slot.params.speculative;
|
auto & params_spec = slot.params.speculative;
|
||||||
const llama_pos draft_base_pos = slot.has_mtp ? slot.cache_tokens.pos_next() : -1;
|
const llama_pos draft_base_pos = slot.uses_mtp() ? slot.cache_tokens.pos_next() : -1;
|
||||||
|
common_speculative_draft_result draft_result = common_speculative_draft_ex(
|
||||||
if (slot.has_mtp) {
|
|
||||||
if (!common_speculative_ensure_sequence_hidden(slot.spec, ctx, slot.id, draft_base_pos - 1)) {
|
|
||||||
LOG_ERROR("MTP hidden state is empty during speculation", {});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
llama_tokens draft = common_speculative_draft(
|
|
||||||
slot.spec,
|
slot.spec,
|
||||||
|
ctx,
|
||||||
params_spec,
|
params_spec,
|
||||||
cached_text_tokens,
|
cached_text_tokens,
|
||||||
slot.sampled,
|
slot.sampled,
|
||||||
draft_base_pos,
|
draft_base_pos,
|
||||||
slot.id);
|
slot.id);
|
||||||
slot.drafted_spec_type = common_speculative_current_type(slot.spec);
|
llama_tokens & draft = draft_result.tokens;
|
||||||
|
|
||||||
const int n_draft_max = slot.get_n_draft_max();
|
const int n_draft_max = slot.get_n_draft_max();
|
||||||
|
|
||||||
@ -3620,7 +3440,6 @@ void server_context::add_sampled_tokens() {
|
|||||||
// fallback to normal decoding
|
// fallback to normal decoding
|
||||||
slot.i_batch = slot.i_batch_dft[0];
|
slot.i_batch = slot.i_batch_dft[0];
|
||||||
slot.drafted.clear();
|
slot.drafted.clear();
|
||||||
slot.drafted_spec_type = COMMON_SPECULATIVE_TYPE_NONE;
|
|
||||||
slot.i_batch_dft.clear();
|
slot.i_batch_dft.clear();
|
||||||
} else {
|
} else {
|
||||||
// keep track of total number of drafted tokens tested
|
// keep track of total number of drafted tokens tested
|
||||||
@ -3637,7 +3456,6 @@ void server_context::add_sampled_tokens() {
|
|||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
// no speculative decoding
|
// no speculative decoding
|
||||||
slot.drafted_spec_type = COMMON_SPECULATIVE_TYPE_NONE;
|
|
||||||
slot.i_batch = batch.n_tokens;
|
slot.i_batch = batch.n_tokens;
|
||||||
|
|
||||||
common_batch_add(batch, slot.sampled, slot.cache_tokens.pos_next(), { slot.id }, true);
|
common_batch_add(batch, slot.sampled, slot.cache_tokens.pos_next(), { slot.id }, true);
|
||||||
@ -3977,15 +3795,10 @@ void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t
|
|||||||
slot.cache_tokens.keep_first(slot.n_past);
|
slot.cache_tokens.keep_first(slot.n_past);
|
||||||
int p0 = (int)system_tokens.size() + slot.n_past;
|
int p0 = (int)system_tokens.size() + slot.n_past;
|
||||||
p0 = system_tokens.size() + slot.cache_tokens.pos_next();
|
p0 = system_tokens.size() + slot.cache_tokens.pos_next();
|
||||||
auto * ctx_companion = slot.spec ? common_speculative_get_companion_ctx(slot.spec) : nullptr;
|
const bool trimmed = common_speculative_trim_sequence(slot.spec, ctx, slot.id, p0);
|
||||||
const bool target_trimmed = llama_kv_cache_seq_rm(ctx, slot.id, p0, -1);
|
if (!trimmed) {
|
||||||
const bool companion_trimmed = ctx_companion == nullptr || llama_kv_cache_seq_rm(ctx_companion, slot.id, p0, -1);
|
|
||||||
if (!target_trimmed || !companion_trimmed) {
|
|
||||||
// could not partially delete (likely using a non-Transformer model)
|
// could not partially delete (likely using a non-Transformer model)
|
||||||
llama_kv_cache_seq_rm(ctx, slot.id, -1, -1);
|
common_speculative_clear_sequence_kv(slot.spec, ctx, slot.id);
|
||||||
if (ctx_companion != nullptr) {
|
|
||||||
llama_kv_cache_seq_rm(ctx_companion, slot.id, -1, -1);
|
|
||||||
}
|
|
||||||
|
|
||||||
p0 = (int)system_tokens.size();
|
p0 = (int)system_tokens.size();
|
||||||
if (p0 != 0) {
|
if (p0 != 0) {
|
||||||
@ -4022,7 +3835,7 @@ void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t
|
|||||||
llama_pos p1 = slot.cache_tokens.pos_next() + slot.n_past_prompt - slot.n_past; // add offset to prompt
|
llama_pos p1 = slot.cache_tokens.pos_next() + slot.n_past_prompt - slot.n_past; // add offset to prompt
|
||||||
server_mtp_warmup mtp_media_warmup {
|
server_mtp_warmup mtp_media_warmup {
|
||||||
ctx,
|
ctx,
|
||||||
slot.has_mtp && slot.spec ? &slot : nullptr,
|
slot.uses_mtp() && slot.spec ? &slot : nullptr,
|
||||||
};
|
};
|
||||||
mtmd_helper_eval_batch_callback mtp_media_callback =
|
mtmd_helper_eval_batch_callback mtp_media_callback =
|
||||||
mtp_media_warmup.slot ? server_mtp_media_warmup_callback : nullptr;
|
mtp_media_warmup.slot ? server_mtp_media_warmup_callback : nullptr;
|
||||||
@ -4164,103 +3977,6 @@ void server_context::extend_context(const int32_t n_tokens) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Restore recurrent state and re-decode accepted tokens after speculative-decode rejection.
|
|
||||||
static void restore_speculative_checkpoint(
|
|
||||||
server_slot & slot, llama_context * ctx, llama_model * model,
|
|
||||||
common_speculative_type spec_type_used,
|
|
||||||
llama_token sampled_before,
|
|
||||||
const std::vector<llama_token> & ids, int n_draft,
|
|
||||||
const std::vector<float> & mtp_hidden_state_pre, int32_t mtp_n_past_base) {
|
|
||||||
if (slot.spec_ckpt.per_step_enabled) {
|
|
||||||
const int step = (int)ids.size() - 1;
|
|
||||||
llama_spec_ckpt_restore(ctx, slot.id, slot.spec_ckpt.n_past, step);
|
|
||||||
|
|
||||||
if (slot.spec_ckpt.sampler) {
|
|
||||||
common_sampler_clone(slot.spec_ckpt.sampler, slot.ctx_sampling);
|
|
||||||
}
|
|
||||||
for (llama_token id : ids) {
|
|
||||||
common_sampler_accept(slot.ctx_sampling, ctx, id, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update MTP KV cache and hidden state using embeddings collected before checkpoint restore.
|
|
||||||
if (slot.has_mtp && !mtp_hidden_state_pre.empty()) {
|
|
||||||
if (!common_speculative_commit_accepted_hidden_rows(
|
|
||||||
slot.spec,
|
|
||||||
spec_type_used,
|
|
||||||
slot.id,
|
|
||||||
mtp_n_past_base,
|
|
||||||
sampled_before,
|
|
||||||
ids,
|
|
||||||
mtp_hidden_state_pre)) {
|
|
||||||
common_speculative_clear_sequence_hidden(slot.spec, slot.id);
|
|
||||||
} else if (spec_type_used != COMMON_SPECULATIVE_TYPE_MTP) {
|
|
||||||
SLT_DBG(slot, "%s", "synced MTP target hidden state from accepted-prefix rows after per-step restore");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
SLT_DBG(slot, "per-step restore: step=%d (rejected %d drafts)\n",
|
|
||||||
step, (int)(n_draft - (ids.size() - 1)));
|
|
||||||
} else {
|
|
||||||
// Restore pre-speculation recurrent state then re-decode accepted tokens.
|
|
||||||
llama_spec_ckpt_restore(ctx, slot.id, slot.spec_ckpt.n_past, 0);
|
|
||||||
|
|
||||||
if (slot.spec_ckpt.sampler) {
|
|
||||||
common_sampler_clone(slot.spec_ckpt.sampler, slot.ctx_sampling);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!ids.empty()) {
|
|
||||||
// Re-decode to advance recurrent state to the accepted position.
|
|
||||||
const int n_re = (int)ids.size();
|
|
||||||
llama_batch re_batch = llama_batch_init(n_re, 0, 1);
|
|
||||||
common_batch_add(re_batch, slot.spec_ckpt.sampled, slot.spec_ckpt.n_past, { slot.id }, n_re == 1);
|
|
||||||
for (int j = 0; j < n_re - 1; j++) {
|
|
||||||
common_batch_add(re_batch, ids[j], slot.spec_ckpt.n_past + 1 + j, { slot.id }, j == n_re - 2);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (slot.has_mtp) {
|
|
||||||
for (int j = 0; j < re_batch.n_tokens; j++) {
|
|
||||||
re_batch.logits[j] = true;
|
|
||||||
}
|
|
||||||
llama_set_embeddings(ctx, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
const int ret = llama_decode(ctx, re_batch);
|
|
||||||
if (ret != 0) {
|
|
||||||
SLT_ERR(slot, "failed to re-decode accepted tokens after checkpoint restore: %d\n", ret);
|
|
||||||
}
|
|
||||||
if (slot.has_mtp) {
|
|
||||||
const int n_accepted = (int)ids.size();
|
|
||||||
std::vector<int32_t> redecoded_indices(n_accepted);
|
|
||||||
for (int j = 0; j < n_accepted; ++j) {
|
|
||||||
redecoded_indices[j] = j;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!common_speculative_commit_accepted_output(
|
|
||||||
slot.spec,
|
|
||||||
ctx,
|
|
||||||
spec_type_used,
|
|
||||||
slot.id,
|
|
||||||
slot.spec_ckpt.n_past,
|
|
||||||
sampled_before,
|
|
||||||
ids,
|
|
||||||
redecoded_indices)) {
|
|
||||||
common_speculative_clear_sequence_hidden(slot.spec, slot.id);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (llama_token id : ids) {
|
|
||||||
common_sampler_accept(slot.ctx_sampling, ctx, id, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
llama_batch_free(re_batch);
|
|
||||||
SLT_DBG(slot, "spec checkpoint restored: re-decoded %d tokens (rejected %d drafts)\n",
|
|
||||||
n_re, (int)(n_draft - (ids.size() - 1)));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
discard_speculative_checkpoint(slot, ctx);
|
|
||||||
}
|
|
||||||
|
|
||||||
void server_context::speculative_decoding_accept() {
|
void server_context::speculative_decoding_accept() {
|
||||||
for (auto& slot : slots) {
|
for (auto& slot : slots) {
|
||||||
if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch_dft.empty()) {
|
if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch_dft.empty()) {
|
||||||
@ -4268,7 +3984,6 @@ void server_context::speculative_decoding_accept() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const llama_token sampled_before = slot.sampled;
|
const llama_token sampled_before = slot.sampled;
|
||||||
const common_speculative_type spec_type_used = slot.drafted_spec_type;
|
|
||||||
size_t n_draft = slot.drafted.size();
|
size_t n_draft = slot.drafted.size();
|
||||||
|
|
||||||
slot.ctx_sampling->to_generated_text = &slot.generated_text;
|
slot.ctx_sampling->to_generated_text = &slot.generated_text;
|
||||||
@ -4298,28 +4013,15 @@ void server_context::speculative_decoding_accept() {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
const bool any_rejected = (ids.size() - 1) < n_draft;
|
|
||||||
int32_t mtp_n_past_base = 0;
|
|
||||||
std::vector<float> mtp_hidden_state_pre;
|
|
||||||
std::vector<int32_t> accepted_output_indices;
|
std::vector<int32_t> accepted_output_indices;
|
||||||
if (slot.has_mtp) {
|
if (slot.uses_mtp()) {
|
||||||
const int32_t n_pre_spec_tokens = slot.cache_tokens.n_tokens() - (int32_t)(slot.drafted.size() + 1);
|
|
||||||
mtp_n_past_base = slot.cache_tokens.pos_next(n_pre_spec_tokens);
|
|
||||||
|
|
||||||
if (!ids.empty()) {
|
if (!ids.empty()) {
|
||||||
accepted_output_indices.assign(slot.i_batch_dft.begin(), slot.i_batch_dft.begin() + ids.size());
|
accepted_output_indices.assign(slot.i_batch_dft.begin(), slot.i_batch_dft.begin() + ids.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (any_rejected && slot.spec_ckpt.valid && !accepted_output_indices.empty()) {
|
|
||||||
if (!common_speculative_copy_output_hidden_rows(slot.spec, ctx, accepted_output_indices, mtp_hidden_state_pre)) {
|
|
||||||
mtp_hidden_state_pre.clear();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
slot.i_batch_dft.clear();
|
slot.i_batch_dft.clear();
|
||||||
slot.drafted.clear();
|
slot.drafted.clear();
|
||||||
slot.drafted_spec_type = COMMON_SPECULATIVE_TYPE_NONE;
|
|
||||||
|
|
||||||
slot.n_past += ids.size();
|
slot.n_past += ids.size();
|
||||||
slot.n_decoded += ids.size();
|
slot.n_decoded += ids.size();
|
||||||
@ -4329,11 +4031,9 @@ void server_context::speculative_decoding_accept() {
|
|||||||
// update how many tokens out of those tested were accepted
|
// update how many tokens out of those tested were accepted
|
||||||
slot.n_draft_accepted += ids.size() - 1;
|
slot.n_draft_accepted += ids.size() - 1;
|
||||||
|
|
||||||
// inform the speculative decoding about the number of accepted tokens
|
|
||||||
common_speculative_accept(slot.spec, ids.size() - 1);
|
|
||||||
|
|
||||||
// rollback to the state before sampling the draft tokens
|
// rollback to the state before sampling the draft tokens
|
||||||
slot.cache_tokens.keep_first(slot.cache_tokens.n_tokens() - n_draft);
|
slot.cache_tokens.keep_first(slot.cache_tokens.n_tokens() - n_draft);
|
||||||
|
const llama_pos spec_pos_base = slot.cache_tokens.pos_next();
|
||||||
|
|
||||||
// add accepted tokens to the prompt
|
// add accepted tokens to the prompt
|
||||||
for (auto it = ids.begin(); it != ids.end() - 1; ++it) {
|
for (auto it = ids.begin(); it != ids.end() - 1; ++it) {
|
||||||
@ -4342,28 +4042,16 @@ void server_context::speculative_decoding_accept() {
|
|||||||
slot.sampled = ids.back(); // last accepted token
|
slot.sampled = ids.back(); // last accepted token
|
||||||
slot.n_past = slot.cache_tokens.n_tokens();
|
slot.n_past = slot.cache_tokens.n_tokens();
|
||||||
|
|
||||||
// for recurrent/hybrid models: if any drafts were rejected, restore recurrent state
|
common_speculative_commit(
|
||||||
if (any_rejected && slot.spec_ckpt.valid) {
|
slot.spec,
|
||||||
restore_speculative_checkpoint(slot, ctx, model, spec_type_used, sampled_before, ids, n_draft, mtp_hidden_state_pre, mtp_n_past_base);
|
ctx,
|
||||||
} else {
|
slot.ctx_sampling,
|
||||||
if (slot.has_mtp && !accepted_output_indices.empty()) {
|
slot.id,
|
||||||
if (!common_speculative_commit_accepted_output(
|
sampled_before,
|
||||||
slot.spec,
|
ids,
|
||||||
ctx,
|
n_draft,
|
||||||
spec_type_used,
|
spec_pos_base,
|
||||||
slot.id,
|
accepted_output_indices);
|
||||||
mtp_n_past_base,
|
|
||||||
sampled_before,
|
|
||||||
ids,
|
|
||||||
accepted_output_indices)) {
|
|
||||||
common_speculative_clear_sequence_hidden(slot.spec, slot.id);
|
|
||||||
} else if (spec_type_used != COMMON_SPECULATIVE_TYPE_MTP) {
|
|
||||||
SLT_DBG(slot, "%s", "synced MTP target hidden state from accepted-prefix rows");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
llama_kv_cache_seq_rm(ctx, slot.id, slot.cache_tokens.pos_next(slot.n_past), -1);
|
|
||||||
discard_speculative_checkpoint(slot, ctx);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (size_t i = 0; i < ids.size(); ++i) {
|
for (size_t i = 0; i < ids.size(); ++i) {
|
||||||
completion_token_output result;
|
completion_token_output result;
|
||||||
@ -4737,9 +4425,9 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
|
|||||||
continue; // continue loop of n_batch
|
continue; // continue loop of n_batch
|
||||||
}
|
}
|
||||||
|
|
||||||
if (server_speculative_has_mtp(params_base.speculative)) {
|
if (params_base.speculative.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP)) {
|
||||||
for (auto & slot : slots) {
|
for (auto & slot : slots) {
|
||||||
if (!slot.spec || !slot.has_mtp) {
|
if (!slot.spec || !slot.uses_mtp()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -4779,7 +4467,7 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
|
|||||||
|
|
||||||
if (slot.n_decoded == 0 && slot.can_speculate()) {
|
if (slot.n_decoded == 0 && slot.can_speculate()) {
|
||||||
static const llama_tokens empty_prompt;
|
static const llama_tokens empty_prompt;
|
||||||
const llama_tokens & spec_prompt = slot.has_mtp && !slot.params.speculative.has_composite_stage_chain()
|
const llama_tokens & spec_prompt = slot.uses_mtp() && !slot.params.speculative.has_composite_stage_chain()
|
||||||
? empty_prompt
|
? empty_prompt
|
||||||
: slot.cache_tokens.get_text_tokens();
|
: slot.cache_tokens.get_text_tokens();
|
||||||
common_speculative_begin(slot.spec, spec_prompt);
|
common_speculative_begin(slot.spec, spec_prompt);
|
||||||
@ -4803,7 +4491,7 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
|
|||||||
completion_token_output result;
|
completion_token_output result;
|
||||||
const int tok_idx = slot.i_batch - i;
|
const int tok_idx = slot.i_batch - i;
|
||||||
|
|
||||||
if (slot.has_mtp && slot.n_decoded == 0) {
|
if (slot.uses_mtp() && slot.n_decoded == 0) {
|
||||||
(void) common_speculative_capture_output_hidden(slot.spec, ctx, tok_idx, slot.id, slot.n_past);
|
(void) common_speculative_capture_output_hidden(slot.spec, ctx, tok_idx, slot.id, slot.n_past);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -4935,10 +4623,25 @@ void server_context::update_slots() {
|
|||||||
if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch_dft.empty()) {
|
if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch_dft.empty()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (save_speculative_checkpoint(slot, model, ctx, ckpt_mode)) {
|
const int32_t n_pre_spec_tokens = slot.cache_tokens.n_tokens() - (int32_t) (slot.drafted.size() + 1);
|
||||||
const char * mode_name = slot.spec_ckpt.per_step_enabled ? "per-step" : "shadow/cpu";
|
const llama_pos n_past_pre_spec = slot.cache_tokens.pos_next(n_pre_spec_tokens);
|
||||||
|
const int max_tokens = (int) slot.drafted.size() + 1;
|
||||||
|
if (common_speculative_before_draft(
|
||||||
|
slot.spec,
|
||||||
|
model,
|
||||||
|
ctx,
|
||||||
|
slot.ctx_sampling,
|
||||||
|
slot.sparams,
|
||||||
|
slot.id,
|
||||||
|
n_past_pre_spec,
|
||||||
|
slot.sampled,
|
||||||
|
max_tokens,
|
||||||
|
ckpt_mode)) {
|
||||||
|
const common_speculative_checkpoint * ckpt = common_speculative_get_checkpoint(slot.spec);
|
||||||
|
GGML_ASSERT(ckpt != nullptr);
|
||||||
|
const char * mode_name = ckpt->per_step_enabled ? "per-step" : "shadow/cpu";
|
||||||
SLT_DBG(slot, "spec checkpoint saved (mode=%s), n_past_pre_spec=%d\n",
|
SLT_DBG(slot, "spec checkpoint saved (mode=%s), n_past_pre_spec=%d\n",
|
||||||
mode_name, slot.spec_ckpt.n_past);
|
mode_name, ckpt->n_past);
|
||||||
} else {
|
} else {
|
||||||
SLT_WRN(slot, "%s", "failed to save spec checkpoint\n");
|
SLT_WRN(slot, "%s", "failed to save spec checkpoint\n");
|
||||||
}
|
}
|
||||||
|
|||||||
@ -22,16 +22,6 @@ enum slot_command {
|
|||||||
SLOT_COMMAND_RELEASE,
|
SLOT_COMMAND_RELEASE,
|
||||||
};
|
};
|
||||||
|
|
||||||
struct server_speculative_checkpoint {
|
|
||||||
bool valid = false;
|
|
||||||
bool per_step_enabled = false; // per-step SSM checkpoints active
|
|
||||||
llama_pos n_past = 0;
|
|
||||||
llama_token sampled = LLAMA_TOKEN_NULL;
|
|
||||||
common_sampler * sampler = nullptr; // saved sampler state
|
|
||||||
|
|
||||||
void clear();
|
|
||||||
};
|
|
||||||
|
|
||||||
struct server_slot {
|
struct server_slot {
|
||||||
int id;
|
int id;
|
||||||
int id_task = -1;
|
int id_task = -1;
|
||||||
@ -39,9 +29,6 @@ struct server_slot {
|
|||||||
|
|
||||||
struct slot_params params;
|
struct slot_params params;
|
||||||
|
|
||||||
llama_batch batch_spec = {};
|
|
||||||
llama_context * ctx_dft = nullptr;
|
|
||||||
|
|
||||||
bool released = false;
|
bool released = false;
|
||||||
slot_state state = SLOT_STATE_IDLE;
|
slot_state state = SLOT_STATE_IDLE;
|
||||||
slot_command command = SLOT_COMMAND_NONE;
|
slot_command command = SLOT_COMMAND_NONE;
|
||||||
@ -136,7 +123,6 @@ struct server_slot {
|
|||||||
// sampling
|
// sampling
|
||||||
llama_token sampled; // in speculative mode, this is the last accepted token
|
llama_token sampled; // in speculative mode, this is the last accepted token
|
||||||
llama_tokens drafted;
|
llama_tokens drafted;
|
||||||
common_speculative_type drafted_spec_type = COMMON_SPECULATIVE_TYPE_NONE;
|
|
||||||
|
|
||||||
json json_schema;
|
json json_schema;
|
||||||
|
|
||||||
@ -171,11 +157,6 @@ struct server_slot {
|
|||||||
// expiring logit bias
|
// expiring logit bias
|
||||||
std::vector<common_sampler::elb_state> prev_elb_states;
|
std::vector<common_sampler::elb_state> prev_elb_states;
|
||||||
|
|
||||||
bool has_mtp = false;
|
|
||||||
|
|
||||||
// saves recurrent state before a speculative batch so it can be restored on rejection
|
|
||||||
server_speculative_checkpoint spec_ckpt;
|
|
||||||
|
|
||||||
// speculative decoding stats
|
// speculative decoding stats
|
||||||
int32_t n_draft_total = 0; // Total draft tokens generated
|
int32_t n_draft_total = 0; // Total draft tokens generated
|
||||||
int32_t n_draft_accepted = 0; // Draft tokens actually accepted
|
int32_t n_draft_accepted = 0; // Draft tokens actually accepted
|
||||||
@ -195,6 +176,7 @@ struct server_slot {
|
|||||||
void reset();
|
void reset();
|
||||||
|
|
||||||
bool need_embd() const;
|
bool need_embd() const;
|
||||||
|
bool uses_mtp() const;
|
||||||
|
|
||||||
bool has_budget(gpt_params& global_params);
|
bool has_budget(gpt_params& global_params);
|
||||||
|
|
||||||
@ -266,11 +248,6 @@ struct server_context {
|
|||||||
// multimodal
|
// multimodal
|
||||||
mtmd_context* mctx = nullptr;
|
mtmd_context* mctx = nullptr;
|
||||||
|
|
||||||
// For speculative decoding
|
|
||||||
llama_model* model_draft = nullptr;
|
|
||||||
llama_context* ctx_draft = nullptr;
|
|
||||||
llama_context_params cparams_dft;
|
|
||||||
|
|
||||||
int32_t n_ctx; // total context for all clients / slots
|
int32_t n_ctx; // total context for all clients / slots
|
||||||
|
|
||||||
// system prompt
|
// system prompt
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user