mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-27 23:50:20 -05:00
move model memory estimation to subprocess
This commit is contained in:
parent
384a495a00
commit
dbc5f7ec82
@ -3342,6 +3342,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
params.offline = true;
|
||||
}
|
||||
).set_env("LLAMA_ARG_OFFLINE"));
|
||||
add_opt(common_arg(
|
||||
{"--measure-only"},
|
||||
"Load the model to measure memory requirements, print to stdout, then exit",
|
||||
[](common_params & params) {
|
||||
params.measure_only = true;
|
||||
}
|
||||
));
|
||||
add_opt(common_arg(
|
||||
{"-lv", "--verbosity", "--log-verbosity"}, "N",
|
||||
string_format("Set the verbosity threshold. Messages with a higher verbosity will be ignored. Values:\n"
|
||||
|
||||
@ -511,6 +511,7 @@ struct common_params {
|
||||
int32_t control_vector_layer_end = -1; // layer range for control vector
|
||||
bool offline = false;
|
||||
bool skip_download = false; // skip model file downloading
|
||||
bool measure_only = false; // load model with no_alloc to measure memory, print to stdout, then exit
|
||||
|
||||
int32_t ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used.
|
||||
int32_t ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line
|
||||
|
||||
@ -8,8 +8,6 @@
|
||||
#include <cpp-httplib/httplib.h> // TODO: remove this once we use HTTP client from download.h
|
||||
#include <sheredom/subprocess.h>
|
||||
|
||||
#include "../../src/llama-ext.h"
|
||||
|
||||
#include <functional>
|
||||
#include <optional>
|
||||
#include <algorithm>
|
||||
@ -894,76 +892,119 @@ void server_models::unload_lru(const buft_memory_map & bmm_req) {
|
||||
}
|
||||
}
|
||||
|
||||
static buft_memory_map get_model_memory_per_buft(const common_preset & preset) {
|
||||
common_params params;
|
||||
preset.apply_to_params(params);
|
||||
|
||||
if (params.model.path.empty()) {
|
||||
return {};
|
||||
buft_memory_map server_models::estimate_model_memory(const std::string & name) {
|
||||
std::vector<std::string> child_args;
|
||||
std::vector<std::string> child_env;
|
||||
{
|
||||
std::lock_guard<std::mutex> lk(mutex);
|
||||
auto & meta = mapping[name].meta;
|
||||
child_args = meta.preset.to_args(bin_path);
|
||||
child_env = base_env;
|
||||
}
|
||||
child_args.push_back("--measure-only");
|
||||
child_args.push_back("--offline");
|
||||
|
||||
struct log_ud_t {
|
||||
struct {
|
||||
ggml_log_callback callback;
|
||||
void * user_data;
|
||||
} original;
|
||||
ggml_log_level min_level;
|
||||
} log_ud;
|
||||
llama_log_get(&log_ud.original.callback, &log_ud.original.user_data);
|
||||
log_ud.min_level = GGML_LOG_LEVEL_WARN;
|
||||
SRV_INF("estimating memory for model name=%s\n", name.c_str());
|
||||
|
||||
llama_log_set([](ggml_log_level level, const char * text, void * ud) {
|
||||
log_ud_t * d = (log_ud_t *) ud;
|
||||
const ggml_log_level eff = level >= d->min_level ? level : GGML_LOG_LEVEL_DEBUG;
|
||||
d->original.callback(eff, text, d->original.user_data);
|
||||
}, &log_ud);
|
||||
std::vector<char *> argv = to_char_ptr_array(child_args);
|
||||
std::vector<char *> envp = to_char_ptr_array(child_env);
|
||||
|
||||
llama_model_params mparams = common_model_params_to_llama(params);
|
||||
mparams.no_alloc = true;
|
||||
mparams.use_mmap = false;
|
||||
mparams.use_mlock = false;
|
||||
|
||||
llama_model_ptr model{llama_model_load_from_file(params.model.path.c_str(), mparams)};
|
||||
|
||||
if (!model) {
|
||||
llama_log_set(log_ud.original.callback, log_ud.original.user_data);
|
||||
return {};
|
||||
}
|
||||
|
||||
llama_context_params cparams = common_context_params_to_llama(params);
|
||||
llama_context_ptr ctx{llama_init_from_model(model.get(), cparams)};
|
||||
llama_log_set(log_ud.original.callback, log_ud.original.user_data);
|
||||
|
||||
if (!ctx) {
|
||||
subprocess_s proc;
|
||||
int options = subprocess_option_no_window | subprocess_option_combined_stdout_stderr;
|
||||
if (subprocess_create_ex(argv.data(), options, envp.data(), &proc) != 0) {
|
||||
SRV_ERR("failed to spawn measure process for model name=%s\n", name.c_str());
|
||||
return {};
|
||||
}
|
||||
|
||||
buft_memory_map result;
|
||||
for (const auto & [buft, data] : llama_get_memory_breakdown(ctx.get())) {
|
||||
size_t total = data.total();
|
||||
if (total > 0) {
|
||||
result[buft] = total;
|
||||
FILE * out = subprocess_stdout(&proc);
|
||||
if (out) {
|
||||
char buffer[4096];
|
||||
while (fgets(buffer, sizeof(buffer), out) != nullptr) {
|
||||
LOG("[measure:%s] %s", name.c_str(), buffer);
|
||||
std::string line(buffer);
|
||||
if (string_starts_with(line, "measure:")) {
|
||||
std::istringstream iss(line.substr(strlen("measure:")));
|
||||
std::string buft_name;
|
||||
size_t size = 0;
|
||||
if (iss >> buft_name >> size) {
|
||||
ggml_backend_buffer_type_t buft = nullptr;
|
||||
for (size_t i = 0; i < ggml_backend_dev_count(); i++) {
|
||||
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
|
||||
ggml_backend_buffer_type_t dev_buft = ggml_backend_dev_buffer_type(dev);
|
||||
if (dev_buft && buft_name == ggml_backend_buft_name(dev_buft)) {
|
||||
buft = dev_buft;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (buft) {
|
||||
result[buft] = size;
|
||||
} else {
|
||||
SRV_WRN("unknown buft name '%s' from measure child for model name=%s\n",
|
||||
buft_name.c_str(), name.c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int exit_code = 0;
|
||||
subprocess_join(&proc, &exit_code);
|
||||
subprocess_destroy(&proc);
|
||||
|
||||
if (exit_code != 0) {
|
||||
SRV_ERR("measure process for model name=%s exited with code %d\n", name.c_str(), exit_code);
|
||||
return {};
|
||||
}
|
||||
|
||||
SRV_INF("memory estimation complete for model name=%s\n", name.c_str());
|
||||
return result;
|
||||
}
|
||||
|
||||
void server_models::join_completed_bg_tasks() {
|
||||
std::vector<std::unique_ptr<bg_task>> to_join;
|
||||
{
|
||||
std::lock_guard<std::mutex> lk(mutex);
|
||||
for (auto it = bg_tasks.begin(); it != bg_tasks.end(); ) {
|
||||
if (it->second->done.load()) {
|
||||
to_join.push_back(std::move(it->second));
|
||||
it = bg_tasks.erase(it);
|
||||
} else {
|
||||
++it;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (auto & task : to_join) {
|
||||
if (task->th.joinable()) {
|
||||
task->th.join();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void server_models::load(const std::string & name) {
|
||||
if (!has_model(name)) {
|
||||
throw std::runtime_error("model name=" + name + " is not found");
|
||||
}
|
||||
|
||||
join_completed_bg_tasks();
|
||||
|
||||
buft_memory_map bmm_req;
|
||||
if (base_params.models_memory_margin > 0) {
|
||||
// determine the required memory by the model upon its first load
|
||||
std::lock_guard<std::mutex> lk(mutex);
|
||||
auto & meta = mapping[name].meta;
|
||||
if (meta.bmm_req.empty()) {
|
||||
meta.bmm_req = get_model_memory_per_buft(meta.preset);
|
||||
{
|
||||
std::lock_guard<std::mutex> lk(mutex);
|
||||
bmm_req = mapping[name].meta.bmm_req;
|
||||
}
|
||||
if (bmm_req.empty()) {
|
||||
bmm_req = estimate_model_memory(name);
|
||||
if (bmm_req.empty()) {
|
||||
SRV_WRN("failed to estimate memory for model %s, memory limits will not apply\n", name.c_str());
|
||||
}
|
||||
{
|
||||
std::lock_guard<std::mutex> lk(mutex);
|
||||
mapping[name].meta.bmm_req = bmm_req;
|
||||
}
|
||||
}
|
||||
|
||||
bmm_req = meta.bmm_req;
|
||||
}
|
||||
|
||||
unload_lru(bmm_req);
|
||||
@ -1249,6 +1290,7 @@ void server_models::unload(const std::string & name) {
|
||||
|
||||
void server_models::unload_all() {
|
||||
std::vector<std::thread> to_join;
|
||||
std::vector<std::unique_ptr<bg_task>> bg_to_join;
|
||||
{
|
||||
std::lock_guard<std::mutex> lk(mutex);
|
||||
for (auto & [name, inst] : mapping) {
|
||||
@ -1264,15 +1306,26 @@ void server_models::unload_all() {
|
||||
// moving the thread to join list to avoid deadlock
|
||||
to_join.push_back(std::move(inst.th));
|
||||
}
|
||||
for (auto & [name, task] : bg_tasks) {
|
||||
bg_to_join.push_back(std::move(task));
|
||||
}
|
||||
bg_tasks.clear();
|
||||
}
|
||||
for (auto & th : to_join) {
|
||||
if (th.joinable()) {
|
||||
th.join();
|
||||
}
|
||||
}
|
||||
for (auto & task : bg_to_join) {
|
||||
if (task && task->th.joinable()) {
|
||||
task->th.join();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void server_models::update_status(const std::string & name, server_model_status status, int exit_code) {
|
||||
join_completed_bg_tasks();
|
||||
|
||||
std::unique_lock<std::mutex> lk(mutex);
|
||||
auto it = mapping.find(name);
|
||||
if (it != mapping.end()) {
|
||||
|
||||
@ -7,6 +7,7 @@
|
||||
#include "server-http.h"
|
||||
#include "server-queue.h"
|
||||
|
||||
#include <atomic>
|
||||
#include <mutex>
|
||||
#include <condition_variable>
|
||||
#include <functional>
|
||||
@ -118,6 +119,13 @@ private:
|
||||
std::condition_variable cv_stop;
|
||||
std::set<std::string> stopping_models;
|
||||
|
||||
// background tasks for download/estimate/load pipelines, keyed by model name
|
||||
struct bg_task {
|
||||
std::thread th;
|
||||
std::atomic<bool> done{false};
|
||||
};
|
||||
std::map<std::string, std::unique_ptr<bg_task>> bg_tasks;
|
||||
|
||||
// set to true while load_models() is executing a reload; load() will wait until clear
|
||||
bool is_reloading = false;
|
||||
|
||||
@ -154,6 +162,12 @@ private:
|
||||
// not thread-safe, caller must hold mutex
|
||||
bool limits_exceeded(const buft_memory_map & bmm_req) const;
|
||||
|
||||
// estimate model memory by spawning a child process with --measure-only
|
||||
// returns the buft memory map, or empty map on failure (caller must NOT hold mutex)
|
||||
buft_memory_map estimate_model_memory(const std::string & name);
|
||||
|
||||
// join and remove completed background tasks
|
||||
void join_completed_bg_tasks();
|
||||
public:
|
||||
server_models(const common_params & params, int argc, char ** argv);
|
||||
|
||||
|
||||
@ -11,6 +11,8 @@
|
||||
#include "llama.h"
|
||||
#include "log.h"
|
||||
|
||||
#include "../../src/llama-ext.h"
|
||||
|
||||
#include <atomic>
|
||||
#include <clocale>
|
||||
#include <exception>
|
||||
@ -120,6 +122,47 @@ int llama_server(int argc, char ** argv) {
|
||||
// struct that contains llama context and inference
|
||||
server_context ctx_server;
|
||||
|
||||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
|
||||
if (params.measure_only) {
|
||||
llama_model_params mparams = common_model_params_to_llama(params);
|
||||
mparams.no_alloc = true;
|
||||
mparams.use_mmap = false;
|
||||
mparams.use_mlock = false;
|
||||
|
||||
llama_model_ptr model{llama_model_load_from_file(params.model.path.c_str(), mparams)};
|
||||
if (!model) {
|
||||
LOG_ERR("%s: failed to load model for measurement\n", __func__);
|
||||
llama_backend_free();
|
||||
return 1;
|
||||
}
|
||||
|
||||
llama_context_params cparams = common_context_params_to_llama(params);
|
||||
llama_context_ptr ctx{llama_init_from_model(model.get(), cparams)};
|
||||
if (!ctx) {
|
||||
LOG_ERR("%s: failed to create context for measurement\n", __func__);
|
||||
llama_backend_free();
|
||||
return 1;
|
||||
}
|
||||
|
||||
common_log_pause(common_log_main());
|
||||
for (const auto & [buft, data] : llama_get_memory_breakdown(ctx.get())) {
|
||||
size_t total = data.total();
|
||||
if (total > 0) {
|
||||
fprintf(stdout, "measure:%s %zu\n", ggml_backend_buft_name(buft), total);
|
||||
}
|
||||
}
|
||||
fflush(stdout);
|
||||
common_log_resume(common_log_main());
|
||||
|
||||
llama_backend_free();
|
||||
return 0;
|
||||
}
|
||||
|
||||
LOG_INF("build_info: %s\n", llama_build_info());
|
||||
LOG_INF("%s\n", common_params_get_system_info(params).c_str());
|
||||
|
||||
server_http_context ctx_http;
|
||||
if (!ctx_http.init(params)) {
|
||||
SRV_ERR("%s", "failed to initialize HTTP server\n");
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user