mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-27 23:50:20 -05:00
add arg --threads-sampling
This commit is contained in:
parent
c62fdd5fd0
commit
095058ca19
@ -1168,6 +1168,16 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
}
|
||||
}
|
||||
));
|
||||
add_opt(common_arg(
|
||||
{"--threads-sampling"}, "N",
|
||||
"number of threads to use during sampling (default: same as --threads)",
|
||||
[](common_params & params, int value) {
|
||||
params.sampling_n_threads = value;
|
||||
if (params.sampling_n_threads <= 0) {
|
||||
params.sampling_n_threads = std::thread::hardware_concurrency();
|
||||
}
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
add_opt(common_arg(
|
||||
{"-C", "--cpu-mask"}, "M",
|
||||
"CPU affinity mask: arbitrarily long hex. Complements cpu-range (default: \"\")",
|
||||
|
||||
@ -471,6 +471,8 @@ struct common_params {
|
||||
common_cpu_params cpuparams;
|
||||
common_cpu_params cpuparams_batch;
|
||||
|
||||
int sampling_n_threads = -1; // number of threads for sampling, used by server
|
||||
|
||||
ggml_backend_sched_eval_callback cb_eval = nullptr;
|
||||
void * cb_eval_user_data = nullptr;
|
||||
|
||||
|
||||
@ -1598,32 +1598,36 @@ server_threadpool::~server_threadpool() {
|
||||
}
|
||||
|
||||
void server_threadpool::init(int n) {
|
||||
for (int i = 0; i < n; i++) {
|
||||
threads.emplace_back([this]() {
|
||||
while (true) {
|
||||
std::function<void()> task;
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(mtx);
|
||||
cv.wait(lock, [this]() { return stop || !tasks.empty(); });
|
||||
if (stop && tasks.empty()) return;
|
||||
task = std::move(tasks.front());
|
||||
tasks.pop();
|
||||
}
|
||||
task();
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mtx);
|
||||
pending--;
|
||||
}
|
||||
cv_done.notify_one();
|
||||
}
|
||||
});
|
||||
// the caller (main thread) participates as a worker, so spawn n-1 threads
|
||||
const int n_workers = std::max(1, n) - 1;
|
||||
for (int i = 0; i < n_workers; i++) {
|
||||
threads.emplace_back([this]() { run_worker(); });
|
||||
}
|
||||
}
|
||||
|
||||
void server_threadpool::run_worker() {
|
||||
while (true) {
|
||||
std::function<void()> task;
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(mtx);
|
||||
cv.wait(lock, [this]() { return stop || !tasks.empty(); });
|
||||
if (stop && tasks.empty()) return;
|
||||
task = std::move(tasks.front());
|
||||
tasks.pop();
|
||||
}
|
||||
task();
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mtx);
|
||||
pending--;
|
||||
}
|
||||
cv_done.notify_all();
|
||||
}
|
||||
}
|
||||
|
||||
void server_threadpool::enqueue(std::function<void()> fn) {
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mtx);
|
||||
GGML_ASSERT(!stop && !threads.empty());
|
||||
GGML_ASSERT(!stop);
|
||||
tasks.push(std::move(fn));
|
||||
pending++;
|
||||
}
|
||||
@ -1631,6 +1635,30 @@ void server_threadpool::enqueue(std::function<void()> fn) {
|
||||
}
|
||||
|
||||
void server_threadpool::wait_all() {
|
||||
std::unique_lock<std::mutex> lock(mtx);
|
||||
cv_done.wait(lock, [this]() { return pending == 0; });
|
||||
// the calling thread helps drain the queue until no tasks remain pending
|
||||
while (true) {
|
||||
std::function<void()> task;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mtx);
|
||||
if (pending == 0) {
|
||||
return;
|
||||
}
|
||||
if (!tasks.empty()) {
|
||||
task = std::move(tasks.front());
|
||||
tasks.pop();
|
||||
}
|
||||
}
|
||||
if (task) {
|
||||
task();
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mtx);
|
||||
pending--;
|
||||
}
|
||||
cv_done.notify_all();
|
||||
} else {
|
||||
// no task available right now, but some are still pending (being run by workers)
|
||||
std::unique_lock<std::mutex> lock(mtx);
|
||||
cv_done.wait(lock, [this]() { return pending == 0 || !tasks.empty(); });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -381,6 +381,8 @@ server_tokens format_prompt_rerank(
|
||||
// to be used for multi-threaded sampling
|
||||
//
|
||||
|
||||
// the main thread participates as one of the pool's workers, so init(n)
|
||||
// only spawns n-1 background threads (the caller is the nth)
|
||||
struct server_threadpool {
|
||||
std::vector<std::thread> threads;
|
||||
std::queue<std::function<void()>> tasks;
|
||||
@ -400,10 +402,12 @@ struct server_threadpool {
|
||||
handler(item);
|
||||
});
|
||||
}
|
||||
// the calling thread runs tasks too, until all are done
|
||||
wait_all();
|
||||
}
|
||||
|
||||
private:
|
||||
void enqueue(std::function<void()> fn);
|
||||
void wait_all();
|
||||
void run_worker();
|
||||
};
|
||||
|
||||
@ -1460,7 +1460,15 @@ private:
|
||||
|
||||
metrics.init();
|
||||
|
||||
threadpool.init(params_base.cpuparams.n_threads);
|
||||
// initialize threadpool
|
||||
{
|
||||
int threadpool_size = params_base.sampling_n_threads;
|
||||
if (threadpool_size <= 0) {
|
||||
threadpool_size = params_base.cpuparams.n_threads;
|
||||
}
|
||||
SRV_DBG("%s: initializing threadpool, size = %d\n", __func__, threadpool_size);
|
||||
threadpool.init(threadpool_size);
|
||||
}
|
||||
|
||||
if (params_base.cache_idle_slots) {
|
||||
if (params_base.cache_ram_mib == 0) {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user