diff --git a/app/download.cpp b/app/download.cpp index f7ac55dedc..7227baadcb 100644 --- a/app/download.cpp +++ b/app/download.cpp @@ -38,7 +38,8 @@ int llama_download(int argc, char ** argv) { } try { - common_params_handle_models(params, LLAMA_EXAMPLE_DOWNLOAD, {}); + common_models_handler handler = common_models_handler_init(params, LLAMA_EXAMPLE_DOWNLOAD); + common_models_handler_apply(handler, params); } catch (const std::exception & e) { fprintf(stderr, "error: %s\n", e.what()); return 1; diff --git a/common/arg.cpp b/common/arg.cpp index 494df2073c..841a38e961 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -297,60 +297,6 @@ struct handle_model_result { std::string preset_path; }; -static handle_model_result common_params_handle_model(struct common_params_model & model, - const common_download_opts & opts) { - handle_model_result result; - - // TODO @ngxson : refactor this into a new common_model_download_context - - if (!model.docker_repo.empty()) { - model.path = common_docker_resolve_model(model.docker_repo); - } else if (!model.hf_repo.empty()) { - // If -m was used with -hf, treat the model "path" as the hf_file to download - if (model.hf_file.empty() && !model.path.empty()) { - model.hf_file = model.path; - model.path = ""; - } - common_download_opts hf_opts = opts; - auto download_result = common_download_model(model, hf_opts); - - if (!download_result.preset_path.empty()) { - result.found_preset = true; - result.preset_path = download_result.preset_path; - return result; // skip everything else if preset.ini is used - } - - if (download_result.model_path.empty()) { - throw std::runtime_error("failed to download model from Hugging Face"); - } - - model.path = download_result.model_path; - - if (!download_result.mmproj_path.empty()) { - result.found_mmproj = true; - result.mmproj.path = download_result.mmproj_path; - } - - if (!download_result.mtp_path.empty()) { - result.found_mtp = true; - result.mtp.path = download_result.mtp_path; - } - } else if (!model.url.empty()) { - if (model.path.empty()) { - auto f = string_split(model.url, '#').front(); - f = string_split(f, '?').front(); - model.path = fs_get_cache_file(string_split(f, '/').back()); - } - - auto download_result = common_download_model(model, opts); - if (download_result.model_path.empty()) { - throw std::runtime_error("failed to download model from " + model.url); - } - } - - return result; -} - const std::vector kv_cache_types = { GGML_TYPE_F32, GGML_TYPE_F16, @@ -395,77 +341,204 @@ static bool parse_bool_value(const std::string & value) { } // -// CLI argument parsing functions +// common_models_handler // -bool common_params_handle_models(common_params & params, llama_example curr_ex, const common_params_handle_models_params & handle_params) { - const bool spec_type_draft_mtp = std::find(params.speculative.types.begin(), - params.speculative.types.end(), - COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params.speculative.types.end(); +static std::string get_default_local_path(const std::string & url) { + auto f = string_split(url, '#').front(); + f = string_split(f, '?').front(); + return fs_get_cache_file(string_split(f, '/').back()); +} +common_models_handler common_models_handler_init(const common_params & params, llama_example curr_ex) { + common_download_hf_plan plan; common_download_opts opts; + + const bool spec_type_draft_mtp = std::find(params.speculative.types.begin(), + params.speculative.types.end(), + COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params.speculative.types.end(); + + // only download mmproj if the current example is using it + bool use_mmproj = false; + for (const auto & ex : mmproj_examples) { + if (curr_ex == ex) { + use_mmproj = true; + break; + } + } + opts.bearer_token = params.hf_token; opts.offline = params.offline; - opts.skip_download = params.skip_download; opts.download_mtp = spec_type_draft_mtp; - opts.download_mmproj = !params.no_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty(); - opts.preset_only = handle_params.preset_only; + opts.download_mmproj = use_mmproj && !params.no_mmproj + && params.mmproj.path.empty() && params.mmproj.url.empty(); - if (handle_params.callback) { - opts.callback = handle_params.callback; + if (!params.model.hf_repo.empty()) { + plan = common_download_get_hf_plan(params.model, opts); } - // sub-models (draft, mmproj, vocoder) are explicitly specified by the user, - // so we should not auto-discover mtp/mmproj siblings for them - common_download_opts sub_opts = opts; - sub_opts.download_mtp = false; - sub_opts.download_mmproj = false; + return common_models_handler{plan, opts}; +} - try { - auto res = common_params_handle_model(params.model, opts); - if (res.found_preset) { - if (!params.models_preset.empty()) { - throw std::invalid_argument("cannot use both --models-preset and -hf with a preset.ini file"); +bool common_models_handler_is_preset_repo(const common_models_handler & handler) { + return !handler.plan.preset.url.empty(); +} + +static std::vector build_url_tasks(const common_params_model & model, common_download_opts opts) { + auto parts = common_download_get_all_parts(model.url); + std::vector tasks; + + // single-part: download straight to model.path if the user gave one (-m), else the cache default + if (parts.size() == 1) { + common_download_task task; + task.url = parts[0]; + task.local_path = model.path.empty() ? get_default_local_path(parts[0]) : model.path; + task.opts = opts; + tasks.push_back(std::move(task)); + return tasks; + } + + // multi-part: place each part under the user's -m directory (if given), else the cache default + std::string base_dir; + if (!model.path.empty()) { + auto pos = model.path.rfind('/'); + base_dir = pos == std::string::npos ? std::string(".") : model.path.substr(0, pos); + } + + for (const auto & part : parts) { + common_download_task task; + task.url = part; + task.opts = opts; + + std::string local = get_default_local_path(part); + if (!base_dir.empty()) { + auto pos = local.rfind('/'); + std::string name = pos == std::string::npos ? local : local.substr(pos + 1); + local = base_dir + "/" + name; + } + task.local_path = local; + tasks.push_back(std::move(task)); + } + return tasks; +} + +void common_models_handler_apply(common_models_handler & handler, common_params & params, common_download_callback * callback) { + std::vector tasks; + + auto & plan = handler.plan; + + auto opts = handler.opts; // copy + opts.callback = callback; + + // handle plain "url" if needed + auto handle_url = [&](common_params_model & model) { + if (!model.url.empty()) { + if (model.path.empty()) { + model.path = get_default_local_path(model.url); } + } + }; + handle_url(params.model); + handle_url(params.mmproj); + handle_url(params.vocoder.model); + handle_url(params.speculative.draft.mparams); + + // optionally, if docker repo is set, resolve it + if (!params.model.docker_repo.empty()) { + params.model.url = common_docker_resolve_model(params.model.docker_repo); + params.model.path = get_default_local_path(params.model.url); + } + + // handle plain "url" tasks (non-hf) + if (!params.model.url.empty()) { + auto url_tasks = build_url_tasks(params.model, opts); + // the first part is what gets loaded, so point params.model.path at it + if (!url_tasks.empty()) { + std::string first_path = url_tasks.front().local_path; + url_tasks.front().on_done = [&]() { params.model.path = first_path; }; + } + for (auto & task : url_tasks) { + tasks.push_back(std::move(task)); + } + } + if (!params.mmproj.url.empty()) { + common_download_task task; + task.url = params.mmproj.url; + task.local_path = params.mmproj.path; + task.opts = opts; + tasks.push_back(task); + } + if (!params.vocoder.model.url.empty()) { + common_download_task task; + task.url = params.vocoder.model.url; + task.local_path = params.vocoder.model.path; + task.opts = opts; + tasks.push_back(task); + } + if (!params.speculative.draft.mparams.url.empty()) { + common_download_task task; + task.url = params.speculative.draft.mparams.url; + task.local_path = params.speculative.draft.mparams.path; + task.opts = opts; + tasks.push_back(task); + } + + // handle hf_plan tasks + if (!plan.model_files.empty()) { + for (size_t i = 0; i < plan.model_files.size(); ++i) { + auto & model_file = plan.model_files[i]; + bool is_first = (i == 0); + tasks.emplace_back(model_file, opts, [&, is_first]() { + if (is_first) { + // only use first part as model path + params.model.path = hf_cache::finalize_file(model_file); + } else { + hf_cache::finalize_file(model_file); + } + }); + } + } + if (!plan.mmproj.local_path.empty()) { + tasks.emplace_back(plan.mmproj, opts, [&]() { + params.mmproj.path = hf_cache::finalize_file(plan.mmproj); + }); + } + if (!plan.mtp.local_path.empty()) { + tasks.emplace_back(plan.mtp, opts, [&]() { + // only fall back to the discovered MTP head when no draft was explicitly provided + if (params.speculative.draft.mparams.empty()) { + params.speculative.draft.mparams.path = hf_cache::finalize_file(plan.mtp); + } else { + hf_cache::finalize_file(plan.mtp); + } + }); + } + if (!plan.preset.local_path.empty()) { + tasks.emplace_back(plan.preset, opts, [&]() { // if HF repo is a preset repo, we simply run server in router mode with the preset.ini file params.models_preset_hf = params.model.hf_repo; // only for showing a warning - params.models_preset = res.preset_path; + params.models_preset = hf_cache::finalize_file(plan.preset); params.model = common_params_model{}; // make sure to clear model, so server starts in router mode - return true; - } + }); + } - if (params.no_mmproj) { - params.mmproj = {}; - } else if (res.found_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty()) { - // optionally, handle mmproj model when -hf is specified - params.mmproj = res.mmproj; - } - // only download mmproj if the current example is using it - for (const auto & ex : mmproj_examples) { - if (curr_ex == ex) { - common_params_handle_model(params.mmproj, sub_opts); - break; - } - } + // run all tasks in parallel + if (!params.offline) { + common_download_run_tasks(tasks); + } - // when --spec-type mtp is set and no draft model was provided explicitly, - // fall back to the MTP head discovered alongside the -hf model - if (spec_type_draft_mtp && res.found_mtp && - params.speculative.draft.mparams.path.empty() && - params.speculative.draft.mparams.hf_repo.empty() && - params.speculative.draft.mparams.url.empty()) { - params.speculative.draft.mparams.path = res.mtp.path; + // download successful, update params with the downloaded paths + for (const auto & task : tasks) { + if (task.on_done) { + task.on_done(); } - common_params_handle_model(params.speculative.draft.mparams, sub_opts); - common_params_handle_model(params.vocoder.model, sub_opts); - return true; - } catch (const common_skip_download_exception &) { - return false; - } catch (const std::exception &) { - throw; } } +// +// CLI argument parsing functions +// + static bool common_params_parse_ex(int argc, char ** argv, common_params_context & ctx_arg) { common_params & params = ctx_arg.params; @@ -601,7 +674,8 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context if (!skip_model_download) { // handle model and download - common_params_handle_models(params, ctx_arg.ex, {}); + common_models_handler handler = common_models_handler_init(params, ctx_arg.ex); + common_models_handler_apply(handler, params); // model is required (except for server) // TODO @ngxson : maybe show a list of available models in CLI in this case diff --git a/common/arg.h b/common/arg.h index fdfc04bc7a..508e33d29e 100644 --- a/common/arg.h +++ b/common/arg.h @@ -8,6 +8,7 @@ #include #include #include +#include // pseudo-env variable to identify preset-only arguments #define COMMON_ARG_PRESET_LOAD_ON_STARTUP "__PRESET_LOAD_ON_STARTUP" @@ -130,19 +131,19 @@ bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map & args); -struct common_params_handle_models_params { - common_download_callback * callback = nullptr; - bool preset_only = false; // if true, only check & download remote preset (for router mode) +struct common_models_handler { + common_download_hf_plan plan; + common_download_opts opts; }; -// populate model paths (main model, mmproj, etc) from -hf if necessary -// return true if the model is ready to use -// throw an exception if there is an error that prevents the model from being used (e.g. network error, model not found, etc) -// if params.skip_download is true, no downloads will be attempted. return false if the model is invalid or missing (e.g. ETag check failed) -bool common_params_handle_models( - common_params & params, - llama_example curr_ex, - const common_params_handle_models_params & handle_params); +// initialize downloading opts and hf_plan if needed, but does not download anything yet +common_models_handler common_models_handler_init(const common_params & params, llama_example curr_ex); + +// check if the model is a preset repo (i.e. has a preset file) +bool common_models_handler_is_preset_repo(const common_models_handler & handler); + +// download and update params with the downloaded model path +void common_models_handler_apply(common_models_handler & handler, common_params & params, common_download_callback * callback = nullptr); // initialize argument parser context - used by test-arg-parser and preset common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr); diff --git a/common/common.h b/common/common.h index 279af46c54..94147d5d8c 100644 --- a/common/common.h +++ b/common/common.h @@ -291,13 +291,13 @@ struct common_params_sampling { }; struct common_params_model { - std::string path = ""; // model local path // NOLINT - std::string url = ""; // model url to download // NOLINT - std::string hf_repo = ""; // HF repo // NOLINT - std::string hf_file = ""; // HF file // NOLINT - std::string docker_repo = ""; // Docker repo // NOLINT + std::string path = ""; // model local path + std::string url = ""; // model url to download + std::string hf_repo = ""; // HF repo + std::string hf_file = ""; // HF file + std::string docker_repo = ""; // Docker repo - std::string get_name() { + std::string get_name() const { if (!hf_repo.empty()) { return hf_repo; } @@ -306,6 +306,10 @@ struct common_params_model { } return path; } + + bool empty() const { + return get_name().empty(); + } }; // draft-model-based speculative decoding parameters @@ -368,7 +372,7 @@ struct common_params_speculative { common_params_speculative_ngram_cache ngram_cache; bool has_dft() const { - return !draft.mparams.path.empty() || !draft.mparams.hf_repo.empty(); + return !draft.mparams.empty(); } uint32_t need_n_rs_seq() const { @@ -520,7 +524,6 @@ struct common_params { int32_t control_vector_layer_start = -1; // layer range for control vector int32_t control_vector_layer_end = -1; // layer range for control vector bool offline = false; - bool skip_download = false; // skip model file downloading int32_t ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used. int32_t ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line diff --git a/common/download.cpp b/common/download.cpp index 5b55c76a11..6b69a44188 100644 --- a/common/download.cpp +++ b/common/download.cpp @@ -292,10 +292,6 @@ static int common_download_file_single_online(const std::string & url, const bool file_exists = std::filesystem::exists(path); - if (!file_exists && opts.skip_download) { - return -2; // file is missing and download is disabled - } - if (file_exists && skip_etag) { LOG_DBG("%s: using cached file: %s\n", __func__, path.c_str()); return 304; // 304 Not Modified - fake cached response @@ -362,9 +358,6 @@ static int common_download_file_single_online(const std::string & url, return 304; // 304 Not Modified - fake cached response } // pass this point, the file exists but is different from the server version, so we need to redownload it - if (opts.skip_download) { - return -2; // special code to indicate that the download was skipped due to etag mismatch - } if (remove(path.c_str()) != 0) { LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str()); return -1; @@ -691,19 +684,8 @@ static void list_available_gguf_files(const hf_cache::hf_files & files) { } } -struct hf_plan { - hf_cache::hf_file primary; - hf_cache::hf_files model_files; - hf_cache::hf_file mmproj; - hf_cache::hf_file mtp; - hf_cache::hf_file preset; // if set, only this file is downloaded -}; - -static hf_plan get_hf_plan(const common_params_model & model, - const common_download_opts & opts, - bool download_mmproj, - bool download_mtp) { - hf_plan plan; +common_download_hf_plan common_download_get_hf_plan(const common_params_model & model, const common_download_opts & opts) { + common_download_hf_plan plan; hf_cache::hf_files all; auto [repo, tag] = common_download_split_repo_tag(model.hf_repo); @@ -752,127 +734,49 @@ static hf_plan get_hf_plan(const common_params_model & model, plan.primary = primary; plan.model_files = get_split_files(all, primary); - if (download_mmproj) { + if (opts.download_mmproj) { plan.mmproj = find_best_mmproj(all, primary.path); } - - if (download_mtp) { + if (opts.download_mtp) { plan.mtp = find_best_mtp(all, primary.path); } return plan; } -struct download_task { - std::string url; - std::string path; -}; - -static std::vector get_url_tasks(const common_params_model & model) { - auto split = get_gguf_split_info(model.url); - - if (split.count <= 1) { - return {{model.url, model.path}}; - } - - auto filename = split.prefix; - if (auto pos = split.prefix.rfind('/'); pos != std::string::npos) { - filename = split.prefix.substr(pos + 1); - } - - auto parent_path = std::filesystem::path(model.path).parent_path(); - auto prefix_path = (parent_path / filename).string(); - - std::vector tasks; - for (int i = 1; i <= split.count; i++) { - auto suffix = string_format("-%05d-of-%05d.gguf", i, split.count); - tasks.push_back({split.prefix + suffix, prefix_path + suffix}); - } - return tasks; -} - -common_download_model_result common_download_model(const common_params_model & model, - const common_download_opts & opts) { - common_download_model_result result; - std::vector tasks; - hf_plan hf; - - bool download_mmproj = opts.download_mmproj; - bool download_mtp = opts.download_mtp; - bool preset_only = opts.preset_only; - bool is_hf = !model.hf_repo.empty(); - - if (is_hf) { - hf = get_hf_plan(model, opts, download_mmproj, download_mtp); - if (!hf.preset.path.empty()) { - // if preset.ini exists, only download that file alone - tasks.push_back({hf.preset.url, hf.preset.local_path}); - } else if (!preset_only) { - // only add other files if we're NOT in preset-only mode (normal run, non-router) - for (const auto & f : hf.model_files) { - tasks.push_back({f.url, f.local_path}); - } - if (!hf.mmproj.path.empty()) { - tasks.push_back({hf.mmproj.url, hf.mmproj.local_path}); - } - if (!hf.mtp.path.empty()) { - tasks.push_back({hf.mtp.url, hf.mtp.local_path}); - } - } - } else if (!model.url.empty()) { - tasks = get_url_tasks(model); - } else { - result.model_path = model.path; - return result; - } - - if (tasks.empty()) { - return result; - } - +void common_download_run_tasks(const std::vector & tasks) { std::vector> futures; for (const auto & task : tasks) { futures.push_back(std::async(std::launch::async, - [&task, &opts, is_hf]() { - return common_download_file_single(task.url, task.path, opts, is_hf); + [&task]() { + return common_download_file_single(task.url, task.local_path, task.opts, task.is_hf); } )); } - for (auto & f : futures) { - int status = f.get(); - if (status == -2 && opts.skip_download) { - throw common_skip_download_exception(); - } + for (size_t i = 0; i < futures.size(); ++i) { + std::string url = tasks[i].url; + int status = futures[i].get(); bool is_ok = is_http_status_ok(status); if (!is_ok) { - return {}; + throw std::runtime_error(string_format("Download '%s' failed with status code: %d", url.c_str(), status)); } } +} - if (is_hf) { - if (!hf.preset.path.empty()) { - // if preset.ini is used, do not set other paths - result.preset_path = hf_cache::finalize_file(hf.preset); - } else { - for (const auto & f : hf.model_files) { - hf_cache::finalize_file(f); - } - result.model_path = hf.primary.final_path; +std::vector common_download_get_all_parts(const std::string & url) { + auto split = get_gguf_split_info(url); - if (!hf.mmproj.path.empty()) { - result.mmproj_path = hf_cache::finalize_file(hf.mmproj); - } - - if (!hf.mtp.path.empty()) { - result.mtp_path = hf_cache::finalize_file(hf.mtp); - } - } - } else { - result.model_path = model.path; + if (split.count <= 1) { + return {url}; } - return result; + std::vector parts; + for (int i = 1; i <= split.count; i++) { + auto suffix = string_format("-%05d-of-%05d.gguf", i, split.count); + parts.push_back(split.prefix + suffix); + } + return parts; } // diff --git a/common/download.h b/common/download.h index 755e34ea8c..816e1c7f58 100644 --- a/common/download.h +++ b/common/download.h @@ -1,7 +1,10 @@ #pragma once +#include "hf-cache.h" + #include #include +#include struct common_params_model; @@ -47,67 +50,40 @@ struct common_cached_model_info { } }; -// Options for common_download_model and common_download_file_single +// Options for common_download_file_single struct common_download_opts { std::string bearer_token; common_header_list headers; bool offline = false; - bool skip_download = false; // if true, only validation is performed, common_skip_download_exception may be thrown if the file is missing or invalid bool download_mmproj = false; bool download_mtp = false; - bool preset_only = false; // if true, only check & download remote preset (for router mode) common_download_callback * callback = nullptr; }; -// Result of common_download_model -struct common_download_model_result { - std::string model_path; - std::string mmproj_path; - std::string mtp_path; - std::string preset_path; +struct common_download_task { + common_download_opts opts; + std::string url; + std::string local_path; + std::function on_done; + bool is_hf = false; + + common_download_task() = default; + common_download_task(hf_cache::hf_file f, + const common_download_opts & opts, + std::function on_done = nullptr) + : opts(opts), url(f.url), local_path(f.local_path), on_done(on_done), is_hf(true) {} }; -// throw if the file is missing or invalid (e.g. ETag check failed) -struct common_skip_download_exception : public std::runtime_error { - common_skip_download_exception() : std::runtime_error("skip download") {} -}; +void common_download_run_tasks(const std::vector & tasks); -// Download model from HuggingFace repo or URL -// -// input (via model struct): -// - model.hf_repo: HF repo with optional tag, see common_download_split_repo_tag -// - model.hf_file: specific file in the repo (requires hf_repo) -// - model.url: simple download (used if hf_repo is empty) -// - model.path: local file path -// -// tag matching (for HF repos without model.hf_file): -// - if tag is specified, searches for GGUF matching that quantization -// - if no tag, searches for Q4_K_M, then Q4_0, then first available GGUF -// -// split GGUF: multi-part files like "model-00001-of-00003.gguf" are automatically -// detected and all parts are downloaded -// -// caching: -// - HF repos: uses HuggingFace cache -// - URLs: uses ETag-based caching -// -// when opts.offline=true, no network requests are made -// when download_mmproj=true, searches for mmproj in same directory as model or any parent directory -// then with the closest quantization bits -// when download_mtp=true, applies the same sibling search for an MTP-head GGUF -// -// returns result with model_path, mmproj_path and mtp_path (empty when not found / on failure) -common_download_model_result common_download_model( - const common_params_model & model, - const common_download_opts & opts = {} -); +// if url is a multi-part GGUF file, returns all parts, otherwise returns the single file +std::vector common_download_get_all_parts(const std::string & url); // returns list of cached models std::vector common_list_cached_models(); // download single file from url to local path // returns status code or -1 on error -// returns -2 if the download was skipped due to ETag mismatch (file outdated, skip_download=true) // skip_etag: if true, don't read/write .etag files (for HF cache where filename is the hash) int common_download_file_single(const std::string & url, const std::string & path, @@ -124,3 +100,12 @@ std::string common_docker_resolve_model(const std::string & docker); // - if tag is present, removes only files matching that tag (and orphaned blobs) // returns true if anything was removed bool common_download_remove(const std::string & hf_repo_with_tag); + +struct common_download_hf_plan { + hf_cache::hf_file primary; + hf_cache::hf_files model_files; + hf_cache::hf_file mmproj; + hf_cache::hf_file mtp; + hf_cache::hf_file preset; // if set, only this file is downloaded +}; +common_download_hf_plan common_download_get_hf_plan(const common_params_model & model, const common_download_opts & opts); diff --git a/tools/llama-bench/llama-bench.cpp b/tools/llama-bench/llama-bench.cpp index 55970c0745..2695f58785 100644 --- a/tools/llama-bench/llama-bench.cpp +++ b/tools/llama-bench/llama-bench.cpp @@ -1035,25 +1035,23 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { if (!params.hf_repo.empty()) { for (size_t i = 0; i < params.hf_repo.size(); i++) { - common_params_model model; - - if (params.hf_file.empty() || params.hf_file[i].empty()) { - model.hf_repo = params.hf_repo[i]; - } else { - model.hf_repo = params.hf_repo[i]; - model.hf_file = params.hf_file[i]; + common_params p; + p.hf_token = params.hf_token; + p.offline = params.offline; + p.model.hf_repo = params.hf_repo[i]; + if (!params.hf_file.empty() && !params.hf_file[i].empty()) { + p.model.hf_file = params.hf_file[i]; } - common_download_opts opts; - opts.bearer_token = params.hf_token; - opts.offline = params.offline; - auto download_result = common_download_model(model, opts); - if (download_result.model_path.empty()) { + // only the text model file is needed + common_models_handler models_handler = common_models_handler_init(p, LLAMA_EXAMPLE_BENCH); + common_models_handler_apply(models_handler, p); + if (p.model.path.empty()) { fprintf(stderr, "error: failed to download model from HuggingFace\n"); exit(1); } - params.model.push_back(download_result.model_path); + params.model.push_back(p.model.path); } } diff --git a/tools/server/server-models.cpp b/tools/server/server-models.cpp index a4df3ef108..bb2f43a10d 100644 --- a/tools/server/server-models.cpp +++ b/tools/server/server-models.cpp @@ -223,8 +223,8 @@ void server_model_meta::update_caps() { "LLAMA_ARG_HF_REPO_FILE", }); params.offline = true; - // params.skip_download = true; // TODO: ideally, we should validate the model here, but it takes too much time - common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, {}); + common_models_handler handler = common_models_handler_init(params, LLAMA_EXAMPLE_SERVER); + common_models_handler_apply(handler, params); // note: this won't download the model because offline=true if (params.mmproj.path.empty()) { multimodal = { false, false }; } else { @@ -1393,9 +1393,8 @@ struct server_download_state : public common_download_callback { bool run(common_params & params) { try { - common_params_handle_models_params p; - p.callback = this; - common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, p); + common_models_handler handler = common_models_handler_init(params, LLAMA_EXAMPLE_SERVER); + common_models_handler_apply(handler, params, this); is_ok = true; } catch (const std::exception & e) { auto model_name = params.model.get_name(); @@ -1768,23 +1767,14 @@ void server_models_routes::init_routes() { throw std::invalid_argument("model must be a non-empty string"); } - common_params_model model; - common_download_opts opts; + common_params p; + p.model.hf_repo = name; + p.hf_token = params.hf_token; - model.hf_repo = name; - opts.bearer_token = params.hf_token; - // note: we only check main model, no need sidecar here - opts.download_mmproj = false; - opts.download_mtp = false; - - // first, only check if the model is valid and can be downloaded - opts.skip_download = true; + // validate by fetching metadata bool ok = false; try { - auto validation = common_download_model(model, opts); - ok = !validation.model_path.empty(); - } catch (const common_skip_download_exception &) { - // model is valid and will be downloaded + common_models_handler_init(p, LLAMA_EXAMPLE_SERVER); ok = true; } catch (...) { SRV_ERR("unknown error while validating model '%s'\n", name.c_str()); diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 4165c1015e..680590871f 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -89,15 +89,16 @@ int llama_server(int argc, char ** argv) { llama_backend_init(); llama_numa_init(params.numa); - // note: router mode also accepts -hf remote-preset, so we need to check that first - if (!params.model.hf_repo.empty()) { - try { - common_params_handle_models_params handle_params; - handle_params.preset_only = true; - common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, handle_params); - } catch (const std::exception & e) { - // ignored for now + common_models_handler models_handler; + try { + models_handler = common_models_handler_init(params, LLAMA_EXAMPLE_SERVER); + if (common_models_handler_is_preset_repo(models_handler)) { + // apply the preset and start the server in router mode + common_models_handler_apply(models_handler, params); } + } catch (const std::exception & e) { + SRV_ERR("failed to fetch model metadata: %s\n", e.what()); + return 1; } // router server never loads a model and must not touch the GPU @@ -274,7 +275,12 @@ int llama_server(int argc, char ** argv) { return child.run_download(params); } else if (!is_router_server) { // single-model mode (NOT spawned by router) - common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, {}); + try { + common_models_handler_apply(models_handler, params); + } catch (const std::exception & e) { + SRV_ERR("failed to download model: %s\n", e.what()); + return 1; + } } //