diff --git a/common/arg.cpp b/common/arg.cpp index 3c0e91e398..8d2f722a65 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -3342,6 +3342,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.offline = true; } ).set_env("LLAMA_ARG_OFFLINE")); + add_opt(common_arg( + {"--measure-only"}, + "Load the model to measure memory requirements, print to stdout, then exit", + [](common_params & params) { + params.measure_only = true; + } + )); add_opt(common_arg( {"-lv", "--verbosity", "--log-verbosity"}, "N", string_format("Set the verbosity threshold. Messages with a higher verbosity will be ignored. Values:\n" diff --git a/common/common.h b/common/common.h index 8f202dbe55..1e875e887b 100644 --- a/common/common.h +++ b/common/common.h @@ -511,6 +511,7 @@ struct common_params { int32_t control_vector_layer_end = -1; // layer range for control vector bool offline = false; bool skip_download = false; // skip model file downloading + bool measure_only = false; // load model with no_alloc to measure memory, print to stdout, then exit int32_t ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used. int32_t ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line diff --git a/tools/server/server-models.cpp b/tools/server/server-models.cpp index df2460a1ce..324dd9e62f 100644 --- a/tools/server/server-models.cpp +++ b/tools/server/server-models.cpp @@ -8,8 +8,6 @@ #include // TODO: remove this once we use HTTP client from download.h #include -#include "../../src/llama-ext.h" - #include #include #include @@ -894,76 +892,119 @@ void server_models::unload_lru(const buft_memory_map & bmm_req) { } } -static buft_memory_map get_model_memory_per_buft(const common_preset & preset) { - common_params params; - preset.apply_to_params(params); - - if (params.model.path.empty()) { - return {}; +buft_memory_map server_models::estimate_model_memory(const std::string & name) { + std::vector child_args; + std::vector child_env; + { + std::lock_guard lk(mutex); + auto & meta = mapping[name].meta; + child_args = meta.preset.to_args(bin_path); + child_env = base_env; } + child_args.push_back("--measure-only"); + child_args.push_back("--offline"); - struct log_ud_t { - struct { - ggml_log_callback callback; - void * user_data; - } original; - ggml_log_level min_level; - } log_ud; - llama_log_get(&log_ud.original.callback, &log_ud.original.user_data); - log_ud.min_level = GGML_LOG_LEVEL_WARN; + SRV_INF("estimating memory for model name=%s\n", name.c_str()); - llama_log_set([](ggml_log_level level, const char * text, void * ud) { - log_ud_t * d = (log_ud_t *) ud; - const ggml_log_level eff = level >= d->min_level ? level : GGML_LOG_LEVEL_DEBUG; - d->original.callback(eff, text, d->original.user_data); - }, &log_ud); + std::vector argv = to_char_ptr_array(child_args); + std::vector envp = to_char_ptr_array(child_env); - llama_model_params mparams = common_model_params_to_llama(params); - mparams.no_alloc = true; - mparams.use_mmap = false; - mparams.use_mlock = false; - - llama_model_ptr model{llama_model_load_from_file(params.model.path.c_str(), mparams)}; - - if (!model) { - llama_log_set(log_ud.original.callback, log_ud.original.user_data); - return {}; - } - - llama_context_params cparams = common_context_params_to_llama(params); - llama_context_ptr ctx{llama_init_from_model(model.get(), cparams)}; - llama_log_set(log_ud.original.callback, log_ud.original.user_data); - - if (!ctx) { + subprocess_s proc; + int options = subprocess_option_no_window | subprocess_option_combined_stdout_stderr; + if (subprocess_create_ex(argv.data(), options, envp.data(), &proc) != 0) { + SRV_ERR("failed to spawn measure process for model name=%s\n", name.c_str()); return {}; } buft_memory_map result; - for (const auto & [buft, data] : llama_get_memory_breakdown(ctx.get())) { - size_t total = data.total(); - if (total > 0) { - result[buft] = total; + FILE * out = subprocess_stdout(&proc); + if (out) { + char buffer[4096]; + while (fgets(buffer, sizeof(buffer), out) != nullptr) { + LOG("[measure:%s] %s", name.c_str(), buffer); + std::string line(buffer); + if (string_starts_with(line, "measure:")) { + std::istringstream iss(line.substr(strlen("measure:"))); + std::string buft_name; + size_t size = 0; + if (iss >> buft_name >> size) { + ggml_backend_buffer_type_t buft = nullptr; + for (size_t i = 0; i < ggml_backend_dev_count(); i++) { + ggml_backend_dev_t dev = ggml_backend_dev_get(i); + ggml_backend_buffer_type_t dev_buft = ggml_backend_dev_buffer_type(dev); + if (dev_buft && buft_name == ggml_backend_buft_name(dev_buft)) { + buft = dev_buft; + break; + } + } + if (buft) { + result[buft] = size; + } else { + SRV_WRN("unknown buft name '%s' from measure child for model name=%s\n", + buft_name.c_str(), name.c_str()); + } + } + } } } + int exit_code = 0; + subprocess_join(&proc, &exit_code); + subprocess_destroy(&proc); + + if (exit_code != 0) { + SRV_ERR("measure process for model name=%s exited with code %d\n", name.c_str(), exit_code); + return {}; + } + + SRV_INF("memory estimation complete for model name=%s\n", name.c_str()); return result; } +void server_models::join_completed_bg_tasks() { + std::vector> to_join; + { + std::lock_guard lk(mutex); + for (auto it = bg_tasks.begin(); it != bg_tasks.end(); ) { + if (it->second->done.load()) { + to_join.push_back(std::move(it->second)); + it = bg_tasks.erase(it); + } else { + ++it; + } + } + } + for (auto & task : to_join) { + if (task->th.joinable()) { + task->th.join(); + } + } +} + + void server_models::load(const std::string & name) { if (!has_model(name)) { throw std::runtime_error("model name=" + name + " is not found"); } + join_completed_bg_tasks(); + buft_memory_map bmm_req; if (base_params.models_memory_margin > 0) { - // determine the required memory by the model upon its first load - std::lock_guard lk(mutex); - auto & meta = mapping[name].meta; - if (meta.bmm_req.empty()) { - meta.bmm_req = get_model_memory_per_buft(meta.preset); + { + std::lock_guard lk(mutex); + bmm_req = mapping[name].meta.bmm_req; + } + if (bmm_req.empty()) { + bmm_req = estimate_model_memory(name); + if (bmm_req.empty()) { + SRV_WRN("failed to estimate memory for model %s, memory limits will not apply\n", name.c_str()); + } + { + std::lock_guard lk(mutex); + mapping[name].meta.bmm_req = bmm_req; + } } - - bmm_req = meta.bmm_req; } unload_lru(bmm_req); @@ -1249,6 +1290,7 @@ void server_models::unload(const std::string & name) { void server_models::unload_all() { std::vector to_join; + std::vector> bg_to_join; { std::lock_guard lk(mutex); for (auto & [name, inst] : mapping) { @@ -1264,15 +1306,26 @@ void server_models::unload_all() { // moving the thread to join list to avoid deadlock to_join.push_back(std::move(inst.th)); } + for (auto & [name, task] : bg_tasks) { + bg_to_join.push_back(std::move(task)); + } + bg_tasks.clear(); } for (auto & th : to_join) { if (th.joinable()) { th.join(); } } + for (auto & task : bg_to_join) { + if (task && task->th.joinable()) { + task->th.join(); + } + } } void server_models::update_status(const std::string & name, server_model_status status, int exit_code) { + join_completed_bg_tasks(); + std::unique_lock lk(mutex); auto it = mapping.find(name); if (it != mapping.end()) { diff --git a/tools/server/server-models.h b/tools/server/server-models.h index 84eb8407b4..7914a1c38a 100644 --- a/tools/server/server-models.h +++ b/tools/server/server-models.h @@ -7,6 +7,7 @@ #include "server-http.h" #include "server-queue.h" +#include #include #include #include @@ -118,6 +119,13 @@ private: std::condition_variable cv_stop; std::set stopping_models; + // background tasks for download/estimate/load pipelines, keyed by model name + struct bg_task { + std::thread th; + std::atomic done{false}; + }; + std::map> bg_tasks; + // set to true while load_models() is executing a reload; load() will wait until clear bool is_reloading = false; @@ -154,6 +162,12 @@ private: // not thread-safe, caller must hold mutex bool limits_exceeded(const buft_memory_map & bmm_req) const; + // estimate model memory by spawning a child process with --measure-only + // returns the buft memory map, or empty map on failure (caller must NOT hold mutex) + buft_memory_map estimate_model_memory(const std::string & name); + + // join and remove completed background tasks + void join_completed_bg_tasks(); public: server_models(const common_params & params, int argc, char ** argv); diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 78ab0318cf..307576eb32 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -11,6 +11,8 @@ #include "llama.h" #include "log.h" +#include "../../src/llama-ext.h" + #include #include #include @@ -120,6 +122,47 @@ int llama_server(int argc, char ** argv) { // struct that contains llama context and inference server_context ctx_server; + llama_backend_init(); + llama_numa_init(params.numa); + + if (params.measure_only) { + llama_model_params mparams = common_model_params_to_llama(params); + mparams.no_alloc = true; + mparams.use_mmap = false; + mparams.use_mlock = false; + + llama_model_ptr model{llama_model_load_from_file(params.model.path.c_str(), mparams)}; + if (!model) { + LOG_ERR("%s: failed to load model for measurement\n", __func__); + llama_backend_free(); + return 1; + } + + llama_context_params cparams = common_context_params_to_llama(params); + llama_context_ptr ctx{llama_init_from_model(model.get(), cparams)}; + if (!ctx) { + LOG_ERR("%s: failed to create context for measurement\n", __func__); + llama_backend_free(); + return 1; + } + + common_log_pause(common_log_main()); + for (const auto & [buft, data] : llama_get_memory_breakdown(ctx.get())) { + size_t total = data.total(); + if (total > 0) { + fprintf(stdout, "measure:%s %zu\n", ggml_backend_buft_name(buft), total); + } + } + fflush(stdout); + common_log_resume(common_log_main()); + + llama_backend_free(); + return 0; + } + + LOG_INF("build_info: %s\n", llama_build_info()); + LOG_INF("%s\n", common_params_get_system_info(params).c_str()); + server_http_context ctx_http; if (!ctx_http.init(params)) { SRV_ERR("%s", "failed to initialize HTTP server\n");