add arg --threads-sampling

This commit is contained in:
Xuan Son Nguyen 2026-06-22 20:03:49 +02:00
parent c62fdd5fd0
commit 095058ca19
5 changed files with 75 additions and 23 deletions

View File

@ -1168,6 +1168,16 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
}
));
add_opt(common_arg(
{"--threads-sampling"}, "N",
"number of threads to use during sampling (default: same as --threads)",
[](common_params & params, int value) {
params.sampling_n_threads = value;
if (params.sampling_n_threads <= 0) {
params.sampling_n_threads = std::thread::hardware_concurrency();
}
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"-C", "--cpu-mask"}, "M",
"CPU affinity mask: arbitrarily long hex. Complements cpu-range (default: \"\")",

View File

@ -471,6 +471,8 @@ struct common_params {
common_cpu_params cpuparams;
common_cpu_params cpuparams_batch;
int sampling_n_threads = -1; // number of threads for sampling, used by server
ggml_backend_sched_eval_callback cb_eval = nullptr;
void * cb_eval_user_data = nullptr;

View File

@ -1598,32 +1598,36 @@ server_threadpool::~server_threadpool() {
}
void server_threadpool::init(int n) {
for (int i = 0; i < n; i++) {
threads.emplace_back([this]() {
while (true) {
std::function<void()> task;
{
std::unique_lock<std::mutex> lock(mtx);
cv.wait(lock, [this]() { return stop || !tasks.empty(); });
if (stop && tasks.empty()) return;
task = std::move(tasks.front());
tasks.pop();
}
task();
{
std::lock_guard<std::mutex> lock(mtx);
pending--;
}
cv_done.notify_one();
}
});
// the caller (main thread) participates as a worker, so spawn n-1 threads
const int n_workers = std::max(1, n) - 1;
for (int i = 0; i < n_workers; i++) {
threads.emplace_back([this]() { run_worker(); });
}
}
void server_threadpool::run_worker() {
while (true) {
std::function<void()> task;
{
std::unique_lock<std::mutex> lock(mtx);
cv.wait(lock, [this]() { return stop || !tasks.empty(); });
if (stop && tasks.empty()) return;
task = std::move(tasks.front());
tasks.pop();
}
task();
{
std::lock_guard<std::mutex> lock(mtx);
pending--;
}
cv_done.notify_all();
}
}
void server_threadpool::enqueue(std::function<void()> fn) {
{
std::lock_guard<std::mutex> lock(mtx);
GGML_ASSERT(!stop && !threads.empty());
GGML_ASSERT(!stop);
tasks.push(std::move(fn));
pending++;
}
@ -1631,6 +1635,30 @@ void server_threadpool::enqueue(std::function<void()> fn) {
}
void server_threadpool::wait_all() {
std::unique_lock<std::mutex> lock(mtx);
cv_done.wait(lock, [this]() { return pending == 0; });
// the calling thread helps drain the queue until no tasks remain pending
while (true) {
std::function<void()> task;
{
std::lock_guard<std::mutex> lock(mtx);
if (pending == 0) {
return;
}
if (!tasks.empty()) {
task = std::move(tasks.front());
tasks.pop();
}
}
if (task) {
task();
{
std::lock_guard<std::mutex> lock(mtx);
pending--;
}
cv_done.notify_all();
} else {
// no task available right now, but some are still pending (being run by workers)
std::unique_lock<std::mutex> lock(mtx);
cv_done.wait(lock, [this]() { return pending == 0 || !tasks.empty(); });
}
}
}

View File

@ -381,6 +381,8 @@ server_tokens format_prompt_rerank(
// to be used for multi-threaded sampling
//
// the main thread participates as one of the pool's workers, so init(n)
// only spawns n-1 background threads (the caller is the nth)
struct server_threadpool {
std::vector<std::thread> threads;
std::queue<std::function<void()>> tasks;
@ -400,10 +402,12 @@ struct server_threadpool {
handler(item);
});
}
// the calling thread runs tasks too, until all are done
wait_all();
}
private:
void enqueue(std::function<void()> fn);
void wait_all();
void run_worker();
};

View File

@ -1460,7 +1460,15 @@ private:
metrics.init();
threadpool.init(params_base.cpuparams.n_threads);
// initialize threadpool
{
int threadpool_size = params_base.sampling_n_threads;
if (threadpool_size <= 0) {
threadpool_size = params_base.cpuparams.n_threads;
}
SRV_DBG("%s: initializing threadpool, size = %d\n", __func__, threadpool_size);
threadpool.init(threadpool_size);
}
if (params_base.cache_idle_slots) {
if (params_base.cache_ram_mib == 0) {