server: (router) add model management API (#23976)

* wip

* server: (router) add SSE realtime updates API

* nits

* wip

* add download API

* add download api

* update docs

* add delete endpoint

* fix std::terminate

* fix crash

* fix 2

* add tests

* nits
This commit is contained in:
Xuan-Son Nguyen 2026-06-17 18:04:58 +02:00 committed by GitHub
parent b4024af6c2
commit 4b4d13ae72
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 855 additions and 63 deletions

View File

@ -997,3 +997,87 @@ std::vector<common_cached_model_info> common_list_cached_models() {
return result;
}
bool common_download_remove(const std::string & hf_repo_with_tag) {
namespace fs = std::filesystem;
auto [repo_id, tag] = common_download_split_repo_tag(hf_repo_with_tag);
if (tag.empty()) {
return hf_cache::remove_cached_repo(repo_id);
}
std::string tag_upper = tag;
for (char & c : tag_upper) {
c = (char) std::toupper((unsigned char) c);
}
auto files = hf_cache::get_cached_files(repo_id);
if (files.empty()) {
return false;
}
// collect snapshot entries whose tag matches
std::vector<fs::path> to_remove;
for (const auto & f : files) {
auto split = get_gguf_split_info(f.path);
if (split.tag == tag_upper) {
to_remove.emplace_back(f.local_path);
}
}
if (to_remove.empty()) {
return false;
}
// resolve blob paths from symlinks before deleting snapshot entries
std::vector<fs::path> blobs_to_check;
for (const auto & p : to_remove) {
std::error_code ec;
if (fs::is_symlink(p, ec)) {
auto target = fs::read_symlink(p, ec);
if (!ec) {
blobs_to_check.push_back((p.parent_path() / target).lexically_normal());
}
}
}
// remove snapshot entries
for (const auto & p : to_remove) {
std::error_code ec;
fs::remove(p, ec);
if (ec) {
LOG_WRN("%s: failed to remove %s: %s\n", __func__, p.string().c_str(), ec.message().c_str());
}
}
if (blobs_to_check.empty()) {
return true;
}
// collect blobs still referenced by remaining snapshot entries
std::unordered_set<std::string> still_referenced;
for (const auto & f : hf_cache::get_cached_files(repo_id)) {
fs::path p(f.local_path);
std::error_code ec;
if (fs::is_symlink(p, ec)) {
auto target = fs::read_symlink(p, ec);
if (!ec) {
still_referenced.insert((p.parent_path() / target).lexically_normal().string());
}
}
}
// remove orphaned blobs
for (const auto & blob : blobs_to_check) {
if (still_referenced.find(blob.string()) == still_referenced.end()) {
std::error_code ec;
fs::remove(blob, ec);
if (ec) {
LOG_WRN("%s: failed to remove blob %s: %s\n", __func__, blob.string().c_str(), ec.message().c_str());
}
}
}
return true;
}

View File

@ -115,3 +115,10 @@ int common_download_file_single(const std::string & url,
// resolve and download model from Docker registry
// return local path to downloaded model file
std::string common_docker_resolve_model(const std::string & docker);
// Remove a cached model from disk
// input format: "user/model" or "user/model:tag"
// - if tag is omitted, removes the entire repo cache directory
// - if tag is present, removes only files matching that tag (and orphaned blobs)
// returns true if anything was removed
bool common_download_remove(const std::string & hf_repo_with_tag);

View File

@ -495,4 +495,19 @@ std::string finalize_file(const hf_file & file) {
return file.final_path;
}
bool remove_cached_repo(const std::string & repo_id) {
if (!is_valid_repo_id(repo_id)) {
LOG_WRN("%s: invalid repository: %s\n", __func__, repo_id.c_str());
return false;
}
fs::path repo_path = get_repo_path(repo_id);
std::error_code ec;
auto removed = fs::remove_all(repo_path, ec);
if (ec) {
LOG_ERR("%s: failed to remove repo cache %s: %s\n", __func__, repo_path.string().c_str(), ec.message().c_str());
return false;
}
return removed > 0;
}
} // namespace hf_cache

View File

@ -29,4 +29,7 @@ hf_files get_cached_files(const std::string & repo_id = {});
// Create snapshot path (link or move/copy) and return it
std::string finalize_file(const hf_file & file);
// Remove the entire cached directory for a repo, returns true if removed
bool remove_cached_repo(const std::string & repo_id);
} // namespace hf_cache

View File

@ -180,6 +180,24 @@ That requires `JSON.stringify` when formatted to message content:
}
```
### Model management API (router mode)
Model management API was added via PR [#23976](https://github.com/ggml-org/llama.cpp/pull/23976)
The main goal of this API is to allow downloading models and/or removing models from the web UI. It relies on the model cache infrastructure under the hood to manage the list of models dynamically.
Instead of building everything from the ground up (like what most AI agents will do when you ask them to implement a similar feature), we built on top of existing, already well-engineered components inside the codebase:
- Model cache infrastructure as mentioned above (`common/download.h`)
- Server response queue (`server-queue.h`). We use this feature to broadcast events to SSE clients.
- Server router thread management (`server-models.h`). We re-use the same thread model that is used for managing subprocess life cycle, except that we don't create a new subprocess, but launch the download right inside the thread.
The flow for downloading a new model:
- POST request comes in --> `post_router_models` --> validation
- `server_models::download()` is called
- Sets up a new thread `inst.th` and runs the download inside
- If a stop request comes in, set `stop_download` to `true`
- Otherwise, upon completion, we call `load_models()` to refresh the list of models
### Notable Related PRs
- Initial server implementation: https://github.com/ggml-org/llama.cpp/pull/1443

View File

@ -1778,6 +1778,20 @@ The `status` object can be:
}
```
Note: for "downloading" state, there can be multiple files be downloading in parallel
```json
"status": {
"value": "downloading",
"progress": {
"https://...model.gguf": {
"done": 195963406,
"total": 219307424
}
}
}
```
### POST `/models/load`: Load a model
Load a model
@ -1820,6 +1834,107 @@ Response:
}
```
### GET `/models/sse`: Real-time events
Example events:
```js
{
"model": "...",
"event": "model_status",
"data": {
"status": "loading"
}
}
{
"model": "...",
"event": "download_progress",
"data": {
// note: there can be multiple files being downloaded in parallel
"https://...model.gguf": {
"done": 195963406,
"total": 219307424
}
}
}
{
"model": "...",
"event": "download_finished",
"data": {
"status": "loading"
}
}
{
"model": "...",
"event": "model_remove"
}
// special event: reload of the list of all models
{
"model": "*",
"event": "models_reload"
}
```
### POST `/models`: Download new model
Trigger a new download (non-blocking), the progress can be tracked via SSE endpoint `/models/sse`
To cancel model downloading, send an event to `/models/unload`
Download procedure:
- Send POST request to `/models`
- Subscribe to `/models/sse` for updates
- On downloading completed, you will receive either `download_finished` or `download_failed` event
- Call GET `/models` to trigger model list update. If the download success, you should see the new model in the list
Payload:
```json
{
"model": "ggml-org/gemma-3-4b-it-GGUF:Q4_K_M",
}
```
Response (download is started in the background):
```json
{
"success": true
}
```
Response (error, cannot start the download):
```json
{
"error": {
"code": 400,
"message": "model validation failed, unable to download",
"type": "invalid_request_error"
}
}
```
### DELETE `/models`: Delete a model from cache
IMPORTANT: only model stored in cache can be deleted. You cannot delete models in a preset.
Model name must be passed via query param: `?model={name}`
If delete success, it will send an SSE event of type `model_remove`
Response:
```json
{
"success": true
}
```
## API errors
`llama-server` returns errors in the same format as OAI: https://github.com/openai/openai-openapi

View File

@ -588,6 +588,23 @@ void server_http_context::post(const std::string & path, const server_http_conte
});
}
void server_http_context::del(const std::string & path, const server_http_context::handler_t & handler) const {
handlers.emplace(path, handler);
pimpl->srv->Delete(path_prefix + path, [handler](const httplib::Request & req, httplib::Response & res) {
server_http_req_ptr request = std::make_unique<server_http_req>(server_http_req{
get_params(req),
get_headers(req),
req.path,
build_query_string(req),
req.body,
{},
req.is_connection_closed
});
server_http_res_ptr response = handler(*request);
process_handler_response(std::move(request), response, res);
});
}
//
// Vertex AI Prediction protocol (AIP_PREDICT_ROUTE)
// https://cloud.google.com/vertex-ai/docs/predictions/custom-container-requirements

View File

@ -86,6 +86,7 @@ struct server_http_context {
void get(const std::string & path, const handler_t & handler) const;
void post(const std::string & path, const handler_t & handler) const;
void del(const std::string & path, const handler_t & handler) const;
// Register the Google Cloud Platform (Vertex AI) compat (AIP_PREDICT_ROUTE env var, or /predict)
// Must be called AFTER all other API routes are registered

View File

@ -9,6 +9,7 @@
#include <sheredom/subprocess.h>
#include <functional>
#include <optional>
#include <algorithm>
#include <thread>
#include <mutex>
@ -51,6 +52,21 @@ extern char **environ;
// ref: https://github.com/ggml-org/llama.cpp/issues/17862
#define CHILD_ADDR "127.0.0.1"
struct server_subproc {
std::optional<subprocess_s> sproc; // empty while in DOWNLOADING state
std::atomic<bool> stop_download{false}; // flag to signal download cancellation
subprocess_s & get() {
GGML_ASSERT(sproc.has_value() && "subprocess not initialized");
return sproc.value();
}
bool is_alive() {
return sproc.has_value() && subprocess_alive(&sproc.value());
}
};
static std::filesystem::path get_server_exec_path() {
#if defined(_WIN32)
wchar_t buf[32768] = { 0 }; // Large buffer to handle long paths
@ -272,12 +288,25 @@ void server_models::add_model(server_model_meta && meta) {
meta.update_caps();
std::string name = meta.name;
mapping[name] = instance_t{
/* subproc */ std::make_shared<subprocess_s>(),
/* subproc */ std::make_shared<server_subproc>(),
/* th */ std::thread(),
/* meta */ std::move(meta)
};
}
void server_models::notify_sse(const std::string & event, const std::string & model_id, const json & data) {
std::unique_ptr<server_task_result_router> result = std::make_unique<server_task_result_router>();
result->data = {
{"model", model_id},
{"event", event},
};
if (!data.is_null()) {
result->data["data"] = data;
}
SRV_DBG("notifying SSE clients about event '%s' for model '%s': %s\n", event.c_str(), model_id.c_str(), safe_json_to_str(result->data).c_str());
sse.broadcast(std::move(result));
}
void server_models::load_models() {
// Phase 1: load presets from all sources — pure I/O, no lock needed
// 1. cached models
@ -304,19 +333,27 @@ void server_models::load_models() {
// note: if a model exists in both cached and local, local takes precedence
common_presets final_presets;
for (const auto & [name, preset] : cached_models) final_presets[name] = preset;
for (const auto & [name, preset] : local_models) final_presets[name] = preset;
std::unordered_map<std::string, server_model_source> source_map;
for (const auto & [name, preset] : cached_models) {
final_presets[name] = preset;
source_map[name] = SERVER_MODEL_SOURCE_CACHE;
}
for (const auto & [name, preset] : local_models) {
final_presets[name] = preset;
source_map[name] = SERVER_MODEL_SOURCE_MODELS_DIR;
}
for (const auto & [name, custom] : custom_presets) {
if (final_presets.find(name) != final_presets.end()) {
final_presets[name].merge(custom);
} else {
final_presets[name] = custom;
}
source_map[name] = SERVER_MODEL_SOURCE_PRESET;
}
// server base preset from CLI args takes highest precedence
for (auto & [name, preset] : final_presets) {
preset.merge(base_preset);
}
auto get_source = [&](const std::string & name) {
return source_map.count(name) ? source_map.at(name) : SERVER_MODEL_SOURCE_PRESET;
};
// Helpers that read `mapping` — must be called while holding the lock.
std::unordered_set<std::string> custom_names;
@ -366,12 +403,15 @@ void server_models::load_models() {
// (unload, load) or when joining threads (the monitoring thread calls update_status
// which locks the mutex, so joining while holding it would deadlock).
std::unique_lock<std::mutex> lk(mutex);
need_reload = false;
bool is_first_load = mapping.empty();
if (is_first_load) {
// FIRST LOAD: add all models, then unlock for autoloading
for (const auto & [name, preset] : final_presets) {
server_model_meta meta{
/* source */ get_source(name),
/* preset */ preset,
/* name */ name,
/* aliases */ {},
@ -384,7 +424,7 @@ void server_models::load_models() {
/* exit_code */ 0,
/* stop_timeout */ DEFAULT_STOP_TIMEOUT,
/* multimodal */ mtmd_caps{false, false},
/* need_download */ false,
// /* need_download */ false,
};
add_model(std::move(meta));
}
@ -453,6 +493,9 @@ void server_models::load_models() {
}
}
for (auto & [name, inst] : mapping) {
if (inst.meta.status == SERVER_MODEL_STATUS_DOWNLOADING) {
continue; // downloading models are not from config sources, leave them alone
}
if (final_presets.find(name) == final_presets.end() && !inst.meta.is_running() && inst.th.joinable()) {
threads_to_join.push_back(std::move(inst.th));
}
@ -465,7 +508,15 @@ void server_models::load_models() {
// erase models no longer in any source
for (auto it = mapping.begin(); it != mapping.end(); ) {
if (final_presets.find(it->first) == final_presets.end()) {
if (it->second.meta.status == SERVER_MODEL_STATUS_DOWNLOADING) {
++it; // download thread is still busy, skip
} else if (it->second.meta.status == SERVER_MODEL_STATUS_DOWNLOADED) {
// download finished, safe to erase
if (it->second.th.joinable()) {
it->second.th.join();
}
it = mapping.erase(it);
} else if (final_presets.find(it->first) == final_presets.end()) {
SRV_INF("(reload) removing model name=%s (no longer in source)\n", it->first.c_str());
GGML_ASSERT(!it->second.th.joinable()); // must have been joined above
it = mapping.erase(it);
@ -526,6 +577,7 @@ void server_models::load_models() {
for (const auto & [name, preset] : final_presets) {
if (mapping.find(name) == mapping.end()) {
server_model_meta meta{
/* source */ get_source(name),
/* preset */ preset,
/* name */ name,
/* aliases */ {},
@ -538,7 +590,7 @@ void server_models::load_models() {
/* exit_code */ 0,
/* stop_timeout */ DEFAULT_STOP_TIMEOUT,
/* multimodal */ mtmd_caps{false, false},
/* need_download */ false,
// /* need_download */ false,
};
add_model(std::move(meta));
newly_added.push_back(name);
@ -571,6 +623,8 @@ void server_models::load_models() {
SRV_INF("(reload) loading new model %s\n", name.c_str());
load(name);
}
notify_sse("models_reload", "*");
}
}
@ -597,7 +651,13 @@ bool server_models::has_model(const std::string & name) {
}
std::optional<server_model_meta> server_models::get_meta(const std::string & name) {
std::lock_guard<std::mutex> lk(mutex);
std::unique_lock<std::mutex> lk(mutex);
if (need_reload) {
lk.unlock();
load_models();
lk.lock();
}
auto it = mapping.find(name);
if (it != mapping.end()) {
return it->second.meta;
@ -683,7 +743,13 @@ static std::vector<char *> to_char_ptr_array(const std::vector<std::string> & ve
}
std::vector<server_model_meta> server_models::get_all_meta() {
std::lock_guard<std::mutex> lk(mutex);
std::unique_lock<std::mutex> lk(mutex);
if (need_reload) {
lk.unlock();
load_models();
lk.lock();
}
std::vector<server_model_meta> result;
result.reserve(mapping.size());
for (const auto & [name, inst] : mapping) {
@ -770,7 +836,7 @@ void server_models::load(const std::string & name) {
throw std::runtime_error("failed to get a port number");
}
inst.subproc = std::make_shared<subprocess_s>();
inst.subproc = std::make_shared<server_subproc>();
{
SRV_INF("spawning server instance with name=%s on port %d\n", inst.meta.name.c_str(), inst.meta.port);
@ -792,19 +858,20 @@ void server_models::load(const std::string & name) {
// TODO @ngxson : maybe separate stdout and stderr in the future
// so that we can use stdout for commands and stderr for logging
int options = subprocess_option_no_window | subprocess_option_combined_stdout_stderr;
int result = subprocess_create_ex(argv.data(), options, envp.data(), inst.subproc.get());
inst.subproc->sproc.emplace();
int result = subprocess_create_ex(argv.data(), options, envp.data(), &inst.subproc->get());
if (result != 0) {
throw std::runtime_error("failed to spawn server instance");
}
inst.stdin_file = subprocess_stdin(inst.subproc.get());
inst.stdin_file = subprocess_stdin(&inst.subproc->get());
}
// start a thread to manage the child process
// captured variables are guaranteed to be destroyed only after the thread is joined
inst.th = std::thread([this, name, child_proc = inst.subproc, port = inst.meta.port, stop_timeout = inst.meta.stop_timeout]() {
FILE * stdin_file = subprocess_stdin(child_proc.get());
FILE * stdout_file = subprocess_stdout(child_proc.get()); // combined stdout/stderr
FILE * stdin_file = subprocess_stdin(&child_proc->get());
FILE * stdout_file = subprocess_stdout(&child_proc->get()); // combined stdout/stderr
std::thread log_thread([&]() {
// read stdout/stderr and forward to main server log
@ -834,14 +901,14 @@ void server_models::load(const std::string & name) {
return this->stopping_models.find(name) != this->stopping_models.end();
};
auto should_wake = [&]() {
return is_stopping() || !subprocess_alive(child_proc.get());
return is_stopping() || !child_proc->is_alive();
};
{
std::unique_lock<std::mutex> lk(this->mutex);
this->cv_stop.wait(lk, should_wake);
}
// child may have already exited (e.g. crashed) — skip shutdown sequence
if (!subprocess_alive(child_proc.get())) {
if (!child_proc->is_alive()) {
return;
}
SRV_INF("stopping model instance name=%s\n", name.c_str());
@ -859,7 +926,7 @@ void server_models::load(const std::string & name) {
if (elapsed >= stop_timeout * 1000) {
// timeout, force kill
SRV_WRN("force-killing model instance name=%s after %d seconds timeout\n", name.c_str(), stop_timeout);
subprocess_terminate(child_proc.get());
subprocess_terminate(&child_proc->get());
return;
}
this->cv_stop.wait_for(lk, std::chrono::seconds(1));
@ -884,8 +951,8 @@ void server_models::load(const std::string & name) {
// get the exit code
int exit_code = 0;
subprocess_join(child_proc.get(), &exit_code);
subprocess_destroy(child_proc.get());
subprocess_join(&child_proc->get(), &exit_code);
subprocess_destroy(&child_proc->get());
// update status and exit code
this->update_status(name, SERVER_MODEL_STATUS_UNLOADED, exit_code);
@ -896,30 +963,118 @@ void server_models::load(const std::string & name) {
{
auto & old_instance = mapping[name];
// old process should have exited already, but just in case, we clean it up here
if (subprocess_alive(old_instance.subproc.get())) {
if (old_instance.subproc->is_alive()) {
SRV_WRN("old process for model name=%s is still alive, this is unexpected\n", name.c_str());
subprocess_terminate(old_instance.subproc.get()); // force kill
subprocess_terminate(&old_instance.subproc->get()); // force kill
}
if (old_instance.th.joinable()) {
old_instance.th.join();
}
}
notify_sse("model_status", name, {
{"status", server_model_status_to_string(inst.meta.status)},
});
mapping[name] = std::move(inst);
cv.notify_all();
}
// callback for model downloading functionality
struct server_models_download_res : public common_download_callback {
common_params_model model;
common_download_opts opts;
std::function<bool()> should_stop;
std::function<void(const common_download_progress & p)> on_progress;
bool is_ok = false;
bool run() {
try {
common_download_model(model, opts);
is_ok = true;
} catch (const std::exception & e) {
SRV_ERR("download failed for model name=%s: %s\n", model.name.c_str(), e.what());
is_ok = false;
}
return is_ok;
}
void on_start(const common_download_progress & p) override {
on_progress(p);
}
void on_update(const common_download_progress & p) override {
on_progress(p);
}
void on_done(const common_download_progress &, bool ok) override {
is_ok = ok;
}
bool is_cancelled() const override {
return should_stop();
}
};
void server_models::download(common_params_model && model, common_download_opts && opts) {
std::string name = model.name;
GGML_ASSERT(name == model.hf_repo);
std::unique_lock<std::mutex> lk(mutex);
if (mapping.find(name) != mapping.end()) {
throw std::runtime_error("model name=" + name + " already exists");
}
instance_t inst;
inst.meta.name = name;
inst.meta.status = SERVER_MODEL_STATUS_DOWNLOADING;
inst.subproc = std::make_shared<server_subproc>();
auto dl = std::make_unique<server_models_download_res>();
dl->model = model; // copy
dl->opts = opts; // copy
dl->should_stop = [sp = inst.subproc]() {
return sp->stop_download.load(std::memory_order_relaxed);
};
dl->on_progress = [this, name](const common_download_progress & p) {
update_download_progress(name, p, false);
};
inst.th = std::thread([this, dl = std::move(dl)]() {
dl->opts.callback = dl.get();
bool ok = dl->run();
SRV_INF("download finished for model name=%s with status=%s\n",
dl->model.name.c_str(), ok ? "success" : "failure");
update_download_progress(dl->model.name, {}, true, ok);
// need_reload is set inside update_download_progress under the mutex;
// the next load_models() call will clean up this instance
});
mapping[name] = std::move(inst);
notify_sse("status_update", name, {
{"status", server_model_status_to_string(SERVER_MODEL_STATUS_DOWNLOADING)},
});
cv.notify_all();
}
void server_models::unload(const std::string & name) {
std::lock_guard<std::mutex> lk(mutex);
std::unique_lock<std::mutex> lk(mutex);
auto it = mapping.find(name);
if (it != mapping.end()) {
if (it->second.meta.is_running()) {
if (it->second.meta.status == SERVER_MODEL_STATUS_DOWNLOADING) {
SRV_INF("cancelling download for model name=%s\n", name.c_str());
it->second.subproc->stop_download.store(true, std::memory_order_relaxed);
// for convenience, we wait the status change here
wait(lk, name, [](const server_model_meta & new_meta) {
return new_meta.status != SERVER_MODEL_STATUS_DOWNLOADING;
});
} else if (it->second.meta.is_running()) {
SRV_INF("stopping model instance name=%s\n", name.c_str());
stopping_models.insert(name);
if (it->second.meta.status == SERVER_MODEL_STATUS_LOADING) {
// special case: if model is in loading state, unloading means force-killing it
SRV_WRN("model name=%s is still loading, force-killing\n", name.c_str());
subprocess_terminate(it->second.subproc.get());
subprocess_terminate(&it->second.subproc->get());
}
cv_stop.notify_all();
// status change will be handled by the managing thread
@ -934,7 +1089,10 @@ void server_models::unload_all() {
{
std::lock_guard<std::mutex> lk(mutex);
for (auto & [name, inst] : mapping) {
if (inst.meta.is_running()) {
if (inst.meta.status == SERVER_MODEL_STATUS_DOWNLOADING) {
SRV_INF("cancelling download for model name=%s\n", name.c_str());
inst.subproc->stop_download.store(true, std::memory_order_relaxed);
} else if (inst.meta.is_running()) {
SRV_INF("stopping model instance name=%s\n", name.c_str());
stopping_models.insert(name);
cv_stop.notify_all();
@ -959,6 +1117,17 @@ void server_models::update_status(const std::string & name, server_model_status
meta.status = status;
meta.exit_code = exit_code;
}
// broadcast status change to SSE
{
json data = {
{"status", server_model_status_to_string(status)},
};
if (status == SERVER_MODEL_STATUS_UNLOADED) {
data["exit_code"] = exit_code;
}
// note: notify_sse doesn't acquire the lock, so no deadlock here
notify_sse("status_change", name, data);
}
cv.notify_all();
}
@ -985,12 +1154,82 @@ void server_models::update_loaded_info(const std::string & name, std::string & r
cv.notify_all();
}
void server_models::wait_until_loading_finished(const std::string & name) {
std::unique_lock<std::mutex> lk(mutex);
cv.wait(lk, [this, &name]() {
void server_models::update_download_progress(const std::string & name, const common_download_progress & progress, bool done, bool ok) {
json curr;
{
std::lock_guard<std::mutex> lk(mutex);
auto it = mapping.find(name);
if (it != mapping.end()) {
return it->second.meta.status != SERVER_MODEL_STATUS_LOADING;
if (done) {
// mark the instance to be erased on next load_models() call
it->second.meta.status = SERVER_MODEL_STATUS_DOWNLOADED;
need_reload = true;
} else {
json & info = it->second.meta.loaded_info;
if (!info.contains("progress")) {
info["progress"] = json{};
}
info["progress"][progress.url] = {
{"done", progress.downloaded},
{"total", progress.total},
};
curr = it->second.meta.loaded_info; // copy
}
}
}
if (done) {
cv.notify_all(); // notify in case unload() is waiting for download to be cancelled
notify_sse(ok ? "download_finished" : "download_failed", name, {});
} else {
notify_sse("download_progress", name, curr);
}
}
bool server_models::remove(const std::string & name) {
auto meta = get_meta(name);
if (!meta.has_value()) {
throw std::runtime_error("model name=" + name + " is not found");
}
if (meta->source != SERVER_MODEL_SOURCE_CACHE) {
throw std::runtime_error("model name=" + name + " is not removable (not from cache)");
}
unload(name); // cancel download or stop running instance
{
std::unique_lock<std::mutex> lk(mutex);
// a cancelled download lands on DOWNLOADED; a stopped instance lands on UNLOADED
wait(lk, name, [](const server_model_meta & new_meta) {
return new_meta.status == SERVER_MODEL_STATUS_UNLOADED
|| new_meta.status == SERVER_MODEL_STATUS_DOWNLOADED;
});
// join before erasing - after status reaches UNLOADED/DOWNLOADED the thread no
// longer acquires this mutex, so joining while holding it is safe
if (mapping[name].th.joinable()) {
mapping[name].th.join();
}
// remove the model from disk (hold lock to prevent concurrent load)
bool ok = common_download_remove(name);
if (ok) {
mapping.erase(name);
}
SRV_INF("removing model name=%s from cache (%s)\n", name.c_str(), ok ? "succeeded" : "failed");
notify_sse("model_remove", name, {});
return ok;
}
}
void server_models::wait(const std::string & name, std::function<bool(const server_model_meta &)> predicate) {
std::unique_lock<std::mutex> lk(mutex);
wait(lk, name, predicate);
}
void server_models::wait(std::unique_lock<std::mutex> & lk, const std::string & name, std::function<bool(const server_model_meta &)> predicate) {
cv.wait(lk, [this, &name, &predicate]() {
auto it = mapping.find(name);
if (it != mapping.end()) {
return predicate(it->second.meta);
}
return false;
});
@ -1014,10 +1253,15 @@ bool server_models::ensure_model_ready(const std::string & name) {
// wait for loading to complete
SRV_INF("waiting until model name=%s is fully loaded...\n", name.c_str());
wait_until_loading_finished(name);
wait(name, [&meta](const server_model_meta & new_meta) {
if (new_meta.status != SERVER_MODEL_STATUS_LOADING) {
meta = new_meta; // update meta for final check after wait
return true;
}
return false;
});
// check final status
meta = get_meta(name);
if (!meta.has_value() || meta->is_failed()) {
throw std::runtime_error("model name=" + name + " failed to load");
}
@ -1111,6 +1355,42 @@ void server_models::notify_router_sleeping_state(bool is_sleeping) {
// server_models_routes
//
// RAII wrapper similar to server_response_reader, but doesn't use server_queue
static std::atomic<int> sse_client_id_counter = 0;
struct server_models_sse_client {
server_response & queue_results;
int client_id;
server_models_sse_client(server_response & q)
: queue_results(q), client_id(sse_client_id_counter.fetch_add(1, std::memory_order_relaxed)) {
SRV_DBG("new SSE client connected, assigned client_id=%d\n", client_id);
queue_results.add_waiting_task_id(client_id);
}
~server_models_sse_client() {
SRV_DBG("SSE client disconnected, removing client_id=%d\n", client_id);
queue_results.remove_waiting_task_id(client_id);
}
// return nullptr if should_stop() is true before receiving a result
// note: if one error is received, it will stop further processing and return error result
server_task_result_ptr next(const std::function<bool()> & should_stop) {
while (true) {
static const int http_polling_seconds = 1; // check should_stop every 1 second
server_task_result_ptr result = queue_results.recv_with_timeout({client_id}, http_polling_seconds);
if (result == nullptr) {
// timeout, check stop condition
if (should_stop()) {
return nullptr;
}
// continue waiting otherwise
} else {
SRV_DBG("recv result for client_id=%d: %s\n", client_id, safe_json_to_str(result->to_json()).c_str());
return result;
}
}
// should not reach here
}
};
static void res_ok(std::unique_ptr<server_http_res> & res, const json & response_data) {
res->status = 200;
res->data = safe_json_to_str(response_data);
@ -1274,7 +1554,9 @@ void server_models_routes::init_routes() {
{"created", t}, // for OAI-compat
{"status", status},
{"architecture", architecture},
{"need_download", meta.need_download},
{"source", server_model_source_to_string(meta.source)},
{"can_remove", meta.source == SERVER_MODEL_SOURCE_CACHE},
// {"need_download", meta.need_download},
// TODO: add other fields, may require reading GGUF metadata
};
@ -1312,6 +1594,87 @@ void server_models_routes::init_routes() {
res_ok(res, {{"success", true}});
return res;
};
this->get_router_models_sse = [this](const server_http_req & req) {
auto res = std::make_unique<server_http_res>();
res->status = 200;
res->content_type = "text/event-stream";
auto sse_client = std::make_shared<server_models_sse_client>(models.sse);
res->next = [this, sse_client, &req](std::string & output) -> bool {
auto result = sse_client->next([&]() {
return stopping.load(std::memory_order_relaxed) || req.should_stop();
});
if (result == nullptr) {
return false; // client disconnected or should_stop
}
output = "data: " + safe_json_to_str(result->to_json()) + "\n\n";
return true; // listen for the next event
};
return res;
};
this->post_router_models = [this](const server_http_req & req) {
auto res = std::make_unique<server_http_res>();
json body = json::parse(req.body);
std::string name = json_value(body, "model", std::string());
if (name.empty()) {
throw std::invalid_argument("model must be a non-empty string");
}
common_params_model model;
common_download_opts opts;
model.name = name;
model.hf_repo = name;
opts.bearer_token = params.hf_token;
opts.download_mmproj = true;
opts.download_mtp = true;
// first, only check if the model is valid and can be downloaded
opts.skip_download = true;
bool ok = false;
try {
auto validation = common_download_model(model, opts);
ok = !validation.model_path.empty();
} catch (const common_skip_download_exception &) {
// model is valid and will be downloaded
ok = true;
} catch (...) {
SRV_ERR("unknown error while validating model '%s'\n", name.c_str());
// other exceptions will be handled by the outer ex_wrapper()
throw;
}
if (!ok) {
throw std::invalid_argument("model validation failed, unable to download");
}
// then, proceed with the actual download
opts.skip_download = false;
SRV_INF("starting download for model '%s'\n", name.c_str());
models.download(std::move(model), std::move(opts));
res_ok(res, {{"success", true}});
return res;
};
this->del_router_models = [this](const server_http_req & req) {
auto res = std::make_unique<server_http_res>();
std::string name = req.get_param("model");
if (name.empty()) {
throw std::invalid_argument("model must be a non-empty string");
}
bool ok = models.remove(name);
if (!ok) {
throw std::runtime_error("failed to remove model '" + name + "'");
}
res_ok(res, {{"success", true}});
return res;
};
}

View File

@ -1,9 +1,11 @@
#pragma once
#include "common.h"
#include "download.h"
#include "preset.h"
#include "server-common.h"
#include "server-http.h"
#include "server-queue.h"
#include <mutex>
#include <condition_variable>
@ -14,6 +16,8 @@
/**
* state diagram:
*
* DOWNLOADING DOWNLOADED (replaced by new instance)
*
* UNLOADED LOADING LOADED SLEEPING
*
* failed
@ -22,39 +26,43 @@
*/
enum server_model_status {
// TODO: also add downloading state when the logic is added
SERVER_MODEL_STATUS_DOWNLOADING,
SERVER_MODEL_STATUS_DOWNLOADED,
SERVER_MODEL_STATUS_UNLOADED,
SERVER_MODEL_STATUS_LOADING,
SERVER_MODEL_STATUS_LOADED,
SERVER_MODEL_STATUS_SLEEPING
};
static server_model_status server_model_status_from_string(const std::string & status_str) {
if (status_str == "unloaded") {
return SERVER_MODEL_STATUS_UNLOADED;
}
if (status_str == "loading") {
return SERVER_MODEL_STATUS_LOADING;
}
if (status_str == "loaded") {
return SERVER_MODEL_STATUS_LOADED;
}
if (status_str == "sleeping") {
return SERVER_MODEL_STATUS_SLEEPING;
}
throw std::runtime_error("invalid server model status");
}
enum server_model_source {
SERVER_MODEL_SOURCE_PRESET,
SERVER_MODEL_SOURCE_MODELS_DIR,
SERVER_MODEL_SOURCE_CACHE,
};
static std::string server_model_status_to_string(server_model_status status) {
switch (status) {
case SERVER_MODEL_STATUS_UNLOADED: return "unloaded";
case SERVER_MODEL_STATUS_LOADING: return "loading";
case SERVER_MODEL_STATUS_LOADED: return "loaded";
case SERVER_MODEL_STATUS_SLEEPING: return "sleeping";
default: return "unknown";
case SERVER_MODEL_STATUS_DOWNLOADING: return "downloading";
case SERVER_MODEL_STATUS_DOWNLOADED: return "downloaded";
case SERVER_MODEL_STATUS_UNLOADED: return "unloaded";
case SERVER_MODEL_STATUS_LOADING: return "loading";
case SERVER_MODEL_STATUS_LOADED: return "loaded";
case SERVER_MODEL_STATUS_SLEEPING: return "sleeping";
default: return "unknown";
}
}
static std::string server_model_source_to_string(server_model_source source) {
switch (source) {
case SERVER_MODEL_SOURCE_PRESET: return "preset";
case SERVER_MODEL_SOURCE_MODELS_DIR: return "models_dir";
case SERVER_MODEL_SOURCE_CACHE: return "cache";
default: return "unknown";
}
}
struct server_model_meta {
server_model_source source = SERVER_MODEL_SOURCE_CACHE;
common_preset preset;
std::string name;
std::set<std::string> aliases; // additional names that resolve to this model
@ -63,11 +71,11 @@ struct server_model_meta {
server_model_status status = SERVER_MODEL_STATUS_UNLOADED;
int64_t last_used = 0; // for LRU unloading
std::vector<std::string> args; // args passed to the model instance, will be populated by render_args()
json loaded_info; // info to be reflected via /v1/models endpoint
json loaded_info; // info to be reflected via /v1/models endpoint ; if in DOWNLOADING state, it should contain download progress info
int exit_code = 0; // exit code of the model instance process (only valid if status == FAILED)
int stop_timeout = 0; // seconds to wait before force-killing the model instance during shutdown
mtmd_caps multimodal; // multimodal capabilities
bool need_download = false; // whether the model needs to be downloaded before loading
// bool need_download = false; // whether the model needs to be downloaded before loading // TODO @ngxson: implement this
bool is_ready() const {
return status == SERVER_MODEL_STATUS_LOADED;
@ -85,12 +93,15 @@ struct server_model_meta {
void update_caps();
};
struct subprocess_s;
struct server_models_routes;
struct server_subproc; // defined in server-models.cpp
struct server_models {
friend struct server_models_routes;
private:
struct instance_t {
std::shared_ptr<subprocess_s> subproc; // shared between main thread and monitoring thread
std::shared_ptr<server_subproc> subproc; // shared between main thread and monitoring thread
std::thread th;
server_model_meta meta;
FILE * stdin_file = nullptr;
@ -107,6 +118,9 @@ private:
// set to true while load_models() is executing a reload; load() will wait until clear
bool is_reloading = false;
// if true, the next get_meta() will trigger a reload of model list
bool need_reload = false;
common_preset_context ctx_preset;
common_params base_params;
@ -122,9 +136,14 @@ private:
// not thread-safe, caller must hold mutex
void add_model(server_model_meta && meta);
// notify SSE clients
void notify_sse(const std::string & event, const std::string & model_id, const json & data = nullptr);
public:
server_models(const common_params & params, int argc, char ** argv);
server_response sse; // for real-time updates via SSE endpoint
// (re-)load the list of models from various sources and prepare the metadata mapping
// - if this is called the first time, simply populate the metadata
// - if this is called subsequently (e.g. when refreshing from disk):
@ -147,13 +166,24 @@ public:
void unload(const std::string & name);
void unload_all();
// download a new model, progress is reported via SSE
// to stop the download, call unload()
void download(common_params_model && model, common_download_opts && opts);
// update the status of a model instance (thread-safe)
void update_status(const std::string & name, server_model_status status, int exit_code);
void update_loaded_info(const std::string & name, std::string & raw_info);
void update_download_progress(const std::string & name, const common_download_progress & progress, bool done, bool ok = true);
// remove a cache model from disk and update the list (thread-safe)
// note: only cache models can be removed; returns false if the model doesn't exist or is not a cache model
bool remove(const std::string & name);
// wait until the model instance is fully loaded (thread-safe)
// note: predicate is called while holding the lock
// return when the model no longer in "loading" state
void wait_until_loading_finished(const std::string & name);
void wait(const std::string & name, std::function<bool(const server_model_meta &)> predicate);
void wait(std::unique_lock<std::mutex> & lk, const std::string & name, std::function<bool(const server_model_meta &)> predicate);
// ensure the model is in ready state (thread-safe)
// return false if model is ready
@ -176,8 +206,9 @@ public:
struct server_models_routes {
common_params params;
json ui_settings = json::object(); // Primary: new name
json webui_settings = json::object(); // Deprecated: use ui_settings (kept for compat)
json ui_settings = json::object(); // Primary: new name
json webui_settings = json::object(); // Deprecated: use ui_settings (kept for compat)
std::atomic<bool> stopping = false; // for graceful disconnecting SSE clients during shutdown
server_models models;
server_models_routes(const common_params & params, int argc, char ** argv)
: params(params), models(params, argc, argv) {
@ -206,6 +237,10 @@ struct server_models_routes {
server_http_context::handler_t get_router_models;
server_http_context::handler_t post_router_models_load;
server_http_context::handler_t post_router_models_unload;
// management API
server_http_context::handler_t get_router_models_sse;
server_http_context::handler_t post_router_models;
server_http_context::handler_t del_router_models;
};
/**

View File

@ -331,6 +331,17 @@ void server_response::send(server_task_result_ptr && result) {
}
}
void server_response::broadcast(server_task_result_ptr && result) {
std::unique_lock<std::mutex> lock(mutex_results);
for (const auto & id_task : waiting_task_ids) {
RES_DBG("task id = %d pushed to result queue\n", id_task);
server_task_result_ptr res_copy(result->clone());
res_copy->id = id_task; // override id with target task id
queue_results.emplace_back(std::move(res_copy));
}
condition_results.notify_all();
}
void server_response::terminate() {
running = false;
condition_results.notify_all();

View File

@ -154,11 +154,15 @@ public:
// Send a new result to a waiting id_task
void send(server_task_result_ptr && result);
// broadcast a new result to all waiting tasks
// (used by router mode)
void broadcast(server_task_result_ptr && result);
// terminate the waiting loop
void terminate();
};
// utility class to make working with server_queue and server_response easier
// RAII wrapper to make working with server_queue and server_response easier
// it provides a generator-like API for server responses
// support pooling connection state and aggregating multiple results
struct server_response_reader {

View File

@ -312,6 +312,9 @@ struct server_task_result {
}
virtual json to_json() = 0;
virtual ~server_task_result() = default;
virtual server_task_result * clone() const {
GGML_ABORT("not implemented for this task type");
}
};
// using shared_ptr for polymorphism of server_task_result
@ -649,3 +652,12 @@ struct server_prompt_cache {
void update();
};
// used exclusively by router mode
struct server_task_result_router : server_task_result {
json data;
virtual json to_json() override { return data; }
virtual server_task_result * clone() const override {
return new server_task_result_router(*this);
}
};

View File

@ -174,8 +174,11 @@ int llama_server(int argc, char ** argv) {
routes.get_props = models_routes->get_router_props;
routes.get_models = models_routes->get_router_models;
ctx_http.post("/models", ex_wrapper(models_routes->post_router_models));
ctx_http.post("/models/load", ex_wrapper(models_routes->post_router_models_load));
ctx_http.post("/models/unload", ex_wrapper(models_routes->post_router_models_unload));
ctx_http.get ("/models/sse", ex_wrapper(models_routes->get_router_models_sse));
ctx_http.del ("/models", ex_wrapper(models_routes->del_router_models));
}
ctx_http.get ("/health", ex_wrapper(routes.get_health)); // public endpoint (no API key check)
@ -261,6 +264,7 @@ int llama_server(int argc, char ** argv) {
clean_up = [&models_routes]() {
SRV_INF("%s: cleaning up before exit...\n", __func__);
if (models_routes.has_value()) {
models_routes->stopping.store(true); // maybe redundant, but just to be safe
models_routes->models.unload_all();
}
llama_backend_free();
@ -274,6 +278,10 @@ int llama_server(int argc, char ** argv) {
ctx_http.is_ready.store(true);
shutdown_handler = [&](int) {
if (models_routes.has_value()) {
// important to disconnect any SSE clients
models_routes->stopping.store(true);
}
ctx_http.stop();
};

View File

@ -1,3 +1,4 @@
import threading
import pytest
from utils import *
@ -253,3 +254,98 @@ def test_router_reload_models():
assert "model-reload-c" in ids, "newly added model should appear"
finally:
os.remove(preset_path)
MODEL_DOWNLOAD_ID = "ggml-org/test-model-router-download:F16"
MODEL_DOWNLOAD_TIMEOUT = 300
def _listen_sse(server: ServerProcess, collected: list, stop: threading.Event):
"""Collect /models/sse events into `collected` until `stop` is set."""
url = f"http://{server.server_host}:{server.server_port}/models/sse"
try:
with requests.get(url, stream=True, timeout=MODEL_DOWNLOAD_TIMEOUT) as resp:
for line_bytes in resp.iter_lines():
if stop.is_set():
break
line = line_bytes.decode("utf-8")
if line.startswith("data: "):
collected.append(json.loads(line[6:]))
except Exception:
pass
def _wait_for_sse_event(collected: list, event_type: str, model: str, timeout: int) -> bool:
deadline = time.time() + timeout
while time.time() < deadline:
if any(e.get("event") == event_type and e.get("model") == model for e in collected):
return True
time.sleep(0.5)
return False
def test_router_download_model():
"""Case 1: download a model, verify SSE events and GET /models."""
global server
server.start()
# Ensure the model is not present before we start
server.make_request("DELETE", f"/models?model={MODEL_DOWNLOAD_ID}")
sse_events: list = []
stop = threading.Event()
sse_thread = threading.Thread(
target=_listen_sse, args=(server, sse_events, stop), daemon=True
)
sse_thread.start()
# Trigger the download
res = server.make_request("POST", "/models", data={"model": MODEL_DOWNLOAD_ID})
assert res.status_code == 200
assert res.body.get("success") is True
# Wait for download_finished SSE event
finished = _wait_for_sse_event(
sse_events, "download_finished", MODEL_DOWNLOAD_ID, MODEL_DOWNLOAD_TIMEOUT
)
stop.set()
assert finished, "Never received download_finished SSE event"
assert any(
e.get("event") == "download_progress" and e.get("model") == MODEL_DOWNLOAD_ID
for e in sse_events
), "No download_progress events received"
# Model should now appear in GET /models
ids = _get_model_ids(is_reload=False)
assert MODEL_DOWNLOAD_ID in ids, f"{MODEL_DOWNLOAD_ID} not found in /models after download"
def test_router_delete_model():
"""Case 2: delete the downloaded model, verify it disappears from GET /models."""
global server
server.start()
# Ensure the model exists (download it if needed)
if MODEL_DOWNLOAD_ID not in _get_model_ids(is_reload=False):
res = server.make_request("POST", "/models", data={"model": MODEL_DOWNLOAD_ID})
assert res.status_code == 200
sse_events: list = []
stop = threading.Event()
threading.Thread(
target=_listen_sse, args=(server, sse_events, stop), daemon=True
).start()
finished = _wait_for_sse_event(
sse_events, "download_finished", MODEL_DOWNLOAD_ID, MODEL_DOWNLOAD_TIMEOUT
)
stop.set()
assert finished, "Model did not finish downloading before delete test"
# Delete the model
del_res = server.make_request("DELETE", f"/models?model={MODEL_DOWNLOAD_ID}")
assert del_res.status_code == 200
assert del_res.body.get("success") is True
# Model should no longer appear in GET /models
ids = _get_model_ids(is_reload=False)
assert MODEL_DOWNLOAD_ID not in ids, f"{MODEL_DOWNLOAD_ID} still present after deletion"

View File

@ -340,6 +340,9 @@ class ServerProcess:
elif method == "POST":
response = requests.post(url, headers=headers, json=data, timeout=timeout)
parse_body = True
elif method == "DELETE":
response = requests.delete(url, headers=headers, timeout=timeout)
parse_body = True
elif method == "OPTIONS":
response = requests.options(url, headers=headers, timeout=timeout)
else: