ik_llama.cpp/common/speculative.h
Samuel Oliveira Alves f4f4b3ff26
Allow dual speculative decoding (#1789)
* wip: test logic to use multiple specs

* feat: introduce composite speculative decoding stages

* handle MTP context and draft invalidation

* fix: allow gemma mtp for speculative stages

* fix: normalize spec stage keys

* refactor: remove enable_mtp flag and improve speculative stage handling

* fix: update cached text tokens handling for stage chains

* feat: implement sync for external MTP after non-MTP accept
2026-05-15 10:10:40 +03:00

78 lines
2.9 KiB
C++

#pragma once
#include "llama.h"
#include "common.h"
#include "spec-tuner.h"
struct common_speculative;
// comma separated list of all types
std::string common_speculative_type_name_str();
// convert string to type
enum common_speculative_type common_speculative_type_from_name(const std::string & name);
// convert type to string
std::string common_speculative_type_to_str(enum common_speculative_type type);
// check if the llama_context is compatible for speculative decoding
// note: clears the memory of the context
bool common_speculative_is_compat(llama_context * ctx_tgt);
common_speculative * common_speculative_init(
common_params_speculative & params,
llama_context * ctx_tgt);
void common_speculative_free(common_speculative * spec);
// optionally call once at the beginning of a new generation
void common_speculative_begin(common_speculative * spec, const llama_tokens & prompt);
// sample up to n_draft tokens and add them to the batch using the draft model
// draft_base_pos/draft_seq_id override the MTP position for id_last
llama_tokens common_speculative_draft(
common_speculative * spec,
common_params_speculative & params,
const llama_tokens & prompt,
llama_token id_last,
llama_pos draft_base_pos = -1,
llama_seq_id draft_seq_id = 0);
// informs the speculative decoder that n_accepted tokens were accepted by the target model
void common_speculative_accept(common_speculative * spec, uint16_t n_accepted);
// print statistics about the speculative decoding
void common_speculative_print_stats(const common_speculative * spec, double slot_tps = 0.0, int n_decoded = 0, int n_past = 0, common_params_speculative * active_params = nullptr);
// get the MTP context from the speculative object (nullptr if not MTP type)
llama_context * common_speculative_get_mtp_ctx(common_speculative * spec);
common_speculative_type common_speculative_current_type(const common_speculative * spec);
// Context shift for MTP to match how server handle main model
void common_speculative_context_shift(
common_speculative * spec,
llama_seq_id seq_id,
llama_pos kv_keep,
llama_pos kv_discard,
llama_pos kv_past);
// Generates speculative draft tokens using the Multi-Token Prediction (MTP) architecture.
std::vector<llama_token> mtp_speculative_gen_draft(
struct common_sampler * smpl,
struct llama_context * ctx,
int n_draft,
float p_min,
llama_token id_last,
llama_pos n_past,
llama_seq_id seq_id,
bool constant_draft_positions = false);
int32_t mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, bool is_prompt_warmup);
void mtp_accept_tokens(
struct llama_context * ctx,
const std::vector<llama_token> & ids,
int32_t n_past_base,
llama_seq_id seq_id
);