diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index a0d78a5dae..98e7475a10 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -24,6 +24,11 @@ #include #include #include +#include +#include +#include +#include +#include // fix problem with std::min and std::max #if defined(_WIN32) @@ -3616,6 +3621,67 @@ private: return true; } + struct sampler_task { + server_slot * slot; + int32_t tok_idx; + }; + + struct sampler_threadpool { + std::vector threads; + std::queue> tasks; + std::mutex mtx; + std::condition_variable cv; + std::condition_variable cv_done; + int pending = 0; + bool stop = false; + + ~sampler_threadpool() { + { + std::lock_guard lock(mtx); + stop = true; + } + cv.notify_all(); + for (auto & t : threads) t.join(); + } + + void init(int n) { + for (int i = 0; i < n; i++) { + threads.emplace_back([this]() { + while (true) { + std::function task; + { + std::unique_lock lock(mtx); + cv.wait(lock, [this]() { return stop || !tasks.empty(); }); + if (stop && tasks.empty()) return; + task = std::move(tasks.front()); + tasks.pop(); + } + task(); + { + std::lock_guard lock(mtx); + pending--; + } + cv_done.notify_one(); + } + }); + } + } + + void enqueue(std::function fn) { + { + std::lock_guard lock(mtx); + tasks.push(std::move(fn)); + pending++; + } + cv.notify_one(); + } + + void wait_all() { + std::unique_lock lock(mtx); + cv_done.wait(lock, [this]() { return pending == 0; }); + } + }; + void post_decode(int32_t n_batch_tokens, int32_t off, llama_batch & batch_view) { // for checking if a given batch index is inside batch_view auto is_inside_view = [&](int32_t idx) { @@ -3637,6 +3703,8 @@ private: slot.task->params.sampling.preserved_tokens.find(token) != slot.task->params.sampling.preserved_tokens.end(); }; + std::vector smpl_tasks; + iterate(slots, [&](server_slot & slot) { // optionally send prompt processing progress if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_DONE_PROMPT) { @@ -3684,13 +3752,42 @@ private: // shifted according to the current sub-batch const int tok_idx = slot.i_batch - off; + smpl_tasks.push_back({&slot, tok_idx}); + }); - llama_token id; - { - scoped_timer timer(t_sampl, n_sampl); - id = common_sampler_sample(slot.smpl.get(), slot.ctx_tgt, tok_idx); + std::unordered_map sampled_token; + + // run common_sampler_sample in a thread pool to sample all tokens in parallel + if (!smpl_tasks.empty()) { + scoped_timer timer(t_sampl, n_sampl); + static sampler_threadpool pool; + if (pool.threads.empty()) { + pool.init(params_base.n_parallel); } + std::vector> results(smpl_tasks.size()); + for (size_t i = 0; i < smpl_tasks.size(); i++) { + const auto & task = smpl_tasks[i]; + pool.enqueue([&results, i, slot_ptr = task.slot, tok_idx = task.tok_idx]() { + results[i] = {slot_ptr, common_sampler_sample(slot_ptr->smpl.get(), slot_ptr->ctx_tgt, tok_idx)}; + }); + } + pool.wait_all(); + + for (const auto & [slot_ptr, id] : results) { + sampled_token[slot_ptr] = id; + } + } + + iterate(slots, [&](server_slot & slot) { + const int tok_idx = slot.i_batch - off; + auto it = sampled_token.find(&slot); + if (it == sampled_token.end()) { + // no token sampled for this slot, skip + return; + } + auto id = it->second; + slot.i_batch = -1; common_sampler_accept(slot.smpl.get(), id, true);