ik_llama.cpp/common/sampling.h
dungquixote42 642c038ccd
Extend expiring logit bias to other sampling parameters (#1770)
* initial commit

* fix underflow bug, add debug prints, update macro/variable names

* fix phrases-sharing-1-flag bug, replace macros with struct member function

* cleanup

* fix file parsing

* string_split_open_close() -> string_extract(), improve escape handling

* support multiple nested entries

* make persistent entries global, simplify file parsing

* cosmetic changes

* add support for jumping to exitword

* update variable names

* fix bad search bug

* better debug prints, reorg

* replace lambda with string_is_found(), add string_unescape() for debug

* add support for inline comments

* add missing debug print macro

* fix type promotion bug

* actually fix type promotion bug
2026-05-23 19:19:12 +03:00

369 lines
16 KiB
C++

#pragma once
#include "llama.h"
#include "llama-grammar.h"
#include "reasoning-budget.h"
#include <set>
#include <random>
#include <string>
#include <unordered_map>
#include <vector>
#define A_DOT_B(a, b) a.b
// sampler types
enum class llama_sampler_type : char {
DRY = 'd',
TOP_K = 'k',
TOP_P = 'p',
MIN_P = 'm',
TFS_Z = 'f',
XTC = 'x',
TOP_N_SIGMA = 'n',
TYPICAL_P = 'y',
TEMPERATURE = 't',
ADAPTIVE_P = 'w',
DIST = 's',
};
enum common_grammar_trigger_type {
COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN,
COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL,
};
struct common_grammar_trigger {
common_grammar_trigger_type type;
std::string value;
llama_token token = LLAMA_TOKEN_NULL;
// T can only be nlohmann::ordered_json
template <class T> T to_json() const;
template <class T> static common_grammar_trigger from_json(const T& in);
};
// Grammar type enumeration
enum common_grammar_type {
COMMON_GRAMMAR_TYPE_NONE, // no grammar set
COMMON_GRAMMAR_TYPE_USER, // user-provided GBNF (--grammar / "grammar" API field)
COMMON_GRAMMAR_TYPE_OUTPUT_FORMAT, // auto-generated from JSON schema (--json-schema / "json_schema" API field)
COMMON_GRAMMAR_TYPE_TOOL_CALLS, // auto-generated by chat template parser for function calling
};
// Grammar variant struct with type and grammar string
struct common_grammar {
common_grammar_type type = COMMON_GRAMMAR_TYPE_NONE;
std::string grammar;
// Default constructor - no grammar
common_grammar() = default;
// Constructor with type and grammar string
common_grammar(common_grammar_type t, std::string g) : type(t), grammar(std::move(g)) {
GGML_ASSERT(type != COMMON_GRAMMAR_TYPE_NONE || !grammar.empty());
}
// Check if a grammar is set
bool empty() const { return type == COMMON_GRAMMAR_TYPE_NONE || grammar.empty(); }
};
// Returns the raw grammar string, or empty string if no grammar is set.
inline const std::string & common_grammar_value(const common_grammar & g) {
return g.grammar;
}
// Returns true when the generation_prompt should be prefilled into the grammar sampler.
// Only output-format and tool-call grammars need prefill; user-supplied grammars must not be prefilled.
inline bool common_grammar_needs_prefill(const common_grammar & g) {
return g.type == COMMON_GRAMMAR_TYPE_OUTPUT_FORMAT
|| g.type == COMMON_GRAMMAR_TYPE_TOOL_CALLS;
}
#define X_COMMON_PARAMS_SAMPLING /* \
*/ X( int32_t , min_keep , 0 , std::round ) /* 0 = disabled, otherwise samplers should return at least min_keep tokens \
*/ X( int32_t , top_k , 40 , std::round ) /* <= 0 to use vocab size \
*/ X( float , top_p , 0.95f , ) /* 1.0 = disabled \
*/ X( float , min_p , 0.05f , ) /* 0.0 = disabled \
*/ X( float , tfs_z , 1.00f , ) /* 1.0 = disabled \
*/ X( float , typical_p , 1.00f , ) /* 1.0 = disabled \
*/ X( float , temp , 0.80f , ) /* <= 0.0 to sample greedily, 0.0 to not output probabilities \
*/ X( float , dynatemp_range , 0.00f , ) /* 0.0 = disabled \
*/ X( float , dynatemp_exponent , 1.00f , ) /* controls how entropy maps to temperature in dynamic temperature sampler \
*/ X( int32_t , penalty_last_n , 64 , std::round ) /* last n tokens to penalize (0 = disable penalty, -1 = context size) \
*/ X( float , penalty_repeat , 1.00f , ) /* 1.0 = disabled \
*/ X( float , penalty_freq , 0.00f , ) /* 0.0 = disabled \
*/ X( float , penalty_present , 0.00f , ) /* 0.0 = disabled \
*/ X( float , dry_multiplier , 0.0f , ) /* 0.0 = disabled; DRY repetition penalty for tokens extending repetition: \
*/ X( float , dry_base , 1.75f , ) /* 0.0 = disabled; multiplier * base ^ (length of sequence before token - allowed length) \
*/ X( int32_t , dry_allowed_length , 2 , std::round ) /* tokens extending repetitions beyond this receive penalty \
*/ X( int32_t , dry_penalty_last_n , -1 , std::round ) /* how many tokens to scan for repetitions (0 = disable penalty, -1 = context size) \
*/ X( int32_t , mirostat , 0 , std::round ) /* 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 \
*/ X( float , mirostat_tau , 5.00f , ) /* target entropy \
*/ X( float , mirostat_eta , 0.10f , ) /* learning rate \
*/ X( float , xtc_probability , 0.0f , ) /* xtc probability \
*/ X( float , xtc_threshold , 1.0f , ) /* xtc threshold, disabled if > 0.5 \
*/ X( float , top_n_sigma , 0.0f , ) /* top-n-sigma \
*/ X( float , adaptive_target , -1.0f , ) /* select tokens near this probability (valid range 0.0 to 1.0; <0 = disabled) \
*/ X( float , adaptive_decay , 0.90f , ) /* decay rate for target adaptation over time. lower values -> faster but less stable adaptation. (valid range 0.0 to 1.0; ≤0 = no adaptation) \
*/ X( bool , adaptive_updt_w_cur , false , std::round ) /* update state with current probability \
*/
enum {
#undef X
#define X(T, MEMBER, DV, PRECAST) SPARAMS_ ## MEMBER ## _ENUM,
X_COMMON_PARAMS_SAMPLING
};
// sampling parameters
typedef struct common_params_sampling {
#undef X
#define X(T, MEMBER, DV, _) T MEMBER = DV;
X_COMMON_PARAMS_SAMPLING
int32_t n_prev = 64; // number of previous tokens to remember
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
int32_t total_context_size = 16840;
bool penalize_nl = false; // consider newlines as a repeatable token
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context
std::vector<std::string> dry_sequence_breakers = { "\n", ":", "\"", "*" }; // default sequence breakers for DRY
std::vector<llama_sampler_type> samplers_sequence = {
llama_sampler_type::DRY,
llama_sampler_type::TOP_K,
llama_sampler_type::TFS_Z,
llama_sampler_type::TYPICAL_P,
llama_sampler_type::TOP_P,
llama_sampler_type::MIN_P,
llama_sampler_type::XTC,
llama_sampler_type::TOP_N_SIGMA,
llama_sampler_type::TEMPERATURE,
llama_sampler_type::ADAPTIVE_P,
llama_sampler_type::DIST,
};
//std::string grammar; // optional BNF-like grammar to constrain sampling
common_grammar grammar; // optional grammar constraint (user / output-format / tool-calls)
bool grammar_lazy = false;
std::vector<common_grammar_trigger> grammar_triggers; // optional triggers (for lazy grammars)
std::set<llama_token> preserved_tokens;
// Classifier-Free Guidance
// https://arxiv.org/abs/2306.17806
std::string cfg_negative_prompt; // string to help guidance
float cfg_scale = 1.f; // how strong is guidance
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
// The assistant generation prompt already prefilled into the prompt.
// Fed to the grammar sampler (to advance past pre-existing tokens) and used
// to determine the reasoning budget sampler's initial state.
// Only applied when the grammar is of output-format or tool-calls type.
std::string generation_prompt;
// reasoning budget sampler parameters
// these are populated by the server/CLI based on chat template params
int32_t reasoning_budget_tokens = -1; // -1 = disabled, >= 0 = token budget
std::vector<llama_token> reasoning_budget_start; // start tag token sequence
std::vector<llama_token> reasoning_budget_end; // end tag token sequence
std::vector<llama_token> reasoning_budget_forced; // forced sequence (message + end tag)
std::vector<llama_token> penalty_prompt_tokens;
bool use_penalty_prompt_tokens = false;
// expiring logit bias
struct elb_param {
struct elb_entry {
std::vector<size_t> posi; // positions of phrases in generated text
std::vector<float> addsubs; // add/modify then subtract/restore sampling parameters
std::vector<bool> addflags; // true if added
size_t max_phrase_len;
std::vector<std::string> phrases;
std::vector<float> biases; // for each phrase, nth bias for nth token, extrapolate
int32_t duration; // bias duration, unless exitword matches
bool is_range; // has lower and upper biases
bool operator == (const struct elb_entry& other) const {
return (is_range == other.is_range)
&& (duration == other.duration)
&& (biases == other.biases)
&& (phrases == other.phrases)
&& (addflags == other.addflags)
&& (addsubs == other.addsubs)
&& (posi == other.posi);
}
};
std::vector<struct elb_entry> entries;
std::string exitword; // move to next state if matched during generation
std::string op; // exitword operator
bool operator == (const struct elb_param& other) const {
return (op == other.op)
&& (exitword == other.exitword)
&& (entries == other.entries);
}
};
std::vector<struct elb_param> elb_params;
} llama_sampling_params;
// general sampler context
// TODO: move to llama.h
struct common_sampler {
// parameters that will be used for sampling
common_params_sampling params;
// mirostat sampler state
float mirostat_mu;
std::string grammar_str;
std::string grammar_root;
llama_grammar * grammar;
// TODO: replace with ring-buffer
std::vector<llama_token> prev;
std::vector<llama_token_data> cur;
llama_sampler_dry* smpl;
llama_sampler_adaptive_p * adapt_p_ctx; // adaptive p sampler
common_reasoning_budget_ctx * rbudget; // reasoning budget sampler
size_t n_valid; // Number of correct top tokens with correct probabilities.
llama_token_data_array cur_p; // current candidates
std::mt19937 rng;
std::vector<float>* server_biases;
std::string drafted_text;
std::string* to_generated_text = nullptr;
// expiring logit bias
struct elb_state {
struct elb_token {
int32_t id;
float bias;
size_t duration;
std::string cond; // bias activation condition
};
std::vector<struct elb_token> first_tokens; // first token of each phrase
std::vector<struct elb_token> other_tokens;
std::string exitword;
size_t countup; // compare against duration
size_t delay; // to avoid early termination of positively biased phrases
int32_t max_cond_len;
std::string jumpword;
size_t jump_idx;
size_t search_word_len;
};
std::vector<struct elb_state> elb_states;
size_t elb_idx; // for elb_states
size_t elb_search_pos;
};
// Create a new sampling context instance.
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params);
void common_sampler_free(struct common_sampler * ctx);
// Reset the sampler context
// - clear prev tokens
// - reset grammar
void common_sampler_reset(common_sampler * ctx);
// Review stateful samplers
// - rewind internal states (maybe)
void common_sampler_review(common_sampler * ctx, const size_t n_unsent, const bool rewind_status);
// Set the sampler seed
void llama_sampling_set_rng_seed(struct common_sampler * ctx, uint32_t seed);
// Copy the sampler context
void common_sampler_clone(common_sampler * src, common_sampler * dst);
// Get the last sampled token
llama_token llama_sampling_last(common_sampler * ctx);
// Get a string representation of the last sampled tokens
std::string llama_sampling_prev_str(common_sampler * ctx_sampling, llama_context * ctx_main, int n);
// Print sampling parameters into a string
std::string llama_sampling_print(const common_params_sampling & params);
// Print sampling order into a string
std::string llama_sampling_order_print(const common_params_sampling & params);
std::string llama_sampling_type_to_str(llama_sampler_type sampler_type);
std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
std::vector<llama_sampler_type> llama_sampling_types_from_chars(const std::string & names_string);
// this is a common sampling function used across the examples for convenience
// it can serve as a starting point for implementing your own sampling function
// Note: When using multiple sequences, it is the caller's responsibility to call
// common_sampler_reset when a sequence ends
//
// required:
// - ctx_main: context to use for sampling
// - ctx_sampling: sampling-specific context
//
// optional:
// - ctx_cfg: context to use for classifier-free guidance
// - idx: sample from llama_get_logits_ith(ctx, idx)
//
// returns:
// - token: sampled token
// - candidates: vector of candidate tokens
//
llama_token common_sampler_sample_legacy(
struct common_sampler * ctx_sampling,
struct llama_context * ctx_main,
struct llama_context * ctx_cfg,
int idx = -1);
llama_token common_sampler_sample(
struct common_sampler * ctx_sampling,
struct llama_context * ctx_main,
int idx = -1,
bool grammar_first = false);
// Prepares and adjusts the set of token candidates for sampling based on penalties, biases, and sampling parameters.
llama_token_data_array llama_sampling_prepare(
struct common_sampler * ctx_sampling,
struct llama_context * ctx_main,
struct llama_context * ctx_cfg,
int idx = 0,
bool apply_grammar = true,
std::vector<float> * original_logits = nullptr);
// if is_generated is true, the token is accepted by the sampling chain, the reasoning budget sampler, and the grammar sampler
void common_sampler_accept(
struct common_sampler * ctx_sampling,
struct llama_context * ctx_main,
llama_token id,
bool is_generated);
// returns at least 1 token, up to draft.size()
// access the internal list of current candidate tokens
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * ctx_sampling, bool do_sort = false);
std::vector<llama_token> llama_sampling_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<llama_token> & draft);
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const std::vector<llama_token> & draft, bool grammar_first = false);
// Greedy argmax sampling for speculative drafting
llama_token common_sampler_sample_speculative(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, float * out_prob = nullptr);
void common_expiring_logit_bias_apply(struct common_sampler* ctx_sampling, float* logits);
void common_expiring_logit_bias_accept(struct common_sampler* ctx_sampling, struct llama_context * ctx_main);
llama_grammar* llama_sampler_init_llg(const llama_vocab* vocab,
const char* grammar_kind, const char* grammar_data);