mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-06-28 04:30:15 -05:00
server: variance based checkpoint eviction
Co-authored-by: gabucino <gabucino>
This commit is contained in:
parent
7ccf1d2095
commit
3476dd6a40
@ -498,6 +498,18 @@ common_webui common_webui_from_name(const std::string& format) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
common_checkpoint_eviction common_checkpoint_eviction_from_name(const std::string & format) {
|
||||||
|
if (format == "auto") {
|
||||||
|
return COMMON_CHECKPOINT_EVICTION_AUTO;
|
||||||
|
} else if (format == "fifo") {
|
||||||
|
return COMMON_CHECKPOINT_EVICTION_FIFO;
|
||||||
|
} else if (format == "variance") {
|
||||||
|
return COMMON_CHECKPOINT_EVICTION_VARIANCE;
|
||||||
|
} else {
|
||||||
|
return COMMON_CHECKPOINT_EVICTION_AUTO;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
thinking_tokens thinking_tokens_from_string(const std::string& format) {
|
thinking_tokens thinking_tokens_from_string(const std::string& format) {
|
||||||
thinking_tokens think_token;
|
thinking_tokens think_token;
|
||||||
std::string token_string = string_strip(format);
|
std::string token_string = string_strip(format);
|
||||||
@ -2772,6 +2784,11 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
|||||||
params.ctx_checkpoints_tolerance = std::stoi(argv[i]);
|
params.ctx_checkpoints_tolerance = std::stoi(argv[i]);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
if (arg == "--ctx-checkpoints-eviction") {
|
||||||
|
CHECK_ARG
|
||||||
|
params.ctx_checkpoint_eviction= common_checkpoint_eviction_from_name(std::string(argv[i]));
|
||||||
|
return true;
|
||||||
|
}
|
||||||
if (arg == "-cram" || arg == "--cache-ram") {
|
if (arg == "-cram" || arg == "--cache-ram") {
|
||||||
CHECK_ARG
|
CHECK_ARG
|
||||||
params.cache_ram_mib = std::stoi(argv[i]);
|
params.cache_ram_mib = std::stoi(argv[i]);
|
||||||
@ -2982,6 +2999,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
|
|||||||
options.push_back({ "*", "--ctx-checkpoints N", "max number of context checkpoints to create per slot (default: %d)",params.ctx_checkpoints_n});
|
options.push_back({ "*", "--ctx-checkpoints N", "max number of context checkpoints to create per slot (default: %d)",params.ctx_checkpoints_n});
|
||||||
options.push_back({ "*", "--ctx-checkpoints-interval N", "minimum number of tokens between each context checkpoint. (default: %d, <=0 disable)",params.ctx_checkpoints_interval});
|
options.push_back({ "*", "--ctx-checkpoints-interval N", "minimum number of tokens between each context checkpoint. (default: %d, <=0 disable)",params.ctx_checkpoints_interval});
|
||||||
options.push_back({ "*", "--ctx-checkpoints-tolerance N", "the number of tokens before the full prompt to create the checkpoint. (default: %d, <=0 disable)",params.ctx_checkpoints_tolerance});
|
options.push_back({ "*", "--ctx-checkpoints-tolerance N", "the number of tokens before the full prompt to create the checkpoint. (default: %d, <=0 disable)",params.ctx_checkpoints_tolerance});
|
||||||
|
options.push_back({ "*", "--ctx-checkpoints-eviction NAME", "Eviction strategy for checkpoint. Accepts fifo, variance and auto. Auto defaults to variance. Variance preserves coverage and maintains uniform interval. (default: variance)" });
|
||||||
options.push_back({ "*", "-cram, --cache-ram N", "set the maximum cache size in MiB (default: %d, -1 - no limit, 0 - disable)",params.cache_ram_mib });
|
options.push_back({ "*", "-cram, --cache-ram N", "set the maximum cache size in MiB (default: %d, -1 - no limit, 0 - disable)",params.cache_ram_mib });
|
||||||
options.push_back({ "*", "-crs, --cache-ram-similarity N", "max of similarity of prompt tokens to cache tokens that triggers prompt cache (default: %.2f).",params.cache_ram_similarity });
|
options.push_back({ "*", "-crs, --cache-ram-similarity N", "max of similarity of prompt tokens to cache tokens that triggers prompt cache (default: %.2f).",params.cache_ram_similarity });
|
||||||
options.push_back({ "*", "-cram-n-min --cache-ram-n-min N", "minimum number of the cached tokens that triggers prompt cache (default: %d).", params.cache_ram_n_min });
|
options.push_back({ "*", "-cram-n-min --cache-ram-n-min N", "minimum number of the cached tokens that triggers prompt cache (default: %d).", params.cache_ram_n_min });
|
||||||
|
|||||||
@ -127,8 +127,17 @@ enum common_webui {
|
|||||||
COMMON_WEBUI_LLAMACPP,
|
COMMON_WEBUI_LLAMACPP,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
enum common_checkpoint_eviction {
|
||||||
|
COMMON_CHECKPOINT_EVICTION_AUTO,
|
||||||
|
COMMON_CHECKPOINT_EVICTION_FIFO,
|
||||||
|
COMMON_CHECKPOINT_EVICTION_VARIANCE
|
||||||
|
};
|
||||||
|
|
||||||
common_webui common_webui_from_name(const std::string& format);
|
common_webui common_webui_from_name(const std::string& format);
|
||||||
|
|
||||||
|
common_checkpoint_eviction common_checkpoint_eviction_from_name(const std::string & format);
|
||||||
|
|
||||||
|
|
||||||
struct thinking_tokens {
|
struct thinking_tokens {
|
||||||
bool exclude = true;
|
bool exclude = true;
|
||||||
std::string begin = "<think>";
|
std::string begin = "<think>";
|
||||||
@ -527,6 +536,7 @@ struct gpt_params {
|
|||||||
int32_t ctx_checkpoints_n = 32; // max number of context checkpoints per slot
|
int32_t ctx_checkpoints_n = 32; // max number of context checkpoints per slot
|
||||||
int32_t ctx_checkpoints_interval = 512; // minimum number of tokens between each context checkpoints
|
int32_t ctx_checkpoints_interval = 512; // minimum number of tokens between each context checkpoints
|
||||||
int32_t ctx_checkpoints_tolerance = 5; // the number of tokens before the full prompt to create the checkpoint
|
int32_t ctx_checkpoints_tolerance = 5; // the number of tokens before the full prompt to create the checkpoint
|
||||||
|
common_checkpoint_eviction ctx_checkpoint_eviction = COMMON_CHECKPOINT_EVICTION_VARIANCE;
|
||||||
int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc.
|
int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc.
|
||||||
int32_t cache_ram_n_min = 0; // min number of tokens required to save in the ram
|
int32_t cache_ram_n_min = 0; // min number of tokens required to save in the ram
|
||||||
float cache_ram_similarity = 0.5f; // similarity of tokens to cached tokens
|
float cache_ram_similarity = 0.5f; // similarity of tokens to cached tokens
|
||||||
|
|||||||
@ -3594,6 +3594,51 @@ void server_context::apply_checkpoint(server_slot & slot) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static std::list<server_prompt_checkpoint>::iterator evict_checkpoint_by_variance(server_slot & slot, std::list<server_prompt_checkpoint> & ckpts) {
|
||||||
|
auto it = ckpts.begin();
|
||||||
|
if (ckpts.size() < 3) {
|
||||||
|
return it;
|
||||||
|
} else if (ckpts.size() == 3) {
|
||||||
|
std::advance(it, 1);
|
||||||
|
return it;
|
||||||
|
}
|
||||||
|
std::vector<int64_t> tokens;
|
||||||
|
tokens.reserve(ckpts.size());
|
||||||
|
for (const auto & ckpt : ckpts) {
|
||||||
|
tokens.push_back(int64_t(ckpt.pos_max));
|
||||||
|
}
|
||||||
|
// Remove the checkpoint that makes the distribution most even after removal.
|
||||||
|
// For each interior checkpoint, compute the variance of gaps that would result
|
||||||
|
// if it were removed. Pick the one with the lowest variance (most uniform spacing).
|
||||||
|
size_t best_idx = 1;
|
||||||
|
const size_t n = tokens.size();
|
||||||
|
const size_t start = 1; // never remove the first
|
||||||
|
const size_t end = n - 1; // never remove the last
|
||||||
|
double max_pos = tokens[n - 1];
|
||||||
|
// To avoid doing double for loop to calculate variance,
|
||||||
|
// We only need to find the one with the min product of two consecutive gaps.
|
||||||
|
// Why:
|
||||||
|
// Gap between checkpoints: x1, x2, .., x_n-1.
|
||||||
|
// The average is constant because first and last checkpoint is never removed
|
||||||
|
// Variance of the gap after removing i_th checkpoint is:
|
||||||
|
// x1^2+..+(x_n-1)^2+2*x_i*x_(i+1) - average^2
|
||||||
|
// Find the minimum variance is finding min { x_i*x_(i+1) }
|
||||||
|
double diff = (tokens[start] - tokens[start - 1]);
|
||||||
|
double diff2 = (tokens[start + 1] - tokens[start]);
|
||||||
|
double best_variance = diff * (diff2 / max_pos);
|
||||||
|
for (size_t i = start+1; i < end; i++) {
|
||||||
|
diff = tokens[i] - tokens[i - 1];
|
||||||
|
diff2 = tokens[i + 1] - tokens[i];
|
||||||
|
double variance = diff * (diff2 / max_pos);
|
||||||
|
if (variance < best_variance) {
|
||||||
|
best_variance = variance;
|
||||||
|
best_idx = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::advance(it, best_idx);
|
||||||
|
return it;
|
||||||
|
}
|
||||||
|
|
||||||
bool server_context::create_checkpoint(server_slot & slot) {
|
bool server_context::create_checkpoint(server_slot & slot) {
|
||||||
bool do_checkpoint = !slot.image_just_processed;
|
bool do_checkpoint = !slot.image_just_processed;
|
||||||
int32_t pos_min = llama_kv_cache_seq_pos_min(slot.ctx, slot.id);
|
int32_t pos_min = llama_kv_cache_seq_pos_min(slot.ctx, slot.id);
|
||||||
@ -3609,12 +3654,15 @@ bool server_context::create_checkpoint(server_slot & slot) {
|
|||||||
const int64_t t_start = ggml_time_us();
|
const int64_t t_start = ggml_time_us();
|
||||||
while (slot.server_cached_prompt.checkpoints.size() >= (size_t)params_base.ctx_checkpoints_n) {
|
while (slot.server_cached_prompt.checkpoints.size() >= (size_t)params_base.ctx_checkpoints_n) {
|
||||||
// make room for the new checkpoint, if needed
|
// make room for the new checkpoint, if needed
|
||||||
const auto & cur = slot.server_cached_prompt.checkpoints.front();
|
auto it = slot.server_cached_prompt.checkpoints.begin();
|
||||||
|
if (params_base.ctx_checkpoint_eviction == COMMON_CHECKPOINT_EVICTION_VARIANCE ||
|
||||||
|
params_base.ctx_checkpoint_eviction == COMMON_CHECKPOINT_EVICTION_AUTO) {
|
||||||
|
it = evict_checkpoint_by_variance(slot, slot.server_cached_prompt.checkpoints);
|
||||||
|
}
|
||||||
|
const auto & cur = *it;
|
||||||
SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
|
SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
|
||||||
cur.pos_min, cur.pos_max, cur.n_tokens, (float)cur.data.size() / 1024 / 1024);
|
cur.pos_min, cur.pos_max, cur.n_tokens, (float)cur.data.size() / 1024 / 1024);
|
||||||
|
slot.server_cached_prompt.checkpoints.erase(it);
|
||||||
slot.server_cached_prompt.checkpoints.erase(slot.server_cached_prompt.checkpoints.begin());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
auto & cur = slot.server_cached_prompt.checkpoints.emplace_back();
|
auto & cur = slot.server_cached_prompt.checkpoints.emplace_back();
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user