From 4b48a53b6cc60e051f35f2acbd06264a909bb255 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrien=20Gallou=C3=ABt?= Date: Fri, 19 Jun 2026 23:26:54 +0200 Subject: [PATCH] server : optimize get_token_probabilities (#24796) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Use std::partial_sort to order only the requested top-n tokens instead of the full vocabulary logprobs sort: vocab=128000 n_top=0 iters=100 full sort: 8555.6 us/op partial sort: 704.3 us/op Signed-off-by: Adrien Gallouët --- tools/server/server-common.cpp | 36 +++++++++++++++++++++++---------- tools/server/server-common.h | 2 +- tools/server/server-context.cpp | 3 +-- 3 files changed, 27 insertions(+), 14 deletions(-) diff --git a/tools/server/server-common.cpp b/tools/server/server-common.cpp index 75729e62dd..3dc686bb46 100644 --- a/tools/server/server-common.cpp +++ b/tools/server/server-common.cpp @@ -12,6 +12,7 @@ #include #include #include +#include json format_error_response(const std::string & message, const enum error_type type) { std::string type_str; @@ -1238,7 +1239,7 @@ json format_response_rerank( // other utils // -std::vector get_token_probabilities(llama_context * ctx, int idx) { +std::vector get_token_probabilities(llama_context * ctx, int idx, size_t n_top) { std::vector cur; const auto * logits = llama_get_logits_ith(ctx, idx); @@ -1257,21 +1258,34 @@ std::vector get_token_probabilities(llama_context * ctx, int i } } - // sort tokens by logits - std::sort(cur.begin(), cur.end(), [](const llama_token_data & a, const llama_token_data & b) { - return a.logit > b.logit; - }); + // sort tokens by logits (partial: only the leading `n_top` need ordering) + if (n_top > cur.size()) { + n_top = cur.size(); + } + if (n_top > 0) { + std::partial_sort(cur.begin(), cur.begin() + n_top, cur.end(), + [](const llama_token_data & a, const llama_token_data & b) { + return a.logit > b.logit; + }); + } // apply softmax - float max_l = cur[0].logit; + float max_l = -std::numeric_limits::infinity(); + if (n_top > 0) { + max_l = cur[0].logit; // partial_sort guarantees the absolute maximum is at index 0 + } else { + for (const auto & t : cur) { + max_l = std::max(max_l, t.logit); + } + } float cum_sum = 0.0f; - for (size_t i = 0; i < cur.size(); ++i) { - float p = expf(cur[i].logit - max_l); - cur[i].p = p; + for (auto & t : cur) { + float p = expf(t.logit - max_l); + t.p = p; cum_sum += p; } - for (size_t i = 0; i < cur.size(); ++i) { - cur[i].p /= cum_sum; + for (auto & t : cur) { + t.p /= cum_sum; } return cur; diff --git a/tools/server/server-common.h b/tools/server/server-common.h index f286b3d156..efd31733b0 100644 --- a/tools/server/server-common.h +++ b/tools/server/server-common.h @@ -326,7 +326,7 @@ json format_response_rerank( // other utils // -std::vector get_token_probabilities(llama_context * ctx, int idx); +std::vector get_token_probabilities(llama_context * ctx, int idx, size_t n_top); std::string safe_json_to_str(const json & data); diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 791188b1e7..1f0e1bfd42 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -1824,8 +1824,7 @@ private: }); } } else { - // TODO: optimize this with min-p optimization - std::vector cur = get_token_probabilities(ctx_tgt, idx); + std::vector cur = get_token_probabilities(ctx_tgt, idx, n_probs_request); const size_t max_probs = cur.size(); const size_t n_probs = std::min(max_probs, n_probs_request);