mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-27 23:50:20 -05:00
server : optimize get_token_probabilities (#24796)
Use std::partial_sort to order only the requested top-n tokens instead
of the full vocabulary
logprobs sort: vocab=128000 n_top=0 iters=100
full sort: 8555.6 us/op
partial sort: 704.3 us/op
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
This commit is contained in:
parent
e475fa2b5f
commit
4b48a53b6c
@ -12,6 +12,7 @@
|
||||
#include <random>
|
||||
#include <sstream>
|
||||
#include <fstream>
|
||||
#include <limits>
|
||||
|
||||
json format_error_response(const std::string & message, const enum error_type type) {
|
||||
std::string type_str;
|
||||
@ -1238,7 +1239,7 @@ json format_response_rerank(
|
||||
// other utils
|
||||
//
|
||||
|
||||
std::vector<llama_token_data> get_token_probabilities(llama_context * ctx, int idx) {
|
||||
std::vector<llama_token_data> get_token_probabilities(llama_context * ctx, int idx, size_t n_top) {
|
||||
std::vector<llama_token_data> cur;
|
||||
|
||||
const auto * logits = llama_get_logits_ith(ctx, idx);
|
||||
@ -1257,21 +1258,34 @@ std::vector<llama_token_data> get_token_probabilities(llama_context * ctx, int i
|
||||
}
|
||||
}
|
||||
|
||||
// sort tokens by logits
|
||||
std::sort(cur.begin(), cur.end(), [](const llama_token_data & a, const llama_token_data & b) {
|
||||
return a.logit > b.logit;
|
||||
});
|
||||
// sort tokens by logits (partial: only the leading `n_top` need ordering)
|
||||
if (n_top > cur.size()) {
|
||||
n_top = cur.size();
|
||||
}
|
||||
if (n_top > 0) {
|
||||
std::partial_sort(cur.begin(), cur.begin() + n_top, cur.end(),
|
||||
[](const llama_token_data & a, const llama_token_data & b) {
|
||||
return a.logit > b.logit;
|
||||
});
|
||||
}
|
||||
|
||||
// apply softmax
|
||||
float max_l = cur[0].logit;
|
||||
float max_l = -std::numeric_limits<float>::infinity();
|
||||
if (n_top > 0) {
|
||||
max_l = cur[0].logit; // partial_sort guarantees the absolute maximum is at index 0
|
||||
} else {
|
||||
for (const auto & t : cur) {
|
||||
max_l = std::max(max_l, t.logit);
|
||||
}
|
||||
}
|
||||
float cum_sum = 0.0f;
|
||||
for (size_t i = 0; i < cur.size(); ++i) {
|
||||
float p = expf(cur[i].logit - max_l);
|
||||
cur[i].p = p;
|
||||
for (auto & t : cur) {
|
||||
float p = expf(t.logit - max_l);
|
||||
t.p = p;
|
||||
cum_sum += p;
|
||||
}
|
||||
for (size_t i = 0; i < cur.size(); ++i) {
|
||||
cur[i].p /= cum_sum;
|
||||
for (auto & t : cur) {
|
||||
t.p /= cum_sum;
|
||||
}
|
||||
|
||||
return cur;
|
||||
|
||||
@ -326,7 +326,7 @@ json format_response_rerank(
|
||||
// other utils
|
||||
//
|
||||
|
||||
std::vector<llama_token_data> get_token_probabilities(llama_context * ctx, int idx);
|
||||
std::vector<llama_token_data> get_token_probabilities(llama_context * ctx, int idx, size_t n_top);
|
||||
|
||||
std::string safe_json_to_str(const json & data);
|
||||
|
||||
|
||||
@ -1824,8 +1824,7 @@ private:
|
||||
});
|
||||
}
|
||||
} else {
|
||||
// TODO: optimize this with min-p optimization
|
||||
std::vector<llama_token_data> cur = get_token_probabilities(ctx_tgt, idx);
|
||||
std::vector<llama_token_data> cur = get_token_probabilities(ctx_tgt, idx, n_probs_request);
|
||||
const size_t max_probs = cur.size();
|
||||
const size_t n_probs = std::min(max_probs, n_probs_request);
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user