mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-06-28 04:30:15 -05:00
* wip: test logic to use multiple specs * feat: introduce composite speculative decoding stages * handle MTP context and draft invalidation * fix: allow gemma mtp for speculative stages * fix: normalize spec stage keys * refactor: remove enable_mtp flag and improve speculative stage handling * fix: update cached text tokens handling for stage chains * feat: implement sync for external MTP after non-MTP accept
78 lines
2.9 KiB
C++
78 lines
2.9 KiB
C++
#pragma once
|
|
|
|
#include "llama.h"
|
|
#include "common.h"
|
|
#include "spec-tuner.h"
|
|
|
|
struct common_speculative;
|
|
|
|
// comma separated list of all types
|
|
std::string common_speculative_type_name_str();
|
|
|
|
// convert string to type
|
|
enum common_speculative_type common_speculative_type_from_name(const std::string & name);
|
|
|
|
// convert type to string
|
|
std::string common_speculative_type_to_str(enum common_speculative_type type);
|
|
|
|
// check if the llama_context is compatible for speculative decoding
|
|
// note: clears the memory of the context
|
|
bool common_speculative_is_compat(llama_context * ctx_tgt);
|
|
|
|
common_speculative * common_speculative_init(
|
|
common_params_speculative & params,
|
|
llama_context * ctx_tgt);
|
|
|
|
void common_speculative_free(common_speculative * spec);
|
|
|
|
// optionally call once at the beginning of a new generation
|
|
void common_speculative_begin(common_speculative * spec, const llama_tokens & prompt);
|
|
|
|
// sample up to n_draft tokens and add them to the batch using the draft model
|
|
// draft_base_pos/draft_seq_id override the MTP position for id_last
|
|
llama_tokens common_speculative_draft(
|
|
common_speculative * spec,
|
|
common_params_speculative & params,
|
|
const llama_tokens & prompt,
|
|
llama_token id_last,
|
|
llama_pos draft_base_pos = -1,
|
|
llama_seq_id draft_seq_id = 0);
|
|
|
|
// informs the speculative decoder that n_accepted tokens were accepted by the target model
|
|
void common_speculative_accept(common_speculative * spec, uint16_t n_accepted);
|
|
|
|
// print statistics about the speculative decoding
|
|
void common_speculative_print_stats(const common_speculative * spec, double slot_tps = 0.0, int n_decoded = 0, int n_past = 0, common_params_speculative * active_params = nullptr);
|
|
|
|
// get the MTP context from the speculative object (nullptr if not MTP type)
|
|
llama_context * common_speculative_get_mtp_ctx(common_speculative * spec);
|
|
common_speculative_type common_speculative_current_type(const common_speculative * spec);
|
|
|
|
// Context shift for MTP to match how server handle main model
|
|
void common_speculative_context_shift(
|
|
common_speculative * spec,
|
|
llama_seq_id seq_id,
|
|
llama_pos kv_keep,
|
|
llama_pos kv_discard,
|
|
llama_pos kv_past);
|
|
|
|
// Generates speculative draft tokens using the Multi-Token Prediction (MTP) architecture.
|
|
std::vector<llama_token> mtp_speculative_gen_draft(
|
|
struct common_sampler * smpl,
|
|
struct llama_context * ctx,
|
|
int n_draft,
|
|
float p_min,
|
|
llama_token id_last,
|
|
llama_pos n_past,
|
|
llama_seq_id seq_id,
|
|
bool constant_draft_positions = false);
|
|
|
|
int32_t mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, bool is_prompt_warmup);
|
|
|
|
void mtp_accept_tokens(
|
|
struct llama_context * ctx,
|
|
const std::vector<llama_token> & ids,
|
|
int32_t n_past_base,
|
|
llama_seq_id seq_id
|
|
);
|