arg: fix handling --spec-draft-hf and --hf-repo-v (#25043)

* arg: fix handling --spec-draft-hf and --hf-repo-v

* fix missing mparams.hf_file
This commit is contained in:
Xuan-Son Nguyen 2026-06-26 14:36:03 +02:00 committed by GitHub
parent 5397c36194
commit 024930c6ad
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 47 additions and 7 deletions

View File

@ -352,6 +352,8 @@ static std::string get_default_local_path(const std::string & url) {
common_models_handler common_models_handler_init(const common_params & params, llama_example curr_ex) {
common_download_hf_plan plan;
common_download_hf_plan plan_spec;
common_download_hf_plan plan_voc;
common_download_opts opts;
const bool spec_type_draft_mtp = std::find(params.speculative.types.begin(),
@ -377,7 +379,15 @@ common_models_handler common_models_handler_init(const common_params & params, l
plan = common_download_get_hf_plan(params.model, opts);
}
return common_models_handler{plan, opts};
if (!params.speculative.draft.mparams.hf_repo.empty()) {
plan_spec = common_download_get_hf_plan(params.speculative.draft.mparams, opts);
}
if (!params.vocoder.model.hf_repo.empty()) {
plan_voc = common_download_get_hf_plan(params.vocoder.model, opts);
}
return common_models_handler{plan, plan_spec, plan_voc, opts};
}
bool common_models_handler_is_preset_repo(const common_models_handler & handler) {
@ -425,7 +435,9 @@ static std::vector<common_download_task> build_url_tasks(const common_params_mod
void common_models_handler_apply(common_models_handler & handler, common_params & params, common_download_callback * callback) {
std::vector<common_download_task> tasks;
auto & plan = handler.plan;
auto & plan = handler.plan;
auto & plan_spec = handler.plan_spec;
auto & plan_voc = handler.plan_voc;
auto opts = handler.opts; // copy
opts.callback = callback;
@ -484,19 +496,22 @@ void common_models_handler_apply(common_models_handler & handler, common_params
}
// handle hf_plan tasks
if (!plan.model_files.empty()) {
for (size_t i = 0; i < plan.model_files.size(); ++i) {
auto & model_file = plan.model_files[i];
auto add_tasks = [&opts, &tasks](const hf_cache::hf_files & model_files, common_params_model & model) {
for (size_t i = 0; i < model_files.size(); ++i) {
auto & model_file = model_files[i];
bool is_first = (i == 0);
tasks.emplace_back(model_file, opts, [&, is_first]() {
if (is_first) {
// only use first part as model path
params.model.path = hf_cache::finalize_file(model_file);
model.path = hf_cache::finalize_file(model_file);
} else {
hf_cache::finalize_file(model_file);
}
});
}
};
if (!plan.model_files.empty()) {
add_tasks(plan.model_files, params.model);
}
if (!plan.mmproj.local_path.empty()) {
tasks.emplace_back(plan.mmproj, opts, [&]() {
@ -522,9 +537,31 @@ void common_models_handler_apply(common_models_handler & handler, common_params
});
}
// handle plan_spec (e.g. --spec-draft-hf)
if (!plan_spec.model_files.empty()) {
add_tasks(plan_spec.model_files, params.speculative.draft.mparams);
}
// handle vocoder plan (e.g. --hf-repo-v)
if (!plan_voc.model_files.empty()) {
add_tasks(plan_voc.model_files, params.vocoder.model);
}
// run all tasks in parallel
if (!params.offline) {
common_download_run_tasks(tasks);
// if duplicated files are found, only download once (but still call on_done for each task)
std::unordered_map<std::string, common_download_task *> unique_tasks;
for (auto & task : tasks) {
auto it = unique_tasks.find(task.local_path);
if (it == unique_tasks.end()) {
unique_tasks[task.local_path] = &task;
}
}
std::vector<common_download_task> unique_tasks_vec;
for (auto & pair : unique_tasks) {
unique_tasks_vec.push_back(*pair.second);
}
common_download_run_tasks(unique_tasks_vec);
}
// download successful, update params with the downloaded paths
@ -3711,6 +3748,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
"draft model for speculative decoding (default: unused)",
[](common_params & params, const std::string & value) {
params.speculative.draft.mparams.path = value;
params.speculative.draft.mparams.hf_file = value; // will be used if --spec-draft-hf is set
}
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_DRAFT_MODEL"));
add_opt(common_arg(

View File

@ -133,6 +133,8 @@ void common_params_add_preset_options(std::vector<common_arg> & args);
struct common_models_handler {
common_download_hf_plan plan;
common_download_hf_plan plan_spec;
common_download_hf_plan plan_voc;
common_download_opts opts;
};