diff --git a/common/common.cpp b/common/common.cpp index dad54c71..2ee4e772 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -2966,6 +2966,8 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", " -l TOKEN_ID(+/-)BIAS", "modifies the likelihood of token appearing in the completion,\n" "i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n" "or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'" }); + options.push_back({ "*", " --expiring-logit-bias-file", + "original PR: https://github.com/ikawrakow/ik_llama.cpp/pull/1731\n"}); options.push_back({ "main", " --cfg-negative-prompt PROMPT", "negative prompt to use for guidance (default: '%s')", sparams.cfg_negative_prompt.c_str() }); options.push_back({ "main", " --cfg-negative-prompt-file FNAME", @@ -3491,6 +3493,28 @@ void string_process_escapes(std::string & input) { input.resize(output_idx); } +std::string string_unescape(const std::string& str) { + std::string result; + result.reserve(2 * str.length()); + for (const auto c: str) { + switch (c) { + case '\n': + result.append("\\n"); + break; + case '\t': + result.append("\\t"); + break; + case '\r': + result.append("\\r"); + break; + default: + result.append(1, c); + break; + } + } + return result; +} + bool string_parse_kv_override(const char * data, std::vector & overrides) { const char * sep = strchr(data, '='); if (sep == nullptr || sep - data >= 128) { @@ -3537,6 +3561,42 @@ bool string_parse_kv_override(const char * data, std::vector string_extract(const std::string& str, const char c, std::vector& posi) { + std::vector extracts; + auto pos = str.find(c); + size_t count = 0; + while (pos != std::string::npos) { + if (count % 2 == 0) { + // opening c + posi.push_back(pos); + ++count; + } else { + // closing c must be unescaped + auto esc_pos = pos; + size_t n_esc = 0; + while ((esc_pos > 0) && (str[--esc_pos] == '\\')) { + ++n_esc; + } + if (n_esc % 2 == 0) { + extracts.push_back(str.substr(posi.back() + 1, pos - posi.back() - 1)); + string_process_escapes(extracts.back()); + posi.push_back(pos); + ++count; + } + } + pos = str.find(c, pos + 1); + } + return extracts; +} + +bool string_is_found(const std::string& window, const std::string& str, size_t& pos) { + if (str.empty()) { + return false; + } + pos = window.find(str); + return pos != std::string::npos; +} + // // Filesystem utils // @@ -5170,121 +5230,169 @@ std::tuple argparse_allowlist_unicode_ru } void argparse_expiring_logit_bias(const std::string& content, common_params_sampling& sparams) { - decltype(sparams.elb_params) elb_params = { { { }, "" } }; + auto elb_params = sparams.elb_params; + elb_params.push_back({ { }, "", "" }); + auto entries = elb_params[0].entries; - int32_t saved_duration = 0; - std::vector saved_phrases; - std::vector saved_biases; - bool saved_is_range = false; - - for (auto line: string_split(content, "\n")) { - string_strip(line); + const auto lines = string_split(content, "\n"); + for (size_t i = 0; i < lines.size(); ++i) { + auto line = string_strip(lines[i]); const char c0 = line.empty() ? '#' : line[0]; if (c0 == '#') { - // comment + LLAMA_LOG_DEBUG("%s: line %zu: comment or empty\n", __func__, i); continue; // next line } + // (... "EXTRACT" ... "EXTRACT" ...) + std::vector qq_posi = { 0 }; + auto extracts = string_extract(line, '"', qq_posi); + qq_posi.push_back(std::string::npos); + for (int32_t j = 0; j < int32_t(qq_posi.size()) - 1; j += 2) { + const auto pnd_pos = line.find('#', qq_posi[j]); + if (pnd_pos < qq_posi[j + 1]) { + LLAMA_LOG_DEBUG("%s: line %zu: inline comment @ %zu\n", __func__, i, pnd_pos); + line = string_strip(line.substr(0, pnd_pos)); + qq_posi.resize(j + 2); + qq_posi.back() = std::string::npos; + extracts.resize(j / 2); + break; + } + } + const auto last_qq_pos = qq_posi[qq_posi.size() - 2]; + auto n_char = line.length(); const char cE = line[n_char - 1]; - if (n_char > 1) { - if ('(' == c0 && cE == ')') { - const bool is_nested = '(' == line[1] && line[n_char - 2] == ')'; - if (is_nested) { - if (n_char == 4) { - // (()) - saved_phrases.clear(); - saved_biases.clear(); - continue; // next line - } - n_char -= 2; - line = line.substr(1, n_char); - } - - auto qqpos = line.find('"'); - - // (DURATION : ...) - int32_t duration = is_nested ? -1 : 1; - const auto cpos = line.find(':'); - if ((cpos != std::string::npos) && (1 < cpos) && (cpos < qqpos)) { - auto sub = line.substr(1, cpos - 1); - duration = std::stoi(sub); - } - if (duration == 0) { + LLAMA_LOG_DEBUG("%s: line %zu: %s\n", __func__, i, line.c_str()); + if ('(' == c0 && cE == ')') { + const bool is_nested = '(' == line[1] && line[n_char - 2] == ')'; + if (is_nested) { + if (n_char == 4) { + // (()) + entries.clear(); + LLAMA_LOG_DEBUG("%s: line %zu: persistent entry clear\n", __func__, i); continue; // next line } + n_char -= 2; + line = line.substr(1, n_char); + LLAMA_LOG_DEBUG("%s: line %zu: persistent entry\n", __func__, i); + } - // (... "PHRASE" ... "PHRASE" ...) - std::vector phrases; - auto pos = line.find('"', qqpos + 1); - while (pos != std::string::npos) { - if (line[pos - 1] == '\\') { - pos = line.find('"', pos + 1); - } else { - auto phrase = line.substr(qqpos + 1, pos - qqpos - 1); - string_process_escapes(phrase); - phrases.push_back(std::move(phrase)); - qqpos = line.find('"', pos + 1); - if (qqpos == std::string::npos) { - break; - } - pos = line.find('"', qqpos + 1); + // (DURATION : ...) + int32_t duration = is_nested ? -1 : 1; + const auto cln_pos = line.find(':'); + if ((cln_pos != std::string::npos) && (1 < cln_pos) && (cln_pos < qq_posi[1])) { + duration = std::stoi(line.substr(1, cln_pos - 1)); + } + if (duration == 0) { + LLAMA_LOG_DEBUG("%s: line %zu: invalid duration\n", __func__, i); + continue; // next line + } + + #undef X + #define X(T, MEMBER, DV, PRECAST) #MEMBER, + static const std::vector names = { X_COMMON_PARAMS_SAMPLING }; + + std::vector addsubs(names.size(), 0.0f); + bool is_sb = false; + + // (... : SPARAM ...) + const auto window = line.substr(last_qq_pos + 1); + for (int j = 0; j < names.size(); ++j) { + const auto& name = names[j]; + auto pos = window.find(name); + if (pos != std::string::npos) { + pos += name.length(); + auto next_pos = window.find(",", pos + 1); + if (next_pos == std::string::npos) { + next_pos = n_char - 1; + } + auto sub = string_strip(window.substr(pos, next_pos - pos)); + if (sub[0] == '~') { + addsubs[j] += std::stof(sub.substr(1)); + is_sb = true; + LLAMA_LOG_DEBUG("%s: line %zu: bias = %f\n", __func__, i, addsubs[j]); } } - if (phrases.empty()) { + } + + auto& phrases = extracts; + if (phrases.empty()) { + if (is_sb) { + phrases.push_back(""); + } else { continue; // next line } + } + const auto n_phrase = phrases.size(); + std::vector biases; + bool is_range = false; + + if (!is_sb) { // (... : BIAS ...) - std::vector biases; - bool is_range = false; - const auto rcpos = line.rfind(':'); - if ((rcpos != std::string::npos) && (line.rfind('"') < rcpos)) { - auto sub = line.substr(rcpos + 1, n_char - rcpos - 2); - if (sub.find("~") != std::string::npos) { - // (... : BIAS ~ BIAS) - const auto splits = string_split(sub, '~'); - auto split = splits.front(); - biases.push_back(std::stof(split)); - split = splits.back(); - biases.push_back(std::stof(split)); - is_range = true; - } else { - // (... : BIAS, BIAS, ..., BIAS) - for (auto split: string_split(sub, ',')) { - if (!split.empty()) { - biases.push_back(std::stof(split)); - } + const auto cln_rpos = line.rfind(':'); + auto sub = line.substr(cln_rpos + 1, n_char - cln_rpos - 2); + if (sub.find("~") != std::string::npos) { + // (... : BIAS ~ BIAS) + const auto splits = string_split(sub, '~'); + biases.push_back(std::stof(splits.front())); + LLAMA_LOG_DEBUG("%s: line %zu: logit bias = %f\n", __func__, i, biases.back()); + biases.push_back(std::stof(splits.back())); + LLAMA_LOG_DEBUG("%s: line %zu: logit bias = %f\n", __func__, i, biases.back()); + is_range = true; + } else { + // (... : BIAS, BIAS, ..., BIAS) + for (const auto& split: string_split(sub, ',')) { + if (!split.empty()) { + biases.push_back(std::stof(split)); + LLAMA_LOG_DEBUG("%s: line %zu: logit bias = %f\n", __func__, i, biases.back()); } } } if (biases.empty()) { continue; // next line } - - if (is_nested) { - saved_duration = duration; - saved_phrases = std::move(phrases); - saved_biases = std::move(biases); - saved_is_range = is_range; - } else { - elb_params.back().entries.push_back({ std::move(phrases), std::move(biases), duration, is_range }); - } - continue; // next line } + + size_t max_phrase_len = 0; + for (const auto& phrase: phrases) { + LLAMA_LOG_DEBUG("%s: line %zu: phrase = \"%s\"\n", __func__, i, phrase.c_str()); + max_phrase_len = std::max(phrase.length(), max_phrase_len); + } + LLAMA_LOG_DEBUG("%s: line %zu: max_phrase_len = %zu\n", __func__, i, max_phrase_len); + + common_params_sampling::elb_param::elb_entry entry = { + std::vector(n_phrase, 0), + std::move(addsubs), + std::vector(n_phrase, false), + max_phrase_len, + std::move(phrases), + std::move(biases), + duration, + is_range + }; + if (is_nested) { + entries.push_back(entry); + } + elb_params.back().entries.push_back(std::move(entry)); + continue; // next line } - // exitword - if ('"' == c0 && cE == '"') { - line = line.substr(1, n_char - 2); + if (last_qq_pos > 0) { + elb_params.back().op = string_strip(line.substr(last_qq_pos + 1)); } - string_process_escapes(line); - elb_params.back().exitword = std::move(line); - if (!saved_phrases.empty() && !saved_biases.empty() && (saved_duration != 0)) { - elb_params.back().entries.push_back({ saved_phrases, saved_biases, saved_duration, saved_is_range }); + + auto& exitwords = extracts; + if (exitwords.empty()) { + string_process_escapes(line); + exitwords.push_back(std::move(line)); } - elb_params.push_back({ { }, "" }); + + // maybe support multiple exitwords in future + elb_params.back().exitword = std::move(exitwords[0]); + + elb_params.push_back({ entries, "", "" }); } sparams.elb_params = std::move(elb_params); diff --git a/common/common.h b/common/common.h index 2d4db6d9..abcc1203 100644 --- a/common/common.h +++ b/common/common.h @@ -646,6 +646,11 @@ std::vector string_split(const std::string& input, cha bool string_parse_kv_override(const char * data, std::vector & overrides); void string_process_escapes(std::string & input); +std::string string_unescape(const std::string& str); + +std::vector string_extract(const std::string& str, const char c, std::vector& posi); + +bool string_is_found(const std::string& window, const std::string& str, size_t& pos); // // Filesystem utils diff --git a/common/sampling.cpp b/common/sampling.cpp index 99fc48ce..03504bee 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -803,6 +803,29 @@ std::vector common_sampler_sample_and_accept_n(struct common_sample return result; } + +static void elb_print(common_params_sampling& sparams, const common_params_sampling::elb_param::elb_entry& entry) { + #undef X + #define X(T, MEMBER, DV, PRECAST) #MEMBER, + static const std::vector names = { X_COMMON_PARAMS_SAMPLING }; + #undef X + #define X(T, MEMBER, DV, PRECAST) if (std::abs(entry.addsubs[SPARAMS_ ## MEMBER ## _ENUM]) > 0.0f) \ + { LLAMA_LOG_DEBUG("%s: %s = %f\n", __func__, names[SPARAMS_ ## MEMBER ## _ENUM].c_str(), float(A_DOT_B(sparams, MEMBER))); } + X_COMMON_PARAMS_SAMPLING +} + +static void elb_add(common_params_sampling& sparams, const common_params_sampling::elb_param::elb_entry& entry) { + #undef X + #define X(T, MEMBER, _, PRECAST) A_DOT_B(sparams, MEMBER) += static_cast(PRECAST(entry.addsubs[SPARAMS_ ## MEMBER ## _ENUM])); + X_COMMON_PARAMS_SAMPLING +} + +static void elb_sub(common_params_sampling& sparams, const common_params_sampling::elb_param::elb_entry& entry) { + #undef X + #define X(T, MEMBER, _, PRECAST) A_DOT_B(sparams, MEMBER) -= static_cast(PRECAST(entry.addsubs[SPARAMS_ ## MEMBER ## _ENUM])); + X_COMMON_PARAMS_SAMPLING +} + void common_expiring_logit_bias_apply(struct common_sampler* ctx_sampling, float* logits) { auto index_first_inactive = [](auto countup, auto& tokens) { return std::distance( @@ -853,6 +876,58 @@ void common_expiring_logit_bias_apply(struct common_sampler* ctx_sampling, float } } } + + // expiring sampler bias + for (auto& entry: ctx_sampling->params.elb_params[ctx_sampling->elb_idx].entries) { + if (!entry.biases.empty()) { + continue; // next entry + } + for (size_t j = 0; j < entry.phrases.size(); ++j) { + const auto& phrase = entry.phrases[j]; + if (phrase.empty()) { + // duration bound only + if (elb.countup == 0) { + LLAMA_LOG_DEBUG("%s: before add\n", __func__); + elb_print(ctx_sampling->params, entry); + + elb_add(ctx_sampling->params, entry); + entry.addflags[j] = true; + + LLAMA_LOG_DEBUG("%s: after add\n", __func__); + elb_print(ctx_sampling->params, entry); + } else if (elb.countup == entry.duration) { + LLAMA_LOG_DEBUG("%s: before sub\n", __func__); + elb_print(ctx_sampling->params, entry); + + elb_sub(ctx_sampling->params, entry); + entry.addflags[j] = false; + + LLAMA_LOG_DEBUG("%s: after sub\n", __func__); + elb_print(ctx_sampling->params, entry); + } + continue; // next entry + } + size_t count = 0; + auto pos = ctx_sampling->to_generated_text->find(phrase, entry.posi[j]); + while (pos != std::string::npos) { + LLAMA_LOG_DEBUG("%s: found %s @ %zu\n", __func__, phrase.c_str(), pos); + ++count; + pos = ctx_sampling->to_generated_text->find(phrase, pos + phrase.length()); + } + entry.posi[j] = std::max(0, int32_t(ctx_sampling->to_generated_text->length()) - int32_t(phrase.length()) + 1); + if (count % 2 == 1) { + // even = no match or cancelled + LLAMA_LOG_DEBUG("%s: before\n", __func__); + elb_print(ctx_sampling->params, entry); + + (entry.addflags[j] ? elb_sub : elb_add)(ctx_sampling->params, entry); + entry.addflags[j] = !entry.addflags[j]; + + LLAMA_LOG_DEBUG("%s: after\n", __func__); + elb_print(ctx_sampling->params, entry); + } + } + } } void common_expiring_logit_bias_accept(struct common_sampler* ctx_sampling, struct llama_context * ctx_main) { @@ -861,25 +936,53 @@ void common_expiring_logit_bias_accept(struct common_sampler* ctx_sampling, stru return; } - auto& elb = ctx_sampling->elb_states[ctx_sampling->elb_idx]; - const int32_t exitword_len = elb.exitword.length(); - if ((elb.delay > ++elb.countup) || (exitword_len == 0)) { + auto idx = ctx_sampling->elb_idx; + auto& elb = ctx_sampling->elb_states[idx]; + if ((elb.delay > ++elb.countup) || (elb.search_word_len == 0)) { return; } - const std::string search_window = ctx_sampling->to_generated_text->substr(std::min( + const std::string window = ctx_sampling->to_generated_text->substr(std::min( ctx_sampling->to_generated_text->length(), - size_t(ctx_sampling->elb_search_pos) - )) + common_token_to_piece(ctx_main, ctx_sampling->prev.back(), true); - - const auto exitword_pos = search_window.find(elb.exitword); - if (exitword_pos != std::string::npos) { + ctx_sampling->elb_search_pos)) + common_token_to_piece(ctx_main, ctx_sampling->prev.back(), true); + size_t pos = 0; + if (string_is_found(window, elb.jumpword, pos)) { + LLAMA_LOG_DEBUG("%s: found %s in %s @ %zu\n", __func__, string_unescape(elb.jumpword).c_str(), string_unescape(window).c_str(), pos); + pos += ctx_sampling->elb_search_pos + elb.jumpword.length(); + ctx_sampling->elb_idx = elb.jump_idx; + } else if (string_is_found(window, elb.exitword, pos)) { + LLAMA_LOG_DEBUG("%s: found %s in %s @ %zu\n", __func__, string_unescape(elb.exitword).c_str(), string_unescape(window).c_str(), pos); + pos += ctx_sampling->elb_search_pos + elb.exitword.length(); ++ctx_sampling->elb_idx; - // no double counting characters that matched - ctx_sampling->elb_search_pos += exitword_pos + exitword_len; } else { - // move search position to include next token - ctx_sampling->elb_search_pos += std::max(0, int32_t(search_window.length()) - exitword_len + 1); + // not found. move search position to include next token + ctx_sampling->elb_search_pos += std::max(0, int32_t(window.length()) - int32_t(elb.search_word_len) + 1); + return; + } + + // single character clearance + // e.g. stop \n\n from expiring two \n immediately + ctx_sampling->elb_search_pos = pos + 1; + + // undo current sampler bias + for (auto& entry: ctx_sampling->params.elb_params[idx].entries) { + for (const auto addflag: entry.addflags) { + if (addflag) { + LLAMA_LOG_DEBUG("%s: before\n", __func__); + elb_print(ctx_sampling->params, entry); + + elb_sub(ctx_sampling->params, entry); + + LLAMA_LOG_DEBUG("%s: after\n", __func__); + elb_print(ctx_sampling->params, entry); + } + } + } + + // prepare next sampler bias + for (auto& entry: ctx_sampling->params.elb_params[ctx_sampling->elb_idx].entries) { + // no clearance for sampler bias + std::fill(entry.posi.begin(), entry.posi.end(), pos); } } diff --git a/common/sampling.h b/common/sampling.h index bdaf5ee5..c36aff4f 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -9,6 +9,8 @@ #include #include +#define A_DOT_B(a, b) a.b + // sampler types enum class llama_sampler_type : char { DRY = 'd', @@ -80,37 +82,49 @@ inline bool common_grammar_needs_prefill(const common_grammar & g) { } +#define X_COMMON_PARAMS_SAMPLING /* \ + */ X( int32_t , min_keep , 0 , std::round ) /* 0 = disabled, otherwise samplers should return at least min_keep tokens \ + */ X( int32_t , top_k , 40 , std::round ) /* <= 0 to use vocab size \ + */ X( float , top_p , 0.95f , ) /* 1.0 = disabled \ + */ X( float , min_p , 0.05f , ) /* 0.0 = disabled \ + */ X( float , tfs_z , 1.00f , ) /* 1.0 = disabled \ + */ X( float , typical_p , 1.00f , ) /* 1.0 = disabled \ + */ X( float , temp , 0.80f , ) /* <= 0.0 to sample greedily, 0.0 to not output probabilities \ + */ X( float , dynatemp_range , 0.00f , ) /* 0.0 = disabled \ + */ X( float , dynatemp_exponent , 1.00f , ) /* controls how entropy maps to temperature in dynamic temperature sampler \ + */ X( int32_t , penalty_last_n , 64 , std::round ) /* last n tokens to penalize (0 = disable penalty, -1 = context size) \ + */ X( float , penalty_repeat , 1.00f , ) /* 1.0 = disabled \ + */ X( float , penalty_freq , 0.00f , ) /* 0.0 = disabled \ + */ X( float , penalty_present , 0.00f , ) /* 0.0 = disabled \ + */ X( float , dry_multiplier , 0.0f , ) /* 0.0 = disabled; DRY repetition penalty for tokens extending repetition: \ + */ X( float , dry_base , 1.75f , ) /* 0.0 = disabled; multiplier * base ^ (length of sequence before token - allowed length) \ + */ X( int32_t , dry_allowed_length , 2 , std::round ) /* tokens extending repetitions beyond this receive penalty \ + */ X( int32_t , dry_penalty_last_n , -1 , std::round ) /* how many tokens to scan for repetitions (0 = disable penalty, -1 = context size) \ + */ X( int32_t , mirostat , 0 , std::round ) /* 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 \ + */ X( float , mirostat_tau , 5.00f , ) /* target entropy \ + */ X( float , mirostat_eta , 0.10f , ) /* learning rate \ + */ X( float , xtc_probability , 0.0f , ) /* xtc probability \ + */ X( float , xtc_threshold , 1.0f , ) /* xtc threshold, disabled if > 0.5 \ + */ X( float , top_n_sigma , 0.0f , ) /* top-n-sigma \ + */ X( float , adaptive_target , -1.0f , ) /* select tokens near this probability (valid range 0.0 to 1.0; <0 = disabled) \ + */ X( float , adaptive_decay , 0.90f , ) /* decay rate for target adaptation over time. lower values -> faster but less stable adaptation. (valid range 0.0 to 1.0; ≤0 = no adaptation) \ + */ X( bool , adaptive_updt_w_cur , false , std::round ) /* update state with current probability \ + */ + +enum { + #undef X + #define X(T, MEMBER, DV, PRECAST) SPARAMS_ ## MEMBER ## _ENUM, + X_COMMON_PARAMS_SAMPLING +}; + // sampling parameters typedef struct common_params_sampling { + #undef X + #define X(T, MEMBER, DV, _) T MEMBER = DV; + X_COMMON_PARAMS_SAMPLING int32_t n_prev = 64; // number of previous tokens to remember int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. - int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens - int32_t top_k = 40; // <= 0 to use vocab size - float top_p = 0.95f; // 1.0 = disabled - float min_p = 0.05f; // 0.0 = disabled - float tfs_z = 1.00f; // 1.0 = disabled - float typical_p = 1.00f; // 1.0 = disabled - float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities - float dynatemp_range = 0.00f; // 0.0 = disabled - float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler - int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) - float penalty_repeat = 1.00f; // 1.0 = disabled - float penalty_freq = 0.00f; // 0.0 = disabled - float penalty_present = 0.00f; // 0.0 = disabled - float dry_multiplier = 0.0f; // 0.0 = disabled; DRY repetition penalty for tokens extending repetition: - float dry_base = 1.75f; // 0.0 = disabled; multiplier * base ^ (length of sequence before token - allowed length) - int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty - int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size) - int32_t total_context_size = 16840; - int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 - float mirostat_tau = 5.00f; // target entropy - float mirostat_eta = 0.10f; // learning rate - float xtc_probability = 0.0f; // xtc probability - float xtc_threshold = 1.0f; // xtc threshold, disabled if > 0.5 - float top_n_sigma = 0.0f; // top-n-sigma - float adaptive_target = -1.0f; // select tokens near this probability (valid range 0.0 to 1.0; <0 = disabled) - float adaptive_decay = 0.90f; // decay rate for target adaptation over time. lower values -> faster but less stable adaptation. (valid range 0.0 to 1.0; ≤0 = no adaptation) - bool adaptive_updt_w_cur = false; // update state with current probability + int32_t total_context_size = 16840; bool penalize_nl = false; // consider newlines as a repeatable token uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context @@ -163,6 +177,10 @@ typedef struct common_params_sampling { // expiring logit bias struct elb_param { struct elb_entry { + std::vector posi; // positions of phrases in generated text + std::vector addsubs; // add/modify then subtract/restore sampling parameters + std::vector addflags; // true if added + size_t max_phrase_len; std::vector phrases; std::vector biases; // for each phrase, nth bias for nth token, extrapolate int32_t duration; // bias duration, unless exitword matches @@ -171,13 +189,18 @@ typedef struct common_params_sampling { return (is_range == other.is_range) && (duration == other.duration) && (biases == other.biases) - && (phrases == other.phrases); + && (phrases == other.phrases) + && (addflags == other.addflags) + && (addsubs == other.addsubs) + && (posi == other.posi); } }; std::vector entries; std::string exitword; // move to next state if matched during generation + std::string op; // exitword operator bool operator == (const struct elb_param& other) const { - return (exitword == other.exitword) + return (op == other.op) + && (exitword == other.exitword) && (entries == other.entries); } }; @@ -233,10 +256,13 @@ struct common_sampler { size_t countup; // compare against duration size_t delay; // to avoid early termination of positively biased phrases int32_t max_cond_len; + std::string jumpword; + size_t jump_idx; + size_t search_word_len; }; std::vector elb_states; size_t elb_idx; // for elb_states - int32_t elb_search_pos; // for exitword + size_t elb_search_pos; }; diff --git a/examples/server/server-context.cpp b/examples/server/server-context.cpp index e5568de5..99345458 100644 --- a/examples/server/server-context.cpp +++ b/examples/server/server-context.cpp @@ -1929,7 +1929,7 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task) do // populate expiring logit bias { - const auto elb_prev_params = slot.sparams.elb_params; + const auto prev_elb_params = slot.sparams.elb_params; const auto& expiring_logit_bias = data.find("expiring_logit_bias"); if (expiring_logit_bias != data.end() && expiring_logit_bias->is_array()) { @@ -1955,9 +1955,9 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task) break; } - if (!slot.elb_prev_states.empty() && (elb_params == elb_prev_params)) { + if (!slot.prev_elb_states.empty() && (elb_params == prev_elb_params)) { // reset and reuse previous states - slot.ctx_sampling->elb_states = slot.elb_prev_states; + slot.ctx_sampling->elb_states = slot.prev_elb_states; for (auto& elb_state: slot.ctx_sampling->elb_states) { elb_state.countup = 0; } @@ -1970,22 +1970,40 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task) slot.ctx_sampling->elb_states.reserve(n_elb_param); // 1 state <-> 1 exitword <-> 1+ entries - for (const auto& [entries, exitword]: elb_params) { - slot.ctx_sampling->elb_states.push_back({ { }, { }, exitword, 0, 0, 0 }); + for (int32_t i = 0; i < elb_params.size(); ++i) { + const auto& [entries, exitword, op] = elb_params[i]; + + if (op == ">>") { + for (auto& elb_state: slot.ctx_sampling->elb_states) { + if (elb_state.jumpword.empty()) { + elb_state.jumpword = exitword; + elb_state.jump_idx = i + 1; + elb_state.search_word_len = std::max(elb_state.exitword.length(), elb_state.jumpword.length()); + } + } + } + + slot.ctx_sampling->elb_states.push_back({ { }, { }, exitword, 0, 0, 0, "", 0, exitword.length() }); + auto& first_tokens = slot.ctx_sampling->elb_states.back().first_tokens; auto& other_tokens = slot.ctx_sampling->elb_states.back().other_tokens; auto& delay = slot.ctx_sampling->elb_states.back().delay; auto& max_cond_len = slot.ctx_sampling->elb_states.back().max_cond_len; // 1 entry <-> 1 phrase <-> 1+ biases - for (auto [phrases, biases, duration, is_range]: entries) { - for (const auto& phrase: phrases) { - if (phrase.empty()) { - continue; - } + for (auto& entry: entries) { + auto biases = entry.biases; + if (biases.empty()) { + // expiring sampler bias + continue; + } + + // expiring logit bias + for (const auto& phrase: entry.phrases) { + auto duration = entry.duration; const auto ids = common_tokenize(model, phrase, false, true); - if (!is_range) { + if (!entry.is_range) { // extrapolate biases.resize(ids.size(), biases.back()); } else if (ids.size() == 1) { @@ -2040,7 +2058,7 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task) }); } } while (false); - slot.elb_prev_states = slot.ctx_sampling->elb_states; + slot.prev_elb_states = slot.ctx_sampling->elb_states; slot.command = SLOT_COMMAND_LOAD_PROMPT; // slot.prompt_tokens.clear(); diff --git a/examples/server/server-context.h b/examples/server/server-context.h index 9d643e02..a33c2113 100644 --- a/examples/server/server-context.h +++ b/examples/server/server-context.h @@ -169,7 +169,7 @@ struct server_slot { common_sampler * ctx_sampling = nullptr; // expiring logit bias - decltype(ctx_sampling->elb_states) elb_prev_states; + std::vector prev_elb_states; bool has_mtp = false;