move model memory estimation to subprocess

This commit is contained in:
Ruben Ortlam 2026-05-13 17:50:11 +02:00
parent 384a495a00
commit dbc5f7ec82
5 changed files with 168 additions and 50 deletions

View File

@ -3342,6 +3342,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.offline = true;
}
).set_env("LLAMA_ARG_OFFLINE"));
add_opt(common_arg(
{"--measure-only"},
"Load the model to measure memory requirements, print to stdout, then exit",
[](common_params & params) {
params.measure_only = true;
}
));
add_opt(common_arg(
{"-lv", "--verbosity", "--log-verbosity"}, "N",
string_format("Set the verbosity threshold. Messages with a higher verbosity will be ignored. Values:\n"

View File

@ -511,6 +511,7 @@ struct common_params {
int32_t control_vector_layer_end = -1; // layer range for control vector
bool offline = false;
bool skip_download = false; // skip model file downloading
bool measure_only = false; // load model with no_alloc to measure memory, print to stdout, then exit
int32_t ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used.
int32_t ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line

View File

@ -8,8 +8,6 @@
#include <cpp-httplib/httplib.h> // TODO: remove this once we use HTTP client from download.h
#include <sheredom/subprocess.h>
#include "../../src/llama-ext.h"
#include <functional>
#include <optional>
#include <algorithm>
@ -894,76 +892,119 @@ void server_models::unload_lru(const buft_memory_map & bmm_req) {
}
}
static buft_memory_map get_model_memory_per_buft(const common_preset & preset) {
common_params params;
preset.apply_to_params(params);
if (params.model.path.empty()) {
return {};
buft_memory_map server_models::estimate_model_memory(const std::string & name) {
std::vector<std::string> child_args;
std::vector<std::string> child_env;
{
std::lock_guard<std::mutex> lk(mutex);
auto & meta = mapping[name].meta;
child_args = meta.preset.to_args(bin_path);
child_env = base_env;
}
child_args.push_back("--measure-only");
child_args.push_back("--offline");
struct log_ud_t {
struct {
ggml_log_callback callback;
void * user_data;
} original;
ggml_log_level min_level;
} log_ud;
llama_log_get(&log_ud.original.callback, &log_ud.original.user_data);
log_ud.min_level = GGML_LOG_LEVEL_WARN;
SRV_INF("estimating memory for model name=%s\n", name.c_str());
llama_log_set([](ggml_log_level level, const char * text, void * ud) {
log_ud_t * d = (log_ud_t *) ud;
const ggml_log_level eff = level >= d->min_level ? level : GGML_LOG_LEVEL_DEBUG;
d->original.callback(eff, text, d->original.user_data);
}, &log_ud);
std::vector<char *> argv = to_char_ptr_array(child_args);
std::vector<char *> envp = to_char_ptr_array(child_env);
llama_model_params mparams = common_model_params_to_llama(params);
mparams.no_alloc = true;
mparams.use_mmap = false;
mparams.use_mlock = false;
llama_model_ptr model{llama_model_load_from_file(params.model.path.c_str(), mparams)};
if (!model) {
llama_log_set(log_ud.original.callback, log_ud.original.user_data);
return {};
}
llama_context_params cparams = common_context_params_to_llama(params);
llama_context_ptr ctx{llama_init_from_model(model.get(), cparams)};
llama_log_set(log_ud.original.callback, log_ud.original.user_data);
if (!ctx) {
subprocess_s proc;
int options = subprocess_option_no_window | subprocess_option_combined_stdout_stderr;
if (subprocess_create_ex(argv.data(), options, envp.data(), &proc) != 0) {
SRV_ERR("failed to spawn measure process for model name=%s\n", name.c_str());
return {};
}
buft_memory_map result;
for (const auto & [buft, data] : llama_get_memory_breakdown(ctx.get())) {
size_t total = data.total();
if (total > 0) {
result[buft] = total;
FILE * out = subprocess_stdout(&proc);
if (out) {
char buffer[4096];
while (fgets(buffer, sizeof(buffer), out) != nullptr) {
LOG("[measure:%s] %s", name.c_str(), buffer);
std::string line(buffer);
if (string_starts_with(line, "measure:")) {
std::istringstream iss(line.substr(strlen("measure:")));
std::string buft_name;
size_t size = 0;
if (iss >> buft_name >> size) {
ggml_backend_buffer_type_t buft = nullptr;
for (size_t i = 0; i < ggml_backend_dev_count(); i++) {
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
ggml_backend_buffer_type_t dev_buft = ggml_backend_dev_buffer_type(dev);
if (dev_buft && buft_name == ggml_backend_buft_name(dev_buft)) {
buft = dev_buft;
break;
}
}
if (buft) {
result[buft] = size;
} else {
SRV_WRN("unknown buft name '%s' from measure child for model name=%s\n",
buft_name.c_str(), name.c_str());
}
}
}
}
}
int exit_code = 0;
subprocess_join(&proc, &exit_code);
subprocess_destroy(&proc);
if (exit_code != 0) {
SRV_ERR("measure process for model name=%s exited with code %d\n", name.c_str(), exit_code);
return {};
}
SRV_INF("memory estimation complete for model name=%s\n", name.c_str());
return result;
}
void server_models::join_completed_bg_tasks() {
std::vector<std::unique_ptr<bg_task>> to_join;
{
std::lock_guard<std::mutex> lk(mutex);
for (auto it = bg_tasks.begin(); it != bg_tasks.end(); ) {
if (it->second->done.load()) {
to_join.push_back(std::move(it->second));
it = bg_tasks.erase(it);
} else {
++it;
}
}
}
for (auto & task : to_join) {
if (task->th.joinable()) {
task->th.join();
}
}
}
void server_models::load(const std::string & name) {
if (!has_model(name)) {
throw std::runtime_error("model name=" + name + " is not found");
}
join_completed_bg_tasks();
buft_memory_map bmm_req;
if (base_params.models_memory_margin > 0) {
// determine the required memory by the model upon its first load
std::lock_guard<std::mutex> lk(mutex);
auto & meta = mapping[name].meta;
if (meta.bmm_req.empty()) {
meta.bmm_req = get_model_memory_per_buft(meta.preset);
{
std::lock_guard<std::mutex> lk(mutex);
bmm_req = mapping[name].meta.bmm_req;
}
if (bmm_req.empty()) {
bmm_req = estimate_model_memory(name);
if (bmm_req.empty()) {
SRV_WRN("failed to estimate memory for model %s, memory limits will not apply\n", name.c_str());
}
{
std::lock_guard<std::mutex> lk(mutex);
mapping[name].meta.bmm_req = bmm_req;
}
}
bmm_req = meta.bmm_req;
}
unload_lru(bmm_req);
@ -1249,6 +1290,7 @@ void server_models::unload(const std::string & name) {
void server_models::unload_all() {
std::vector<std::thread> to_join;
std::vector<std::unique_ptr<bg_task>> bg_to_join;
{
std::lock_guard<std::mutex> lk(mutex);
for (auto & [name, inst] : mapping) {
@ -1264,15 +1306,26 @@ void server_models::unload_all() {
// moving the thread to join list to avoid deadlock
to_join.push_back(std::move(inst.th));
}
for (auto & [name, task] : bg_tasks) {
bg_to_join.push_back(std::move(task));
}
bg_tasks.clear();
}
for (auto & th : to_join) {
if (th.joinable()) {
th.join();
}
}
for (auto & task : bg_to_join) {
if (task && task->th.joinable()) {
task->th.join();
}
}
}
void server_models::update_status(const std::string & name, server_model_status status, int exit_code) {
join_completed_bg_tasks();
std::unique_lock<std::mutex> lk(mutex);
auto it = mapping.find(name);
if (it != mapping.end()) {

View File

@ -7,6 +7,7 @@
#include "server-http.h"
#include "server-queue.h"
#include <atomic>
#include <mutex>
#include <condition_variable>
#include <functional>
@ -118,6 +119,13 @@ private:
std::condition_variable cv_stop;
std::set<std::string> stopping_models;
// background tasks for download/estimate/load pipelines, keyed by model name
struct bg_task {
std::thread th;
std::atomic<bool> done{false};
};
std::map<std::string, std::unique_ptr<bg_task>> bg_tasks;
// set to true while load_models() is executing a reload; load() will wait until clear
bool is_reloading = false;
@ -154,6 +162,12 @@ private:
// not thread-safe, caller must hold mutex
bool limits_exceeded(const buft_memory_map & bmm_req) const;
// estimate model memory by spawning a child process with --measure-only
// returns the buft memory map, or empty map on failure (caller must NOT hold mutex)
buft_memory_map estimate_model_memory(const std::string & name);
// join and remove completed background tasks
void join_completed_bg_tasks();
public:
server_models(const common_params & params, int argc, char ** argv);

View File

@ -11,6 +11,8 @@
#include "llama.h"
#include "log.h"
#include "../../src/llama-ext.h"
#include <atomic>
#include <clocale>
#include <exception>
@ -120,6 +122,47 @@ int llama_server(int argc, char ** argv) {
// struct that contains llama context and inference
server_context ctx_server;
llama_backend_init();
llama_numa_init(params.numa);
if (params.measure_only) {
llama_model_params mparams = common_model_params_to_llama(params);
mparams.no_alloc = true;
mparams.use_mmap = false;
mparams.use_mlock = false;
llama_model_ptr model{llama_model_load_from_file(params.model.path.c_str(), mparams)};
if (!model) {
LOG_ERR("%s: failed to load model for measurement\n", __func__);
llama_backend_free();
return 1;
}
llama_context_params cparams = common_context_params_to_llama(params);
llama_context_ptr ctx{llama_init_from_model(model.get(), cparams)};
if (!ctx) {
LOG_ERR("%s: failed to create context for measurement\n", __func__);
llama_backend_free();
return 1;
}
common_log_pause(common_log_main());
for (const auto & [buft, data] : llama_get_memory_breakdown(ctx.get())) {
size_t total = data.total();
if (total > 0) {
fprintf(stdout, "measure:%s %zu\n", ggml_backend_buft_name(buft), total);
}
}
fflush(stdout);
common_log_resume(common_log_main());
llama_backend_free();
return 0;
}
LOG_INF("build_info: %s\n", llama_build_info());
LOG_INF("%s\n", common_params_get_system_info(params).c_str());
server_http_context ctx_http;
if (!ctx_http.init(params)) {
SRV_ERR("%s", "failed to initialize HTTP server\n");