mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-27 23:50:20 -05:00
server: fix report progress for loading spec models, add "stages" list (#24870)
* server: fix report progress for loading spec models, add "stages" list * improve * nits * nits 2
This commit is contained in:
parent
bddfd2b113
commit
7c082bc417
@ -1863,11 +1863,15 @@ Example events:
|
||||
"data": {
|
||||
"status": "loading",
|
||||
"progress": {
|
||||
"stage": "fit_params",
|
||||
"value": 0.5 // from 0.0 to 1.0 ; note: not all stages have this "value"
|
||||
"stages": ["text_model", "spec_model", "mmproj_model"],
|
||||
"current": "text_model",
|
||||
"value": 0.5
|
||||
}
|
||||
}
|
||||
}
|
||||
// note for "loading" status:
|
||||
// - subsequent events will follow the same order of "stages" list
|
||||
// - mmap is may report incorrect progress on some platforms; if you need exact progress, use --no-mmap
|
||||
|
||||
{
|
||||
"model": "...",
|
||||
|
||||
@ -962,6 +962,7 @@ private:
|
||||
struct load_progress_data {
|
||||
server_context_impl * ctx;
|
||||
std::string stage;
|
||||
std::vector<std::string> stages;
|
||||
int64_t t_last_load_progress_ms = 0;
|
||||
load_progress_data(server_context_impl * ctx, const std::string & stage) : ctx(ctx), stage(stage) {}
|
||||
};
|
||||
@ -982,7 +983,8 @@ private:
|
||||
}
|
||||
if (d->ctx->callback_state) {
|
||||
d->ctx->callback_state(SERVER_STATE_LOADING, {
|
||||
{"stage", d->stage},
|
||||
{"stages", d->stages},
|
||||
{"current", d->stage},
|
||||
{"value", progress},
|
||||
});
|
||||
}
|
||||
@ -992,18 +994,42 @@ private:
|
||||
// load the model and initialize llama_context
|
||||
// this may also be called to resume from sleeping state
|
||||
bool load_model(common_params & params) {
|
||||
load_progress_data load_progress_text(this, "text_model");
|
||||
load_progress_data load_progress_text (this, "text_model");
|
||||
load_progress_data load_progress_mmproj(this, "mmproj_model");
|
||||
load_progress_data load_progress_spec (this, "spec_model");
|
||||
|
||||
bool is_resume = sleeping;
|
||||
|
||||
SRV_INF("loading model '%s'\n", params.model.path.c_str());
|
||||
const bool is_resume = sleeping;
|
||||
|
||||
params_base = params;
|
||||
params_base.n_outputs_max = server_n_outputs_max(params_base);
|
||||
|
||||
const bool has_mmproj = !params.mmproj.path.empty();
|
||||
const bool has_draft = params.speculative.has_dft();
|
||||
const bool spec_mtp = std::find(params_base.speculative.types.begin(),
|
||||
params_base.speculative.types.end(),
|
||||
COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params_base.speculative.types.end();
|
||||
const bool has_spec = has_draft || spec_mtp;
|
||||
|
||||
if (callback_state) {
|
||||
std::vector<std::string> stages = {"text_model"};
|
||||
if (has_spec) {
|
||||
stages.push_back("spec_model");
|
||||
}
|
||||
if (has_mmproj) {
|
||||
stages.push_back("mmproj_model");
|
||||
}
|
||||
load_progress_text.stages = stages;
|
||||
load_progress_mmproj.stages = stages;
|
||||
load_progress_spec.stages = stages;
|
||||
|
||||
// trigger 0% progress
|
||||
load_progress_callback(0.0f, &load_progress_text);
|
||||
}
|
||||
|
||||
|
||||
SRV_INF("loading model '%s'\n", params.model.path.c_str());
|
||||
|
||||
std::string & mmproj_path = params_base.mmproj.path;
|
||||
bool has_mmproj = !mmproj_path.empty();
|
||||
mtmd_context_params mparams = mtmd_context_params_default();
|
||||
if (has_mmproj) {
|
||||
mparams.use_gpu = params_base.mmproj_use_gpu;
|
||||
@ -1050,16 +1076,7 @@ private:
|
||||
|
||||
// optionally reserve VRAM for the draft / MTP context before fitting the target model
|
||||
if (params_base.fit_params) {
|
||||
if (callback_state) {
|
||||
callback_state(SERVER_STATE_LOADING, {{"stage", "fit_params"}});
|
||||
}
|
||||
|
||||
const bool spec_mtp = std::find(params_base.speculative.types.begin(),
|
||||
params_base.speculative.types.end(),
|
||||
COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params_base.speculative.types.end();
|
||||
const bool has_draft = params_base.speculative.has_dft();
|
||||
|
||||
if (has_draft || spec_mtp) {
|
||||
if (has_spec) {
|
||||
common_params params_dft = params_base;
|
||||
bool measure_model_bytes = true;
|
||||
|
||||
@ -1151,11 +1168,7 @@ private:
|
||||
|
||||
add_bos_token = llama_vocab_get_add_bos(vocab);
|
||||
|
||||
if (params_base.speculative.has_dft()) {
|
||||
if (callback_state) {
|
||||
callback_state(SERVER_STATE_LOADING, {{"stage", "spec_model"}});
|
||||
}
|
||||
|
||||
if (has_draft) {
|
||||
// TODO speculative: move to common/speculative.cpp?
|
||||
const auto & params_spec = params_base.speculative.draft;
|
||||
|
||||
@ -1178,6 +1191,10 @@ private:
|
||||
|
||||
auto mparams_dft = common_model_params_to_llama(params_dft);
|
||||
|
||||
// progress callback
|
||||
mparams_dft.progress_callback = load_progress_callback;
|
||||
mparams_dft.progress_callback_user_data = &load_progress_spec;
|
||||
|
||||
model_dft.reset(llama_model_load_from_file(params_dft.model.path.c_str(), mparams_dft));
|
||||
if (model_dft == nullptr) {
|
||||
SRV_ERR("failed to load draft model, '%s'\n", params_dft.model.path.c_str());
|
||||
@ -1186,10 +1203,6 @@ private:
|
||||
|
||||
auto cparams = common_context_params_to_llama(params_dft);
|
||||
|
||||
const bool spec_mtp = std::find(params_base.speculative.types.begin(),
|
||||
params_base.speculative.types.end(),
|
||||
COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params_base.speculative.types.end();
|
||||
|
||||
if (spec_mtp) {
|
||||
cparams.ctx_type = LLAMA_CONTEXT_TYPE_MTP;
|
||||
}
|
||||
@ -1203,8 +1216,10 @@ private:
|
||||
|
||||
params_base.speculative.draft.ctx_tgt = ctx_tgt;
|
||||
params_base.speculative.draft.ctx_dft = ctx_dft.get();
|
||||
} else if (std::find(params_base.speculative.types.begin(), params_base.speculative.types.end(),
|
||||
COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params_base.speculative.types.end()) {
|
||||
} else if (spec_mtp) {
|
||||
// no new model load, so we simply report 0.0 and 1.0 progress
|
||||
load_progress_callback(0.0f, &load_progress_spec);
|
||||
|
||||
SRV_INF("creating MTP draft context against the target model '%s'\n",
|
||||
params_base.model.path.c_str());
|
||||
|
||||
@ -1224,6 +1239,8 @@ private:
|
||||
|
||||
params_base.speculative.draft.ctx_tgt = ctx_tgt;
|
||||
params_base.speculative.draft.ctx_dft = ctx_dft.get();
|
||||
|
||||
load_progress_callback(1.0f, &load_progress_spec);
|
||||
}
|
||||
|
||||
if (has_mmproj) {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user