diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index 3ca663e3..e0593ef5 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #define MAX_REPETITION_THRESHOLD 2000 @@ -462,6 +463,7 @@ const char* llama_grammar_parser::parse_sequence( bool is_nested) { size_t last_sym_start = rule.size(); const char* pos = src; + uint64_t n_prev_rules = 1; // use UINT64_MAX as the empty value because we aligned to the proper uint64_t type so -1 can't be used // (though it's technically the same as -1 now) @@ -489,6 +491,18 @@ const char* llama_grammar_parser::parse_sequence( // S' ::= S | llama_grammar_rule prev_rule(rule.begin() + last_sym_start, rule.end()); + // Calculate the total number of rules that will be generated by this repetition + uint64_t total_rules = 1; // Start with 1 for the original rule + if (!no_max && max_times > 0) { + total_rules = max_times; + } else if (min_times > 0) { + total_rules = min_times; + } + + if (n_prev_rules * total_rules >= MAX_REPETITION_THRESHOLD) { + throw std::runtime_error("number of rules that are going to be repeated multiplied by the new repetition exceeds sane defaults, please reduce the number of repetitions or rule complexity"); + } + if (min_times == 0) { rule.resize(last_sym_start); } @@ -517,12 +531,15 @@ const char* llama_grammar_parser::parse_sequence( if (n_opt > 0) { rule.push_back({ LLAMA_GRETYPE_RULE_REF, last_rec_rule_id }); } + n_prev_rules *= total_rules; + GGML_ASSERT(n_prev_rules >= 1); }; while (*pos) { if (*pos == '"') { // literal string pos++; last_sym_start = rule.size(); + n_prev_rules = 1; while (*pos != '"') { if (!*pos) { throw std::runtime_error("unexpected end of input"); @@ -541,6 +558,7 @@ const char* llama_grammar_parser::parse_sequence( start_type = LLAMA_GRETYPE_CHAR_NOT; } last_sym_start = rule.size(); + n_prev_rules = 1; while (*pos != ']') { if (!*pos) { throw std::runtime_error("unexpected end of input"); @@ -571,6 +589,7 @@ const char* llama_grammar_parser::parse_sequence( auto token_pair = parse_token(vocab, pos); const char * token_end = token_pair.second; last_sym_start = rule.size(); + n_prev_rules = 1; rule.push_back({type, token_pair.first}); pos = parse_space(token_end, is_nested); } else if (is_word_char(*pos)) { // rule reference @@ -578,13 +597,15 @@ const char* llama_grammar_parser::parse_sequence( uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos); pos = parse_space(name_end, is_nested); last_sym_start = rule.size(); - rule.push_back({ LLAMA_GRETYPE_RULE_REF, ref_rule_id }); - } - else if (*pos == '(') { // grouping - // parse nested alternates into synthesized rule + n_prev_rules = 1; + rule.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id}); + } else if (*pos == '(') { // grouping + // parse nested alternates into synthesized rule pos = parse_space(pos + 1, true); + uint32_t n_rules_before = symbol_ids.size(); uint32_t sub_rule_id = generate_symbol_id(rule_name); pos = parse_alternates(pos, rule_name, sub_rule_id, true); + n_prev_rules = std::max(1u, (uint32_t)symbol_ids.size() - n_rules_before); last_sym_start = rule.size(); // output reference to synthesized rule rule.push_back({ LLAMA_GRETYPE_RULE_REF, sub_rule_id }); @@ -595,7 +616,8 @@ const char* llama_grammar_parser::parse_sequence( } else if (*pos == '.') { // any char last_sym_start = rule.size(); - rule.push_back({ LLAMA_GRETYPE_CHAR_ANY, 0 }); + n_prev_rules = 1; + rule.push_back({LLAMA_GRETYPE_CHAR_ANY, 0}); pos = parse_space(pos + 1, is_nested); } else if (*pos == '*') { @@ -858,32 +880,54 @@ static bool llama_grammar_match_token( static void llama_grammar_advance_stack( const llama_grammar_rules & rules, const llama_grammar_stack & stack, - llama_grammar_stacks & new_stacks) { - if (stack.empty()) { - if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) { - new_stacks.emplace_back(stack); + llama_grammar_stacks & new_stacks) { + std::vector todo; + todo.push_back(stack); + + auto stack_cmp = [](const llama_grammar_stack & a, const llama_grammar_stack & b) { + return std::lexicographical_compare(a.begin(), a.end(), b.begin(), b.end(), + [](const llama_grammar_element * pa, const llama_grammar_element * pb) { + return pa < pb; // Compare pointer addresses + } + ); + }; + + std::set seen(stack_cmp); + + while (!todo.empty()) { + llama_grammar_stack curr_stack = std::move(todo.back()); + todo.pop_back(); + + if (seen.find( curr_stack) != seen.end()) { + continue; } - return; - } + seen.insert(curr_stack); - const llama_grammar_element * pos = stack.back(); + if (curr_stack.empty()) { + if (std::find(new_stacks.begin(), new_stacks.end(), curr_stack) == new_stacks.end()) { + new_stacks.emplace_back(std::move(curr_stack)); + } + continue; + } - switch (pos->type) { + const llama_grammar_element * pos = curr_stack.back(); + + switch (pos->type) { case LLAMA_GRETYPE_RULE_REF: { const size_t rule_id = static_cast(pos->value); const llama_grammar_element * subpos = rules[rule_id].data(); do { // init new stack without the top (pos) - llama_grammar_stack new_stack(stack.begin(), stack.end() - 1); + llama_grammar_stack next_stack(curr_stack.begin(), curr_stack.end() - 1); if (!llama_grammar_is_end_of_sequence(pos + 1)) { // if this rule ref is followed by another element, add that to stack - new_stack.push_back(pos + 1); + next_stack.push_back(pos + 1); } if (!llama_grammar_is_end_of_sequence(subpos)) { // if alternate is nonempty, add to stack - new_stack.push_back(subpos); + next_stack.push_back(subpos); } - llama_grammar_advance_stack(rules, new_stack, new_stacks); + todo.push_back(std::move(next_stack)); while (!llama_grammar_is_end_of_sequence(subpos)) { // scan to end of alternate def subpos++; @@ -891,8 +935,8 @@ static void llama_grammar_advance_stack( if (subpos->type == LLAMA_GRETYPE_ALT) { // there's another alternate def of this rule to process subpos++; - } - else { + } + else { break; } } while (true); @@ -903,9 +947,9 @@ static void llama_grammar_advance_stack( case LLAMA_GRETYPE_CHAR_ANY: case LLAMA_GRETYPE_TOKEN: case LLAMA_GRETYPE_TOKEN_NOT: - if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) { + if (std::find(new_stacks.begin(), new_stacks.end(), curr_stack) == new_stacks.end()) { // only add the stack if it's not a duplicate of one we already have - new_stacks.emplace_back(stack); + new_stacks.emplace_back(std::move(curr_stack)); } break; default: @@ -913,6 +957,7 @@ static void llama_grammar_advance_stack( // (LLAMA_GRETYPE_CHAR_ALT, LLAMA_GRETYPE_CHAR_RNG_UPPER); stack should never be left on // those GGML_ABORT("fatal error"); + } } } @@ -1198,13 +1243,13 @@ struct llama_grammar* llama_grammar_init_impl( // if there is a grammar, parse it // rules will be empty (default) if there are parse errors if (!parser.parse(grammar_str) || parser.rules.empty()) { - fprintf(stderr, "%s: failed to parse grammar\n", __func__); + LLAMA_LOG_ERROR("failed to parse grammar\n"); return nullptr; } - // Ensure that there is a "root" node. - if (parser.symbol_ids.find("root") == parser.symbol_ids.end()) { - fprintf(stderr, "%s: grammar does not contain a 'root' symbol\n", __func__); + // Ensure that the grammar contains the start symbol + if (parser.symbol_ids.find(grammar_root) == parser.symbol_ids.end()) { + LLAMA_LOG_ERROR("grammar does not contain a '%s' symbol\n", grammar_root); return nullptr; } @@ -1233,7 +1278,7 @@ struct llama_grammar* llama_grammar_init_impl( continue; } if (llama_grammar_detect_left_recursion(vec_rules, i, &rules_visited, &rules_in_progress, &rules_may_be_empty)) { - LLAMA_LOG_ERROR("unsupported grammar, left recursion detected for nonterminal at index %zu", i); + LLAMA_LOG_ERROR("unsupported grammar, left recursion detected for nonterminal at index %zu\n", i); return nullptr; } } diff --git a/tests/test-grammar-integration.cpp b/tests/test-grammar-integration.cpp index b6bf394c..114d42e1 100644 --- a/tests/test-grammar-integration.cpp +++ b/tests/test-grammar-integration.cpp @@ -15,20 +15,12 @@ using json = nlohmann::ordered_json; -static llama_grammar* build_grammar(const std::string & grammar_str) { - auto parsed_grammar = grammar_parser::parse(grammar_str.c_str()); +static llama_grammar * build_grammar_with_root(const std::string & grammar_str, const char * grammar_root) { + return llama_grammar_init_impl(nullptr, grammar_str.c_str(), grammar_root, false, nullptr, 0, nullptr, 0); +} - // Ensure we parsed correctly - assert(!parsed_grammar.rules.empty()); - - // Ensure we have a root node - assert(!(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end())); - - std::vector grammar_rules(parsed_grammar.c_rules()); - llama_grammar* grammar = llama_grammar_init( - grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); - - return grammar; +static llama_grammar * build_grammar(const std::string & grammar_str) { + return build_grammar_with_root(grammar_str, "root"); } static bool test_build_grammar_fails(const std::string & grammar_str) { @@ -801,6 +793,24 @@ static void test_quantifiers() { "0xFF 0x12 0xAB 0x00 0x00 0x00", } ); + test_grammar( + "segfault", + // Grammar + R"""( + root ::= ( [x]* )* + )""", + // Passing strings + { + "", + "x", + "xx" + }, + // Failing strings + { + "y", + "yy" + } + ); } static void test_failure_missing_root() { @@ -875,6 +885,36 @@ static void test_failure_left_recursion() { fprintf(stderr, " ✅︎ Passed\n"); } +static void test_failure_missing_root_symbol() { + fprintf(stderr, "⚫ Testing missing root symbol:\n"); + + const std::string grammar_str = R"""( + root ::= "foobar" + )"""; + + llama_grammar * failure_result = build_grammar_with_root(grammar_str, "nonexistent"); + assert(failure_result == nullptr); + + fprintf(stderr, " ✅︎ Passed\n"); +} + +static void test_custom_root_symbol_check() { + fprintf(stderr, "⚫ Testing custom root symbol check:\n"); + + const std::string custom_root_grammar_str = R"""( + foobar ::= "foobar" + )"""; + + llama_grammar * failure_result = build_grammar_with_root(custom_root_grammar_str, "root"); + assert(failure_result == nullptr); + + llama_grammar * success_result = build_grammar_with_root(custom_root_grammar_str, "foobar"); + assert(success_result != nullptr); + llama_grammar_free_impl(success_result); + + fprintf(stderr, " ✅︎ Passed\n"); +} + static void test_json_schema() { // Note that this is similar to the regular grammar tests, // but we convert each json schema to a grammar before parsing. @@ -1448,6 +1488,8 @@ int main() { test_failure_missing_root(); test_failure_missing_reference(); test_failure_left_recursion(); + test_failure_missing_root_symbol(); + test_custom_root_symbol_check(); test_json_schema(); fprintf(stdout, "All tests passed.\n"); return 0; diff --git a/tests/test-grammar-parser.cpp b/tests/test-grammar-parser.cpp index 68a38639..aab46c5d 100644 --- a/tests/test-grammar-parser.cpp +++ b/tests/test-grammar-parser.cpp @@ -141,6 +141,10 @@ int main() root ::= "a"{,}" )"""); + verify_failure(R"""( + root ::= (((((([^x]*){0,99}){0,99}){0,99}){0,99}){0,99}){0,99} + )"""); + verify_failure(R"""( root ::= "a"{,10}" )"""); diff --git a/tests/test-llama-grammar.cpp b/tests/test-llama-grammar.cpp index 3814f124..ef4c0c88 100644 --- a/tests/test-llama-grammar.cpp +++ b/tests/test-llama-grammar.cpp @@ -125,25 +125,27 @@ int main() std::vector> expected_stacks = { { - {LLAMA_GRETYPE_RULE_REF, 5}, + {LLAMA_GRETYPE_CHAR, 61}, + {LLAMA_GRETYPE_RULE_REF, 7}, + {LLAMA_GRETYPE_CHAR, 40}, + }, + { + {LLAMA_GRETYPE_CHAR, 61}, + {LLAMA_GRETYPE_RULE_REF, 7}, + {LLAMA_GRETYPE_RULE_REF, 3}, + {LLAMA_GRETYPE_CHAR, 48}, + }, + { + {LLAMA_GRETYPE_CHAR, 61}, + {LLAMA_GRETYPE_RULE_REF, 7}, + {LLAMA_GRETYPE_RULE_REF, 3}, + {LLAMA_GRETYPE_CHAR, 48}, + }, + { {LLAMA_GRETYPE_CHAR, 61}, {LLAMA_GRETYPE_RULE_REF, 7}, {LLAMA_GRETYPE_CHAR, 97}, }, - { - {LLAMA_GRETYPE_RULE_REF, 5}, - {LLAMA_GRETYPE_CHAR, 61}, - {LLAMA_GRETYPE_RULE_REF, 7}, - {LLAMA_GRETYPE_RULE_REF, 3}, - {LLAMA_GRETYPE_CHAR, 48}, - }, - { - {LLAMA_GRETYPE_RULE_REF, 5}, - {LLAMA_GRETYPE_CHAR, 61}, - {LLAMA_GRETYPE_RULE_REF, 7}, - {LLAMA_GRETYPE_RULE_REF, 3}, - {LLAMA_GRETYPE_CHAR, 48}, - }, { {LLAMA_GRETYPE_RULE_REF, 5}, {LLAMA_GRETYPE_CHAR, 61}, @@ -151,26 +153,24 @@ int main() {LLAMA_GRETYPE_CHAR, 40}, }, { + {LLAMA_GRETYPE_RULE_REF, 5}, + {LLAMA_GRETYPE_CHAR, 61}, + {LLAMA_GRETYPE_RULE_REF, 7}, + {LLAMA_GRETYPE_RULE_REF, 3}, + {LLAMA_GRETYPE_CHAR, 48}, + }, + { + {LLAMA_GRETYPE_RULE_REF, 5}, + {LLAMA_GRETYPE_CHAR, 61}, + {LLAMA_GRETYPE_RULE_REF, 7}, + {LLAMA_GRETYPE_RULE_REF, 3}, + {LLAMA_GRETYPE_CHAR, 48}, + }, + { + {LLAMA_GRETYPE_RULE_REF, 5}, {LLAMA_GRETYPE_CHAR, 61}, {LLAMA_GRETYPE_RULE_REF, 7}, {LLAMA_GRETYPE_CHAR, 97}, - }, - { - {LLAMA_GRETYPE_CHAR, 61}, - {LLAMA_GRETYPE_RULE_REF, 7}, - {LLAMA_GRETYPE_RULE_REF, 3}, - {LLAMA_GRETYPE_CHAR, 48}, - }, - { - {LLAMA_GRETYPE_CHAR, 61}, - {LLAMA_GRETYPE_RULE_REF, 7}, - {LLAMA_GRETYPE_RULE_REF, 3}, - {LLAMA_GRETYPE_CHAR, 48}, - }, - { - {LLAMA_GRETYPE_CHAR, 61}, - {LLAMA_GRETYPE_RULE_REF, 7}, - {LLAMA_GRETYPE_CHAR, 40}, }}; auto index = 0; @@ -197,9 +197,9 @@ int main() } std::vector next_candidates; - next_candidates.resize(24); + next_candidates.resize(23); - for (size_t i = 0; i < 24; ++i) + for (size_t i = 0; i < 23; ++i) { uint32_t *cp = new uint32_t[2]; // dynamically allocate memory for code_point cp[0] = 37 + i; @@ -212,7 +212,6 @@ int main() {0, 37}, {1, 38}, {2, 39}, - {3, 40}, {4, 41}, {5, 42}, {6, 43}, @@ -270,6 +269,7 @@ int main() {0, 37}, {1, 38}, {2, 39}, + {3, 40}, {4, 41}, {5, 42}, {6, 43}, @@ -289,13 +289,11 @@ int main() {20, 57}, {21, 58}, {22, 59}, - {23, 60}, }, { {0, 37}, {1, 38}, {2, 39}, - {3, 40}, {4, 41}, {5, 42}, {6, 43}, @@ -353,6 +351,7 @@ int main() {0, 37}, {1, 38}, {2, 39}, + {3, 40}, {4, 41}, {5, 42}, {6, 43}, @@ -372,7 +371,6 @@ int main() {20, 57}, {21, 58}, {22, 59}, - {23, 60}, }, };