#define LLAMA_API_INTERNAL #include "sampling.h" #include "llama-vocab.h" #include "common.h" #include "reasoning-budget.cpp" #include #include #if defined(__GNUC__) && (defined(__x86_64__) || defined(__i386__)) #include #endif #include using json = nlohmann::ordered_json; struct llama_sampler_adaptive_p * llama_clone_adaptive_p(const struct llama_sampler_adaptive_p * adapt_p_ctx); void llama_free_adaptive_p(struct llama_sampler_adaptive_p * adapt_p_ctx); struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) { const llama_vocab * vocab = llama_model_get_vocab(model); struct common_sampler * result = new common_sampler(); result->params = params; result->grammar = nullptr; result->rbudget = nullptr; struct llama_grammar* grmr; const std::string & grammar_str = common_grammar_value(params.grammar); if (grammar_str.compare(0, 11, "%llguidance") == 0) { #ifdef LLAMA_USE_LLGUIDANCE grmr = llama_sampler_init_llg(vocab, "lark", params.grammar.c_str()); result->grammar = grmr; #else GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled"); #endif // LLAMA_USE_LLGUIDANCE } else { std::vector trigger_patterns; std::vector trigger_tokens; for (const auto & trigger : params.grammar_triggers) { switch (trigger.type) { case COMMON_GRAMMAR_TRIGGER_TYPE_WORD: { const auto & word = trigger.value; trigger_patterns.push_back(regex_escape(word)); break; } case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN: { trigger_patterns.push_back(trigger.value); break; } case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL: { const auto & pattern = trigger.value; std::string anchored = "^$"; if (!pattern.empty()) { anchored = (pattern.front() != '^' ? "^" : "") + pattern + (pattern.back() != '$' ? "$" : ""); } trigger_patterns.push_back(anchored); break; } case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN: { const auto token = trigger.token; trigger_tokens.push_back(token); break; } default: GGML_ASSERT(false && "unknown trigger type"); } } std::vector trigger_patterns_c; trigger_patterns_c.reserve(trigger_patterns.size()); for (const auto & regex : trigger_patterns) { trigger_patterns_c.push_back(regex.c_str()); } if (!grammar_str.empty()) { grmr = params.grammar_lazy ? llama_sampler_init_grammar_lazy_patterns(vocab, grammar_str.c_str(), "root", trigger_patterns_c.data(), trigger_patterns_c.size(), trigger_tokens.data(), trigger_tokens.size()) : llama_sampler_init_grammar(vocab, grammar_str.c_str(), "root"); if (grmr) { result->prev.resize(params.n_prev); result->grammar = grmr; } } result->n_valid = 0; result->grammar_str = grammar_str; result->grammar_root = "root"; } // Compute prefill tokens from the generation prompt std::vector prefill_tokens; if (!params.generation_prompt.empty()) { GGML_ASSERT(vocab != nullptr); auto tokens = common_tokenize(vocab, params.generation_prompt, false, true); for (size_t i = 0; i < tokens.size(); i++) { std::string piece = common_token_to_piece(vocab, tokens[i], true); if (i == 0 && std::isspace(piece[0]) && !std::isspace(params.generation_prompt[0])) { // Some tokenizers will add a space before the first special token, need to exclude continue; } LOG_DBG("%s: prefill token: %d = %s\n", __func__, tokens[i], piece.c_str()); prefill_tokens.push_back(tokens[i]); } } // Feed generation prompt tokens to the grammar sampler so it advances past // tokens the template already placed in the prompt. // Only applies to output-format and tool-call grammars; user-supplied grammars must not be prefilled. if (grmr && !params.grammar_lazy && common_grammar_needs_prefill(params.grammar)) { try { for (const auto & token : prefill_tokens) { llama_grammar_accept_impl(*grmr, vocab, nullptr, token); LOG_DBG("%s: grammar accepted prefill token (%d)\n", __func__, token); } } catch (std::exception & e) { LOG_ERR("%s: error initializing grammar sampler for grammar:\n%s\n\nGeneration prompt:\n'%s'\n", __func__, common_grammar_value(params.grammar).c_str(), params.generation_prompt.c_str()); throw e; } } // reasoning budget sampler (skip when budget is unlimited unless a lazy grammar is active, which needs rbudget for thinking-block suppression) if (!params.reasoning_budget_start.empty() && !params.reasoning_budget_end.empty() && (params.grammar_lazy || params.reasoning_budget_tokens >= 0)) { result->rbudget = common_reasoning_budget_init( vocab, params.reasoning_budget_start, params.reasoning_budget_end, params.reasoning_budget_forced, params.reasoning_budget_tokens < 0 ? INT_MAX : params.reasoning_budget_tokens); for (const auto & token : prefill_tokens) { common_reasoning_budget_accept(result->rbudget, token); LOG_DBG("%s: reasoning-budget accepted prefill token (%d)\n", __func__, token); } } llama_sampling_set_rng_seed(result, params.seed); for (const auto& cnstr : params.samplers_sequence) { switch (cnstr) { case llama_sampler_type::DRY: { std::vector c_breakers; c_breakers.reserve(params.dry_sequence_breakers.size()); for (const auto& str : params.dry_sequence_breakers) { c_breakers.push_back(str.c_str()); } result->smpl=llama_sampler_init_dry(vocab, params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()); break; } case llama_sampler_type::ADAPTIVE_P: { if (params.adaptive_target >= 0.0f) { GGML_ASSERT(vocab); auto n_vocab = llama_vocab_n_tokens(vocab); result->adapt_p_ctx = llama_init_adaptive_p(n_vocab, params.adaptive_target, params.adaptive_decay, params.adaptive_updt_w_cur, result->rng()); } break; } default: break; } } result->elb_idx = 0; result->elb_search_pos = 0; return result; } void common_sampler_free(struct common_sampler * ctx) { if (!ctx) { return; } if (ctx->grammar) { llama_grammar_free(ctx->grammar); } if (ctx->smpl) llama_sampler_dry_free(ctx->smpl); if (ctx->adapt_p_ctx) llama_free_adaptive_p(ctx->adapt_p_ctx); if (ctx->rbudget) common_reasoning_budget_free(ctx->rbudget); delete ctx; } static void llama_grammar_reset(common_sampler * ctx) { if (!ctx->grammar) { return; } std::vector trigger_patterns_c; trigger_patterns_c.reserve(ctx->grammar->trigger_patterns.size()); for (auto& trigger_pattern : ctx->grammar->trigger_patterns) { trigger_patterns_c.push_back(trigger_pattern.pattern.c_str()); } auto* grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str(), ctx->grammar->lazy, trigger_patterns_c.data(), trigger_patterns_c.size(), ctx->grammar->trigger_tokens.data(), ctx->grammar->trigger_tokens.size()); llama_grammar_free_impl(ctx->grammar); ctx->grammar = grammar_new; } void common_sampler_reset(common_sampler * ctx) { // llama_grammar_reset(ctx); ctx->prev.clear(); llama_sampler_dry_reset(ctx->smpl); } void common_sampler_review(common_sampler * ctx, const size_t n_unsent, const bool rewind_status) { // add stateful samplers here if (ctx->adapt_p_ctx != nullptr) { llama_review_adaptive_p(ctx->adapt_p_ctx, n_unsent, rewind_status); } } void llama_sampling_set_rng_seed(struct common_sampler * ctx, uint32_t seed) { if (seed == LLAMA_DEFAULT_SEED) { seed = std::random_device{}(); } ctx->rng.seed(seed); } void common_sampler_clone(common_sampler * src, common_sampler * dst) { dst->params = src->params; dst->mirostat_mu = src->mirostat_mu; dst->n_valid = src->n_valid; dst->rng = src->rng; dst->server_biases = src->server_biases; if (dst->grammar) { llama_grammar_free(dst->grammar); dst->grammar = nullptr; } if (src->grammar) { dst->grammar_root = src->grammar_root; dst->grammar_str = src->grammar_str; dst->grammar = llama_grammar_copy(src->grammar); } dst->prev = src->prev; if (dst->smpl) { llama_sampler_dry_free(dst->smpl); dst->smpl = nullptr; } if (src->smpl) { dst->smpl = llama_sampler_dry_clone(src->smpl); } if (dst->adapt_p_ctx) { llama_free_adaptive_p(dst->adapt_p_ctx); dst->adapt_p_ctx = nullptr; } if (src->adapt_p_ctx) { dst->adapt_p_ctx = llama_clone_adaptive_p(src->adapt_p_ctx); } if (dst->rbudget) { common_reasoning_budget_free(dst->rbudget); dst->rbudget = nullptr; } if (src->rbudget) { dst->rbudget = common_reasoning_budget_clone(src->rbudget); } } llama_token llama_sampling_last(common_sampler * ctx) { return ctx->prev.back(); } std::string llama_sampling_prev_str(common_sampler * ctx_sampling, llama_context * ctx_main, int n) { const int size = ctx_sampling->prev.size(); n = std::min(n, size); std::string result; for (int i = size - n; i < size; i++) { result += common_token_to_piece(ctx_main, ctx_sampling->prev[i]); } return result; } std::string llama_sampling_print(const common_params_sampling & params) { char result[1024]; snprintf(result, sizeof(result), "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n" "\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, typical_p = %.3f, temp = %.3f\n" "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f\n" "\txtc_probability = %.3f, xtc_threshold = %.3f, top_n_sigma = %.3f\n" "\tadaptive_target = %.2f, adaptive_decay = %.2f", params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present, params.top_k, params.tfs_z, params.top_p, params.min_p, params.typical_p, params.temp, params.mirostat, params.mirostat_eta, params.mirostat_tau, params.xtc_probability, params.xtc_threshold, params.top_n_sigma, params.adaptive_target, params.adaptive_decay); return std::string(result); } std::string llama_sampling_order_print(const common_params_sampling & params) { std::string result = "CFG -> Penalties "; if (params.mirostat == 0) { for (auto sampler_type : params.samplers_sequence) { const auto sampler_type_name = llama_sampling_type_to_str(sampler_type); if (!sampler_type_name.empty()) { result += "-> " + sampler_type_name + " "; } } } else { result += "-> mirostat "; } return result; } std::string llama_sampling_type_to_str(llama_sampler_type sampler_type) { switch (sampler_type) { case llama_sampler_type::DRY: return "dry"; case llama_sampler_type::TOP_K: return "top_k"; case llama_sampler_type::TFS_Z: return "tfs_z"; case llama_sampler_type::TYPICAL_P: return "typical_p"; case llama_sampler_type::TOP_P: return "top_p"; case llama_sampler_type::MIN_P: return "min_p"; case llama_sampler_type::TEMPERATURE: return "temperature"; case llama_sampler_type::XTC : return "xtc"; case llama_sampler_type::TOP_N_SIGMA: return "top_n_sigma"; case llama_sampler_type::ADAPTIVE_P : return "adaptive_p"; default : return ""; } } std::vector llama_sampling_types_from_names(const std::vector & names, bool allow_alt_names) { std::unordered_map sampler_canonical_name_map { {"dry", llama_sampler_type::DRY}, {"top_k", llama_sampler_type::TOP_K}, {"top_p", llama_sampler_type::TOP_P}, {"typical_p", llama_sampler_type::TYPICAL_P}, {"min_p", llama_sampler_type::MIN_P}, {"tfs_z", llama_sampler_type::TFS_Z}, {"xtc", llama_sampler_type::XTC}, {"top_n_sigma", llama_sampler_type::TOP_N_SIGMA}, {"temperature", llama_sampler_type::TEMPERATURE}, {"adaptive_p", llama_sampler_type::ADAPTIVE_P}, }; // since samplers names are written multiple ways // make it ready for both system names and input names std::unordered_map sampler_alt_name_map { {"dry", llama_sampler_type::DRY}, {"top-k", llama_sampler_type::TOP_K}, {"top-p", llama_sampler_type::TOP_P}, {"nucleus", llama_sampler_type::TOP_P}, {"typical-p", llama_sampler_type::TYPICAL_P}, {"typical", llama_sampler_type::TYPICAL_P}, {"min-p", llama_sampler_type::MIN_P}, {"tfs-z", llama_sampler_type::TFS_Z}, {"tfs", llama_sampler_type::TFS_Z}, {"xtc", llama_sampler_type::XTC}, {"top-n-sigma", llama_sampler_type::TOP_N_SIGMA}, {"temp", llama_sampler_type::TEMPERATURE}, {"adaptive-p", llama_sampler_type::ADAPTIVE_P}, }; std::vector sampler_types; sampler_types.reserve(names.size()); for (const auto & name : names) { auto sampler_item = sampler_canonical_name_map.find(name); if (sampler_item != sampler_canonical_name_map.end()) { sampler_types.push_back(sampler_item->second); } else { if (allow_alt_names) { sampler_item = sampler_alt_name_map.find(name); if (sampler_item != sampler_alt_name_map.end()) { sampler_types.push_back(sampler_item->second); } } } } return sampler_types; } std::vector llama_sampling_types_from_chars(const std::string & names_string) { std::unordered_map sampler_name_map { {'d', llama_sampler_type::DRY}, {'k', llama_sampler_type::TOP_K}, {'p', llama_sampler_type::TOP_P}, {'y', llama_sampler_type::TYPICAL_P}, {'m', llama_sampler_type::MIN_P}, {'f', llama_sampler_type::TFS_Z}, {'x', llama_sampler_type::XTC}, {'n', llama_sampler_type::TOP_N_SIGMA}, {'t', llama_sampler_type::TEMPERATURE}, {'w', llama_sampler_type::ADAPTIVE_P}, }; std::vector sampler_types; sampler_types.reserve(names_string.size()); for (const auto & c : names_string) { const auto sampler_item = sampler_name_map.find(c); if (sampler_item != sampler_name_map.end()) { sampler_types.push_back(sampler_item->second); } } return sampler_types; } // no reasons to expose this function in header static void sampler_queue( struct llama_context* ctx_main, const common_params_sampling& params, common_sampler * ctx_sampling, llama_token_data_array& cur_p, size_t min_keep) { const float temp = params.temp; const float dynatemp_range = params.dynatemp_range; const float dynatemp_exponent = params.dynatemp_exponent; const int32_t top_k = params.top_k; const float top_p = params.top_p; const float min_p = params.min_p; const float tfs_z = params.tfs_z; const float typical_p = params.typical_p; const float xtc_probability = params.xtc_probability; const float xtc_threshold = params.xtc_threshold; const float top_n_sigma = params.top_n_sigma; const std::vector & samplers_sequence = params.samplers_sequence; bool use_adaptive_p = false; // see below for (auto sampler_type : samplers_sequence) { switch (sampler_type) { case llama_sampler_type::DRY : llama_sample_dry (ctx_main, ctx_sampling->smpl, &cur_p); break; case llama_sampler_type::TOP_K : llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); break; case llama_sampler_type::TFS_Z : llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); break; case llama_sampler_type::TYPICAL_P : llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); break; case llama_sampler_type::TOP_P : llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); break; case llama_sampler_type::MIN_P : llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); break; case llama_sampler_type::XTC : llama_sample_xtc (ctx_main, &cur_p, xtc_probability, xtc_threshold, min_keep); break; case llama_sampler_type::TOP_N_SIGMA: llama_sample_top_n_sigma(ctx_main, &cur_p, top_n_sigma); break; case llama_sampler_type::DIST : llama_sample_dist (ctx_main, &cur_p); break; case llama_sampler_type::TEMPERATURE: if (dynatemp_range > 0) { float dynatemp_min = std::max(0.0f, temp - dynatemp_range); float dynatemp_max = std::max(0.0f, temp + dynatemp_range); llama_sample_entropy(ctx_main, &cur_p, dynatemp_min, dynatemp_max, dynatemp_exponent); } else { llama_sample_temp(ctx_main, &cur_p, temp); } break; case llama_sampler_type::ADAPTIVE_P: use_adaptive_p = ctx_sampling->adapt_p_ctx != nullptr; break; default : break; } } if (use_adaptive_p) { // adaptive p should be put to the last, so we ignore the order in the sampler llama_sample_adaptive_p(ctx_main, &cur_p, ctx_sampling->adapt_p_ctx); } } static bool grammar_should_apply(struct common_sampler * gsmpl) { if (!gsmpl->grammar) { return false; } if (!gsmpl->rbudget) { return true; } if (gsmpl->params.grammar_lazy) { // if grammar is lazy, only apply when reasoning budget is not active const auto state = common_reasoning_budget_get_state(gsmpl->rbudget); return state == REASONING_BUDGET_IDLE || state == REASONING_BUDGET_DONE; } return true; } static llama_token llama_sampling_sample_impl( struct common_sampler * ctx_sampling, struct llama_context * ctx_main, struct llama_context * ctx_cfg, const int idx, bool grammar_first) { const common_params_sampling & params = ctx_sampling->params; const float temp = params.temp; const int mirostat = params.mirostat; const float mirostat_tau = params.mirostat_tau; const float mirostat_eta = params.mirostat_eta; const float adaptive_target = params.adaptive_target; std::vector original_logits; llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, /* grammar_first= */ grammar_first, &original_logits); llama_token_data_array & cur_p = ctx_sampling->cur_p; if (ctx_sampling->grammar != NULL && !grammar_first) { GGML_ASSERT(!original_logits.empty()); } auto & rbudget = ctx_sampling->rbudget; llama_token id = 0; float * logits = llama_get_logits_ith(ctx_main, idx); // apply reasoning budget first common_reasoning_budget_apply(rbudget, &cur_p); // Sample grammar first for resampling if (ctx_sampling->grammar != NULL && grammar_first && grammar_should_apply(ctx_sampling)) { // Apply grammar constraints to all candidates llama_grammar_apply(ctx_sampling->grammar, ctx_main, &cur_p); } // llama_sampler_apply if (temp < 0.0) { // greedy sampling, with probs llama_sample_softmax(ctx_main, &cur_p); id = cur_p.data[0].id; } else if (temp == 0.0) { // greedy sampling, no probs id = llama_sample_token_greedy(ctx_main, &cur_p); } else { if (mirostat == 1) { const int mirostat_m = 100; llama_sample_temp(ctx_main, &cur_p, temp); id = llama_sample_token_mirostat(ctx_main, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_sampling->mirostat_mu); } else if (mirostat == 2) { llama_sample_temp(ctx_main, &cur_p, temp); id = llama_sample_token_mirostat_v2(ctx_main, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu); } else if (adaptive_target >= 0.0f && ctx_sampling->adapt_p_ctx!=nullptr) { // adaptive p sampling llama_prep_adaptive_p(ctx_main, &cur_p, ctx_sampling->adapt_p_ctx); sampler_queue(ctx_main, params, ctx_sampling, cur_p, std::max(1, params.min_keep)); id = llama_sample_token_adaptive_p(ctx_main, &cur_p, ctx_sampling->adapt_p_ctx); } else { // temperature sampling size_t min_keep = std::max(1, params.min_keep); sampler_queue(ctx_main, params,ctx_sampling, cur_p, min_keep); id = llama_sample_token_with_rng(ctx_main, &cur_p, ctx_sampling->rng); } } if (grammar_first || !grammar_should_apply(ctx_sampling)) { return id; } if (ctx_sampling->grammar != NULL && !grammar_first && grammar_should_apply(ctx_sampling)) { // Get a pointer to the logits float * logits = llama_get_logits_ith(ctx_main, idx); // Create an array with a single token data element for the sampled id llama_token_data single_token_data = {id, logits[id], 0.0f}; llama_token_data_array single_token_data_array = { &single_token_data, 1, false }; // Apply grammar constraints to the single token llama_grammar_apply(ctx_sampling->grammar, ctx_main, &single_token_data_array); // Check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY bool is_valid = single_token_data_array.data[0].logit != -INFINITY; // If the token is not valid according to the grammar, perform resampling if (!is_valid) { LOG("Resampling because token %d: '%s' does not meet grammar rules\n", id, common_token_to_piece(ctx_main, id).c_str()); // Restore logits from the copy std::copy(original_logits.begin(), original_logits.end(), logits); return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, /* is_resampling= */ true); } } ctx_sampling->n_valid = temp == 0.0f ? 0 : cur_p.size; return id; } static llama_token_data_array llama_sampling_prepare_impl( struct common_sampler * ctx_sampling, struct llama_context * ctx_main, struct llama_context * ctx_cfg, const int idx, bool grammar_first, std::vector * original_logits) { const common_params_sampling & params = ctx_sampling->params; const int n_vocab = llama_n_vocab(llama_get_model(ctx_main)); const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n; const float penalty_repeat = params.penalty_repeat; const float penalty_freq = params.penalty_freq; const float penalty_present = params.penalty_present; const bool penalize_nl = params.penalize_nl; auto & prev = ctx_sampling->prev; auto & cur = ctx_sampling->cur; // Get a pointer to the logits float * logits = llama_get_logits_ith(ctx_main, idx); if (ctx_sampling->grammar != NULL && !grammar_first) { GGML_ASSERT(original_logits != NULL); // Only make a copy of the original logits if we are not applying grammar checks, not sure if I actually have to do this. *original_logits = {logits, logits + n_vocab}; } // apply params.logit_bias map for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) { logits[it->first] += it->second; } if (ctx_cfg) { float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx); llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale); } if (ctx_sampling->elb_states.size() > ctx_sampling->elb_idx) { common_expiring_logit_bias_apply(ctx_sampling, logits); } cur.resize(n_vocab); if ((ctx_sampling->server_biases != nullptr) && (ctx_sampling->server_biases->size() == n_vocab)) { for (llama_token token_id = 0; token_id < n_vocab; token_id++) { cur[token_id] = llama_token_data{token_id, logits[token_id] + ctx_sampling->server_biases->at(token_id), 0.0f}; } } else { for (llama_token token_id = 0; token_id < n_vocab; token_id++) { cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; } } ctx_sampling->cur_p = { cur.data(), cur.size(), false }; llama_token_data_array & cur_p = ctx_sampling->cur_p; // apply penalties const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev; const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n); if (penalty_tokens_used_size) { const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))]; llama_sample_repetition_penalties(ctx_main, &cur_p, penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size, penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present); if (!penalize_nl) { for (size_t idx = 0; idx < cur_p.size; idx++) { if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) { cur_p.data[idx].logit = nl_logit; break; } } } } // apply grammar checks before sampling logic if (grammar_first && ctx_sampling->grammar != NULL) { llama_grammar_apply(ctx_sampling->grammar, ctx_main, &cur_p); } return cur_p; } llama_token common_sampler_sample_legacy( struct common_sampler * ctx_sampling, struct llama_context * ctx_main, struct llama_context * ctx_cfg, const int idx) { // Call the implementation function with is_resampling set to false by default return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, /* is_resampling= */ false); } llama_token common_sampler_sample( struct common_sampler * ctx_sampling, struct llama_context * ctx_main, const int idx, bool grammar_first) { // Call the implementation function with is_resampling set to false by default return llama_sampling_sample_impl(ctx_sampling, ctx_main, nullptr, idx, /* is_resampling= */ grammar_first); } llama_token_data_array llama_sampling_prepare( struct common_sampler * ctx_sampling, struct llama_context * ctx_main, struct llama_context * ctx_cfg, const int idx, bool grammar_first, std::vector * original_logits) { return llama_sampling_prepare_impl(ctx_sampling,ctx_main, ctx_cfg, idx, grammar_first, original_logits); } void common_sampler_accept( struct common_sampler * ctx_sampling, struct llama_context * ctx_main, llama_token token, bool is_generated) { if (ctx_sampling->prev.size() > 0) { ctx_sampling->prev.erase(ctx_sampling->prev.begin()); } ctx_sampling->prev.push_back(token); // grammar_should_apply() checks the reasoning budget state, so calculate this before we accept const auto accept_grammar = is_generated && grammar_should_apply(ctx_sampling); if (ctx_sampling->rbudget && is_generated) { common_reasoning_budget_accept(ctx_sampling->rbudget, token); } if (ctx_sampling->grammar && accept_grammar) { llama_grammar_accept_token(ctx_sampling->grammar, ctx_main, token); } if (ctx_sampling->smpl) { llama_sampler_dry_accept(ctx_sampling->smpl, token); } if (ctx_sampling->elb_states.size() > ctx_sampling->elb_idx) { common_expiring_logit_bias_accept(ctx_sampling, ctx_main); } } llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl, bool do_sort) { auto * res = &gsmpl->cur_p; if (do_sort && !res->sorted) { // remember the selected token before sorting const llama_token id = res->data[res->selected].id; std::sort(res->data, res->data + res->size, [](const llama_token_data & a, const llama_token_data & b) { return a.p > b.p; }); // restore the selected token after sorting for (size_t i = 0; i < res->size; ++i) { if (res->data[i].id == id) { res->selected = i; break; } } res->sorted = true; } return res; } std::vector llama_sampling_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & draft) { std::vector idxs(draft.size() + 1); for (size_t i = 0; i < idxs.size(); ++i) { idxs[i] = i; } return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft); } std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const std::vector & draft, bool grammar_first) { GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1"); std::vector result; result.reserve(idxs.size()); size_t i = 0; for (; i < draft.size(); i++) { const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first); gsmpl->drafted_text += common_token_to_piece(ctx, id, true); common_sampler_accept(gsmpl, ctx, id, true); result.push_back(id); if (draft[i] != id) { break; } } if (i == draft.size()) { const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first); gsmpl->drafted_text += common_token_to_piece(ctx, id, true); common_sampler_accept(gsmpl, ctx, id, true); result.push_back(id); } return result; } static void elb_print(common_params_sampling& sparams, const common_params_sampling::elb_param::elb_entry& entry) { #undef X #define X(T, MEMBER, DV, PRECAST) #MEMBER, static const std::vector names = { X_COMMON_PARAMS_SAMPLING }; #undef X #define X(T, MEMBER, DV, PRECAST) if (std::abs(entry.addsubs[SPARAMS_ ## MEMBER ## _ENUM]) > 0.0f) \ { LLAMA_LOG_DEBUG("%s: %s = %f\n", __func__, names[SPARAMS_ ## MEMBER ## _ENUM].c_str(), float(A_DOT_B(sparams, MEMBER))); } X_COMMON_PARAMS_SAMPLING } static void elb_add(common_params_sampling& sparams, const common_params_sampling::elb_param::elb_entry& entry) { #undef X #define X(T, MEMBER, _, PRECAST) A_DOT_B(sparams, MEMBER) += static_cast(PRECAST(entry.addsubs[SPARAMS_ ## MEMBER ## _ENUM])); X_COMMON_PARAMS_SAMPLING } static void elb_sub(common_params_sampling& sparams, const common_params_sampling::elb_param::elb_entry& entry) { #undef X #define X(T, MEMBER, _, PRECAST) A_DOT_B(sparams, MEMBER) -= static_cast(PRECAST(entry.addsubs[SPARAMS_ ## MEMBER ## _ENUM])); X_COMMON_PARAMS_SAMPLING } void common_expiring_logit_bias_apply(struct common_sampler* ctx_sampling, float* logits) { auto index_first_inactive = [](auto countup, auto& tokens) { return std::distance( tokens.begin(), std::upper_bound(tokens.begin(), tokens.end(), countup, [](const auto& countup, const auto& token) { return countup > token.duration; }) ); }; const auto& elb = ctx_sampling->elb_states[ctx_sampling->elb_idx]; std::string combined_text; const std::string* search_window = &combined_text; if (!ctx_sampling->drafted_text.empty()) { // add speculated tokens combined_text = ctx_sampling->to_generated_text != nullptr ? ( ctx_sampling->to_generated_text->substr(std::max(0, int32_t(ctx_sampling->to_generated_text->length()) - elb.max_cond_len)) ) : "" + ctx_sampling->drafted_text; } else if (ctx_sampling->to_generated_text != nullptr) { search_window = ctx_sampling->to_generated_text; } if (!search_window->empty() && !elb.other_tokens.empty() && (elb.other_tokens.front().duration > elb.countup)) { const auto ifi = index_first_inactive(elb.countup, elb.other_tokens); for (size_t j = 0; j < ifi; ++j) { const auto& [id, bias, _, cond] = elb.other_tokens[j]; if (string_ends_with(*search_window, cond)) { logits[id] += bias; } } } if (!elb.first_tokens.empty() && (elb.first_tokens.front().duration > elb.countup)) { const auto ifi = index_first_inactive(elb.countup, elb.first_tokens); if (search_window->empty()) { // empty case here for (size_t j = 0; j < ifi; ++j) { logits[elb.first_tokens[j].id] += elb.first_tokens[j].bias; } } else { for (size_t j = 0; j < ifi; ++j) { const auto& [id, bias, _, cond] = elb.first_tokens[j]; // no bias if seen (probably too late) if (!string_ends_with(*search_window, cond)) { logits[id] += bias; } } } } // expiring sampler bias for (auto& entry: ctx_sampling->params.elb_params[ctx_sampling->elb_idx].entries) { if (!entry.biases.empty()) { continue; // next entry } for (size_t j = 0; j < entry.phrases.size(); ++j) { const auto& phrase = entry.phrases[j]; if (phrase.empty()) { // duration bound only if (elb.countup == 0) { LLAMA_LOG_DEBUG("%s: before add\n", __func__); elb_print(ctx_sampling->params, entry); elb_add(ctx_sampling->params, entry); entry.addflags[j] = true; LLAMA_LOG_DEBUG("%s: after add\n", __func__); elb_print(ctx_sampling->params, entry); } else if (elb.countup == entry.duration) { LLAMA_LOG_DEBUG("%s: before sub\n", __func__); elb_print(ctx_sampling->params, entry); elb_sub(ctx_sampling->params, entry); entry.addflags[j] = false; LLAMA_LOG_DEBUG("%s: after sub\n", __func__); elb_print(ctx_sampling->params, entry); } continue; // next entry } size_t count = 0; auto pos = ctx_sampling->to_generated_text->find(phrase, entry.posi[j]); while (pos != std::string::npos) { LLAMA_LOG_DEBUG("%s: found %s @ %zu\n", __func__, phrase.c_str(), pos); ++count; pos = ctx_sampling->to_generated_text->find(phrase, pos + phrase.length()); } entry.posi[j] = std::max(0, int32_t(ctx_sampling->to_generated_text->length()) - int32_t(phrase.length()) + 1); if (count % 2 == 1) { // even = no match or cancelled LLAMA_LOG_DEBUG("%s: before\n", __func__); elb_print(ctx_sampling->params, entry); (entry.addflags[j] ? elb_sub : elb_add)(ctx_sampling->params, entry); entry.addflags[j] = !entry.addflags[j]; LLAMA_LOG_DEBUG("%s: after\n", __func__); elb_print(ctx_sampling->params, entry); } } } } void common_expiring_logit_bias_accept(struct common_sampler* ctx_sampling, struct llama_context * ctx_main) { if (ctx_sampling->to_generated_text == nullptr) { // prompt processing return; } auto idx = ctx_sampling->elb_idx; auto& elb = ctx_sampling->elb_states[idx]; if ((elb.delay > ++elb.countup) || (elb.search_word_len == 0)) { return; } const std::string window = ctx_sampling->to_generated_text->substr(std::min( ctx_sampling->to_generated_text->length(), ctx_sampling->elb_search_pos)) + common_token_to_piece(ctx_main, ctx_sampling->prev.back(), true); size_t pos = 0; if (string_is_found(window, elb.jumpword, pos)) { LLAMA_LOG_DEBUG("%s: found %s in %s @ %zu\n", __func__, string_unescape(elb.jumpword).c_str(), string_unescape(window).c_str(), pos); pos += ctx_sampling->elb_search_pos + elb.jumpword.length(); ctx_sampling->elb_idx = elb.jump_idx; } else if (string_is_found(window, elb.exitword, pos)) { LLAMA_LOG_DEBUG("%s: found %s in %s @ %zu\n", __func__, string_unescape(elb.exitword).c_str(), string_unescape(window).c_str(), pos); pos += ctx_sampling->elb_search_pos + elb.exitword.length(); ++ctx_sampling->elb_idx; } else { // not found. move search position to include next token ctx_sampling->elb_search_pos += std::max(0, int32_t(window.length()) - int32_t(elb.search_word_len) + 1); return; } // single character clearance // e.g. stop \n\n from expiring two \n immediately ctx_sampling->elb_search_pos = pos + 1; // undo current sampler bias for (auto& entry: ctx_sampling->params.elb_params[idx].entries) { for (const auto addflag: entry.addflags) { if (addflag) { LLAMA_LOG_DEBUG("%s: before\n", __func__); elb_print(ctx_sampling->params, entry); elb_sub(ctx_sampling->params, entry); LLAMA_LOG_DEBUG("%s: after\n", __func__); elb_print(ctx_sampling->params, entry); } } } // prepare next sampler bias for (auto& entry: ctx_sampling->params.elb_params[ctx_sampling->elb_idx].entries) { // no clearance for sampler bias std::fill(entry.posi.begin(), entry.posi.end(), pos); } } template <> json common_grammar_trigger::to_json() const { json out{ {"type", (int)type}, {"value", value}, }; if (type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) { out["token"] = (int)token; } return out; } template <> common_grammar_trigger common_grammar_trigger::from_json(const json& in) { common_grammar_trigger out; out.type = (common_grammar_trigger_type)in.at("type").get(); out.value = in.at("value").get(); if (out.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) { out.token = (llama_token)in.at("token").get(); } return out; } #if defined(__GNUC__) && (defined(__x86_64__) || defined(__i386__)) __attribute__((target("avx2"))) static bool common_sampler_speculative_top1_avx2(const float * logits, const int n_vocab, int & best_id, float & max_val) { if (n_vocab < 8) { return false; } __m256 max_v = _mm256_loadu_ps(logits); __m256i id_v = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); const __m256i step = _mm256_set1_epi32(8); __m256i cur_id = _mm256_add_epi32(id_v, step); int i = 8; for (; i + 7 < n_vocab; i += 8) { const __m256 x = _mm256_loadu_ps(logits + i); const __m256 gt_max = _mm256_cmp_ps(x, max_v, _CMP_GT_OQ); max_v = _mm256_blendv_ps(max_v, x, gt_max); id_v = _mm256_blendv_epi8(id_v, cur_id, _mm256_castps_si256(gt_max)); cur_id = _mm256_add_epi32(cur_id, step); } alignas(32) float max_buf[8]; alignas(32) int id_buf[8]; _mm256_store_ps(max_buf, max_v); _mm256_store_si256((__m256i *) id_buf, id_v); best_id = id_buf[0]; max_val = max_buf[0]; for (int j = 1; j < 8; ++j) { if (max_buf[j] > max_val) { max_val = max_buf[j]; best_id = id_buf[j]; } } for (; i < n_vocab; ++i) { if (logits[i] > max_val) { max_val = logits[i]; best_id = i; } } return true; } __attribute__((target("avx2,fma"))) static inline __m256 v_expf(__m256 x) { const __m256 r = _mm256_set1_ps(0x1.8p23f); const __m256 z = _mm256_fmadd_ps(x, _mm256_set1_ps(0x1.715476p+0f), r); const __m256 n = _mm256_sub_ps(z, r); const __m256 b = _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.7f7d1cp-20f), _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.62e4p-1f), x)); const __m256i e = _mm256_slli_epi32(_mm256_castps_si256(z), 23); const __m256 k = _mm256_castsi256_ps( _mm256_add_epi32(e, _mm256_castps_si256(_mm256_set1_ps(1)))); const __m256i c = _mm256_castps_si256( _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n), _mm256_set1_ps(126), _CMP_GT_OQ)); const __m256 u = _mm256_mul_ps(b, b); const __m256 j = _mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_set1_ps(0x1.0e4020p-7f), b, _mm256_set1_ps(0x1.573e2ep-5f)), u, _mm256_fmadd_ps(_mm256_set1_ps(0x1.555e66p-3f), b, _mm256_set1_ps(0x1.fffdb6p-2f))), u, _mm256_mul_ps(_mm256_set1_ps(0x1.ffffecp-1f), b)); if (!_mm256_movemask_ps(_mm256_castsi256_ps(c))) return _mm256_fmadd_ps(j, k, k); const __m256i g = _mm256_and_si256( _mm256_castps_si256(_mm256_cmp_ps(n, _mm256_setzero_ps(), _CMP_LE_OQ)), _mm256_set1_epi32(0x82000000u)); const __m256 s1 = _mm256_castsi256_ps(_mm256_add_epi32(g, _mm256_set1_epi32(0x7f000000u))); const __m256 s2 = _mm256_castsi256_ps(_mm256_sub_epi32(e, g)); const __m256i d = _mm256_castps_si256( _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n), _mm256_set1_ps(192), _CMP_GT_OQ)); return _mm256_or_ps( _mm256_and_ps(_mm256_castsi256_ps(d), _mm256_mul_ps(s1, s1)), _mm256_andnot_ps( _mm256_castsi256_ps(d), _mm256_or_ps( _mm256_and_ps(_mm256_castsi256_ps(c), _mm256_mul_ps(_mm256_fmadd_ps(s2, j, s2), s1)), _mm256_andnot_ps(_mm256_castsi256_ps(c), _mm256_fmadd_ps(k, j, k))))); } __attribute__((target("avx2"))) static inline float hsum_float_4(__m128 x) { x = _mm_add_ps(x, _mm_movehl_ps(x, x)); x = _mm_add_ss(x, _mm_movehdup_ps(x)); return _mm_cvtss_f32(x); } __attribute__((target("avx2"))) static inline float hsum_float_8(__m256 x) { return hsum_float_4(_mm_add_ps(_mm256_castps256_ps128(x), _mm256_extractf128_ps(x, 1))); } __attribute__((target("avx2,fma"))) static float prob_avx2(int n, const float * logits, float max_val) { float sumf = 0; int i = 0; if (n >= 8) { auto sum_v = _mm256_setzero_ps(); auto max_v = _mm256_set1_ps(max_val); for (; i < n - 7; i += 8) { auto x = _mm256_loadu_ps(logits + i); auto exp_x = v_expf(_mm256_sub_ps(x, max_v)); sum_v = _mm256_add_ps(sum_v, exp_x); } sumf = hsum_float_8(sum_v); } for (; i < n; ++i) { sumf += expf(logits[i] - max_val); } return 1.0f/sumf; } #endif static float prob_scalar(int n, const float * logits, float max_val) { double sum_exp = 0.0; for (int i = 0; i < n; ++i) { sum_exp += exp((double)(logits[i] - max_val)); } return (float)(1./sum_exp); } llama_token common_sampler_sample_speculative(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, float * out_prob) { GGML_UNUSED(gsmpl); float * logits = llama_get_logits_ith(ctx, idx); const int n_vocab = llama_n_vocab(llama_get_model(ctx)); int best_id = 0; float max_val = logits[0]; #if defined(__GNUC__) && (defined(__x86_64__) || defined(__i386__)) static const bool has_avx2 = __builtin_cpu_supports("avx2"); if (has_avx2 && common_sampler_speculative_top1_avx2(logits, n_vocab, best_id, max_val)) { if (out_prob) { static const bool has_fma = __builtin_cpu_supports("fma"); if (has_fma) { *out_prob = prob_avx2(n_vocab, logits, max_val); } else { *out_prob = prob_scalar(n_vocab, logits, max_val); } } return best_id; } #endif for (int i = 1; i < n_vocab; ++i) { if (logits[i] > max_val) { max_val = logits[i]; best_id = i; } } if (out_prob) { *out_prob = prob_scalar(n_vocab, logits, max_val); } return best_id; }