mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-06-28 04:30:15 -05:00
Self-decoding: Adds support for suffix decoding (#1646)
* speculative: implement suffix-tree decoder * speculative: add support to cache and tuner
This commit is contained in:
parent
52efa12fda
commit
260622faf6
@ -86,6 +86,8 @@ add_library(${TARGET} STATIC
|
|||||||
unicode.h
|
unicode.h
|
||||||
ngram-mod.cpp
|
ngram-mod.cpp
|
||||||
ngram-mod.h
|
ngram-mod.h
|
||||||
|
suffix-tree.cpp
|
||||||
|
suffix-tree.h
|
||||||
regex-partial.cpp
|
regex-partial.cpp
|
||||||
regex-partial.h
|
regex-partial.h
|
||||||
jinja/lexer.cpp
|
jinja/lexer.cpp
|
||||||
|
|||||||
@ -1070,6 +1070,8 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
|||||||
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V;
|
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V;
|
||||||
} else if (value == "ngram-mod") {
|
} else if (value == "ngram-mod") {
|
||||||
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MOD;
|
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MOD;
|
||||||
|
} else if (value == "suffix") {
|
||||||
|
params.speculative.type = COMMON_SPECULATIVE_TYPE_SUFFIX;
|
||||||
} else {
|
} else {
|
||||||
throw std::invalid_argument("unknown speculative decoding type without draft model");
|
throw std::invalid_argument("unknown speculative decoding type without draft model");
|
||||||
}
|
}
|
||||||
@ -1102,6 +1104,29 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
|||||||
params.speculative.ngram_min_hits = value;
|
params.speculative.ngram_min_hits = value;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
if (arg == "--suffix-pattern-len") {
|
||||||
|
CHECK_ARG
|
||||||
|
int value = std::stoi(argv[i]);
|
||||||
|
if (value < 1) {
|
||||||
|
throw std::invalid_argument("suffix pattern length must be at least 1");
|
||||||
|
}
|
||||||
|
params.speculative.suffix_min_match_len = value;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (arg == "--suffix-max-depth") {
|
||||||
|
CHECK_ARG
|
||||||
|
int value = std::stoi(argv[i]);
|
||||||
|
if (value < 1) {
|
||||||
|
throw std::invalid_argument("suffix max depth must be at least 1");
|
||||||
|
}
|
||||||
|
params.speculative.suffix_max_depth = value;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (arg == "--suffix-corpus") {
|
||||||
|
CHECK_ARG
|
||||||
|
params.speculative.suffix_corpus = argv[i];
|
||||||
|
return true;
|
||||||
|
}
|
||||||
if (arg == "-a" || arg == "--alias") {
|
if (arg == "-a" || arg == "--alias") {
|
||||||
CHECK_ARG
|
CHECK_ARG
|
||||||
params.model_alias = argv[i];
|
params.model_alias = argv[i];
|
||||||
@ -2661,12 +2686,15 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
|
|||||||
"number of tokens to draft for speculative decoding (default: %d)", params.speculative.n_max });
|
"number of tokens to draft for speculative decoding (default: %d)", params.speculative.n_max });
|
||||||
options.push_back({ "*", "--draft-min, --draft-n-min N", "minimum number of draft tokens to use for speculative decoding" });
|
options.push_back({ "*", "--draft-min, --draft-n-min N", "minimum number of draft tokens to use for speculative decoding" });
|
||||||
options.push_back({ "*", "--draft-p-min P", "minimum speculative decoding probability (greedy) (default: %.1f)", (double)params.speculative.p_min });
|
options.push_back({ "*", "--draft-p-min P", "minimum speculative decoding probability (greedy) (default: %.1f)", (double)params.speculative.p_min });
|
||||||
options.push_back({ "*", "--spec-type Name [none | ngram - cache | ngram - simple | ngram - map - k | ngram - map - k4v | ngram - mod]", "type of speculative decoding to use when no draft model is provided (default: %d)\n", (int)params.speculative.type});
|
options.push_back({ "*", "--spec-type Name [none | ngram - cache | ngram - simple | ngram - map - k | ngram - map - k4v | ngram - mod | suffix]", "type of speculative decoding to use when no draft model is provided (default: %d)\n", (int)params.speculative.type});
|
||||||
options.push_back({ "*", "--spec-ngram-size-n N", "ngram size N for ngram-simple/ngram-map speculative decoding, length of lookup n-gram (default: %d)\n",params.speculative.ngram_size_n });
|
options.push_back({ "*", "--spec-ngram-size-n N", "ngram size N for ngram-simple/ngram-map speculative decoding, length of lookup n-gram (default: %d)\n",params.speculative.ngram_size_n });
|
||||||
|
|
||||||
options.push_back({ "*", "--spec-ngram-size-m N", "ngram size M for ngram-simple/ngram-map speculative decoding, length of draft m-gram (default: %d)\n", params.speculative.ngram_size_m });
|
options.push_back({ "*", "--spec-ngram-size-m N", "ngram size M for ngram-simple/ngram-map speculative decoding, length of draft m-gram (default: %d)\n", params.speculative.ngram_size_m });
|
||||||
|
|
||||||
options.push_back({ "*", "--spec-ngram-min-hits N", "minimum hits for ngram-map speculative decoding (default: %d)\n", params.speculative.ngram_min_hits });
|
options.push_back({ "*", "--spec-ngram-min-hits N", "minimum hits for ngram-map speculative decoding (default: %d)\n", params.speculative.ngram_min_hits });
|
||||||
|
options.push_back({ "*", "--suffix-pattern-len N", "minimum context match length for suffix decoding (default: %d)", params.speculative.suffix_min_match_len });
|
||||||
|
options.push_back({ "*", "--suffix-max-depth N", "suffix tree maximum depth for suffix decoding (default: %d)", params.speculative.suffix_max_depth });
|
||||||
|
options.push_back({ "*", "--suffix-corpus PATH", "corpus file to pre-warm the suffix tree: .json (array of strings or conversation messages) or .bin (raw int32 token IDs)" });
|
||||||
options.push_back({ "*", "--spec-autotune", "automatically tune speculative params to maximize tokens/sec" });
|
options.push_back({ "*", "--spec-autotune", "automatically tune speculative params to maximize tokens/sec" });
|
||||||
|
|
||||||
options.push_back({ "retrieval" });
|
options.push_back({ "retrieval" });
|
||||||
|
|||||||
@ -146,6 +146,7 @@ enum common_speculative_type {
|
|||||||
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, // self-speculative decoding with n-gram keys and 4 m-gram values
|
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, // self-speculative decoding with n-gram keys and 4 m-gram values
|
||||||
COMMON_SPECULATIVE_TYPE_NGRAM_MOD,
|
COMMON_SPECULATIVE_TYPE_NGRAM_MOD,
|
||||||
COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, // self-speculative decoding with 3-level n-gram cache
|
COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, // self-speculative decoding with 3-level n-gram cache
|
||||||
|
COMMON_SPECULATIVE_TYPE_SUFFIX, // self-speculative suffix-decoding (arXiv:2411.04975)
|
||||||
COMMON_SPECULATIVE_TYPE_COUNT // number of types, unknown type
|
COMMON_SPECULATIVE_TYPE_COUNT // number of types, unknown type
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -182,6 +183,11 @@ struct common_params_speculative {
|
|||||||
|
|
||||||
std::shared_ptr<common_ngram_mod> ngram_mod;
|
std::shared_ptr<common_ngram_mod> ngram_mod;
|
||||||
|
|
||||||
|
// suffix-decoding specific
|
||||||
|
int32_t suffix_min_match_len = 5; // minimum context match length
|
||||||
|
int32_t suffix_max_depth = 64; // suffix tree maximum depth
|
||||||
|
std::string suffix_corpus; // path to corpus file for offline pre-warming (.json or .bin)
|
||||||
|
|
||||||
std::string lookup_cache_static; // path of static ngram cache file for lookup decoding // NOLINT
|
std::string lookup_cache_static; // path of static ngram cache file for lookup decoding // NOLINT
|
||||||
std::string lookup_cache_dynamic; // path of dynamic ngram cache file for lookup decoding // NOLINT
|
std::string lookup_cache_dynamic; // path of dynamic ngram cache file for lookup decoding // NOLINT
|
||||||
|
|
||||||
|
|||||||
@ -107,12 +107,13 @@ void spec_tuner::reset_exploration() {
|
|||||||
void spec_tuner::write_best(common_params_speculative & params) const {
|
void spec_tuner::write_best(common_params_speculative & params) const {
|
||||||
for (const auto & coord : coords) {
|
for (const auto & coord : coords) {
|
||||||
float val = coord.arms[coord.best_idx].value;
|
float val = coord.arms[coord.best_idx].value;
|
||||||
if (coord.name == "n_max") params.n_max = (int32_t)val;
|
if (coord.name == "n_max") params.n_max = (int32_t)val;
|
||||||
else if (coord.name == "p_min") params.p_min = val;
|
else if (coord.name == "p_min") params.p_min = val;
|
||||||
else if (coord.name == "n_min") params.n_min = (int32_t)val;
|
else if (coord.name == "n_min") params.n_min = (int32_t)val;
|
||||||
else if (coord.name == "ngram_size_n") params.ngram_size_n = (uint16_t)val;
|
else if (coord.name == "ngram_size_n") params.ngram_size_n = (uint16_t)val;
|
||||||
else if (coord.name == "ngram_size_m") params.ngram_size_m = (uint16_t)val;
|
else if (coord.name == "ngram_size_m") params.ngram_size_m = (uint16_t)val;
|
||||||
else if (coord.name == "ngram_min_hits") params.ngram_min_hits = (uint16_t)val;
|
else if (coord.name == "ngram_min_hits") params.ngram_min_hits = (uint16_t)val;
|
||||||
|
else if (coord.name == "suffix_min_match_len") params.suffix_min_match_len = (int32_t)val;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -155,6 +156,21 @@ void spec_tuner::init(common_speculative_type type, const common_params_speculat
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (type == COMMON_SPECULATIVE_TYPE_SUFFIX) {
|
||||||
|
{
|
||||||
|
spec_tuner_coord coord;
|
||||||
|
coord.name = "p_min";
|
||||||
|
coord.build_grid_float(0.0f, 0.95f, 11, user_params.p_min);
|
||||||
|
coords.push_back(std::move(coord));
|
||||||
|
}
|
||||||
|
{
|
||||||
|
spec_tuner_coord coord;
|
||||||
|
coord.name = "suffix_min_match_len";
|
||||||
|
coord.build_grid_int(1, 12, 1, user_params.suffix_min_match_len);
|
||||||
|
coords.push_back(std::move(coord));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Ngram can change only n_max/n_min per call
|
// Ngram can change only n_max/n_min per call
|
||||||
if (type == COMMON_SPECULATIVE_TYPE_NGRAM_MOD) {
|
if (type == COMMON_SPECULATIVE_TYPE_NGRAM_MOD) {
|
||||||
{
|
{
|
||||||
@ -168,12 +184,13 @@ void spec_tuner::init(common_speculative_type type, const common_params_speculat
|
|||||||
|
|
||||||
for (auto & coord : coords) {
|
for (auto & coord : coords) {
|
||||||
float user_val = 0.0f;
|
float user_val = 0.0f;
|
||||||
if (coord.name == "n_max") user_val = (float)user_params.n_max;
|
if (coord.name == "n_max") user_val = (float)user_params.n_max;
|
||||||
else if (coord.name == "p_min") user_val = user_params.p_min;
|
else if (coord.name == "p_min") user_val = user_params.p_min;
|
||||||
else if (coord.name == "n_min") user_val = (float)user_params.n_min;
|
else if (coord.name == "n_min") user_val = (float)user_params.n_min;
|
||||||
else if (coord.name == "ngram_size_n") user_val = (float)user_params.ngram_size_n;
|
else if (coord.name == "ngram_size_n") user_val = (float)user_params.ngram_size_n;
|
||||||
else if (coord.name == "ngram_size_m") user_val = (float)user_params.ngram_size_m;
|
else if (coord.name == "ngram_size_m") user_val = (float)user_params.ngram_size_m;
|
||||||
else if (coord.name == "ngram_min_hits") user_val = (float)user_params.ngram_min_hits;
|
else if (coord.name == "ngram_min_hits") user_val = (float)user_params.ngram_min_hits;
|
||||||
|
else if (coord.name == "suffix_min_match_len") user_val = (float)user_params.suffix_min_match_len;
|
||||||
|
|
||||||
coord.user_idx = coord.find_nearest_arm(user_val);
|
coord.user_idx = coord.find_nearest_arm(user_val);
|
||||||
coord.best_idx = 0;
|
coord.best_idx = 0;
|
||||||
@ -201,12 +218,13 @@ void spec_tuner::propose(common_params_speculative & params) {
|
|||||||
coord.current_idx = coord.select_epsilon_greedy(epsilon);
|
coord.current_idx = coord.select_epsilon_greedy(epsilon);
|
||||||
|
|
||||||
float val = coord.arms[coord.current_idx].value;
|
float val = coord.arms[coord.current_idx].value;
|
||||||
if (coord.name == "n_max") params.n_max = (int32_t)val;
|
if (coord.name == "n_max") params.n_max = (int32_t)val;
|
||||||
else if (coord.name == "p_min") params.p_min = val;
|
else if (coord.name == "p_min") params.p_min = val;
|
||||||
else if (coord.name == "n_min") params.n_min = (int32_t)val;
|
else if (coord.name == "n_min") params.n_min = (int32_t)val;
|
||||||
else if (coord.name == "ngram_size_n") params.ngram_size_n = (uint16_t)val;
|
else if (coord.name == "ngram_size_n") params.ngram_size_n = (uint16_t)val;
|
||||||
else if (coord.name == "ngram_size_m") params.ngram_size_m = (uint16_t)val;
|
else if (coord.name == "ngram_size_m") params.ngram_size_m = (uint16_t)val;
|
||||||
else if (coord.name == "ngram_min_hits") params.ngram_min_hits = (uint16_t)val;
|
else if (coord.name == "ngram_min_hits") params.ngram_min_hits = (uint16_t)val;
|
||||||
|
else if (coord.name == "suffix_min_match_len") params.suffix_min_match_len = (int32_t)val;
|
||||||
}
|
}
|
||||||
|
|
||||||
enforce_constraints(params);
|
enforce_constraints(params);
|
||||||
@ -346,6 +364,7 @@ void spec_tuner::print_best() const {
|
|||||||
else if (coord.name == "ngram_size_n") oss << "--spec-ngram-size-n ";
|
else if (coord.name == "ngram_size_n") oss << "--spec-ngram-size-n ";
|
||||||
else if (coord.name == "ngram_size_m") oss << "--spec-ngram-size-m ";
|
else if (coord.name == "ngram_size_m") oss << "--spec-ngram-size-m ";
|
||||||
else if (coord.name == "ngram_min_hits") oss << "--spec-ngram-min-hits ";
|
else if (coord.name == "ngram_min_hits") oss << "--spec-ngram-min-hits ";
|
||||||
|
else if (coord.name == "suffix_min_match_len") oss << "--suffix-pattern-len ";
|
||||||
else oss << "--" << coord.name << " ";
|
else oss << "--" << coord.name << " ";
|
||||||
|
|
||||||
if (is_int) oss << (int)coord.arms[coord.best_idx].value << " ";
|
if (is_int) oss << (int)coord.arms[coord.best_idx].value << " ";
|
||||||
|
|||||||
@ -8,6 +8,7 @@
|
|||||||
#include "ngram-map.h"
|
#include "ngram-map.h"
|
||||||
#include "ngram-mod.h"
|
#include "ngram-mod.h"
|
||||||
#include "sampling.h"
|
#include "sampling.h"
|
||||||
|
#include "suffix-tree.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
@ -26,7 +27,8 @@ const std::vector<enum common_speculative_type> common_speculative_types = {
|
|||||||
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K,
|
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K,
|
||||||
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V,
|
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V,
|
||||||
COMMON_SPECULATIVE_TYPE_NGRAM_MOD,
|
COMMON_SPECULATIVE_TYPE_NGRAM_MOD,
|
||||||
COMMON_SPECULATIVE_TYPE_NGRAM_CACHE
|
COMMON_SPECULATIVE_TYPE_NGRAM_CACHE,
|
||||||
|
COMMON_SPECULATIVE_TYPE_SUFFIX
|
||||||
};
|
};
|
||||||
|
|
||||||
const std::map<std::string, enum common_speculative_type> common_speculative_type_from_name_map = {
|
const std::map<std::string, enum common_speculative_type> common_speculative_type_from_name_map = {
|
||||||
@ -38,7 +40,8 @@ const std::map<std::string, enum common_speculative_type> common_speculative_typ
|
|||||||
{"ngram_map_k", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K},
|
{"ngram_map_k", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K},
|
||||||
{"ngram_map_k4v", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V},
|
{"ngram_map_k4v", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V},
|
||||||
{"ngram_mod", COMMON_SPECULATIVE_TYPE_NGRAM_MOD},
|
{"ngram_mod", COMMON_SPECULATIVE_TYPE_NGRAM_MOD},
|
||||||
{"ngram_cache", COMMON_SPECULATIVE_TYPE_NGRAM_CACHE}
|
{"ngram_cache", COMMON_SPECULATIVE_TYPE_NGRAM_CACHE},
|
||||||
|
{"suffix", COMMON_SPECULATIVE_TYPE_SUFFIX}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct common_speculative_config {
|
struct common_speculative_config {
|
||||||
@ -790,6 +793,126 @@ struct common_speculative_state_ngram_cache : public common_speculative_state {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct common_speculative_state_suffix : public common_speculative_state {
|
||||||
|
common_suffix_tree tree;
|
||||||
|
common_suffix_tree corpus_tree;
|
||||||
|
bool has_corpus = false;
|
||||||
|
size_t cache_size = 0;
|
||||||
|
|
||||||
|
// Acceptance feedback
|
||||||
|
size_t n_draft_last = 0;
|
||||||
|
bool had_accept = false;
|
||||||
|
int n_low = 0;
|
||||||
|
float base_p_min = 0.1f;
|
||||||
|
float eff_p_min = 0.1f;
|
||||||
|
|
||||||
|
common_speculative_state_suffix(
|
||||||
|
enum common_speculative_type type,
|
||||||
|
int max_depth,
|
||||||
|
const std::string & corpus_path,
|
||||||
|
const llama_model * model)
|
||||||
|
: common_speculative_state(type)
|
||||||
|
, tree(max_depth)
|
||||||
|
, corpus_tree(max_depth)
|
||||||
|
{
|
||||||
|
if (!corpus_path.empty()) {
|
||||||
|
std::function<std::vector<llama_token>(const std::string &)> tokenize_fn;
|
||||||
|
if (model) {
|
||||||
|
tokenize_fn = [model](const std::string & text) -> std::vector<llama_token> {
|
||||||
|
return common_tokenize(model, text, false, true);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
has_corpus = corpus_tree.load_corpus(corpus_path, tokenize_fn);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void begin(const llama_tokens & prompt) override {
|
||||||
|
cache_size = 0;
|
||||||
|
n_draft_last = 0;
|
||||||
|
had_accept = false;
|
||||||
|
n_low = 0;
|
||||||
|
GGML_UNUSED(prompt);
|
||||||
|
}
|
||||||
|
|
||||||
|
void draft(
|
||||||
|
const common_params_speculative & params,
|
||||||
|
const llama_tokens & prompt_tgt,
|
||||||
|
llama_token id_last,
|
||||||
|
llama_tokens & result) override {
|
||||||
|
|
||||||
|
base_p_min = params.p_min;
|
||||||
|
if (n_draft_last > 0 && !had_accept) {
|
||||||
|
if (++n_low >= 3) {
|
||||||
|
eff_p_min = std::min(eff_p_min + 0.1f, 0.5f);
|
||||||
|
n_low = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
had_accept = false;
|
||||||
|
|
||||||
|
if (cache_size < prompt_tgt.size() + 1) {
|
||||||
|
llama_tokens tokens_new;
|
||||||
|
tokens_new.reserve(prompt_tgt.size() + 1 - cache_size);
|
||||||
|
for (size_t j = cache_size; j < prompt_tgt.size(); ++j) {
|
||||||
|
tokens_new.push_back(prompt_tgt[j]);
|
||||||
|
}
|
||||||
|
tokens_new.push_back(id_last);
|
||||||
|
|
||||||
|
tree.extend(tokens_new.data(), (int)tokens_new.size());
|
||||||
|
cache_size = prompt_tgt.size() + 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int ctx_len = std::min((int)(prompt_tgt.size() + 1), tree.max_depth());
|
||||||
|
llama_tokens context;
|
||||||
|
context.reserve(ctx_len);
|
||||||
|
const int ctx_start = (int)prompt_tgt.size() + 1 - ctx_len;
|
||||||
|
for (int j = ctx_start; j < (int)prompt_tgt.size(); ++j) {
|
||||||
|
context.push_back(prompt_tgt[j]);
|
||||||
|
}
|
||||||
|
context.push_back(id_last);
|
||||||
|
const int min_match_len = std::max(1, params.suffix_min_match_len);
|
||||||
|
|
||||||
|
result = tree.speculate(
|
||||||
|
context.data(), (int)context.size(),
|
||||||
|
params.n_max,
|
||||||
|
eff_p_min,
|
||||||
|
1,
|
||||||
|
min_match_len);
|
||||||
|
|
||||||
|
if (has_corpus) {
|
||||||
|
auto corpus_result = corpus_tree.speculate(
|
||||||
|
context.data(), (int)context.size(),
|
||||||
|
params.n_max,
|
||||||
|
eff_p_min,
|
||||||
|
1,
|
||||||
|
min_match_len);
|
||||||
|
if (corpus_result.size() > result.size()) {
|
||||||
|
result = std::move(corpus_result);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
n_draft_last = result.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
void accept(uint16_t n_accepted) override {
|
||||||
|
if (n_draft_last == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
had_accept = true;
|
||||||
|
const double f_acc = (double)n_accepted / (double)n_draft_last;
|
||||||
|
if (f_acc < 0.5) {
|
||||||
|
if (++n_low >= 3) {
|
||||||
|
eff_p_min = std::min(eff_p_min + 0.1f, 0.5f);
|
||||||
|
n_low = 0;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
n_low = 0;
|
||||||
|
if (eff_p_min > base_p_min) {
|
||||||
|
eff_p_min = std::max(eff_p_min - 0.05f, base_p_min);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
struct common_speculative {
|
struct common_speculative {
|
||||||
std::vector<std::unique_ptr<common_speculative_state>> impls; // list of implementations to use and their states
|
std::vector<std::unique_ptr<common_speculative_state>> impls; // list of implementations to use and their states
|
||||||
common_speculative_state * curr_impl = nullptr; // current implementation in use (for stats)
|
common_speculative_state * curr_impl = nullptr; // current implementation in use (for stats)
|
||||||
@ -843,6 +966,7 @@ std::string common_speculative_type_to_str(enum common_speculative_type type) {
|
|||||||
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: return "ngram_map_k4v";
|
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: return "ngram_map_k4v";
|
||||||
case COMMON_SPECULATIVE_TYPE_NGRAM_MOD: return "ngram_mod";
|
case COMMON_SPECULATIVE_TYPE_NGRAM_MOD: return "ngram_mod";
|
||||||
case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: return "ngram_cache";
|
case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: return "ngram_cache";
|
||||||
|
case COMMON_SPECULATIVE_TYPE_SUFFIX: return "suffix";
|
||||||
default: return "unknown";
|
default: return "unknown";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -912,6 +1036,7 @@ common_speculative * common_speculative_init(
|
|||||||
bool has_ngram_map_k = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K);
|
bool has_ngram_map_k = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K);
|
||||||
bool has_ngram_map_k4v = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V);
|
bool has_ngram_map_k4v = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V);
|
||||||
bool has_ngram_mod = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_MOD);
|
bool has_ngram_mod = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_MOD);
|
||||||
|
bool has_suffix = (params.type == COMMON_SPECULATIVE_TYPE_SUFFIX);
|
||||||
|
|
||||||
// In a more complex implementation we could use the same implementation but with different parameters.
|
// In a more complex implementation we could use the same implementation but with different parameters.
|
||||||
// This was initially used in PR-18471 but removed to simplify the code.
|
// This was initially used in PR-18471 but removed to simplify the code.
|
||||||
@ -945,6 +1070,9 @@ common_speculative * common_speculative_init(
|
|||||||
if (has_ngram_cache) {
|
if (has_ngram_cache) {
|
||||||
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, params));
|
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, params));
|
||||||
}
|
}
|
||||||
|
if (has_suffix) {
|
||||||
|
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_SUFFIX, params));
|
||||||
|
}
|
||||||
if (has_mtp) {
|
if (has_mtp) {
|
||||||
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_MTP, params));
|
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_MTP, params));
|
||||||
}
|
}
|
||||||
@ -1023,6 +1151,13 @@ common_speculative * common_speculative_init(
|
|||||||
impls.push_back(std::make_unique<common_speculative_state_ngram_cache>(state));
|
impls.push_back(std::make_unique<common_speculative_state_ngram_cache>(state));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case COMMON_SPECULATIVE_TYPE_SUFFIX: {
|
||||||
|
int depth = config.params.suffix_max_depth > 0 ? config.params.suffix_max_depth : 64;
|
||||||
|
const llama_model * model = llama_get_model(ctx_tgt);
|
||||||
|
impls.push_back(std::make_unique<common_speculative_state_suffix>(
|
||||||
|
config.type, depth, config.params.suffix_corpus, model));
|
||||||
|
break;
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|||||||
260
common/suffix-tree.cpp
Normal file
260
common/suffix-tree.cpp
Normal file
@ -0,0 +1,260 @@
|
|||||||
|
#include "suffix-tree.h"
|
||||||
|
#include "log.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cmath>
|
||||||
|
#include <cstdio>
|
||||||
|
#include <fstream>
|
||||||
|
#include <nlohmann/json.hpp>
|
||||||
|
|
||||||
|
using json = nlohmann::json;
|
||||||
|
|
||||||
|
common_suffix_tree::common_suffix_tree(int max_depth)
|
||||||
|
: _max_depth(max_depth)
|
||||||
|
, _root(std::make_unique<common_suffix_node>())
|
||||||
|
{}
|
||||||
|
|
||||||
|
common_suffix_tree::~common_suffix_tree() = default;
|
||||||
|
|
||||||
|
void common_suffix_tree::clear() {
|
||||||
|
_root = std::make_unique<common_suffix_node>();
|
||||||
|
_tokens.clear();
|
||||||
|
_n_inserted = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
void common_suffix_tree::extend(const llama_token * tokens, int n_tokens) {
|
||||||
|
if (n_tokens <= 0) return;
|
||||||
|
|
||||||
|
const int old_size = (int)_tokens.size();
|
||||||
|
_tokens.insert(_tokens.end(), tokens, tokens + n_tokens);
|
||||||
|
const int new_size = (int)_tokens.size();
|
||||||
|
|
||||||
|
// Insert/update suffixes that are affected by the new tokens.
|
||||||
|
// For any position i, the suffix covers tokens[i .. min(i+max_depth, end)].
|
||||||
|
// Positions within max_depth of the old end had truncated suffixes that
|
||||||
|
// can now be extended with new tokens.
|
||||||
|
const int reinsert_from = std::max(0, old_size - _max_depth);
|
||||||
|
|
||||||
|
for (int i = reinsert_from; i < new_size; ++i) {
|
||||||
|
if (i < _n_inserted) {
|
||||||
|
const int old_len = std::min(old_size - i, _max_depth);
|
||||||
|
const int new_len = std::min(new_size - i, _max_depth);
|
||||||
|
if (new_len > old_len) {
|
||||||
|
_extend_suffix(i, old_len, new_len);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
_insert_suffix(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
_n_inserted = new_size;
|
||||||
|
}
|
||||||
|
|
||||||
|
void common_suffix_tree::_insert_suffix(int start_pos) {
|
||||||
|
const int total = (int)_tokens.size();
|
||||||
|
const int len = std::min(total - start_pos, _max_depth);
|
||||||
|
if (len <= 0) return;
|
||||||
|
|
||||||
|
common_suffix_node * node = _root.get();
|
||||||
|
|
||||||
|
for (int i = 0; i < len; ++i) {
|
||||||
|
const llama_token tok = _tokens[start_pos + i];
|
||||||
|
auto it = node->children.find(tok);
|
||||||
|
if (it == node->children.end()) {
|
||||||
|
auto child = std::make_unique<common_suffix_node>();
|
||||||
|
auto * child_ptr = child.get();
|
||||||
|
child_ptr->count = 1;
|
||||||
|
node->children[tok] = std::move(child);
|
||||||
|
node = child_ptr;
|
||||||
|
} else {
|
||||||
|
node = it->second.get();
|
||||||
|
node->count++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void common_suffix_tree::_extend_suffix(int start_pos, int old_len, int new_len) {
|
||||||
|
common_suffix_node * node = _root.get();
|
||||||
|
|
||||||
|
for (int i = 0; i < old_len; ++i) {
|
||||||
|
const llama_token tok = _tokens[start_pos + i];
|
||||||
|
auto it = node->children.find(tok);
|
||||||
|
if (it == node->children.end()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
node = it->second.get();
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = old_len; i < new_len; ++i) {
|
||||||
|
const llama_token tok = _tokens[start_pos + i];
|
||||||
|
auto it = node->children.find(tok);
|
||||||
|
if (it == node->children.end()) {
|
||||||
|
auto child = std::make_unique<common_suffix_node>();
|
||||||
|
auto * child_ptr = child.get();
|
||||||
|
child_ptr->count = 1;
|
||||||
|
node->children[tok] = std::move(child);
|
||||||
|
node = child_ptr;
|
||||||
|
} else {
|
||||||
|
node = it->second.get();
|
||||||
|
node->count++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<llama_token> common_suffix_tree::speculate(
|
||||||
|
const llama_token * context, int n_context,
|
||||||
|
int max_spec_tokens,
|
||||||
|
float min_token_prob,
|
||||||
|
int min_match_count,
|
||||||
|
int min_match_len) const {
|
||||||
|
|
||||||
|
std::vector<llama_token> best_draft;
|
||||||
|
|
||||||
|
if (!_root || n_context <= 0 || max_spec_tokens <= 0) return best_draft;
|
||||||
|
|
||||||
|
if (n_context > _max_depth) {
|
||||||
|
context += (n_context - _max_depth);
|
||||||
|
n_context = _max_depth;
|
||||||
|
}
|
||||||
|
|
||||||
|
float best_score = 0.0f;
|
||||||
|
|
||||||
|
for (int match_len = std::max(1, min_match_len); match_len <= n_context; ++match_len) {
|
||||||
|
const llama_token * ctx = context + (n_context - match_len);
|
||||||
|
|
||||||
|
const common_suffix_node * node = _root.get();
|
||||||
|
bool matched = true;
|
||||||
|
for (int i = 0; i < match_len; ++i) {
|
||||||
|
auto it = node->children.find(ctx[i]);
|
||||||
|
if (it == node->children.end()) {
|
||||||
|
matched = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
node = it->second.get();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!matched) break;
|
||||||
|
if (node->count < min_match_count) continue;
|
||||||
|
if (node->children.empty()) continue;
|
||||||
|
|
||||||
|
// Speculate: greedily follow highest-count child
|
||||||
|
// Probability decays multiplicatively: prob *= child_count / parent_count
|
||||||
|
const int draft_limit = std::min(max_spec_tokens, match_len + 8);
|
||||||
|
|
||||||
|
std::vector<llama_token> draft;
|
||||||
|
float score = 0.0f;
|
||||||
|
float prob = 1.0f;
|
||||||
|
const common_suffix_node * cur = node;
|
||||||
|
|
||||||
|
for (int i = 0; i < draft_limit; ++i) {
|
||||||
|
if (cur->children.empty()) break;
|
||||||
|
|
||||||
|
llama_token best_tok = -1;
|
||||||
|
int64_t best_count = 0;
|
||||||
|
for (const auto & [token, child] : cur->children) {
|
||||||
|
if (child->count > best_count) {
|
||||||
|
best_count = child->count;
|
||||||
|
best_tok = token;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
prob *= (float)best_count / (float)cur->count;
|
||||||
|
if (prob < min_token_prob) break;
|
||||||
|
|
||||||
|
score += prob;
|
||||||
|
draft.push_back(best_tok);
|
||||||
|
cur = cur->children.at(best_tok).get();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (score > best_score && !draft.empty()) {
|
||||||
|
best_score = score;
|
||||||
|
best_draft = std::move(draft);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return best_draft;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void _extract_texts(const json & node, std::vector<std::string> & out) {
|
||||||
|
if (node.is_string()) {
|
||||||
|
const std::string s = node.get<std::string>();
|
||||||
|
if (!s.empty()) out.push_back(s);
|
||||||
|
} else if (node.is_array()) {
|
||||||
|
for (const auto & item : node) {
|
||||||
|
_extract_texts(item, out);
|
||||||
|
}
|
||||||
|
} else if (node.is_object()) {
|
||||||
|
if (node.contains("content") && node["content"].is_string()) {
|
||||||
|
const std::string s = node["content"].get<std::string>();
|
||||||
|
if (!s.empty()) out.push_back(s);
|
||||||
|
} else if (node.contains("messages")) {
|
||||||
|
_extract_texts(node["messages"], out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool common_suffix_tree::load_corpus(
|
||||||
|
const std::string & path,
|
||||||
|
std::function<std::vector<llama_token>(const std::string &)> tokenize_fn) {
|
||||||
|
|
||||||
|
bool is_json = path.size() >= 5 &&
|
||||||
|
path.compare(path.size() - 5, 5, ".json") == 0;
|
||||||
|
|
||||||
|
if (is_json) {
|
||||||
|
if (!tokenize_fn) {
|
||||||
|
LOG_ERR("%s: JSON corpus requires a tokenizer but none was provided (path: '%s')\n",
|
||||||
|
__func__, path.c_str());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
std::ifstream f(path);
|
||||||
|
if (!f.is_open()) {
|
||||||
|
LOG_ERR("%s: failed to open corpus file '%s'\n", __func__, path.c_str());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
json root;
|
||||||
|
try {
|
||||||
|
f >> root;
|
||||||
|
} catch (const json::exception & e) {
|
||||||
|
LOG_ERR("%s: JSON parse error in '%s': %s\n", __func__, path.c_str(), e.what());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
std::vector<std::string> texts;
|
||||||
|
_extract_texts(root, texts);
|
||||||
|
if (texts.empty()) {
|
||||||
|
LOG_WRN("%s: no text content found in corpus '%s'\n", __func__, path.c_str());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
size_t total_tokens = 0;
|
||||||
|
for (const auto & text : texts) {
|
||||||
|
auto tokens = tokenize_fn(text);
|
||||||
|
if (!tokens.empty()) {
|
||||||
|
extend(tokens.data(), (int)tokens.size());
|
||||||
|
total_tokens += tokens.size();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
LOG_DBG("%s: loaded JSON corpus — %zu texts, %zu tokens from '%s'\n",
|
||||||
|
__func__, texts.size(), total_tokens, path.c_str());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Binary format: raw int32 token IDs
|
||||||
|
FILE * fp = std::fopen(path.c_str(), "rb");
|
||||||
|
if (!fp) {
|
||||||
|
LOG_ERR("%s: failed to open corpus file '%s'\n", __func__, path.c_str());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
std::vector<llama_token> tokens;
|
||||||
|
int32_t tok;
|
||||||
|
while (std::fread(&tok, sizeof(tok), 1, fp) == 1) {
|
||||||
|
tokens.push_back(tok);
|
||||||
|
}
|
||||||
|
std::fclose(fp);
|
||||||
|
if (tokens.empty()) {
|
||||||
|
LOG_WRN("%s: suffix corpus file '%s' is empty\n", __func__, path.c_str());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
extend(tokens.data(), (int)tokens.size());
|
||||||
|
LOG_DBG("%s: loaded binary corpus — %zu tokens from '%s'\n",
|
||||||
|
__func__, tokens.size(), path.c_str());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
62
common/suffix-tree.h
Normal file
62
common/suffix-tree.h
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "llama.h"
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <functional>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
// A trie-based suffix tree for suffix-decoding speculative decoding.
|
||||||
|
//
|
||||||
|
// Stores all suffixes (up to max_depth) of the token history.
|
||||||
|
// Used to find matching patterns in context and generate draft tokens
|
||||||
|
// by following the most frequent continuation path.
|
||||||
|
//
|
||||||
|
// Reference: "Suffix Decoding" (Saxena et al., 2024) — arXiv:2411.04975
|
||||||
|
|
||||||
|
struct common_suffix_node {
|
||||||
|
int64_t count = 0;
|
||||||
|
std::unordered_map<llama_token, std::unique_ptr<common_suffix_node>> children;
|
||||||
|
};
|
||||||
|
|
||||||
|
class common_suffix_tree {
|
||||||
|
public:
|
||||||
|
explicit common_suffix_tree(int max_depth = 64);
|
||||||
|
~common_suffix_tree();
|
||||||
|
|
||||||
|
// Append tokens to the history and insert new suffixes into the trie.
|
||||||
|
// Incremental: only processes suffixes that haven't been inserted yet.
|
||||||
|
void extend(const llama_token * tokens, int n_tokens);
|
||||||
|
|
||||||
|
void clear();
|
||||||
|
|
||||||
|
// Generate draft tokens by matching the context in the trie.
|
||||||
|
// Tries multiple context lengths and returns the draft with the best score.
|
||||||
|
std::vector<llama_token> speculate(
|
||||||
|
const llama_token * context, int n_context,
|
||||||
|
int max_spec_tokens,
|
||||||
|
float min_token_prob = 0.1f,
|
||||||
|
int min_match_count = 1,
|
||||||
|
int min_match_len = 5) const;
|
||||||
|
|
||||||
|
// Load an offline corpus to pre-warm the tree before any request.
|
||||||
|
// Supported formats (.json or .bin)
|
||||||
|
bool load_corpus(
|
||||||
|
const std::string & path,
|
||||||
|
std::function<std::vector<llama_token>(const std::string &)> tokenize_fn = {});
|
||||||
|
|
||||||
|
int max_depth() const { return _max_depth; }
|
||||||
|
int token_count() const { return (int)_tokens.size(); }
|
||||||
|
|
||||||
|
private:
|
||||||
|
int _max_depth;
|
||||||
|
std::unique_ptr<common_suffix_node> _root;
|
||||||
|
std::vector<llama_token> _tokens;
|
||||||
|
int _n_inserted = 0;
|
||||||
|
|
||||||
|
void _insert_suffix(int start_pos);
|
||||||
|
void _extend_suffix(int start_pos, int old_len, int new_len);
|
||||||
|
};
|
||||||
Loading…
x
Reference in New Issue
Block a user