mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-06-28 04:30:15 -05:00
common/grammar: fix grammar parsing issues to prevent stack overflow and hangs (#1822)
* grammar: Fix grammar root symbol check (#19761) * grammar: fix bad check for root symbol, correct error logging * add tests to demonstrate root symbol check failure # Conflicts: # tests/test-grammar-integration.cpp * common/grammar: fix grammar parsing issues to prevent stack overflow and hangs (#18604) * grammar: add test case for nullable symbol loop Reproduce stack overflow (or OOM) with ( [x]* )* found while adding GBNF support to ripgrep-edit. llama-server reproducer: curl \ -X POST \ -d '{ "messages": [{ "role": "user", "content": "write yes" }], "grammar": "root ::= ( [x]* )*" }' \ -H "Content-Type: application/json" \ http://localhost:8811/v1/chat/completions * grammar: prevent stack overflow with nullable symbol loop Fix a potential stack overflow in llama_grammar_advance_stack that could occur when processing grammars with nullable symbols that lead to infinite derivations of empty strings. The fix introduces cycle detection by tracking visited stacks to prevent infinite recursion. rg-edit regexp: llama_grammar_advance_stack rg-edit extra-args: -A20 rg-edit directive: """Rewrite: fix the following segfault: [..] ⚫ Testing segfault. Grammar: root ::= ( [x]* )* root ::= ( [x]* )* Segmentation fault build/bin/test-grammar-integration""" gptel-context: (("~/llama.cpp/src/llama-grammar.cpp") ("~/llama.cpp/tests/test-grammar-integration.cpp") ("~/llama.cpp/grammars/./list.gbnf") ("~/llama.cpp/grammars/./json_arr.gbnf") ("~/llama.cpp/grammars/./json.gbnf") ("~/llama.cpp/grammars/./japanese.gbnf") ("~/llama.cpp/grammars/./english.gbnf") ("~/llama.cpp/grammars/./chess.gbnf") ("~/llama.cpp/grammars/./c.gbnf") ("~/llama.cpp/grammars/./arithmetic.gbnf") ("~/llama.cpp/grammars/./README.md")) * grammar: convert recursive llama_grammar_advance_stack to iterative This change converts the function to an iterative approach using explicit stacks, which prevents deep recursion and eliminates the risk of stack overflow. rg-edit regexp: llama_grammar_advance_stack rg-edit extra-args: -A30 rg-edit directive: """Rewrite: fix the following segfault: [..] ⚫ Testing segfault. Grammar: root ::= ( [x]* )* root ::= ( [x]* )* Segmentation fault build/bin/test-grammar-integration convert from recursive to interactive""" gptel-context: (("~/llama.cpp/src/llama-grammar.cpp") ("~/llama.cpp/tests/test-grammar-integration.cpp") ("~/llama.cpp/grammars/./list.gbnf") ("~/llama.cpp/grammars/./json_arr.gbnf") ("~/llama.cpp/grammars/./json.gbnf") ("~/llama.cpp/grammars/./japanese.gbnf") ("~/llama.cpp/grammars/./english.gbnf") ("~/llama.cpp/grammars/./chess.gbnf") ("~/llama.cpp/grammars/./c.gbnf") ("~/llama.cpp/grammars/./arithmetic.gbnf") ("~/llama.cpp/grammars/./README.md")) v2: Added a `std::set` to perform tree-based lookups with O(N log N) complexity. Testing with a parallel run of `test-grammar-integration` shows a double-digit percentage increase in runtime. An `unordered_set` with O(1) hashing was also evaluated, but the overhead of constructing hash keys from pointers made it significantly slower than the rbtree implementation that only requires an ordering operator. The performance regression in the test suite appears justified by the overall reduction in algorithmic complexity. Co-developed-by: Piotr Wilkin (ilintar) <piotr.wilkin@syndatis.com> * grammar: add test case for hang in repetition grammar processing This commit adds a new test case to the grammar integration tests that specifically targets a hang scenario in the repetition grammar parser found while adding GBNF support to ripgrep-edit. llama-server reproducer: curl \ -X POST \ -d '{ "messages": [{ "role": "user", "content": "write yes" }], "grammar": "root ::= (([^x]*){0,99}){0,99}" }' \ -H "Content-Type: application/json" \ http://localhost:8811/v1/chat/completions * grammar: add repetition threshold check The change introduces a maximum repetition threshold to avoid excessive rule expansion during grammar parsing. When parsing repetition patterns like {m,n}, the parser now calculates the potential number of rules that would be generated and throws an error if the product of previous rules and new rules exceeds the threshold. A test case was added to verify the threshold is properly enforced for deeply nested repetition patterns that would otherwise cause hangs. --------- Co-authored-by: Asbjørn Olling <asbjornolling@gmail.com> Co-authored-by: Andrea Arcangeli <aarcange@redhat.com>
This commit is contained in:
parent
c07a052315
commit
9ad8b8c6db
@ -6,6 +6,7 @@
|
|||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
#include <set>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
|
|
||||||
#define MAX_REPETITION_THRESHOLD 2000
|
#define MAX_REPETITION_THRESHOLD 2000
|
||||||
@ -462,6 +463,7 @@ const char* llama_grammar_parser::parse_sequence(
|
|||||||
bool is_nested) {
|
bool is_nested) {
|
||||||
size_t last_sym_start = rule.size();
|
size_t last_sym_start = rule.size();
|
||||||
const char* pos = src;
|
const char* pos = src;
|
||||||
|
uint64_t n_prev_rules = 1;
|
||||||
|
|
||||||
// use UINT64_MAX as the empty value because we aligned to the proper uint64_t type so -1 can't be used
|
// use UINT64_MAX as the empty value because we aligned to the proper uint64_t type so -1 can't be used
|
||||||
// (though it's technically the same as -1 now)
|
// (though it's technically the same as -1 now)
|
||||||
@ -489,6 +491,18 @@ const char* llama_grammar_parser::parse_sequence(
|
|||||||
// S' ::= S |
|
// S' ::= S |
|
||||||
|
|
||||||
llama_grammar_rule prev_rule(rule.begin() + last_sym_start, rule.end());
|
llama_grammar_rule prev_rule(rule.begin() + last_sym_start, rule.end());
|
||||||
|
// Calculate the total number of rules that will be generated by this repetition
|
||||||
|
uint64_t total_rules = 1; // Start with 1 for the original rule
|
||||||
|
if (!no_max && max_times > 0) {
|
||||||
|
total_rules = max_times;
|
||||||
|
} else if (min_times > 0) {
|
||||||
|
total_rules = min_times;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (n_prev_rules * total_rules >= MAX_REPETITION_THRESHOLD) {
|
||||||
|
throw std::runtime_error("number of rules that are going to be repeated multiplied by the new repetition exceeds sane defaults, please reduce the number of repetitions or rule complexity");
|
||||||
|
}
|
||||||
|
|
||||||
if (min_times == 0) {
|
if (min_times == 0) {
|
||||||
rule.resize(last_sym_start);
|
rule.resize(last_sym_start);
|
||||||
}
|
}
|
||||||
@ -517,12 +531,15 @@ const char* llama_grammar_parser::parse_sequence(
|
|||||||
if (n_opt > 0) {
|
if (n_opt > 0) {
|
||||||
rule.push_back({ LLAMA_GRETYPE_RULE_REF, last_rec_rule_id });
|
rule.push_back({ LLAMA_GRETYPE_RULE_REF, last_rec_rule_id });
|
||||||
}
|
}
|
||||||
|
n_prev_rules *= total_rules;
|
||||||
|
GGML_ASSERT(n_prev_rules >= 1);
|
||||||
};
|
};
|
||||||
|
|
||||||
while (*pos) {
|
while (*pos) {
|
||||||
if (*pos == '"') { // literal string
|
if (*pos == '"') { // literal string
|
||||||
pos++;
|
pos++;
|
||||||
last_sym_start = rule.size();
|
last_sym_start = rule.size();
|
||||||
|
n_prev_rules = 1;
|
||||||
while (*pos != '"') {
|
while (*pos != '"') {
|
||||||
if (!*pos) {
|
if (!*pos) {
|
||||||
throw std::runtime_error("unexpected end of input");
|
throw std::runtime_error("unexpected end of input");
|
||||||
@ -541,6 +558,7 @@ const char* llama_grammar_parser::parse_sequence(
|
|||||||
start_type = LLAMA_GRETYPE_CHAR_NOT;
|
start_type = LLAMA_GRETYPE_CHAR_NOT;
|
||||||
}
|
}
|
||||||
last_sym_start = rule.size();
|
last_sym_start = rule.size();
|
||||||
|
n_prev_rules = 1;
|
||||||
while (*pos != ']') {
|
while (*pos != ']') {
|
||||||
if (!*pos) {
|
if (!*pos) {
|
||||||
throw std::runtime_error("unexpected end of input");
|
throw std::runtime_error("unexpected end of input");
|
||||||
@ -571,6 +589,7 @@ const char* llama_grammar_parser::parse_sequence(
|
|||||||
auto token_pair = parse_token(vocab, pos);
|
auto token_pair = parse_token(vocab, pos);
|
||||||
const char * token_end = token_pair.second;
|
const char * token_end = token_pair.second;
|
||||||
last_sym_start = rule.size();
|
last_sym_start = rule.size();
|
||||||
|
n_prev_rules = 1;
|
||||||
rule.push_back({type, token_pair.first});
|
rule.push_back({type, token_pair.first});
|
||||||
pos = parse_space(token_end, is_nested);
|
pos = parse_space(token_end, is_nested);
|
||||||
} else if (is_word_char(*pos)) { // rule reference
|
} else if (is_word_char(*pos)) { // rule reference
|
||||||
@ -578,13 +597,15 @@ const char* llama_grammar_parser::parse_sequence(
|
|||||||
uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos);
|
uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos);
|
||||||
pos = parse_space(name_end, is_nested);
|
pos = parse_space(name_end, is_nested);
|
||||||
last_sym_start = rule.size();
|
last_sym_start = rule.size();
|
||||||
rule.push_back({ LLAMA_GRETYPE_RULE_REF, ref_rule_id });
|
n_prev_rules = 1;
|
||||||
}
|
rule.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
|
||||||
else if (*pos == '(') { // grouping
|
} else if (*pos == '(') { // grouping
|
||||||
// parse nested alternates into synthesized rule
|
// parse nested alternates into synthesized rule
|
||||||
pos = parse_space(pos + 1, true);
|
pos = parse_space(pos + 1, true);
|
||||||
|
uint32_t n_rules_before = symbol_ids.size();
|
||||||
uint32_t sub_rule_id = generate_symbol_id(rule_name);
|
uint32_t sub_rule_id = generate_symbol_id(rule_name);
|
||||||
pos = parse_alternates(pos, rule_name, sub_rule_id, true);
|
pos = parse_alternates(pos, rule_name, sub_rule_id, true);
|
||||||
|
n_prev_rules = std::max(1u, (uint32_t)symbol_ids.size() - n_rules_before);
|
||||||
last_sym_start = rule.size();
|
last_sym_start = rule.size();
|
||||||
// output reference to synthesized rule
|
// output reference to synthesized rule
|
||||||
rule.push_back({ LLAMA_GRETYPE_RULE_REF, sub_rule_id });
|
rule.push_back({ LLAMA_GRETYPE_RULE_REF, sub_rule_id });
|
||||||
@ -595,7 +616,8 @@ const char* llama_grammar_parser::parse_sequence(
|
|||||||
}
|
}
|
||||||
else if (*pos == '.') { // any char
|
else if (*pos == '.') { // any char
|
||||||
last_sym_start = rule.size();
|
last_sym_start = rule.size();
|
||||||
rule.push_back({ LLAMA_GRETYPE_CHAR_ANY, 0 });
|
n_prev_rules = 1;
|
||||||
|
rule.push_back({LLAMA_GRETYPE_CHAR_ANY, 0});
|
||||||
pos = parse_space(pos + 1, is_nested);
|
pos = parse_space(pos + 1, is_nested);
|
||||||
}
|
}
|
||||||
else if (*pos == '*') {
|
else if (*pos == '*') {
|
||||||
@ -858,32 +880,54 @@ static bool llama_grammar_match_token(
|
|||||||
static void llama_grammar_advance_stack(
|
static void llama_grammar_advance_stack(
|
||||||
const llama_grammar_rules & rules,
|
const llama_grammar_rules & rules,
|
||||||
const llama_grammar_stack & stack,
|
const llama_grammar_stack & stack,
|
||||||
llama_grammar_stacks & new_stacks) {
|
llama_grammar_stacks & new_stacks) {
|
||||||
if (stack.empty()) {
|
std::vector<llama_grammar_stack> todo;
|
||||||
if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
|
todo.push_back(stack);
|
||||||
new_stacks.emplace_back(stack);
|
|
||||||
|
auto stack_cmp = [](const llama_grammar_stack & a, const llama_grammar_stack & b) {
|
||||||
|
return std::lexicographical_compare(a.begin(), a.end(), b.begin(), b.end(),
|
||||||
|
[](const llama_grammar_element * pa, const llama_grammar_element * pb) {
|
||||||
|
return pa < pb; // Compare pointer addresses
|
||||||
|
}
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
std::set<llama_grammar_stack, decltype(stack_cmp)> seen(stack_cmp);
|
||||||
|
|
||||||
|
while (!todo.empty()) {
|
||||||
|
llama_grammar_stack curr_stack = std::move(todo.back());
|
||||||
|
todo.pop_back();
|
||||||
|
|
||||||
|
if (seen.find( curr_stack) != seen.end()) {
|
||||||
|
continue;
|
||||||
}
|
}
|
||||||
return;
|
seen.insert(curr_stack);
|
||||||
}
|
|
||||||
|
|
||||||
const llama_grammar_element * pos = stack.back();
|
if (curr_stack.empty()) {
|
||||||
|
if (std::find(new_stacks.begin(), new_stacks.end(), curr_stack) == new_stacks.end()) {
|
||||||
|
new_stacks.emplace_back(std::move(curr_stack));
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
switch (pos->type) {
|
const llama_grammar_element * pos = curr_stack.back();
|
||||||
|
|
||||||
|
switch (pos->type) {
|
||||||
case LLAMA_GRETYPE_RULE_REF: {
|
case LLAMA_GRETYPE_RULE_REF: {
|
||||||
const size_t rule_id = static_cast<size_t>(pos->value);
|
const size_t rule_id = static_cast<size_t>(pos->value);
|
||||||
const llama_grammar_element * subpos = rules[rule_id].data();
|
const llama_grammar_element * subpos = rules[rule_id].data();
|
||||||
do {
|
do {
|
||||||
// init new stack without the top (pos)
|
// init new stack without the top (pos)
|
||||||
llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
|
llama_grammar_stack next_stack(curr_stack.begin(), curr_stack.end() - 1);
|
||||||
if (!llama_grammar_is_end_of_sequence(pos + 1)) {
|
if (!llama_grammar_is_end_of_sequence(pos + 1)) {
|
||||||
// if this rule ref is followed by another element, add that to stack
|
// if this rule ref is followed by another element, add that to stack
|
||||||
new_stack.push_back(pos + 1);
|
next_stack.push_back(pos + 1);
|
||||||
}
|
}
|
||||||
if (!llama_grammar_is_end_of_sequence(subpos)) {
|
if (!llama_grammar_is_end_of_sequence(subpos)) {
|
||||||
// if alternate is nonempty, add to stack
|
// if alternate is nonempty, add to stack
|
||||||
new_stack.push_back(subpos);
|
next_stack.push_back(subpos);
|
||||||
}
|
}
|
||||||
llama_grammar_advance_stack(rules, new_stack, new_stacks);
|
todo.push_back(std::move(next_stack));
|
||||||
while (!llama_grammar_is_end_of_sequence(subpos)) {
|
while (!llama_grammar_is_end_of_sequence(subpos)) {
|
||||||
// scan to end of alternate def
|
// scan to end of alternate def
|
||||||
subpos++;
|
subpos++;
|
||||||
@ -891,8 +935,8 @@ static void llama_grammar_advance_stack(
|
|||||||
if (subpos->type == LLAMA_GRETYPE_ALT) {
|
if (subpos->type == LLAMA_GRETYPE_ALT) {
|
||||||
// there's another alternate def of this rule to process
|
// there's another alternate def of this rule to process
|
||||||
subpos++;
|
subpos++;
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
} while (true);
|
} while (true);
|
||||||
@ -903,9 +947,9 @@ static void llama_grammar_advance_stack(
|
|||||||
case LLAMA_GRETYPE_CHAR_ANY:
|
case LLAMA_GRETYPE_CHAR_ANY:
|
||||||
case LLAMA_GRETYPE_TOKEN:
|
case LLAMA_GRETYPE_TOKEN:
|
||||||
case LLAMA_GRETYPE_TOKEN_NOT:
|
case LLAMA_GRETYPE_TOKEN_NOT:
|
||||||
if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
|
if (std::find(new_stacks.begin(), new_stacks.end(), curr_stack) == new_stacks.end()) {
|
||||||
// only add the stack if it's not a duplicate of one we already have
|
// only add the stack if it's not a duplicate of one we already have
|
||||||
new_stacks.emplace_back(stack);
|
new_stacks.emplace_back(std::move(curr_stack));
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
@ -913,6 +957,7 @@ static void llama_grammar_advance_stack(
|
|||||||
// (LLAMA_GRETYPE_CHAR_ALT, LLAMA_GRETYPE_CHAR_RNG_UPPER); stack should never be left on
|
// (LLAMA_GRETYPE_CHAR_ALT, LLAMA_GRETYPE_CHAR_RNG_UPPER); stack should never be left on
|
||||||
// those
|
// those
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1198,13 +1243,13 @@ struct llama_grammar* llama_grammar_init_impl(
|
|||||||
// if there is a grammar, parse it
|
// if there is a grammar, parse it
|
||||||
// rules will be empty (default) if there are parse errors
|
// rules will be empty (default) if there are parse errors
|
||||||
if (!parser.parse(grammar_str) || parser.rules.empty()) {
|
if (!parser.parse(grammar_str) || parser.rules.empty()) {
|
||||||
fprintf(stderr, "%s: failed to parse grammar\n", __func__);
|
LLAMA_LOG_ERROR("failed to parse grammar\n");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure that there is a "root" node.
|
// Ensure that the grammar contains the start symbol
|
||||||
if (parser.symbol_ids.find("root") == parser.symbol_ids.end()) {
|
if (parser.symbol_ids.find(grammar_root) == parser.symbol_ids.end()) {
|
||||||
fprintf(stderr, "%s: grammar does not contain a 'root' symbol\n", __func__);
|
LLAMA_LOG_ERROR("grammar does not contain a '%s' symbol\n", grammar_root);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1233,7 +1278,7 @@ struct llama_grammar* llama_grammar_init_impl(
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (llama_grammar_detect_left_recursion(vec_rules, i, &rules_visited, &rules_in_progress, &rules_may_be_empty)) {
|
if (llama_grammar_detect_left_recursion(vec_rules, i, &rules_visited, &rules_in_progress, &rules_may_be_empty)) {
|
||||||
LLAMA_LOG_ERROR("unsupported grammar, left recursion detected for nonterminal at index %zu", i);
|
LLAMA_LOG_ERROR("unsupported grammar, left recursion detected for nonterminal at index %zu\n", i);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -15,20 +15,12 @@
|
|||||||
|
|
||||||
using json = nlohmann::ordered_json;
|
using json = nlohmann::ordered_json;
|
||||||
|
|
||||||
static llama_grammar* build_grammar(const std::string & grammar_str) {
|
static llama_grammar * build_grammar_with_root(const std::string & grammar_str, const char * grammar_root) {
|
||||||
auto parsed_grammar = grammar_parser::parse(grammar_str.c_str());
|
return llama_grammar_init_impl(nullptr, grammar_str.c_str(), grammar_root, false, nullptr, 0, nullptr, 0);
|
||||||
|
}
|
||||||
|
|
||||||
// Ensure we parsed correctly
|
static llama_grammar * build_grammar(const std::string & grammar_str) {
|
||||||
assert(!parsed_grammar.rules.empty());
|
return build_grammar_with_root(grammar_str, "root");
|
||||||
|
|
||||||
// Ensure we have a root node
|
|
||||||
assert(!(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end()));
|
|
||||||
|
|
||||||
std::vector<const llama_grammar_element*> grammar_rules(parsed_grammar.c_rules());
|
|
||||||
llama_grammar* grammar = llama_grammar_init(
|
|
||||||
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
|
|
||||||
|
|
||||||
return grammar;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool test_build_grammar_fails(const std::string & grammar_str) {
|
static bool test_build_grammar_fails(const std::string & grammar_str) {
|
||||||
@ -801,6 +793,24 @@ static void test_quantifiers() {
|
|||||||
"0xFF 0x12 0xAB 0x00 0x00 0x00",
|
"0xFF 0x12 0xAB 0x00 0x00 0x00",
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
test_grammar(
|
||||||
|
"segfault",
|
||||||
|
// Grammar
|
||||||
|
R"""(
|
||||||
|
root ::= ( [x]* )*
|
||||||
|
)""",
|
||||||
|
// Passing strings
|
||||||
|
{
|
||||||
|
"",
|
||||||
|
"x",
|
||||||
|
"xx"
|
||||||
|
},
|
||||||
|
// Failing strings
|
||||||
|
{
|
||||||
|
"y",
|
||||||
|
"yy"
|
||||||
|
}
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void test_failure_missing_root() {
|
static void test_failure_missing_root() {
|
||||||
@ -875,6 +885,36 @@ static void test_failure_left_recursion() {
|
|||||||
fprintf(stderr, " ✅︎ Passed\n");
|
fprintf(stderr, " ✅︎ Passed\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void test_failure_missing_root_symbol() {
|
||||||
|
fprintf(stderr, "⚫ Testing missing root symbol:\n");
|
||||||
|
|
||||||
|
const std::string grammar_str = R"""(
|
||||||
|
root ::= "foobar"
|
||||||
|
)""";
|
||||||
|
|
||||||
|
llama_grammar * failure_result = build_grammar_with_root(grammar_str, "nonexistent");
|
||||||
|
assert(failure_result == nullptr);
|
||||||
|
|
||||||
|
fprintf(stderr, " ✅︎ Passed\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
static void test_custom_root_symbol_check() {
|
||||||
|
fprintf(stderr, "⚫ Testing custom root symbol check:\n");
|
||||||
|
|
||||||
|
const std::string custom_root_grammar_str = R"""(
|
||||||
|
foobar ::= "foobar"
|
||||||
|
)""";
|
||||||
|
|
||||||
|
llama_grammar * failure_result = build_grammar_with_root(custom_root_grammar_str, "root");
|
||||||
|
assert(failure_result == nullptr);
|
||||||
|
|
||||||
|
llama_grammar * success_result = build_grammar_with_root(custom_root_grammar_str, "foobar");
|
||||||
|
assert(success_result != nullptr);
|
||||||
|
llama_grammar_free_impl(success_result);
|
||||||
|
|
||||||
|
fprintf(stderr, " ✅︎ Passed\n");
|
||||||
|
}
|
||||||
|
|
||||||
static void test_json_schema() {
|
static void test_json_schema() {
|
||||||
// Note that this is similar to the regular grammar tests,
|
// Note that this is similar to the regular grammar tests,
|
||||||
// but we convert each json schema to a grammar before parsing.
|
// but we convert each json schema to a grammar before parsing.
|
||||||
@ -1448,6 +1488,8 @@ int main() {
|
|||||||
test_failure_missing_root();
|
test_failure_missing_root();
|
||||||
test_failure_missing_reference();
|
test_failure_missing_reference();
|
||||||
test_failure_left_recursion();
|
test_failure_left_recursion();
|
||||||
|
test_failure_missing_root_symbol();
|
||||||
|
test_custom_root_symbol_check();
|
||||||
test_json_schema();
|
test_json_schema();
|
||||||
fprintf(stdout, "All tests passed.\n");
|
fprintf(stdout, "All tests passed.\n");
|
||||||
return 0;
|
return 0;
|
||||||
|
|||||||
@ -141,6 +141,10 @@ int main()
|
|||||||
root ::= "a"{,}"
|
root ::= "a"{,}"
|
||||||
)""");
|
)""");
|
||||||
|
|
||||||
|
verify_failure(R"""(
|
||||||
|
root ::= (((((([^x]*){0,99}){0,99}){0,99}){0,99}){0,99}){0,99}
|
||||||
|
)""");
|
||||||
|
|
||||||
verify_failure(R"""(
|
verify_failure(R"""(
|
||||||
root ::= "a"{,10}"
|
root ::= "a"{,10}"
|
||||||
)""");
|
)""");
|
||||||
|
|||||||
@ -125,25 +125,27 @@ int main()
|
|||||||
|
|
||||||
std::vector<std::vector<llama_grammar_element>> expected_stacks = {
|
std::vector<std::vector<llama_grammar_element>> expected_stacks = {
|
||||||
{
|
{
|
||||||
{LLAMA_GRETYPE_RULE_REF, 5},
|
{LLAMA_GRETYPE_CHAR, 61},
|
||||||
|
{LLAMA_GRETYPE_RULE_REF, 7},
|
||||||
|
{LLAMA_GRETYPE_CHAR, 40},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
{LLAMA_GRETYPE_CHAR, 61},
|
||||||
|
{LLAMA_GRETYPE_RULE_REF, 7},
|
||||||
|
{LLAMA_GRETYPE_RULE_REF, 3},
|
||||||
|
{LLAMA_GRETYPE_CHAR, 48},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
{LLAMA_GRETYPE_CHAR, 61},
|
||||||
|
{LLAMA_GRETYPE_RULE_REF, 7},
|
||||||
|
{LLAMA_GRETYPE_RULE_REF, 3},
|
||||||
|
{LLAMA_GRETYPE_CHAR, 48},
|
||||||
|
},
|
||||||
|
{
|
||||||
{LLAMA_GRETYPE_CHAR, 61},
|
{LLAMA_GRETYPE_CHAR, 61},
|
||||||
{LLAMA_GRETYPE_RULE_REF, 7},
|
{LLAMA_GRETYPE_RULE_REF, 7},
|
||||||
{LLAMA_GRETYPE_CHAR, 97},
|
{LLAMA_GRETYPE_CHAR, 97},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
{LLAMA_GRETYPE_RULE_REF, 5},
|
|
||||||
{LLAMA_GRETYPE_CHAR, 61},
|
|
||||||
{LLAMA_GRETYPE_RULE_REF, 7},
|
|
||||||
{LLAMA_GRETYPE_RULE_REF, 3},
|
|
||||||
{LLAMA_GRETYPE_CHAR, 48},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
{LLAMA_GRETYPE_RULE_REF, 5},
|
|
||||||
{LLAMA_GRETYPE_CHAR, 61},
|
|
||||||
{LLAMA_GRETYPE_RULE_REF, 7},
|
|
||||||
{LLAMA_GRETYPE_RULE_REF, 3},
|
|
||||||
{LLAMA_GRETYPE_CHAR, 48},
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
{LLAMA_GRETYPE_RULE_REF, 5},
|
{LLAMA_GRETYPE_RULE_REF, 5},
|
||||||
{LLAMA_GRETYPE_CHAR, 61},
|
{LLAMA_GRETYPE_CHAR, 61},
|
||||||
@ -151,26 +153,24 @@ int main()
|
|||||||
{LLAMA_GRETYPE_CHAR, 40},
|
{LLAMA_GRETYPE_CHAR, 40},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
{LLAMA_GRETYPE_RULE_REF, 5},
|
||||||
|
{LLAMA_GRETYPE_CHAR, 61},
|
||||||
|
{LLAMA_GRETYPE_RULE_REF, 7},
|
||||||
|
{LLAMA_GRETYPE_RULE_REF, 3},
|
||||||
|
{LLAMA_GRETYPE_CHAR, 48},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
{LLAMA_GRETYPE_RULE_REF, 5},
|
||||||
|
{LLAMA_GRETYPE_CHAR, 61},
|
||||||
|
{LLAMA_GRETYPE_RULE_REF, 7},
|
||||||
|
{LLAMA_GRETYPE_RULE_REF, 3},
|
||||||
|
{LLAMA_GRETYPE_CHAR, 48},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
{LLAMA_GRETYPE_RULE_REF, 5},
|
||||||
{LLAMA_GRETYPE_CHAR, 61},
|
{LLAMA_GRETYPE_CHAR, 61},
|
||||||
{LLAMA_GRETYPE_RULE_REF, 7},
|
{LLAMA_GRETYPE_RULE_REF, 7},
|
||||||
{LLAMA_GRETYPE_CHAR, 97},
|
{LLAMA_GRETYPE_CHAR, 97},
|
||||||
},
|
|
||||||
{
|
|
||||||
{LLAMA_GRETYPE_CHAR, 61},
|
|
||||||
{LLAMA_GRETYPE_RULE_REF, 7},
|
|
||||||
{LLAMA_GRETYPE_RULE_REF, 3},
|
|
||||||
{LLAMA_GRETYPE_CHAR, 48},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
{LLAMA_GRETYPE_CHAR, 61},
|
|
||||||
{LLAMA_GRETYPE_RULE_REF, 7},
|
|
||||||
{LLAMA_GRETYPE_RULE_REF, 3},
|
|
||||||
{LLAMA_GRETYPE_CHAR, 48},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
{LLAMA_GRETYPE_CHAR, 61},
|
|
||||||
{LLAMA_GRETYPE_RULE_REF, 7},
|
|
||||||
{LLAMA_GRETYPE_CHAR, 40},
|
|
||||||
}};
|
}};
|
||||||
|
|
||||||
auto index = 0;
|
auto index = 0;
|
||||||
@ -197,9 +197,9 @@ int main()
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::vector<llama_grammar_candidate> next_candidates;
|
std::vector<llama_grammar_candidate> next_candidates;
|
||||||
next_candidates.resize(24);
|
next_candidates.resize(23);
|
||||||
|
|
||||||
for (size_t i = 0; i < 24; ++i)
|
for (size_t i = 0; i < 23; ++i)
|
||||||
{
|
{
|
||||||
uint32_t *cp = new uint32_t[2]; // dynamically allocate memory for code_point
|
uint32_t *cp = new uint32_t[2]; // dynamically allocate memory for code_point
|
||||||
cp[0] = 37 + i;
|
cp[0] = 37 + i;
|
||||||
@ -212,7 +212,6 @@ int main()
|
|||||||
{0, 37},
|
{0, 37},
|
||||||
{1, 38},
|
{1, 38},
|
||||||
{2, 39},
|
{2, 39},
|
||||||
{3, 40},
|
|
||||||
{4, 41},
|
{4, 41},
|
||||||
{5, 42},
|
{5, 42},
|
||||||
{6, 43},
|
{6, 43},
|
||||||
@ -270,6 +269,7 @@ int main()
|
|||||||
{0, 37},
|
{0, 37},
|
||||||
{1, 38},
|
{1, 38},
|
||||||
{2, 39},
|
{2, 39},
|
||||||
|
{3, 40},
|
||||||
{4, 41},
|
{4, 41},
|
||||||
{5, 42},
|
{5, 42},
|
||||||
{6, 43},
|
{6, 43},
|
||||||
@ -289,13 +289,11 @@ int main()
|
|||||||
{20, 57},
|
{20, 57},
|
||||||
{21, 58},
|
{21, 58},
|
||||||
{22, 59},
|
{22, 59},
|
||||||
{23, 60},
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
{0, 37},
|
{0, 37},
|
||||||
{1, 38},
|
{1, 38},
|
||||||
{2, 39},
|
{2, 39},
|
||||||
{3, 40},
|
|
||||||
{4, 41},
|
{4, 41},
|
||||||
{5, 42},
|
{5, 42},
|
||||||
{6, 43},
|
{6, 43},
|
||||||
@ -353,6 +351,7 @@ int main()
|
|||||||
{0, 37},
|
{0, 37},
|
||||||
{1, 38},
|
{1, 38},
|
||||||
{2, 39},
|
{2, 39},
|
||||||
|
{3, 40},
|
||||||
{4, 41},
|
{4, 41},
|
||||||
{5, 42},
|
{5, 42},
|
||||||
{6, 43},
|
{6, 43},
|
||||||
@ -372,7 +371,6 @@ int main()
|
|||||||
{20, 57},
|
{20, 57},
|
||||||
{21, 58},
|
{21, 58},
|
||||||
{22, 59},
|
{22, 59},
|
||||||
{23, 60},
|
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user