Add Unicode allowlist (#1597)

* initial commit

* cleanup

* fix whitelist arg parsing and simplify keyword search state

* rename white* to allow*

* add vocab_pieces init function, rename update functions, delete accidentally added file

* delete temporary bias code

* auto-generate fill function with script data inside

* deduplicate allowlist unicode rule parsing

* minor cleanup

* delete unnecessary header

* refactor allowlist to support sequential rule sets via keywords

* add early exit for zero-rules case

* delete accidentally added file
This commit is contained in:
dungquixote42 2026-04-10 12:22:57 -04:00 committed by GitHub
parent 5720a4131a
commit 869b83bc49
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 2393 additions and 4 deletions

View File

@ -1119,6 +1119,7 @@ src/unicode.o: \
$(CXX) $(CXXFLAGS) -c $< -o $@
src/unicode-data.o: \
src/unicode-script-data.cpp \
src/unicode-data.cpp \
src/unicode-data.h
$(CXX) $(CXXFLAGS) -c $< -o $@

View File

@ -1698,6 +1698,30 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.banned_n = std::stoi(argv[i]);
return true;
}
if (arg == "--allowlist-unicode-rule") {
CHECK_ARG
if (params.allow_ruless.size() == 0) {
params.allow_ruless.push_back({});
}
params.allow_ruless.back().push_back(argparse_allowlist_unicode_rule(argv[i]));
return true;
}
if (arg == "--allowlist-pieces") {
CHECK_ARG
params.allow_pieces.push_back(argv[i]);
return true;
}
if (arg == "--allowlist-keyword") {
CHECK_ARG
params.allow_kws.push_back(argv[i]);
params.allow_ruless.push_back({});
return true;
}
if (arg == "--allowlist-keyword-delay") {
CHECK_ARG
params.allow_kw_delay = std::stoul(argv[i]);
return true;
}
if (arg == "-ld" || arg == "--logdir") {
CHECK_ARG
params.logdir = argv[i];
@ -2442,9 +2466,16 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
options.push_back({ "*", " --top-n-sigma t", "top-n-sigma parmeter (default: %.1f, 0.0 = disabled)", (double)sparams.top_n_sigma});
options.push_back({ "*", " --adaptive-target", "adaptive-p sampling: (default: %.2f, <0.0 = disabled)", (double)sparams.adaptive_target});
options.push_back({ "*", " --adaptive-decay", "adaptive-p sampling: (default: %.2f)", (double)sparams.adaptive_decay});
options.push_back({ "*", " --adaptive-updt-w-cur", "adaptive-p sampling: (default: %s)", sparams.adaptive_updt_w_cur ? "true" : "false"});
options.push_back({ "*", " --banned-string-file", "file path of the list of banned strings on each line" });
options.push_back({ "*", " --banned-n", "number of tokens banned in the phrase during rewind. -1 means all tokens: (default: %d)",params.banned_n });
options.push_back({ "*", " --adaptive-updt-w-cur", "adaptive-p sampling: (default: %s)", sparams.adaptive_updt_w_cur ? "true" : "false"});
options.push_back({ "*", " --allowlist-unicode-rule",
"rule for allowlisting unicode script and/or codepoints. disabled without any rule. format: `LOWER..UPPER,SCRIPT:BIAS`\n"
"if unspecified: LOWER = 0, UPPER = -1(=max), SCRIPT=\"\", BIAS = 0. at least one of LOWER, UPPER, or SCRIPT is required\n" });
options.push_back({ "*", " --allowlist-pieces", "allowlist each token in argument. inherits max BIAS in --allowlist-unicode-rule. overrides --allowlist-unicode-rule" });
options.push_back({ "*", " --allowlist-keyword", "keyword to expire earlier allowlist rules if matched during generation. does not affect later rules" });
options.push_back({ "*", " --allowlist-keyword-delay",
"# tokens to delay matching for the first keyword (default: %zu)", params.allow_kw_delay });
options.push_back({ "*", " -l TOKEN_ID(+/-)BIAS", "modifies the likelihood of token appearing in the completion,\n"
"i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n"
"or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'" });
@ -4557,3 +4588,37 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false");
fprintf(stream, "display_prompt: %s # default: true\n", params.display_prompt ? "true" : "false");
}
//
// Argparse utils
//
std::tuple<uint32_t, uint32_t, std::string, float> argparse_allowlist_unicode_rule(std::string argstr) {
// format:
// LOWER..UPPER,SCRIPT:BIAS
auto subs = string_split(argstr, ":");
float bias = subs.size() == 1 ? 0 : std::stof(subs[1]);
subs = string_split(subs[0], ",");
std::string script = std::all_of(subs.back().begin(), subs.back().end(), [](char c) {
return std::isalpha(c);
}) ? string_lower(subs.back()) : "*";
if (script == "ascii") {
return { 0x000000, 0x00007F, "*", bias };
}
uint32_t first = 0;
uint32_t last = -1;
if ((script == "*") || (subs.size() > 1)) {
subs = string_split(subs.front(), ".");
if (!subs.front().empty()) {
first = std::stoul(subs.front());
}
if (!subs.back().empty()) {
last = std::stoul(subs.back());
}
}
return { std::min(first, last), std::max(first, last), script, bias };
}

View File

@ -288,6 +288,16 @@ struct gpt_params {
size_t n_buffer = 0; // number of token buffers for string ban
bool can_ban_phrases = true; // whether to ban strings
std::vector<std::vector<std::tuple<
uint32_t // lower codepoint
,uint32_t // upper codepoint
,std::string // unicode script name
,float // bias
>>> allow_ruless;
std::vector<std::string> allow_pieces; // each token to allowlist
std::vector<std::string> allow_kws; // keywords
size_t allow_kw_delay; // minimum n_decoded before first keyword is active
std::vector<llama_model_kv_override> kv_overrides;
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
std::vector<std::pair<int,int>> offload_policy;
@ -735,3 +745,9 @@ void yaml_dump_non_result_info(
const std::string & timestamp, const std::vector<int> & prompt_tokens, const char * model_desc);
std::string string_format(const char* fmt, ...);
//
// Argparse utils
//
std::tuple<uint32_t, uint32_t, std::string, float> argparse_allowlist_unicode_rule(std::string argstr);

View File

@ -505,8 +505,14 @@ static llama_token_data_array llama_sampling_prepare_impl(
cur.resize(n_vocab);
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
if ((ctx_sampling->server_biases != nullptr) && (ctx_sampling->server_biases->size() == n_vocab)) {
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
cur[token_id] = llama_token_data{token_id, logits[token_id] + ctx_sampling->server_biases->at(token_id), 0.0f};
}
} else {
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
}
}
ctx_sampling->cur_p = { cur.data(), cur.size(), false };

View File

@ -134,6 +134,8 @@ struct common_sampler {
llama_token_data_array cur_p; // current candidates
std::mt19937 rng;
std::vector<float>* server_biases;
};

View File

@ -361,6 +361,7 @@ void server_slot::prompt_load(server_prompt_cache& prompt_cache, const server_to
void server_slot::reset() {
n_prompt_tokens = 0;
last_gentxt_size = 0;
generated_text = "";
truncated = false;
stopped_eos = false;
@ -394,6 +395,12 @@ void server_slot::reset() {
ban_regex.clear();
ban_regex_ci.clear();
allow_ruless.clear();
allow_pieces.clear();
allow_kws.clear();
allow_kw_delay = 0;
allow_idx = 0;
// Reset speculative decoding stats
n_draft_total = 0;
n_draft_accepted = 0;
@ -851,6 +858,19 @@ server_slot* server_context::get_available_slot(const server_task& task) {
return ret;
}
int32_t server_context::populate_vocab_pieces() {
const int32_t n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(model));
if (vocab_pieces.size() == n_vocab) {
return n_vocab;
}
vocab_pieces.clear();
vocab_pieces.reserve(n_vocab);
for (int32_t id = 0; id < n_vocab; ++id) {
vocab_pieces.push_back(common_token_to_piece(ctx, id, true));
}
return n_vocab;
}
bool server_context::launch_slot_with_task(server_slot& slot, server_task& task) {
slot_params defaults;
defaults.speculative = params_base.speculative;
@ -1338,6 +1358,106 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
slot.ban_phrases_bias = json_value(data, "banned_bias", params_base.ban_phrases_bias);
slot.banned_n = json_value(data, "banned_n", params_base.banned_n);
}
do // populate allowlist biases
{
// TODO: JSON parsing for rules and keywords
slot.allow_ruless = params_base.allow_ruless;
if (slot.allow_ruless.size() == 0) {
slot.allow_biasess.clear();
break;
}
slot.allow_kws = params_base.allow_kws;
slot.allow_pieces = params_base.allow_pieces;
const auto& allowlist_piece_array = data.find("allowlist_piece_array");
if (allowlist_piece_array != data.end() && allowlist_piece_array->is_array()) {
slot.allow_pieces.clear();
for (const auto& piece: *allowlist_piece_array) {
if (piece.is_string()) {
slot.allow_pieces.push_back(piece.get<std::string>());
}
}
}
slot.allow_kw_delay = json_value(data, "allowlist_keyword_delay", params_base.allow_kw_delay);
// end of allowlist criteria update
const int32_t n_vocab = populate_vocab_pieces();
std::unordered_set<llama_token> allow_settoken;
for (const auto& piece: slot.allow_pieces) {
for (const auto token: common_tokenize(model, piece, false, true)) {
allow_settoken.insert(token);
}
}
auto n_rules = slot.allow_ruless.size();
if (n_rules > slot.allow_kws.size() + 1) {
// one more rules than keyword, last rules do not expire
n_rules = slot.allow_kws.size() + 1;
slot.allow_ruless.resize(n_rules);
} else if (n_rules < slot.allow_kws.size()) {
// every rules expire
slot.allow_kws.resize(n_rules);
}
slot.allow_biasess.resize(n_rules);
for (size_t i = 0; i < n_rules; ++i) {
const auto& rules = slot.allow_ruless[i];
if ((i < slot.allow_ruless_prev.size()) && (rules == slot.allow_ruless_prev[i])) {
continue;
}
LLAMA_LOG_DEBUG("%s: allowlist %zu is new\n", __func__, i);
auto& biases = slot.allow_biasess[i];
biases.resize(n_vocab);
std::vector<uint32_t> cpts;
std::vector<std::string> scripts;
for (size_t id = 0; id < n_vocab; ++id) {
const size_t n_cpt = llama_fill_from_utf8(&vocab_pieces[id], &cpts, &scripts);
float bias = -INFINITY;
// each codepoint must be found in
for (size_t j = 0; j < n_cpt; ++j) {
bool in_rule = false;
// at least one rule
for (const auto& rule: rules) {
const bool in_range = (std::get<0>(rule) <= cpts[j]) && (cpts[j] <= std::get<1>(rule));
in_rule = in_range && ((std::get<2>(rule) == "*") || std::get<2>(rule) == scripts[j]);
if (in_rule) {
// earlier rule has higher priority
bias = std::max(bias, std::get<3>(rule));
break;
}
}
if (!in_rule) {
if ((scripts[j] == "common") || (scripts[j] == "inherited")) {
// for common or inherited codepoints (e.g. whitespace), defer to other codepoints in the token
continue;
}
// to shadow realm
bias = -INFINITY;
break;
}
}
biases[id] = bias;
}
float max_bias = -INFINITY;
for (const auto& rule: rules) {
max_bias = std::max(max_bias, std::get<3>(rule));
}
for (const auto token: allow_settoken) {
biases[token] = max_bias;
}
}
} while (false);
slot.allow_ruless_prev = slot.allow_ruless;
if (llama_model_has_recurrent(llama_get_model(slot.ctx))) {
params_base.can_ban_phrases = false;
bool do_checkpoint = params_base.ctx_checkpoints_n > 0;
@ -1498,6 +1618,7 @@ bool server_context::process_token(completion_token_output& result, server_slot&
slot.sampled = result.tok;
// search stop word and delete it
slot.last_gentxt_size = slot.generated_text.size();
slot.generated_text += token_str;
slot.has_next_token = true;
@ -1930,6 +2051,16 @@ void server_context::send_embedding(const server_slot& slot, const llama_batch&
queue_results.send(std::move(res));
}
void server_context::apply_server_biases(server_slot& slot) {
auto& server_biases = slot.ctx_sampling->server_biases;
if (slot.allow_idx < slot.allow_biasess.size()) {
server_biases = &slot.allow_biasess[slot.allow_idx];
} else {
server_biases = nullptr;
}
}
void server_context::request_completion(int id_task, int id_multi, json data, bool infill, bool embedding, server_tokens&& inputs) {
server_task task;
task.id = id_task;
@ -3422,6 +3553,8 @@ void server_context::speculative_decoding_accept() {
size_t n_draft = slot.drafted.size();
apply_server_biases(slot);
// the accepted tokens from the speculation
const auto ids = common_sampler_sample_and_accept_n(slot.ctx_sampling, ctx, slot.i_batch_dft, slot.drafted);
@ -3502,6 +3635,8 @@ void server_context::speculative_decoding_accept() {
}
common_sampler_review(slot.ctx_sampling, slot.token_buffer.size(), slot.rewind_status);
update_allowlist_state(slot);
}
SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int)ids.size() - 1, (int)slot.drafted.size(), slot.n_past);
LOG_VERBOSE("speculative decoding result", {
@ -3677,7 +3812,7 @@ inline void rewind_context(server_slot& slot, int32_t ban_pos) {
size_t n_keep_cache = 0;
if (ban_pos > 0) {
n_keep_cache = (size_t)(ban_pos - 1);
}
}
if (n_keep_cache > slot.cache_tokens.size()) {
n_keep_cache = slot.cache_tokens.size();
@ -3769,6 +3904,25 @@ void server_context::buffer_and_check_string_ban(server_slot & slot, completion_
}
}
void server_context::update_allowlist_state(server_slot& slot) {
const auto& kws = slot.allow_kws;
auto& idx = slot.allow_idx;
if ((slot.allow_kw_delay > slot.n_decoded) || (idx >= kws.size())) {
return;
}
// search for keyword
auto kw = kws[idx];
auto pos = slot.generated_text.find(kw, std::max(0, slot.last_gentxt_size - (int32_t)kw.length() + 1));
while (pos != std::string::npos) {
if (++idx >= kws.size()) {
break;
}
kw = kws[idx];
pos = slot.generated_text.find(kw, pos + 1);
}
}
void server_context::process_batch_tokens(int32_t & n_batch) {
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
@ -3901,6 +4055,8 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
}
}
apply_server_biases(slot);
const llama_token id = common_sampler_sample(slot.ctx_sampling, ctx, tok_idx);
common_sampler_accept(slot.ctx_sampling, ctx, id, true);
@ -3944,6 +4100,8 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
common_sampler_review(slot.ctx_sampling, slot.token_buffer.size(), slot.rewind_status);
update_allowlist_state(slot);
slot.i_batch = -1;
}
if (mtp_warmup_needed && !batch_mtp_hidden_state.empty()) {

View File

@ -64,6 +64,7 @@ struct server_slot {
server_tokens prompt_tokens;
server_tokens cache_tokens;
int32_t last_gentxt_size = 0;
std::string generated_text;
// idx of draft tokens in the main batch
@ -102,6 +103,15 @@ struct server_slot {
int32_t banned_n = 1;
std::map<int32_t, std::set<llama_token>> positional_bans;
// allowlist
std::vector<std::vector<std::tuple<uint32_t, uint32_t, std::string, float>>> allow_ruless_prev;
std::vector<std::vector<std::tuple<uint32_t, uint32_t, std::string, float>>> allow_ruless;
std::vector<std::string> allow_pieces;
std::vector<std::string> allow_kws;
size_t allow_kw_delay = 0;
std::vector<std::vector<float>> allow_biasess;
size_t allow_idx = 0;
server_prompt server_cached_prompt;
void prompt_save(server_prompt_cache& prompt_cache) const;
@ -222,6 +232,8 @@ struct server_context {
std::vector<llama_lora_adapter_container> lora_adapters;
std::vector<control_vector_container> control_vectors;
std::vector<std::string> vocab_pieces;
gpt_params params_base;
llama_batch batch;
@ -284,6 +296,8 @@ struct server_context {
server_slot* get_available_slot(const server_task& task);
int32_t populate_vocab_pieces();
bool launch_slot_with_task(server_slot& slot, server_task& task);
void kv_cache_clear();
@ -313,6 +327,8 @@ struct server_context {
void send_embedding(const server_slot& slot, const llama_batch& batch);
void apply_server_biases(server_slot& slot);
void request_completion(int id_task, int id_multi, json data, bool infill, bool embedding, server_tokens&& inputs);
void request_cancel(int id_task);
@ -361,6 +377,8 @@ struct server_context {
void buffer_and_check_string_ban(server_slot& slot, completion_token_output& result);
void update_allowlist_state(server_slot& slot);
json model_meta() const;
// Re-aggregates all active vectors and updates the model state

View File

@ -1559,4 +1559,6 @@ llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_
#endif // LLAMA_API_INTERNAL
size_t llama_fill_from_utf8(void* utf8, void* cpts, void* scripts);
#endif // LLAMA_H

View File

@ -0,0 +1,110 @@
from collections import defaultdict
import requests
MAX_CODEPOINTS = 0x110000
SCRIPT_DATA_URL = "https://www.unicode.org/Public/UCD/latest/ucd/Scripts.txt"
res = requests.get(SCRIPT_DATA_URL)
res.raise_for_status()
data = res.content.decode()
cptL_cptU_script = []
for line in data.splitlines():
line = line.split()
if len(line) <= 1 or line[0] == "#":
continue
cpt = line[0].split("..")
if len(cpt) == 1:
cpt += cpt
cpt_lower, cpt_upper = cpt
cpt_lower = int(cpt_lower, 16)
if cpt_lower >= MAX_CODEPOINTS:
break
cpt_upper = int(cpt_upper, 16)
if cpt_upper >= MAX_CODEPOINTS:
break
assert line[1] == ";"
script = line[2].lower()
assert line[3] == "#"
# categ = line[4]
# assert len(categ) == 2
cptL_cptU_script.append([cpt_lower, cpt_upper, script])
cptL_cptU_script.sort(key=lambda x: x[0]) # just in case
# merge neighboring codepoints that belong to same script
im = 0 # merge index
for cpt_lower, cpt_upper, script in cptL_cptU_script[1:]:
if (cptL_cptU_script[im][2] == script) and (cptL_cptU_script[im][1] + 1 == cpt_lower):
cptL_cptU_script[im][1] = cpt_upper
else:
im += 1
cptL_cptU_script[im] = [cpt_lower, cpt_upper, script]
del cptL_cptU_script[im + 1:]
def out(line=""):
print(line, end='\n') # noqa
# Generate 'unicode-script-data.cpp':
# python scripts/gen-unicode-script-data.py > src/unicode-script-data.cpp
out("""\
// generated with scripts/gen-unicode-script-data.py
#include "unicode.h"
#include "unicode-data.h"
""")
out("""\
size_t unicode_fill_from_utf8(std::string* utf8, std::vector<uint32_t>* dst_cpts, std::vector<std::string>* dst_scripts) {
if (utf8 == nullptr) {
return 0;
}
""")
out("static const std::vector<std::string> unicode_scripts = {")
for _, _, script in cptL_cptU_script:
out(" \"%s\"," % script)
out("};")
out("static const std::vector<uint32_t> unicode_script_lasts = {")
for _, cpt_upper, _ in cptL_cptU_script:
out(" 0x%06X," % cpt_upper)
out("};")
out("""\
const auto cpts = unicode_cpts_from_utf8(*utf8);
const size_t n_cpt = cpts.size();
std::vector<std::string> scripts;
scripts.reserve(n_cpt);
for (const auto& cpt: cpts) {
const auto it = std::lower_bound(unicode_script_lasts.begin(), unicode_script_lasts.end(), cpt);
if (it != unicode_script_lasts.end()) {
scripts.push_back(unicode_scripts[std::distance(unicode_script_lasts.begin(), it)]);
}
}
if (dst_cpts != nullptr) {
*dst_cpts = cpts;
}
if (dst_scripts != nullptr) {
*dst_scripts = scripts;
}
return n_cpt;
}
""")

View File

@ -60,6 +60,7 @@ add_library(llama
unicode.h
unicode.cpp
unicode-data.cpp
unicode-script-data.cpp
)
target_include_directories(llama PUBLIC . ../include)

View File

@ -9122,3 +9122,6 @@ void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float
ctx->draft_input_hidden_state = hidden_state;
}
size_t llama_fill_from_utf8(void* utf8, void* cpts, void* scripts) {
return unicode_fill_from_utf8((std::string*)utf8, (std::vector<uint32_t>*)cpts, (std::vector<std::string>*)scripts);
}

2005
src/unicode-script-data.cpp Normal file

File diff suppressed because it is too large Load Diff

View File

@ -88,6 +88,8 @@ struct unicode_cpt_flags {
}
};
size_t unicode_fill_from_utf8(std::string* utf8, std::vector<uint32_t>* dst_cpts, std::vector<std::string>* dst_scripts);
size_t unicode_len_utf8(char src);
std::string unicode_cpt_to_utf8 (uint32_t cpt);