mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-06-28 04:30:15 -05:00
Add Unicode allowlist (#1597)
* initial commit * cleanup * fix whitelist arg parsing and simplify keyword search state * rename white* to allow* * add vocab_pieces init function, rename update functions, delete accidentally added file * delete temporary bias code * auto-generate fill function with script data inside * deduplicate allowlist unicode rule parsing * minor cleanup * delete unnecessary header * refactor allowlist to support sequential rule sets via keywords * add early exit for zero-rules case * delete accidentally added file
This commit is contained in:
parent
5720a4131a
commit
869b83bc49
1
Makefile
1
Makefile
@ -1119,6 +1119,7 @@ src/unicode.o: \
|
||||
$(CXX) $(CXXFLAGS) -c $< -o $@
|
||||
|
||||
src/unicode-data.o: \
|
||||
src/unicode-script-data.cpp \
|
||||
src/unicode-data.cpp \
|
||||
src/unicode-data.h
|
||||
$(CXX) $(CXXFLAGS) -c $< -o $@
|
||||
|
||||
@ -1698,6 +1698,30 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
||||
params.banned_n = std::stoi(argv[i]);
|
||||
return true;
|
||||
}
|
||||
if (arg == "--allowlist-unicode-rule") {
|
||||
CHECK_ARG
|
||||
if (params.allow_ruless.size() == 0) {
|
||||
params.allow_ruless.push_back({});
|
||||
}
|
||||
params.allow_ruless.back().push_back(argparse_allowlist_unicode_rule(argv[i]));
|
||||
return true;
|
||||
}
|
||||
if (arg == "--allowlist-pieces") {
|
||||
CHECK_ARG
|
||||
params.allow_pieces.push_back(argv[i]);
|
||||
return true;
|
||||
}
|
||||
if (arg == "--allowlist-keyword") {
|
||||
CHECK_ARG
|
||||
params.allow_kws.push_back(argv[i]);
|
||||
params.allow_ruless.push_back({});
|
||||
return true;
|
||||
}
|
||||
if (arg == "--allowlist-keyword-delay") {
|
||||
CHECK_ARG
|
||||
params.allow_kw_delay = std::stoul(argv[i]);
|
||||
return true;
|
||||
}
|
||||
if (arg == "-ld" || arg == "--logdir") {
|
||||
CHECK_ARG
|
||||
params.logdir = argv[i];
|
||||
@ -2442,9 +2466,16 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
|
||||
options.push_back({ "*", " --top-n-sigma t", "top-n-sigma parmeter (default: %.1f, 0.0 = disabled)", (double)sparams.top_n_sigma});
|
||||
options.push_back({ "*", " --adaptive-target", "adaptive-p sampling: (default: %.2f, <0.0 = disabled)", (double)sparams.adaptive_target});
|
||||
options.push_back({ "*", " --adaptive-decay", "adaptive-p sampling: (default: %.2f)", (double)sparams.adaptive_decay});
|
||||
options.push_back({ "*", " --adaptive-updt-w-cur", "adaptive-p sampling: (default: %s)", sparams.adaptive_updt_w_cur ? "true" : "false"});
|
||||
options.push_back({ "*", " --banned-string-file", "file path of the list of banned strings on each line" });
|
||||
options.push_back({ "*", " --banned-n", "number of tokens banned in the phrase during rewind. -1 means all tokens: (default: %d)",params.banned_n });
|
||||
options.push_back({ "*", " --adaptive-updt-w-cur", "adaptive-p sampling: (default: %s)", sparams.adaptive_updt_w_cur ? "true" : "false"});
|
||||
options.push_back({ "*", " --allowlist-unicode-rule",
|
||||
"rule for allowlisting unicode script and/or codepoints. disabled without any rule. format: `LOWER..UPPER,SCRIPT:BIAS`\n"
|
||||
"if unspecified: LOWER = 0, UPPER = -1(=max), SCRIPT=\"\", BIAS = 0. at least one of LOWER, UPPER, or SCRIPT is required\n" });
|
||||
options.push_back({ "*", " --allowlist-pieces", "allowlist each token in argument. inherits max BIAS in --allowlist-unicode-rule. overrides --allowlist-unicode-rule" });
|
||||
options.push_back({ "*", " --allowlist-keyword", "keyword to expire earlier allowlist rules if matched during generation. does not affect later rules" });
|
||||
options.push_back({ "*", " --allowlist-keyword-delay",
|
||||
"# tokens to delay matching for the first keyword (default: %zu)", params.allow_kw_delay });
|
||||
options.push_back({ "*", " -l TOKEN_ID(+/-)BIAS", "modifies the likelihood of token appearing in the completion,\n"
|
||||
"i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n"
|
||||
"or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'" });
|
||||
@ -4557,3 +4588,37 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
|
||||
fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false");
|
||||
fprintf(stream, "display_prompt: %s # default: true\n", params.display_prompt ? "true" : "false");
|
||||
}
|
||||
|
||||
//
|
||||
// Argparse utils
|
||||
//
|
||||
|
||||
std::tuple<uint32_t, uint32_t, std::string, float> argparse_allowlist_unicode_rule(std::string argstr) {
|
||||
// format:
|
||||
// LOWER..UPPER,SCRIPT:BIAS
|
||||
|
||||
auto subs = string_split(argstr, ":");
|
||||
float bias = subs.size() == 1 ? 0 : std::stof(subs[1]);
|
||||
|
||||
subs = string_split(subs[0], ",");
|
||||
std::string script = std::all_of(subs.back().begin(), subs.back().end(), [](char c) {
|
||||
return std::isalpha(c);
|
||||
}) ? string_lower(subs.back()) : "*";
|
||||
if (script == "ascii") {
|
||||
return { 0x000000, 0x00007F, "*", bias };
|
||||
}
|
||||
|
||||
uint32_t first = 0;
|
||||
uint32_t last = -1;
|
||||
if ((script == "*") || (subs.size() > 1)) {
|
||||
subs = string_split(subs.front(), ".");
|
||||
if (!subs.front().empty()) {
|
||||
first = std::stoul(subs.front());
|
||||
}
|
||||
if (!subs.back().empty()) {
|
||||
last = std::stoul(subs.back());
|
||||
}
|
||||
}
|
||||
|
||||
return { std::min(first, last), std::max(first, last), script, bias };
|
||||
}
|
||||
|
||||
@ -288,6 +288,16 @@ struct gpt_params {
|
||||
size_t n_buffer = 0; // number of token buffers for string ban
|
||||
bool can_ban_phrases = true; // whether to ban strings
|
||||
|
||||
std::vector<std::vector<std::tuple<
|
||||
uint32_t // lower codepoint
|
||||
,uint32_t // upper codepoint
|
||||
,std::string // unicode script name
|
||||
,float // bias
|
||||
>>> allow_ruless;
|
||||
std::vector<std::string> allow_pieces; // each token to allowlist
|
||||
std::vector<std::string> allow_kws; // keywords
|
||||
size_t allow_kw_delay; // minimum n_decoded before first keyword is active
|
||||
|
||||
std::vector<llama_model_kv_override> kv_overrides;
|
||||
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
|
||||
std::vector<std::pair<int,int>> offload_policy;
|
||||
@ -735,3 +745,9 @@ void yaml_dump_non_result_info(
|
||||
const std::string & timestamp, const std::vector<int> & prompt_tokens, const char * model_desc);
|
||||
|
||||
std::string string_format(const char* fmt, ...);
|
||||
|
||||
//
|
||||
// Argparse utils
|
||||
//
|
||||
|
||||
std::tuple<uint32_t, uint32_t, std::string, float> argparse_allowlist_unicode_rule(std::string argstr);
|
||||
|
||||
@ -505,8 +505,14 @@ static llama_token_data_array llama_sampling_prepare_impl(
|
||||
|
||||
cur.resize(n_vocab);
|
||||
|
||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
|
||||
if ((ctx_sampling->server_biases != nullptr) && (ctx_sampling->server_biases->size() == n_vocab)) {
|
||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||
cur[token_id] = llama_token_data{token_id, logits[token_id] + ctx_sampling->server_biases->at(token_id), 0.0f};
|
||||
}
|
||||
} else {
|
||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
|
||||
}
|
||||
}
|
||||
|
||||
ctx_sampling->cur_p = { cur.data(), cur.size(), false };
|
||||
|
||||
@ -134,6 +134,8 @@ struct common_sampler {
|
||||
llama_token_data_array cur_p; // current candidates
|
||||
|
||||
std::mt19937 rng;
|
||||
|
||||
std::vector<float>* server_biases;
|
||||
};
|
||||
|
||||
|
||||
|
||||
@ -361,6 +361,7 @@ void server_slot::prompt_load(server_prompt_cache& prompt_cache, const server_to
|
||||
|
||||
void server_slot::reset() {
|
||||
n_prompt_tokens = 0;
|
||||
last_gentxt_size = 0;
|
||||
generated_text = "";
|
||||
truncated = false;
|
||||
stopped_eos = false;
|
||||
@ -394,6 +395,12 @@ void server_slot::reset() {
|
||||
ban_regex.clear();
|
||||
ban_regex_ci.clear();
|
||||
|
||||
allow_ruless.clear();
|
||||
allow_pieces.clear();
|
||||
allow_kws.clear();
|
||||
allow_kw_delay = 0;
|
||||
allow_idx = 0;
|
||||
|
||||
// Reset speculative decoding stats
|
||||
n_draft_total = 0;
|
||||
n_draft_accepted = 0;
|
||||
@ -851,6 +858,19 @@ server_slot* server_context::get_available_slot(const server_task& task) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
int32_t server_context::populate_vocab_pieces() {
|
||||
const int32_t n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(model));
|
||||
if (vocab_pieces.size() == n_vocab) {
|
||||
return n_vocab;
|
||||
}
|
||||
vocab_pieces.clear();
|
||||
vocab_pieces.reserve(n_vocab);
|
||||
for (int32_t id = 0; id < n_vocab; ++id) {
|
||||
vocab_pieces.push_back(common_token_to_piece(ctx, id, true));
|
||||
}
|
||||
return n_vocab;
|
||||
}
|
||||
|
||||
bool server_context::launch_slot_with_task(server_slot& slot, server_task& task) {
|
||||
slot_params defaults;
|
||||
defaults.speculative = params_base.speculative;
|
||||
@ -1338,6 +1358,106 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
|
||||
slot.ban_phrases_bias = json_value(data, "banned_bias", params_base.ban_phrases_bias);
|
||||
slot.banned_n = json_value(data, "banned_n", params_base.banned_n);
|
||||
}
|
||||
|
||||
do // populate allowlist biases
|
||||
{
|
||||
// TODO: JSON parsing for rules and keywords
|
||||
slot.allow_ruless = params_base.allow_ruless;
|
||||
if (slot.allow_ruless.size() == 0) {
|
||||
slot.allow_biasess.clear();
|
||||
break;
|
||||
}
|
||||
slot.allow_kws = params_base.allow_kws;
|
||||
|
||||
slot.allow_pieces = params_base.allow_pieces;
|
||||
const auto& allowlist_piece_array = data.find("allowlist_piece_array");
|
||||
if (allowlist_piece_array != data.end() && allowlist_piece_array->is_array()) {
|
||||
slot.allow_pieces.clear();
|
||||
for (const auto& piece: *allowlist_piece_array) {
|
||||
if (piece.is_string()) {
|
||||
slot.allow_pieces.push_back(piece.get<std::string>());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
slot.allow_kw_delay = json_value(data, "allowlist_keyword_delay", params_base.allow_kw_delay);
|
||||
// end of allowlist criteria update
|
||||
|
||||
const int32_t n_vocab = populate_vocab_pieces();
|
||||
|
||||
std::unordered_set<llama_token> allow_settoken;
|
||||
for (const auto& piece: slot.allow_pieces) {
|
||||
for (const auto token: common_tokenize(model, piece, false, true)) {
|
||||
allow_settoken.insert(token);
|
||||
}
|
||||
}
|
||||
|
||||
auto n_rules = slot.allow_ruless.size();
|
||||
if (n_rules > slot.allow_kws.size() + 1) {
|
||||
// one more rules than keyword, last rules do not expire
|
||||
n_rules = slot.allow_kws.size() + 1;
|
||||
slot.allow_ruless.resize(n_rules);
|
||||
} else if (n_rules < slot.allow_kws.size()) {
|
||||
// every rules expire
|
||||
slot.allow_kws.resize(n_rules);
|
||||
}
|
||||
slot.allow_biasess.resize(n_rules);
|
||||
|
||||
for (size_t i = 0; i < n_rules; ++i) {
|
||||
const auto& rules = slot.allow_ruless[i];
|
||||
if ((i < slot.allow_ruless_prev.size()) && (rules == slot.allow_ruless_prev[i])) {
|
||||
continue;
|
||||
}
|
||||
LLAMA_LOG_DEBUG("%s: allowlist %zu is new\n", __func__, i);
|
||||
|
||||
auto& biases = slot.allow_biasess[i];
|
||||
biases.resize(n_vocab);
|
||||
|
||||
std::vector<uint32_t> cpts;
|
||||
std::vector<std::string> scripts;
|
||||
for (size_t id = 0; id < n_vocab; ++id) {
|
||||
const size_t n_cpt = llama_fill_from_utf8(&vocab_pieces[id], &cpts, &scripts);
|
||||
float bias = -INFINITY;
|
||||
|
||||
// each codepoint must be found in
|
||||
for (size_t j = 0; j < n_cpt; ++j) {
|
||||
bool in_rule = false;
|
||||
|
||||
// at least one rule
|
||||
for (const auto& rule: rules) {
|
||||
const bool in_range = (std::get<0>(rule) <= cpts[j]) && (cpts[j] <= std::get<1>(rule));
|
||||
in_rule = in_range && ((std::get<2>(rule) == "*") || std::get<2>(rule) == scripts[j]);
|
||||
if (in_rule) {
|
||||
// earlier rule has higher priority
|
||||
bias = std::max(bias, std::get<3>(rule));
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!in_rule) {
|
||||
if ((scripts[j] == "common") || (scripts[j] == "inherited")) {
|
||||
// for common or inherited codepoints (e.g. whitespace), defer to other codepoints in the token
|
||||
continue;
|
||||
}
|
||||
|
||||
// to shadow realm
|
||||
bias = -INFINITY;
|
||||
break;
|
||||
}
|
||||
}
|
||||
biases[id] = bias;
|
||||
}
|
||||
|
||||
float max_bias = -INFINITY;
|
||||
for (const auto& rule: rules) {
|
||||
max_bias = std::max(max_bias, std::get<3>(rule));
|
||||
}
|
||||
for (const auto token: allow_settoken) {
|
||||
biases[token] = max_bias;
|
||||
}
|
||||
}
|
||||
} while (false);
|
||||
slot.allow_ruless_prev = slot.allow_ruless;
|
||||
|
||||
if (llama_model_has_recurrent(llama_get_model(slot.ctx))) {
|
||||
params_base.can_ban_phrases = false;
|
||||
bool do_checkpoint = params_base.ctx_checkpoints_n > 0;
|
||||
@ -1498,6 +1618,7 @@ bool server_context::process_token(completion_token_output& result, server_slot&
|
||||
slot.sampled = result.tok;
|
||||
|
||||
// search stop word and delete it
|
||||
slot.last_gentxt_size = slot.generated_text.size();
|
||||
slot.generated_text += token_str;
|
||||
slot.has_next_token = true;
|
||||
|
||||
@ -1930,6 +2051,16 @@ void server_context::send_embedding(const server_slot& slot, const llama_batch&
|
||||
queue_results.send(std::move(res));
|
||||
}
|
||||
|
||||
void server_context::apply_server_biases(server_slot& slot) {
|
||||
auto& server_biases = slot.ctx_sampling->server_biases;
|
||||
|
||||
if (slot.allow_idx < slot.allow_biasess.size()) {
|
||||
server_biases = &slot.allow_biasess[slot.allow_idx];
|
||||
} else {
|
||||
server_biases = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void server_context::request_completion(int id_task, int id_multi, json data, bool infill, bool embedding, server_tokens&& inputs) {
|
||||
server_task task;
|
||||
task.id = id_task;
|
||||
@ -3422,6 +3553,8 @@ void server_context::speculative_decoding_accept() {
|
||||
|
||||
size_t n_draft = slot.drafted.size();
|
||||
|
||||
apply_server_biases(slot);
|
||||
|
||||
// the accepted tokens from the speculation
|
||||
const auto ids = common_sampler_sample_and_accept_n(slot.ctx_sampling, ctx, slot.i_batch_dft, slot.drafted);
|
||||
|
||||
@ -3502,6 +3635,8 @@ void server_context::speculative_decoding_accept() {
|
||||
}
|
||||
|
||||
common_sampler_review(slot.ctx_sampling, slot.token_buffer.size(), slot.rewind_status);
|
||||
|
||||
update_allowlist_state(slot);
|
||||
}
|
||||
SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int)ids.size() - 1, (int)slot.drafted.size(), slot.n_past);
|
||||
LOG_VERBOSE("speculative decoding result", {
|
||||
@ -3677,7 +3812,7 @@ inline void rewind_context(server_slot& slot, int32_t ban_pos) {
|
||||
size_t n_keep_cache = 0;
|
||||
if (ban_pos > 0) {
|
||||
n_keep_cache = (size_t)(ban_pos - 1);
|
||||
}
|
||||
}
|
||||
|
||||
if (n_keep_cache > slot.cache_tokens.size()) {
|
||||
n_keep_cache = slot.cache_tokens.size();
|
||||
@ -3769,6 +3904,25 @@ void server_context::buffer_and_check_string_ban(server_slot & slot, completion_
|
||||
}
|
||||
}
|
||||
|
||||
void server_context::update_allowlist_state(server_slot& slot) {
|
||||
const auto& kws = slot.allow_kws;
|
||||
auto& idx = slot.allow_idx;
|
||||
if ((slot.allow_kw_delay > slot.n_decoded) || (idx >= kws.size())) {
|
||||
return;
|
||||
}
|
||||
|
||||
// search for keyword
|
||||
auto kw = kws[idx];
|
||||
auto pos = slot.generated_text.find(kw, std::max(0, slot.last_gentxt_size - (int32_t)kw.length() + 1));
|
||||
while (pos != std::string::npos) {
|
||||
if (++idx >= kws.size()) {
|
||||
break;
|
||||
}
|
||||
kw = kws[idx];
|
||||
pos = slot.generated_text.find(kw, pos + 1);
|
||||
}
|
||||
}
|
||||
|
||||
void server_context::process_batch_tokens(int32_t & n_batch) {
|
||||
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
|
||||
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
|
||||
@ -3901,6 +4055,8 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
|
||||
}
|
||||
}
|
||||
|
||||
apply_server_biases(slot);
|
||||
|
||||
const llama_token id = common_sampler_sample(slot.ctx_sampling, ctx, tok_idx);
|
||||
|
||||
common_sampler_accept(slot.ctx_sampling, ctx, id, true);
|
||||
@ -3944,6 +4100,8 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
|
||||
|
||||
common_sampler_review(slot.ctx_sampling, slot.token_buffer.size(), slot.rewind_status);
|
||||
|
||||
update_allowlist_state(slot);
|
||||
|
||||
slot.i_batch = -1;
|
||||
}
|
||||
if (mtp_warmup_needed && !batch_mtp_hidden_state.empty()) {
|
||||
|
||||
@ -64,6 +64,7 @@ struct server_slot {
|
||||
server_tokens prompt_tokens;
|
||||
server_tokens cache_tokens;
|
||||
|
||||
int32_t last_gentxt_size = 0;
|
||||
std::string generated_text;
|
||||
|
||||
// idx of draft tokens in the main batch
|
||||
@ -102,6 +103,15 @@ struct server_slot {
|
||||
int32_t banned_n = 1;
|
||||
std::map<int32_t, std::set<llama_token>> positional_bans;
|
||||
|
||||
// allowlist
|
||||
std::vector<std::vector<std::tuple<uint32_t, uint32_t, std::string, float>>> allow_ruless_prev;
|
||||
std::vector<std::vector<std::tuple<uint32_t, uint32_t, std::string, float>>> allow_ruless;
|
||||
std::vector<std::string> allow_pieces;
|
||||
std::vector<std::string> allow_kws;
|
||||
size_t allow_kw_delay = 0;
|
||||
std::vector<std::vector<float>> allow_biasess;
|
||||
size_t allow_idx = 0;
|
||||
|
||||
server_prompt server_cached_prompt;
|
||||
|
||||
void prompt_save(server_prompt_cache& prompt_cache) const;
|
||||
@ -222,6 +232,8 @@ struct server_context {
|
||||
std::vector<llama_lora_adapter_container> lora_adapters;
|
||||
std::vector<control_vector_container> control_vectors;
|
||||
|
||||
std::vector<std::string> vocab_pieces;
|
||||
|
||||
gpt_params params_base;
|
||||
|
||||
llama_batch batch;
|
||||
@ -284,6 +296,8 @@ struct server_context {
|
||||
|
||||
server_slot* get_available_slot(const server_task& task);
|
||||
|
||||
int32_t populate_vocab_pieces();
|
||||
|
||||
bool launch_slot_with_task(server_slot& slot, server_task& task);
|
||||
|
||||
void kv_cache_clear();
|
||||
@ -313,6 +327,8 @@ struct server_context {
|
||||
|
||||
void send_embedding(const server_slot& slot, const llama_batch& batch);
|
||||
|
||||
void apply_server_biases(server_slot& slot);
|
||||
|
||||
void request_completion(int id_task, int id_multi, json data, bool infill, bool embedding, server_tokens&& inputs);
|
||||
|
||||
void request_cancel(int id_task);
|
||||
@ -361,6 +377,8 @@ struct server_context {
|
||||
|
||||
void buffer_and_check_string_ban(server_slot& slot, completion_token_output& result);
|
||||
|
||||
void update_allowlist_state(server_slot& slot);
|
||||
|
||||
json model_meta() const;
|
||||
|
||||
// Re-aggregates all active vectors and updates the model state
|
||||
|
||||
@ -1559,4 +1559,6 @@ llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_
|
||||
|
||||
#endif // LLAMA_API_INTERNAL
|
||||
|
||||
size_t llama_fill_from_utf8(void* utf8, void* cpts, void* scripts);
|
||||
|
||||
#endif // LLAMA_H
|
||||
|
||||
110
scripts/gen-unicode-script-data.py
Normal file
110
scripts/gen-unicode-script-data.py
Normal file
@ -0,0 +1,110 @@
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
import requests
|
||||
|
||||
MAX_CODEPOINTS = 0x110000
|
||||
|
||||
SCRIPT_DATA_URL = "https://www.unicode.org/Public/UCD/latest/ucd/Scripts.txt"
|
||||
|
||||
|
||||
res = requests.get(SCRIPT_DATA_URL)
|
||||
res.raise_for_status()
|
||||
data = res.content.decode()
|
||||
|
||||
cptL_cptU_script = []
|
||||
for line in data.splitlines():
|
||||
line = line.split()
|
||||
if len(line) <= 1 or line[0] == "#":
|
||||
continue
|
||||
|
||||
cpt = line[0].split("..")
|
||||
if len(cpt) == 1:
|
||||
cpt += cpt
|
||||
cpt_lower, cpt_upper = cpt
|
||||
|
||||
cpt_lower = int(cpt_lower, 16)
|
||||
if cpt_lower >= MAX_CODEPOINTS:
|
||||
break
|
||||
|
||||
cpt_upper = int(cpt_upper, 16)
|
||||
if cpt_upper >= MAX_CODEPOINTS:
|
||||
break
|
||||
|
||||
assert line[1] == ";"
|
||||
|
||||
script = line[2].lower()
|
||||
|
||||
assert line[3] == "#"
|
||||
|
||||
# categ = line[4]
|
||||
# assert len(categ) == 2
|
||||
|
||||
cptL_cptU_script.append([cpt_lower, cpt_upper, script])
|
||||
|
||||
cptL_cptU_script.sort(key=lambda x: x[0]) # just in case
|
||||
|
||||
# merge neighboring codepoints that belong to same script
|
||||
im = 0 # merge index
|
||||
for cpt_lower, cpt_upper, script in cptL_cptU_script[1:]:
|
||||
if (cptL_cptU_script[im][2] == script) and (cptL_cptU_script[im][1] + 1 == cpt_lower):
|
||||
cptL_cptU_script[im][1] = cpt_upper
|
||||
else:
|
||||
im += 1
|
||||
cptL_cptU_script[im] = [cpt_lower, cpt_upper, script]
|
||||
del cptL_cptU_script[im + 1:]
|
||||
|
||||
def out(line=""):
|
||||
print(line, end='\n') # noqa
|
||||
|
||||
# Generate 'unicode-script-data.cpp':
|
||||
# python scripts/gen-unicode-script-data.py > src/unicode-script-data.cpp
|
||||
|
||||
out("""\
|
||||
// generated with scripts/gen-unicode-script-data.py
|
||||
|
||||
#include "unicode.h"
|
||||
#include "unicode-data.h"
|
||||
""")
|
||||
|
||||
out("""\
|
||||
size_t unicode_fill_from_utf8(std::string* utf8, std::vector<uint32_t>* dst_cpts, std::vector<std::string>* dst_scripts) {
|
||||
if (utf8 == nullptr) {
|
||||
return 0;
|
||||
}
|
||||
""")
|
||||
|
||||
out("static const std::vector<std::string> unicode_scripts = {")
|
||||
for _, _, script in cptL_cptU_script:
|
||||
out(" \"%s\"," % script)
|
||||
out("};")
|
||||
|
||||
out("static const std::vector<uint32_t> unicode_script_lasts = {")
|
||||
for _, cpt_upper, _ in cptL_cptU_script:
|
||||
out(" 0x%06X," % cpt_upper)
|
||||
out("};")
|
||||
|
||||
out("""\
|
||||
const auto cpts = unicode_cpts_from_utf8(*utf8);
|
||||
const size_t n_cpt = cpts.size();
|
||||
|
||||
std::vector<std::string> scripts;
|
||||
scripts.reserve(n_cpt);
|
||||
|
||||
for (const auto& cpt: cpts) {
|
||||
const auto it = std::lower_bound(unicode_script_lasts.begin(), unicode_script_lasts.end(), cpt);
|
||||
if (it != unicode_script_lasts.end()) {
|
||||
scripts.push_back(unicode_scripts[std::distance(unicode_script_lasts.begin(), it)]);
|
||||
}
|
||||
}
|
||||
|
||||
if (dst_cpts != nullptr) {
|
||||
*dst_cpts = cpts;
|
||||
}
|
||||
if (dst_scripts != nullptr) {
|
||||
*dst_scripts = scripts;
|
||||
}
|
||||
|
||||
return n_cpt;
|
||||
}
|
||||
""")
|
||||
@ -60,6 +60,7 @@ add_library(llama
|
||||
unicode.h
|
||||
unicode.cpp
|
||||
unicode-data.cpp
|
||||
unicode-script-data.cpp
|
||||
)
|
||||
|
||||
target_include_directories(llama PUBLIC . ../include)
|
||||
|
||||
@ -9122,3 +9122,6 @@ void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float
|
||||
ctx->draft_input_hidden_state = hidden_state;
|
||||
}
|
||||
|
||||
size_t llama_fill_from_utf8(void* utf8, void* cpts, void* scripts) {
|
||||
return unicode_fill_from_utf8((std::string*)utf8, (std::vector<uint32_t>*)cpts, (std::vector<std::string>*)scripts);
|
||||
}
|
||||
|
||||
2005
src/unicode-script-data.cpp
Normal file
2005
src/unicode-script-data.cpp
Normal file
File diff suppressed because it is too large
Load Diff
@ -88,6 +88,8 @@ struct unicode_cpt_flags {
|
||||
}
|
||||
};
|
||||
|
||||
size_t unicode_fill_from_utf8(std::string* utf8, std::vector<uint32_t>* dst_cpts, std::vector<std::string>* dst_scripts);
|
||||
|
||||
size_t unicode_len_utf8(char src);
|
||||
|
||||
std::string unicode_cpt_to_utf8 (uint32_t cpt);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user