From a1824902b573134458945b0c7973e105a7837b59 Mon Sep 17 00:00:00 2001 From: Ruixiang Wang Date: Tue, 16 Jun 2026 11:05:52 +0200 Subject: [PATCH] spec: add backend sampling support for eagle3 (#24655) --- common/speculative.cpp | 33 ++++++++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index a744c79ae5..6f387f2cfc 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -418,6 +418,9 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl { std::vector smpls; + // backend sampler chain per seq, attached to ctx_dft + std::vector backend_chains; + int32_t n_embd_dec = 0; // draft hidden size int32_t n_embd_enc = 0; // target_layer_ids_n * target_hidden_size int32_t n_embd_tgt = 0; // target model hidden size @@ -443,7 +446,7 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl { , params(params.draft) { LOG_INF("%s: adding speculative implementation 'draft-eagle3'\n", __func__); - LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%f\n", __func__, params.draft.n_max, params.draft.n_min, params.draft.p_min); + LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%f, backend_sampling=%d\n", __func__, params.draft.n_max, params.draft.n_min, params.draft.p_min, (int) params.draft.backend_sampling); auto * ctx_tgt = this->params.ctx_tgt; auto * ctx_dft = this->params.ctx_dft; @@ -478,6 +481,22 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl { s.reset(common_sampler_init(llama_get_model(ctx_dft), sparams)); } + // offload draft sampling to the backend + backend_chains.assign(n_seq, nullptr); + if (this->params.backend_sampling) { + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { + llama_sampler * chain = llama_sampler_chain_init(llama_sampler_chain_default_params()); + llama_sampler_chain_add(chain, llama_sampler_init_top_k(10)); + + if (!llama_set_sampler(ctx_dft, seq_id, chain)) { + LOG_WRN("%s: backend offload failed for seq_id=%d; using CPU sampler\n", __func__, (int) seq_id); + llama_sampler_free(chain); + chain = nullptr; + } + backend_chains[seq_id] = chain; + } + } + // turn on extraction of the target layers' input embeddings for (uint32_t k = 0; k < target_layer_ids_n; ++k) { llama_set_embeddings_layer_inp(ctx_tgt, (uint32_t) target_layer_ids[k], true); @@ -496,6 +515,18 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl { } ~common_speculative_impl_draft_eagle3() override { + auto * ctx_dft = this->params.ctx_dft; + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) backend_chains.size(); ++seq_id) { + if (backend_chains[seq_id] == nullptr) { + continue; + } + if (ctx_dft) { + llama_set_sampler(ctx_dft, seq_id, nullptr); + } + llama_sampler_free(backend_chains[seq_id]); + } + backend_chains.clear(); + if (batch.token != nullptr) { free(batch.token); batch.token = nullptr;