server: real-time model load progress tracking via /models/sse (#24828)

* server: real-time model load progress tracking via /models/sse

* update docs

* add mutex for notify_to_router

* correct docs
This commit is contained in:
Xuan-Son Nguyen 2026-06-21 11:58:14 +02:00 committed by GitHub
parent 8a118ee86c
commit d6d899580d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 106 additions and 7 deletions

View File

@ -1859,9 +1859,33 @@ Example events:
{
"model": "...",
"event": "download_finished",
"event": "model_status",
"data": {
"status": "loading"
"status": "loading",
"progress": {
"stage": "fit_params",
"value": 0.5 // from 0.0 to 1.0 ; note: not all stages have this "value"
}
}
}
{
"model": "...",
"event": "model_status",
"data": {
"status": "loaded",
"info": {
// note: only include info on first load
// waking up from sleep doesn't have this
}
}
}
{
"model": "...",
"event": "model_status",
"data": {
"status": "sleeping"
}
}

View File

@ -833,6 +833,8 @@ private:
bool sleeping = false;
int64_t t_last_load_progress_ms = 0;
void destroy() {
spec.reset();
ctx_dft.reset();
@ -863,6 +865,30 @@ private:
sleeping = new_state;
}
static bool load_progress_callback(float progress, void * user_data) {
auto * ctx = static_cast<server_context_impl *>(user_data);
GGML_ASSERT(ctx);
// always emit the first and final sample; throttle the rest to one per 200ms
{
auto & t_last = ctx->t_last_load_progress_ms;
const int64_t t_now = ggml_time_ms();
const bool first = t_last == 0;
const bool done = progress >= 1.0f;
const bool throttled = !first && !done && (t_now - t_last) < 200;
if (throttled) {
return true;
}
t_last = t_now;
}
if (ctx->callback_state) {
ctx->callback_state(SERVER_STATE_LOADING, {
{"stage", "text_model"},
{"value", progress},
});
}
return true;
}
// load the model and initialize llama_context
// this may also be called to resume from sleeping state
bool load_model(common_params & params) {
@ -916,6 +942,10 @@ private:
// optionally reserve VRAM for the draft / MTP context before fitting the target model
if (params_base.fit_params) {
if (callback_state) {
callback_state(SERVER_STATE_LOADING, {{"stage", "fit_params"}});
}
const bool spec_mtp = std::find(params_base.speculative.types.begin(),
params_base.speculative.types.end(),
COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params_base.speculative.types.end();
@ -991,6 +1021,13 @@ private:
}
}
// attach a progress callback
{
t_last_load_progress_ms = 0;
params_base.load_progress_callback = load_progress_callback;
params_base.load_progress_callback_user_data = this;
}
llama_init = common_init_from_params(params_base);
model_tgt = llama_init->model();
@ -1008,6 +1045,10 @@ private:
add_bos_token = llama_vocab_get_add_bos(vocab);
if (params_base.speculative.has_dft()) {
if (callback_state) {
callback_state(SERVER_STATE_LOADING, {{"stage", "spec_model"}});
}
// TODO speculative: move to common/speculative.cpp?
const auto & params_spec = params_base.speculative.draft;
@ -1079,6 +1120,10 @@ private:
}
if (has_mmproj) {
if (callback_state) {
callback_state(SERVER_STATE_LOADING, {{"stage", "mmproj_model"}});
}
if (!is_resume) {
mtmd_helper_log_set(common_log_default_callback, nullptr);
}
@ -1259,6 +1304,10 @@ private:
return init();
}
if (callback_state) {
callback_state(SERVER_STATE_READY, {});
}
return true;
}
@ -1335,6 +1384,9 @@ private:
const bool enable_thinking = params_base.enable_reasoning != 0 && template_supports_thinking;
SRV_INF("%s: chat template, thinking = %d\n", __func__, enable_thinking);
// IMPORTANT: chat_params is reused across sleeping / resuming states,
// never store llama_context/llama_model pointers in chat_params,
// as they may be invalidated after sleeping
chat_params = {
/* use_jinja */ params_base.use_jinja,
/* prefill_assistant */ params_base.prefill_assistant,
@ -3734,7 +3786,10 @@ struct server_res_generator : server_http_res {
void server_context::set_state_callback(server_state_callback_t callback) {
impl->callback_state = std::move(callback);
impl->queue_tasks.on_sleeping_state([this](bool sleeping) {
impl->callback_state(sleeping ? SERVER_STATE_SLEEPING : SERVER_STATE_READY, {});
if (sleeping) {
impl->callback_state(SERVER_STATE_SLEEPING, {});
}
// for sleeping == false, event is emitted by load_model()
});
}

View File

@ -442,6 +442,7 @@ void server_models::load_models() {
/* last_used */ 0,
/* args */ std::vector<std::string>(),
/* loaded_info */ {},
/* progress */ {},
/* exit_code */ 0,
/* stop_timeout */ DEFAULT_STOP_TIMEOUT,
/* multimodal */ mtmd_caps{false, false},
@ -608,6 +609,7 @@ void server_models::load_models() {
/* last_used */ 0,
/* args */ std::vector<std::string>(),
/* loaded_info */ {},
/* progress */ {},
/* exit_code */ 0,
/* stop_timeout */ DEFAULT_STOP_TIMEOUT,
/* multimodal */ mtmd_caps{false, false},
@ -1140,6 +1142,9 @@ void server_models::update_status(const std::string & name, const update_status_
if (!args.loaded_info.is_null()) {
meta.loaded_info = args.loaded_info;
}
if (!args.progress.is_null()) {
meta.progress = args.progress;
}
}
// broadcast status change to SSE
{
@ -1152,6 +1157,9 @@ void server_models::update_status(const std::string & name, const update_status_
if (!args.loaded_info.is_null()) {
data["info"] = args.loaded_info;
}
if (!args.progress.is_null()) {
data["progress"] = args.progress;
}
// note: notify_sse doesn't acquire the lock, so no deadlock here
notify_sse("status_change", name, data);
}
@ -1322,8 +1330,12 @@ void server_models::handle_child_state(const std::string & name, const std::stri
switch (state) {
case SERVER_STATE_LOADING:
{
// do nothing for now
// TODO: report loading progress for first load and wakeup from sleep
update_status(name, {
SERVER_MODEL_STATUS_LOADING,
0,
nullptr, // no loaded_info yet
payload,
});
} break;
case SERVER_STATE_READY:
{
@ -1331,7 +1343,8 @@ void server_models::handle_child_state(const std::string & name, const std::stri
SERVER_MODEL_STATUS_LOADED,
0,
// note: payload can be empty if this is a wakeup from sleep
payload.size() > 0 ? payload : nullptr
payload.size() > 0 ? payload : nullptr,
{}, // reset progress info
});
} break;
case SERVER_STATE_SLEEPING:
@ -1384,6 +1397,7 @@ void server_child::notify_to_router(const std::string & state, const json & payl
{"state", state},
{"payload", payload},
};
std::lock_guard<std::mutex> lk(mtx_stdout);
common_log_pause(common_log_main());
fflush(stdout);
fprintf(stdout, "%s%s\n", CMD_CHILD_TO_ROUTER_STATE, safe_json_to_str(data).c_str());

View File

@ -72,6 +72,7 @@ struct server_model_meta {
int64_t last_used = 0; // for LRU unloading
std::vector<std::string> args; // args passed to the model instance, will be populated by render_args()
json loaded_info; // info to be reflected via /v1/models endpoint ; if in DOWNLOADING state, it should contain download progress info
json progress; // reflect load or download progress info, if any
int exit_code = 0; // exit code of the model instance process (only valid if status == FAILED)
int stop_timeout = 0; // seconds to wait before force-killing the model instance during shutdown
mtmd_caps multimodal; // multimodal capabilities
@ -170,12 +171,14 @@ public:
// to stop the download, call unload()
void download(common_params_model && model, common_download_opts && opts);
// update the status of a model instance (thread-safe)
struct update_status_args {
server_model_status status;
int exit_code = 0; // only valid if status == UNLOADED
json loaded_info = nullptr;
json progress = nullptr;
};
// update the status of a model instance (thread-safe)
// also send SSE notification to /models/sse endpoint
void update_status(const std::string & name, const update_status_args & args);
void update_download_progress(const std::string & name, const common_download_progress & progress, bool done, bool ok = true);
@ -208,6 +211,9 @@ public:
};
struct server_child {
// serializes the notify_to_router writes
std::mutex mtx_stdout;
// return true if the current process is a child server instance
bool is_child();