mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-27 23:50:20 -05:00
common/peg : refactor until gbnf grammar generation (#24839)
* common/peg : refactor until gbnf grammar into an ac automaton * cont : add a test with multiple strings * cont : pad state with 0s so rules line up * cont : clean up comments * cont : use set everywhere * cont : inline state num string padding * cont : add a ref to PR * cont : fix regression in server-tools.cpp
This commit is contained in:
parent
c57607016a
commit
063d9c156e
@ -6,13 +6,14 @@
|
||||
#include "unicode.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <deque>
|
||||
#include <initializer_list>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <nlohmann/json.hpp>
|
||||
#include <regex>
|
||||
#include <set>
|
||||
#include <stdexcept>
|
||||
#include <unordered_set>
|
||||
|
||||
// Trick to catch missing branches
|
||||
template <typename T>
|
||||
@ -88,40 +89,7 @@ struct trie {
|
||||
return match_result{match_result::NO_MATCH};
|
||||
}
|
||||
|
||||
struct prefix_and_next {
|
||||
std::vector<uint32_t> prefix;
|
||||
std::vector<uint32_t> next_chars;
|
||||
};
|
||||
|
||||
std::vector<prefix_and_next> collect_prefix_and_next() {
|
||||
std::vector<uint32_t> prefix;
|
||||
std::vector<prefix_and_next> result;
|
||||
collect_prefix_and_next(0, prefix, result);
|
||||
return result;
|
||||
}
|
||||
|
||||
private:
|
||||
void collect_prefix_and_next(size_t index, std::vector<uint32_t> & prefix, std::vector<prefix_and_next> & out) {
|
||||
if (!nodes[index].is_word) {
|
||||
if (!nodes[index].children.empty()) {
|
||||
std::vector<uint32_t> chars;
|
||||
chars.reserve(nodes[index].children.size());
|
||||
for (const auto & p : nodes[index].children) {
|
||||
chars.push_back(p.first);
|
||||
}
|
||||
out.emplace_back(prefix_and_next{prefix, chars});
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto & p : nodes[index].children) {
|
||||
uint32_t ch = p.first;
|
||||
auto child = p.second;
|
||||
prefix.push_back(ch);
|
||||
collect_prefix_and_next(child, prefix, out);
|
||||
prefix.pop_back();
|
||||
}
|
||||
}
|
||||
|
||||
size_t create_node() {
|
||||
size_t index = nodes.size();
|
||||
nodes.emplace_back();
|
||||
@ -153,6 +121,65 @@ struct trie {
|
||||
}
|
||||
};
|
||||
|
||||
// Aho-Corasick automaton
|
||||
struct aho_corasick {
|
||||
trie t;
|
||||
std::vector<size_t> fail; // failure links
|
||||
std::vector<size_t> order; // states in BFS order
|
||||
std::vector<bool> terminal; // match states (directly or via a suffix link)
|
||||
std::set<uint32_t> alphabet; // every character with a transition
|
||||
|
||||
aho_corasick(const std::vector<std::string> & strings) : t(strings) {
|
||||
const auto & nodes = t.nodes;
|
||||
const size_t n = nodes.size();
|
||||
|
||||
fail.assign(n, 0);
|
||||
order.reserve(n);
|
||||
|
||||
std::deque<size_t> queue{ 0 };
|
||||
while (!queue.empty()) {
|
||||
size_t u = queue.front();
|
||||
queue.pop_front();
|
||||
order.push_back(u);
|
||||
for (const auto & [ch, v] : nodes[u].children) {
|
||||
if (u != 0) {
|
||||
size_t f = fail[u];
|
||||
while (f && nodes[f].children.find(ch) == nodes[f].children.end()) {
|
||||
f = fail[f];
|
||||
}
|
||||
auto it = nodes[f].children.find(ch);
|
||||
fail[v] = (it != nodes[f].children.end() && it->second != v) ? it->second : 0;
|
||||
}
|
||||
queue.push_back(v);
|
||||
}
|
||||
}
|
||||
|
||||
terminal.assign(n, false);
|
||||
for (size_t u : order) {
|
||||
terminal[u] = nodes[u].is_word || (u != 0 && terminal[fail[u]]);
|
||||
}
|
||||
|
||||
for (const auto & node : nodes) {
|
||||
for (const auto & [ch, v] : node.children) {
|
||||
alphabet.insert(ch);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
size_t num_states() const { return t.nodes.size(); }
|
||||
bool is_terminal(size_t s) const { return terminal[s]; }
|
||||
|
||||
// follow failure links until a transition on `ch` exists.
|
||||
size_t next(size_t state, uint32_t ch) const {
|
||||
const auto & nodes = t.nodes;
|
||||
while (state && nodes[state].children.find(ch) == nodes[state].children.end()) {
|
||||
state = fail[state];
|
||||
}
|
||||
auto it = nodes[state].children.find(ch);
|
||||
return it != nodes[state].children.end() ? it->second : 0;
|
||||
}
|
||||
};
|
||||
|
||||
static std::pair<uint32_t, size_t> parse_hex_escape(const std::string & str, size_t pos, int hex_count) {
|
||||
if (pos + hex_count > str.length()) {
|
||||
return {0, 0};
|
||||
@ -992,12 +1019,12 @@ void common_peg_arena::resolve_refs() {
|
||||
}
|
||||
|
||||
std::string common_peg_arena::dump(common_peg_parser_id id) const {
|
||||
std::unordered_set<common_peg_parser_id> visited;
|
||||
std::set<common_peg_parser_id> visited;
|
||||
return dump_impl(id, visited);
|
||||
}
|
||||
|
||||
std::string common_peg_arena::dump_impl(common_peg_parser_id id,
|
||||
std::unordered_set<common_peg_parser_id> & visited) const {
|
||||
std::set<common_peg_parser_id> & visited) const {
|
||||
// Check for cycles
|
||||
if (visited.count(id)) {
|
||||
return "[cycle]";
|
||||
@ -1502,61 +1529,74 @@ static std::string gbnf_escape_char_class(uint32_t c) {
|
||||
return std::string(buf);
|
||||
}
|
||||
|
||||
static std::string gbnf_excluding_pattern(const std::vector<std::string> & strings) {
|
||||
trie matcher(strings);
|
||||
auto pieces = matcher.collect_prefix_and_next();
|
||||
// GBNF grammar matching strings that contain no string in `strings` as a
|
||||
// substring. Emits the complement of an Aho-Corasick automaton DFA and returns
|
||||
// the start state rule name.
|
||||
//
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/24839
|
||||
static std::string gbnf_excluding_grammar(const common_grammar_builder & builder,
|
||||
const std::string & prefix,
|
||||
const std::vector<std::string> & strings) {
|
||||
aho_corasick ac(strings);
|
||||
|
||||
std::string pattern;
|
||||
std::string trailing; // optional proper-prefix of a delimiter, allowed only at the very end
|
||||
for (size_t i = 0; i < pieces.size(); ++i) {
|
||||
if (i > 0) {
|
||||
pattern += " | ";
|
||||
auto state_name = [&](size_t s) -> std::string {
|
||||
if (s == 0) {
|
||||
return prefix;
|
||||
}
|
||||
std::string num = std::to_string(s);
|
||||
num = num.size() == 1 ? ("0" + num) : num;
|
||||
return prefix + "-" + num;
|
||||
};
|
||||
|
||||
const auto & pre = pieces[i].prefix;
|
||||
const auto & chars = pieces[i].next_chars;
|
||||
|
||||
std::string cls;
|
||||
cls.reserve(chars.size());
|
||||
auto char_class = [](const std::vector<uint32_t> & chars, bool negate) {
|
||||
std::string s = negate ? "[^" : "[";
|
||||
for (uint32_t ch : chars) {
|
||||
cls += gbnf_escape_char_class(ch);
|
||||
s += gbnf_escape_char_class(ch);
|
||||
}
|
||||
return s + "]";
|
||||
};
|
||||
|
||||
for (size_t q = 0; q < ac.num_states(); q++) {
|
||||
if (ac.is_terminal(q)) {
|
||||
continue; // match states are dropped
|
||||
}
|
||||
|
||||
if (!pre.empty()) {
|
||||
std::string pre_literal = gbnf_format_literal(common_unicode_cpts_to_utf8(pre));
|
||||
pattern += pre_literal + " [^" + cls + "]";
|
||||
// Each interior alternative consumes a delimiter-prefix plus a disambiguating
|
||||
// char, so the repetition alone cannot match a value that *ends* on a proper
|
||||
// prefix of a delimiter (e.g. a trailing "\n" when the delimiter is
|
||||
// "\n</parameter>\n"). The runtime until() (greedy first-match) accepts such
|
||||
// values, so without this the grammar would reject input the parser accepts.
|
||||
// Allow the value to terminate on any proper prefix as an optional tail.
|
||||
// This makes the grammar a slight superset of the runtime language (a value
|
||||
// may end on the longest prefix, which greedy first-match would not itself
|
||||
// produce); harmless for constrained generation, which only needs to admit
|
||||
// every runtime-valid string.
|
||||
if (!trailing.empty()) {
|
||||
trailing += " | ";
|
||||
}
|
||||
trailing += pre_literal;
|
||||
} else {
|
||||
pattern += "[^" + cls + "]";
|
||||
std::map<size_t, std::vector<uint32_t>> buckets;
|
||||
std::vector<uint32_t> excluded;
|
||||
for (uint32_t c : ac.alphabet) {
|
||||
size_t d = ac.next(q, c);
|
||||
if (ac.is_terminal(d)) {
|
||||
excluded.push_back(c); // completes a forbidden string -> omit
|
||||
} else if (d != 0) {
|
||||
buckets[d].push_back(c); // specific non-root destination
|
||||
excluded.push_back(c);
|
||||
}
|
||||
}
|
||||
|
||||
std::string result = "(" + pattern + ")*";
|
||||
if (!trailing.empty()) {
|
||||
result += " (" + trailing + ")?";
|
||||
std::string rhs = "|"; // every state is accepting
|
||||
for (const auto & [d, chars] : buckets) {
|
||||
rhs += " " + char_class(chars, false) + " " + state_name(d) + " |";
|
||||
}
|
||||
return result;
|
||||
rhs += " " + char_class(excluded, true) + " " + state_name(0);
|
||||
|
||||
builder.add_rule(state_name(q), rhs);
|
||||
}
|
||||
|
||||
static std::unordered_set<std::string> collect_reachable_rules(
|
||||
// An empty delimiter makes the start state terminal. Emit an entry rule
|
||||
// that matches nothing so the returned reference stays valid.
|
||||
if (ac.is_terminal(0)) {
|
||||
builder.add_rule(prefix, "|");
|
||||
}
|
||||
|
||||
return state_name(0);
|
||||
}
|
||||
|
||||
static std::set<std::string> collect_reachable_rules(
|
||||
const common_peg_arena & arena,
|
||||
const common_peg_parser_id & rule
|
||||
) {
|
||||
std::unordered_set<std::string> reachable;
|
||||
std::unordered_set<std::string> visited;
|
||||
std::set<std::string> reachable;
|
||||
std::set<std::string> visited;
|
||||
|
||||
std::function<void(common_peg_parser_id)> visit = [&](common_peg_parser_id id) {
|
||||
const auto & parser = arena.get(id);
|
||||
@ -1765,7 +1805,7 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo
|
||||
if (p.delimiters.empty()) {
|
||||
return ".*";
|
||||
}
|
||||
return gbnf_excluding_pattern(p.delimiters);
|
||||
return gbnf_excluding_grammar(builder, "until-" + std::to_string(id), p.delimiters);
|
||||
} else if constexpr (std::is_same_v<T, common_peg_schema_parser>) {
|
||||
if (schema_delegates(p)) {
|
||||
return to_gbnf(p.child);
|
||||
@ -1789,7 +1829,7 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo
|
||||
};
|
||||
|
||||
// Collect reachable rules
|
||||
std::unordered_set<std::string> reachable_rules;
|
||||
std::set<std::string> reachable_rules;
|
||||
|
||||
if (lazy) {
|
||||
// Collect rules reachable from trigger rules
|
||||
|
||||
@ -3,8 +3,8 @@
|
||||
#include <nlohmann/json_fwd.hpp>
|
||||
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <functional>
|
||||
@ -335,7 +335,7 @@ class common_peg_arena {
|
||||
friend class common_peg_parser_builder;
|
||||
|
||||
private:
|
||||
std::string dump_impl(common_peg_parser_id id, std::unordered_set<common_peg_parser_id> & visited) const;
|
||||
std::string dump_impl(common_peg_parser_id id, std::set<common_peg_parser_id> & visited) const;
|
||||
|
||||
common_peg_parser_id add_parser(common_peg_parser_variant parser);
|
||||
void add_rule(const std::string & name, common_peg_parser_id id);
|
||||
|
||||
@ -129,8 +129,86 @@ void test_gbnf_generation(testing &t) {
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= ([^<] | "<" [^/] | "</" [^t] | "</t" [^a] | "</ta" [^g] | "</tag" [^>])* ("<" | "</" | "</t" | "</ta" | "</tag")?
|
||||
root ::= until-0
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
until-0 ::= | [<] until-0-01 | [^<] until-0
|
||||
until-0-01 ::= | [<] until-0-01 | [/] until-0-02 | [^/<] until-0
|
||||
until-0-02 ::= | [<] until-0-01 | [t] until-0-03 | [^<t] until-0
|
||||
until-0-03 ::= | [<] until-0-01 | [a] until-0-04 | [^<a] until-0
|
||||
until-0-04 ::= | [<] until-0-01 | [g] until-0-05 | [^<g] until-0
|
||||
until-0-05 ::= | [<] until-0-01 | [^<>] until-0
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("until grammar overlapping delimiter", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.until("\n</parameter>\n");
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= until-0
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
until-0 ::= | [\n] until-0-01 | [^\n] until-0
|
||||
until-0-01 ::= | [\n] until-0-01 | [<] until-0-02 | [^\n<] until-0
|
||||
until-0-02 ::= | [\n] until-0-01 | [/] until-0-03 | [^\n/] until-0
|
||||
until-0-03 ::= | [\n] until-0-01 | [p] until-0-04 | [^\np] until-0
|
||||
until-0-04 ::= | [\n] until-0-01 | [a] until-0-05 | [^\na] until-0
|
||||
until-0-05 ::= | [\n] until-0-01 | [r] until-0-06 | [^\nr] until-0
|
||||
until-0-06 ::= | [\n] until-0-01 | [a] until-0-07 | [^\na] until-0
|
||||
until-0-07 ::= | [\n] until-0-01 | [m] until-0-08 | [^\nm] until-0
|
||||
until-0-08 ::= | [\n] until-0-01 | [e] until-0-09 | [^\ne] until-0
|
||||
until-0-09 ::= | [\n] until-0-01 | [t] until-0-10 | [^\nt] until-0
|
||||
until-0-10 ::= | [\n] until-0-01 | [e] until-0-11 | [^\ne] until-0
|
||||
until-0-11 ::= | [\n] until-0-01 | [r] until-0-12 | [^\nr] until-0
|
||||
until-0-12 ::= | [\n] until-0-01 | [>] until-0-13 | [^\n>] until-0
|
||||
until-0-13 ::= | [^\n] until-0
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
// DeepSeek-V3.2 tag prefix. The DSML token (|DSML|) embeds U+FF5C,
|
||||
// so the delimiter mixes ASCII and multi-byte codepoints.
|
||||
t.test("until grammar unicode delimiter", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.until("<|DSML|");
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= until-0
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
until-0 ::= | [<] until-0-01 | [^<] until-0
|
||||
until-0-01 ::= | [<] until-0-01 | [\uFF5C] until-0-02 | [^<\uFF5C] until-0
|
||||
until-0-02 ::= | [<] until-0-01 | [D] until-0-03 | [^<D] until-0
|
||||
until-0-03 ::= | [<] until-0-01 | [S] until-0-04 | [^<S] until-0
|
||||
until-0-04 ::= | [<] until-0-01 | [M] until-0-05 | [^<M] until-0
|
||||
until-0-05 ::= | [<] until-0-01 | [L] until-0-06 | [^<L] until-0
|
||||
until-0-06 ::= | [<] until-0-01 | [^<\uFF5C] until-0
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("until grammar multiple delimiters", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.until_one_of({"ab", "cd", "ef"});
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= until-0
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
until-0 ::= | [a] until-0-01 | [c] until-0-03 | [e] until-0-05 | [^ace] until-0
|
||||
until-0-01 ::= | [a] until-0-01 | [c] until-0-03 | [e] until-0-05 | [^abce] until-0
|
||||
until-0-03 ::= | [a] until-0-01 | [c] until-0-03 | [e] until-0-05 | [^acde] until-0
|
||||
until-0-05 ::= | [a] until-0-01 | [c] until-0-03 | [e] until-0-05 | [^acef] until-0
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
|
||||
@ -11,6 +11,7 @@
|
||||
#include <cstring>
|
||||
#include <climits>
|
||||
#include <algorithm>
|
||||
#include <unordered_set>
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user