diff --git a/common/arg.cpp b/common/arg.cpp index 8f54b5c814..5297d90753 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -396,7 +396,7 @@ static bool parse_bool_value(const std::string & value) { // CLI argument parsing functions // -bool common_params_handle_models(common_params & params, llama_example curr_ex) { +bool common_params_handle_models(common_params & params, llama_example curr_ex, common_download_callback * callback) { const bool spec_type_draft_mtp = std::find(params.speculative.types.begin(), params.speculative.types.end(), COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params.speculative.types.end(); @@ -408,6 +408,10 @@ bool common_params_handle_models(common_params & params, llama_example curr_ex) opts.download_mtp = spec_type_draft_mtp; opts.download_mmproj = !params.no_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty(); + if (callback) { + opts.callback = callback; + } + // sub-models (draft, mmproj, vocoder) are explicitly specified by the user, // so we should not auto-discover mtp/mmproj siblings for them common_download_opts sub_opts = opts; @@ -584,8 +588,11 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context throw std::invalid_argument("error: --prompt-cache-all not supported in interactive mode yet\n"); } - // export_graph_ops loads only metadata - const bool skip_model_download = ctx_arg.ex == LLAMA_EXAMPLE_EXPORT_GRAPH_OPS; + const bool skip_model_download = + // server will call common_params_handle_models() later, so we skip it here + ctx_arg.ex == LLAMA_EXAMPLE_SERVER || + // export_graph_ops loads only metadata + ctx_arg.ex == LLAMA_EXAMPLE_EXPORT_GRAPH_OPS; if (!skip_model_download) { // handle model and download @@ -594,7 +601,6 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context // model is required (except for server) // TODO @ngxson : maybe show a list of available models in CLI in this case if (params.model.path.empty() - && ctx_arg.ex != LLAMA_EXAMPLE_SERVER && !params.usage && !params.completion) { throw std::invalid_argument("error: --model is required\n"); diff --git a/common/arg.h b/common/arg.h index 0010f2a9ac..c061fc60f7 100644 --- a/common/arg.h +++ b/common/arg.h @@ -1,6 +1,7 @@ #pragma once #include "common.h" +#include "download.h" #include #include @@ -133,7 +134,10 @@ void common_params_add_preset_options(std::vector & args); // return true if the model is ready to use // throw an exception if there is an error that prevents the model from being used (e.g. network error, model not found, etc) // if params.skip_download is true, no downloads will be attempted. return false if the model is invalid or missing (e.g. ETag check failed) -bool common_params_handle_models(common_params & params, llama_example curr_ex); +bool common_params_handle_models( + common_params & params, + llama_example curr_ex, + common_download_callback * callback = nullptr); // initialize argument parser context - used by test-arg-parser and preset common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr); diff --git a/tools/server/README-dev.md b/tools/server/README-dev.md index 2796d28350..5959745e47 100644 --- a/tools/server/README-dev.md +++ b/tools/server/README-dev.md @@ -204,9 +204,9 @@ Instead of building everything from the ground up (like what most AI agents will The flow for downloading a new model: - POST request comes in --> `post_router_models` --> validation -- `server_models::download()` is called - - Sets up a new thread `inst.th` and runs the download inside -- If a stop request comes in, set `stop_download` to `true` +- A new `llama-server` subprocess will be spawned with special `SERVER_CHILD_MODE_DOWNLOAD` +- Child process runs the download and report status back to router via stdin/out +- If a stop request comes in, the router asks the child process to stop (same mechanism as running a model in child process) - Otherwise, upon completion, we call `load_models()` to refresh the list of models ### Notable Related PRs diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 3f9391cacb..0a25b414ed 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -931,6 +931,8 @@ private: bool sleeping = false; + int64_t t_last_load_progress_ms = 0; + void destroy() { spec.reset(); ctx_dft.reset(); @@ -1244,6 +1246,10 @@ private: } if (has_mmproj) { + if (callback_state) { + callback_state(SERVER_STATE_LOADING, {{"stage", "mmproj_model"}}); + } + if (!is_resume) { mtmd_helper_log_set(common_log_default_callback, nullptr); } diff --git a/tools/server/server-context.h b/tools/server/server-context.h index c7218a12ed..952f825f72 100644 --- a/tools/server/server-context.h +++ b/tools/server/server-context.h @@ -53,7 +53,7 @@ struct server_context_meta { }; enum server_state { - // SERVER_STATE_DOWNLOADING, + SERVER_STATE_DOWNLOADING, SERVER_STATE_LOADING, SERVER_STATE_READY, SERVER_STATE_SLEEPING, @@ -61,6 +61,7 @@ enum server_state { static std::string server_state_to_str(server_state state) { switch (state) { + case SERVER_STATE_DOWNLOADING: return "downloading"; case SERVER_STATE_LOADING: return "loading"; case SERVER_STATE_READY: return "ready"; case SERVER_STATE_SLEEPING: return "sleeping"; @@ -69,6 +70,7 @@ static std::string server_state_to_str(server_state state) { } static server_state server_state_from_str(const std::string & str) { + if (str == "downloading") return SERVER_STATE_DOWNLOADING; if (str == "loading") return SERVER_STATE_LOADING; if (str == "ready") return SERVER_STATE_READY; if (str == "sleeping") return SERVER_STATE_SLEEPING; diff --git a/tools/server/server-models.cpp b/tools/server/server-models.cpp index 68eefdffac..a87e4e423e 100644 --- a/tools/server/server-models.cpp +++ b/tools/server/server-models.cpp @@ -64,6 +64,17 @@ struct server_subproc { return sproc.has_value() && subprocess_alive(&sproc.value()); } + void request_exit() { + if (sproc.has_value()) { + FILE * stdin_file = subprocess_stdin(&sproc.value()); + if (stdin_file) { + fprintf(stdin_file, "%s\n", CMD_ROUTER_TO_CHILD_EXIT); + fflush(stdin_file); + } + } + stopped.store(true, std::memory_order_relaxed); + } + void terminate() { if (!sproc.has_value()) { return; @@ -323,7 +334,7 @@ void server_models::notify_sse(const std::string & event, const std::string & mo } void server_models::load_models() { - // Phase 1: load presets from all sources — pure I/O, no lock needed + // Phase 1: load presets from all sources - pure I/O, no lock needed // 1. cached models common_presets cached_models = ctx_preset.load_from_cache(); SRV_INF("Loaded %zu cached model presets\n", cached_models.size()); @@ -376,7 +387,7 @@ void server_models::load_models() { return source_map.count(name) ? source_map.at(name) : SERVER_MODEL_SOURCE_PRESET; }; - // Helpers that read `mapping` — must be called while holding the lock. + // Helpers that read `mapping` - must be called while holding the lock. std::unordered_set custom_names; for (const auto & [name, preset] : custom_presets) custom_names.insert(name); auto join_set = [](const std::set & s) { @@ -523,7 +534,7 @@ void server_models::load_models() { } } - // join outside the lock — monitoring thread calls update_status (needs lock) + // join outside the lock - monitoring thread calls update_status (needs lock) lk.unlock(); for (auto & th : threads_to_join) th.join(); lk.lock(); @@ -622,7 +633,7 @@ void server_models::load_models() { apply_stop_timeout(); - // clear reload flag before unlocking for autoload — load() blocks on !is_reloading, + // clear reload flag before unlocking for autoload - load() blocks on !is_reloading, // so clearing it here (while still locked) prevents a deadlock in the autoload calls below is_reloading = false; cv.notify_all(); @@ -815,17 +826,23 @@ void server_models::unload_lru() { } void server_models::load(const std::string & name) { - if (!has_model(name)) { - throw std::runtime_error("model name=" + name + " is not found"); + load(name, load_options{}); +} + +void server_models::load(const std::string & name, const load_options & opts) { + if (!opts.custom_meta.has_value()) { + if (!has_model(name)) { + throw std::runtime_error("model name=" + name + " is not found"); + } + unload_lru(); } - unload_lru(); std::unique_lock lk(mutex); // edge case: block until any in-progress reload has finished so we always load // against the freshest preset and a consistent mapping state cv.wait(lk, [this]() { return !is_reloading; }); - auto meta = mapping[name].meta; + auto meta = opts.custom_meta.has_value() ? *opts.custom_meta : mapping[name].meta; if (meta.status != SERVER_MODEL_STATUS_UNLOADED) { SRV_INF("model %s is not ready\n", name.c_str()); return; @@ -869,6 +886,12 @@ void server_models::load(const std::string & name) { std::vector child_env = base_env; // copy child_env.push_back("LLAMA_SERVER_ROUTER_PORT=" + std::to_string(base_params.port)); + if (opts.mode == SERVER_CHILD_MODE_DOWNLOAD) { + inst.meta.status = SERVER_MODEL_STATUS_DOWNLOADING; + child_env.push_back("LLAMA_SERVER_CHILD_MODE=download"); + child_env.push_back("LLAMA_ARG_HF_REPO=" + name); + } + SRV_INF("%s", "spawning server instance with args:\n"); for (const auto & arg : child_args) { SRV_INF(" %s\n", arg.c_str()); @@ -886,13 +909,17 @@ void server_models::load(const std::string & name) { if (result != 0) { throw std::runtime_error("failed to spawn server instance"); } - - inst.stdin_file = subprocess_stdin(&inst.subproc->get()); } // start a thread to manage the child process // captured variables are guaranteed to be destroyed only after the thread is joined - inst.th = std::thread([this, name, child_proc = inst.subproc, port = inst.meta.port, stop_timeout = inst.meta.stop_timeout]() { + inst.th = std::thread([ + this, name, + child_proc = inst.subproc, + port = inst.meta.port, + stop_timeout = inst.meta.stop_timeout, + child_mode = opts.mode + ]() { FILE * stdin_file = subprocess_stdin(&child_proc->get()); FILE * stdout_file = subprocess_stdout(&child_proc->get()); // combined stdout/stderr @@ -925,7 +952,7 @@ void server_models::load(const std::string & name) { return is_stopping() || child_proc->stopped.load(std::memory_order_acquire); }); } - // child crashed or finished on its own — skip graceful shutdown sequence + // child crashed or finished on its own, skip graceful shutdown sequence if (child_proc->stopped.load(std::memory_order_acquire)) { return; } @@ -973,10 +1000,14 @@ void server_models::load(const std::string & name) { subprocess_destroy(&child_proc->get()); // update status and exit code - this->update_status(name, { - SERVER_MODEL_STATUS_UNLOADED, - exit_code - }); + if (child_mode == SERVER_CHILD_MODE_DOWNLOAD) { + // instance will be cleaned up on next load_models() call + } else { + this->update_status(name, { + SERVER_MODEL_STATUS_UNLOADED, + exit_code + }); + } SRV_INF("instance name=%s exited with status %d\n", name.c_str(), exit_code); }); @@ -984,7 +1015,7 @@ void server_models::load(const std::string & name) { { auto & old_instance = mapping[name]; // old process should have exited already, but just in case, we clean it up here - if (old_instance.subproc->is_alive()) { + if (old_instance.subproc && old_instance.subproc->is_alive()) { SRV_WRN("old process for model name=%s is still alive, this is unexpected\n", name.c_str()); old_instance.subproc->terminate(); // force kill } @@ -1001,92 +1032,13 @@ void server_models::load(const std::string & name) { cv.notify_all(); } -// callback for model downloading functionality -struct server_models_download_res : public common_download_callback { - common_params_model model; - common_download_opts opts; - - std::function should_stop; - std::function on_progress; - - bool is_ok = false; - - bool run() { - try { - common_download_model(model, opts); - is_ok = true; - } catch (const std::exception & e) { - auto model_name = model.get_name(); - SRV_ERR("download failed for model name=%s: %s\n", model_name.c_str(), e.what()); - is_ok = false; - } - return is_ok; - } - void on_start(const common_download_progress & p) override { - on_progress(p); - } - void on_update(const common_download_progress & p) override { - on_progress(p); - } - void on_done(const common_download_progress &, bool ok) override { - is_ok = ok; - } - bool is_cancelled() const override { - return should_stop(); - } -}; - -void server_models::download(common_params_model && model, common_download_opts && opts) { - std::string name = model.get_name(); - GGML_ASSERT(name == model.hf_repo); - - std::unique_lock lk(mutex); - if (mapping.find(name) != mapping.end()) { - throw std::runtime_error("model name=" + name + " already exists"); - } - - instance_t inst; - inst.meta.name = name; - inst.meta.status = SERVER_MODEL_STATUS_DOWNLOADING; - inst.subproc = std::make_shared(); - - auto dl = std::make_unique(); - dl->model = model; // copy - dl->opts = opts; // copy - - dl->should_stop = [sp = inst.subproc]() { - return sp->stopped.load(std::memory_order_relaxed); - }; - - dl->on_progress = [this, name](const common_download_progress & p) { - update_download_progress(name, p, false); - }; - - inst.th = std::thread([this, dl = std::move(dl)]() { - dl->opts.callback = dl.get(); - bool ok = dl->run(); - auto model_name = dl->model.get_name(); - SRV_INF("download finished for model name=%s with status=%s\n", - model_name.c_str(), ok ? "success" : "failure"); - update_download_progress(model_name, {}, true, ok); - // need_reload is set inside update_download_progress under the mutex; - // the next load_models() call will clean up this instance - }); - - mapping[name] = std::move(inst); - notify_sse("status_update", name, { - {"status", server_model_status_to_string(SERVER_MODEL_STATUS_DOWNLOADING)}, - }); - cv.notify_all(); -} - void server_models::unload(const std::string & name) { std::unique_lock lk(mutex); auto it = mapping.find(name); if (it != mapping.end()) { if (it->second.meta.status == SERVER_MODEL_STATUS_DOWNLOADING) { SRV_INF("cancelling download for model name=%s\n", name.c_str()); - it->second.subproc->stopped.store(true, std::memory_order_relaxed); + it->second.subproc->request_exit(); // for convenience, we wait the status change here wait(lk, name, [](const server_model_meta & new_meta) { return new_meta.status != SERVER_MODEL_STATUS_DOWNLOADING; @@ -1198,37 +1150,65 @@ void server_models::update_download_progress(const std::string & name, const com } bool server_models::remove(const std::string & name) { - auto meta = get_meta(name); + // do everything under one lock acquisition; avoid get_meta() / + // unload() because they can trigger load_models() which erases + // transient DOWNLOADING / DOWNLOADED entries as a side-effect + std::unique_lock lk(mutex); - if (!meta.has_value()) { + auto it = mapping.find(name); + if (it == mapping.end()) { throw std::runtime_error("model name=" + name + " is not found"); } - if (meta->source != SERVER_MODEL_SOURCE_CACHE) { + if (it->second.meta.source != SERVER_MODEL_SOURCE_CACHE) { throw std::runtime_error("model name=" + name + " is not removable (not from cache)"); } - unload(name); // cancel download or stop running instance - { - std::unique_lock lk(mutex); - // a cancelled download lands on DOWNLOADED; a stopped instance lands on UNLOADED - wait(lk, name, [](const server_model_meta & new_meta) { - return new_meta.status == SERVER_MODEL_STATUS_UNLOADED - || new_meta.status == SERVER_MODEL_STATUS_DOWNLOADED; - }); - // join before erasing - after status reaches UNLOADED/DOWNLOADED the thread no - // longer acquires this mutex, so joining while holding it is safe - if (mapping[name].th.joinable()) { - mapping[name].th.join(); + if (it->second.meta.status == SERVER_MODEL_STATUS_DOWNLOADING) { + // cancel in-flight download + SRV_INF("cancelling download for model name=%s\n", name.c_str()); + it->second.subproc->request_exit(); + } else if (it->second.meta.is_running()) { + // stop running instance + SRV_INF("stopping model instance name=%s\n", name.c_str()); + stopping_models.insert(name); + if (it->second.meta.status == SERVER_MODEL_STATUS_LOADING) { + it->second.subproc->terminate(); } - // remove the model from disk (hold lock to prevent concurrent load) - bool ok = common_download_remove(name); - if (ok) { - mapping.erase(name); - } - SRV_INF("removing model name=%s from cache (%s)\n", name.c_str(), ok ? "succeeded" : "failed"); - notify_sse("model_remove", name, {}); - return ok; + cv_stop.notify_all(); } + + // wait until the monitoring thread finishes + wait(lk, name, [](const server_model_meta & meta) { + return meta.status == SERVER_MODEL_STATUS_UNLOADED + || meta.status == SERVER_MODEL_STATUS_DOWNLOADED; + }); + + // re-find after wait - load_models() may have erased the entry during the wait + it = mapping.find(name); + if (it == mapping.end()) { + // load_models() already joined the thread and erased the entry; + // we just need to clean up the cached files on disk + lk.unlock(); + bool ok = common_download_remove(name); + SRV_INF("removing model name=%s from cache (%s)\n", name.c_str(), ok ? "succeeded" : "partial"); + notify_sse("model_remove", name, {}); + return true; + } + + // join before erasing - thread no longer acquires this mutex + if (it->second.th.joinable()) { + it->second.th.join(); + } + + // remove from disk (best-effort: cancelled downloads may have no cached files) + bool ok = common_download_remove(name); + mapping.erase(name); + if (!ok) { + SRV_WRN("removing model name=%s from disk returned false (no cached files?)\n", name.c_str()); + } + SRV_INF("removing model name=%s from cache (%s)\n", name.c_str(), ok ? "succeeded" : "partial"); + notify_sse("model_remove", name, {}); + return true; } void server_models::wait(const std::string & name, std::function predicate) { @@ -1243,7 +1223,9 @@ void server_models::wait(std::unique_lock & lk, const std::string & return predicate(it->second.meta); } - return false; + // model was removed from mapping by another code path (e.g. load_models()). + // nothing left to wait for - tell the caller to proceed. + return true; }); } @@ -1328,6 +1310,31 @@ void server_models::handle_child_state(const std::string & name, const std::stri } switch (state) { + case SERVER_STATE_DOWNLOADING: + { + std::string result = json_value(payload, "result", std::string()); + std::string url = json_value(payload, "url", std::string()); + auto request_exit = [&]() { + std::lock_guard lk(mutex); + auto it = mapping.find(name); + if (it != mapping.end()) { + return it->second.subproc->request_exit(); + } + }; + if (result == "download_finished") { + update_download_progress(name, {}, true, true); + request_exit(); + } else if (result == "download_failed") { + update_download_progress(name, {}, true, false); + request_exit(); + } else if (!url.empty()) { + common_download_progress p; + p.url = url; + p.downloaded = json_value(payload, "downloaded", (size_t)0); + p.total = json_value(payload, "total", (size_t)0); + update_download_progress(name, p, false); + } + } break; case SERVER_STATE_LOADING: { update_status(name, { @@ -1366,6 +1373,90 @@ bool server_child::is_child() { return router_port != nullptr; } +server_child_mode server_child::get_mode() { + const char * mode = std::getenv("LLAMA_SERVER_CHILD_MODE"); + std::string mode_str(mode ? mode : ""); + if (mode_str == "download") { + return SERVER_CHILD_MODE_DOWNLOAD; + } else { + return SERVER_CHILD_MODE_NORMAL; + } +} + +struct server_download_state : public common_download_callback { + server_child * self; + std::function should_stop; + std::atomic last_progress_time{0}; // multiple files downloading in different threads + bool is_ok = false; + + server_download_state(server_child * s) : self(s) {} + + bool run(common_params & params) { + try { + common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, this); + is_ok = true; + } catch (const std::exception & e) { + auto model_name = params.model.get_name(); + SRV_ERR("download failed for model name=%s: %s\n", model_name.c_str(), e.what()); + is_ok = false; + } + return is_ok; + } + void on_progress(const common_download_progress & p) { + json data = { + {"url", p.url}, + {"downloaded", p.downloaded}, + {"total", p.total}, + }; + self->notify_to_router(server_state_to_str(SERVER_STATE_DOWNLOADING), data); + } + void on_start(const common_download_progress & p) override { + on_progress(p); + } + void on_update(const common_download_progress & p) override { + int64_t now = ggml_time_ms(); + // throttle progress updates to avoid flooding logs + if (now - last_progress_time.load(std::memory_order_relaxed) >= 100) { + on_progress(p); + last_progress_time.store(now, std::memory_order_relaxed); + } + } + void on_done(const common_download_progress & p, bool) override { + on_progress(p); + } + bool is_cancelled() const override { + return should_stop ? should_stop() : false; + } +}; + +int server_child::run_download(common_params & params) { + auto cancelled = std::make_shared>(false); + + // monitor stdin for cancellation command from the router + std::thread signal_thread = setup([cancelled](int) { + cancelled->store(true, std::memory_order_relaxed); + }); + + server_download_state dl(this); + dl.should_stop = [cancelled]() { + return cancelled->load(std::memory_order_relaxed); + }; + + bool ok = dl.run(params); + + notify_to_router(server_state_to_str(SERVER_STATE_DOWNLOADING), { + {"result", ok ? "download_finished" : "download_failed"}, + }); + + // router should send CMD_ROUTER_TO_CHILD_EXIT after receiving the result + if (signal_thread.joinable()) { + signal_thread.join(); + } + + SRV_INF("download completed %s\n", ok ? "successfully" : "with errors"); + return 0; +} + std::thread server_child::setup(const std::function & shutdown_handler) { // setup thread for monitoring stdin return std::thread([shutdown_handler]() { @@ -1639,7 +1730,7 @@ void server_models_routes::init_routes() { res_err(res, format_error_response("model is not found", ERROR_TYPE_INVALID_REQUEST)); return res; } - if (!model->is_running()) { + if (!model->is_running() && model->status != SERVER_MODEL_STATUS_DOWNLOADING) { res_err(res, format_error_response("model is not running", ERROR_TYPE_INVALID_REQUEST)); return res; } @@ -1680,8 +1771,9 @@ void server_models_routes::init_routes() { model.hf_repo = name; opts.bearer_token = params.hf_token; - opts.download_mmproj = true; - opts.download_mtp = true; + // note: we only check main model, no need sidecar here + opts.download_mmproj = false; + opts.download_mtp = false; // first, only check if the model is valid and can be downloaded opts.skip_download = true; @@ -1702,10 +1794,21 @@ void server_models_routes::init_routes() { throw std::invalid_argument("model validation failed, unable to download"); } + // reject if model already exists + if (models.has_model(name)) { + throw std::invalid_argument("model '" + name + "' already exists"); + } + // then, proceed with the actual download - opts.skip_download = false; SRV_INF("starting download for model '%s'\n", name.c_str()); - models.download(std::move(model), std::move(opts)); + { + server_models::load_options load_opts; + load_opts.mode = SERVER_CHILD_MODE_DOWNLOAD; + load_opts.custom_meta = server_model_meta{}; + load_opts.custom_meta->source = SERVER_MODEL_SOURCE_CACHE; + load_opts.custom_meta->name = name; + models.load(name, load_opts); + } res_ok(res, {{"success", true}}); return res; @@ -1719,10 +1822,7 @@ void server_models_routes::init_routes() { throw std::invalid_argument("model must be a non-empty string"); } - bool ok = models.remove(name); - if (!ok) { - throw std::runtime_error("failed to remove model '" + name + "'"); - } + models.remove(name); // throws on error res_ok(res, {{"success", true}}); return res; diff --git a/tools/server/server-models.h b/tools/server/server-models.h index 17759b00a5..9ed4aeead0 100644 --- a/tools/server/server-models.h +++ b/tools/server/server-models.h @@ -40,6 +40,11 @@ enum server_model_source { SERVER_MODEL_SOURCE_CACHE, }; +enum server_child_mode { + SERVER_CHILD_MODE_NORMAL, // load the model and run normally + SERVER_CHILD_MODE_DOWNLOAD, // download the model and exit +}; + static std::string server_model_status_to_string(server_model_status status) { switch (status) { case SERVER_MODEL_STATUS_DOWNLOADING: return "downloading"; @@ -105,7 +110,6 @@ private: std::shared_ptr subproc; // shared between main thread and monitoring thread std::thread th; server_model_meta meta; - FILE * stdin_file = nullptr; }; std::mutex mutex; @@ -161,16 +165,19 @@ public: // return a copy of all model metadata (thread-safe) std::vector get_all_meta(); + struct load_options { + server_child_mode mode = SERVER_CHILD_MODE_NORMAL; + // used for spawning a downloading child process + std::optional custom_meta = std::nullopt; + }; + // load and unload model instances // these functions are thread-safe void load(const std::string & name); + void load(const std::string & name, const load_options & opts); void unload(const std::string & name); void unload_all(); - // download a new model, progress is reported via SSE - // to stop the download, call unload() - void download(common_params_model && model, common_download_opts && opts); - struct update_status_args { server_model_status status; int exit_code = 0; // only valid if status == UNLOADED @@ -213,9 +220,12 @@ public: struct server_child { // serializes the notify_to_router writes std::mutex mtx_stdout; + std::atomic is_finished_downloading = false; // set by run_download // return true if the current process is a child server instance bool is_child(); + server_child_mode get_mode(); + int run_download(common_params & params); // register the shutdown_handler to be called by the router // return the monitoring thread (to be joined by the caller) diff --git a/tools/server/server.cpp b/tools/server/server.cpp index bf3680b9f0..dd4b1c507c 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -134,6 +134,7 @@ int llama_server(int argc, char ** argv) { // // register API routes + server_child child; // only used in non-router mode server_routes routes(params, ctx_server); server_tools tools; @@ -254,11 +255,21 @@ int llama_server(int argc, char ** argv) { ctx_http.post("/tools", ex_wrapper(tools.handle_post)); } + // + // Handle downloading model + // + + if (child.is_child() && child.get_mode() == SERVER_CHILD_MODE_DOWNLOAD) { + return child.run_download(params); + } else if (!is_router_server) { + // single-model mode (NOT spawned by router) + common_params_handle_models(params, LLAMA_EXAMPLE_SERVER); + } + // // Start the server // - server_child child; // only used in non-router mode std::function clean_up; if (is_router_server) { diff --git a/tools/server/tests/unit/test_router.py b/tools/server/tests/unit/test_router.py index 11c77ca7aa..41e95f4c5f 100644 --- a/tools/server/tests/unit/test_router.py +++ b/tools/server/tests/unit/test_router.py @@ -257,14 +257,25 @@ def test_router_reload_models(): MODEL_DOWNLOAD_ID = "ggml-org/test-model-router-download:F16" -MODEL_DOWNLOAD_TIMEOUT = 300 +MODEL_DOWNLOAD_TIMEOUT = 30 -def _listen_sse(server: ServerProcess, collected: list, stop: threading.Event): - """Collect /models/sse events into `collected` until `stop` is set.""" +def _listen_sse( + server: ServerProcess, collected: list, stop: threading.Event, ready: threading.Event | None = None +): + """Collect /models/sse events into `collected` until `stop` is set. + + When `ready` is provided, it is set once the streaming response is open, + i.e. the server has accepted the connection and registered us as a + subscriber. Callers that trigger one-shot events (e.g. download_finished) + must wait on `ready` before acting, otherwise the event can be broadcast + before this client is subscribed and be lost. + """ url = f"http://{server.server_host}:{server.server_port}/models/sse" try: with requests.get(url, stream=True, timeout=MODEL_DOWNLOAD_TIMEOUT) as resp: + if ready is not None: + ready.set() for line_bytes in resp.iter_lines(): if stop.is_set(): break @@ -294,11 +305,17 @@ def test_router_download_model(): sse_events: list = [] stop = threading.Event() + sse_ready = threading.Event() sse_thread = threading.Thread( - target=_listen_sse, args=(server, sse_events, stop), daemon=True + target=_listen_sse, args=(server, sse_events, stop, sse_ready), daemon=True ) sse_thread.start() + # wait for the SSE client to be subscribed before triggering the download, + # otherwise the one-shot download_finished event can be broadcast before + # this client is registered and be lost + assert sse_ready.wait(10), "SSE client failed to connect" + # Trigger the download res = server.make_request("POST", "/models", data={"model": MODEL_DOWNLOAD_ID}) assert res.status_code == 200 @@ -328,13 +345,17 @@ def test_router_delete_model(): # Ensure the model exists (download it if needed) if MODEL_DOWNLOAD_ID not in _get_model_ids(is_reload=False): - res = server.make_request("POST", "/models", data={"model": MODEL_DOWNLOAD_ID}) - assert res.status_code == 200 sse_events: list = [] stop = threading.Event() + sse_ready = threading.Event() threading.Thread( - target=_listen_sse, args=(server, sse_events, stop), daemon=True + target=_listen_sse, args=(server, sse_events, stop, sse_ready), daemon=True ).start() + # subscribe before triggering the download so the one-shot + # download_finished event is not lost (see test_router_download_model) + assert sse_ready.wait(10), "SSE client failed to connect" + res = server.make_request("POST", "/models", data={"model": MODEL_DOWNLOAD_ID}) + assert res.status_code == 200 finished = _wait_for_sse_event( sse_events, "download_finished", MODEL_DOWNLOAD_ID, MODEL_DOWNLOAD_TIMEOUT )