Self-decoding: Adds support for suffix decoding (#1646)

* speculative: implement suffix-tree decoder

* speculative: add support to cache and tuner
This commit is contained in:
Samuel Oliveira Alves 2026-04-18 11:10:10 -03:00 committed by GitHub
parent 52efa12fda
commit 260622faf6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 533 additions and 21 deletions

View File

@ -86,6 +86,8 @@ add_library(${TARGET} STATIC
unicode.h
ngram-mod.cpp
ngram-mod.h
suffix-tree.cpp
suffix-tree.h
regex-partial.cpp
regex-partial.h
jinja/lexer.cpp

View File

@ -1070,6 +1070,8 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V;
} else if (value == "ngram-mod") {
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MOD;
} else if (value == "suffix") {
params.speculative.type = COMMON_SPECULATIVE_TYPE_SUFFIX;
} else {
throw std::invalid_argument("unknown speculative decoding type without draft model");
}
@ -1102,6 +1104,29 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.speculative.ngram_min_hits = value;
return true;
}
if (arg == "--suffix-pattern-len") {
CHECK_ARG
int value = std::stoi(argv[i]);
if (value < 1) {
throw std::invalid_argument("suffix pattern length must be at least 1");
}
params.speculative.suffix_min_match_len = value;
return true;
}
if (arg == "--suffix-max-depth") {
CHECK_ARG
int value = std::stoi(argv[i]);
if (value < 1) {
throw std::invalid_argument("suffix max depth must be at least 1");
}
params.speculative.suffix_max_depth = value;
return true;
}
if (arg == "--suffix-corpus") {
CHECK_ARG
params.speculative.suffix_corpus = argv[i];
return true;
}
if (arg == "-a" || arg == "--alias") {
CHECK_ARG
params.model_alias = argv[i];
@ -2661,12 +2686,15 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
"number of tokens to draft for speculative decoding (default: %d)", params.speculative.n_max });
options.push_back({ "*", "--draft-min, --draft-n-min N", "minimum number of draft tokens to use for speculative decoding" });
options.push_back({ "*", "--draft-p-min P", "minimum speculative decoding probability (greedy) (default: %.1f)", (double)params.speculative.p_min });
options.push_back({ "*", "--spec-type Name [none | ngram - cache | ngram - simple | ngram - map - k | ngram - map - k4v | ngram - mod]", "type of speculative decoding to use when no draft model is provided (default: %d)\n", (int)params.speculative.type});
options.push_back({ "*", "--spec-type Name [none | ngram - cache | ngram - simple | ngram - map - k | ngram - map - k4v | ngram - mod | suffix]", "type of speculative decoding to use when no draft model is provided (default: %d)\n", (int)params.speculative.type});
options.push_back({ "*", "--spec-ngram-size-n N", "ngram size N for ngram-simple/ngram-map speculative decoding, length of lookup n-gram (default: %d)\n",params.speculative.ngram_size_n });
options.push_back({ "*", "--spec-ngram-size-m N", "ngram size M for ngram-simple/ngram-map speculative decoding, length of draft m-gram (default: %d)\n", params.speculative.ngram_size_m });
options.push_back({ "*", "--spec-ngram-min-hits N", "minimum hits for ngram-map speculative decoding (default: %d)\n", params.speculative.ngram_min_hits });
options.push_back({ "*", "--suffix-pattern-len N", "minimum context match length for suffix decoding (default: %d)", params.speculative.suffix_min_match_len });
options.push_back({ "*", "--suffix-max-depth N", "suffix tree maximum depth for suffix decoding (default: %d)", params.speculative.suffix_max_depth });
options.push_back({ "*", "--suffix-corpus PATH", "corpus file to pre-warm the suffix tree: .json (array of strings or conversation messages) or .bin (raw int32 token IDs)" });
options.push_back({ "*", "--spec-autotune", "automatically tune speculative params to maximize tokens/sec" });
options.push_back({ "retrieval" });

View File

@ -146,6 +146,7 @@ enum common_speculative_type {
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, // self-speculative decoding with n-gram keys and 4 m-gram values
COMMON_SPECULATIVE_TYPE_NGRAM_MOD,
COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, // self-speculative decoding with 3-level n-gram cache
COMMON_SPECULATIVE_TYPE_SUFFIX, // self-speculative suffix-decoding (arXiv:2411.04975)
COMMON_SPECULATIVE_TYPE_COUNT // number of types, unknown type
};
@ -182,6 +183,11 @@ struct common_params_speculative {
std::shared_ptr<common_ngram_mod> ngram_mod;
// suffix-decoding specific
int32_t suffix_min_match_len = 5; // minimum context match length
int32_t suffix_max_depth = 64; // suffix tree maximum depth
std::string suffix_corpus; // path to corpus file for offline pre-warming (.json or .bin)
std::string lookup_cache_static; // path of static ngram cache file for lookup decoding // NOLINT
std::string lookup_cache_dynamic; // path of dynamic ngram cache file for lookup decoding // NOLINT

View File

@ -107,12 +107,13 @@ void spec_tuner::reset_exploration() {
void spec_tuner::write_best(common_params_speculative & params) const {
for (const auto & coord : coords) {
float val = coord.arms[coord.best_idx].value;
if (coord.name == "n_max") params.n_max = (int32_t)val;
else if (coord.name == "p_min") params.p_min = val;
else if (coord.name == "n_min") params.n_min = (int32_t)val;
else if (coord.name == "ngram_size_n") params.ngram_size_n = (uint16_t)val;
else if (coord.name == "ngram_size_m") params.ngram_size_m = (uint16_t)val;
else if (coord.name == "ngram_min_hits") params.ngram_min_hits = (uint16_t)val;
if (coord.name == "n_max") params.n_max = (int32_t)val;
else if (coord.name == "p_min") params.p_min = val;
else if (coord.name == "n_min") params.n_min = (int32_t)val;
else if (coord.name == "ngram_size_n") params.ngram_size_n = (uint16_t)val;
else if (coord.name == "ngram_size_m") params.ngram_size_m = (uint16_t)val;
else if (coord.name == "ngram_min_hits") params.ngram_min_hits = (uint16_t)val;
else if (coord.name == "suffix_min_match_len") params.suffix_min_match_len = (int32_t)val;
}
}
@ -155,6 +156,21 @@ void spec_tuner::init(common_speculative_type type, const common_params_speculat
}
}
if (type == COMMON_SPECULATIVE_TYPE_SUFFIX) {
{
spec_tuner_coord coord;
coord.name = "p_min";
coord.build_grid_float(0.0f, 0.95f, 11, user_params.p_min);
coords.push_back(std::move(coord));
}
{
spec_tuner_coord coord;
coord.name = "suffix_min_match_len";
coord.build_grid_int(1, 12, 1, user_params.suffix_min_match_len);
coords.push_back(std::move(coord));
}
}
// Ngram can change only n_max/n_min per call
if (type == COMMON_SPECULATIVE_TYPE_NGRAM_MOD) {
{
@ -168,12 +184,13 @@ void spec_tuner::init(common_speculative_type type, const common_params_speculat
for (auto & coord : coords) {
float user_val = 0.0f;
if (coord.name == "n_max") user_val = (float)user_params.n_max;
else if (coord.name == "p_min") user_val = user_params.p_min;
else if (coord.name == "n_min") user_val = (float)user_params.n_min;
else if (coord.name == "ngram_size_n") user_val = (float)user_params.ngram_size_n;
else if (coord.name == "ngram_size_m") user_val = (float)user_params.ngram_size_m;
else if (coord.name == "ngram_min_hits") user_val = (float)user_params.ngram_min_hits;
if (coord.name == "n_max") user_val = (float)user_params.n_max;
else if (coord.name == "p_min") user_val = user_params.p_min;
else if (coord.name == "n_min") user_val = (float)user_params.n_min;
else if (coord.name == "ngram_size_n") user_val = (float)user_params.ngram_size_n;
else if (coord.name == "ngram_size_m") user_val = (float)user_params.ngram_size_m;
else if (coord.name == "ngram_min_hits") user_val = (float)user_params.ngram_min_hits;
else if (coord.name == "suffix_min_match_len") user_val = (float)user_params.suffix_min_match_len;
coord.user_idx = coord.find_nearest_arm(user_val);
coord.best_idx = 0;
@ -201,12 +218,13 @@ void spec_tuner::propose(common_params_speculative & params) {
coord.current_idx = coord.select_epsilon_greedy(epsilon);
float val = coord.arms[coord.current_idx].value;
if (coord.name == "n_max") params.n_max = (int32_t)val;
else if (coord.name == "p_min") params.p_min = val;
else if (coord.name == "n_min") params.n_min = (int32_t)val;
else if (coord.name == "ngram_size_n") params.ngram_size_n = (uint16_t)val;
else if (coord.name == "ngram_size_m") params.ngram_size_m = (uint16_t)val;
else if (coord.name == "ngram_min_hits") params.ngram_min_hits = (uint16_t)val;
if (coord.name == "n_max") params.n_max = (int32_t)val;
else if (coord.name == "p_min") params.p_min = val;
else if (coord.name == "n_min") params.n_min = (int32_t)val;
else if (coord.name == "ngram_size_n") params.ngram_size_n = (uint16_t)val;
else if (coord.name == "ngram_size_m") params.ngram_size_m = (uint16_t)val;
else if (coord.name == "ngram_min_hits") params.ngram_min_hits = (uint16_t)val;
else if (coord.name == "suffix_min_match_len") params.suffix_min_match_len = (int32_t)val;
}
enforce_constraints(params);
@ -346,6 +364,7 @@ void spec_tuner::print_best() const {
else if (coord.name == "ngram_size_n") oss << "--spec-ngram-size-n ";
else if (coord.name == "ngram_size_m") oss << "--spec-ngram-size-m ";
else if (coord.name == "ngram_min_hits") oss << "--spec-ngram-min-hits ";
else if (coord.name == "suffix_min_match_len") oss << "--suffix-pattern-len ";
else oss << "--" << coord.name << " ";
if (is_int) oss << (int)coord.arms[coord.best_idx].value << " ";

View File

@ -8,6 +8,7 @@
#include "ngram-map.h"
#include "ngram-mod.h"
#include "sampling.h"
#include "suffix-tree.h"
#include <algorithm>
#include <cstring>
@ -26,7 +27,8 @@ const std::vector<enum common_speculative_type> common_speculative_types = {
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K,
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V,
COMMON_SPECULATIVE_TYPE_NGRAM_MOD,
COMMON_SPECULATIVE_TYPE_NGRAM_CACHE
COMMON_SPECULATIVE_TYPE_NGRAM_CACHE,
COMMON_SPECULATIVE_TYPE_SUFFIX
};
const std::map<std::string, enum common_speculative_type> common_speculative_type_from_name_map = {
@ -38,7 +40,8 @@ const std::map<std::string, enum common_speculative_type> common_speculative_typ
{"ngram_map_k", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K},
{"ngram_map_k4v", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V},
{"ngram_mod", COMMON_SPECULATIVE_TYPE_NGRAM_MOD},
{"ngram_cache", COMMON_SPECULATIVE_TYPE_NGRAM_CACHE}
{"ngram_cache", COMMON_SPECULATIVE_TYPE_NGRAM_CACHE},
{"suffix", COMMON_SPECULATIVE_TYPE_SUFFIX}
};
struct common_speculative_config {
@ -790,6 +793,126 @@ struct common_speculative_state_ngram_cache : public common_speculative_state {
}
};
struct common_speculative_state_suffix : public common_speculative_state {
common_suffix_tree tree;
common_suffix_tree corpus_tree;
bool has_corpus = false;
size_t cache_size = 0;
// Acceptance feedback
size_t n_draft_last = 0;
bool had_accept = false;
int n_low = 0;
float base_p_min = 0.1f;
float eff_p_min = 0.1f;
common_speculative_state_suffix(
enum common_speculative_type type,
int max_depth,
const std::string & corpus_path,
const llama_model * model)
: common_speculative_state(type)
, tree(max_depth)
, corpus_tree(max_depth)
{
if (!corpus_path.empty()) {
std::function<std::vector<llama_token>(const std::string &)> tokenize_fn;
if (model) {
tokenize_fn = [model](const std::string & text) -> std::vector<llama_token> {
return common_tokenize(model, text, false, true);
};
}
has_corpus = corpus_tree.load_corpus(corpus_path, tokenize_fn);
}
}
void begin(const llama_tokens & prompt) override {
cache_size = 0;
n_draft_last = 0;
had_accept = false;
n_low = 0;
GGML_UNUSED(prompt);
}
void draft(
const common_params_speculative & params,
const llama_tokens & prompt_tgt,
llama_token id_last,
llama_tokens & result) override {
base_p_min = params.p_min;
if (n_draft_last > 0 && !had_accept) {
if (++n_low >= 3) {
eff_p_min = std::min(eff_p_min + 0.1f, 0.5f);
n_low = 0;
}
}
had_accept = false;
if (cache_size < prompt_tgt.size() + 1) {
llama_tokens tokens_new;
tokens_new.reserve(prompt_tgt.size() + 1 - cache_size);
for (size_t j = cache_size; j < prompt_tgt.size(); ++j) {
tokens_new.push_back(prompt_tgt[j]);
}
tokens_new.push_back(id_last);
tree.extend(tokens_new.data(), (int)tokens_new.size());
cache_size = prompt_tgt.size() + 1;
}
const int ctx_len = std::min((int)(prompt_tgt.size() + 1), tree.max_depth());
llama_tokens context;
context.reserve(ctx_len);
const int ctx_start = (int)prompt_tgt.size() + 1 - ctx_len;
for (int j = ctx_start; j < (int)prompt_tgt.size(); ++j) {
context.push_back(prompt_tgt[j]);
}
context.push_back(id_last);
const int min_match_len = std::max(1, params.suffix_min_match_len);
result = tree.speculate(
context.data(), (int)context.size(),
params.n_max,
eff_p_min,
1,
min_match_len);
if (has_corpus) {
auto corpus_result = corpus_tree.speculate(
context.data(), (int)context.size(),
params.n_max,
eff_p_min,
1,
min_match_len);
if (corpus_result.size() > result.size()) {
result = std::move(corpus_result);
}
}
n_draft_last = result.size();
}
void accept(uint16_t n_accepted) override {
if (n_draft_last == 0) {
return;
}
had_accept = true;
const double f_acc = (double)n_accepted / (double)n_draft_last;
if (f_acc < 0.5) {
if (++n_low >= 3) {
eff_p_min = std::min(eff_p_min + 0.1f, 0.5f);
n_low = 0;
}
} else {
n_low = 0;
if (eff_p_min > base_p_min) {
eff_p_min = std::max(eff_p_min - 0.05f, base_p_min);
}
}
}
};
struct common_speculative {
std::vector<std::unique_ptr<common_speculative_state>> impls; // list of implementations to use and their states
common_speculative_state * curr_impl = nullptr; // current implementation in use (for stats)
@ -843,6 +966,7 @@ std::string common_speculative_type_to_str(enum common_speculative_type type) {
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: return "ngram_map_k4v";
case COMMON_SPECULATIVE_TYPE_NGRAM_MOD: return "ngram_mod";
case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: return "ngram_cache";
case COMMON_SPECULATIVE_TYPE_SUFFIX: return "suffix";
default: return "unknown";
}
}
@ -912,6 +1036,7 @@ common_speculative * common_speculative_init(
bool has_ngram_map_k = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K);
bool has_ngram_map_k4v = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V);
bool has_ngram_mod = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_MOD);
bool has_suffix = (params.type == COMMON_SPECULATIVE_TYPE_SUFFIX);
// In a more complex implementation we could use the same implementation but with different parameters.
// This was initially used in PR-18471 but removed to simplify the code.
@ -945,6 +1070,9 @@ common_speculative * common_speculative_init(
if (has_ngram_cache) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, params));
}
if (has_suffix) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_SUFFIX, params));
}
if (has_mtp) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_MTP, params));
}
@ -1023,6 +1151,13 @@ common_speculative * common_speculative_init(
impls.push_back(std::make_unique<common_speculative_state_ngram_cache>(state));
break;
}
case COMMON_SPECULATIVE_TYPE_SUFFIX: {
int depth = config.params.suffix_max_depth > 0 ? config.params.suffix_max_depth : 64;
const llama_model * model = llama_get_model(ctx_tgt);
impls.push_back(std::make_unique<common_speculative_state_suffix>(
config.type, depth, config.params.suffix_corpus, model));
break;
}
default:
break;
}

260
common/suffix-tree.cpp Normal file
View File

@ -0,0 +1,260 @@
#include "suffix-tree.h"
#include "log.h"
#include <algorithm>
#include <cmath>
#include <cstdio>
#include <fstream>
#include <nlohmann/json.hpp>
using json = nlohmann::json;
common_suffix_tree::common_suffix_tree(int max_depth)
: _max_depth(max_depth)
, _root(std::make_unique<common_suffix_node>())
{}
common_suffix_tree::~common_suffix_tree() = default;
void common_suffix_tree::clear() {
_root = std::make_unique<common_suffix_node>();
_tokens.clear();
_n_inserted = 0;
}
void common_suffix_tree::extend(const llama_token * tokens, int n_tokens) {
if (n_tokens <= 0) return;
const int old_size = (int)_tokens.size();
_tokens.insert(_tokens.end(), tokens, tokens + n_tokens);
const int new_size = (int)_tokens.size();
// Insert/update suffixes that are affected by the new tokens.
// For any position i, the suffix covers tokens[i .. min(i+max_depth, end)].
// Positions within max_depth of the old end had truncated suffixes that
// can now be extended with new tokens.
const int reinsert_from = std::max(0, old_size - _max_depth);
for (int i = reinsert_from; i < new_size; ++i) {
if (i < _n_inserted) {
const int old_len = std::min(old_size - i, _max_depth);
const int new_len = std::min(new_size - i, _max_depth);
if (new_len > old_len) {
_extend_suffix(i, old_len, new_len);
}
} else {
_insert_suffix(i);
}
}
_n_inserted = new_size;
}
void common_suffix_tree::_insert_suffix(int start_pos) {
const int total = (int)_tokens.size();
const int len = std::min(total - start_pos, _max_depth);
if (len <= 0) return;
common_suffix_node * node = _root.get();
for (int i = 0; i < len; ++i) {
const llama_token tok = _tokens[start_pos + i];
auto it = node->children.find(tok);
if (it == node->children.end()) {
auto child = std::make_unique<common_suffix_node>();
auto * child_ptr = child.get();
child_ptr->count = 1;
node->children[tok] = std::move(child);
node = child_ptr;
} else {
node = it->second.get();
node->count++;
}
}
}
void common_suffix_tree::_extend_suffix(int start_pos, int old_len, int new_len) {
common_suffix_node * node = _root.get();
for (int i = 0; i < old_len; ++i) {
const llama_token tok = _tokens[start_pos + i];
auto it = node->children.find(tok);
if (it == node->children.end()) {
return;
}
node = it->second.get();
}
for (int i = old_len; i < new_len; ++i) {
const llama_token tok = _tokens[start_pos + i];
auto it = node->children.find(tok);
if (it == node->children.end()) {
auto child = std::make_unique<common_suffix_node>();
auto * child_ptr = child.get();
child_ptr->count = 1;
node->children[tok] = std::move(child);
node = child_ptr;
} else {
node = it->second.get();
node->count++;
}
}
}
std::vector<llama_token> common_suffix_tree::speculate(
const llama_token * context, int n_context,
int max_spec_tokens,
float min_token_prob,
int min_match_count,
int min_match_len) const {
std::vector<llama_token> best_draft;
if (!_root || n_context <= 0 || max_spec_tokens <= 0) return best_draft;
if (n_context > _max_depth) {
context += (n_context - _max_depth);
n_context = _max_depth;
}
float best_score = 0.0f;
for (int match_len = std::max(1, min_match_len); match_len <= n_context; ++match_len) {
const llama_token * ctx = context + (n_context - match_len);
const common_suffix_node * node = _root.get();
bool matched = true;
for (int i = 0; i < match_len; ++i) {
auto it = node->children.find(ctx[i]);
if (it == node->children.end()) {
matched = false;
break;
}
node = it->second.get();
}
if (!matched) break;
if (node->count < min_match_count) continue;
if (node->children.empty()) continue;
// Speculate: greedily follow highest-count child
// Probability decays multiplicatively: prob *= child_count / parent_count
const int draft_limit = std::min(max_spec_tokens, match_len + 8);
std::vector<llama_token> draft;
float score = 0.0f;
float prob = 1.0f;
const common_suffix_node * cur = node;
for (int i = 0; i < draft_limit; ++i) {
if (cur->children.empty()) break;
llama_token best_tok = -1;
int64_t best_count = 0;
for (const auto & [token, child] : cur->children) {
if (child->count > best_count) {
best_count = child->count;
best_tok = token;
}
}
prob *= (float)best_count / (float)cur->count;
if (prob < min_token_prob) break;
score += prob;
draft.push_back(best_tok);
cur = cur->children.at(best_tok).get();
}
if (score > best_score && !draft.empty()) {
best_score = score;
best_draft = std::move(draft);
}
}
return best_draft;
}
static void _extract_texts(const json & node, std::vector<std::string> & out) {
if (node.is_string()) {
const std::string s = node.get<std::string>();
if (!s.empty()) out.push_back(s);
} else if (node.is_array()) {
for (const auto & item : node) {
_extract_texts(item, out);
}
} else if (node.is_object()) {
if (node.contains("content") && node["content"].is_string()) {
const std::string s = node["content"].get<std::string>();
if (!s.empty()) out.push_back(s);
} else if (node.contains("messages")) {
_extract_texts(node["messages"], out);
}
}
}
bool common_suffix_tree::load_corpus(
const std::string & path,
std::function<std::vector<llama_token>(const std::string &)> tokenize_fn) {
bool is_json = path.size() >= 5 &&
path.compare(path.size() - 5, 5, ".json") == 0;
if (is_json) {
if (!tokenize_fn) {
LOG_ERR("%s: JSON corpus requires a tokenizer but none was provided (path: '%s')\n",
__func__, path.c_str());
return false;
}
std::ifstream f(path);
if (!f.is_open()) {
LOG_ERR("%s: failed to open corpus file '%s'\n", __func__, path.c_str());
return false;
}
json root;
try {
f >> root;
} catch (const json::exception & e) {
LOG_ERR("%s: JSON parse error in '%s': %s\n", __func__, path.c_str(), e.what());
return false;
}
std::vector<std::string> texts;
_extract_texts(root, texts);
if (texts.empty()) {
LOG_WRN("%s: no text content found in corpus '%s'\n", __func__, path.c_str());
return false;
}
size_t total_tokens = 0;
for (const auto & text : texts) {
auto tokens = tokenize_fn(text);
if (!tokens.empty()) {
extend(tokens.data(), (int)tokens.size());
total_tokens += tokens.size();
}
}
LOG_DBG("%s: loaded JSON corpus — %zu texts, %zu tokens from '%s'\n",
__func__, texts.size(), total_tokens, path.c_str());
return true;
}
// Binary format: raw int32 token IDs
FILE * fp = std::fopen(path.c_str(), "rb");
if (!fp) {
LOG_ERR("%s: failed to open corpus file '%s'\n", __func__, path.c_str());
return false;
}
std::vector<llama_token> tokens;
int32_t tok;
while (std::fread(&tok, sizeof(tok), 1, fp) == 1) {
tokens.push_back(tok);
}
std::fclose(fp);
if (tokens.empty()) {
LOG_WRN("%s: suffix corpus file '%s' is empty\n", __func__, path.c_str());
return false;
}
extend(tokens.data(), (int)tokens.size());
LOG_DBG("%s: loaded binary corpus — %zu tokens from '%s'\n",
__func__, tokens.size(), path.c_str());
return true;
}

62
common/suffix-tree.h Normal file
View File

@ -0,0 +1,62 @@
#pragma once
#include "llama.h"
#include <cstdint>
#include <functional>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
// A trie-based suffix tree for suffix-decoding speculative decoding.
//
// Stores all suffixes (up to max_depth) of the token history.
// Used to find matching patterns in context and generate draft tokens
// by following the most frequent continuation path.
//
// Reference: "Suffix Decoding" (Saxena et al., 2024) — arXiv:2411.04975
struct common_suffix_node {
int64_t count = 0;
std::unordered_map<llama_token, std::unique_ptr<common_suffix_node>> children;
};
class common_suffix_tree {
public:
explicit common_suffix_tree(int max_depth = 64);
~common_suffix_tree();
// Append tokens to the history and insert new suffixes into the trie.
// Incremental: only processes suffixes that haven't been inserted yet.
void extend(const llama_token * tokens, int n_tokens);
void clear();
// Generate draft tokens by matching the context in the trie.
// Tries multiple context lengths and returns the draft with the best score.
std::vector<llama_token> speculate(
const llama_token * context, int n_context,
int max_spec_tokens,
float min_token_prob = 0.1f,
int min_match_count = 1,
int min_match_len = 5) const;
// Load an offline corpus to pre-warm the tree before any request.
// Supported formats (.json or .bin)
bool load_corpus(
const std::string & path,
std::function<std::vector<llama_token>(const std::string &)> tokenize_fn = {});
int max_depth() const { return _max_depth; }
int token_count() const { return (int)_tokens.size(); }
private:
int _max_depth;
std::unique_ptr<common_suffix_node> _root;
std::vector<llama_token> _tokens;
int _n_inserted = 0;
void _insert_suffix(int start_pos);
void _extend_suffix(int start_pos, int old_len, int new_len);
};