From 024930c6ad80ea9eb61159d807571e64b187bd99 Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Fri, 26 Jun 2026 14:36:03 +0200 Subject: [PATCH] arg: fix handling --spec-draft-hf and --hf-repo-v (#25043) * arg: fix handling --spec-draft-hf and --hf-repo-v * fix missing mparams.hf_file --- common/arg.cpp | 52 +++++++++++++++++++++++++++++++++++++++++++------- common/arg.h | 2 ++ 2 files changed, 47 insertions(+), 7 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 841a38e961..841ca3ce2e 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -352,6 +352,8 @@ static std::string get_default_local_path(const std::string & url) { common_models_handler common_models_handler_init(const common_params & params, llama_example curr_ex) { common_download_hf_plan plan; + common_download_hf_plan plan_spec; + common_download_hf_plan plan_voc; common_download_opts opts; const bool spec_type_draft_mtp = std::find(params.speculative.types.begin(), @@ -377,7 +379,15 @@ common_models_handler common_models_handler_init(const common_params & params, l plan = common_download_get_hf_plan(params.model, opts); } - return common_models_handler{plan, opts}; + if (!params.speculative.draft.mparams.hf_repo.empty()) { + plan_spec = common_download_get_hf_plan(params.speculative.draft.mparams, opts); + } + + if (!params.vocoder.model.hf_repo.empty()) { + plan_voc = common_download_get_hf_plan(params.vocoder.model, opts); + } + + return common_models_handler{plan, plan_spec, plan_voc, opts}; } bool common_models_handler_is_preset_repo(const common_models_handler & handler) { @@ -425,7 +435,9 @@ static std::vector build_url_tasks(const common_params_mod void common_models_handler_apply(common_models_handler & handler, common_params & params, common_download_callback * callback) { std::vector tasks; - auto & plan = handler.plan; + auto & plan = handler.plan; + auto & plan_spec = handler.plan_spec; + auto & plan_voc = handler.plan_voc; auto opts = handler.opts; // copy opts.callback = callback; @@ -484,19 +496,22 @@ void common_models_handler_apply(common_models_handler & handler, common_params } // handle hf_plan tasks - if (!plan.model_files.empty()) { - for (size_t i = 0; i < plan.model_files.size(); ++i) { - auto & model_file = plan.model_files[i]; + auto add_tasks = [&opts, &tasks](const hf_cache::hf_files & model_files, common_params_model & model) { + for (size_t i = 0; i < model_files.size(); ++i) { + auto & model_file = model_files[i]; bool is_first = (i == 0); tasks.emplace_back(model_file, opts, [&, is_first]() { if (is_first) { // only use first part as model path - params.model.path = hf_cache::finalize_file(model_file); + model.path = hf_cache::finalize_file(model_file); } else { hf_cache::finalize_file(model_file); } }); } + }; + if (!plan.model_files.empty()) { + add_tasks(plan.model_files, params.model); } if (!plan.mmproj.local_path.empty()) { tasks.emplace_back(plan.mmproj, opts, [&]() { @@ -522,9 +537,31 @@ void common_models_handler_apply(common_models_handler & handler, common_params }); } + // handle plan_spec (e.g. --spec-draft-hf) + if (!plan_spec.model_files.empty()) { + add_tasks(plan_spec.model_files, params.speculative.draft.mparams); + } + + // handle vocoder plan (e.g. --hf-repo-v) + if (!plan_voc.model_files.empty()) { + add_tasks(plan_voc.model_files, params.vocoder.model); + } + // run all tasks in parallel if (!params.offline) { - common_download_run_tasks(tasks); + // if duplicated files are found, only download once (but still call on_done for each task) + std::unordered_map unique_tasks; + for (auto & task : tasks) { + auto it = unique_tasks.find(task.local_path); + if (it == unique_tasks.end()) { + unique_tasks[task.local_path] = &task; + } + } + std::vector unique_tasks_vec; + for (auto & pair : unique_tasks) { + unique_tasks_vec.push_back(*pair.second); + } + common_download_run_tasks(unique_tasks_vec); } // download successful, update params with the downloaded paths @@ -3711,6 +3748,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex "draft model for speculative decoding (default: unused)", [](common_params & params, const std::string & value) { params.speculative.draft.mparams.path = value; + params.speculative.draft.mparams.hf_file = value; // will be used if --spec-draft-hf is set } ).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_DRAFT_MODEL")); add_opt(common_arg( diff --git a/common/arg.h b/common/arg.h index 508e33d29e..54a38b9cce 100644 --- a/common/arg.h +++ b/common/arg.h @@ -133,6 +133,8 @@ void common_params_add_preset_options(std::vector & args); struct common_models_handler { common_download_hf_plan plan; + common_download_hf_plan plan_spec; + common_download_hf_plan plan_voc; common_download_opts opts; };