diff --git a/common/arg.cpp b/common/arg.cpp index 5297d90753..2e10ad8971 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1168,6 +1168,16 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } } )); + add_opt(common_arg( + {"--threads-sampling"}, "N", + "number of threads to use during sampling (default: same as --threads)", + [](common_params & params, int value) { + params.sampling_n_threads = value; + if (params.sampling_n_threads <= 0) { + params.sampling_n_threads = std::thread::hardware_concurrency(); + } + } + ).set_examples({LLAMA_EXAMPLE_SERVER})); add_opt(common_arg( {"-C", "--cpu-mask"}, "M", "CPU affinity mask: arbitrarily long hex. Complements cpu-range (default: \"\")", diff --git a/common/common.h b/common/common.h index f2f2202ec2..0537b3ed1f 100644 --- a/common/common.h +++ b/common/common.h @@ -471,6 +471,8 @@ struct common_params { common_cpu_params cpuparams; common_cpu_params cpuparams_batch; + int sampling_n_threads = -1; // number of threads for sampling, used by server + ggml_backend_sched_eval_callback cb_eval = nullptr; void * cb_eval_user_data = nullptr; diff --git a/tools/server/server-common.cpp b/tools/server/server-common.cpp index 81ea63a4f2..ab6f8fc5c0 100644 --- a/tools/server/server-common.cpp +++ b/tools/server/server-common.cpp @@ -1598,32 +1598,36 @@ server_threadpool::~server_threadpool() { } void server_threadpool::init(int n) { - for (int i = 0; i < n; i++) { - threads.emplace_back([this]() { - while (true) { - std::function task; - { - std::unique_lock lock(mtx); - cv.wait(lock, [this]() { return stop || !tasks.empty(); }); - if (stop && tasks.empty()) return; - task = std::move(tasks.front()); - tasks.pop(); - } - task(); - { - std::lock_guard lock(mtx); - pending--; - } - cv_done.notify_one(); - } - }); + // the caller (main thread) participates as a worker, so spawn n-1 threads + const int n_workers = std::max(1, n) - 1; + for (int i = 0; i < n_workers; i++) { + threads.emplace_back([this]() { run_worker(); }); + } +} + +void server_threadpool::run_worker() { + while (true) { + std::function task; + { + std::unique_lock lock(mtx); + cv.wait(lock, [this]() { return stop || !tasks.empty(); }); + if (stop && tasks.empty()) return; + task = std::move(tasks.front()); + tasks.pop(); + } + task(); + { + std::lock_guard lock(mtx); + pending--; + } + cv_done.notify_all(); } } void server_threadpool::enqueue(std::function fn) { { std::lock_guard lock(mtx); - GGML_ASSERT(!stop && !threads.empty()); + GGML_ASSERT(!stop); tasks.push(std::move(fn)); pending++; } @@ -1631,6 +1635,30 @@ void server_threadpool::enqueue(std::function fn) { } void server_threadpool::wait_all() { - std::unique_lock lock(mtx); - cv_done.wait(lock, [this]() { return pending == 0; }); + // the calling thread helps drain the queue until no tasks remain pending + while (true) { + std::function task; + { + std::lock_guard lock(mtx); + if (pending == 0) { + return; + } + if (!tasks.empty()) { + task = std::move(tasks.front()); + tasks.pop(); + } + } + if (task) { + task(); + { + std::lock_guard lock(mtx); + pending--; + } + cv_done.notify_all(); + } else { + // no task available right now, but some are still pending (being run by workers) + std::unique_lock lock(mtx); + cv_done.wait(lock, [this]() { return pending == 0 || !tasks.empty(); }); + } + } } diff --git a/tools/server/server-common.h b/tools/server/server-common.h index 072bc623e9..2d44c70d64 100644 --- a/tools/server/server-common.h +++ b/tools/server/server-common.h @@ -381,6 +381,8 @@ server_tokens format_prompt_rerank( // to be used for multi-threaded sampling // +// the main thread participates as one of the pool's workers, so init(n) +// only spawns n-1 background threads (the caller is the nth) struct server_threadpool { std::vector threads; std::queue> tasks; @@ -400,10 +402,12 @@ struct server_threadpool { handler(item); }); } + // the calling thread runs tasks too, until all are done wait_all(); } private: void enqueue(std::function fn); void wait_all(); + void run_worker(); }; diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index bc991e8d68..6d68138f58 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -1460,7 +1460,15 @@ private: metrics.init(); - threadpool.init(params_base.cpuparams.n_threads); + // initialize threadpool + { + int threadpool_size = params_base.sampling_n_threads; + if (threadpool_size <= 0) { + threadpool_size = params_base.cpuparams.n_threads; + } + SRV_DBG("%s: initializing threadpool, size = %d\n", __func__, threadpool_size); + threadpool.init(threadpool_size); + } if (params_base.cache_idle_slots) { if (params_base.cache_ram_mib == 0) {