llama.cpp/tools/server/server-models.cpp
Pascal 1a87dcdc45
server + ui: SSE Replay Buffer (#23226)
* server: SSE replay buffer, survives client disconnect

Opt in on POST /v1/chat/completions when the client sends
X-Stream-Resume: 1 and a non empty X-Conversation-Id. The conv id is
the session identity end to end, no extra opaque token. The drain
runs detached server side and buffers SSE bytes, the generation
survives HTTP disconnect, F5, or lets users switch from iOS Safari
to another app without losing the actively generated response.

Routes:
  GET    /v1/stream/<conv_id>?from=N       replay
  GET    /v1/streams[?conversation_id=X]   list, drives sidebar spinners
  DELETE /v1/stream/<conv_id>              Stop, idempotent

Router parent fans out to children for list and delete, probes on GET
to route to the owner, fans out DELETE on POST so "one session per
conv" holds across model swaps.

WebUI: the layout snapshots /v1/streams at mount and on
visibilitychange, the sidebar reflects live inferences across all
convs. The chat page reattaches on mount, append vs fresh is detected
from existing content so continue mid stream keeps its prefix.

update_slots: on llama_memory_seq_rm refusal at a deep position, full
clear of the seq and reprefill from zero instead of GGML_ABORT.

OAI strict path unchanged when the opt in headers are absent.

* server: create stream session only after post_tasks succeeds

* server, ui: drop X-Stream-Resume, X-Conversation-Id alone enables the replay buffer

* server: drop magic 17, derive the X-Conversation-Id header length from sizeof at build time

* refactor: address review feedback from ngxson

* server-context: cleaning

* server-stream: fix use-after-free on rd

Guard stop_producer with a shared alive flag, flipped by on_stream_end
before rd dies. Prevents a late cancel (session eviction by a later
POST on the same conv_id, or a DELETE arriving after the producer
ended) from touching a destroyed rd.

* ui: fix cross-conversation contamination

Scope streaming flags per conv so one finishing does not unflag the
others, guard discoverActiveStream against concurrent runs to avoid
duplicate attaches, and stop racing syncRemoteRunningStreams for the
sidebar set.

* server-http: keep request alive in detached SSE drain

The response next() lambda may reach into *request via &req long
after on_complete reset the request shared_ptr. Capture request in
the detached thread so it outlives the drain.

* ui: address review feedback from coder543

Forward Authorization to /v1/stream and /v1/streams fetches, the resumable routes
must obey --api-key like the rest of the API.

Wrap reader.read() in a try/catch, the underlying connection drop rejects with
TypeError instead of resolving done=true, treat it as a premature end of stream
so the existing resume loop kicks in.

Freeze the model at session start in chatStreamingStates.model and thread it
through cancel and resume, the dropdown selection may have changed since the
POST and the server side identity is fixed at that time.

* format

* ui: remove unused selectedModelName

* server-stream: poll session->is_cancelled() in stream_aware_should_stop

Address review feedback from coder543. The cancel propagation through
rd.stop() relies on the slot eventually processing the cancel task and
posting a result that notifies the recv condvar, remove_waiting_task_ids
does not notify directly. Add a defensive poll on session->is_cancelled()
so the producer-side next() loop exits on its next iteration after
cancel() without waiting for the cancel task to round trip through a slot.

* server-stream, ui: replace GET /v1/streams with POST /v1/streams/lookup

Address review feedback from coder543. Listing live sessions leaks the
conversation_id of every concurrent user, which defeats the random UUID
unguessability. The new route takes {conversation_ids: [...]} in the
body and returns matches only for the ids the caller already owns, so
foreign UUIDs stay private. The router fans out the same POST to every
child and aggregates, the WebUI passes the convs visible in its sidebar.

* ui: read conv ids from IndexedDB in syncRemoteRunningStreams

The conversations store is not hydrated yet at +layout onMount, so the
sidebar spinners stayed off for background convs until the user clicked
on them. Read straight from the DB to dodge the init race.

* server-models: deduplicate stream lookup timeouts behind one constant

* ui: extract visibility kick grace into a stream constant, bump to 1000 ms

* make it safer & more simple

* server-stream: survive client disconnect via stream_pipe::finish_producer

After the RAII rewrite the generation stopped the moment the client
disconnected. httplib bails its content provider on the is_peer_alive
check at the top of write_content_chunked, so returning true from the
provider never keeps it producing: the response resets, rd is destroyed
and its task gets cancelled.

Reinstate the disconnect survival inside the pipe. stream_pipe gains
finish_producer, which pumps the response next() into the ring buffer
until the generation ends, and mark_producer_done for the clean wire
end. server-http only triggers them: mark before sink.done on a clean
close, finish in on_complete when the peer left early. No detach, no
stream logic in server-http beyond the trigger, and the strict OAI path
is untouched when no pipe is attached.

Known limitation: finish_producer pumps synchronously on the http
worker, so a disconnected stream keeps its worker busy until the
generation ends. A follow-up will move the drain off the http worker so
no worker is held.

* server-stream: drain disconnected streams on a manager owned thread

The previous commit pumped the post disconnect drain synchronously in
on_complete, on the http worker, so a disconnected stream kept its
worker busy until the generation ended. Under a wave of reloads or tab
closes that pins workers from the pool.

Move the drain off the http worker. on_complete now hands the response
to stream_session_manager::adopt_orphan, which pumps it to completion on
a manager owned thread and releases the worker at once. One thread per
disconnected stream still generating, stored in a list, joined and
reaped on the next adopt, by the GC, and at shutdown. No detach, the
thread lifecycle is fully owned by the manager. needs_drain gates the
handoff so a cleanly finished stream never spawns a thread, and the
strict OAI path stays untouched when no pipe is attached.

stop_gc now cancels sessions before finalizing them, so an in flight
drain sees is_cancelled and exits instead of blocking the shutdown join
until the generation ends naturally.

* ui: add missing JSDoc

* server-stream: drain on the http worker, drop the manager thread

Address @ngxson review: httplib runs a large dynamic pool and a worker
blocked in next() sits on a condvar instead of burning cpu, so draining
the rest of the generation on that worker is fine and much simpler than
a dedicated thread.

on_complete calls finish_producer directly again. Removes adopt_orphan,
the orphan thread list and its reaping, the stop_gc session cancel that
only existed to unblock those threads, and the now dead drain_shutdown
flag.

* server-stream: split stream_pipe into producer and consumer classes

Address @ngxson review: one class covering both ends was messy. stream_pipe
is now a base holding the session and is_cancelled, with stream_pipe_producer
(write, mark_producer_done, finish_producer, cleanup, finalizes on destruct)
and stream_pipe_consumer (read only, no finalize) deriving from it.

Drops the is_producer_ discriminator and its runtime guards, the type now
encodes the role. res.spipe is retyped to shared_ptr<stream_pipe_producer>
since it is only ever a producer. No behavior change.

* server-stream: rename producer methods to unix pipe semantics

Address @ngxson review: mark_producer_done becomes done(), finish_producer
becomes close(), matching a unix pipe write end. The producer_done_ member
follows as done_. write() is unchanged. No behavior change.

* server, ui: route resumable streams via a conv map, persist resume identity

Address ngxson review: drop the polling probe, proxy_post records a conv_id ->
model map and the stream routes resolve the owning child with one lookup. The
map is the single source of truth, the ::model suffix stays for child session
uniqueness but the router never parses it.

UI: the server keys a session by the POST time identity (conv::model), but reload
probed with the bare conv id and missed model tagged sessions, so F5 stopped the
stream and sidebar spinners stayed off. Persist the model and rebuild the exact
identity on resume, single conv and bulk sidebar both send it.

Add unit coverage for the identity round trip.

* ui: resolve continue target by id to stop cross-conversation flash on switch

* ui: skip stream resume when the abort is intentional

* server: move the conv id to model map into a self contained tracker

Address review from ngxson: server_models held two mutexes side by side, the
global one and a bare conv_model_mu guarding a loose map, which made the locking
hard to follow. Wrap the map and its lock in a small conv_model_tracker struct
that owns its mutex, one mutex per struct. The remember, lookup and forget
methods move inline into the tracker, server_models exposes a single conv_models
member and the routes call models.conv_models.lookup and friends. No behavior
change, the map stays the single source of truth for routing resumable streams
to a child.

* ui: replace stream magic values with enums and shared constants

Address review from allozaur: lift the inline literals around the resumable
stream code into named symbols so the intent is explicit and reusable.

* ui: fold the stream resume and discovery helpers into ChatService

Address review from allozaur: drop the two standalone stream-*.service files.
They were used only by the chat service and store, carried no shared state, and
did not follow the static class pattern the other services use, so a separate
abstraction was not warranted. Move the helpers onto ChatService as static
methods. No behavior change, tests now exercise them through ChatService.

* docs: document the SSE replay buffer in server README-dev

Add the resumable streaming section, list stream_session_manager in the
backend component inventory, and link PR 23226 in the related PRs.

* ui: align attachServerStream call with onCompletionId param in handleStreamResponse

* server-http: rename del_ to del to match get and post

* ui: address review feedback from allozaur

* ui: drop duplicate SSE constants, keep sse.ts canonical

* ui: use svelte:document for the visibilitychange listener

address review from allozaur: replace the manual document.addEventListener
in onMount with a declarative <svelte:document onvisibilitychange>. svelte
handles attach, detach and SSR, so the typeof document guard and the onMount
cleanup go away. onMount keeps only the first load snapshot.

* server: trim redundant stream drain comments

Address review from ngxson

* server: balance and clean up stream comments

remove redundant comments and tighten the verbose ones across the resumable
stream code, keeping the concurrency and lifetime rationale that is not obvious
from the code. also fix two stale comments in server.cpp and server-models.h
that still described the old ::model suffix probe and fan out routing, now
replaced by the conv_id -> model map

Address review from ngxson

* ui: balance and clean up stream comments

dedup repeated rationale (frozen conv::model identity, the lookup privacy note,
the abort patterns) down to one canonical spot, tighten the verbose blocks, and
keep the concurrency and resume-offset reasoning. fix stale comments in
stream-identity.ts and chat.service.ts that still described the old loopback
probe and fan out routing, now the conv_id -> model map.

---------

Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
2026-06-26 09:31:29 +02:00

2306 lines
85 KiB
C++

#include "server-common.h"
#include "server-models.h"
#include "server-context.h"
#include "server-stream.h"
#include "build-info.h"
#include "preset.h"
#include "download.h"
#include <cpp-httplib/httplib.h> // TODO: remove this once we use HTTP client from download.h
#include <optional>
#include <sheredom/subprocess.h>
#include <functional>
#include <optional>
#include <algorithm>
#include <thread>
#include <mutex>
#include <condition_variable>
#include <cstring>
#include <cstdlib>
#include <atomic>
#include <chrono>
#include <queue>
#include <filesystem>
#include <random>
#include <sstream>
#include <cstring>
#ifdef _WIN32
#include <winsock2.h>
#include <windows.h>
#else
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <unistd.h>
extern char **environ;
#endif
#if defined(__APPLE__) && defined(__MACH__)
// macOS: use _NSGetExecutablePath to get the executable path
#include <mach-o/dyld.h>
#include <limits.h>
#endif
#define DEFAULT_STOP_TIMEOUT 10 // seconds
#define CMD_ROUTER_TO_CHILD_EXIT "cmd_router_to_child:exit"
#define CMD_CHILD_TO_ROUTER_STATE "cmd_child_to_router:state:" // followed by json string
// address for child process, this is needed because router may run on 0.0.0.0
// ref: https://github.com/ggml-org/llama.cpp/issues/17862
#define CHILD_ADDR "127.0.0.1"
struct server_subproc {
std::optional<subprocess_s> sproc; // empty while in DOWNLOADING state
std::atomic<bool> stopped{false}; // set to cancel a download or signal child process exit
subprocess_s & get() {
GGML_ASSERT(sproc.has_value() && "subprocess not initialized");
return sproc.value();
}
bool is_alive() {
return sproc.has_value() && subprocess_alive(&sproc.value());
}
void request_exit() {
if (sproc.has_value()) {
FILE * stdin_file = subprocess_stdin(&sproc.value());
if (stdin_file) {
fprintf(stdin_file, "%s\n", CMD_ROUTER_TO_CHILD_EXIT);
fflush(stdin_file);
}
}
stopped.store(true, std::memory_order_relaxed);
}
void terminate() {
if (!sproc.has_value()) {
return;
}
#if defined(_WIN32)
if (sproc->hProcess == NULL) {
return;
}
#else
if (sproc->child <= 0) {
return;
}
#endif
subprocess_terminate(&sproc.value());
}
};
// short loopback budget for the resumable stream router to child JSON calls (probe, lookup,
// delete). distinct from params.timeout_read/write which only applies to the generation proxy
static constexpr int STREAM_LOOKUP_TIMEOUT_MS = 250;
static std::filesystem::path get_server_exec_path() {
#if defined(_WIN32)
wchar_t buf[32768] = { 0 }; // Large buffer to handle long paths
DWORD len = GetModuleFileNameW(nullptr, buf, _countof(buf));
if (len == 0 || len >= _countof(buf)) {
throw std::runtime_error("GetModuleFileNameW failed or path too long");
}
return std::filesystem::path(buf);
#elif defined(__APPLE__) && defined(__MACH__)
char small_path[PATH_MAX];
uint32_t size = sizeof(small_path);
if (_NSGetExecutablePath(small_path, &size) == 0) {
// resolve any symlinks to get absolute path
try {
return std::filesystem::canonical(std::filesystem::path(small_path));
} catch (...) {
return std::filesystem::path(small_path);
}
} else {
// buffer was too small, allocate required size and call again
std::vector<char> buf(size);
if (_NSGetExecutablePath(buf.data(), &size) == 0) {
try {
return std::filesystem::canonical(std::filesystem::path(buf.data()));
} catch (...) {
return std::filesystem::path(buf.data());
}
}
throw std::runtime_error("_NSGetExecutablePath failed after buffer resize");
}
#else
char path[FILENAME_MAX];
ssize_t count = readlink("/proc/self/exe", path, FILENAME_MAX);
if (count <= 0) {
throw std::runtime_error("failed to resolve /proc/self/exe");
}
return std::filesystem::path(std::string(path, count));
#endif
}
static void unset_reserved_args(common_preset & preset, bool unset_model_args) {
preset.unset_option("LLAMA_ARG_SSL_KEY_FILE");
preset.unset_option("LLAMA_ARG_SSL_CERT_FILE");
preset.unset_option("LLAMA_API_KEY");
preset.unset_option("LLAMA_ARG_MODELS_DIR");
preset.unset_option("LLAMA_ARG_MODELS_MAX");
preset.unset_option("LLAMA_ARG_MODELS_PRESET");
preset.unset_option("LLAMA_ARG_MODELS_AUTOLOAD");
if (unset_model_args) {
preset.unset_option("LLAMA_ARG_MODEL");
preset.unset_option("LLAMA_ARG_MMPROJ");
preset.unset_option("LLAMA_ARG_ALIAS");
preset.unset_option("LLAMA_ARG_HF_REPO");
}
}
#ifdef _WIN32
static std::string wide_to_utf8(const wchar_t * ws) {
if (!ws || !*ws) {
return {};
}
const int len = static_cast<int>(std::wcslen(ws));
const int bytes = WideCharToMultiByte(CP_UTF8, 0, ws, len, nullptr, 0, nullptr, nullptr);
if (bytes == 0) {
return {};
}
std::string utf8(bytes, '\0');
WideCharToMultiByte(CP_UTF8, 0, ws, len, utf8.data(), bytes, nullptr, nullptr);
return utf8;
}
#endif
static std::vector<std::string> get_environment() {
std::vector<std::string> env;
#ifdef _WIN32
LPWCH env_block = GetEnvironmentStringsW();
if (!env_block) {
return env;
}
for (LPWCH e = env_block; *e; e += wcslen(e) + 1) {
env.emplace_back(wide_to_utf8(e));
}
FreeEnvironmentStringsW(env_block);
#else
if (environ == nullptr) {
return env;
}
for (char ** e = environ; *e != nullptr; e++) {
env.emplace_back(*e);
}
#endif
return env;
}
void server_model_meta::update_args(common_preset_context & ctx_preset, std::string bin_path) {
// update params
unset_reserved_args(preset, false);
preset.set_option(ctx_preset, "LLAMA_ARG_HOST", CHILD_ADDR);
preset.set_option(ctx_preset, "LLAMA_ARG_PORT", std::to_string(port));
preset.set_option(ctx_preset, "LLAMA_ARG_ALIAS", name);
// TODO: maybe validate preset before rendering ?
// render args
args = preset.to_args(bin_path);
// unified binary dispatches by subcommand, re-inject it right after the
// binary path so the child starts as 'llama serve ...' not 'llama ...'
const char * app_cmd = std::getenv("LLAMA_APP_CMD");
if (app_cmd != nullptr && app_cmd[0] != '\0' && !bin_path.empty()) {
args.insert(args.begin() + 1, app_cmd);
}
}
void server_model_meta::update_caps() {
try {
common_params params;
preset.apply_to_params(params, {
"LLAMA_ARG_MODEL",
"LLAMA_ARG_MODEL_URL",
"LLAMA_ARG_MMPROJ",
"LLAMA_ARG_MMPROJ_URL",
"LLAMA_ARG_HF_REPO",
"LLAMA_ARG_HF_REPO_FILE",
});
params.offline = true;
common_models_handler handler = common_models_handler_init(params, LLAMA_EXAMPLE_SERVER);
common_models_handler_apply(handler, params); // note: this won't download the model because offline=true
if (params.mmproj.path.empty()) {
multimodal = { false, false };
} else {
multimodal = mtmd_get_cap_from_file(params.mmproj.path.c_str());
}
} catch (const std::exception & e) {
LOG_WRN("failed to initialize common_params for multimodal capability detection: %s\n", e.what());
multimodal = { false, false };
}
}
//
// server_models
//
server_models::server_models(
const common_params & params,
int argc,
char ** argv)
: ctx_preset(LLAMA_EXAMPLE_SERVER),
base_params(params),
base_env(get_environment()),
base_preset(ctx_preset.load_from_args(argc, argv)) {
// clean up base preset
unset_reserved_args(base_preset, true);
// set binary path
try {
bin_path = get_server_exec_path().string();
} catch (const std::exception & e) {
bin_path = argv[0];
LOG_WRN("failed to get server executable path: %s\n", e.what());
LOG_WRN("using original argv[0] as fallback: %s\n", argv[0]);
}
load_models();
}
void server_models::add_model(server_model_meta && meta) {
if (mapping.find(meta.name) != mapping.end()) {
throw std::runtime_error(string_format("model '%s' appears multiple times", meta.name.c_str()));
}
// check model name does not conflict with existing aliases
for (const auto & [key, inst] : mapping) {
if (inst.meta.aliases.count(meta.name)) {
throw std::runtime_error(string_format("model name '%s' conflicts with alias of model '%s'",
meta.name.c_str(), key.c_str()));
}
}
// parse aliases from preset's --alias option (comma-separated)
std::string alias_str;
if (meta.preset.get_option("LLAMA_ARG_ALIAS", alias_str) && !alias_str.empty()) {
for (auto & alias : string_split<std::string>(alias_str, ',')) {
alias = string_strip(alias);
if (!alias.empty()) {
meta.aliases.insert(alias);
}
}
}
// parse tags from preset's --tags option (comma-separated)
std::string tags_str;
if (meta.preset.get_option("LLAMA_ARG_TAGS", tags_str) && !tags_str.empty()) {
for (auto & tag : string_split<std::string>(tags_str, ',')) {
tag = string_strip(tag);
if (!tag.empty()) {
meta.tags.insert(tag);
}
}
}
// validate aliases do not conflict with existing names or aliases
for (const auto & alias : meta.aliases) {
if (mapping.find(alias) != mapping.end()) {
throw std::runtime_error(string_format("alias '%s' for model '%s' conflicts with existing model name",
alias.c_str(), meta.name.c_str()));
}
for (const auto & [key, inst] : mapping) {
if (inst.meta.aliases.count(alias)) {
throw std::runtime_error(string_format("alias '%s' for model '%s' conflicts with alias of model '%s'",
alias.c_str(), meta.name.c_str(), key.c_str()));
}
}
}
meta.update_args(ctx_preset, bin_path); // render args
meta.update_caps();
std::string name = meta.name;
mapping[name] = instance_t{
/* subproc */ std::make_shared<server_subproc>(),
/* th */ std::thread(),
/* meta */ std::move(meta)
};
}
void server_models::notify_sse(const std::string & event, const std::string & model_id, const json & data) {
std::unique_ptr<server_task_result_router> result = std::make_unique<server_task_result_router>();
result->data = {
{"model", model_id},
{"event", event},
};
if (!data.is_null()) {
result->data["data"] = data;
}
SRV_DBG("notifying SSE clients about event '%s' for model '%s': %s\n", event.c_str(), model_id.c_str(), safe_json_to_str(result->data).c_str());
sse.broadcast(std::move(result));
}
void server_models::load_models() {
// Phase 1: load presets from all sources - pure I/O, no lock needed
// 1. cached models
common_presets cached_models = ctx_preset.load_from_cache();
SRV_INF("Loaded %zu cached model presets\n", cached_models.size());
// 2. local models from --models-dir
common_presets local_models;
if (!base_params.models_dir.empty()) {
local_models = ctx_preset.load_from_models_dir(base_params.models_dir);
SRV_INF("Loaded %zu local model presets from %s\n", local_models.size(), base_params.models_dir.c_str());
}
// 3. custom-path models from presets
common_preset global = {};
common_presets custom_presets = {};
if (!base_params.models_preset.empty()) {
custom_presets = ctx_preset.load_from_ini(base_params.models_preset, global);
SRV_INF("Loaded %zu custom model presets from %s\n", custom_presets.size(), base_params.models_preset.c_str());
}
// cascade, apply global preset first
cached_models = ctx_preset.cascade(global, cached_models);
local_models = ctx_preset.cascade(global, local_models);
custom_presets = ctx_preset.cascade(global, custom_presets);
// note: if a model exists in both cached and local, local takes precedence
common_presets final_presets;
std::unordered_map<std::string, server_model_source> source_map;
for (const auto & [name, preset] : cached_models) {
final_presets[name] = preset;
source_map[name] = SERVER_MODEL_SOURCE_CACHE;
}
for (const auto & [name, preset] : local_models) {
final_presets[name] = preset;
source_map[name] = SERVER_MODEL_SOURCE_MODELS_DIR;
}
for (const auto & [name, custom] : custom_presets) {
if (final_presets.find(name) != final_presets.end()) {
final_presets[name].merge(custom);
} else {
final_presets[name] = custom;
}
source_map[name] = SERVER_MODEL_SOURCE_PRESET;
}
// overlay router's own CLI args on top of every model preset so that
// e.g. `llama-server --temp 0` is honoured by all child processes
for (auto & [name, preset] : final_presets) {
preset.merge(base_preset);
}
auto get_source = [&](const std::string & name) {
return source_map.count(name) ? source_map.at(name) : SERVER_MODEL_SOURCE_PRESET;
};
// Helpers that read `mapping` - must be called while holding the lock.
std::unordered_set<std::string> custom_names;
for (const auto & [name, preset] : custom_presets) custom_names.insert(name);
auto join_set = [](const std::set<std::string> & s) {
std::string result;
for (const auto & v : s) {
if (!result.empty()) result += ", ";
result += v;
}
return result;
};
auto log_available_models = [&]() {
SRV_INF("Available models (%zu) (*: custom preset)\n", mapping.size());
for (const auto & [name, inst] : mapping) {
bool has_custom = custom_names.find(name) != custom_names.end();
std::string info;
if (!inst.meta.aliases.empty()) info += " (aliases: " + join_set(inst.meta.aliases) + ")";
if (!inst.meta.tags.empty()) info += " [tags: " + join_set(inst.meta.tags) + "]";
SRV_INF(" %c %s%s\n", has_custom ? '*' : ' ', name.c_str(), info.c_str());
}
};
auto apply_stop_timeout = [&]() {
for (auto & [name, inst] : mapping) {
std::string val;
if (inst.meta.preset.get_option(COMMON_ARG_PRESET_STOP_TIMEOUT, val)) {
try {
inst.meta.stop_timeout = std::stoi(val);
} catch (...) {
SRV_WRN("invalid stop-timeout value '%s' for model '%s', using default %d seconds\n",
val.c_str(), name.c_str(), DEFAULT_STOP_TIMEOUT);
inst.meta.stop_timeout = DEFAULT_STOP_TIMEOUT;
}
}
}
};
// update_args() injects HOST/PORT/ALIAS, so strip them before comparing presets
auto preset_options_for_compare = [](common_preset p) {
p.unset_option("LLAMA_ARG_HOST");
p.unset_option("LLAMA_ARG_PORT");
p.unset_option("LLAMA_ARG_ALIAS");
return p.options;
};
// Phase 2: acquire the lock once for all mapping mutations.
// We temporarily release it only when calling functions that acquire it internally
// (unload, load) or when joining threads (the monitoring thread calls update_status
// which locks the mutex, so joining while holding it would deadlock).
std::unique_lock<std::mutex> lk(mutex);
need_reload = false;
bool is_first_load = mapping.empty();
if (is_first_load) {
// FIRST LOAD: add all models, then unlock for autoloading
for (const auto & [name, preset] : final_presets) {
server_model_meta meta{
/* source */ get_source(name),
/* preset */ preset,
/* name */ name,
/* aliases */ {},
/* tags */ {},
/* port */ 0,
/* status */ SERVER_MODEL_STATUS_UNLOADED,
/* last_used */ 0,
/* args */ std::vector<std::string>(),
/* loaded_info */ {},
/* progress */ {},
/* exit_code */ 0,
/* stop_timeout */ DEFAULT_STOP_TIMEOUT,
/* multimodal */ mtmd_caps{false, false},
// /* need_download */ false,
};
add_model(std::move(meta));
}
apply_stop_timeout();
log_available_models();
std::vector<std::string> models_to_load;
for (const auto & [name, inst] : mapping) {
std::string val;
if (inst.meta.preset.get_option(COMMON_ARG_PRESET_LOAD_ON_STARTUP, val) && common_arg_utils::is_truthy(val)) {
models_to_load.push_back(name);
}
}
if ((int)models_to_load.size() > base_params.models_max) {
throw std::runtime_error(string_format(
"number of models to load on startup (%zu) exceeds models_max (%d)",
models_to_load.size(), base_params.models_max));
}
lk.unlock();
for (const auto & name : models_to_load) {
SRV_INF("(startup) loading model %s\n", name.c_str());
load(name);
}
} else {
// RELOAD: diff the new preset list against the current mapping and reconcile
is_reloading = true;
// find running models whose source was removed or whose preset changed
std::vector<std::string> to_unload;
for (const auto & [name, inst] : mapping) {
if (!inst.meta.is_running()) continue;
auto it = final_presets.find(name);
if (it == final_presets.end()) {
to_unload.push_back(name); // removed from source
} else if (preset_options_for_compare(inst.meta.preset) != preset_options_for_compare(it->second)) {
to_unload.push_back(name); // preset changed
}
}
// unload() acquires the lock internally, so release before each call
for (const auto & name : to_unload) {
SRV_INF("(reload) unloading model name=%s (source updated or removed)\n", name.c_str());
lk.unlock();
unload(name);
lk.lock();
}
// wait for all targeted models to reach UNLOADED; cv.wait handles unlock/relock
cv.wait(lk, [&]() {
for (const auto & name : to_unload) {
auto it = mapping.find(name);
if (it != mapping.end() && it->second.meta.is_running()) return false;
}
return true;
});
// collect all threads to join in one pass while the lock is held:
// - monitoring threads from just-unloaded models (to_unload)
// - threads of already-UNLOADED models that are being removed from source
std::vector<std::thread> threads_to_join;
for (const auto & name : to_unload) {
auto it = mapping.find(name);
if (it != mapping.end() && it->second.th.joinable()) {
threads_to_join.push_back(std::move(it->second.th));
}
}
for (auto & [name, inst] : mapping) {
if (inst.meta.status == SERVER_MODEL_STATUS_DOWNLOADING) {
continue; // downloading models are not from config sources, leave them alone
}
if (final_presets.find(name) == final_presets.end() && !inst.meta.is_running() && inst.th.joinable()) {
threads_to_join.push_back(std::move(inst.th));
}
}
// join outside the lock - monitoring thread calls update_status (needs lock)
lk.unlock();
for (auto & th : threads_to_join) th.join();
lk.lock();
// erase models no longer in any source
for (auto it = mapping.begin(); it != mapping.end(); ) {
if (it->second.meta.status == SERVER_MODEL_STATUS_DOWNLOADING) {
++it; // download thread is still busy, skip
} else if (it->second.meta.status == SERVER_MODEL_STATUS_DOWNLOADED) {
// download finished, safe to erase
if (it->second.th.joinable()) {
it->second.th.join();
}
it = mapping.erase(it);
} else if (final_presets.find(it->first) == final_presets.end()) {
SRV_INF("(reload) removing model name=%s (no longer in source)\n", it->first.c_str());
GGML_ASSERT(!it->second.th.joinable()); // must have been joined above
it = mapping.erase(it);
} else {
++it;
}
}
// update presets for non-running models still in source
for (auto & [name, inst] : mapping) {
if (inst.meta.is_running()) continue;
auto it = final_presets.find(name);
if (it == final_presets.end()) continue; // erased above
inst.meta.preset = it->second;
// re-parse aliases, then validate against other models
std::set<std::string> new_aliases;
std::string alias_str;
if (inst.meta.preset.get_option("LLAMA_ARG_ALIAS", alias_str) && !alias_str.empty()) {
for (auto & alias : string_split<std::string>(alias_str, ',')) {
alias = string_strip(alias);
if (!alias.empty()) new_aliases.insert(alias);
}
}
inst.meta.aliases.clear();
for (const auto & alias : new_aliases) {
bool conflict = false;
for (const auto & [other_name, other_inst] : mapping) {
if (other_name == name) continue;
if (other_name == alias || other_inst.meta.aliases.count(alias)) {
SRV_WRN("(reload) alias '%s' for model '%s' conflicts with model '%s', skipping\n",
alias.c_str(), name.c_str(), other_name.c_str());
conflict = true;
break;
}
}
if (!conflict) inst.meta.aliases.insert(alias);
}
// re-parse tags
inst.meta.tags.clear();
std::string tags_str;
if (inst.meta.preset.get_option("LLAMA_ARG_TAGS", tags_str) && !tags_str.empty()) {
for (auto & tag : string_split<std::string>(tags_str, ',')) {
tag = string_strip(tag);
if (!tag.empty()) inst.meta.tags.insert(tag);
}
}
inst.meta.exit_code = 0; // clear failed state so the model can be reloaded
inst.meta.update_args(ctx_preset, bin_path);
inst.meta.update_caps();
}
// add models that are new in this reload
std::vector<std::string> newly_added;
for (const auto & [name, preset] : final_presets) {
if (mapping.find(name) == mapping.end()) {
server_model_meta meta{
/* source */ get_source(name),
/* preset */ preset,
/* name */ name,
/* aliases */ {},
/* tags */ {},
/* port */ 0,
/* status */ SERVER_MODEL_STATUS_UNLOADED,
/* last_used */ 0,
/* args */ std::vector<std::string>(),
/* loaded_info */ {},
/* progress */ {},
/* exit_code */ 0,
/* stop_timeout */ DEFAULT_STOP_TIMEOUT,
/* multimodal */ mtmd_caps{false, false},
// /* need_download */ false,
};
add_model(std::move(meta));
newly_added.push_back(name);
}
}
apply_stop_timeout();
// clear reload flag before unlocking for autoload - load() blocks on !is_reloading,
// so clearing it here (while still locked) prevents a deadlock in the autoload calls below
is_reloading = false;
cv.notify_all();
log_available_models();
// collect autoload candidates while still under the lock
std::vector<std::string> to_autoload;
for (const auto & name : newly_added) {
auto it = mapping.find(name);
if (it != mapping.end()) {
std::string val;
if (it->second.meta.preset.get_option(COMMON_ARG_PRESET_LOAD_ON_STARTUP, val) && common_arg_utils::is_truthy(val)) {
to_autoload.push_back(name);
}
}
}
lk.unlock();
for (const auto & name : to_autoload) {
SRV_INF("(reload) loading new model %s\n", name.c_str());
load(name);
}
notify_sse("models_reload", "*");
}
}
void server_models::update_meta(const std::string & name, const server_model_meta & meta) {
std::lock_guard<std::mutex> lk(mutex);
auto it = mapping.find(name);
if (it != mapping.end()) {
it->second.meta = meta;
}
cv.notify_all(); // notify wait_until_loading_finished
}
bool server_models::has_model(const std::string & name) {
std::lock_guard<std::mutex> lk(mutex);
if (mapping.find(name) != mapping.end()) {
return true;
}
for (const auto & [key, inst] : mapping) {
if (inst.meta.aliases.count(name)) {
return true;
}
}
return false;
}
std::optional<server_model_meta> server_models::get_meta(const std::string & name) {
std::unique_lock<std::mutex> lk(mutex);
if (need_reload) {
lk.unlock();
load_models();
lk.lock();
}
auto it = mapping.find(name);
if (it != mapping.end()) {
return it->second.meta;
}
for (const auto & [key, inst] : mapping) {
if (inst.meta.aliases.count(name)) {
return inst.meta;
}
}
return std::nullopt;
}
static int get_free_port() {
#ifdef _WIN32
WSADATA wsaData;
if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) {
return -1;
}
typedef SOCKET native_socket_t;
#define INVALID_SOCKET_VAL INVALID_SOCKET
#define CLOSE_SOCKET(s) closesocket(s)
#else
typedef int native_socket_t;
#define INVALID_SOCKET_VAL -1
#define CLOSE_SOCKET(s) close(s)
#endif
native_socket_t sock = socket(AF_INET, SOCK_STREAM, 0);
if (sock == INVALID_SOCKET_VAL) {
#ifdef _WIN32
WSACleanup();
#endif
return -1;
}
struct sockaddr_in serv_addr;
std::memset(&serv_addr, 0, sizeof(serv_addr));
serv_addr.sin_family = AF_INET;
serv_addr.sin_addr.s_addr = htonl(INADDR_ANY);
serv_addr.sin_port = htons(0);
if (bind(sock, (struct sockaddr*)&serv_addr, sizeof(serv_addr)) != 0) {
CLOSE_SOCKET(sock);
#ifdef _WIN32
WSACleanup();
#endif
return -1;
}
#ifdef _WIN32
int namelen = sizeof(serv_addr);
#else
socklen_t namelen = sizeof(serv_addr);
#endif
if (getsockname(sock, (struct sockaddr*)&serv_addr, &namelen) != 0) {
CLOSE_SOCKET(sock);
#ifdef _WIN32
WSACleanup();
#endif
return -1;
}
int port = ntohs(serv_addr.sin_port);
CLOSE_SOCKET(sock);
#ifdef _WIN32
WSACleanup();
#endif
return port;
}
// helper to convert vector<string> to char **
// pointers are only valid as long as the original vector is valid
static std::vector<char *> to_char_ptr_array(const std::vector<std::string> & vec) {
std::vector<char *> result;
result.reserve(vec.size() + 1);
for (const auto & s : vec) {
result.push_back(const_cast<char*>(s.c_str()));
}
result.push_back(nullptr);
return result;
}
std::vector<server_model_meta> server_models::get_all_meta() {
std::unique_lock<std::mutex> lk(mutex);
if (need_reload) {
lk.unlock();
load_models();
lk.lock();
}
std::vector<server_model_meta> result;
result.reserve(mapping.size());
for (const auto & [name, inst] : mapping) {
result.push_back(inst.meta);
}
return result;
}
void server_models::unload_lru() {
if (base_params.models_max <= 0) {
return; // no limit
}
// remove one of the servers if we passed the models_max (least recently used - LRU)
std::string lru_model_name = "";
int64_t lru_last_used = ggml_time_ms();
size_t count_active = 0;
{
std::unique_lock<std::mutex> lk(mutex);
for (const auto & m : mapping) {
if (m.second.meta.is_running()) {
count_active++;
if (m.second.meta.last_used < lru_last_used) {
lru_model_name = m.first;
lru_last_used = m.second.meta.last_used;
}
}
}
}
if (!lru_model_name.empty() && count_active >= (size_t)base_params.models_max) {
SRV_INF("models_max limit reached, removing LRU name=%s\n", lru_model_name.c_str());
unload(lru_model_name);
// wait for unload to complete
{
std::unique_lock<std::mutex> lk(mutex);
cv.wait(lk, [this, &lru_model_name]() {
return mapping[lru_model_name].meta.status == SERVER_MODEL_STATUS_UNLOADED;
});
}
}
}
void server_models::load(const std::string & name) {
load(name, load_options{});
}
void server_models::load(const std::string & name, const load_options & opts) {
if (!opts.custom_meta.has_value()) {
if (!has_model(name)) {
throw std::runtime_error("model name=" + name + " is not found");
}
unload_lru();
}
std::unique_lock<std::mutex> lk(mutex);
// edge case: block until any in-progress reload has finished so we always load
// against the freshest preset and a consistent mapping state
cv.wait(lk, [this]() { return !is_reloading; });
auto meta = opts.custom_meta.has_value() ? *opts.custom_meta : mapping[name].meta;
if (meta.status != SERVER_MODEL_STATUS_UNLOADED) {
SRV_INF("model %s is not ready\n", name.c_str());
return;
}
// Re-check capacity under the lock to prevent concurrent loads from
// exceeding models_max. Without this, the window between unload_lru()
// releasing its lock and this lock_guard acquiring allows multiple
// threads to each observe capacity and all proceed to load.
if (base_params.models_max > 0) {
size_t count_active = 0;
for (const auto & m : mapping) {
if (m.second.meta.is_running()) {
count_active++;
}
}
if (count_active >= (size_t)base_params.models_max) {
throw std::runtime_error("model limit reached, try again later");
}
}
// prepare new instance info
instance_t inst;
inst.meta = meta;
inst.meta.port = get_free_port();
inst.meta.status = SERVER_MODEL_STATUS_LOADING;
inst.meta.loaded_info = json{};
inst.meta.last_used = ggml_time_ms();
if (inst.meta.port <= 0) {
throw std::runtime_error("failed to get a port number");
}
inst.subproc = std::make_shared<server_subproc>();
{
SRV_INF("spawning server instance with name=%s on port %d\n", inst.meta.name.c_str(), inst.meta.port);
inst.meta.update_args(ctx_preset, bin_path); // render args
std::vector<std::string> child_args = inst.meta.args; // copy
std::vector<std::string> child_env = base_env; // copy
child_env.push_back("LLAMA_SERVER_ROUTER_PORT=" + std::to_string(base_params.port));
if (opts.mode == SERVER_CHILD_MODE_DOWNLOAD) {
inst.meta.status = SERVER_MODEL_STATUS_DOWNLOADING;
child_env.push_back("LLAMA_SERVER_CHILD_MODE=download");
child_env.push_back("LLAMA_ARG_HF_REPO=" + name);
}
SRV_INF("%s", "spawning server instance with args:\n");
for (const auto & arg : child_args) {
SRV_INF(" %s\n", arg.c_str());
}
inst.meta.args = child_args; // save for debugging
std::vector<char *> argv = to_char_ptr_array(child_args);
std::vector<char *> envp = to_char_ptr_array(child_env);
// TODO @ngxson : maybe separate stdout and stderr in the future
// so that we can use stdout for commands and stderr for logging
int options = subprocess_option_no_window | subprocess_option_combined_stdout_stderr;
inst.subproc->sproc.emplace();
int result = subprocess_create_ex(argv.data(), options, envp.data(), &inst.subproc->get());
if (result != 0) {
throw std::runtime_error("failed to spawn server instance");
}
}
// start a thread to manage the child process
// captured variables are guaranteed to be destroyed only after the thread is joined
inst.th = std::thread([
this, name,
child_proc = inst.subproc,
port = inst.meta.port,
stop_timeout = inst.meta.stop_timeout,
child_mode = opts.mode
]() {
FILE * stdin_file = subprocess_stdin(&child_proc->get());
FILE * stdout_file = subprocess_stdout(&child_proc->get()); // combined stdout/stderr
std::thread log_thread([&]() {
// read stdout/stderr and forward to main server log
// also handle status report from child process
std::vector<char> vec_buf(128 * 1024); // large buffer for storing info
char * buffer = vec_buf.data();
if (stdout_file) {
while (fgets(buffer, vec_buf.size(), stdout_file) != nullptr) {
LOG("[%5d] %s", port, buffer);
std::string str(buffer);
if (string_starts_with(buffer, CMD_CHILD_TO_ROUTER_STATE)) {
this->handle_child_state(name, str);
}
}
} else {
SRV_ERR("failed to get stdout/stderr of child process for name=%s\n", name.c_str());
}
});
std::thread stopping_thread([&]() {
// thread to monitor explicit stop requests; child crash is signalled via child_proc->stopped
auto is_stopping = [this, &name]() {
return this->stopping_models.find(name) != this->stopping_models.end();
};
{
std::unique_lock<std::mutex> lk(this->mutex);
this->cv_stop.wait(lk, [&]() {
return is_stopping() || child_proc->stopped.load(std::memory_order_acquire);
});
}
// child crashed or finished on its own, skip graceful shutdown sequence
if (child_proc->stopped.load(std::memory_order_acquire)) {
return;
}
SRV_INF("stopping model instance name=%s\n", name.c_str());
fprintf(stdin_file, "%s\n", CMD_ROUTER_TO_CHILD_EXIT);
fflush(stdin_file);
int64_t start_time = ggml_time_ms();
while (true) {
std::unique_lock<std::mutex> lk(this->mutex);
if (!is_stopping() || child_proc->stopped.load(std::memory_order_acquire)) {
return;
}
int64_t elapsed = ggml_time_ms() - start_time;
if (elapsed >= stop_timeout * 1000) {
lk.unlock();
SRV_WRN("force-killing model instance name=%s after %d seconds timeout\n", name.c_str(), stop_timeout);
child_proc->terminate();
return;
}
this->cv_stop.wait_for(lk, std::chrono::seconds(1), [&]() {
return !is_stopping() || child_proc->stopped.load(std::memory_order_acquire);
});
}
});
// we reach here when the child process exits (stdout EOF)
// note: we cannot join() prior to this point because it will close stdin_file
if (log_thread.joinable()) {
log_thread.join();
}
child_proc->stopped.store(true, std::memory_order_release);
{
std::lock_guard<std::mutex> lk(this->mutex);
stopping_models.erase(name);
cv_stop.notify_all();
}
if (stopping_thread.joinable()) {
stopping_thread.join();
}
// get the exit code
int exit_code = 0;
subprocess_join(&child_proc->get(), &exit_code);
subprocess_destroy(&child_proc->get());
// update status and exit code
if (child_mode == SERVER_CHILD_MODE_DOWNLOAD) {
// instance will be cleaned up on next load_models() call
} else {
this->update_status(name, {
SERVER_MODEL_STATUS_UNLOADED,
exit_code
});
}
SRV_INF("instance name=%s exited with status %d\n", name.c_str(), exit_code);
});
// clean up old process/thread if exists
{
auto & old_instance = mapping[name];
// old process should have exited already, but just in case, we clean it up here
if (old_instance.subproc && old_instance.subproc->is_alive()) {
SRV_WRN("old process for model name=%s is still alive, this is unexpected\n", name.c_str());
old_instance.subproc->terminate(); // force kill
}
if (old_instance.th.joinable()) {
old_instance.th.join();
}
}
notify_sse("model_status", name, {
{"status", server_model_status_to_string(inst.meta.status)},
});
mapping[name] = std::move(inst);
cv.notify_all();
}
void server_models::unload(const std::string & name) {
std::unique_lock<std::mutex> lk(mutex);
auto it = mapping.find(name);
if (it != mapping.end()) {
if (it->second.meta.status == SERVER_MODEL_STATUS_DOWNLOADING) {
SRV_INF("cancelling download for model name=%s\n", name.c_str());
it->second.subproc->request_exit();
// for convenience, we wait the status change here
wait(lk, name, [](const server_model_meta & new_meta) {
return new_meta.status != SERVER_MODEL_STATUS_DOWNLOADING;
});
} else if (it->second.meta.is_running()) {
SRV_INF("stopping model instance name=%s\n", name.c_str());
stopping_models.insert(name);
if (it->second.meta.status == SERVER_MODEL_STATUS_LOADING) {
// special case: if model is in loading state, unloading means force-killing it
SRV_WRN("model name=%s is still loading, force-killing\n", name.c_str());
it->second.subproc->terminate();
}
cv_stop.notify_all();
// status change will be handled by the managing thread
} else {
SRV_WRN("model instance name=%s is not running\n", name.c_str());
}
}
}
void server_models::unload_all() {
std::vector<std::thread> to_join;
{
std::lock_guard<std::mutex> lk(mutex);
for (auto & [name, inst] : mapping) {
if (inst.meta.status == SERVER_MODEL_STATUS_DOWNLOADING) {
SRV_INF("cancelling download for model name=%s\n", name.c_str());
inst.subproc->stopped.store(true, std::memory_order_relaxed);
} else if (inst.meta.is_running()) {
SRV_INF("stopping model instance name=%s\n", name.c_str());
stopping_models.insert(name);
cv_stop.notify_all();
// status change will be handled by the managing thread
}
// moving the thread to join list to avoid deadlock
to_join.push_back(std::move(inst.th));
}
}
for (auto & th : to_join) {
if (th.joinable()) {
th.join();
}
}
}
void server_models::update_status(const std::string & name, const update_status_args & args) {
std::unique_lock<std::mutex> lk(mutex);
auto it = mapping.find(name);
if (it != mapping.end()) {
auto & meta = it->second.meta;
meta.status = args.status;
meta.exit_code = args.exit_code;
if (!args.loaded_info.is_null()) {
meta.loaded_info = args.loaded_info;
}
if (!args.progress.is_null()) {
meta.progress = args.progress;
}
}
// broadcast status change to SSE
{
json data = {
{"status", server_model_status_to_string(args.status)},
};
if (args.status == SERVER_MODEL_STATUS_UNLOADED) {
data["exit_code"] = args.exit_code;
}
if (!args.loaded_info.is_null()) {
data["info"] = args.loaded_info;
}
if (!args.progress.is_null()) {
data["progress"] = args.progress;
}
// note: notify_sse doesn't acquire the lock, so no deadlock here
notify_sse("status_change", name, data);
}
cv.notify_all();
}
void server_models::update_download_progress(const std::string & name, const common_download_progress & progress, bool done, bool ok) {
json curr;
{
std::lock_guard<std::mutex> lk(mutex);
auto it = mapping.find(name);
if (it != mapping.end()) {
if (done) {
// mark the instance to be erased on next load_models() call
it->second.meta.status = SERVER_MODEL_STATUS_DOWNLOADED;
need_reload = true;
} else {
json & info = it->second.meta.loaded_info;
if (!info.contains("progress")) {
info["progress"] = json{};
}
info["progress"][progress.url] = {
{"done", progress.downloaded},
{"total", progress.total},
};
curr = it->second.meta.loaded_info; // copy
}
}
}
if (done) {
cv.notify_all(); // notify in case unload() is waiting for download to be cancelled
notify_sse(ok ? "download_finished" : "download_failed", name, {});
} else {
notify_sse("download_progress", name, curr);
}
}
bool server_models::remove(const std::string & name) {
// do everything under one lock acquisition; avoid get_meta() /
// unload() because they can trigger load_models() which erases
// transient DOWNLOADING / DOWNLOADED entries as a side-effect
std::unique_lock<std::mutex> lk(mutex);
auto it = mapping.find(name);
if (it == mapping.end()) {
throw std::runtime_error("model name=" + name + " is not found");
}
if (it->second.meta.source != SERVER_MODEL_SOURCE_CACHE) {
throw std::runtime_error("model name=" + name + " is not removable (not from cache)");
}
if (it->second.meta.status == SERVER_MODEL_STATUS_DOWNLOADING) {
// cancel in-flight download
SRV_INF("cancelling download for model name=%s\n", name.c_str());
it->second.subproc->request_exit();
} else if (it->second.meta.is_running()) {
// stop running instance
SRV_INF("stopping model instance name=%s\n", name.c_str());
stopping_models.insert(name);
if (it->second.meta.status == SERVER_MODEL_STATUS_LOADING) {
it->second.subproc->terminate();
}
cv_stop.notify_all();
}
// wait until the monitoring thread finishes
wait(lk, name, [](const server_model_meta & meta) {
return meta.status == SERVER_MODEL_STATUS_UNLOADED
|| meta.status == SERVER_MODEL_STATUS_DOWNLOADED;
});
// re-find after wait - load_models() may have erased the entry during the wait
it = mapping.find(name);
if (it == mapping.end()) {
// load_models() already joined the thread and erased the entry;
// we just need to clean up the cached files on disk
lk.unlock();
bool ok = common_download_remove(name);
SRV_INF("removing model name=%s from cache (%s)\n", name.c_str(), ok ? "succeeded" : "partial");
notify_sse("model_remove", name, {});
return true;
}
// join before erasing - thread no longer acquires this mutex
if (it->second.th.joinable()) {
it->second.th.join();
}
// remove from disk (best-effort: cancelled downloads may have no cached files)
bool ok = common_download_remove(name);
mapping.erase(name);
if (!ok) {
SRV_WRN("removing model name=%s from disk returned false (no cached files?)\n", name.c_str());
}
SRV_INF("removing model name=%s from cache (%s)\n", name.c_str(), ok ? "succeeded" : "partial");
notify_sse("model_remove", name, {});
return true;
}
void server_models::wait(const std::string & name, std::function<bool(const server_model_meta &)> predicate) {
std::unique_lock<std::mutex> lk(mutex);
wait(lk, name, predicate);
}
void server_models::wait(std::unique_lock<std::mutex> & lk, const std::string & name, std::function<bool(const server_model_meta &)> predicate) {
cv.wait(lk, [this, &name, &predicate]() {
auto it = mapping.find(name);
if (it != mapping.end()) {
return predicate(it->second.meta);
}
// model was removed from mapping by another code path (e.g. load_models()).
// nothing left to wait for - tell the caller to proceed.
return true;
});
}
bool server_models::ensure_model_ready(const std::string & name) {
auto meta = get_meta(name);
if (!meta.has_value()) {
throw std::runtime_error("model name=" + name + " is not found");
}
if (meta->is_ready()) {
return false; // ready for taking requests
}
if (meta->status == SERVER_MODEL_STATUS_SLEEPING) {
return false; // child is sleeping but still running; new request will wake it up
}
if (meta->status == SERVER_MODEL_STATUS_UNLOADED) {
SRV_INF("model name=%s is not loaded, loading...\n", name.c_str());
load(name);
}
// wait for loading to complete
SRV_INF("waiting until model name=%s is fully loaded...\n", name.c_str());
wait(name, [&meta](const server_model_meta & new_meta) {
if (new_meta.status != SERVER_MODEL_STATUS_LOADING) {
meta = new_meta; // update meta for final check after wait
return true;
}
return false;
});
// check final status
if (!meta.has_value() || meta->is_failed()) {
throw std::runtime_error("model name=" + name + " failed to load");
}
return true;
}
server_http_res_ptr server_models::proxy_request(const server_http_req & req, const std::string & method, const std::string & name, bool update_last_used) {
auto meta = get_meta(name);
if (!meta.has_value()) {
throw std::runtime_error("model name=" + name + " is not found");
}
if (!meta->is_running()) {
throw std::invalid_argument("model name=" + name + " is not running");
}
if (update_last_used) {
std::unique_lock<std::mutex> lk(mutex);
mapping[name].meta.last_used = ggml_time_ms();
}
SRV_INF("proxying request to model %s on port %d\n", name.c_str(), meta->port);
std::string proxy_path = req.path;
if (!req.query_string.empty()) {
proxy_path += '?' + req.query_string;
}
auto proxy = std::make_unique<server_http_proxy>(
method,
"http",
CHILD_ADDR,
meta->port,
proxy_path,
req.headers,
req.body,
req.files,
req.should_stop,
base_params.timeout_read,
base_params.timeout_write
);
return proxy;
}
void server_models::handle_child_state(const std::string & name, const std::string & raw_input) {
server_state state;
json payload;
try {
json data = json::parse(raw_input.substr(strlen(CMD_CHILD_TO_ROUTER_STATE)));
state = server_state_from_str(json_value(data, "state", std::string()));
payload = json_value(data, "payload", json{});
} catch (const std::exception & e) {
SRV_ERR("failed to parse child state update for name=%s: %s\n", name.c_str(), e.what());
return;
}
switch (state) {
case SERVER_STATE_DOWNLOADING:
{
std::string result = json_value(payload, "result", std::string());
std::string url = json_value(payload, "url", std::string());
auto request_exit = [&]() {
std::lock_guard<std::mutex> lk(mutex);
auto it = mapping.find(name);
if (it != mapping.end()) {
return it->second.subproc->request_exit();
}
};
if (result == "download_finished") {
update_download_progress(name, {}, true, true);
request_exit();
} else if (result == "download_failed") {
update_download_progress(name, {}, true, false);
request_exit();
} else if (!url.empty()) {
common_download_progress p;
p.url = url;
p.downloaded = json_value(payload, "downloaded", (size_t)0);
p.total = json_value(payload, "total", (size_t)0);
update_download_progress(name, p, false);
}
} break;
case SERVER_STATE_LOADING:
{
update_status(name, {
SERVER_MODEL_STATUS_LOADING,
0,
nullptr, // no loaded_info yet
payload,
});
} break;
case SERVER_STATE_READY:
{
update_status(name, {
SERVER_MODEL_STATUS_LOADED,
0,
// note: payload can be empty if this is a wakeup from sleep
payload.size() > 0 ? payload : nullptr,
{}, // reset progress info
});
} break;
case SERVER_STATE_SLEEPING:
{
update_status(name, { SERVER_MODEL_STATUS_SLEEPING });
} break;
default:
// should never happen, but just in case
GGML_ASSERT(false && "unexpected state from child server");
}
}
//
// server_child
//
bool server_child::is_child() {
const char * router_port = std::getenv("LLAMA_SERVER_ROUTER_PORT");
return router_port != nullptr;
}
server_child_mode server_child::get_mode() {
const char * mode = std::getenv("LLAMA_SERVER_CHILD_MODE");
std::string mode_str(mode ? mode : "");
if (mode_str == "download") {
return SERVER_CHILD_MODE_DOWNLOAD;
} else {
return SERVER_CHILD_MODE_NORMAL;
}
}
struct server_download_state : public common_download_callback {
server_child * self;
std::function<bool()> should_stop;
std::atomic<int64_t> last_progress_time{0}; // multiple files downloading in different threads
bool is_ok = false;
server_download_state(server_child * s) : self(s) {}
bool run(common_params & params) {
try {
common_models_handler handler = common_models_handler_init(params, LLAMA_EXAMPLE_SERVER);
common_models_handler_apply(handler, params, this);
is_ok = true;
} catch (const std::exception & e) {
auto model_name = params.model.get_name();
SRV_ERR("download failed for model name=%s: %s\n", model_name.c_str(), e.what());
is_ok = false;
}
return is_ok;
}
void on_progress(const common_download_progress & p) {
json data = {
{"url", p.url},
{"downloaded", p.downloaded},
{"total", p.total},
};
self->notify_to_router(server_state_to_str(SERVER_STATE_DOWNLOADING), data);
}
void on_start(const common_download_progress & p) override {
on_progress(p);
}
void on_update(const common_download_progress & p) override {
int64_t now = ggml_time_ms();
// throttle progress updates to avoid flooding logs
if (now - last_progress_time.load(std::memory_order_relaxed) >= 100) {
on_progress(p);
last_progress_time.store(now, std::memory_order_relaxed);
}
}
void on_done(const common_download_progress & p, bool) override {
on_progress(p);
}
bool is_cancelled() const override {
return should_stop ? should_stop() : false;
}
};
int server_child::run_download(common_params & params) {
auto cancelled = std::make_shared<std::atomic<bool>>(false);
// monitor stdin for cancellation command from the router
std::thread signal_thread = setup([cancelled](int) {
cancelled->store(true, std::memory_order_relaxed);
});
server_download_state dl(this);
dl.should_stop = [cancelled]() {
return cancelled->load(std::memory_order_relaxed);
};
bool ok = dl.run(params);
notify_to_router(server_state_to_str(SERVER_STATE_DOWNLOADING), {
{"result", ok ? "download_finished" : "download_failed"},
});
// router should send CMD_ROUTER_TO_CHILD_EXIT after receiving the result
if (signal_thread.joinable()) {
signal_thread.join();
}
SRV_INF("download completed %s\n", ok ? "successfully" : "with errors");
return 0;
}
std::thread server_child::setup(const std::function<void(int)> & shutdown_handler) {
// setup thread for monitoring stdin
return std::thread([shutdown_handler]() {
// wait for EOF on stdin
SRV_INF("%s", "child server monitoring thread started, waiting for EOF on stdin...\n");
bool eof = false;
while (true) {
std::string line;
if (!std::getline(std::cin, line)) {
// EOF detected, that means the router server is unexpectedly exit or killed
eof = true;
break;
}
if (line.find(CMD_ROUTER_TO_CHILD_EXIT) != std::string::npos) {
SRV_INF("%s", "exit command received, exiting...\n");
shutdown_handler(0);
break;
}
}
if (eof) {
SRV_INF("%s", "EOF on stdin detected, forcing shutdown...\n");
exit(1);
}
});
}
void server_child::notify_to_router(const std::string & state, const json & payload) {
json data = {
{"state", state},
{"payload", payload},
};
std::lock_guard<std::mutex> lk(mtx_stdout);
common_log_pause(common_log_main());
fflush(stdout);
fprintf(stdout, "%s%s\n", CMD_CHILD_TO_ROUTER_STATE, safe_json_to_str(data).c_str());
fflush(stdout);
common_log_resume(common_log_main());
}
//
// server_models_routes
//
// RAII wrapper similar to server_response_reader, but doesn't use server_queue
static std::atomic<int> sse_client_id_counter = 0;
struct server_models_sse_client {
server_response & queue_results;
int client_id;
server_models_sse_client(server_response & q)
: queue_results(q), client_id(sse_client_id_counter.fetch_add(1, std::memory_order_relaxed)) {
SRV_DBG("new SSE client connected, assigned client_id=%d\n", client_id);
queue_results.add_waiting_task_id(client_id);
}
~server_models_sse_client() {
SRV_DBG("SSE client disconnected, removing client_id=%d\n", client_id);
queue_results.remove_waiting_task_id(client_id);
}
// return nullptr if should_stop() is true before receiving a result
// note: if one error is received, it will stop further processing and return error result
server_task_result_ptr next(const std::function<bool()> & should_stop) {
while (true) {
static const int http_polling_seconds = 1; // check should_stop every 1 second
server_task_result_ptr result = queue_results.recv_with_timeout({client_id}, http_polling_seconds);
if (result == nullptr) {
// timeout, check stop condition
if (should_stop()) {
return nullptr;
}
// continue waiting otherwise
} else {
SRV_DBG("recv result for client_id=%d: %s\n", client_id, safe_json_to_str(result->to_json()).c_str());
return result;
}
}
// should not reach here
}
};
static void res_ok(std::unique_ptr<server_http_res> & res, const json & response_data) {
res->status = 200;
res->data = safe_json_to_str(response_data);
}
static void res_err(std::unique_ptr<server_http_res> & res, const json & error_data) {
res->status = json_value(error_data, "code", 500);
res->data = safe_json_to_str({{ "error", error_data }});
}
static bool router_validate_model(std::string & name, server_models & models, bool models_autoload, std::unique_ptr<server_http_res> & res) {
if (name.empty()) {
res_err(res, format_error_response("model name is missing from the request", ERROR_TYPE_INVALID_REQUEST));
return false;
}
auto meta = models.get_meta(name);
if (!meta.has_value()) {
res_err(res, format_error_response(string_format("model '%s' not found", name.c_str()), ERROR_TYPE_INVALID_REQUEST));
return false;
}
// resolve alias to canonical model name
name = meta->name;
if (models_autoload) {
models.ensure_model_ready(name);
} else {
if (!meta->is_running()) {
res_err(res, format_error_response("model is not loaded", ERROR_TYPE_INVALID_REQUEST));
return false;
}
}
return true;
}
static bool is_autoload(const common_params & params, const server_http_req & req) {
std::string autoload = req.get_param("autoload");
if (autoload.empty()) {
return params.models_autoload;
} else {
return autoload == "true" || autoload == "1";
}
}
// percent encode one query or path component, covers reserved chars without pulling in
// httplib::detail. used by the stream routes to forward conversation_id to children safely
static std::string encode_qs(const std::string & in) {
std::string out;
out.reserve(in.size() * 3);
for (unsigned char c : in) {
bool safe = (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9')
|| c == '-' || c == '_' || c == '.' || c == '~';
if (safe) {
out.push_back(char(c));
} else {
char buf[4];
std::snprintf(buf, sizeof(buf), "%%%02X", c);
out.append(buf, 3);
}
}
return out;
}
// resolve the child that owns a conversation's stream session via the conv_id -> model map
// populated when the POST was routed. single map lookup then a meta lookup, no polling, no
// parsing of the conv id. returns nullopt when nothing maps, the caller answers not found and
// the client recovers
static std::optional<server_model_meta> resolve_child_for_conv(
server_models & models, const std::string & conversation_id) {
if (conversation_id.empty()) {
return std::nullopt;
}
auto tracked = models.conv_models.lookup(conversation_id);
if (!tracked.has_value()) {
return std::nullopt;
}
auto meta = models.get_meta(*tracked);
if (meta.has_value() && meta->is_ready()) {
return meta;
}
return std::nullopt;
}
void server_models_routes::init_routes() {
this->get_router_props = [this](const server_http_req & req) {
std::string name = req.get_param("model");
if (name.empty()) {
// main instance
auto res = std::make_unique<server_http_res>();
res_ok(res, {
// TODO: add support for this on web UI
{"role", "router"},
{"max_instances", params.models_max},
{"models_autoload", params.models_autoload},
// this is a dummy response to make sure the UI doesn't break
{"model_alias", "llama-server"},
{"model_path", "none"},
{"default_generation_settings", {
{"params", json{}},
{"n_ctx", 0},
}},
// New key
{"ui_settings", ui_settings},
{"build_info", std::string(llama_build_info())},
{"cors_proxy_enabled", params.ui_mcp_proxy},
});
return res;
}
return proxy_get(req);
};
this->proxy_get = [this](const server_http_req & req) {
std::string method = "GET";
std::string name = req.get_param("model");
bool autoload = is_autoload(params, req);
auto error_res = std::make_unique<server_http_res>();
if (!router_validate_model(name, models, autoload, error_res)) {
return error_res;
}
return models.proxy_request(req, method, name, false);
};
this->proxy_post = [this](const server_http_req & req) {
std::string method = "POST";
json body = json::parse(req.body);
std::string name = json_value(body, "model", std::string());
bool autoload = is_autoload(params, req);
auto error_res = std::make_unique<server_http_res>();
if (!router_validate_model(name, models, autoload, error_res)) {
return error_res;
}
// remember which child serves this conversation so the stream routes can route straight
// to it without polling, keyed on the exact conv id from the header
std::string conv_id = stream_conv_id_from_headers(req.headers);
if (!conv_id.empty()) {
models.conv_models.remember(conv_id, name);
}
return models.proxy_request(req, method, name, true); // update last usage for POST request only
};
this->post_router_models_load = [this](const server_http_req & req) {
auto res = std::make_unique<server_http_res>();
json body = json::parse(req.body);
std::string name = json_value(body, "model", std::string());
auto meta = models.get_meta(name);
if (!meta.has_value()) {
res_err(res, format_error_response("model is not found", ERROR_TYPE_NOT_FOUND));
return res;
}
if (meta->is_running()) {
res_err(res, format_error_response("model is already running", ERROR_TYPE_INVALID_REQUEST));
return res;
}
models.load(meta->name);
res_ok(res, {{"success", true}});
return res;
};
this->get_router_models = [this](const server_http_req & req) {
bool reload = !req.get_param("reload", "").empty();
if (reload) {
models.load_models();
}
auto res = std::make_unique<server_http_res>();
json models_json = json::array();
auto all_models = models.get_all_meta();
std::time_t t = std::time(0);
for (const auto & meta : all_models) {
json status {
{"value", server_model_status_to_string(meta.status)},
{"args", meta.args},
};
if (!meta.preset.name.empty()) {
common_preset preset_copy = meta.preset;
unset_reserved_args(preset_copy, false);
preset_copy.unset_option("LLAMA_ARG_HOST");
preset_copy.unset_option("LLAMA_ARG_PORT");
preset_copy.unset_option("LLAMA_ARG_ALIAS");
preset_copy.unset_option("LLAMA_ARG_TAGS");
status["preset"] = preset_copy.to_ini();
}
if (meta.is_failed()) {
status["exit_code"] = meta.exit_code;
status["failed"] = true;
}
// pi coding agent multimodal compatibility
json input_modalities = json::array({"text"});
if (meta.multimodal.inp_vision) {
input_modalities.push_back("image");
}
if (meta.multimodal.inp_audio) {
input_modalities.push_back("audio");
}
json architecture {
{"input_modalities", input_modalities},
{"output_modalities", json::array({"text"})},
};
json model_info = json {
{"id", meta.name},
{"aliases", meta.aliases},
{"tags", meta.tags},
{"object", "model"}, // for OAI-compat
{"owned_by", "llamacpp"}, // for OAI-compat
{"created", t}, // for OAI-compat
{"status", status},
{"architecture", architecture},
{"source", server_model_source_to_string(meta.source)},
{"can_remove", meta.source == SERVER_MODEL_SOURCE_CACHE},
// {"need_download", meta.need_download},
// TODO: add other fields, may require reading GGUF metadata
};
// merge with loaded_info from the child process if available
if (meta.is_running()) {
for (auto it = meta.loaded_info.begin(); it != meta.loaded_info.end(); ++it) {
if (!model_info.contains(it.key())) {
model_info[it.key()] = it.value();
}
}
}
models_json.push_back(model_info);
}
res_ok(res, {
{"data", models_json},
{"object", "list"},
});
return res;
};
this->post_router_models_unload = [this](const server_http_req & req) {
auto res = std::make_unique<server_http_res>();
json body = json::parse(req.body);
std::string name = json_value(body, "model", std::string());
auto model = models.get_meta(name);
if (!model.has_value()) {
res_err(res, format_error_response("model is not found", ERROR_TYPE_INVALID_REQUEST));
return res;
}
if (!model->is_running() && model->status != SERVER_MODEL_STATUS_DOWNLOADING) {
res_err(res, format_error_response("model is not running", ERROR_TYPE_INVALID_REQUEST));
return res;
}
models.unload(model->name);
res_ok(res, {{"success", true}});
return res;
};
this->get_router_models_sse = [this](const server_http_req & req) {
auto res = std::make_unique<server_http_res>();
res->status = 200;
res->content_type = "text/event-stream";
auto sse_client = std::make_shared<server_models_sse_client>(models.sse);
res->next = [this, sse_client, &req](std::string & output) -> bool {
auto result = sse_client->next([&]() {
return stopping.load(std::memory_order_relaxed) || req.should_stop();
});
if (result == nullptr) {
return false; // client disconnected or should_stop
}
output = "data: " + safe_json_to_str(result->to_json()) + "\n\n";
return true; // listen for the next event
};
return res;
};
this->post_router_models = [this](const server_http_req & req) {
auto res = std::make_unique<server_http_res>();
json body = json::parse(req.body);
std::string name = json_value(body, "model", std::string());
if (name.empty()) {
throw std::invalid_argument("model must be a non-empty string");
}
common_params p;
p.model.hf_repo = name;
p.hf_token = params.hf_token;
// validate by fetching metadata
bool ok = false;
try {
common_models_handler_init(p, LLAMA_EXAMPLE_SERVER);
ok = true;
} catch (...) {
SRV_ERR("unknown error while validating model '%s'\n", name.c_str());
// other exceptions will be handled by the outer ex_wrapper()
throw;
}
if (!ok) {
throw std::invalid_argument("model validation failed, unable to download");
}
// reject if model already exists
if (models.has_model(name)) {
throw std::invalid_argument("model '" + name + "' already exists");
}
// then, proceed with the actual download
SRV_INF("starting download for model '%s'\n", name.c_str());
{
server_models::load_options load_opts;
load_opts.mode = SERVER_CHILD_MODE_DOWNLOAD;
load_opts.custom_meta = server_model_meta{};
load_opts.custom_meta->source = SERVER_MODEL_SOURCE_CACHE;
load_opts.custom_meta->name = name;
models.load(name, load_opts);
}
res_ok(res, {{"success", true}});
return res;
};
this->del_router_models = [this](const server_http_req & req) {
auto res = std::make_unique<server_http_res>();
std::string name = req.get_param("model");
if (name.empty()) {
throw std::invalid_argument("model must be a non-empty string");
}
models.remove(name); // throws on error
res_ok(res, {{"success", true}});
return res;
};
this->router_stream_get = [this](const server_http_req & req) {
// GET /v1/stream/<conv_id>?from=N. resolve the owning child from the conv_id -> model
// map, 404 when nothing maps
auto res = std::make_unique<server_http_res>();
std::string conv_id = req.get_param("conv_id");
if (conv_id.empty()) {
res_err(res, format_error_response("Missing conversation id in path", ERROR_TYPE_INVALID_REQUEST));
return res;
}
std::optional<server_model_meta> owner = resolve_child_for_conv(models, conv_id);
if (!owner.has_value()) {
res_err(res, format_error_response("Stream not found or expired", ERROR_TYPE_NOT_FOUND));
return res;
}
std::string from = req.get_param("from");
std::string child_path = "/v1/stream/" + encode_qs(conv_id);
if (!from.empty()) {
child_path += "?from=" + from;
}
SRV_INF("proxying stream resume to model %s on port %d, path=%s\n",
owner->name.c_str(), owner->port, child_path.c_str());
auto proxy = std::make_unique<server_http_proxy>(
"GET",
"http",
CHILD_ADDR,
owner->port,
child_path,
req.headers,
req.body,
req.files,
req.should_stop,
params.timeout_read,
params.timeout_write);
return std::unique_ptr<server_http_res>(std::move(proxy));
};
this->router_streams_lookup = [this](const server_http_req & req) {
// POST /v1/streams/lookup. resolve each requested conv id to its owning child via the
// map, group the ids per child, and query only the children that actually own some of
// them instead of fanning out to every ready child. a child only answers for the ids
// it owns, never lists anything else
auto res = std::make_unique<server_http_res>();
std::vector<std::string> requested;
try {
json body = json::parse(req.body);
if (body.contains("conversation_ids") && body["conversation_ids"].is_array()) {
for (const auto & v : body["conversation_ids"]) {
if (v.is_string() && !v.get<std::string>().empty()) {
requested.push_back(v.get<std::string>());
}
}
}
} catch (const std::exception &) {
res_ok(res, json::array());
return res;
}
// group requested ids by the child port that owns them, drop ids that map to nothing
std::unordered_map<int, json> per_child;
for (const auto & cid : requested) {
auto owner = resolve_child_for_conv(models, cid);
if (!owner.has_value()) {
continue;
}
per_child[owner->port].push_back(cid);
}
json aggregated = json::array();
for (auto & [port, ids] : per_child) {
json child_body = {{"conversation_ids", ids}};
httplib::Client cli(CHILD_ADDR, port);
cli.set_connection_timeout(0, STREAM_LOOKUP_TIMEOUT_MS * 1000);
cli.set_read_timeout(0, STREAM_LOOKUP_TIMEOUT_MS * 1000);
cli.set_write_timeout(0, STREAM_LOOKUP_TIMEOUT_MS * 1000);
auto resp = cli.Post("/v1/streams/lookup", child_body.dump(), "application/json");
if (!resp || resp->status != 200) {
continue;
}
try {
json child_arr = json::parse(resp->body);
if (!child_arr.is_array()) {
continue;
}
for (auto & entry : child_arr) {
if (entry.is_object()) {
aggregated.push_back(entry);
}
}
} catch (const std::exception &) {
continue;
}
}
res_ok(res, aggregated);
return res;
};
this->router_stream_delete = [this](const server_http_req & req) {
// DELETE /v1/stream/<conv_id>. resolve the owning child via the map and forward only to
// it, evict_and_cancel is idempotent on the child
auto res = std::make_unique<server_http_res>();
std::string conv_id = req.get_param("conv_id");
if (conv_id.empty()) {
res_err(res, format_error_response("Missing conversation id in path", ERROR_TYPE_INVALID_REQUEST));
return res;
}
std::string child_path = "/v1/stream/" + encode_qs(conv_id);
auto owner = resolve_child_for_conv(models, conv_id);
if (owner.has_value()) {
httplib::Client cli(CHILD_ADDR, owner->port);
cli.set_connection_timeout(0, STREAM_LOOKUP_TIMEOUT_MS * 1000);
cli.set_read_timeout(0, STREAM_LOOKUP_TIMEOUT_MS * 1000);
cli.set_write_timeout(0, STREAM_LOOKUP_TIMEOUT_MS * 1000);
auto resp = cli.Delete(child_path.c_str());
(void) resp; // best effort, 404 and network errors are equivalent to no op
}
// drop the tracking entry, the session is being torn down
models.conv_models.forget(conv_id);
res->status = 204;
res->content_type = "application/json";
return res;
};
}
//
// server_http_proxy
//
// simple implementation of a pipe
// used for streaming data between threads
template<typename T>
struct pipe_t {
std::mutex mutex;
std::condition_variable cv;
std::queue<T> queue;
std::atomic<bool> writer_closed{false};
std::atomic<bool> reader_closed{false};
void close_write() {
writer_closed.store(true, std::memory_order_relaxed);
cv.notify_all();
}
void close_read() {
reader_closed.store(true, std::memory_order_relaxed);
cv.notify_all();
}
bool read(T & output, const std::function<bool()> & should_stop) {
std::unique_lock<std::mutex> lk(mutex);
constexpr auto poll_interval = std::chrono::milliseconds(500);
while (true) {
if (!queue.empty()) {
output = std::move(queue.front());
queue.pop();
return true;
}
if (writer_closed.load()) {
return false; // clean EOF
}
if (should_stop()) {
close_read(); // signal broken pipe to writer
return false; // cancelled / reader no longer alive
}
cv.wait_for(lk, poll_interval);
}
}
bool write(T && data) {
std::lock_guard<std::mutex> lk(mutex);
if (reader_closed.load()) {
return false; // broken pipe
}
queue.push(std::move(data));
cv.notify_one();
return true;
}
};
static std::string to_lower_copy(const std::string & value) {
std::string lowered(value.size(), '\0');
std::transform(value.begin(), value.end(), lowered.begin(), [](unsigned char c) { return std::tolower(c); });
return lowered;
}
static bool should_strip_proxy_header(const std::string & header_name) {
// Headers that get duplicated when router forwards child responses
if (header_name == "server" ||
header_name == "transfer-encoding" ||
header_name == "content-length" || // quick fix for https://github.com/ggml-org/llama.cpp/issues/17710
header_name == "keep-alive") {
return true;
}
// Router injects CORS, child also sends them: duplicate
if (header_name.rfind("access-control-", 0) == 0) {
return true;
}
return false;
}
static std::string generate_multipart_boundary() {
thread_local std::mt19937 gen(std::random_device{}());
static const char chars[] = "0123456789abcdefghijklmnopqrstuvwxyz";
std::uniform_int_distribution<> dis(0, sizeof(chars) - 2);
std::string boundary = "----llama-cpp-proxy-";
for (int i = 0; i < 16; i++) {
boundary += chars[dis(gen)];
}
return boundary;
}
static std::string build_multipart_body(
const json & form_fields,
const std::map<std::string, uploaded_file> & files,
const std::string & boundary) {
static auto sanitize_field = [](const std::string & text) {
std::string result;
result.reserve(text.size());
for (char c : text) {
if (c != '\n' && c != '\r' && c != '"') {
result += c;
}
}
return result;
};
std::ostringstream body;
for (const auto & [key, value] : form_fields.items()) {
if (value.is_array()) {
for (const auto & item : value) {
body << "--" << boundary << "\r\n";
body << "Content-Disposition: form-data; name=\"" << sanitize_field(key) << "\"\r\n";
body << "\r\n";
if (!item.is_string()) {
throw std::invalid_argument("expected string");
}
body << item.get<std::string>() << "\r\n";
}
} else {
body << "--" << boundary << "\r\n";
body << "Content-Disposition: form-data; name=\"" << sanitize_field(key) << "\"\r\n";
body << "\r\n";
if (!value.is_string()) {
throw std::invalid_argument("expected string");
}
body << value.get<std::string>() << "\r\n";
}
}
for (const auto & [key, file] : files) {
body << "--" << boundary << "\r\n";
body << "Content-Disposition: form-data; name=\"" << sanitize_field(key) << "\"";
if (!file.filename.empty()) {
body << "; filename=\"" << sanitize_field(file.filename) << "\"";
}
body << "\r\n";
if (!file.content_type.empty()) {
body << "Content-Type: " << sanitize_field(file.content_type) << "\r\n";
} else {
body << "Content-Type: application/octet-stream\r\n";
}
body << "\r\n";
body.write(reinterpret_cast<const char*>(file.data.data()), file.data.size());
body << "\r\n";
}
body << "--" << boundary << "--\r\n";
return body.str();
}
server_http_proxy::server_http_proxy(
const std::string & method,
const std::string & scheme,
const std::string & host,
int port,
const std::string & path,
const std::map<std::string, std::string> & headers,
const std::string & body,
const std::map<std::string, uploaded_file> & files,
const std::function<bool()> should_stop,
int32_t timeout_read,
int32_t timeout_write
) {
// shared between reader and writer threads
auto cli = std::make_shared<httplib::ClientImpl>(host, port);
auto pipe = std::make_shared<pipe_t<msg_t>>();
if (scheme == "https") {
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
cli.reset(new httplib::SSLClient(host, port));
#else
throw std::runtime_error("HTTPS requested but CPPHTTPLIB_OPENSSL_SUPPORT is not defined");
#endif
}
// setup Client
cli->set_follow_location(true);
cli->set_connection_timeout(timeout_read, 0); // use --timeout value instead of hardcoded 5 s
cli->set_write_timeout(timeout_read, 0); // reversed for cli (client) vs srv (server)
cli->set_read_timeout(timeout_write, 0);
this->status = 500; // to be overwritten upon response
this->cleanup = [pipe]() {
pipe->close_read();
pipe->close_write();
};
// wire up the receive end of the pipe
this->next = [pipe, should_stop](std::string & out) -> bool {
msg_t msg;
bool has_next = pipe->read(msg, should_stop);
if (!msg.data.empty()) {
out = std::move(msg.data);
}
return has_next; // false if EOF or pipe broken
};
// wire up the HTTP client
// note: do NOT capture `this` pointer, as it may be destroyed before the thread ends
httplib::ResponseHandler response_handler = [pipe, cli](const httplib::Response & response) {
msg_t msg;
msg.status = response.status;
for (const auto & [key, value] : response.headers) {
const auto lowered = to_lower_copy(key);
if (should_strip_proxy_header(lowered)) {
continue;
}
if (lowered == "content-type") {
msg.content_type = value;
continue;
}
msg.headers[key] = value;
}
return pipe->write(std::move(msg)); // send headers first
};
httplib::ContentReceiverWithProgress content_receiver = [pipe](const char * data, size_t data_length, size_t, size_t) {
// send data chunks
// returns false if pipe is closed / broken (signal to stop receiving)
return pipe->write({{}, 0, std::string(data, data_length), ""});
};
// when files are present, the body was converted from multipart form data to JSON
// we need to reconstruct the multipart body for the downstream server
std::string effective_body = body;
std::string override_content_type;
bool has_files = !files.empty();
if (has_files) {
json form_fields = json::parse(body, nullptr, false);
if (!form_fields.is_discarded()) {
auto boundary = generate_multipart_boundary();
effective_body = build_multipart_body(form_fields, files, boundary);
override_content_type = "multipart/form-data; boundary=" + boundary;
} else {
throw std::runtime_error("failed to parse multipart form fields JSON");
}
}
// prepare the request to destination server
httplib::Request req;
{
req.method = method;
req.path = path;
for (const auto & [key, value] : headers) {
const auto lowered = to_lower_copy(key);
if (lowered == "accept-encoding") {
// disable Accept-Encoding to avoid compressed responses
continue;
}
if (lowered == "transfer-encoding") {
// the body is already decoded
continue;
}
if (lowered == "content-length") {
// let httplib calculate Content-Length from the actual body
continue;
}
if (lowered == "content-type") {
if (has_files) {
// we set our own Content-Type with the new boundary
continue;
}
// when no files but the original request was multipart,
// the body is now JSON, so correct the Content-Type
if (value.find("multipart/form-data") != std::string::npos) {
override_content_type = "application/json; charset=utf-8";
continue;
}
}
if (lowered == "host") {
bool is_default_port = (scheme == "https" && port == 443) || (scheme == "http" && port == 80);
req.set_header(key, is_default_port ? host : host + ":" + std::to_string(port));
} else {
req.set_header(key, value);
}
}
req.body = effective_body;
if (!override_content_type.empty()) {
req.set_header("Content-Type", override_content_type);
}
req.response_handler = response_handler;
req.content_receiver = content_receiver;
}
// start the proxy thread
SRV_DBG("start proxy thread %s %s\n", req.method.c_str(), req.path.c_str());
this->thread = std::thread([cli, pipe, req]() {
auto result = cli->send(std::move(req));
if (result.error() != httplib::Error::Success) {
auto err_str = httplib::to_string(result.error());
SRV_ERR("http client error: %s\n", err_str.c_str());
pipe->write({{}, 500, "", ""}); // header
pipe->write({{}, 0, "proxy error: " + err_str, ""}); // body
}
pipe->close_write(); // signal EOF to reader
SRV_DBG("%s", "client request thread ended\n");
});
this->thread.detach();
// wait for the first chunk (headers)
{
msg_t header;
if (pipe->read(header, should_stop)) {
SRV_DBG("%s", "received response headers\n");
this->status = header.status;
this->headers = std::move(header.headers);
if (!header.content_type.empty()) {
this->content_type = std::move(header.content_type);
}
} else {
SRV_DBG("%s", "no response headers received (request cancelled?)\n");
}
}
}