suffix-spec: load corpus in chunks (#1721)

This commit is contained in:
Samuel Oliveira Alves 2026-05-04 01:56:07 -03:00 committed by GitHub
parent 418d60a909
commit a342831115
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2,6 +2,7 @@
#include "log.h"
#include <algorithm>
#include <chrono>
#include <cmath>
#include <cstdio>
#include <fstream>
@ -193,10 +194,42 @@ static void _extract_texts(const json & node, std::vector<std::string> & out) {
}
}
namespace {
constexpr size_t SUFFIX_CORPUS_BINARY_CHUNK_TOKENS = 1u << 15;
constexpr uint64_t SUFFIX_CORPUS_MAX_INSERT_WORK = 256ull * 1024ull * 1024ull;
static uint64_t suffix_estimated_insert_work(size_t n_tokens, int max_depth) {
return (uint64_t) n_tokens * (uint64_t) std::max(max_depth, 1);
}
static bool suffix_corpus_check_limit(const std::string & path, size_t n_tokens, int max_depth) {
const uint64_t estimated_work = suffix_estimated_insert_work(n_tokens, max_depth);
if (estimated_work <= SUFFIX_CORPUS_MAX_INSERT_WORK) {
return true;
}
LOG_ERR("load_corpus: refusing suffix corpus '%s' - estimated insert work %llu exceeds limit %llu (tokens=%zu, depth=%d); reduce corpus size or --suffix-max-depth\n",
path.c_str(),
(unsigned long long) estimated_work,
(unsigned long long) SUFFIX_CORPUS_MAX_INSERT_WORK,
n_tokens,
max_depth);
return false;
}
static double suffix_elapsed_ms(const std::chrono::steady_clock::time_point & started) {
return std::chrono::duration<double, std::milli>(std::chrono::steady_clock::now() - started).count();
}
} // namespace
bool common_suffix_tree::load_corpus(
const std::string & path,
std::function<std::vector<llama_token>(const std::string &)> tokenize_fn) {
const auto load_started = std::chrono::steady_clock::now();
bool is_json = path.size() >= 5 &&
path.compare(path.size() - 5, 5, ".json") == 0;
@ -224,16 +257,36 @@ bool common_suffix_tree::load_corpus(
LOG_WRN("%s: no text content found in corpus '%s'\n", __func__, path.c_str());
return false;
}
LOG_INF("load_corpus: loading suffix JSON corpus '%s' (%zu texts, depth=%d)\n",
path.c_str(), texts.size(), _max_depth);
size_t total_tokens = 0;
for (const auto & text : texts) {
for (size_t i = 0; i < texts.size(); ++i) {
const auto & text = texts[i];
auto tokens = tokenize_fn(text);
if (!tokens.empty()) {
extend(tokens.data(), (int)tokens.size());
total_tokens += tokens.size();
const size_t projected_tokens = total_tokens + tokens.size();
if (!suffix_corpus_check_limit(path, projected_tokens, _max_depth)) {
clear();
return false;
}
extend(tokens.data(), (int) tokens.size());
total_tokens = projected_tokens;
}
}
LOG_DBG("%s: loaded JSON corpus — %zu texts, %zu tokens from '%s'\n",
__func__, texts.size(), total_tokens, path.c_str());
if (total_tokens == 0) {
LOG_WRN("%s: no tokens were extracted from suffix corpus '%s'\n",
__func__, path.c_str());
clear();
return false;
}
LOG_INF("load_corpus: done loading suffix JSON corpus '%s' - %zu texts, %zu tokens in %.1f ms\n",
path.c_str(), texts.size(), total_tokens, suffix_elapsed_ms(load_started));
return true;
}
@ -243,18 +296,69 @@ bool common_suffix_tree::load_corpus(
LOG_ERR("%s: failed to open corpus file '%s'\n", __func__, path.c_str());
return false;
}
std::vector<llama_token> tokens;
int32_t tok;
while (std::fread(&tok, sizeof(tok), 1, fp) == 1) {
tokens.push_back(tok);
size_t total_tokens_est = 0;
if (std::fseek(fp, 0, SEEK_END) == 0) {
const long file_size = std::ftell(fp);
if (file_size >= 0) {
total_tokens_est = (size_t) file_size / sizeof(int32_t);
if ((size_t) file_size % sizeof(int32_t) != 0) {
LOG_WRN("%s: suffix corpus '%s' has %zu trailing bytes; ignoring the remainder\n",
__func__, path.c_str(), (size_t) file_size % sizeof(int32_t));
}
}
std::rewind(fp);
}
if (total_tokens_est > 0 && !suffix_corpus_check_limit(path, total_tokens_est, _max_depth)) {
std::fclose(fp);
return false;
}
LOG_INF("load_corpus: loading suffix binary corpus '%s' (%zu tokens, depth=%d)\n",
path.c_str(), total_tokens_est, _max_depth);
std::vector<int32_t> raw_tokens(SUFFIX_CORPUS_BINARY_CHUNK_TOKENS);
std::vector<llama_token> tokens(SUFFIX_CORPUS_BINARY_CHUNK_TOKENS);
size_t total_tokens = 0;
while (true) {
const size_t n_read = std::fread(raw_tokens.data(), sizeof(int32_t), raw_tokens.size(), fp);
if (n_read == 0) {
break;
}
const size_t projected_tokens = total_tokens + n_read;
if (!suffix_corpus_check_limit(path, projected_tokens, _max_depth)) {
std::fclose(fp);
clear();
return false;
}
for (size_t i = 0; i < n_read; ++i) {
tokens[i] = raw_tokens[i];
}
extend(tokens.data(), (int) n_read);
total_tokens = projected_tokens;
}
const bool read_error = std::ferror(fp) != 0;
std::fclose(fp);
if (tokens.empty()) {
if (read_error) {
LOG_ERR("%s: read error while loading suffix corpus '%s'\n", __func__, path.c_str());
clear();
return false;
}
if (total_tokens == 0) {
LOG_WRN("%s: suffix corpus file '%s' is empty\n", __func__, path.c_str());
return false;
}
extend(tokens.data(), (int)tokens.size());
LOG_DBG("%s: loaded binary corpus — %zu tokens from '%s'\n",
__func__, tokens.size(), path.c_str());
LOG_INF("load_corpus: done loading suffix binary corpus '%s' - %zu tokens in %.1f ms\n",
path.c_str(), total_tokens, suffix_elapsed_ms(load_started));
return true;
}