diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index b03dec9b..8fbe1d2f 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -86,6 +86,8 @@ add_library(${TARGET} STATIC unicode.h ngram-mod.cpp ngram-mod.h + suffix-tree.cpp + suffix-tree.h regex-partial.cpp regex-partial.h jinja/lexer.cpp diff --git a/common/common.cpp b/common/common.cpp index 7aaff9a5..d553d7d1 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1070,6 +1070,8 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V; } else if (value == "ngram-mod") { params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MOD; + } else if (value == "suffix") { + params.speculative.type = COMMON_SPECULATIVE_TYPE_SUFFIX; } else { throw std::invalid_argument("unknown speculative decoding type without draft model"); } @@ -1102,6 +1104,29 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.speculative.ngram_min_hits = value; return true; } + if (arg == "--suffix-pattern-len") { + CHECK_ARG + int value = std::stoi(argv[i]); + if (value < 1) { + throw std::invalid_argument("suffix pattern length must be at least 1"); + } + params.speculative.suffix_min_match_len = value; + return true; + } + if (arg == "--suffix-max-depth") { + CHECK_ARG + int value = std::stoi(argv[i]); + if (value < 1) { + throw std::invalid_argument("suffix max depth must be at least 1"); + } + params.speculative.suffix_max_depth = value; + return true; + } + if (arg == "--suffix-corpus") { + CHECK_ARG + params.speculative.suffix_corpus = argv[i]; + return true; + } if (arg == "-a" || arg == "--alias") { CHECK_ARG params.model_alias = argv[i]; @@ -2661,12 +2686,15 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param "number of tokens to draft for speculative decoding (default: %d)", params.speculative.n_max }); options.push_back({ "*", "--draft-min, --draft-n-min N", "minimum number of draft tokens to use for speculative decoding" }); options.push_back({ "*", "--draft-p-min P", "minimum speculative decoding probability (greedy) (default: %.1f)", (double)params.speculative.p_min }); - options.push_back({ "*", "--spec-type Name [none | ngram - cache | ngram - simple | ngram - map - k | ngram - map - k4v | ngram - mod]", "type of speculative decoding to use when no draft model is provided (default: %d)\n", (int)params.speculative.type}); + options.push_back({ "*", "--spec-type Name [none | ngram - cache | ngram - simple | ngram - map - k | ngram - map - k4v | ngram - mod | suffix]", "type of speculative decoding to use when no draft model is provided (default: %d)\n", (int)params.speculative.type}); options.push_back({ "*", "--spec-ngram-size-n N", "ngram size N for ngram-simple/ngram-map speculative decoding, length of lookup n-gram (default: %d)\n",params.speculative.ngram_size_n }); options.push_back({ "*", "--spec-ngram-size-m N", "ngram size M for ngram-simple/ngram-map speculative decoding, length of draft m-gram (default: %d)\n", params.speculative.ngram_size_m }); options.push_back({ "*", "--spec-ngram-min-hits N", "minimum hits for ngram-map speculative decoding (default: %d)\n", params.speculative.ngram_min_hits }); + options.push_back({ "*", "--suffix-pattern-len N", "minimum context match length for suffix decoding (default: %d)", params.speculative.suffix_min_match_len }); + options.push_back({ "*", "--suffix-max-depth N", "suffix tree maximum depth for suffix decoding (default: %d)", params.speculative.suffix_max_depth }); + options.push_back({ "*", "--suffix-corpus PATH", "corpus file to pre-warm the suffix tree: .json (array of strings or conversation messages) or .bin (raw int32 token IDs)" }); options.push_back({ "*", "--spec-autotune", "automatically tune speculative params to maximize tokens/sec" }); options.push_back({ "retrieval" }); diff --git a/common/common.h b/common/common.h index 326c1b6e..df93c1d4 100644 --- a/common/common.h +++ b/common/common.h @@ -146,6 +146,7 @@ enum common_speculative_type { COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, // self-speculative decoding with n-gram keys and 4 m-gram values COMMON_SPECULATIVE_TYPE_NGRAM_MOD, COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, // self-speculative decoding with 3-level n-gram cache + COMMON_SPECULATIVE_TYPE_SUFFIX, // self-speculative suffix-decoding (arXiv:2411.04975) COMMON_SPECULATIVE_TYPE_COUNT // number of types, unknown type }; @@ -182,6 +183,11 @@ struct common_params_speculative { std::shared_ptr ngram_mod; + // suffix-decoding specific + int32_t suffix_min_match_len = 5; // minimum context match length + int32_t suffix_max_depth = 64; // suffix tree maximum depth + std::string suffix_corpus; // path to corpus file for offline pre-warming (.json or .bin) + std::string lookup_cache_static; // path of static ngram cache file for lookup decoding // NOLINT std::string lookup_cache_dynamic; // path of dynamic ngram cache file for lookup decoding // NOLINT diff --git a/common/spec-tuner.cpp b/common/spec-tuner.cpp index 705671b3..90bfdfa1 100644 --- a/common/spec-tuner.cpp +++ b/common/spec-tuner.cpp @@ -107,12 +107,13 @@ void spec_tuner::reset_exploration() { void spec_tuner::write_best(common_params_speculative & params) const { for (const auto & coord : coords) { float val = coord.arms[coord.best_idx].value; - if (coord.name == "n_max") params.n_max = (int32_t)val; - else if (coord.name == "p_min") params.p_min = val; - else if (coord.name == "n_min") params.n_min = (int32_t)val; - else if (coord.name == "ngram_size_n") params.ngram_size_n = (uint16_t)val; - else if (coord.name == "ngram_size_m") params.ngram_size_m = (uint16_t)val; - else if (coord.name == "ngram_min_hits") params.ngram_min_hits = (uint16_t)val; + if (coord.name == "n_max") params.n_max = (int32_t)val; + else if (coord.name == "p_min") params.p_min = val; + else if (coord.name == "n_min") params.n_min = (int32_t)val; + else if (coord.name == "ngram_size_n") params.ngram_size_n = (uint16_t)val; + else if (coord.name == "ngram_size_m") params.ngram_size_m = (uint16_t)val; + else if (coord.name == "ngram_min_hits") params.ngram_min_hits = (uint16_t)val; + else if (coord.name == "suffix_min_match_len") params.suffix_min_match_len = (int32_t)val; } } @@ -155,6 +156,21 @@ void spec_tuner::init(common_speculative_type type, const common_params_speculat } } + if (type == COMMON_SPECULATIVE_TYPE_SUFFIX) { + { + spec_tuner_coord coord; + coord.name = "p_min"; + coord.build_grid_float(0.0f, 0.95f, 11, user_params.p_min); + coords.push_back(std::move(coord)); + } + { + spec_tuner_coord coord; + coord.name = "suffix_min_match_len"; + coord.build_grid_int(1, 12, 1, user_params.suffix_min_match_len); + coords.push_back(std::move(coord)); + } + } + // Ngram can change only n_max/n_min per call if (type == COMMON_SPECULATIVE_TYPE_NGRAM_MOD) { { @@ -168,12 +184,13 @@ void spec_tuner::init(common_speculative_type type, const common_params_speculat for (auto & coord : coords) { float user_val = 0.0f; - if (coord.name == "n_max") user_val = (float)user_params.n_max; - else if (coord.name == "p_min") user_val = user_params.p_min; - else if (coord.name == "n_min") user_val = (float)user_params.n_min; - else if (coord.name == "ngram_size_n") user_val = (float)user_params.ngram_size_n; - else if (coord.name == "ngram_size_m") user_val = (float)user_params.ngram_size_m; - else if (coord.name == "ngram_min_hits") user_val = (float)user_params.ngram_min_hits; + if (coord.name == "n_max") user_val = (float)user_params.n_max; + else if (coord.name == "p_min") user_val = user_params.p_min; + else if (coord.name == "n_min") user_val = (float)user_params.n_min; + else if (coord.name == "ngram_size_n") user_val = (float)user_params.ngram_size_n; + else if (coord.name == "ngram_size_m") user_val = (float)user_params.ngram_size_m; + else if (coord.name == "ngram_min_hits") user_val = (float)user_params.ngram_min_hits; + else if (coord.name == "suffix_min_match_len") user_val = (float)user_params.suffix_min_match_len; coord.user_idx = coord.find_nearest_arm(user_val); coord.best_idx = 0; @@ -201,12 +218,13 @@ void spec_tuner::propose(common_params_speculative & params) { coord.current_idx = coord.select_epsilon_greedy(epsilon); float val = coord.arms[coord.current_idx].value; - if (coord.name == "n_max") params.n_max = (int32_t)val; - else if (coord.name == "p_min") params.p_min = val; - else if (coord.name == "n_min") params.n_min = (int32_t)val; - else if (coord.name == "ngram_size_n") params.ngram_size_n = (uint16_t)val; - else if (coord.name == "ngram_size_m") params.ngram_size_m = (uint16_t)val; - else if (coord.name == "ngram_min_hits") params.ngram_min_hits = (uint16_t)val; + if (coord.name == "n_max") params.n_max = (int32_t)val; + else if (coord.name == "p_min") params.p_min = val; + else if (coord.name == "n_min") params.n_min = (int32_t)val; + else if (coord.name == "ngram_size_n") params.ngram_size_n = (uint16_t)val; + else if (coord.name == "ngram_size_m") params.ngram_size_m = (uint16_t)val; + else if (coord.name == "ngram_min_hits") params.ngram_min_hits = (uint16_t)val; + else if (coord.name == "suffix_min_match_len") params.suffix_min_match_len = (int32_t)val; } enforce_constraints(params); @@ -346,6 +364,7 @@ void spec_tuner::print_best() const { else if (coord.name == "ngram_size_n") oss << "--spec-ngram-size-n "; else if (coord.name == "ngram_size_m") oss << "--spec-ngram-size-m "; else if (coord.name == "ngram_min_hits") oss << "--spec-ngram-min-hits "; + else if (coord.name == "suffix_min_match_len") oss << "--suffix-pattern-len "; else oss << "--" << coord.name << " "; if (is_int) oss << (int)coord.arms[coord.best_idx].value << " "; diff --git a/common/speculative.cpp b/common/speculative.cpp index 5acef4b4..00696b65 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -8,6 +8,7 @@ #include "ngram-map.h" #include "ngram-mod.h" #include "sampling.h" +#include "suffix-tree.h" #include #include @@ -26,7 +27,8 @@ const std::vector common_speculative_types = { COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, COMMON_SPECULATIVE_TYPE_NGRAM_MOD, - COMMON_SPECULATIVE_TYPE_NGRAM_CACHE + COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, + COMMON_SPECULATIVE_TYPE_SUFFIX }; const std::map common_speculative_type_from_name_map = { @@ -38,7 +40,8 @@ const std::map common_speculative_typ {"ngram_map_k", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K}, {"ngram_map_k4v", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V}, {"ngram_mod", COMMON_SPECULATIVE_TYPE_NGRAM_MOD}, - {"ngram_cache", COMMON_SPECULATIVE_TYPE_NGRAM_CACHE} + {"ngram_cache", COMMON_SPECULATIVE_TYPE_NGRAM_CACHE}, + {"suffix", COMMON_SPECULATIVE_TYPE_SUFFIX} }; struct common_speculative_config { @@ -790,6 +793,126 @@ struct common_speculative_state_ngram_cache : public common_speculative_state { } }; +struct common_speculative_state_suffix : public common_speculative_state { + common_suffix_tree tree; + common_suffix_tree corpus_tree; + bool has_corpus = false; + size_t cache_size = 0; + + // Acceptance feedback + size_t n_draft_last = 0; + bool had_accept = false; + int n_low = 0; + float base_p_min = 0.1f; + float eff_p_min = 0.1f; + + common_speculative_state_suffix( + enum common_speculative_type type, + int max_depth, + const std::string & corpus_path, + const llama_model * model) + : common_speculative_state(type) + , tree(max_depth) + , corpus_tree(max_depth) + { + if (!corpus_path.empty()) { + std::function(const std::string &)> tokenize_fn; + if (model) { + tokenize_fn = [model](const std::string & text) -> std::vector { + return common_tokenize(model, text, false, true); + }; + } + has_corpus = corpus_tree.load_corpus(corpus_path, tokenize_fn); + } + } + + void begin(const llama_tokens & prompt) override { + cache_size = 0; + n_draft_last = 0; + had_accept = false; + n_low = 0; + GGML_UNUSED(prompt); + } + + void draft( + const common_params_speculative & params, + const llama_tokens & prompt_tgt, + llama_token id_last, + llama_tokens & result) override { + + base_p_min = params.p_min; + if (n_draft_last > 0 && !had_accept) { + if (++n_low >= 3) { + eff_p_min = std::min(eff_p_min + 0.1f, 0.5f); + n_low = 0; + } + } + had_accept = false; + + if (cache_size < prompt_tgt.size() + 1) { + llama_tokens tokens_new; + tokens_new.reserve(prompt_tgt.size() + 1 - cache_size); + for (size_t j = cache_size; j < prompt_tgt.size(); ++j) { + tokens_new.push_back(prompt_tgt[j]); + } + tokens_new.push_back(id_last); + + tree.extend(tokens_new.data(), (int)tokens_new.size()); + cache_size = prompt_tgt.size() + 1; + } + + const int ctx_len = std::min((int)(prompt_tgt.size() + 1), tree.max_depth()); + llama_tokens context; + context.reserve(ctx_len); + const int ctx_start = (int)prompt_tgt.size() + 1 - ctx_len; + for (int j = ctx_start; j < (int)prompt_tgt.size(); ++j) { + context.push_back(prompt_tgt[j]); + } + context.push_back(id_last); + const int min_match_len = std::max(1, params.suffix_min_match_len); + + result = tree.speculate( + context.data(), (int)context.size(), + params.n_max, + eff_p_min, + 1, + min_match_len); + + if (has_corpus) { + auto corpus_result = corpus_tree.speculate( + context.data(), (int)context.size(), + params.n_max, + eff_p_min, + 1, + min_match_len); + if (corpus_result.size() > result.size()) { + result = std::move(corpus_result); + } + } + + n_draft_last = result.size(); + } + + void accept(uint16_t n_accepted) override { + if (n_draft_last == 0) { + return; + } + had_accept = true; + const double f_acc = (double)n_accepted / (double)n_draft_last; + if (f_acc < 0.5) { + if (++n_low >= 3) { + eff_p_min = std::min(eff_p_min + 0.1f, 0.5f); + n_low = 0; + } + } else { + n_low = 0; + if (eff_p_min > base_p_min) { + eff_p_min = std::max(eff_p_min - 0.05f, base_p_min); + } + } + } +}; + struct common_speculative { std::vector> impls; // list of implementations to use and their states common_speculative_state * curr_impl = nullptr; // current implementation in use (for stats) @@ -843,6 +966,7 @@ std::string common_speculative_type_to_str(enum common_speculative_type type) { case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: return "ngram_map_k4v"; case COMMON_SPECULATIVE_TYPE_NGRAM_MOD: return "ngram_mod"; case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: return "ngram_cache"; + case COMMON_SPECULATIVE_TYPE_SUFFIX: return "suffix"; default: return "unknown"; } } @@ -912,6 +1036,7 @@ common_speculative * common_speculative_init( bool has_ngram_map_k = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K); bool has_ngram_map_k4v = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V); bool has_ngram_mod = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_MOD); + bool has_suffix = (params.type == COMMON_SPECULATIVE_TYPE_SUFFIX); // In a more complex implementation we could use the same implementation but with different parameters. // This was initially used in PR-18471 but removed to simplify the code. @@ -945,6 +1070,9 @@ common_speculative * common_speculative_init( if (has_ngram_cache) { configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, params)); } + if (has_suffix) { + configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_SUFFIX, params)); + } if (has_mtp) { configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_MTP, params)); } @@ -1023,6 +1151,13 @@ common_speculative * common_speculative_init( impls.push_back(std::make_unique(state)); break; } + case COMMON_SPECULATIVE_TYPE_SUFFIX: { + int depth = config.params.suffix_max_depth > 0 ? config.params.suffix_max_depth : 64; + const llama_model * model = llama_get_model(ctx_tgt); + impls.push_back(std::make_unique( + config.type, depth, config.params.suffix_corpus, model)); + break; + } default: break; } diff --git a/common/suffix-tree.cpp b/common/suffix-tree.cpp new file mode 100644 index 00000000..1f3152a2 --- /dev/null +++ b/common/suffix-tree.cpp @@ -0,0 +1,260 @@ +#include "suffix-tree.h" +#include "log.h" + +#include +#include +#include +#include +#include + +using json = nlohmann::json; + +common_suffix_tree::common_suffix_tree(int max_depth) + : _max_depth(max_depth) + , _root(std::make_unique()) +{} + +common_suffix_tree::~common_suffix_tree() = default; + +void common_suffix_tree::clear() { + _root = std::make_unique(); + _tokens.clear(); + _n_inserted = 0; +} + +void common_suffix_tree::extend(const llama_token * tokens, int n_tokens) { + if (n_tokens <= 0) return; + + const int old_size = (int)_tokens.size(); + _tokens.insert(_tokens.end(), tokens, tokens + n_tokens); + const int new_size = (int)_tokens.size(); + + // Insert/update suffixes that are affected by the new tokens. + // For any position i, the suffix covers tokens[i .. min(i+max_depth, end)]. + // Positions within max_depth of the old end had truncated suffixes that + // can now be extended with new tokens. + const int reinsert_from = std::max(0, old_size - _max_depth); + + for (int i = reinsert_from; i < new_size; ++i) { + if (i < _n_inserted) { + const int old_len = std::min(old_size - i, _max_depth); + const int new_len = std::min(new_size - i, _max_depth); + if (new_len > old_len) { + _extend_suffix(i, old_len, new_len); + } + } else { + _insert_suffix(i); + } + } + + _n_inserted = new_size; +} + +void common_suffix_tree::_insert_suffix(int start_pos) { + const int total = (int)_tokens.size(); + const int len = std::min(total - start_pos, _max_depth); + if (len <= 0) return; + + common_suffix_node * node = _root.get(); + + for (int i = 0; i < len; ++i) { + const llama_token tok = _tokens[start_pos + i]; + auto it = node->children.find(tok); + if (it == node->children.end()) { + auto child = std::make_unique(); + auto * child_ptr = child.get(); + child_ptr->count = 1; + node->children[tok] = std::move(child); + node = child_ptr; + } else { + node = it->second.get(); + node->count++; + } + } +} + +void common_suffix_tree::_extend_suffix(int start_pos, int old_len, int new_len) { + common_suffix_node * node = _root.get(); + + for (int i = 0; i < old_len; ++i) { + const llama_token tok = _tokens[start_pos + i]; + auto it = node->children.find(tok); + if (it == node->children.end()) { + return; + } + node = it->second.get(); + } + + for (int i = old_len; i < new_len; ++i) { + const llama_token tok = _tokens[start_pos + i]; + auto it = node->children.find(tok); + if (it == node->children.end()) { + auto child = std::make_unique(); + auto * child_ptr = child.get(); + child_ptr->count = 1; + node->children[tok] = std::move(child); + node = child_ptr; + } else { + node = it->second.get(); + node->count++; + } + } +} + +std::vector common_suffix_tree::speculate( + const llama_token * context, int n_context, + int max_spec_tokens, + float min_token_prob, + int min_match_count, + int min_match_len) const { + + std::vector best_draft; + + if (!_root || n_context <= 0 || max_spec_tokens <= 0) return best_draft; + + if (n_context > _max_depth) { + context += (n_context - _max_depth); + n_context = _max_depth; + } + + float best_score = 0.0f; + + for (int match_len = std::max(1, min_match_len); match_len <= n_context; ++match_len) { + const llama_token * ctx = context + (n_context - match_len); + + const common_suffix_node * node = _root.get(); + bool matched = true; + for (int i = 0; i < match_len; ++i) { + auto it = node->children.find(ctx[i]); + if (it == node->children.end()) { + matched = false; + break; + } + node = it->second.get(); + } + + if (!matched) break; + if (node->count < min_match_count) continue; + if (node->children.empty()) continue; + + // Speculate: greedily follow highest-count child + // Probability decays multiplicatively: prob *= child_count / parent_count + const int draft_limit = std::min(max_spec_tokens, match_len + 8); + + std::vector draft; + float score = 0.0f; + float prob = 1.0f; + const common_suffix_node * cur = node; + + for (int i = 0; i < draft_limit; ++i) { + if (cur->children.empty()) break; + + llama_token best_tok = -1; + int64_t best_count = 0; + for (const auto & [token, child] : cur->children) { + if (child->count > best_count) { + best_count = child->count; + best_tok = token; + } + } + + prob *= (float)best_count / (float)cur->count; + if (prob < min_token_prob) break; + + score += prob; + draft.push_back(best_tok); + cur = cur->children.at(best_tok).get(); + } + + if (score > best_score && !draft.empty()) { + best_score = score; + best_draft = std::move(draft); + } + } + + return best_draft; +} + +static void _extract_texts(const json & node, std::vector & out) { + if (node.is_string()) { + const std::string s = node.get(); + if (!s.empty()) out.push_back(s); + } else if (node.is_array()) { + for (const auto & item : node) { + _extract_texts(item, out); + } + } else if (node.is_object()) { + if (node.contains("content") && node["content"].is_string()) { + const std::string s = node["content"].get(); + if (!s.empty()) out.push_back(s); + } else if (node.contains("messages")) { + _extract_texts(node["messages"], out); + } + } +} + +bool common_suffix_tree::load_corpus( + const std::string & path, + std::function(const std::string &)> tokenize_fn) { + + bool is_json = path.size() >= 5 && + path.compare(path.size() - 5, 5, ".json") == 0; + + if (is_json) { + if (!tokenize_fn) { + LOG_ERR("%s: JSON corpus requires a tokenizer but none was provided (path: '%s')\n", + __func__, path.c_str()); + return false; + } + std::ifstream f(path); + if (!f.is_open()) { + LOG_ERR("%s: failed to open corpus file '%s'\n", __func__, path.c_str()); + return false; + } + json root; + try { + f >> root; + } catch (const json::exception & e) { + LOG_ERR("%s: JSON parse error in '%s': %s\n", __func__, path.c_str(), e.what()); + return false; + } + std::vector texts; + _extract_texts(root, texts); + if (texts.empty()) { + LOG_WRN("%s: no text content found in corpus '%s'\n", __func__, path.c_str()); + return false; + } + size_t total_tokens = 0; + for (const auto & text : texts) { + auto tokens = tokenize_fn(text); + if (!tokens.empty()) { + extend(tokens.data(), (int)tokens.size()); + total_tokens += tokens.size(); + } + } + LOG_DBG("%s: loaded JSON corpus — %zu texts, %zu tokens from '%s'\n", + __func__, texts.size(), total_tokens, path.c_str()); + return true; + } + + // Binary format: raw int32 token IDs + FILE * fp = std::fopen(path.c_str(), "rb"); + if (!fp) { + LOG_ERR("%s: failed to open corpus file '%s'\n", __func__, path.c_str()); + return false; + } + std::vector tokens; + int32_t tok; + while (std::fread(&tok, sizeof(tok), 1, fp) == 1) { + tokens.push_back(tok); + } + std::fclose(fp); + if (tokens.empty()) { + LOG_WRN("%s: suffix corpus file '%s' is empty\n", __func__, path.c_str()); + return false; + } + extend(tokens.data(), (int)tokens.size()); + LOG_DBG("%s: loaded binary corpus — %zu tokens from '%s'\n", + __func__, tokens.size(), path.c_str()); + return true; +} diff --git a/common/suffix-tree.h b/common/suffix-tree.h new file mode 100644 index 00000000..553452af --- /dev/null +++ b/common/suffix-tree.h @@ -0,0 +1,62 @@ +#pragma once + +#include "llama.h" + +#include +#include +#include +#include +#include +#include + +// A trie-based suffix tree for suffix-decoding speculative decoding. +// +// Stores all suffixes (up to max_depth) of the token history. +// Used to find matching patterns in context and generate draft tokens +// by following the most frequent continuation path. +// +// Reference: "Suffix Decoding" (Saxena et al., 2024) — arXiv:2411.04975 + +struct common_suffix_node { + int64_t count = 0; + std::unordered_map> children; +}; + +class common_suffix_tree { +public: + explicit common_suffix_tree(int max_depth = 64); + ~common_suffix_tree(); + + // Append tokens to the history and insert new suffixes into the trie. + // Incremental: only processes suffixes that haven't been inserted yet. + void extend(const llama_token * tokens, int n_tokens); + + void clear(); + + // Generate draft tokens by matching the context in the trie. + // Tries multiple context lengths and returns the draft with the best score. + std::vector speculate( + const llama_token * context, int n_context, + int max_spec_tokens, + float min_token_prob = 0.1f, + int min_match_count = 1, + int min_match_len = 5) const; + + // Load an offline corpus to pre-warm the tree before any request. + // Supported formats (.json or .bin) + bool load_corpus( + const std::string & path, + std::function(const std::string &)> tokenize_fn = {}); + + int max_depth() const { return _max_depth; } + int token_count() const { return (int)_tokens.size(); } + +private: + int _max_depth; + std::unique_ptr _root; + std::vector _tokens; + int _n_inserted = 0; + + void _insert_suffix(int start_pos); + void _extend_suffix(int start_pos, int old_len, int new_len); +};