diff --git a/tools/server/README.md b/tools/server/README.md index 5efdad0954..7fa3a4d728 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -1863,11 +1863,15 @@ Example events: "data": { "status": "loading", "progress": { - "stage": "fit_params", - "value": 0.5 // from 0.0 to 1.0 ; note: not all stages have this "value" + "stages": ["text_model", "spec_model", "mmproj_model"], + "current": "text_model", + "value": 0.5 } } } +// note for "loading" status: +// - subsequent events will follow the same order of "stages" list +// - mmap is may report incorrect progress on some platforms; if you need exact progress, use --no-mmap { "model": "...", diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 91a8eb9452..3f9391cacb 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -962,6 +962,7 @@ private: struct load_progress_data { server_context_impl * ctx; std::string stage; + std::vector stages; int64_t t_last_load_progress_ms = 0; load_progress_data(server_context_impl * ctx, const std::string & stage) : ctx(ctx), stage(stage) {} }; @@ -982,7 +983,8 @@ private: } if (d->ctx->callback_state) { d->ctx->callback_state(SERVER_STATE_LOADING, { - {"stage", d->stage}, + {"stages", d->stages}, + {"current", d->stage}, {"value", progress}, }); } @@ -992,18 +994,42 @@ private: // load the model and initialize llama_context // this may also be called to resume from sleeping state bool load_model(common_params & params) { - load_progress_data load_progress_text(this, "text_model"); + load_progress_data load_progress_text (this, "text_model"); load_progress_data load_progress_mmproj(this, "mmproj_model"); + load_progress_data load_progress_spec (this, "spec_model"); - bool is_resume = sleeping; - - SRV_INF("loading model '%s'\n", params.model.path.c_str()); + const bool is_resume = sleeping; params_base = params; params_base.n_outputs_max = server_n_outputs_max(params_base); + const bool has_mmproj = !params.mmproj.path.empty(); + const bool has_draft = params.speculative.has_dft(); + const bool spec_mtp = std::find(params_base.speculative.types.begin(), + params_base.speculative.types.end(), + COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params_base.speculative.types.end(); + const bool has_spec = has_draft || spec_mtp; + + if (callback_state) { + std::vector stages = {"text_model"}; + if (has_spec) { + stages.push_back("spec_model"); + } + if (has_mmproj) { + stages.push_back("mmproj_model"); + } + load_progress_text.stages = stages; + load_progress_mmproj.stages = stages; + load_progress_spec.stages = stages; + + // trigger 0% progress + load_progress_callback(0.0f, &load_progress_text); + } + + + SRV_INF("loading model '%s'\n", params.model.path.c_str()); + std::string & mmproj_path = params_base.mmproj.path; - bool has_mmproj = !mmproj_path.empty(); mtmd_context_params mparams = mtmd_context_params_default(); if (has_mmproj) { mparams.use_gpu = params_base.mmproj_use_gpu; @@ -1050,16 +1076,7 @@ private: // optionally reserve VRAM for the draft / MTP context before fitting the target model if (params_base.fit_params) { - if (callback_state) { - callback_state(SERVER_STATE_LOADING, {{"stage", "fit_params"}}); - } - - const bool spec_mtp = std::find(params_base.speculative.types.begin(), - params_base.speculative.types.end(), - COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params_base.speculative.types.end(); - const bool has_draft = params_base.speculative.has_dft(); - - if (has_draft || spec_mtp) { + if (has_spec) { common_params params_dft = params_base; bool measure_model_bytes = true; @@ -1151,11 +1168,7 @@ private: add_bos_token = llama_vocab_get_add_bos(vocab); - if (params_base.speculative.has_dft()) { - if (callback_state) { - callback_state(SERVER_STATE_LOADING, {{"stage", "spec_model"}}); - } - + if (has_draft) { // TODO speculative: move to common/speculative.cpp? const auto & params_spec = params_base.speculative.draft; @@ -1178,6 +1191,10 @@ private: auto mparams_dft = common_model_params_to_llama(params_dft); + // progress callback + mparams_dft.progress_callback = load_progress_callback; + mparams_dft.progress_callback_user_data = &load_progress_spec; + model_dft.reset(llama_model_load_from_file(params_dft.model.path.c_str(), mparams_dft)); if (model_dft == nullptr) { SRV_ERR("failed to load draft model, '%s'\n", params_dft.model.path.c_str()); @@ -1186,10 +1203,6 @@ private: auto cparams = common_context_params_to_llama(params_dft); - const bool spec_mtp = std::find(params_base.speculative.types.begin(), - params_base.speculative.types.end(), - COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params_base.speculative.types.end(); - if (spec_mtp) { cparams.ctx_type = LLAMA_CONTEXT_TYPE_MTP; } @@ -1203,8 +1216,10 @@ private: params_base.speculative.draft.ctx_tgt = ctx_tgt; params_base.speculative.draft.ctx_dft = ctx_dft.get(); - } else if (std::find(params_base.speculative.types.begin(), params_base.speculative.types.end(), - COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params_base.speculative.types.end()) { + } else if (spec_mtp) { + // no new model load, so we simply report 0.0 and 1.0 progress + load_progress_callback(0.0f, &load_progress_spec); + SRV_INF("creating MTP draft context against the target model '%s'\n", params_base.model.path.c_str()); @@ -1224,6 +1239,8 @@ private: params_base.speculative.draft.ctx_tgt = ctx_tgt; params_base.speculative.draft.ctx_dft = ctx_dft.get(); + + load_progress_callback(1.0f, &load_progress_spec); } if (has_mmproj) {