diff --git a/tools/server/README.md b/tools/server/README.md index eb730e713a..5efdad0954 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -1859,9 +1859,33 @@ Example events: { "model": "...", - "event": "download_finished", + "event": "model_status", "data": { - "status": "loading" + "status": "loading", + "progress": { + "stage": "fit_params", + "value": 0.5 // from 0.0 to 1.0 ; note: not all stages have this "value" + } + } +} + +{ + "model": "...", + "event": "model_status", + "data": { + "status": "loaded", + "info": { + // note: only include info on first load + // waking up from sleep doesn't have this + } + } +} + +{ + "model": "...", + "event": "model_status", + "data": { + "status": "sleeping" } } diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 3de1335ec2..531b106e55 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -833,6 +833,8 @@ private: bool sleeping = false; + int64_t t_last_load_progress_ms = 0; + void destroy() { spec.reset(); ctx_dft.reset(); @@ -863,6 +865,30 @@ private: sleeping = new_state; } + static bool load_progress_callback(float progress, void * user_data) { + auto * ctx = static_cast(user_data); + GGML_ASSERT(ctx); + // always emit the first and final sample; throttle the rest to one per 200ms + { + auto & t_last = ctx->t_last_load_progress_ms; + const int64_t t_now = ggml_time_ms(); + const bool first = t_last == 0; + const bool done = progress >= 1.0f; + const bool throttled = !first && !done && (t_now - t_last) < 200; + if (throttled) { + return true; + } + t_last = t_now; + } + if (ctx->callback_state) { + ctx->callback_state(SERVER_STATE_LOADING, { + {"stage", "text_model"}, + {"value", progress}, + }); + } + return true; + } + // load the model and initialize llama_context // this may also be called to resume from sleeping state bool load_model(common_params & params) { @@ -916,6 +942,10 @@ private: // optionally reserve VRAM for the draft / MTP context before fitting the target model if (params_base.fit_params) { + if (callback_state) { + callback_state(SERVER_STATE_LOADING, {{"stage", "fit_params"}}); + } + const bool spec_mtp = std::find(params_base.speculative.types.begin(), params_base.speculative.types.end(), COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params_base.speculative.types.end(); @@ -991,6 +1021,13 @@ private: } } + // attach a progress callback + { + t_last_load_progress_ms = 0; + params_base.load_progress_callback = load_progress_callback; + params_base.load_progress_callback_user_data = this; + } + llama_init = common_init_from_params(params_base); model_tgt = llama_init->model(); @@ -1008,6 +1045,10 @@ private: add_bos_token = llama_vocab_get_add_bos(vocab); if (params_base.speculative.has_dft()) { + if (callback_state) { + callback_state(SERVER_STATE_LOADING, {{"stage", "spec_model"}}); + } + // TODO speculative: move to common/speculative.cpp? const auto & params_spec = params_base.speculative.draft; @@ -1079,6 +1120,10 @@ private: } if (has_mmproj) { + if (callback_state) { + callback_state(SERVER_STATE_LOADING, {{"stage", "mmproj_model"}}); + } + if (!is_resume) { mtmd_helper_log_set(common_log_default_callback, nullptr); } @@ -1259,6 +1304,10 @@ private: return init(); } + if (callback_state) { + callback_state(SERVER_STATE_READY, {}); + } + return true; } @@ -1335,6 +1384,9 @@ private: const bool enable_thinking = params_base.enable_reasoning != 0 && template_supports_thinking; SRV_INF("%s: chat template, thinking = %d\n", __func__, enable_thinking); + // IMPORTANT: chat_params is reused across sleeping / resuming states, + // never store llama_context/llama_model pointers in chat_params, + // as they may be invalidated after sleeping chat_params = { /* use_jinja */ params_base.use_jinja, /* prefill_assistant */ params_base.prefill_assistant, @@ -3734,7 +3786,10 @@ struct server_res_generator : server_http_res { void server_context::set_state_callback(server_state_callback_t callback) { impl->callback_state = std::move(callback); impl->queue_tasks.on_sleeping_state([this](bool sleeping) { - impl->callback_state(sleeping ? SERVER_STATE_SLEEPING : SERVER_STATE_READY, {}); + if (sleeping) { + impl->callback_state(SERVER_STATE_SLEEPING, {}); + } + // for sleeping == false, event is emitted by load_model() }); } diff --git a/tools/server/server-models.cpp b/tools/server/server-models.cpp index a569c8be3c..68eefdffac 100644 --- a/tools/server/server-models.cpp +++ b/tools/server/server-models.cpp @@ -442,6 +442,7 @@ void server_models::load_models() { /* last_used */ 0, /* args */ std::vector(), /* loaded_info */ {}, + /* progress */ {}, /* exit_code */ 0, /* stop_timeout */ DEFAULT_STOP_TIMEOUT, /* multimodal */ mtmd_caps{false, false}, @@ -608,6 +609,7 @@ void server_models::load_models() { /* last_used */ 0, /* args */ std::vector(), /* loaded_info */ {}, + /* progress */ {}, /* exit_code */ 0, /* stop_timeout */ DEFAULT_STOP_TIMEOUT, /* multimodal */ mtmd_caps{false, false}, @@ -1140,6 +1142,9 @@ void server_models::update_status(const std::string & name, const update_status_ if (!args.loaded_info.is_null()) { meta.loaded_info = args.loaded_info; } + if (!args.progress.is_null()) { + meta.progress = args.progress; + } } // broadcast status change to SSE { @@ -1152,6 +1157,9 @@ void server_models::update_status(const std::string & name, const update_status_ if (!args.loaded_info.is_null()) { data["info"] = args.loaded_info; } + if (!args.progress.is_null()) { + data["progress"] = args.progress; + } // note: notify_sse doesn't acquire the lock, so no deadlock here notify_sse("status_change", name, data); } @@ -1322,8 +1330,12 @@ void server_models::handle_child_state(const std::string & name, const std::stri switch (state) { case SERVER_STATE_LOADING: { - // do nothing for now - // TODO: report loading progress for first load and wakeup from sleep + update_status(name, { + SERVER_MODEL_STATUS_LOADING, + 0, + nullptr, // no loaded_info yet + payload, + }); } break; case SERVER_STATE_READY: { @@ -1331,7 +1343,8 @@ void server_models::handle_child_state(const std::string & name, const std::stri SERVER_MODEL_STATUS_LOADED, 0, // note: payload can be empty if this is a wakeup from sleep - payload.size() > 0 ? payload : nullptr + payload.size() > 0 ? payload : nullptr, + {}, // reset progress info }); } break; case SERVER_STATE_SLEEPING: @@ -1384,6 +1397,7 @@ void server_child::notify_to_router(const std::string & state, const json & payl {"state", state}, {"payload", payload}, }; + std::lock_guard lk(mtx_stdout); common_log_pause(common_log_main()); fflush(stdout); fprintf(stdout, "%s%s\n", CMD_CHILD_TO_ROUTER_STATE, safe_json_to_str(data).c_str()); diff --git a/tools/server/server-models.h b/tools/server/server-models.h index 40a0e078c6..17759b00a5 100644 --- a/tools/server/server-models.h +++ b/tools/server/server-models.h @@ -72,6 +72,7 @@ struct server_model_meta { int64_t last_used = 0; // for LRU unloading std::vector args; // args passed to the model instance, will be populated by render_args() json loaded_info; // info to be reflected via /v1/models endpoint ; if in DOWNLOADING state, it should contain download progress info + json progress; // reflect load or download progress info, if any int exit_code = 0; // exit code of the model instance process (only valid if status == FAILED) int stop_timeout = 0; // seconds to wait before force-killing the model instance during shutdown mtmd_caps multimodal; // multimodal capabilities @@ -170,12 +171,14 @@ public: // to stop the download, call unload() void download(common_params_model && model, common_download_opts && opts); - // update the status of a model instance (thread-safe) struct update_status_args { server_model_status status; int exit_code = 0; // only valid if status == UNLOADED json loaded_info = nullptr; + json progress = nullptr; }; + // update the status of a model instance (thread-safe) + // also send SSE notification to /models/sse endpoint void update_status(const std::string & name, const update_status_args & args); void update_download_progress(const std::string & name, const common_download_progress & progress, bool done, bool ok = true); @@ -208,6 +211,9 @@ public: }; struct server_child { + // serializes the notify_to_router writes + std::mutex mtx_stdout; + // return true if the current process is a child server instance bool is_child();