Move to backend sampling for MTP draft path (#23287)

* Move to backend sampling for MTP draft path

Run top_k(10) on the draft backend. D2H transfers happen only for the top 10 logits

Make backend sampling more robust and fallback to CPU on failure cases, such as with "-sm tensor" or when a backend doesn't support TOP_K.

* Allow sampler chains to be partially offloaded to backend

* Add --spec-draft-backend-sampling argument. Enabled by default.
This commit is contained in:
Gaurav Garg 2026-05-20 22:34:45 +05:30 committed by GitHub
parent 3a6db741a8
commit ad27757261
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 56 additions and 1 deletions

View File

@ -3591,6 +3591,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.speculative.draft.p_min = std::stof(value);
}
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_DRAFT_P_MIN"));
add_opt(common_arg(
{"--spec-draft-backend-sampling"},
{"--no-spec-draft-backend-sampling"},
string_format("offload draft sampling to the backend (default: %s)",
params.speculative.draft.backend_sampling ? "enabled" : "disabled"),
[](common_params & params, bool value) {
params.speculative.draft.backend_sampling = value;
}
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_DRAFT_BACKEND_SAMPLING"));
add_opt(common_arg(
{"--spec-draft-device", "-devd", "--device-draft"}, "<dev1,dev2,..>",
"comma-separated list of devices to use for offloading the draft model (none = don't offload)\n"

View File

@ -305,6 +305,8 @@ struct common_params_speculative_draft {
float p_split = 0.1f; // speculative decoding split probability
float p_min = 0.0f; // minimum speculative decoding probability (greedy)
bool backend_sampling = true; // offload draft sampling to the backend (default: on)
common_params_model mparams;
llama_context * ctx_tgt = nullptr;

View File

@ -413,6 +413,9 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
std::vector<common_sampler_ptr> smpls;
// backend sampler chain per seq, attached to ctx_dft
std::vector<llama_sampler *> backend_chains;
int32_t n_embd = 0;
// Per-sequence cross-batch carryover: pair (h_p, x_{p+1}) at MTP pos p+1.
@ -444,7 +447,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
n_embd = llama_model_n_embd(llama_get_model(ctx_dft));
LOG_INF("%s: adding speculative implementation 'draft-mtp'\n", __func__);
LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%.2f, n_embd=%d\n", __func__, this->params.n_max, this->params.n_min, this->params.p_min, n_embd);
LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%.2f, n_embd=%d, backend_sampling=%d\n", __func__, this->params.n_max, this->params.n_min, this->params.p_min, n_embd, (int) this->params.backend_sampling);
LOG_INF("%s: - gpu_layers=%d, cache_k=%s, cache_v=%s, ctx_tgt=%s, ctx_dft=%s, devices=[%s]\n", __func__,
this->params.n_gpu_layers,
ggml_type_name(this->params.cache_type_k),
@ -468,6 +471,22 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
s.reset(common_sampler_init(llama_get_model(ctx_dft), sparams));
}
// offload draft sampling to the backend
backend_chains.assign(n_seq, nullptr);
if (this->params.backend_sampling) {
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
llama_sampler * chain = llama_sampler_chain_init(llama_sampler_chain_default_params());
llama_sampler_chain_add(chain, llama_sampler_init_top_k(10));
if (!llama_set_sampler(ctx_dft, seq_id, chain)) {
LOG_WRN("%s: backend offload failed for seq_id=%d; using CPU sampler\n", __func__, (int) seq_id);
llama_sampler_free(chain);
chain = nullptr;
}
backend_chains[seq_id] = chain;
}
}
llama_set_embeddings_pre_norm(ctx_tgt, true, /*masked*/ false);
llama_set_embeddings_pre_norm(ctx_dft, true, /*masked*/ true);
@ -483,6 +502,18 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
}
~common_speculative_impl_draft_mtp() override {
auto * ctx_dft = this->params.ctx_dft;
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) backend_chains.size(); ++seq_id) {
if (backend_chains[seq_id] == nullptr) {
continue;
}
if (ctx_dft) {
llama_set_sampler(ctx_dft, seq_id, nullptr);
}
llama_sampler_free(backend_chains[seq_id]);
}
backend_chains.clear();
if (batch.token != nullptr) {
free(batch.token);
batch.token = nullptr;

View File

@ -1137,6 +1137,19 @@ bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) {
LLAMA_LOG_DEBUG("%s: seq_id = %d, sampler = %p\n", __func__, (int) seq_id, (void *) sampler);
if (sampler && model.split_mode() == LLAMA_SPLIT_MODE_TENSOR) {
static bool warned = false;
if (!warned) {
LLAMA_LOG_WARN("%s: backend sampling not supported with SPLIT_MODE_TENSOR; using CPU\n", __func__);
warned = true;
}
if (sampling.samplers.count(seq_id) > 0) {
sched_need_reserve = true;
}
sampling.samplers.erase(seq_id);
return false;
}
const bool can_offload =
sampler &&
sampler->iface->backend_init &&