From 063d9c156e816ae3cf62db01f429a07a099afe97 Mon Sep 17 00:00:00 2001 From: Aldehir Rojas Date: Sat, 20 Jun 2026 21:15:06 -0500 Subject: [PATCH] common/peg : refactor until gbnf grammar generation (#24839) * common/peg : refactor until gbnf grammar into an ac automaton * cont : add a test with multiple strings * cont : pad state with 0s so rules line up * cont : clean up comments * cont : use set everywhere * cont : inline state num string padding * cont : add a ref to PR * cont : fix regression in server-tools.cpp --- common/peg-parser.cpp | 194 +++++++++++++--------- common/peg-parser.h | 4 +- tests/peg-parser/test-gbnf-generation.cpp | 80 ++++++++- tools/server/server-tools.cpp | 1 + 4 files changed, 199 insertions(+), 80 deletions(-) diff --git a/common/peg-parser.cpp b/common/peg-parser.cpp index ff0d24d43f..506b902451 100644 --- a/common/peg-parser.cpp +++ b/common/peg-parser.cpp @@ -6,13 +6,14 @@ #include "unicode.h" #include +#include #include #include #include #include #include +#include #include -#include // Trick to catch missing branches template @@ -88,40 +89,7 @@ struct trie { return match_result{match_result::NO_MATCH}; } - struct prefix_and_next { - std::vector prefix; - std::vector next_chars; - }; - - std::vector collect_prefix_and_next() { - std::vector prefix; - std::vector result; - collect_prefix_and_next(0, prefix, result); - return result; - } - private: - void collect_prefix_and_next(size_t index, std::vector & prefix, std::vector & out) { - if (!nodes[index].is_word) { - if (!nodes[index].children.empty()) { - std::vector chars; - chars.reserve(nodes[index].children.size()); - for (const auto & p : nodes[index].children) { - chars.push_back(p.first); - } - out.emplace_back(prefix_and_next{prefix, chars}); - } - } - - for (const auto & p : nodes[index].children) { - uint32_t ch = p.first; - auto child = p.second; - prefix.push_back(ch); - collect_prefix_and_next(child, prefix, out); - prefix.pop_back(); - } - } - size_t create_node() { size_t index = nodes.size(); nodes.emplace_back(); @@ -153,6 +121,65 @@ struct trie { } }; +// Aho-Corasick automaton +struct aho_corasick { + trie t; + std::vector fail; // failure links + std::vector order; // states in BFS order + std::vector terminal; // match states (directly or via a suffix link) + std::set alphabet; // every character with a transition + + aho_corasick(const std::vector & strings) : t(strings) { + const auto & nodes = t.nodes; + const size_t n = nodes.size(); + + fail.assign(n, 0); + order.reserve(n); + + std::deque queue{ 0 }; + while (!queue.empty()) { + size_t u = queue.front(); + queue.pop_front(); + order.push_back(u); + for (const auto & [ch, v] : nodes[u].children) { + if (u != 0) { + size_t f = fail[u]; + while (f && nodes[f].children.find(ch) == nodes[f].children.end()) { + f = fail[f]; + } + auto it = nodes[f].children.find(ch); + fail[v] = (it != nodes[f].children.end() && it->second != v) ? it->second : 0; + } + queue.push_back(v); + } + } + + terminal.assign(n, false); + for (size_t u : order) { + terminal[u] = nodes[u].is_word || (u != 0 && terminal[fail[u]]); + } + + for (const auto & node : nodes) { + for (const auto & [ch, v] : node.children) { + alphabet.insert(ch); + } + } + } + + size_t num_states() const { return t.nodes.size(); } + bool is_terminal(size_t s) const { return terminal[s]; } + + // follow failure links until a transition on `ch` exists. + size_t next(size_t state, uint32_t ch) const { + const auto & nodes = t.nodes; + while (state && nodes[state].children.find(ch) == nodes[state].children.end()) { + state = fail[state]; + } + auto it = nodes[state].children.find(ch); + return it != nodes[state].children.end() ? it->second : 0; + } +}; + static std::pair parse_hex_escape(const std::string & str, size_t pos, int hex_count) { if (pos + hex_count > str.length()) { return {0, 0}; @@ -992,12 +1019,12 @@ void common_peg_arena::resolve_refs() { } std::string common_peg_arena::dump(common_peg_parser_id id) const { - std::unordered_set visited; + std::set visited; return dump_impl(id, visited); } std::string common_peg_arena::dump_impl(common_peg_parser_id id, - std::unordered_set & visited) const { + std::set & visited) const { // Check for cycles if (visited.count(id)) { return "[cycle]"; @@ -1502,61 +1529,74 @@ static std::string gbnf_escape_char_class(uint32_t c) { return std::string(buf); } -static std::string gbnf_excluding_pattern(const std::vector & strings) { - trie matcher(strings); - auto pieces = matcher.collect_prefix_and_next(); +// GBNF grammar matching strings that contain no string in `strings` as a +// substring. Emits the complement of an Aho-Corasick automaton DFA and returns +// the start state rule name. +// +// ref: https://github.com/ggml-org/llama.cpp/pull/24839 +static std::string gbnf_excluding_grammar(const common_grammar_builder & builder, + const std::string & prefix, + const std::vector & strings) { + aho_corasick ac(strings); - std::string pattern; - std::string trailing; // optional proper-prefix of a delimiter, allowed only at the very end - for (size_t i = 0; i < pieces.size(); ++i) { - if (i > 0) { - pattern += " | "; + auto state_name = [&](size_t s) -> std::string { + if (s == 0) { + return prefix; } + std::string num = std::to_string(s); + num = num.size() == 1 ? ("0" + num) : num; + return prefix + "-" + num; + }; - const auto & pre = pieces[i].prefix; - const auto & chars = pieces[i].next_chars; - - std::string cls; - cls.reserve(chars.size()); + auto char_class = [](const std::vector & chars, bool negate) { + std::string s = negate ? "[^" : "["; for (uint32_t ch : chars) { - cls += gbnf_escape_char_class(ch); + s += gbnf_escape_char_class(ch); + } + return s + "]"; + }; + + for (size_t q = 0; q < ac.num_states(); q++) { + if (ac.is_terminal(q)) { + continue; // match states are dropped } - if (!pre.empty()) { - std::string pre_literal = gbnf_format_literal(common_unicode_cpts_to_utf8(pre)); - pattern += pre_literal + " [^" + cls + "]"; - // Each interior alternative consumes a delimiter-prefix plus a disambiguating - // char, so the repetition alone cannot match a value that *ends* on a proper - // prefix of a delimiter (e.g. a trailing "\n" when the delimiter is - // "\n\n"). The runtime until() (greedy first-match) accepts such - // values, so without this the grammar would reject input the parser accepts. - // Allow the value to terminate on any proper prefix as an optional tail. - // This makes the grammar a slight superset of the runtime language (a value - // may end on the longest prefix, which greedy first-match would not itself - // produce); harmless for constrained generation, which only needs to admit - // every runtime-valid string. - if (!trailing.empty()) { - trailing += " | "; + std::map> buckets; + std::vector excluded; + for (uint32_t c : ac.alphabet) { + size_t d = ac.next(q, c); + if (ac.is_terminal(d)) { + excluded.push_back(c); // completes a forbidden string -> omit + } else if (d != 0) { + buckets[d].push_back(c); // specific non-root destination + excluded.push_back(c); } - trailing += pre_literal; - } else { - pattern += "[^" + cls + "]"; } + + std::string rhs = "|"; // every state is accepting + for (const auto & [d, chars] : buckets) { + rhs += " " + char_class(chars, false) + " " + state_name(d) + " |"; + } + rhs += " " + char_class(excluded, true) + " " + state_name(0); + + builder.add_rule(state_name(q), rhs); } - std::string result = "(" + pattern + ")*"; - if (!trailing.empty()) { - result += " (" + trailing + ")?"; + // An empty delimiter makes the start state terminal. Emit an entry rule + // that matches nothing so the returned reference stays valid. + if (ac.is_terminal(0)) { + builder.add_rule(prefix, "|"); } - return result; + + return state_name(0); } -static std::unordered_set collect_reachable_rules( +static std::set collect_reachable_rules( const common_peg_arena & arena, const common_peg_parser_id & rule ) { - std::unordered_set reachable; - std::unordered_set visited; + std::set reachable; + std::set visited; std::function visit = [&](common_peg_parser_id id) { const auto & parser = arena.get(id); @@ -1765,7 +1805,7 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo if (p.delimiters.empty()) { return ".*"; } - return gbnf_excluding_pattern(p.delimiters); + return gbnf_excluding_grammar(builder, "until-" + std::to_string(id), p.delimiters); } else if constexpr (std::is_same_v) { if (schema_delegates(p)) { return to_gbnf(p.child); @@ -1789,7 +1829,7 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo }; // Collect reachable rules - std::unordered_set reachable_rules; + std::set reachable_rules; if (lazy) { // Collect rules reachable from trigger rules diff --git a/common/peg-parser.h b/common/peg-parser.h index b6bb05214b..132173a64c 100644 --- a/common/peg-parser.h +++ b/common/peg-parser.h @@ -3,8 +3,8 @@ #include #include +#include #include -#include #include #include #include @@ -335,7 +335,7 @@ class common_peg_arena { friend class common_peg_parser_builder; private: - std::string dump_impl(common_peg_parser_id id, std::unordered_set & visited) const; + std::string dump_impl(common_peg_parser_id id, std::set & visited) const; common_peg_parser_id add_parser(common_peg_parser_variant parser); void add_rule(const std::string & name, common_peg_parser_id id); diff --git a/tests/peg-parser/test-gbnf-generation.cpp b/tests/peg-parser/test-gbnf-generation.cpp index 00111e6a19..45d692ca60 100644 --- a/tests/peg-parser/test-gbnf-generation.cpp +++ b/tests/peg-parser/test-gbnf-generation.cpp @@ -129,8 +129,86 @@ void test_gbnf_generation(testing &t) { }); assert_gbnf_equal(t, R"""( - root ::= ([^<] | "<" [^/] | "])* ("<" | "] until-0 + )""", gbnf); + }); + + t.test("until grammar overlapping delimiter", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.until("\n\n"); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_gbnf_equal(t, R"""( + root ::= until-0 + space ::= | " " | "\n"{1,2} [ \t]{0,20} + until-0 ::= | [\n] until-0-01 | [^\n] until-0 + until-0-01 ::= | [\n] until-0-01 | [<] until-0-02 | [^\n<] until-0 + until-0-02 ::= | [\n] until-0-01 | [/] until-0-03 | [^\n/] until-0 + until-0-03 ::= | [\n] until-0-01 | [p] until-0-04 | [^\np] until-0 + until-0-04 ::= | [\n] until-0-01 | [a] until-0-05 | [^\na] until-0 + until-0-05 ::= | [\n] until-0-01 | [r] until-0-06 | [^\nr] until-0 + until-0-06 ::= | [\n] until-0-01 | [a] until-0-07 | [^\na] until-0 + until-0-07 ::= | [\n] until-0-01 | [m] until-0-08 | [^\nm] until-0 + until-0-08 ::= | [\n] until-0-01 | [e] until-0-09 | [^\ne] until-0 + until-0-09 ::= | [\n] until-0-01 | [t] until-0-10 | [^\nt] until-0 + until-0-10 ::= | [\n] until-0-01 | [e] until-0-11 | [^\ne] until-0 + until-0-11 ::= | [\n] until-0-01 | [r] until-0-12 | [^\nr] until-0 + until-0-12 ::= | [\n] until-0-01 | [>] until-0-13 | [^\n>] until-0 + until-0-13 ::= | [^\n] until-0 + )""", gbnf); + }); + + // DeepSeek-V3.2 tag prefix. The DSML token (|DSML|) embeds U+FF5C, + // so the delimiter mixes ASCII and multi-byte codepoints. + t.test("until grammar unicode delimiter", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.until("<|DSML|"); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_gbnf_equal(t, R"""( + root ::= until-0 + space ::= | " " | "\n"{1,2} [ \t]{0,20} + until-0 ::= | [<] until-0-01 | [^<] until-0 + until-0-01 ::= | [<] until-0-01 | [\uFF5C] until-0-02 | [^<\uFF5C] until-0 + until-0-02 ::= | [<] until-0-01 | [D] until-0-03 | [^ #include #include +#include namespace fs = std::filesystem;