fit : wrap llama_device_memory_data

This commit is contained in:
Georgi Gerganov 2026-06-12 18:12:24 +03:00
parent 02182fc5b9
commit 3518061868
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
3 changed files with 40 additions and 14 deletions

View File

@ -26,7 +26,7 @@ class common_params_fit_exception : public std::runtime_error {
using std::runtime_error::runtime_error;
};
std::vector<llama_device_memory_data> common_get_device_memory_data(
static std::vector<llama_device_memory_data> common_get_device_memory_data_impl(
const char * path_model,
const llama_model_params * mparams,
const llama_context_params * cparams,
@ -150,6 +150,29 @@ std::vector<llama_device_memory_data> common_get_device_memory_data(
return ret;
}
common_device_memory_data_vec common_get_device_memory_data(
const char * path_model,
const llama_model_params * mparams,
const llama_context_params * cparams,
std::vector<ggml_backend_dev_t> & devs,
uint32_t & hp_ngl,
uint32_t & hp_n_ctx_train,
uint32_t & hp_n_expert,
ggml_log_level log_level) {
std::vector<llama_device_memory_data> impl = common_get_device_memory_data_impl(
path_model, mparams, cparams, devs, hp_ngl, hp_n_ctx_train, hp_n_expert, log_level);
common_device_memory_data_vec ret(impl.size());
for (size_t i = 0; i < impl.size(); i++) {
ret[i].total = impl[i].total;
ret[i].free = impl[i].free;
ret[i].model = impl[i].mb.model;
ret[i].context = impl[i].mb.context;
ret[i].compute = impl[i].mb.compute;
}
return ret;
}
static void common_params_fit_impl(
const char * path_model, struct llama_model_params * mparams, struct llama_context_params * cparams,
float * tensor_split, struct llama_model_tensor_buft_override * tensor_buft_overrides,
@ -169,7 +192,7 @@ static void common_params_fit_impl(
// step 1: get data for default parameters and check whether any changes are necessary in the first place
LOG_TRC("%s: getting device memory data for initial parameters:\n", __func__);
const dmds_t dmds_full = common_get_device_memory_data(path_model, mparams, cparams, devs, hp_ngl, hp_nct, hp_nex, log_level);
const dmds_t dmds_full = common_get_device_memory_data_impl(path_model, mparams, cparams, devs, hp_ngl, hp_nct, hp_nex, log_level);
const size_t nd = devs.size(); // number of devices
std::vector<int64_t> margins; // this function uses int64_t rather than size_t for memory sizes to more conveniently handle deficits
@ -304,7 +327,7 @@ static void common_params_fit_impl(
int64_t sum_projected_used_min_ctx = 0;
cparams->n_ctx = n_ctx_min;
const dmds_t dmds_min_ctx = common_get_device_memory_data(path_model, mparams, cparams, devs, hp_ngl, hp_nct, hp_nex, log_level);
const dmds_t dmds_min_ctx = common_get_device_memory_data_impl(path_model, mparams, cparams, devs, hp_ngl, hp_nct, hp_nex, log_level);
if (nd == 0) {
sum_projected_used_min_ctx = dmds_min_ctx.back().mb.total();
} else {
@ -482,7 +505,7 @@ static void common_params_fit_impl(
llama_model_params mparams_copy = *mparams;
set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, mparams_copy);
const dmds_t dmd_nl = common_get_device_memory_data(
const dmds_t dmd_nl = common_get_device_memory_data_impl(
path_model, &mparams_copy, cparams, devs, hp_ngl, hp_nct, hp_nex, log_level);
LOG_TRC("%s: memory for test allocation by device:\n", func_name);
@ -510,7 +533,7 @@ static void common_params_fit_impl(
mparams->tensor_buft_overrides = tensor_buft_overrides;
LOG_TRC("%s: getting device memory data with all MoE tensors moved to system memory:\n", __func__);
const dmds_t dmds_cpu_moe = common_get_device_memory_data(
const dmds_t dmds_cpu_moe = common_get_device_memory_data_impl(
path_model, mparams, cparams, devs, hp_ngl, hp_nct, hp_nex, log_level);
for (size_t id = 0; id < nd; id++) {
@ -940,7 +963,7 @@ void common_fit_print(
uint32_t hp_nct = 0; // hparams.n_ctx_train
uint32_t hp_nex = 0; // hparams.n_expert
auto dmd = common_get_device_memory_data(path_model, mparams, cparams, devs, hp_ngl, hp_nct, hp_nex, GGML_LOG_LEVEL_ERROR);
auto dmd = common_get_device_memory_data_impl(path_model, mparams, cparams, devs, hp_ngl, hp_nct, hp_nex, GGML_LOG_LEVEL_ERROR);
GGML_ASSERT(dmd.size() == devs.size() + 1);
for (size_t id = 0; id < devs.size(); id++) {

View File

@ -34,12 +34,18 @@ void common_fit_print(
void common_memory_breakdown_print(const llama_context * ctx);
// TODO: convert this to common_device_memory_data that wraps llama_device_memory_data
// add API for accessing the internal `llama-ext.h` information
struct llama_device_memory_data;
struct common_device_memory_data {
int64_t total;
int64_t free;
size_t model;
size_t context;
size_t compute;
};
using common_device_memory_data_vec = std::vector<common_device_memory_data>;
// Load a model + context with no_alloc and return the per-device memory breakdown.
std::vector<llama_device_memory_data> common_get_device_memory_data(
common_device_memory_data_vec common_get_device_memory_data(
const char * path_model,
const llama_model_params * mparams,
const llama_context_params * cparams,

View File

@ -866,10 +866,7 @@ private:
}
for (size_t j = 0; j < devs.size(); ++j) {
const size_t bytes =
(measure_model_bytes ? dmd[j].mb.model : 0) +
dmd[j].mb.context +
dmd[j].mb.compute;
const size_t bytes = (measure_model_bytes ? dmd[j].model : 0) + dmd[j].context + dmd[j].compute;
total += bytes;
for (size_t i = 0; i < tgt_devices.size(); i++) {
if (tgt_devices[i] == devs[j]) {