cont : clean-up

This commit is contained in:
Georgi Gerganov 2026-04-16 14:32:47 +03:00 committed by Ruben Ortlam
parent 3046b8853a
commit a35afd504f
3 changed files with 80 additions and 63 deletions

View File

@ -645,7 +645,7 @@ struct common_params {
std::string models_dir = ""; // directory containing models for the router server
std::string models_preset = ""; // directory containing model presets for the router server
int models_max = 4; // maximum number of models to load simultaneously
int models_memory_margin = 1024; // MB of free memory to preserve per device (0 = disabled)
int models_memory_margin = 1024; // MiB of free memory to preserve per device (0 = disabled)
bool models_autoload = true; // automatically load models when requested via the router server
std::string models_preset_hf = ""; // show a warning about remote presets on router loaded (if not empty)

View File

@ -248,11 +248,11 @@ server_models::server_models(
bin_path = get_server_exec_path().string();
} catch (const std::exception & e) {
bin_path = argv[0];
LOG_WRN("failed to get server executable path: %s\n", e.what());
LOG_WRN("using original argv[0] as fallback: %s\n", argv[0]);
SRV_WRN("failed to get server executable path: %s\n", e.what());
SRV_WRN("using original argv[0] as fallback: %s\n", argv[0]);
}
const uint64_t memory_margin = (uint64_t)base_params.models_memory_margin * 1024 * 1024;
const size_t memory_margin = (size_t) base_params.models_memory_margin * 1024 * 1024;
if (memory_margin > 0) {
const size_t n_devs = ggml_backend_dev_count();
@ -261,11 +261,10 @@ server_models::server_models(
size_t free, total;
ggml_backend_dev_memory(dev, &free, &total);
if (total > 0) {
const uint64_t available = (free > memory_margin) ? free - memory_margin : 0;
available_memory_per_device[dev] = available;
SRV_DBG("device %s: available memory after margin=%lu MiB\n",
ggml_backend_dev_name(dev),
(unsigned long)(available / (1024 * 1024)));
const size_t available = (free > memory_margin) ? free - memory_margin : 0;
dmm_available[dev] = available;
SRV_DBG("device %s: available memory after margin=%zu MiB\n",
ggml_backend_dev_name(dev), available / (1024 * 1024));
}
}
}
@ -804,52 +803,57 @@ std::vector<server_model_meta> server_models::get_all_meta() {
return result;
}
uint64_t server_models::get_memory_exceeded(const model_memory_map & new_model_memory_per_device) const {
model_memory_map total_memory_per_device;
int server_models::can_fit(const device_memory_map & dmm_req) const {
device_memory_map dmm_total;
for (const auto & m : mapping) {
if (m.second.meta.is_running()) {
for (const auto & [key, value] : m.second.meta.memory_usage_per_device) {
total_memory_per_device[key] += value;
for (const auto & [dev, mem] : m.second.meta.dmm_req) {
dmm_total[dev] += mem;
}
}
}
auto get = [](const model_memory_map & m, ggml_backend_dev_t k) {
auto it = m.find(k);
return it != m.end() ? it->second : 0;
auto get = [](const device_memory_map & dmm, ggml_backend_dev_t dev) {
auto it = dmm.find(dev);
return it != dmm.end() ? it->second : 0;
};
size_t count_memory_exceeded = 0;
int res = 0;
for (const auto & [key, limit] : available_memory_per_device) {
const uint64_t total_memory = get(total_memory_per_device, key);
const uint64_t new_memory = get(new_model_memory_per_device, key);
SRV_DBG("device %s: total=%lu MB, new=%lu MB, limit=%lu MB\n",
ggml_backend_dev_name(key),
(unsigned long)(total_memory / (1024 * 1024)),
(unsigned long)(new_memory / (1024 * 1024)),
(unsigned long)(limit / (1024 * 1024)));
for (const auto & [dev, limit] : dmm_available) {
const size_t mem_total = get(dmm_total, dev);
const size_t mem_new = get(dmm_req, dev);
if (total_memory + new_memory > limit) {
count_memory_exceeded++;
SRV_DBG("device %s: total=%zu MiB, new=%zu MiB, limit=%zu MiB\n",
ggml_backend_dev_name(dev),
mem_total / (1024 * 1024), mem_new / (1024 * 1024), limit / (1024 * 1024));
if (mem_total + mem_new > limit) {
res++;
}
}
return count_memory_exceeded;
return res;
}
void server_models::unload_lru(const model_memory_map & new_model_memory_per_device) {
const bool check_memory = base_params.models_memory_margin > 0 && !available_memory_per_device.empty();
void server_models::unload_lru(const device_memory_map & dmm_req) {
const bool check_active = base_params.models_max > 0;
const bool check_memory = base_params.models_memory_margin > 0;
if (base_params.models_max <= 0 && !check_memory) {
if (!check_active && !check_memory) {
return; // no limit
}
if (check_memory) {
GGML_ASSERT(!dmm_available.empty());
}
while (true) {
std::string lru_model_name = "";
std::string lru_model_name;
int64_t lru_last_used = ggml_time_ms();
size_t count_active = 0;
size_t count_memory_exceeded = 0;
int count_active = 0;
int count_exceed = 0;
{
std::unique_lock<std::mutex> lk(mutex);
for (const auto & m : mapping) {
@ -861,14 +865,17 @@ void server_models::unload_lru(const model_memory_map & new_model_memory_per_dev
}
}
}
count_memory_exceeded = get_memory_exceeded(new_model_memory_per_device);
if (check_memory) {
count_exceed = can_fit(dmm_req);
}
}
bool count_exceeded = base_params.models_max > 0 &&
(count_active + 1) > (size_t)base_params.models_max;
if (!lru_model_name.empty() && (count_exceeded || count_memory_exceeded > 0)) {
SRV_INF("limits reached (count=%zu, memory margin exceeded on %zu device(s)), removing LRU name=%s\n",
count_active, count_memory_exceeded, lru_model_name.c_str());
const bool active_exceeded = check_active && count_active >= base_params.models_max;
const bool memory_exceeded = check_memory && count_exceed > 0;
if (!lru_model_name.empty() && (active_exceeded || memory_exceeded)) {
SRV_INF("limits reached (count=%d, memory margin exceeded on %d device(s)), removing LRU name=%s\n",
count_active, count_exceed, lru_model_name.c_str());
unload(lru_model_name);
// wait for unload to complete
{
@ -883,11 +890,11 @@ void server_models::unload_lru(const model_memory_map & new_model_memory_per_dev
}
}
static model_memory_map get_model_memory_per_device(const common_preset& preset) {
static device_memory_map get_model_memory_per_device(const common_preset & preset) {
common_params params;
preset.apply_to_params(params);
if(params.model.path.empty()) {
if (params.model.path.empty()) {
return {};
}
@ -927,7 +934,7 @@ static model_memory_map get_model_memory_per_device(const common_preset& preset)
return {};
}
model_memory_map result;
device_memory_map result;
const size_t n_devs = ggml_backend_dev_count();
for (size_t i = 0; i < n_devs; i++) {
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
@ -945,18 +952,19 @@ void server_models::load(const std::string & name) {
throw std::runtime_error("model name=" + name + " is not found");
}
model_memory_map new_model_memory_per_device;
device_memory_map dmm_req;
if (base_params.models_memory_margin > 0) {
// determine the required memory by the model upon its first load
std::lock_guard<std::mutex> lk(mutex);
auto & meta = mapping[name].meta;
if (meta.memory_usage_per_device.empty()) {
meta.memory_usage_per_device = get_model_memory_per_device(meta.preset);
if (meta.dmm_req.empty()) {
meta.dmm_req = get_model_memory_per_device(meta.preset);
}
new_model_memory_per_device = meta.memory_usage_per_device;
dmm_req = meta.dmm_req;
}
unload_lru(new_model_memory_per_device);
unload_lru(dmm_req);
std::unique_lock<std::mutex> lk(mutex);
// edge case: block until any in-progress reload has finished so we always load
@ -973,17 +981,24 @@ void server_models::load(const std::string & name) {
// exceeding models_max. Without this, the window between unload_lru()
// releasing its lock and this lock_guard acquiring allows multiple
// threads to each observe capacity and all proceed to load.
if (base_params.models_max > 0 || base_params.models_memory_margin > 0) {
size_t count_active = 0;
for (const auto & m : mapping) {
if (m.second.meta.is_running()) {
count_active++;
{
const bool check_active = base_params.models_max > 0;
const bool check_memory = base_params.models_memory_margin > 0;
if (check_active || check_memory) {
int count_active = 0;
for (const auto & m : mapping) {
if (m.second.meta.is_running()) {
count_active++;
}
}
const bool active_exceeded = check_active && count_active >= base_params.models_max;
const bool memory_exceeded = check_memory && can_fit(dmm_req) > 0;
if (active_exceeded || memory_exceeded) {
throw std::runtime_error("model limit reached, try again later");
}
}
bool count_exceeded = base_params.models_max > 0 && count_active >= (size_t)base_params.models_max;
bool memory_exceeded = get_memory_exceeded(new_model_memory_per_device) > 0;
if (count_exceeded || memory_exceeded) {
throw std::runtime_error("model limit reached, try again later");
}
}

View File

@ -61,7 +61,7 @@ static std::string server_model_source_to_string(server_model_source source) {
}
}
using model_memory_map = std::map<ggml_backend_dev_t, uint64_t>;
using device_memory_map = std::map<ggml_backend_dev_t, size_t>;
struct server_model_meta {
server_model_source source = SERVER_MODEL_SOURCE_CACHE;
@ -72,7 +72,7 @@ struct server_model_meta {
int port = 0;
server_model_status status = SERVER_MODEL_STATUS_UNLOADED;
int64_t last_used = 0; // for LRU unloading
model_memory_map memory_usage_per_device; // bytes used per device
device_memory_map dmm_req; // bytes required per device
std::vector<std::string> args; // args passed to the model instance, will be populated by render_args()
json loaded_info; // info to be reflected via /v1/models endpoint ; if in DOWNLOADING state, it should contain download progress info
int exit_code = 0; // exit code of the model instance process (only valid if status == FAILED)
@ -132,12 +132,12 @@ private:
common_preset base_preset; // base preset from llama-server CLI args
// available memory per device
model_memory_map available_memory_per_device;
device_memory_map dmm_available;
void update_meta(const std::string & name, const server_model_meta & meta);
// unload least recently used models if the limit is reached
void unload_lru(const model_memory_map & new_model_memory_per_device);
void unload_lru(const device_memory_map & dmm_req);
// not thread-safe, caller must hold mutex
void add_model(server_model_meta && meta);
@ -145,8 +145,10 @@ private:
// notify SSE clients
void notify_sse(const std::string & event, const std::string & model_id, const json & data = nullptr);
// return number of devices where the memory limit would be exceeded
// return 0 if the new model would fit on all devices
// not thread-safe, caller must hold mutex
uint64_t get_memory_exceeded(const model_memory_map & new_model_memory_per_device) const;
int can_fit(const device_memory_map & dmm_req) const;
public:
server_models(const common_params & params, int argc, char ** argv);