server: variance based checkpoint eviction

Co-authored-by: gabucino <gabucino>
This commit is contained in:
firecoperana 2026-06-23 20:22:06 -05:00
parent 7ccf1d2095
commit 3476dd6a40
3 changed files with 80 additions and 4 deletions

View File

@ -498,6 +498,18 @@ common_webui common_webui_from_name(const std::string& format) {
} }
} }
common_checkpoint_eviction common_checkpoint_eviction_from_name(const std::string & format) {
if (format == "auto") {
return COMMON_CHECKPOINT_EVICTION_AUTO;
} else if (format == "fifo") {
return COMMON_CHECKPOINT_EVICTION_FIFO;
} else if (format == "variance") {
return COMMON_CHECKPOINT_EVICTION_VARIANCE;
} else {
return COMMON_CHECKPOINT_EVICTION_AUTO;
}
}
thinking_tokens thinking_tokens_from_string(const std::string& format) { thinking_tokens thinking_tokens_from_string(const std::string& format) {
thinking_tokens think_token; thinking_tokens think_token;
std::string token_string = string_strip(format); std::string token_string = string_strip(format);
@ -2772,6 +2784,11 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.ctx_checkpoints_tolerance = std::stoi(argv[i]); params.ctx_checkpoints_tolerance = std::stoi(argv[i]);
return true; return true;
} }
if (arg == "--ctx-checkpoints-eviction") {
CHECK_ARG
params.ctx_checkpoint_eviction= common_checkpoint_eviction_from_name(std::string(argv[i]));
return true;
}
if (arg == "-cram" || arg == "--cache-ram") { if (arg == "-cram" || arg == "--cache-ram") {
CHECK_ARG CHECK_ARG
params.cache_ram_mib = std::stoi(argv[i]); params.cache_ram_mib = std::stoi(argv[i]);
@ -2982,6 +2999,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
options.push_back({ "*", "--ctx-checkpoints N", "max number of context checkpoints to create per slot (default: %d)",params.ctx_checkpoints_n}); options.push_back({ "*", "--ctx-checkpoints N", "max number of context checkpoints to create per slot (default: %d)",params.ctx_checkpoints_n});
options.push_back({ "*", "--ctx-checkpoints-interval N", "minimum number of tokens between each context checkpoint. (default: %d, <=0 disable)",params.ctx_checkpoints_interval}); options.push_back({ "*", "--ctx-checkpoints-interval N", "minimum number of tokens between each context checkpoint. (default: %d, <=0 disable)",params.ctx_checkpoints_interval});
options.push_back({ "*", "--ctx-checkpoints-tolerance N", "the number of tokens before the full prompt to create the checkpoint. (default: %d, <=0 disable)",params.ctx_checkpoints_tolerance}); options.push_back({ "*", "--ctx-checkpoints-tolerance N", "the number of tokens before the full prompt to create the checkpoint. (default: %d, <=0 disable)",params.ctx_checkpoints_tolerance});
options.push_back({ "*", "--ctx-checkpoints-eviction NAME", "Eviction strategy for checkpoint. Accepts fifo, variance and auto. Auto defaults to variance. Variance preserves coverage and maintains uniform interval. (default: variance)" });
options.push_back({ "*", "-cram, --cache-ram N", "set the maximum cache size in MiB (default: %d, -1 - no limit, 0 - disable)",params.cache_ram_mib }); options.push_back({ "*", "-cram, --cache-ram N", "set the maximum cache size in MiB (default: %d, -1 - no limit, 0 - disable)",params.cache_ram_mib });
options.push_back({ "*", "-crs, --cache-ram-similarity N", "max of similarity of prompt tokens to cache tokens that triggers prompt cache (default: %.2f).",params.cache_ram_similarity }); options.push_back({ "*", "-crs, --cache-ram-similarity N", "max of similarity of prompt tokens to cache tokens that triggers prompt cache (default: %.2f).",params.cache_ram_similarity });
options.push_back({ "*", "-cram-n-min --cache-ram-n-min N", "minimum number of the cached tokens that triggers prompt cache (default: %d).", params.cache_ram_n_min }); options.push_back({ "*", "-cram-n-min --cache-ram-n-min N", "minimum number of the cached tokens that triggers prompt cache (default: %d).", params.cache_ram_n_min });

View File

@ -127,8 +127,17 @@ enum common_webui {
COMMON_WEBUI_LLAMACPP, COMMON_WEBUI_LLAMACPP,
}; };
enum common_checkpoint_eviction {
COMMON_CHECKPOINT_EVICTION_AUTO,
COMMON_CHECKPOINT_EVICTION_FIFO,
COMMON_CHECKPOINT_EVICTION_VARIANCE
};
common_webui common_webui_from_name(const std::string& format); common_webui common_webui_from_name(const std::string& format);
common_checkpoint_eviction common_checkpoint_eviction_from_name(const std::string & format);
struct thinking_tokens { struct thinking_tokens {
bool exclude = true; bool exclude = true;
std::string begin = "<think>"; std::string begin = "<think>";
@ -527,6 +536,7 @@ struct gpt_params {
int32_t ctx_checkpoints_n = 32; // max number of context checkpoints per slot int32_t ctx_checkpoints_n = 32; // max number of context checkpoints per slot
int32_t ctx_checkpoints_interval = 512; // minimum number of tokens between each context checkpoints int32_t ctx_checkpoints_interval = 512; // minimum number of tokens between each context checkpoints
int32_t ctx_checkpoints_tolerance = 5; // the number of tokens before the full prompt to create the checkpoint int32_t ctx_checkpoints_tolerance = 5; // the number of tokens before the full prompt to create the checkpoint
common_checkpoint_eviction ctx_checkpoint_eviction = COMMON_CHECKPOINT_EVICTION_VARIANCE;
int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc. int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc.
int32_t cache_ram_n_min = 0; // min number of tokens required to save in the ram int32_t cache_ram_n_min = 0; // min number of tokens required to save in the ram
float cache_ram_similarity = 0.5f; // similarity of tokens to cached tokens float cache_ram_similarity = 0.5f; // similarity of tokens to cached tokens

View File

@ -3594,6 +3594,51 @@ void server_context::apply_checkpoint(server_slot & slot) {
} }
} }
static std::list<server_prompt_checkpoint>::iterator evict_checkpoint_by_variance(server_slot & slot, std::list<server_prompt_checkpoint> & ckpts) {
auto it = ckpts.begin();
if (ckpts.size() < 3) {
return it;
} else if (ckpts.size() == 3) {
std::advance(it, 1);
return it;
}
std::vector<int64_t> tokens;
tokens.reserve(ckpts.size());
for (const auto & ckpt : ckpts) {
tokens.push_back(int64_t(ckpt.pos_max));
}
// Remove the checkpoint that makes the distribution most even after removal.
// For each interior checkpoint, compute the variance of gaps that would result
// if it were removed. Pick the one with the lowest variance (most uniform spacing).
size_t best_idx = 1;
const size_t n = tokens.size();
const size_t start = 1; // never remove the first
const size_t end = n - 1; // never remove the last
double max_pos = tokens[n - 1];
// To avoid doing double for loop to calculate variance,
// We only need to find the one with the min product of two consecutive gaps.
// Why:
// Gap between checkpoints: x1, x2, .., x_n-1.
// The average is constant because first and last checkpoint is never removed
// Variance of the gap after removing i_th checkpoint is:
// x1^2+..+(x_n-1)^2+2*x_i*x_(i+1) - average^2
// Find the minimum variance is finding min { x_i*x_(i+1) }
double diff = (tokens[start] - tokens[start - 1]);
double diff2 = (tokens[start + 1] - tokens[start]);
double best_variance = diff * (diff2 / max_pos);
for (size_t i = start+1; i < end; i++) {
diff = tokens[i] - tokens[i - 1];
diff2 = tokens[i + 1] - tokens[i];
double variance = diff * (diff2 / max_pos);
if (variance < best_variance) {
best_variance = variance;
best_idx = i;
}
}
std::advance(it, best_idx);
return it;
}
bool server_context::create_checkpoint(server_slot & slot) { bool server_context::create_checkpoint(server_slot & slot) {
bool do_checkpoint = !slot.image_just_processed; bool do_checkpoint = !slot.image_just_processed;
int32_t pos_min = llama_kv_cache_seq_pos_min(slot.ctx, slot.id); int32_t pos_min = llama_kv_cache_seq_pos_min(slot.ctx, slot.id);
@ -3609,12 +3654,15 @@ bool server_context::create_checkpoint(server_slot & slot) {
const int64_t t_start = ggml_time_us(); const int64_t t_start = ggml_time_us();
while (slot.server_cached_prompt.checkpoints.size() >= (size_t)params_base.ctx_checkpoints_n) { while (slot.server_cached_prompt.checkpoints.size() >= (size_t)params_base.ctx_checkpoints_n) {
// make room for the new checkpoint, if needed // make room for the new checkpoint, if needed
const auto & cur = slot.server_cached_prompt.checkpoints.front(); auto it = slot.server_cached_prompt.checkpoints.begin();
if (params_base.ctx_checkpoint_eviction == COMMON_CHECKPOINT_EVICTION_VARIANCE ||
params_base.ctx_checkpoint_eviction == COMMON_CHECKPOINT_EVICTION_AUTO) {
it = evict_checkpoint_by_variance(slot, slot.server_cached_prompt.checkpoints);
}
const auto & cur = *it;
SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
cur.pos_min, cur.pos_max, cur.n_tokens, (float)cur.data.size() / 1024 / 1024); cur.pos_min, cur.pos_max, cur.n_tokens, (float)cur.data.size() / 1024 / 1024);
slot.server_cached_prompt.checkpoints.erase(it);
slot.server_cached_prompt.checkpoints.erase(slot.server_cached_prompt.checkpoints.begin());
} }
auto & cur = slot.server_cached_prompt.checkpoints.emplace_back(); auto & cur = slot.server_cached_prompt.checkpoints.emplace_back();