diff --git a/common/arg.cpp b/common/arg.cpp index 8f4f7d0763..a9b1a25b27 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -303,7 +303,6 @@ static handle_model_result common_params_handle_model(struct common_params_model if (!model.docker_repo.empty()) { model.path = common_docker_resolve_model(model.docker_repo); - model.name = model.docker_repo; } else if (!model.hf_repo.empty()) { // If -m was used with -hf, treat the model "path" as the hf_file to download if (model.hf_file.empty() && !model.path.empty()) { @@ -323,7 +322,6 @@ static handle_model_result common_params_handle_model(struct common_params_model throw std::runtime_error("failed to download model from Hugging Face"); } - model.name = model.hf_repo; model.path = download_result.model_path; if (!download_result.mmproj_path.empty()) { diff --git a/common/common.h b/common/common.h index 254454dcb1..f2f2202ec2 100644 --- a/common/common.h +++ b/common/common.h @@ -295,7 +295,16 @@ struct common_params_model { std::string hf_repo = ""; // HF repo // NOLINT std::string hf_file = ""; // HF file // NOLINT std::string docker_repo = ""; // Docker repo // NOLINT - std::string name = ""; // in format /[:] (tag is optional) // NOLINT + + std::string get_name() { + if (!hf_repo.empty()) { + return hf_repo; + } + if (!docker_repo.empty()) { + return docker_repo; + } + return path; + } }; // draft-model-based speculative decoding parameters diff --git a/tools/server/README-dev.md b/tools/server/README-dev.md index 4c41031239..2796d28350 100644 --- a/tools/server/README-dev.md +++ b/tools/server/README-dev.md @@ -180,6 +180,17 @@ That requires `JSON.stringify` when formatted to message content: } ``` +### Router mode: how child <--> router communicates + +Upon spawning a new child process using `subprocess`, both child and router listen to the stdout/stderr (combined) + +For the direction from child to router: +- Generic messages are logs, it will be forwarded to router's stdout +- Special state update messages are prefixed by `cmd_child_to_router:state:`, followed by a JSON. See `server_models::handle_child_state` for more + +For the direction from router to child: +- When server sends `cmd_router_to_child:exit`, the child should exit gracefully --> if after `DEFAULT_STOP_TIMEOUT` and the child is still running, force-kill it + ### Model management API (router mode) Model management API was added via PR [#23976](https://github.com/ggml-org/llama.cpp/pull/23976) diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 1f0e1bfd42..3de1335ec2 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -63,11 +63,6 @@ enum slot_state { SLOT_STATE_GENERATING, }; -enum server_state { - SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet - SERVER_STATE_READY, // Server is ready and model is loaded -}; - struct server_slot { int id; @@ -773,6 +768,8 @@ public: // note: chat_params must not be refreshed upon existing sleeping state server_chat_params chat_params; + server_state_callback_t callback_state = [](server_state, json) -> void {}; + server_context_impl() { mtmd_helper_log_set(common_log_default_callback, nullptr); } @@ -1244,8 +1241,8 @@ private: if (!params_base.model_alias.empty()) { // backward compat: use first alias as model name model_name = *params_base.model_alias.begin(); - } else if (!params_base.model.name.empty()) { - model_name = params_base.model.name; + } else if (!params_base.model.get_name().empty()) { + model_name = params_base.model.get_name(); } else { // fallback: derive model name from file name auto model_path = std::filesystem::path(params_base.model.path); @@ -3734,8 +3731,11 @@ struct server_res_generator : server_http_res { } }; -void server_context::on_sleeping_changed(std::function callback) { - impl->queue_tasks.on_sleeping_state(std::move(callback)); +void server_context::set_state_callback(server_state_callback_t callback) { + impl->callback_state = std::move(callback); + impl->queue_tasks.on_sleeping_state([this](bool sleeping) { + impl->callback_state(sleeping ? SERVER_STATE_SLEEPING : SERVER_STATE_READY, {}); + }); } // compute the number of tokens before the last user message in the prompt diff --git a/tools/server/server-context.h b/tools/server/server-context.h index 07afabb926..c7218a12ed 100644 --- a/tools/server/server-context.h +++ b/tools/server/server-context.h @@ -52,6 +52,31 @@ struct server_context_meta { uint64_t model_size; }; +enum server_state { + // SERVER_STATE_DOWNLOADING, + SERVER_STATE_LOADING, + SERVER_STATE_READY, + SERVER_STATE_SLEEPING, +}; + +static std::string server_state_to_str(server_state state) { + switch (state) { + case SERVER_STATE_LOADING: return "loading"; + case SERVER_STATE_READY: return "ready"; + case SERVER_STATE_SLEEPING: return "sleeping"; + default: GGML_ASSERT(false && "invalid server_state"); + } +} + +static server_state server_state_from_str(const std::string & str) { + if (str == "loading") return SERVER_STATE_LOADING; + if (str == "ready") return SERVER_STATE_READY; + if (str == "sleeping") return SERVER_STATE_SLEEPING; + GGML_ASSERT(false && "invalid server_state string"); +} + +using server_state_callback_t = std::function; + struct server_context { std::unique_ptr impl; @@ -79,9 +104,8 @@ struct server_context { // not thread-safe, should only be used from the main thread server_context_meta get_meta() const; - // register a callback to be called when sleeping state changes - // must be set before load_model() is called - void on_sleeping_changed(std::function callback); + // note: must be set before load_model() is called + void set_state_callback(server_state_callback_t callback); }; diff --git a/tools/server/server-models.cpp b/tools/server/server-models.cpp index 1fffa6b6e5..a569c8be3c 100644 --- a/tools/server/server-models.cpp +++ b/tools/server/server-models.cpp @@ -1,5 +1,6 @@ #include "server-common.h" #include "server-models.h" +#include "server-context.h" #include "build-info.h" #include "preset.h" @@ -44,9 +45,7 @@ extern char **environ; #define DEFAULT_STOP_TIMEOUT 10 // seconds #define CMD_ROUTER_TO_CHILD_EXIT "cmd_router_to_child:exit" -#define CMD_CHILD_TO_ROUTER_READY "cmd_child_to_router:ready" // also sent when waking up from sleep -#define CMD_CHILD_TO_ROUTER_SLEEP "cmd_child_to_router:sleep" -#define CMD_CHILD_TO_ROUTER_INFO "cmd_child_to_router:info:" // followed by json string +#define CMD_CHILD_TO_ROUTER_STATE "cmd_child_to_router:state:" // followed by json string // address for child process, this is needed because router may run on 0.0.0.0 // ref: https://github.com/ggml-org/llama.cpp/issues/17862 @@ -904,12 +903,8 @@ void server_models::load(const std::string & name) { while (fgets(buffer, vec_buf.size(), stdout_file) != nullptr) { LOG("[%5d] %s", port, buffer); std::string str(buffer); - if (string_starts_with(buffer, CMD_CHILD_TO_ROUTER_READY)) { - this->update_status(name, SERVER_MODEL_STATUS_LOADED, 0); - } else if (string_starts_with(buffer, CMD_CHILD_TO_ROUTER_INFO)) { - this->update_loaded_info(name, str); - } else if (string_starts_with(buffer, CMD_CHILD_TO_ROUTER_SLEEP)) { - this->update_status(name, SERVER_MODEL_STATUS_SLEEPING, 0); + if (string_starts_with(buffer, CMD_CHILD_TO_ROUTER_STATE)) { + this->handle_child_state(name, str); } } } else { @@ -976,7 +971,10 @@ void server_models::load(const std::string & name) { subprocess_destroy(&child_proc->get()); // update status and exit code - this->update_status(name, SERVER_MODEL_STATUS_UNLOADED, exit_code); + this->update_status(name, { + SERVER_MODEL_STATUS_UNLOADED, + exit_code + }); SRV_INF("instance name=%s exited with status %d\n", name.c_str(), exit_code); }); @@ -1016,7 +1014,8 @@ struct server_models_download_res : public common_download_callback { common_download_model(model, opts); is_ok = true; } catch (const std::exception & e) { - SRV_ERR("download failed for model name=%s: %s\n", model.name.c_str(), e.what()); + auto model_name = model.get_name(); + SRV_ERR("download failed for model name=%s: %s\n", model_name.c_str(), e.what()); is_ok = false; } return is_ok; @@ -1036,7 +1035,7 @@ struct server_models_download_res : public common_download_callback { }; void server_models::download(common_params_model && model, common_download_opts && opts) { - std::string name = model.name; + std::string name = model.get_name(); GGML_ASSERT(name == model.hf_repo); std::unique_lock lk(mutex); @@ -1064,9 +1063,10 @@ void server_models::download(common_params_model && model, common_download_opts inst.th = std::thread([this, dl = std::move(dl)]() { dl->opts.callback = dl.get(); bool ok = dl->run(); + auto model_name = dl->model.get_name(); SRV_INF("download finished for model name=%s with status=%s\n", - dl->model.name.c_str(), ok ? "success" : "failure"); - update_download_progress(dl->model.name, {}, true, ok); + model_name.c_str(), ok ? "success" : "failure"); + update_download_progress(model_name, {}, true, ok); // need_reload is set inside update_download_progress under the mutex; // the next load_models() call will clean up this instance }); @@ -1130,21 +1130,27 @@ void server_models::unload_all() { } } -void server_models::update_status(const std::string & name, server_model_status status, int exit_code) { +void server_models::update_status(const std::string & name, const update_status_args & args) { std::unique_lock lk(mutex); auto it = mapping.find(name); if (it != mapping.end()) { auto & meta = it->second.meta; - meta.status = status; - meta.exit_code = exit_code; + meta.status = args.status; + meta.exit_code = args.exit_code; + if (!args.loaded_info.is_null()) { + meta.loaded_info = args.loaded_info; + } } // broadcast status change to SSE { json data = { - {"status", server_model_status_to_string(status)}, + {"status", server_model_status_to_string(args.status)}, }; - if (status == SERVER_MODEL_STATUS_UNLOADED) { - data["exit_code"] = exit_code; + if (args.status == SERVER_MODEL_STATUS_UNLOADED) { + data["exit_code"] = args.exit_code; + } + if (!args.loaded_info.is_null()) { + data["info"] = args.loaded_info; } // note: notify_sse doesn't acquire the lock, so no deadlock here notify_sse("status_change", name, data); @@ -1152,29 +1158,6 @@ void server_models::update_status(const std::string & name, server_model_status cv.notify_all(); } -void server_models::update_loaded_info(const std::string & name, std::string & raw_info) { - if (!string_starts_with(raw_info, CMD_CHILD_TO_ROUTER_INFO)) { - SRV_WRN("invalid loaded info format from child for model name=%s: %s\n", name.c_str(), raw_info.c_str()); - return; - } - - json info; - try { - info = json::parse(raw_info.substr(strlen(CMD_CHILD_TO_ROUTER_INFO))); - } catch (const std::exception & e) { - SRV_WRN("failed to parse loaded info from child for model name=%s: %s\n", name.c_str(), e.what()); - return; - } - - std::unique_lock lk(mutex); - auto it = mapping.find(name); - if (it != mapping.end()) { - auto & meta = it->second.meta; - meta.loaded_info = info; - } - cv.notify_all(); -} - void server_models::update_download_progress(const std::string & name, const common_download_progress & progress, bool done, bool ok) { json curr; { @@ -1323,21 +1306,54 @@ server_http_res_ptr server_models::proxy_request(const server_http_req & req, co return proxy; } -bool server_models::is_child_server() { +void server_models::handle_child_state(const std::string & name, const std::string & raw_input) { + server_state state; + json payload; + + try { + json data = json::parse(raw_input.substr(strlen(CMD_CHILD_TO_ROUTER_STATE))); + state = server_state_from_str(json_value(data, "state", std::string())); + payload = json_value(data, "payload", json{}); + } catch (const std::exception & e) { + SRV_ERR("failed to parse child state update for name=%s: %s\n", name.c_str(), e.what()); + return; + } + + switch (state) { + case SERVER_STATE_LOADING: + { + // do nothing for now + // TODO: report loading progress for first load and wakeup from sleep + } break; + case SERVER_STATE_READY: + { + update_status(name, { + SERVER_MODEL_STATUS_LOADED, + 0, + // note: payload can be empty if this is a wakeup from sleep + payload.size() > 0 ? payload : nullptr + }); + } break; + case SERVER_STATE_SLEEPING: + { + update_status(name, { SERVER_MODEL_STATUS_SLEEPING }); + } break; + default: + // should never happen, but just in case + GGML_ASSERT(false && "unexpected state from child server"); + } +} + +// +// server_child +// + +bool server_child::is_child() { const char * router_port = std::getenv("LLAMA_SERVER_ROUTER_PORT"); return router_port != nullptr; } -std::thread server_models::setup_child_server(const std::function & shutdown_handler, const json & model_info) { - // send a notification to the router server that a model instance is ready - common_log_pause(common_log_main()); - fflush(stdout); - fprintf(stdout, "%s\n", CMD_CHILD_TO_ROUTER_READY); - fflush(stdout); - fprintf(stdout, "%s%s\n", CMD_CHILD_TO_ROUTER_INFO, safe_json_to_str(model_info).c_str()); - fflush(stdout); - common_log_resume(common_log_main()); - +std::thread server_child::setup(const std::function & shutdown_handler) { // setup thread for monitoring stdin return std::thread([shutdown_handler]() { // wait for EOF on stdin @@ -1363,10 +1379,14 @@ std::thread server_models::setup_child_server(const std::function & s }); } -void server_models::notify_router_sleeping_state(bool is_sleeping) { +void server_child::notify_to_router(const std::string & state, const json & payload) { + json data = { + {"state", state}, + {"payload", payload}, + }; common_log_pause(common_log_main()); fflush(stdout); - fprintf(stdout, "%s\n", is_sleeping ? CMD_CHILD_TO_ROUTER_SLEEP : CMD_CHILD_TO_ROUTER_READY); + fprintf(stdout, "%s%s\n", CMD_CHILD_TO_ROUTER_STATE, safe_json_to_str(data).c_str()); fflush(stdout); common_log_resume(common_log_main()); } @@ -1644,7 +1664,6 @@ void server_models_routes::init_routes() { common_params_model model; common_download_opts opts; - model.name = name; model.hf_repo = name; opts.bearer_token = params.hf_token; opts.download_mmproj = true; diff --git a/tools/server/server-models.h b/tools/server/server-models.h index 98872b0461..40a0e078c6 100644 --- a/tools/server/server-models.h +++ b/tools/server/server-models.h @@ -171,8 +171,12 @@ public: void download(common_params_model && model, common_download_opts && opts); // update the status of a model instance (thread-safe) - void update_status(const std::string & name, server_model_status status, int exit_code); - void update_loaded_info(const std::string & name, std::string & raw_info); + struct update_status_args { + server_model_status status; + int exit_code = 0; // only valid if status == UNLOADED + json loaded_info = nullptr; + }; + void update_status(const std::string & name, const update_status_args & args); void update_download_progress(const std::string & name, const common_download_progress & progress, bool done, bool ok = true); // remove a cache model from disk and update the list (thread-safe) @@ -193,15 +197,27 @@ public: // proxy an HTTP request to the model instance server_http_res_ptr proxy_request(const server_http_req & req, const std::string & method, const std::string & name, bool update_last_used); + // handle message sent from server_child::notify_to_router() + // raw input must starts with CMD_CHILD_TO_ROUTER_STATE, followed by a JSON string + // this function is not thread-safe, must be called from instance's monitoring thread + // payload per state: + // state = loading -> payload = {} (TODO: add progress info) + // state = ready -> payload = model_info (json), or {} if wakeup from sleeping + // state = sleeping -> payload = {} + void handle_child_state(const std::string & name, const std::string & raw_input); +}; + +struct server_child { // return true if the current process is a child server instance - static bool is_child_server(); + bool is_child(); - // notify the router server that a model instance is ready + // register the shutdown_handler to be called by the router // return the monitoring thread (to be joined by the caller) - static std::thread setup_child_server(const std::function & shutdown_handler, const json & model_info); + std::thread setup(const std::function & shutdown_handler); - // notify the router server that the sleeping state has changed - static void notify_router_sleeping_state(bool sleeping); + // notify router server for status changes (e.g. loading, downloading, sleeping, etc.) + // message will be handled by server_models::handle_child_state() on the router side + void notify_to_router(const std::string & state_name, const json & payload); }; struct server_models_routes { diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 2a67bfcfed..bf3680b9f0 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -90,8 +90,10 @@ int llama_server(int argc, char ** argv) { llama_numa_init(params.numa); // router server never loads a model and must not touch the GPU + const bool is_router_server = params.model.path.empty() + && params.model.hf_repo.empty(); + // skip device enumeration so the CUDA primary context stays uncreated - const bool is_router_server = params.model.path.empty(); common_params_print_info(params, !is_router_server); if (!is_router_server) { @@ -113,8 +115,9 @@ int llama_server(int argc, char ** argv) { } // for consistency between server router mode and single-model mode, we set the same model name as alias - if (params.model_alias.empty() && !params.model.name.empty()) { - params.model_alias.insert(params.model.name); + auto model_name = params.model.get_name(); + if (params.model_alias.empty() && !model_name.empty()) { + params.model_alias.insert(model_name); } // struct that contains llama context and inference @@ -255,6 +258,7 @@ int llama_server(int argc, char ** argv) { // Start the server // + server_child child; // only used in non-router mode std::function clean_up; if (is_router_server) { @@ -300,15 +304,16 @@ int llama_server(int argc, char ** argv) { return 1; } - // load the model - SRV_INF("%s", "loading model\n"); - - if (server_models::is_child_server()) { - ctx_server.on_sleeping_changed([&](bool sleeping) { - server_models::notify_router_sleeping_state(sleeping); + // setup communication child --> router if necessary + if (child.is_child()) { + ctx_server.set_state_callback([&](server_state state, json payload) { + child.notify_to_router(server_state_to_str(state), payload); }); } + // load the model + SRV_INF("%s", "loading model\n"); + if (!ctx_server.load_model(params)) { clean_up(); if (ctx_http.thread.joinable()) { @@ -365,9 +370,9 @@ int llama_server(int argc, char ** argv) { // optionally, notify router server that this instance is ready std::thread monitor_thread; - if (server_models::is_child_server()) { - json model_info = routes.get_model_info(); - monitor_thread = server_models::setup_child_server(shutdown_handler, model_info); + if (child.is_child()) { + monitor_thread = child.setup(shutdown_handler); + child.notify_to_router(server_state_to_str(SERVER_STATE_READY), routes.get_model_info()); } // this call blocks the main thread until queue_tasks.terminate() is called