use memory margin instead of total size limit, apply to each device separately

This commit is contained in:
Ruben Ortlam 2026-04-02 09:24:53 +02:00
parent 4ed48154b0
commit c749b6882c
6 changed files with 146 additions and 70 deletions

View File

@ -3100,12 +3100,12 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_MODELS_MAX"));
add_opt(common_arg(
{"--models-memory-max"}, "N",
string_format("for router server, maximum memory usage in MB (default: %d, 0 = unlimited)", params.models_memory_max),
{"--models-memory-margin"}, "N",
string_format("for router server, MB of memory to leave free, per device (default: %d, 0 = unlimited)", params.models_memory_margin),
[](common_params & params, int value) {
params.models_memory_max = value;
params.models_memory_margin = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_MODELS_MEMORY_MAX"));
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_MODELS_MEMORY_MARGIN"));
add_opt(common_arg(
{"--models-autoload"},
{"--no-models-autoload"},

View File

@ -645,7 +645,7 @@ struct common_params {
std::string models_dir = ""; // directory containing models for the router server
std::string models_preset = ""; // directory containing model presets for the router server
int models_max = 4; // maximum number of models to load simultaneously
int models_memory_max = 0; // maximum memory usage in MB (0 = unlimited, estimated from model files)
int models_memory_margin = 1024; // MB of free memory to preserve per device (0 = disabled)
bool models_autoload = true; // automatically load models when requested via the router server
std::string models_preset_hf = ""; // show a warning about remote presets on router loaded (if not empty)

View File

@ -1548,6 +1548,12 @@ extern "C" {
LLAMA_API void llama_perf_sampler_print(const struct llama_sampler * chain);
LLAMA_API void llama_perf_sampler_reset( struct llama_sampler * chain);
// Returns the projected memory use (model + context + compute) in bytes
// for the given device within this context. Returns 0 if the device is not used.
LLAMA_API uint64_t llama_context_device_memory(
const struct llama_context * ctx,
ggml_backend_dev_t device);
//
// training
//

View File

@ -4088,6 +4088,19 @@ void llama_perf_context_reset(llama_context * ctx) {
ctx->perf_reset();
}
uint64_t llama_context_device_memory(const llama_context * ctx, ggml_backend_dev_t device) {
const bool is_host = ggml_backend_dev_type(device) == GGML_BACKEND_DEVICE_TYPE_CPU;
uint64_t total = 0;
for (const auto & [buft, mb] : ctx->memory_breakdown()) {
const bool matches = is_host ? ggml_backend_buft_is_host(buft) :
ggml_backend_buft_get_device(buft) == device;
if (matches) {
total += mb.total();
}
}
return total;
}
//
// training
//

View File

@ -248,6 +248,21 @@ server_models::server_models(
LOG_WRN("failed to get server executable path: %s\n", e.what());
LOG_WRN("using original argv[0] as fallback: %s\n", argv[0]);
}
const uint64_t memory_margin = base_params.models_memory_margin * 1024 * 1024;
if (memory_margin > 0) {
const size_t n_devs = ggml_backend_dev_count();
for (size_t i = 0; i < n_devs; i++) {
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
size_t free, total;
ggml_backend_dev_memory(dev, &free, &total);
if (total > 0) {
memory_per_device[dev] = (free > memory_margin) ? free - memory_margin : 0;
}
}
}
load_models();
}
@ -433,20 +448,20 @@ void server_models::load_models() {
// FIRST LOAD: add all models, then unlock for autoloading
for (const auto & [name, preset] : final_presets) {
server_model_meta meta{
/* source */ get_source(name),
/* preset */ preset,
/* name */ name,
/* aliases */ {},
/* tags */ {},
/* port */ 0,
/* status */ SERVER_MODEL_STATUS_UNLOADED,
/* last_used */ 0,
/* memory_mb */ 0,
/* args */ std::vector<std::string>(),
/* loaded_info */ {},
/* exit_code */ 0,
/* stop_timeout */ DEFAULT_STOP_TIMEOUT,
/* multimodal */ mtmd_caps{false, false},
/* source */ get_source(name),
/* preset */ preset,
/* name */ name,
/* aliases */ {},
/* tags */ {},
/* port */ 0,
/* status */ SERVER_MODEL_STATUS_UNLOADED,
/* last_used */ 0,
/* memory_per_device */ {},
/* args */ std::vector<std::string>(),
/* loaded_info */ {},
/* exit_code */ 0,
/* stop_timeout */ DEFAULT_STOP_TIMEOUT,
/* multimodal */ mtmd_caps{false, false},
// /* need_download */ false,
};
add_model(std::move(meta));
@ -600,20 +615,20 @@ void server_models::load_models() {
for (const auto & [name, preset] : final_presets) {
if (mapping.find(name) == mapping.end()) {
server_model_meta meta{
/* source */ get_source(name),
/* preset */ preset,
/* name */ name,
/* aliases */ {},
/* tags */ {},
/* port */ 0,
/* status */ SERVER_MODEL_STATUS_UNLOADED,
/* last_used */ 0,
/* memory_mb */ 0,
/* args */ std::vector<std::string>(),
/* loaded_info */ {},
/* exit_code */ 0,
/* stop_timeout */ DEFAULT_STOP_TIMEOUT,
/* multimodal */ mtmd_caps{false, false},
/* source */ get_source(name),
/* preset */ preset,
/* name */ name,
/* aliases */ {},
/* tags */ {},
/* port */ 0,
/* status */ SERVER_MODEL_STATUS_UNLOADED,
/* last_used */ 0,
/* memory_per_device */ {},
/* args */ std::vector<std::string>(),
/* loaded_info */ {},
/* exit_code */ 0,
/* stop_timeout */ DEFAULT_STOP_TIMEOUT,
/* multimodal */ mtmd_caps{false, false},
// /* need_download */ false,
};
add_model(std::move(meta));
@ -782,36 +797,63 @@ std::vector<server_model_meta> server_models::get_all_meta() {
return result;
}
void server_models::unload_lru(uint64_t new_model_memory_mb) {
if (base_params.models_max <= 0 && base_params.models_memory_max <= 0) {
uint64_t server_models::get_memory_exceeded(const model_memory_map& new_model_memory_per_device) const {
model_memory_map total_memory_per_device;
for (const auto & m : mapping) {
if (m.second.meta.is_running()) {
for (const auto& [key, value] : m.second.meta.memory_per_device) {
total_memory_per_device[key] += value;
}
}
}
auto get = [](const model_memory_map & m, ggml_backend_dev_t k) {
auto it = m.find(k);
return it != m.end() ? it->second : 0;
};
uint64_t memory_exceeded = 0;
for (const auto& [key, limit] : memory_per_device) {
if (get(new_model_memory_per_device, key) + get(total_memory_per_device, key) > limit) {
memory_exceeded++;
}
}
return memory_exceeded;
}
void server_models::unload_lru(const model_memory_map& new_model_memory_per_device) {
const bool check_memory = base_params.models_memory_margin > 0 && !memory_per_device.empty();
if (base_params.models_max <= 0 && !check_memory) {
return; // no limit
}
while (true) {
std::string lru_model_name = "";
int64_t lru_last_used = ggml_time_ms();
size_t count_active = 0;
uint64_t total_memory_mb = 0;
uint64_t memory_exceeded = 0;
{
std::unique_lock<std::mutex> lk(mutex);
for (const auto & m : mapping) {
if (m.second.meta.is_running()) {
count_active++;
total_memory_mb += m.second.meta.memory_mb;
if (m.second.meta.last_used < lru_last_used) {
lru_model_name = m.first;
lru_last_used = m.second.meta.last_used;
}
}
}
memory_exceeded = get_memory_exceeded(new_model_memory_per_device);
}
bool count_exceeded = base_params.models_max > 0 &&
(count_active + 1) >= (size_t)base_params.models_max;
uint64_t projected_memory = total_memory_mb + new_model_memory_mb;
bool memory_exceeded = base_params.models_memory_max > 0 &&
projected_memory >= (uint64_t)base_params.models_memory_max;
if (!lru_model_name.empty() && (count_exceeded || memory_exceeded)) {
SRV_INF("limits reached (count=%zu, memory=%lu MB + %lu MB new), removing LRU name=%s\n",
count_active, (unsigned long)total_memory_mb, (unsigned long)new_model_memory_mb, lru_model_name.c_str());
if (!lru_model_name.empty() && (count_exceeded || memory_exceeded > 0)) {
SRV_INF("limits reached (count=%zu, memory margin exceeded on %zu device(s)), removing LRU name=%s\n",
count_active, memory_exceeded, lru_model_name.c_str());
unload(lru_model_name);
// wait for unload to complete
{
@ -826,12 +868,12 @@ void server_models::unload_lru(uint64_t new_model_memory_mb) {
}
}
static uint64_t get_model_memory_mb(const common_preset& preset) {
static model_memory_map get_model_memory_per_device(const common_preset& preset) {
common_params params;
preset.apply_to_params(params);
if(params.model.path.empty()) {
return 0;
return {};
}
struct log_ud_t {
@ -855,18 +897,32 @@ static uint64_t get_model_memory_mb(const common_preset& preset) {
mparams.use_mmap = false;
mparams.use_mlock = false;
llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams);
llama_log_set(log_ud.original.callback, log_ud.original.user_data);
llama_model_ptr model{llama_model_load_from_file(params.model.path.c_str(), mparams)};
if (!model) {
return 0;
llama_log_set(log_ud.original.callback, log_ud.original.user_data);
return {};
}
uint64_t size_bytes = llama_model_size(model);
llama_model_free(model);
llama_context_params cparams = common_context_params_to_llama(params);
llama_context_ptr ctx{llama_init_from_model(model.get(), cparams)};
llama_log_set(log_ud.original.callback, log_ud.original.user_data);
return size_bytes / (1024 * 1024);
if (!ctx) {
return {};
}
model_memory_map result;
const size_t n_devs = ggml_backend_dev_count();
for (size_t i = 0; i < n_devs; i++) {
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
uint64_t bytes = llama_context_device_memory(ctx.get(), dev);
if (bytes > 0) {
result[dev] = bytes;
}
}
return result;
}
void server_models::load(const std::string & name) {
@ -874,23 +930,18 @@ void server_models::load(const std::string & name) {
throw std::runtime_error("model name=" + name + " is not found");
}
uint64_t new_model_memory_mb = 0;
if (base_params.models_memory_max > 0) {
model_memory_map new_model_memory_per_device;
if (base_params.models_memory_margin > 0) {
std::lock_guard<std::mutex> lk(mutex);
auto & meta = mapping[name].meta;
if (meta.memory_mb > 0) {
new_model_memory_mb = meta.memory_mb;
} else {
new_model_memory_mb = get_model_memory_mb(meta.preset);
meta.memory_mb = new_model_memory_mb;
}
if (new_model_memory_mb > 0) {
SRV_INF("model %s memory requirements: %lu MB\n", name.c_str(),
(unsigned long)new_model_memory_mb);
if (meta.memory_per_device.empty()) {
meta.memory_per_device = get_model_memory_per_device(meta.preset);
}
new_model_memory_per_device = meta.memory_per_device;
}
unload_lru(new_model_memory_mb);
unload_lru(new_model_memory_per_device);
std::unique_lock<std::mutex> lk(mutex);
// edge case: block until any in-progress reload has finished so we always load
@ -907,17 +958,15 @@ void server_models::load(const std::string & name) {
// exceeding models_max. Without this, the window between unload_lru()
// releasing its lock and this lock_guard acquiring allows multiple
// threads to each observe capacity and all proceed to load.
if (base_params.models_max > 0 || base_params.models_memory_max > 0) {
if (base_params.models_max > 0 || base_params.models_memory_margin > 0) {
size_t count_active = 0;
uint64_t total_memory_mb = 0;
for (const auto & m : mapping) {
if (m.second.meta.is_running()) {
count_active++;
total_memory_mb += m.second.meta.memory_mb;
}
}
bool count_exceeded = base_params.models_max > 0 && count_active >= (size_t)base_params.models_max;
bool memory_exceeded = base_params.models_memory_max > 0 && total_memory_mb >= (uint64_t)base_params.models_memory_max;
bool memory_exceeded = get_memory_exceeded(new_model_memory_per_device) > 0;
if (count_exceeded || memory_exceeded) {
throw std::runtime_error("model limit reached, try again later");
}

View File

@ -61,6 +61,8 @@ static std::string server_model_source_to_string(server_model_source source) {
}
}
using model_memory_map = std::map<ggml_backend_dev_t, uint64_t>;
struct server_model_meta {
server_model_source source = SERVER_MODEL_SOURCE_CACHE;
common_preset preset;
@ -70,7 +72,7 @@ struct server_model_meta {
int port = 0;
server_model_status status = SERVER_MODEL_STATUS_UNLOADED;
int64_t last_used = 0; // for LRU unloading
uint64_t memory_mb = 0; // size in MB
model_memory_map memory_per_device; // projected bytes per device
std::vector<std::string> args; // args passed to the model instance, will be populated by render_args()
json loaded_info; // info to be reflected via /v1/models endpoint ; if in DOWNLOADING state, it should contain download progress info
int exit_code = 0; // exit code of the model instance process (only valid if status == FAILED)
@ -129,10 +131,13 @@ private:
std::vector<std::string> base_env;
common_preset base_preset; // base preset from llama-server CLI args
// available memory per device
std::map<ggml_backend_dev_t, uint64_t> memory_per_device;
void update_meta(const std::string & name, const server_model_meta & meta);
// unload least recently used models if the limit is reached
void unload_lru(uint64_t new_model_memory_mb = 0);
void unload_lru(const model_memory_map& new_model_memory_per_device);
// not thread-safe, caller must hold mutex
void add_model(server_model_meta && meta);
@ -140,6 +145,9 @@ private:
// notify SSE clients
void notify_sse(const std::string & event, const std::string & model_id, const json & data = nullptr);
// not thread-safe, caller must hold mutex
uint64_t get_memory_exceeded(const model_memory_map& new_model_memory_per_device) const;
public:
server_models(const common_params & params, int argc, char ** argv);