diff --git a/common/common.cpp b/common/common.cpp index 544eed2c..e189c79f 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -498,6 +498,18 @@ common_webui common_webui_from_name(const std::string& format) { } } +common_checkpoint_eviction common_checkpoint_eviction_from_name(const std::string & format) { + if (format == "auto") { + return COMMON_CHECKPOINT_EVICTION_AUTO; + } else if (format == "fifo") { + return COMMON_CHECKPOINT_EVICTION_FIFO; + } else if (format == "variance") { + return COMMON_CHECKPOINT_EVICTION_VARIANCE; + } else { + return COMMON_CHECKPOINT_EVICTION_AUTO; + } +} + thinking_tokens thinking_tokens_from_string(const std::string& format) { thinking_tokens think_token; std::string token_string = string_strip(format); @@ -2772,6 +2784,11 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.ctx_checkpoints_tolerance = std::stoi(argv[i]); return true; } + if (arg == "--ctx-checkpoints-eviction") { + CHECK_ARG + params.ctx_checkpoint_eviction= common_checkpoint_eviction_from_name(std::string(argv[i])); + return true; + } if (arg == "-cram" || arg == "--cache-ram") { CHECK_ARG params.cache_ram_mib = std::stoi(argv[i]); @@ -2982,6 +2999,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", "--ctx-checkpoints N", "max number of context checkpoints to create per slot (default: %d)",params.ctx_checkpoints_n}); options.push_back({ "*", "--ctx-checkpoints-interval N", "minimum number of tokens between each context checkpoint. (default: %d, <=0 disable)",params.ctx_checkpoints_interval}); options.push_back({ "*", "--ctx-checkpoints-tolerance N", "the number of tokens before the full prompt to create the checkpoint. (default: %d, <=0 disable)",params.ctx_checkpoints_tolerance}); + options.push_back({ "*", "--ctx-checkpoints-eviction NAME", "Eviction strategy for checkpoint. Accepts fifo, variance and auto. Auto defaults to variance. Variance preserves coverage and maintains uniform interval. (default: variance)" }); options.push_back({ "*", "-cram, --cache-ram N", "set the maximum cache size in MiB (default: %d, -1 - no limit, 0 - disable)",params.cache_ram_mib }); options.push_back({ "*", "-crs, --cache-ram-similarity N", "max of similarity of prompt tokens to cache tokens that triggers prompt cache (default: %.2f).",params.cache_ram_similarity }); options.push_back({ "*", "-cram-n-min --cache-ram-n-min N", "minimum number of the cached tokens that triggers prompt cache (default: %d).", params.cache_ram_n_min }); diff --git a/common/common.h b/common/common.h index a6946528..b99848d1 100644 --- a/common/common.h +++ b/common/common.h @@ -127,8 +127,17 @@ enum common_webui { COMMON_WEBUI_LLAMACPP, }; +enum common_checkpoint_eviction { + COMMON_CHECKPOINT_EVICTION_AUTO, + COMMON_CHECKPOINT_EVICTION_FIFO, + COMMON_CHECKPOINT_EVICTION_VARIANCE +}; + common_webui common_webui_from_name(const std::string& format); +common_checkpoint_eviction common_checkpoint_eviction_from_name(const std::string & format); + + struct thinking_tokens { bool exclude = true; std::string begin = ""; @@ -527,6 +536,7 @@ struct gpt_params { int32_t ctx_checkpoints_n = 32; // max number of context checkpoints per slot int32_t ctx_checkpoints_interval = 512; // minimum number of tokens between each context checkpoints int32_t ctx_checkpoints_tolerance = 5; // the number of tokens before the full prompt to create the checkpoint + common_checkpoint_eviction ctx_checkpoint_eviction = COMMON_CHECKPOINT_EVICTION_VARIANCE; int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc. int32_t cache_ram_n_min = 0; // min number of tokens required to save in the ram float cache_ram_similarity = 0.5f; // similarity of tokens to cached tokens diff --git a/examples/server/server-context.cpp b/examples/server/server-context.cpp index c2563b23..b23de8bd 100644 --- a/examples/server/server-context.cpp +++ b/examples/server/server-context.cpp @@ -3594,6 +3594,51 @@ void server_context::apply_checkpoint(server_slot & slot) { } } +static std::list::iterator evict_checkpoint_by_variance(server_slot & slot, std::list & ckpts) { + auto it = ckpts.begin(); + if (ckpts.size() < 3) { + return it; + } else if (ckpts.size() == 3) { + std::advance(it, 1); + return it; + } + std::vector tokens; + tokens.reserve(ckpts.size()); + for (const auto & ckpt : ckpts) { + tokens.push_back(int64_t(ckpt.pos_max)); + } + // Remove the checkpoint that makes the distribution most even after removal. + // For each interior checkpoint, compute the variance of gaps that would result + // if it were removed. Pick the one with the lowest variance (most uniform spacing). + size_t best_idx = 1; + const size_t n = tokens.size(); + const size_t start = 1; // never remove the first + const size_t end = n - 1; // never remove the last + double max_pos = tokens[n - 1]; + // To avoid doing double for loop to calculate variance, + // We only need to find the one with the min product of two consecutive gaps. + // Why: + // Gap between checkpoints: x1, x2, .., x_n-1. + // The average is constant because first and last checkpoint is never removed + // Variance of the gap after removing i_th checkpoint is: + // x1^2+..+(x_n-1)^2+2*x_i*x_(i+1) - average^2 + // Find the minimum variance is finding min { x_i*x_(i+1) } + double diff = (tokens[start] - tokens[start - 1]); + double diff2 = (tokens[start + 1] - tokens[start]); + double best_variance = diff * (diff2 / max_pos); + for (size_t i = start+1; i < end; i++) { + diff = tokens[i] - tokens[i - 1]; + diff2 = tokens[i + 1] - tokens[i]; + double variance = diff * (diff2 / max_pos); + if (variance < best_variance) { + best_variance = variance; + best_idx = i; + } + } + std::advance(it, best_idx); + return it; +} + bool server_context::create_checkpoint(server_slot & slot) { bool do_checkpoint = !slot.image_just_processed; int32_t pos_min = llama_kv_cache_seq_pos_min(slot.ctx, slot.id); @@ -3609,12 +3654,15 @@ bool server_context::create_checkpoint(server_slot & slot) { const int64_t t_start = ggml_time_us(); while (slot.server_cached_prompt.checkpoints.size() >= (size_t)params_base.ctx_checkpoints_n) { // make room for the new checkpoint, if needed - const auto & cur = slot.server_cached_prompt.checkpoints.front(); - + auto it = slot.server_cached_prompt.checkpoints.begin(); + if (params_base.ctx_checkpoint_eviction == COMMON_CHECKPOINT_EVICTION_VARIANCE || + params_base.ctx_checkpoint_eviction == COMMON_CHECKPOINT_EVICTION_AUTO) { + it = evict_checkpoint_by_variance(slot, slot.server_cached_prompt.checkpoints); + } + const auto & cur = *it; SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", cur.pos_min, cur.pos_max, cur.n_tokens, (float)cur.data.size() / 1024 / 1024); - - slot.server_cached_prompt.checkpoints.erase(slot.server_cached_prompt.checkpoints.begin()); + slot.server_cached_prompt.checkpoints.erase(it); } auto & cur = slot.server_cached_prompt.checkpoints.emplace_back();