diff --git a/common/chat.cpp b/common/chat.cpp index 7ac7870c..60ef46b8 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -1933,6 +1933,172 @@ static common_chat_params common_chat_params_init_deepseek_v3_2(const common_cha return data; } +static common_chat_params common_chat_params_init_minimax_m3(const common_chat_template & tmpl, + const autoparser::generation_params & inputs) { + common_chat_params data; + + data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs); + data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; + data.supports_thinking = true; + data.thinking_start_tag = ""; + data.thinking_end_tag = ""; + + const std::string NS = "]<]minimax[>["; + const std::string THINK_START = ""; + const std::string THINK_END = ""; + const std::string FC_START = NS + ""; + const std::string FC_END = NS + ""; + const std::string INVOKE_END = NS + ""; + + data.preserved_tokens = { + NS, + "", + "", + THINK_START, + THINK_END, + }; + + auto has_tools = inputs.tools.is_array() && !inputs.tools.empty(); + auto has_response_format = !inputs.json_schema.is_null() && inputs.json_schema.is_object(); + auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE; + auto include_grammar = has_response_format || (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE); + + auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) { + auto generation_prompt = p.prefix(inputs.generation_prompt, THINK_START); + auto end = p.end(); + + auto reasoning = p.eps(); + if (extract_reasoning && inputs.enable_thinking) { + reasoning = p.optional(p.optional(p.literal(THINK_START)) + p.reasoning(p.until(THINK_END)) + THINK_END); + } else if (extract_reasoning) { + reasoning = p.optional(p.optional(p.literal(THINK_START)) + p.until(THINK_END) + p.literal(THINK_END)); + } + + if (has_response_format) { + auto response_format = p.rule("response-format", + p.literal("```json") + p.space() + + p.content(p.schema(p.json(), "response-format-schema", inputs.json_schema)) + + p.space() + p.literal("```")); + return generation_prompt + reasoning + response_format + end; + } + + if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) { + return generation_prompt + reasoning + p.content(p.rest()) + end; + } + + auto tool_choice = p.choice(); + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool.at("function"); + std::string name = function.at("name"); + auto params = function.contains("parameters") ? function.at("parameters") : json::object(); + const auto & props = params.contains("properties") ? params.at("properties") : json::object(); + + std::set required; + if (params.contains("required")) { + params.at("required").get_to(required); + } + + auto schema_info = common_schema_info(); + schema_info.resolve_refs(params); + + std::vector required_parsers; + std::vector optional_parsers; + for (const auto & [param_name, param_schema] : props.items()) { + bool is_required = required.find(param_name) != required.end(); + bool is_string = schema_info.resolves_to_string(param_schema); + + const std::string p_close = NS + ""; + + auto arg = p.tool_arg( + p.tool_arg_open( + p.literal(NS + "<") + + p.tool_arg_name(p.literal(param_name)) + + p.literal(">")) + + (is_string + ? p.tool_arg_string_value(p.until(p_close)) + : p.tool_arg_json_value(p.schema(p.json(), + "tool-" + name + "-arg-" + param_name + "-schema", + param_schema, false))) + + p.tool_arg_close(p.literal(p_close))); + + auto named_arg = p.rule("tool-" + name + "-arg-" + param_name, arg); + if (is_required) { + required_parsers.push_back(named_arg); + } else { + optional_parsers.push_back(named_arg); + } + } + + common_peg_parser args_seq = p.eps(); + for (size_t i = 0; i < required_parsers.size(); i++) { + if (i > 0) { + args_seq = args_seq + p.space(); + } + args_seq = args_seq + required_parsers[i]; + } + + if (!optional_parsers.empty()) { + common_peg_parser any_opt = p.choice(); + for (const auto & opt : optional_parsers) { + any_opt |= opt; + } + args_seq = args_seq + p.repeat(p.space() + any_opt, 0, -1); + } + + auto func_parser = p.tool( + p.tool_open(p.literal(NS + "")) + + p.space() + args_seq + p.space() + + p.tool_close(p.literal(INVOKE_END))); + + tool_choice |= p.rule("tool-" + name, func_parser); + }); + + auto require_tools = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED; + + common_peg_parser tool_calls = p.eps(); + if (inputs.parallel_tool_calls) { + tool_calls = p.trigger_rule("tool-call", + p.literal(FC_START) + p.space() + tool_choice + + p.zero_or_more(p.space() + tool_choice) + p.space() + p.literal(FC_END)); + } else { + tool_calls = p.trigger_rule("tool-call", + p.literal(FC_START) + p.space() + tool_choice + p.space() + p.literal(FC_END)); + } + + if (!require_tools) { + tool_calls = p.optional(tool_calls); + } + + auto content_before_tools = p.content(p.until(FC_START)); + return generation_prompt + reasoning + p.space() + content_before_tools + tool_calls + end; + }); + + data.parser = parser.save(); + + if (include_grammar) { + data.grammar_lazy = !(has_response_format || (has_tools && inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED)); + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool.at("function"); + auto schema = function.contains("parameters") ? function.at("parameters") : json::object(); + builder.resolve_refs(schema); + }); + if (has_response_format) { + auto schema = inputs.json_schema; + builder.resolve_refs(schema); + } + parser.build_grammar(builder, data.grammar_lazy); + }); + + data.grammar_triggers = { + { COMMON_GRAMMAR_TRIGGER_TYPE_WORD, FC_START }, + }; + } + + return data; +} + // Cohere2 MoE (a.k.a. "North Code") parser. // // The assistant turn is fully marker-wrapped: @@ -2329,6 +2495,13 @@ std::optional common_chat_try_specialized_template( return common_chat_params_init_gigachat_v3(tmpl, params); } + if (src.find("]<]minimax[>[") != std::string::npos && + src.find("") != std::string::npos && + src.find(" common_chat_templates_get_caps(const common_chat_tem GGML_ASSERT(chat_templates->template_default != nullptr); return chat_templates->template_default->caps.to_map(); } - diff --git a/models/templates/MiniMax-M3.jinja b/models/templates/MiniMax-M3.jinja deleted file mode 100644 index ff04e450..00000000 --- a/models/templates/MiniMax-M3.jinja +++ /dev/null @@ -1,169 +0,0 @@ -{# MiniMax-M3 override. - Keep MiniMax-M2's PEG-compatible tool-call wrapper, but use M3 thinking tags. #} -{%- set toolcall_begin_token = '' -%} -{%- set toolcall_end_token = '' -%} -{%- set think_begin_token = '' -%} -{%- set think_end_token = '' -%} - -{#- Tool Rendering Functions ============================================== -#} -{%- macro render_tool_namespace(namespace_name, tool_list) -%} -{%- for tool in tool_list -%} -{{ tool.function | tojson(ensure_ascii=False) }} -{% endfor -%} -{%- endmacro -%} - -{%- macro visible_text(content) -%} - {%- if content is string -%} - {{ content }} - {%- elif content is iterable and content is not mapping -%} - {%- for item in content -%} - {%- if item is mapping and item.type == 'text' -%} - {{- item.text }} - {%- elif item is string -%} - {{- item }} - {%- endif -%} - {%- endfor -%} - {%- else -%} - {{- content }} - {%- endif -%} -{%- endmacro -%} - -{#- System Message Construction ============================================ -#} -{%- macro build_system_message(system_message) -%} - {%- if system_message and system_message.content -%} - {{- visible_text(system_message.content) }} - {%- else -%} - {%- if model_identity is not defined -%} - {%- set model_identity = "You are a helpful assistant." -%} - {%- endif -%} - {{- model_identity }} - {%- endif -%} - - {%- if system_message and system_message.current_date -%} - {{- '\n' ~ 'Current date: ' + system_message.current_date }} - {%- endif -%} - {%- if system_message and system_message.current_location -%} - {{- '\n' ~ 'Current location: ' + system_message.current_location }} - {%- endif -%} -{%- endmacro -%} - -{#- Main Template Logic ===================================================== -#} -{%- set system_message = none -%} -{%- set conversation_messages = messages -%} -{%- if messages and messages[0].role == "system" -%} - {%- set system_message = messages[0] -%} - {%- set conversation_messages = messages[1:] -%} -{%- endif -%} - -{#- Get the last user message turn, for interleaved thinking -#} -{%- set ns = namespace(last_user_index=-1) %} -{% for m in conversation_messages %} - {%- if m.role == 'user' %} - {% set ns.last_user_index = loop.index0 -%} - {%- endif %} -{%- endfor %} - -{#- Render system message -#} -{{- ']~!b[' ~ ']~b]system' ~ '\n' }} -{{- build_system_message(system_message) }} - -{#- Render tools if available -#} -{%- if tools -%} - {{- '\n\n' ~ '# Tools' ~ '\n' ~ 'You may call one or more tools to assist with the user query.\nHere are the tools available in JSONSchema format:' ~ '\n' }} - {{- '\n' ~ '' ~ '\n' }} - {{- render_tool_namespace("functions", tools) }} - {{- '' ~ '\n\n' }} - {{- 'When making tool calls, use XML format to invoke tools and pass parameters:' ~ '\n' }} - {{- '\n' ~ toolcall_begin_token }} - - param-value-1 - param-value-2 - ... - - {{- '\n' ~ toolcall_end_token }} -{%- endif -%} - -{{- '[e~[\n' }} - -{#- Render messages -#} -{%- set last_tool_call = namespace(name=none) -%} -{%- for message in conversation_messages -%} - {%- if message.role == 'assistant' -%} - {{- ']~b]ai' ~ '\n' }} - - {%- set reasoning_content = '' %} - {%- set content = visible_text(message.content) %} - {%- if message.reasoning_content is string %} - {%- set reasoning_content = message.reasoning_content %} - {%- else %} - {%- if think_end_token in content %} - {%- set reasoning_content = content.split(think_end_token)[0].strip('\n').split(think_begin_token)[-1].strip('\n') %} - {%- set content = content.split(think_end_token)[-1].strip('\n') %} - {%- elif '' in content %} - {%- set reasoning_content = content.split('')[0].strip('\n').split('')[-1].strip('\n') %} - {%- set content = content.split('')[-1].strip('\n') %} - {%- endif %} - {%- endif %} - {%- if reasoning_content and loop.index0 > ns.last_user_index -%} - {{- think_begin_token ~ '\n' ~ reasoning_content ~ '\n' ~ think_end_token ~ '\n\n' }} - {%- endif -%} - {%- if content -%} - {{- content }} - {%- endif -%} - - {%- if message.tool_calls -%} - {{- '\n' ~ toolcall_begin_token ~ '\n' }} - {%- for tool_call in message.tool_calls -%} - {%- if tool_call.function %} - {%- set tool_call = tool_call.function %} - {%- endif %} - {{- '' }} - {% set _args = tool_call.arguments %} - {%- for k, v in _args.items() %} - {{- '' }} - {{- v | tojson(ensure_ascii=False) if v is not string else v }} - {{- '' }} - {% endfor %} - {{- '' ~ '\n' }} - {%- endfor -%} - - {{- toolcall_end_token}} - {%- set last_tool_call.name = message.tool_calls[-1].function.name -%} - {%- else -%} - {%- set last_tool_call.name = none -%} - {%- endif -%} - {{- '[e~[' ~ '\n' }} - - {%- elif message.role == 'tool' -%} - {%- if last_tool_call.name is none -%} - {{- raise_exception("Message has tool role, but there was no previous assistant message with a tool call!") }} - {%- endif -%} - {%- if loop.first or (conversation_messages[loop.index0 - 1].role != 'tool') -%} - {{- ']~b]tool' }} - {%- endif -%} - {%- if message.content is string -%} - {{- '\n' }} - {{- message.content }} - {{- '' }} - {%- else -%} - {%- for tr in message.content -%} - {{- '\n' }} - {{- tr.output if tr.output is defined else (tr.text if tr.type == 'text' and tr.text is defined else tr) }} - {{- '\n' }} - {%- endfor -%} - {%- endif -%} - {%- if loop.last or (conversation_messages[loop.index0 + 1].role != 'tool') -%} - {{- '[e~[\n' -}} - {%- endif -%} - - {%- elif message.role == 'user' -%} - {{- ']~b]user' ~ '\n' }} - {{- visible_text(message.content) }} - {{- '[e~[' ~ '\n' }} - {%- endif -%} -{%- endfor -%} - -{#- Generation prompt -#} -{%- if add_generation_prompt -%} -{{- ']~b]ai' ~ '\n' ~ think_begin_token ~ '\n' }} -{%- endif -%} diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index 6cc13213..a7514b52 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -23,6 +23,7 @@ using json = nlohmann::ordered_json; static int main_automated_tests(void); +static common_chat_msg simple_msg(const std::string & role, const std::string & content); static void run_multiple(const std::string& dir_path, bool stop_on_first_failure, const json& input, bool use_common = false); static void run_single(const std::string& contents, json input, bool use_common = false, const std::string & output_path = ""); @@ -225,7 +226,6 @@ static std::string normalize_newlines(const std::string & s) { #endif } - static std::string format_using_common( const std::string & template_str, const std::string & bos_token, @@ -243,6 +243,87 @@ static std::string format_using_common( return output; } +static void test_minimax_m3_native_tool_parser(void) { + const std::string template_str = R"( +{%- set ns_token = ']<]minimax[>[' -%} +{%- set toolcall_begin_token = ns_token ~ '' -%} +{%- set toolcall_end_token = ns_token ~ '' -%} +{%- for message in messages -%} +{{- message.role ~ ': ' ~ message.content ~ '\n' -}} +{%- endfor -%} +{%- if tools -%} +{{- toolcall_begin_token ~ ns_token ~ '' ~ ns_token ~ '' ~ toolcall_end_token -}} +{%- endif -%} +{%- if add_generation_prompt -%} +{{- '' -}} +{%- endif -%} +)"; + + common_chat_templates_inputs inputs; + inputs.use_jinja = true; + inputs.messages = { simple_msg("user", "Call a tool") }; + inputs.add_generation_prompt = true; + inputs.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; + inputs.enable_thinking = true; + inputs.parallel_tool_calls = true; + inputs.tools = { + common_chat_tool{ + /* .name = */ "special_function", + /* .description = */ "Test function", + /* .parameters = */ R"({ + "type": "object", + "properties": { + "arg1": { "type": "integer" } + }, + "required": ["arg1"] + })", + }, + common_chat_tool{ + /* .name = */ "python", + /* .description = */ "Run Python", + /* .parameters = */ R"({ + "type": "object", + "properties": { + "code": { "type": "string" } + }, + "required": ["code"] + })", + }, + }; + + common_chat_templates_ptr tmpls(common_chat_templates_init(/* model= */ nullptr, template_str)); + auto params = common_chat_templates_apply(tmpls.get(), inputs); + + assert(params.format == COMMON_CHAT_FORMAT_PEG_NATIVE); + assert(!params.parser.empty()); + + common_peg_arena arena; + arena.load(params.parser); + + common_chat_parser_params parser_params(params); + parser_params.parser = arena; + + const std::string output = + "Calling both\n" + "]<]minimax[>[\n" + "]<]minimax[>[" + "]<]minimax[>[1]<]minimax[>[" + "]<]minimax[>[\n" + "]<]minimax[>[" + "]<]minimax[>[print('hey')]<]minimax[>[" + "]<]minimax[>[\n" + "]<]minimax[>["; + + auto msg = common_chat_parse(output, /* is_partial = */ false, parser_params); + assert(msg.reasoning_content == "Calling both"); + assert(msg.content.empty()); + assert(msg.tool_calls.size() == 2); + assert(msg.tool_calls[0].name == "special_function"); + assert(json::parse(msg.tool_calls[0].arguments) == json::parse(R"({"arg1": 1})")); + assert(msg.tool_calls[1].name == "python"); + assert(json::parse(msg.tool_calls[1].arguments) == json::parse("{\"code\": \"print('hey')\"}")); +} + // skip libcommon, use direct jinja engine static jinja::value_string format_using_direct_engine( @@ -336,6 +417,8 @@ static common_chat_msg simple_msg(const std::string & role, const std::string & int main_automated_tests(void) { // jinja::enable_debug(true); + test_minimax_m3_native_tool_parser(); + std::vector conversation { {"system", "You are a helpful assistant"}, {"user", "Hello"},