Extend expiring logit bias to other sampling parameters (#1770)

* initial commit

* fix underflow bug, add debug prints, update macro/variable names

* fix phrases-sharing-1-flag bug, replace macros with struct member function

* cleanup

* fix file parsing

* string_split_open_close() -> string_extract(), improve escape handling

* support multiple nested entries

* make persistent entries global, simplify file parsing

* cosmetic changes

* add support for jumping to exitword

* update variable names

* fix bad search bug

* better debug prints, reorg

* replace lambda with string_is_found(), add string_unescape() for debug

* add support for inline comments

* add missing debug print macro

* fix type promotion bug

* actually fix type promotion bug
This commit is contained in:
dungquixote42 2026-05-23 12:19:12 -04:00 committed by GitHub
parent 40d8cb196a
commit 642c038ccd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 402 additions and 142 deletions

View File

@ -2966,6 +2966,8 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
options.push_back({ "*", " -l TOKEN_ID(+/-)BIAS", "modifies the likelihood of token appearing in the completion,\n" options.push_back({ "*", " -l TOKEN_ID(+/-)BIAS", "modifies the likelihood of token appearing in the completion,\n"
"i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n" "i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n"
"or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'" }); "or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'" });
options.push_back({ "*", " --expiring-logit-bias-file",
"original PR: https://github.com/ikawrakow/ik_llama.cpp/pull/1731\n"});
options.push_back({ "main", " --cfg-negative-prompt PROMPT", options.push_back({ "main", " --cfg-negative-prompt PROMPT",
"negative prompt to use for guidance (default: '%s')", sparams.cfg_negative_prompt.c_str() }); "negative prompt to use for guidance (default: '%s')", sparams.cfg_negative_prompt.c_str() });
options.push_back({ "main", " --cfg-negative-prompt-file FNAME", options.push_back({ "main", " --cfg-negative-prompt-file FNAME",
@ -3491,6 +3493,28 @@ void string_process_escapes(std::string & input) {
input.resize(output_idx); input.resize(output_idx);
} }
std::string string_unescape(const std::string& str) {
std::string result;
result.reserve(2 * str.length());
for (const auto c: str) {
switch (c) {
case '\n':
result.append("\\n");
break;
case '\t':
result.append("\\t");
break;
case '\r':
result.append("\\r");
break;
default:
result.append(1, c);
break;
}
}
return result;
}
bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides) { bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides) {
const char * sep = strchr(data, '='); const char * sep = strchr(data, '=');
if (sep == nullptr || sep - data >= 128) { if (sep == nullptr || sep - data >= 128) {
@ -3537,6 +3561,42 @@ bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_over
return true; return true;
} }
std::vector<std::string> string_extract(const std::string& str, const char c, std::vector<size_t>& posi) {
std::vector<std::string> extracts;
auto pos = str.find(c);
size_t count = 0;
while (pos != std::string::npos) {
if (count % 2 == 0) {
// opening c
posi.push_back(pos);
++count;
} else {
// closing c must be unescaped
auto esc_pos = pos;
size_t n_esc = 0;
while ((esc_pos > 0) && (str[--esc_pos] == '\\')) {
++n_esc;
}
if (n_esc % 2 == 0) {
extracts.push_back(str.substr(posi.back() + 1, pos - posi.back() - 1));
string_process_escapes(extracts.back());
posi.push_back(pos);
++count;
}
}
pos = str.find(c, pos + 1);
}
return extracts;
}
bool string_is_found(const std::string& window, const std::string& str, size_t& pos) {
if (str.empty()) {
return false;
}
pos = window.find(str);
return pos != std::string::npos;
}
// //
// Filesystem utils // Filesystem utils
// //
@ -5170,121 +5230,169 @@ std::tuple<uint32_t, uint32_t, std::string, float> argparse_allowlist_unicode_ru
} }
void argparse_expiring_logit_bias(const std::string& content, common_params_sampling& sparams) { void argparse_expiring_logit_bias(const std::string& content, common_params_sampling& sparams) {
decltype(sparams.elb_params) elb_params = { { { }, "" } }; auto elb_params = sparams.elb_params;
elb_params.push_back({ { }, "", "" });
auto entries = elb_params[0].entries;
int32_t saved_duration = 0; const auto lines = string_split(content, "\n");
std::vector<std::string> saved_phrases; for (size_t i = 0; i < lines.size(); ++i) {
std::vector<float> saved_biases; auto line = string_strip(lines[i]);
bool saved_is_range = false;
for (auto line: string_split(content, "\n")) {
string_strip(line);
const char c0 = line.empty() ? '#' : line[0]; const char c0 = line.empty() ? '#' : line[0];
if (c0 == '#') { if (c0 == '#') {
// comment LLAMA_LOG_DEBUG("%s: line %zu: comment or empty\n", __func__, i);
continue; // next line continue; // next line
} }
// (... "EXTRACT" ... "EXTRACT" ...)
std::vector<size_t> qq_posi = { 0 };
auto extracts = string_extract(line, '"', qq_posi);
qq_posi.push_back(std::string::npos);
for (int32_t j = 0; j < int32_t(qq_posi.size()) - 1; j += 2) {
const auto pnd_pos = line.find('#', qq_posi[j]);
if (pnd_pos < qq_posi[j + 1]) {
LLAMA_LOG_DEBUG("%s: line %zu: inline comment @ %zu\n", __func__, i, pnd_pos);
line = string_strip(line.substr(0, pnd_pos));
qq_posi.resize(j + 2);
qq_posi.back() = std::string::npos;
extracts.resize(j / 2);
break;
}
}
const auto last_qq_pos = qq_posi[qq_posi.size() - 2];
auto n_char = line.length(); auto n_char = line.length();
const char cE = line[n_char - 1]; const char cE = line[n_char - 1];
if (n_char > 1) { LLAMA_LOG_DEBUG("%s: line %zu: %s\n", __func__, i, line.c_str());
if ('(' == c0 && cE == ')') { if ('(' == c0 && cE == ')') {
const bool is_nested = '(' == line[1] && line[n_char - 2] == ')'; const bool is_nested = '(' == line[1] && line[n_char - 2] == ')';
if (is_nested) { if (is_nested) {
if (n_char == 4) { if (n_char == 4) {
// (()) // (())
saved_phrases.clear(); entries.clear();
saved_biases.clear(); LLAMA_LOG_DEBUG("%s: line %zu: persistent entry clear\n", __func__, i);
continue; // next line
}
n_char -= 2;
line = line.substr(1, n_char);
}
auto qqpos = line.find('"');
// (DURATION : ...)
int32_t duration = is_nested ? -1 : 1;
const auto cpos = line.find(':');
if ((cpos != std::string::npos) && (1 < cpos) && (cpos < qqpos)) {
auto sub = line.substr(1, cpos - 1);
duration = std::stoi(sub);
}
if (duration == 0) {
continue; // next line continue; // next line
} }
n_char -= 2;
line = line.substr(1, n_char);
LLAMA_LOG_DEBUG("%s: line %zu: persistent entry\n", __func__, i);
}
// (... "PHRASE" ... "PHRASE" ...) // (DURATION : ...)
std::vector<std::string> phrases; int32_t duration = is_nested ? -1 : 1;
auto pos = line.find('"', qqpos + 1); const auto cln_pos = line.find(':');
while (pos != std::string::npos) { if ((cln_pos != std::string::npos) && (1 < cln_pos) && (cln_pos < qq_posi[1])) {
if (line[pos - 1] == '\\') { duration = std::stoi(line.substr(1, cln_pos - 1));
pos = line.find('"', pos + 1); }
} else { if (duration == 0) {
auto phrase = line.substr(qqpos + 1, pos - qqpos - 1); LLAMA_LOG_DEBUG("%s: line %zu: invalid duration\n", __func__, i);
string_process_escapes(phrase); continue; // next line
phrases.push_back(std::move(phrase)); }
qqpos = line.find('"', pos + 1);
if (qqpos == std::string::npos) { #undef X
break; #define X(T, MEMBER, DV, PRECAST) #MEMBER,
} static const std::vector<std::string> names = { X_COMMON_PARAMS_SAMPLING };
pos = line.find('"', qqpos + 1);
std::vector<float> addsubs(names.size(), 0.0f);
bool is_sb = false;
// (... : SPARAM ...)
const auto window = line.substr(last_qq_pos + 1);
for (int j = 0; j < names.size(); ++j) {
const auto& name = names[j];
auto pos = window.find(name);
if (pos != std::string::npos) {
pos += name.length();
auto next_pos = window.find(",", pos + 1);
if (next_pos == std::string::npos) {
next_pos = n_char - 1;
}
auto sub = string_strip(window.substr(pos, next_pos - pos));
if (sub[0] == '~') {
addsubs[j] += std::stof(sub.substr(1));
is_sb = true;
LLAMA_LOG_DEBUG("%s: line %zu: bias = %f\n", __func__, i, addsubs[j]);
} }
} }
if (phrases.empty()) { }
auto& phrases = extracts;
if (phrases.empty()) {
if (is_sb) {
phrases.push_back("");
} else {
continue; // next line continue; // next line
} }
}
const auto n_phrase = phrases.size();
std::vector<float> biases;
bool is_range = false;
if (!is_sb) {
// (... : BIAS ...) // (... : BIAS ...)
std::vector<float> biases; const auto cln_rpos = line.rfind(':');
bool is_range = false; auto sub = line.substr(cln_rpos + 1, n_char - cln_rpos - 2);
const auto rcpos = line.rfind(':'); if (sub.find("~") != std::string::npos) {
if ((rcpos != std::string::npos) && (line.rfind('"') < rcpos)) { // (... : BIAS ~ BIAS)
auto sub = line.substr(rcpos + 1, n_char - rcpos - 2); const auto splits = string_split(sub, '~');
if (sub.find("~") != std::string::npos) { biases.push_back(std::stof(splits.front()));
// (... : BIAS ~ BIAS) LLAMA_LOG_DEBUG("%s: line %zu: logit bias = %f\n", __func__, i, biases.back());
const auto splits = string_split(sub, '~'); biases.push_back(std::stof(splits.back()));
auto split = splits.front(); LLAMA_LOG_DEBUG("%s: line %zu: logit bias = %f\n", __func__, i, biases.back());
biases.push_back(std::stof(split)); is_range = true;
split = splits.back(); } else {
biases.push_back(std::stof(split)); // (... : BIAS, BIAS, ..., BIAS)
is_range = true; for (const auto& split: string_split(sub, ',')) {
} else { if (!split.empty()) {
// (... : BIAS, BIAS, ..., BIAS) biases.push_back(std::stof(split));
for (auto split: string_split(sub, ',')) { LLAMA_LOG_DEBUG("%s: line %zu: logit bias = %f\n", __func__, i, biases.back());
if (!split.empty()) {
biases.push_back(std::stof(split));
}
} }
} }
} }
if (biases.empty()) { if (biases.empty()) {
continue; // next line continue; // next line
} }
if (is_nested) {
saved_duration = duration;
saved_phrases = std::move(phrases);
saved_biases = std::move(biases);
saved_is_range = is_range;
} else {
elb_params.back().entries.push_back({ std::move(phrases), std::move(biases), duration, is_range });
}
continue; // next line
} }
size_t max_phrase_len = 0;
for (const auto& phrase: phrases) {
LLAMA_LOG_DEBUG("%s: line %zu: phrase = \"%s\"\n", __func__, i, phrase.c_str());
max_phrase_len = std::max(phrase.length(), max_phrase_len);
}
LLAMA_LOG_DEBUG("%s: line %zu: max_phrase_len = %zu\n", __func__, i, max_phrase_len);
common_params_sampling::elb_param::elb_entry entry = {
std::vector<size_t>(n_phrase, 0),
std::move(addsubs),
std::vector<bool>(n_phrase, false),
max_phrase_len,
std::move(phrases),
std::move(biases),
duration,
is_range
};
if (is_nested) {
entries.push_back(entry);
}
elb_params.back().entries.push_back(std::move(entry));
continue; // next line
} }
// exitword if (last_qq_pos > 0) {
if ('"' == c0 && cE == '"') { elb_params.back().op = string_strip(line.substr(last_qq_pos + 1));
line = line.substr(1, n_char - 2);
} }
string_process_escapes(line);
elb_params.back().exitword = std::move(line); auto& exitwords = extracts;
if (!saved_phrases.empty() && !saved_biases.empty() && (saved_duration != 0)) { if (exitwords.empty()) {
elb_params.back().entries.push_back({ saved_phrases, saved_biases, saved_duration, saved_is_range }); string_process_escapes(line);
exitwords.push_back(std::move(line));
} }
elb_params.push_back({ { }, "" });
// maybe support multiple exitwords in future
elb_params.back().exitword = std::move(exitwords[0]);
elb_params.push_back({ entries, "", "" });
} }
sparams.elb_params = std::move(elb_params); sparams.elb_params = std::move(elb_params);

View File

@ -646,6 +646,11 @@ std::vector<std::string> string_split<std::string>(const std::string& input, cha
bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides); bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
void string_process_escapes(std::string & input); void string_process_escapes(std::string & input);
std::string string_unescape(const std::string& str);
std::vector<std::string> string_extract(const std::string& str, const char c, std::vector<size_t>& posi);
bool string_is_found(const std::string& window, const std::string& str, size_t& pos);
// //
// Filesystem utils // Filesystem utils

View File

@ -803,6 +803,29 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
return result; return result;
} }
static void elb_print(common_params_sampling& sparams, const common_params_sampling::elb_param::elb_entry& entry) {
#undef X
#define X(T, MEMBER, DV, PRECAST) #MEMBER,
static const std::vector<std::string> names = { X_COMMON_PARAMS_SAMPLING };
#undef X
#define X(T, MEMBER, DV, PRECAST) if (std::abs(entry.addsubs[SPARAMS_ ## MEMBER ## _ENUM]) > 0.0f) \
{ LLAMA_LOG_DEBUG("%s: %s = %f\n", __func__, names[SPARAMS_ ## MEMBER ## _ENUM].c_str(), float(A_DOT_B(sparams, MEMBER))); }
X_COMMON_PARAMS_SAMPLING
}
static void elb_add(common_params_sampling& sparams, const common_params_sampling::elb_param::elb_entry& entry) {
#undef X
#define X(T, MEMBER, _, PRECAST) A_DOT_B(sparams, MEMBER) += static_cast<T>(PRECAST(entry.addsubs[SPARAMS_ ## MEMBER ## _ENUM]));
X_COMMON_PARAMS_SAMPLING
}
static void elb_sub(common_params_sampling& sparams, const common_params_sampling::elb_param::elb_entry& entry) {
#undef X
#define X(T, MEMBER, _, PRECAST) A_DOT_B(sparams, MEMBER) -= static_cast<T>(PRECAST(entry.addsubs[SPARAMS_ ## MEMBER ## _ENUM]));
X_COMMON_PARAMS_SAMPLING
}
void common_expiring_logit_bias_apply(struct common_sampler* ctx_sampling, float* logits) { void common_expiring_logit_bias_apply(struct common_sampler* ctx_sampling, float* logits) {
auto index_first_inactive = [](auto countup, auto& tokens) { auto index_first_inactive = [](auto countup, auto& tokens) {
return std::distance( return std::distance(
@ -853,6 +876,58 @@ void common_expiring_logit_bias_apply(struct common_sampler* ctx_sampling, float
} }
} }
} }
// expiring sampler bias
for (auto& entry: ctx_sampling->params.elb_params[ctx_sampling->elb_idx].entries) {
if (!entry.biases.empty()) {
continue; // next entry
}
for (size_t j = 0; j < entry.phrases.size(); ++j) {
const auto& phrase = entry.phrases[j];
if (phrase.empty()) {
// duration bound only
if (elb.countup == 0) {
LLAMA_LOG_DEBUG("%s: before add\n", __func__);
elb_print(ctx_sampling->params, entry);
elb_add(ctx_sampling->params, entry);
entry.addflags[j] = true;
LLAMA_LOG_DEBUG("%s: after add\n", __func__);
elb_print(ctx_sampling->params, entry);
} else if (elb.countup == entry.duration) {
LLAMA_LOG_DEBUG("%s: before sub\n", __func__);
elb_print(ctx_sampling->params, entry);
elb_sub(ctx_sampling->params, entry);
entry.addflags[j] = false;
LLAMA_LOG_DEBUG("%s: after sub\n", __func__);
elb_print(ctx_sampling->params, entry);
}
continue; // next entry
}
size_t count = 0;
auto pos = ctx_sampling->to_generated_text->find(phrase, entry.posi[j]);
while (pos != std::string::npos) {
LLAMA_LOG_DEBUG("%s: found %s @ %zu\n", __func__, phrase.c_str(), pos);
++count;
pos = ctx_sampling->to_generated_text->find(phrase, pos + phrase.length());
}
entry.posi[j] = std::max(0, int32_t(ctx_sampling->to_generated_text->length()) - int32_t(phrase.length()) + 1);
if (count % 2 == 1) {
// even = no match or cancelled
LLAMA_LOG_DEBUG("%s: before\n", __func__);
elb_print(ctx_sampling->params, entry);
(entry.addflags[j] ? elb_sub : elb_add)(ctx_sampling->params, entry);
entry.addflags[j] = !entry.addflags[j];
LLAMA_LOG_DEBUG("%s: after\n", __func__);
elb_print(ctx_sampling->params, entry);
}
}
}
} }
void common_expiring_logit_bias_accept(struct common_sampler* ctx_sampling, struct llama_context * ctx_main) { void common_expiring_logit_bias_accept(struct common_sampler* ctx_sampling, struct llama_context * ctx_main) {
@ -861,25 +936,53 @@ void common_expiring_logit_bias_accept(struct common_sampler* ctx_sampling, stru
return; return;
} }
auto& elb = ctx_sampling->elb_states[ctx_sampling->elb_idx]; auto idx = ctx_sampling->elb_idx;
const int32_t exitword_len = elb.exitword.length(); auto& elb = ctx_sampling->elb_states[idx];
if ((elb.delay > ++elb.countup) || (exitword_len == 0)) { if ((elb.delay > ++elb.countup) || (elb.search_word_len == 0)) {
return; return;
} }
const std::string search_window = ctx_sampling->to_generated_text->substr(std::min( const std::string window = ctx_sampling->to_generated_text->substr(std::min(
ctx_sampling->to_generated_text->length(), ctx_sampling->to_generated_text->length(),
size_t(ctx_sampling->elb_search_pos) ctx_sampling->elb_search_pos)) + common_token_to_piece(ctx_main, ctx_sampling->prev.back(), true);
)) + common_token_to_piece(ctx_main, ctx_sampling->prev.back(), true); size_t pos = 0;
if (string_is_found(window, elb.jumpword, pos)) {
const auto exitword_pos = search_window.find(elb.exitword); LLAMA_LOG_DEBUG("%s: found %s in %s @ %zu\n", __func__, string_unescape(elb.jumpword).c_str(), string_unescape(window).c_str(), pos);
if (exitword_pos != std::string::npos) { pos += ctx_sampling->elb_search_pos + elb.jumpword.length();
ctx_sampling->elb_idx = elb.jump_idx;
} else if (string_is_found(window, elb.exitword, pos)) {
LLAMA_LOG_DEBUG("%s: found %s in %s @ %zu\n", __func__, string_unescape(elb.exitword).c_str(), string_unescape(window).c_str(), pos);
pos += ctx_sampling->elb_search_pos + elb.exitword.length();
++ctx_sampling->elb_idx; ++ctx_sampling->elb_idx;
// no double counting characters that matched
ctx_sampling->elb_search_pos += exitword_pos + exitword_len;
} else { } else {
// move search position to include next token // not found. move search position to include next token
ctx_sampling->elb_search_pos += std::max(0, int32_t(search_window.length()) - exitword_len + 1); ctx_sampling->elb_search_pos += std::max(0, int32_t(window.length()) - int32_t(elb.search_word_len) + 1);
return;
}
// single character clearance
// e.g. stop \n\n from expiring two \n immediately
ctx_sampling->elb_search_pos = pos + 1;
// undo current sampler bias
for (auto& entry: ctx_sampling->params.elb_params[idx].entries) {
for (const auto addflag: entry.addflags) {
if (addflag) {
LLAMA_LOG_DEBUG("%s: before\n", __func__);
elb_print(ctx_sampling->params, entry);
elb_sub(ctx_sampling->params, entry);
LLAMA_LOG_DEBUG("%s: after\n", __func__);
elb_print(ctx_sampling->params, entry);
}
}
}
// prepare next sampler bias
for (auto& entry: ctx_sampling->params.elb_params[ctx_sampling->elb_idx].entries) {
// no clearance for sampler bias
std::fill(entry.posi.begin(), entry.posi.end(), pos);
} }
} }

View File

@ -9,6 +9,8 @@
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#define A_DOT_B(a, b) a.b
// sampler types // sampler types
enum class llama_sampler_type : char { enum class llama_sampler_type : char {
DRY = 'd', DRY = 'd',
@ -80,37 +82,49 @@ inline bool common_grammar_needs_prefill(const common_grammar & g) {
} }
#define X_COMMON_PARAMS_SAMPLING /* \
*/ X( int32_t , min_keep , 0 , std::round ) /* 0 = disabled, otherwise samplers should return at least min_keep tokens \
*/ X( int32_t , top_k , 40 , std::round ) /* <= 0 to use vocab size \
*/ X( float , top_p , 0.95f , ) /* 1.0 = disabled \
*/ X( float , min_p , 0.05f , ) /* 0.0 = disabled \
*/ X( float , tfs_z , 1.00f , ) /* 1.0 = disabled \
*/ X( float , typical_p , 1.00f , ) /* 1.0 = disabled \
*/ X( float , temp , 0.80f , ) /* <= 0.0 to sample greedily, 0.0 to not output probabilities \
*/ X( float , dynatemp_range , 0.00f , ) /* 0.0 = disabled \
*/ X( float , dynatemp_exponent , 1.00f , ) /* controls how entropy maps to temperature in dynamic temperature sampler \
*/ X( int32_t , penalty_last_n , 64 , std::round ) /* last n tokens to penalize (0 = disable penalty, -1 = context size) \
*/ X( float , penalty_repeat , 1.00f , ) /* 1.0 = disabled \
*/ X( float , penalty_freq , 0.00f , ) /* 0.0 = disabled \
*/ X( float , penalty_present , 0.00f , ) /* 0.0 = disabled \
*/ X( float , dry_multiplier , 0.0f , ) /* 0.0 = disabled; DRY repetition penalty for tokens extending repetition: \
*/ X( float , dry_base , 1.75f , ) /* 0.0 = disabled; multiplier * base ^ (length of sequence before token - allowed length) \
*/ X( int32_t , dry_allowed_length , 2 , std::round ) /* tokens extending repetitions beyond this receive penalty \
*/ X( int32_t , dry_penalty_last_n , -1 , std::round ) /* how many tokens to scan for repetitions (0 = disable penalty, -1 = context size) \
*/ X( int32_t , mirostat , 0 , std::round ) /* 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 \
*/ X( float , mirostat_tau , 5.00f , ) /* target entropy \
*/ X( float , mirostat_eta , 0.10f , ) /* learning rate \
*/ X( float , xtc_probability , 0.0f , ) /* xtc probability \
*/ X( float , xtc_threshold , 1.0f , ) /* xtc threshold, disabled if > 0.5 \
*/ X( float , top_n_sigma , 0.0f , ) /* top-n-sigma \
*/ X( float , adaptive_target , -1.0f , ) /* select tokens near this probability (valid range 0.0 to 1.0; <0 = disabled) \
*/ X( float , adaptive_decay , 0.90f , ) /* decay rate for target adaptation over time. lower values -> faster but less stable adaptation. (valid range 0.0 to 1.0; ≤0 = no adaptation) \
*/ X( bool , adaptive_updt_w_cur , false , std::round ) /* update state with current probability \
*/
enum {
#undef X
#define X(T, MEMBER, DV, PRECAST) SPARAMS_ ## MEMBER ## _ENUM,
X_COMMON_PARAMS_SAMPLING
};
// sampling parameters // sampling parameters
typedef struct common_params_sampling { typedef struct common_params_sampling {
#undef X
#define X(T, MEMBER, DV, _) T MEMBER = DV;
X_COMMON_PARAMS_SAMPLING
int32_t n_prev = 64; // number of previous tokens to remember int32_t n_prev = 64; // number of previous tokens to remember
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens int32_t total_context_size = 16840;
int32_t top_k = 40; // <= 0 to use vocab size
float top_p = 0.95f; // 1.0 = disabled
float min_p = 0.05f; // 0.0 = disabled
float tfs_z = 1.00f; // 1.0 = disabled
float typical_p = 1.00f; // 1.0 = disabled
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
float dynatemp_range = 0.00f; // 0.0 = disabled
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
float penalty_repeat = 1.00f; // 1.0 = disabled
float penalty_freq = 0.00f; // 0.0 = disabled
float penalty_present = 0.00f; // 0.0 = disabled
float dry_multiplier = 0.0f; // 0.0 = disabled; DRY repetition penalty for tokens extending repetition:
float dry_base = 1.75f; // 0.0 = disabled; multiplier * base ^ (length of sequence before token - allowed length)
int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty
int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
int32_t total_context_size = 16840;
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
float mirostat_tau = 5.00f; // target entropy
float mirostat_eta = 0.10f; // learning rate
float xtc_probability = 0.0f; // xtc probability
float xtc_threshold = 1.0f; // xtc threshold, disabled if > 0.5
float top_n_sigma = 0.0f; // top-n-sigma
float adaptive_target = -1.0f; // select tokens near this probability (valid range 0.0 to 1.0; <0 = disabled)
float adaptive_decay = 0.90f; // decay rate for target adaptation over time. lower values -> faster but less stable adaptation. (valid range 0.0 to 1.0; ≤0 = no adaptation)
bool adaptive_updt_w_cur = false; // update state with current probability
bool penalize_nl = false; // consider newlines as a repeatable token bool penalize_nl = false; // consider newlines as a repeatable token
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context
@ -163,6 +177,10 @@ typedef struct common_params_sampling {
// expiring logit bias // expiring logit bias
struct elb_param { struct elb_param {
struct elb_entry { struct elb_entry {
std::vector<size_t> posi; // positions of phrases in generated text
std::vector<float> addsubs; // add/modify then subtract/restore sampling parameters
std::vector<bool> addflags; // true if added
size_t max_phrase_len;
std::vector<std::string> phrases; std::vector<std::string> phrases;
std::vector<float> biases; // for each phrase, nth bias for nth token, extrapolate std::vector<float> biases; // for each phrase, nth bias for nth token, extrapolate
int32_t duration; // bias duration, unless exitword matches int32_t duration; // bias duration, unless exitword matches
@ -171,13 +189,18 @@ typedef struct common_params_sampling {
return (is_range == other.is_range) return (is_range == other.is_range)
&& (duration == other.duration) && (duration == other.duration)
&& (biases == other.biases) && (biases == other.biases)
&& (phrases == other.phrases); && (phrases == other.phrases)
&& (addflags == other.addflags)
&& (addsubs == other.addsubs)
&& (posi == other.posi);
} }
}; };
std::vector<struct elb_entry> entries; std::vector<struct elb_entry> entries;
std::string exitword; // move to next state if matched during generation std::string exitword; // move to next state if matched during generation
std::string op; // exitword operator
bool operator == (const struct elb_param& other) const { bool operator == (const struct elb_param& other) const {
return (exitword == other.exitword) return (op == other.op)
&& (exitword == other.exitword)
&& (entries == other.entries); && (entries == other.entries);
} }
}; };
@ -233,10 +256,13 @@ struct common_sampler {
size_t countup; // compare against duration size_t countup; // compare against duration
size_t delay; // to avoid early termination of positively biased phrases size_t delay; // to avoid early termination of positively biased phrases
int32_t max_cond_len; int32_t max_cond_len;
std::string jumpword;
size_t jump_idx;
size_t search_word_len;
}; };
std::vector<struct elb_state> elb_states; std::vector<struct elb_state> elb_states;
size_t elb_idx; // for elb_states size_t elb_idx; // for elb_states
int32_t elb_search_pos; // for exitword size_t elb_search_pos;
}; };

View File

@ -1929,7 +1929,7 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
do // populate expiring logit bias do // populate expiring logit bias
{ {
const auto elb_prev_params = slot.sparams.elb_params; const auto prev_elb_params = slot.sparams.elb_params;
const auto& expiring_logit_bias = data.find("expiring_logit_bias"); const auto& expiring_logit_bias = data.find("expiring_logit_bias");
if (expiring_logit_bias != data.end() && expiring_logit_bias->is_array()) { if (expiring_logit_bias != data.end() && expiring_logit_bias->is_array()) {
@ -1955,9 +1955,9 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
break; break;
} }
if (!slot.elb_prev_states.empty() && (elb_params == elb_prev_params)) { if (!slot.prev_elb_states.empty() && (elb_params == prev_elb_params)) {
// reset and reuse previous states // reset and reuse previous states
slot.ctx_sampling->elb_states = slot.elb_prev_states; slot.ctx_sampling->elb_states = slot.prev_elb_states;
for (auto& elb_state: slot.ctx_sampling->elb_states) { for (auto& elb_state: slot.ctx_sampling->elb_states) {
elb_state.countup = 0; elb_state.countup = 0;
} }
@ -1970,22 +1970,40 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
slot.ctx_sampling->elb_states.reserve(n_elb_param); slot.ctx_sampling->elb_states.reserve(n_elb_param);
// 1 state <-> 1 exitword <-> 1+ entries // 1 state <-> 1 exitword <-> 1+ entries
for (const auto& [entries, exitword]: elb_params) { for (int32_t i = 0; i < elb_params.size(); ++i) {
slot.ctx_sampling->elb_states.push_back({ { }, { }, exitword, 0, 0, 0 }); const auto& [entries, exitword, op] = elb_params[i];
if (op == ">>") {
for (auto& elb_state: slot.ctx_sampling->elb_states) {
if (elb_state.jumpword.empty()) {
elb_state.jumpword = exitword;
elb_state.jump_idx = i + 1;
elb_state.search_word_len = std::max(elb_state.exitword.length(), elb_state.jumpword.length());
}
}
}
slot.ctx_sampling->elb_states.push_back({ { }, { }, exitword, 0, 0, 0, "", 0, exitword.length() });
auto& first_tokens = slot.ctx_sampling->elb_states.back().first_tokens; auto& first_tokens = slot.ctx_sampling->elb_states.back().first_tokens;
auto& other_tokens = slot.ctx_sampling->elb_states.back().other_tokens; auto& other_tokens = slot.ctx_sampling->elb_states.back().other_tokens;
auto& delay = slot.ctx_sampling->elb_states.back().delay; auto& delay = slot.ctx_sampling->elb_states.back().delay;
auto& max_cond_len = slot.ctx_sampling->elb_states.back().max_cond_len; auto& max_cond_len = slot.ctx_sampling->elb_states.back().max_cond_len;
// 1 entry <-> 1 phrase <-> 1+ biases // 1 entry <-> 1 phrase <-> 1+ biases
for (auto [phrases, biases, duration, is_range]: entries) { for (auto& entry: entries) {
for (const auto& phrase: phrases) { auto biases = entry.biases;
if (phrase.empty()) { if (biases.empty()) {
continue; // expiring sampler bias
} continue;
}
// expiring logit bias
for (const auto& phrase: entry.phrases) {
auto duration = entry.duration;
const auto ids = common_tokenize(model, phrase, false, true); const auto ids = common_tokenize(model, phrase, false, true);
if (!is_range) { if (!entry.is_range) {
// extrapolate // extrapolate
biases.resize(ids.size(), biases.back()); biases.resize(ids.size(), biases.back());
} else if (ids.size() == 1) { } else if (ids.size() == 1) {
@ -2040,7 +2058,7 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
}); });
} }
} while (false); } while (false);
slot.elb_prev_states = slot.ctx_sampling->elb_states; slot.prev_elb_states = slot.ctx_sampling->elb_states;
slot.command = SLOT_COMMAND_LOAD_PROMPT; slot.command = SLOT_COMMAND_LOAD_PROMPT;
// slot.prompt_tokens.clear(); // slot.prompt_tokens.clear();

View File

@ -169,7 +169,7 @@ struct server_slot {
common_sampler * ctx_sampling = nullptr; common_sampler * ctx_sampling = nullptr;
// expiring logit bias // expiring logit bias
decltype(ctx_sampling->elb_states) elb_prev_states; std::vector<common_sampler::elb_state> prev_elb_states;
bool has_mtp = false; bool has_mtp = false;