diff --git a/common/chat-auto-parser-generator.cpp b/common/chat-auto-parser-generator.cpp index 453559a4..6021fc4e 100644 --- a/common/chat-auto-parser-generator.cpp +++ b/common/chat-auto-parser-generator.cpp @@ -136,10 +136,10 @@ common_peg_parser analyze_reasoning::build_parser(parser_build_context & ctx) co if (!end.empty()) { if (!start.empty()) { // Standard tag-based: optional(reasoning) - return p.optional(start + p.reasoning(p.until(end)) + end + p.space()); + return p.optional(p.optspace(start) + p.reasoning(p.until(trim_whitespace(end))) + p.optspace(end)); } // Delimiter-style (empty start) - return p.optional(p.reasoning(p.until(end)) + end + p.space()); + return p.optional(p.reasoning(p.until(trim_whitespace(end))) + p.optspace(end)); } } @@ -186,7 +186,6 @@ common_peg_parser analyze_tools::build_parser(parser_build_context & ctx) const common_peg_parser analyze_tools::build_tool_parser_json_native(parser_build_context & ctx) const { auto & p = ctx.p; const auto & inputs = ctx.inputs; - bool force_tools = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED; // Build effective field names with dot notation if function_field is set std::string name_field = format.name_field; @@ -225,8 +224,7 @@ common_peg_parser analyze_tools::build_tool_parser_json_native(parser_build_cont tool_start = format.per_call_start; } - return ctx.reasoning_parser + (force_tools ? p.eps() : p.optional(p.content(p.until(tool_start)))) + tools_parser + - p.end(); + return ctx.reasoning_parser + p.optional(p.content(p.until(tool_start))) + tools_parser + p.end(); } common_peg_parser analyze_tools::build_func_parser(common_chat_peg_builder & p, const std::string & name, @@ -270,7 +268,6 @@ common_peg_parser analyze_tools::build_func_parser(common_chat_peg_builder & p, common_peg_parser analyze_tools::build_tool_parser_tag_json(parser_build_context & ctx) const { auto & p = ctx.p; const auto & inputs = ctx.inputs; - bool force_tools = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED; common_peg_parser tool_choice = p.choice(); @@ -336,14 +333,12 @@ common_peg_parser analyze_tools::build_tool_parser_tag_json(parser_build_context std::string trigger_marker = !format.section_start.empty() ? format.section_start : format.per_call_start; auto content_before_tools = trigger_marker.empty() ? p.eps() : p.until(trigger_marker); - return ctx.reasoning_parser + (force_tools ? p.eps() : p.optional(p.content(content_before_tools))) + tool_calls + - p.end(); + return ctx.reasoning_parser + p.optional(p.content(content_before_tools)) + tool_calls + p.end(); } common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_context & ctx) const { auto & p = ctx.p; const auto & inputs = ctx.inputs; - bool force_tools = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED; auto until_suffix = p.rule("until-suffix", p.until(arguments.value_suffix)); @@ -374,9 +369,7 @@ common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_conte arguments.name_suffix) + arguments.value_prefix + (schema_info.resolves_to_string(param_schema) ? - p.tool_arg_string_value(p.schema(until_suffix, - "tool-" + name + "-arg-" + param_name + "-schema", - param_schema, true)) : + p.tool_arg_string_value(until_suffix) : p.tool_arg_json_value(p.schema( p.json(), "tool-" + name + "-arg-" + param_name + "-schema", param_schema, false)) + p.space()) + @@ -471,8 +464,7 @@ common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_conte std::string trigger_marker = !format.section_start.empty() ? format.section_start : format.per_call_start; auto content_before_tools = trigger_marker.empty() ? p.eps() : p.until(trigger_marker); - return ctx.reasoning_parser + (force_tools ? p.eps() : p.optional(p.content(content_before_tools))) + tool_calls + - p.end(); + return ctx.reasoning_parser + p.optional(p.content(content_before_tools)) + tool_calls + p.end(); } } // namespace autoparser diff --git a/common/chat-diff-analyzer.cpp b/common/chat-diff-analyzer.cpp index 2f0bd14a..9c7c9678 100644 --- a/common/chat-diff-analyzer.cpp +++ b/common/chat-diff-analyzer.cpp @@ -296,7 +296,7 @@ void analyze_reasoning::compare_reasoning_presence() { return p.literal(reasoning_content) + p.space() + p.optional(p.tag("post", (p.marker() + p.space())) + p.rest()); }); auto parser_wrapped = build_tagged_peg_parser([&](common_peg_parser_builder &p) { - return p.tag("pre", p.marker() + p.space()) + p.literal(reasoning_content) + p.space() + p.tag("post", (p.marker() + p.space())) + p.rest(); + return p.tag("pre", p.marker() + p.space()) + p.literal(reasoning_content) + p.tag("post", (p.space() + p.marker() + p.space())) + p.rest(); }); // try the more aggressive parse first, if it fails, fall back to the delimiter one auto result = parser_wrapped.parse_anywhere_and_extract(comparison->output_B); @@ -306,11 +306,11 @@ void analyze_reasoning::compare_reasoning_presence() { if (result.result.success()) { if (!result.tags["pre"].empty() && !result.tags["post"].empty()) { mode = reasoning_mode::TAG_BASED; - start = trim_leading_whitespace(result.tags["pre"]); - end = trim_trailing_whitespace(result.tags["post"]); + start = result.tags["pre"]; + end = result.tags["post"]; } else if (!result.tags["post"].empty()) { mode = reasoning_mode::TAG_BASED; - end = trim_trailing_whitespace(result.tags["post"]); + end = result.tags["post"]; } } } @@ -342,7 +342,7 @@ void analyze_reasoning::compare_thinking_enabled() { if (left_trimmed.empty() && !diff.right.empty()) { if (!right_trimmed.empty() && string_ends_with(comparison->output_B, right_trimmed)) { if (start.empty()) { - start = trim_leading_whitespace(diff.right); + start = diff.right; mode = reasoning_mode::TAG_BASED; } } @@ -353,7 +353,7 @@ void analyze_reasoning::compare_thinking_enabled() { if (seg.size() >= 2 && seg[seg.size() - 1].value == left_trimmed && seg[seg.size() - 2].type == segment_type::MARKER) { start = seg[seg.size() - 2].value; } - end = trim_trailing_whitespace(diff.left); + end = diff.left; mode = reasoning_mode::TAG_BASED; } } @@ -445,14 +445,14 @@ void analyze_reasoning::compare_reasoning_scope() { auto result = parser_wrapped.parse_anywhere_and_extract(comparison->output_B); if (result.result.success()) { start = result.tags["pre"]; - end = trim_trailing_whitespace(result.tags["post"]); + end = result.tags["post"]; } else { auto parser_delimiter = build_tagged_peg_parser([&](common_peg_parser_builder &p) { return p.literal(reasoning_content) + p.space() + p.optional(p.tag("post", (p.marker() + p.space()))); }); result = parser_delimiter.parse_anywhere_and_extract(comparison->output_B); if (result.result.success()) { - end = trim_trailing_whitespace(result.tags["post"]); + end = result.tags["post"]; } else { LOG_DBG(ANSI_ORANGE "%s: Unable to extract reasoning markers, falling back to reasoning = NONE\n" ANSI_RESET, __func__); mode = reasoning_mode::NONE; diff --git a/common/chat-peg-parser.cpp b/common/chat-peg-parser.cpp index 56eb567d..79274feb 100644 --- a/common/chat-peg-parser.cpp +++ b/common/chat-peg-parser.cpp @@ -358,35 +358,7 @@ void common_chat_peg_mapper::map(const common_peg_ast_node & node) { if (is_potential_container) { value_content = normalize_container_value(value_content); } - - // Try to parse as JSON value (number, bool, null, object, array) - try { - ordered_json parsed = ordered_json::parse(value_content); - if (parsed.is_string()) { - // Don't add closing quote yet (added by arg_close) for monotonic streaming - std::string escaped = parsed.dump(); - if (!escaped.empty() && escaped.back() == '"') { - escaped.pop_back(); - } - value_to_add = escaped; - closing_quote_pending = true; - } else { - // Non-string values: use raw content to preserve whitespace for monotonicity - value_to_add = value_content; - } - } catch (...) { - if (node.is_partial && is_potential_container) { - // Partial container: pass through the already-normalized content - value_to_add = value_content; - } else { - // Not valid JSON - treat as string value - if (!closing_quote_pending) { - value_to_add = "\""; - closing_quote_pending = true; - } - value_to_add += escape_json_string_inner(value_content); - } - } + value_to_add += value_content; } args_target() += value_to_add; @@ -816,6 +788,32 @@ common_peg_parser common_chat_peg_builder::prefix(const std::string & s, const s return literal(s.substr(0, s.rfind(delimiter))); } +common_peg_parser common_chat_peg_builder::optspace(const std::string & tag) { + auto parser = eps(); + size_t end_of_prefix_space = tag.size(); + size_t start_of_suffix_space = tag.size(); + for (size_t i = 0; i < tag.size(); i++) { + if (!std::isspace(tag[i])) { + end_of_prefix_space = i; + break; + } + } + for (size_t i = tag.size(); i > 0; i--) { + if (!std::isspace(tag[i - 1])) { + start_of_suffix_space = i; + break; + } + } + for (size_t i = 0; i < end_of_prefix_space; i++) { + parser += optional(literal(std::string(1, tag[i]))); + } + parser += literal(tag.substr(end_of_prefix_space, start_of_suffix_space - end_of_prefix_space)); + for (size_t i = start_of_suffix_space; i < tag.size(); i++) { + parser += optional(literal(std::string(1, tag[i]))); + } + return parser; +} + common_peg_parser common_chat_peg_builder::standard_json_tools( const std::string & section_start, const std::string & section_end, diff --git a/common/chat-peg-parser.h b/common/chat-peg-parser.h index 1ea3eb7e..be92f17d 100644 --- a/common/chat-peg-parser.h +++ b/common/chat-peg-parser.h @@ -90,12 +90,15 @@ class common_chat_peg_builder : public common_peg_parser_builder { // Use for schema-declared string types - won't be treated as potential JSON container common_peg_parser tool_arg_string_value(const common_peg_parser & p) { return tag(TOOL_ARG_STRING_VALUE, p); } - common_peg_parser tool_arg_json_value(const common_peg_parser & p) { return atomic(tag(TOOL_ARG_VALUE, p)); } + common_peg_parser tool_arg_json_value(const common_peg_parser & p) { return tag(TOOL_ARG_VALUE, p); } // Return a parser that parses the prefix of a string, up to a given delimiter. common_peg_parser prefix(const std::string & s, const std::string & delimiter = {}); + // Return a parser that parses all elements of tag, but leading and trailing spaces are optional + common_peg_parser optspace(const std::string & tag); + // Legacy-compatible helper for building standard JSON tool calls // Used by tests and manual parsers // name_key/args_key: JSON key names for function name and arguments diff --git a/common/chat.cpp b/common/chat.cpp index ed1c0e54..ea3e2fd9 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -81,7 +81,7 @@ json common_chat_msg::to_json_oaicompat(bool concat_typed_text) const { if (!content.empty()) { jmsg["content"] = content; } else if (!content_parts.empty()) { - if (concat_typed_text) { + if (concat_typed_text || contains_media()) { std::string text; bool last_was_media_marker = false; // join parts with newline, do not add newline before or after media markers @@ -398,6 +398,25 @@ json common_chat_msgs_to_json_oaicompat(const std::vector & msg return render_message_to_json(msgs, c); } +json common_chat_tools_to_json_oaicompat(const std::vector & tools) { + if (tools.empty()) { + return json(); + } + + auto result = json::array(); + for (const auto & tool : tools) { + result.push_back({ + { "type", "function" }, + { "function", { + { "name", tool.name }, + { "description", tool.description }, + { "parameters", json::parse(tool.parameters) }, + }}, + }); + } + return result; +} + std::vector common_chat_tools_parse_oaicompat(const json & tools) { std::vector result; @@ -433,56 +452,6 @@ std::vector common_chat_tools_parse_oaicompat(const json & too return result; } -json common_chat_tools_to_json_oaicompat(const std::vector & tools) { - if (tools.empty()) { - return json(); - } - - auto result = json::array(); - for (const auto & tool : tools) { - result.push_back({ - { "type", "function" }, - { "function", - { - { "name", tool.name }, - { "description", tool.description }, - { "parameters", json::parse(tool.parameters) }, - } }, - }); - } - return result; -} - -json common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff) { - json delta = json::object(); - if (!diff.reasoning_content_delta.empty()) { - delta["reasoning_content"] = diff.reasoning_content_delta; - } - if (!diff.content_delta.empty()) { - delta["content"] = diff.content_delta; - } - if (diff.tool_call_index != std::string::npos) { - json tool_call; - tool_call["index"] = diff.tool_call_index; - if (!diff.tool_call_delta.id.empty()) { - tool_call["id"] = diff.tool_call_delta.id; - tool_call["type"] = "function"; - } - if (!diff.tool_call_delta.name.empty() || !diff.tool_call_delta.arguments.empty()) { - json function = json::object(); - if (!diff.tool_call_delta.name.empty()) { - function["name"] = diff.tool_call_delta.name; - } - if (!diff.tool_call_delta.arguments.empty()) { - function["arguments"] = diff.tool_call_delta.arguments; - } - tool_call["function"] = function; - } - delta["tool_calls"] = json::array({ tool_call }); - } - return delta; -} - bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) { if (use_jinja) { try { @@ -2268,22 +2237,38 @@ std::optional common_chat_try_specialized_template( return std::nullopt; } +static std::string common_chat_templates_generation_prompt(const common_chat_template & tmpl, const autoparser::generation_params & inputs) { + autoparser::generation_params params = inputs; + params.add_generation_prompt = false; + std::string no_gen_prompt = common_chat_template_direct_apply_impl(tmpl, params); + params.add_generation_prompt = true; + std::string gen_prompt = common_chat_template_direct_apply_impl(tmpl, params); + + size_t prefix_len = 0; + size_t min_size = std::min(no_gen_prompt.size(), gen_prompt.size()); + while (prefix_len < min_size && no_gen_prompt[prefix_len] == gen_prompt[prefix_len]) { + prefix_len++; + } + return gen_prompt.substr(prefix_len); +} + static common_chat_params common_chat_templates_apply_jinja(const struct common_chat_templates * tmpls, const struct common_chat_templates_inputs & inputs) { autoparser::generation_params params; params.tools = common_chat_tools_to_json_oaicompat(inputs.tools); const auto & tmpl = params.tools.is_array() && tmpls->template_tool_use ? *tmpls->template_tool_use : *tmpls->template_default; - const auto & src = tmpl.source(); - const auto & caps = tmpl.original_caps(); - params.messages = render_message_to_json(inputs.messages, tmpl.original_caps()); - params.tool_choice = inputs.tool_choice; - params.reasoning_format = inputs.reasoning_format; - params.enable_thinking = inputs.enable_thinking; - params.grammar = inputs.grammar; - params.now = inputs.now; - params.add_bos = tmpls->add_bos; - params.add_eos = tmpls->add_eos; + const auto & src = tmpl.source(); + const auto & caps = tmpl.original_caps(); + params.messages = render_message_to_json(inputs.messages, tmpl.original_caps()); + params.tool_choice = inputs.tool_choice; + params.reasoning_format = inputs.reasoning_format; + params.enable_thinking = inputs.enable_thinking; + params.grammar = inputs.grammar; + params.now = inputs.now; + params.add_generation_prompt = inputs.add_generation_prompt; + params.add_bos = tmpls->add_bos; + params.add_eos = tmpls->add_eos; if (src.find("<|channel|>") == std::string::npos) { // map developer to system for all models except for GPT-OSS @@ -2305,14 +2290,7 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_ workaround::func_args_not_string(params.messages); } - params.add_generation_prompt = false; - std::string no_gen_prompt = common_chat_template_direct_apply_impl(tmpl, params); - params.add_generation_prompt = true; - std::string gen_prompt = common_chat_template_direct_apply_impl(tmpl, params); - auto diff = calculate_diff_split(no_gen_prompt, gen_prompt); - params.generation_prompt = diff.right + diff.suffix; - - params.add_generation_prompt = inputs.add_generation_prompt; + params.generation_prompt = common_chat_templates_generation_prompt(tmpl, params); params.extra_context = common_chat_extra_context(); for (auto el : inputs.chat_template_kwargs) { @@ -2366,8 +2344,8 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_ auto auto_params = autoparser::peg_generator::generate_parser(tmpl, params, autoparser); auto_params.supports_thinking = autoparser.reasoning.mode != autoparser::reasoning_mode::NONE; if (auto_params.supports_thinking) { - auto_params.thinking_start_tag = autoparser.reasoning.start; - auto_params.thinking_end_tag = autoparser.reasoning.end; + auto_params.thinking_start_tag = trim_whitespace(autoparser.reasoning.start); + auto_params.thinking_end_tag = trim_whitespace(autoparser.reasoning.end); } auto_params.generation_prompt = params.generation_prompt; common_peg_arena arena; diff --git a/common/chat.h b/common/chat.h index 03eb0cc0..9e450e17 100644 --- a/common/chat.h +++ b/common/chat.h @@ -94,6 +94,15 @@ struct common_chat_msg { tool_name.empty() && tool_call_id.empty(); } + bool contains_media() const { + for (const auto & part : content_parts) { + if (part.type == "media_marker") { + return true; + } + } + return false; + } + void set_tool_call_ids(std::vector & ids_cache, const std::function & gen_tool_call_id) { for (auto i = 0u; i < tool_calls.size(); i++) { @@ -254,14 +263,13 @@ bool common_chat_templates_support_enable_thinking(const common_chat_templates * // Parses a JSON array of messages in OpenAI's chat completion API format. std::vector common_chat_msgs_parse_oaicompat(const nlohmann::ordered_json & messages); +std::vector common_chat_tools_parse_oaicompat(const nlohmann::ordered_json & tools); + // DEPRECATED: only used in tests nlohmann::ordered_json common_chat_msgs_to_json_oaicompat(const std::vector & msgs, bool concat_typed_text = false); -std::vector common_chat_tools_parse_oaicompat(const nlohmann::ordered_json & tools); nlohmann::ordered_json common_chat_tools_to_json_oaicompat(const std::vector & tools); -nlohmann::ordered_json common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff); - // get template caps, useful for reporting to server /props endpoint std::map common_chat_templates_get_caps(const common_chat_templates * chat_templates); diff --git a/common/jinja/caps.cpp b/common/jinja/caps.cpp index ec207a53..ead86476 100644 --- a/common/jinja/caps.cpp +++ b/common/jinja/caps.cpp @@ -1,4 +1,3 @@ -#include "log.h" #include "value.h" #include "runtime.h" #include "caps.h" diff --git a/common/jinja/runtime.h b/common/jinja/runtime.h index 3ca5f175..b6f4a6ab 100644 --- a/common/jinja/runtime.h +++ b/common/jinja/runtime.h @@ -106,10 +106,16 @@ struct statement { size_t pos; // position in source, for debugging virtual ~statement() = default; virtual std::string type() const { return "Statement"; } + // execute_impl must be overridden by derived classes - virtual value execute_impl(context &) { throw std::runtime_error("cannot exec " + type()); } + virtual value execute_impl(context &) { throw_exec_error(); } // execute is the public method to execute a statement with error handling value execute(context &); + +private: + [[noreturn]] void throw_exec_error() const { + throw std::runtime_error("cannot exec " + type()); + } }; // Type Checking Utilities @@ -143,7 +149,7 @@ struct program : public statement { program() = default; explicit program(statements && body) : body(std::move(body)) {} std::string type() const override { return "Program"; } - value execute_impl(context &) override { + [[noreturn]] value execute_impl(context &) override { throw std::runtime_error("Cannot execute program directly, use jinja::runtime instead"); } }; @@ -195,7 +201,7 @@ struct break_statement : public statement { } }; - value execute_impl(context &) override { + [[noreturn]] value execute_impl(context &) override { throw break_statement::signal(); } }; @@ -209,7 +215,7 @@ struct continue_statement : public statement { } }; - value execute_impl(context &) override { + [[noreturn]] value execute_impl(context &) override { throw continue_statement::signal(); } }; @@ -509,7 +515,7 @@ struct slice_expression : public expression { chk_type(this->step_expr); } std::string type() const override { return "SliceExpression"; } - value execute_impl(context &) override { + [[noreturn]] value execute_impl(context &) override { throw std::runtime_error("must be handled by MemberExpression"); } }; diff --git a/common/jinja/value.cpp b/common/jinja/value.cpp index 8e86a715..0b79098c 100644 --- a/common/jinja/value.cpp +++ b/common/jinja/value.cpp @@ -590,6 +590,10 @@ static bool string_endswith(const std::string & str, const std::string & suffix) return str.compare(str.length() - suffix.length(), suffix.length(), suffix) == 0; } +[[noreturn]] static value string_join_not_implemented(const func_args &) { + throw not_implemented_exception("String join builtin not implemented"); +} + const func_builtins & value_string_t::get_builtins() const { static const func_builtins builtins = { {"default", default_value}, @@ -851,9 +855,7 @@ const func_builtins & value_string_t::get_builtins() const { res->val_str.mark_input_based_on(val_input->as_string()); return res; }}, - {"join", [](const func_args &) -> value { - throw not_implemented_exception("String join builtin not implemented"); - }}, + {"join", string_join_not_implemented}, }; return builtins; } @@ -884,6 +886,9 @@ const func_builtins & value_bool_t::get_builtins() const { return builtins; } +[[noreturn]] static value array_unique_not_implemented(const func_args &) { + throw not_implemented_exception("Array unique builtin not implemented"); +} const func_builtins & value_array_t::get_builtins() const { static const func_builtins builtins = { @@ -1084,13 +1089,14 @@ const func_builtins & value_array_t::get_builtins() const { std::reverse(arr.begin(), arr.end()); return is_val(val) ? mk_val(std::move(arr)) : mk_val(std::move(arr)); }}, - {"unique", [](const func_args &) -> value { - throw not_implemented_exception("Array unique builtin not implemented"); - }}, + {"unique", array_unique_not_implemented}, }; return builtins; } +[[noreturn]] static value object_join_not_implemented(const func_args &) { + throw not_implemented_exception("object join not implemented"); +} const func_builtins & value_object_t::get_builtins() const { if (!has_builtins) { @@ -1183,9 +1189,7 @@ const func_builtins & value_object_t::get_builtins() const { }); return result; }}, - {"join", [](const func_args &) -> value { - throw not_implemented_exception("object join not implemented"); - }}, + {"join", object_join_not_implemented}, }; return builtins; } diff --git a/common/jinja/value.h b/common/jinja/value.h index 7d164588..5cf85e4f 100644 --- a/common/jinja/value.h +++ b/common/jinja/value.h @@ -129,27 +129,25 @@ struct value_t { // Note: only for debugging and error reporting purposes virtual std::string type() const { return ""; } - virtual int64_t as_int() const { throw std::runtime_error(type() + " is not an int value"); } - virtual double as_float() const { throw std::runtime_error(type() + " is not a float value"); } - virtual string as_string() const { throw std::runtime_error(type() + " is not a string value"); } - virtual bool as_bool() const { throw std::runtime_error(type() + " is not a bool value"); } - virtual const std::vector & as_array() const { throw std::runtime_error(type() + " is not an array value"); } - virtual const std::vector> & as_ordered_object() const { throw std::runtime_error(type() + " is not an object value"); } - virtual value invoke(const func_args &) const { throw std::runtime_error(type() + " is not a function value"); } + virtual int64_t as_int() const { throw_type_error("is not an int value"); } + virtual double as_float() const { throw_type_error("is not a float value"); } + virtual string as_string() const { throw_type_error("is not a string value"); } + virtual bool as_bool() const { throw_type_error("is not a bool value"); } + virtual const std::vector & as_array() const { throw_type_error("is not an array value"); } + virtual const std::vector> & as_ordered_object() const { throw_type_error("is not an object value"); } + virtual value invoke(const func_args &) const { throw_type_error("is not a function value"); } virtual bool is_none() const { return false; } virtual bool is_undefined() const { return false; } - virtual const func_builtins & get_builtins() const { - throw std::runtime_error("No builtins available for type " + type()); - } + virtual const func_builtins & get_builtins() const { throw_type_error("has no builtins"); } - virtual bool has_key(const value &) { throw std::runtime_error(type() + " is not an object value"); } - virtual void insert(const value & /* key */, const value & /* val */) { throw std::runtime_error(type() + " is not an object value"); } - virtual value & at(const value & /* key */, value & /* default_val */) { throw std::runtime_error(type() + " is not an object value"); } - virtual value & at(const value & /* key */) { throw std::runtime_error(type() + " is not an object value"); } - virtual value & at(const std::string & /* key */, value & /* default_val */) { throw std::runtime_error(type() + " is not an object value"); } - virtual value & at(const std::string & /* key */) { throw std::runtime_error(type() + " is not an object value"); } - virtual value & at(int64_t /* idx */, value & /* default_val */) { throw std::runtime_error(type() + " is not an array value"); } - virtual value & at(int64_t /* idx */) { throw std::runtime_error(type() + " is not an array value"); } + virtual bool has_key(const value &) { throw_type_error("is not an object value"); } + virtual void insert(const value & /* key */, const value & /* val */) { throw_type_error("is not an object value"); } + virtual value & at(const value & /* key */, value & /* default_val */) { throw_type_error("is not an object value"); } + virtual value & at(const value & /* key */) { throw_type_error("is not an object value"); } + virtual value & at(const std::string & /* key */, value & /* default_val */) { throw_type_error("is not an object value"); } + virtual value & at(const std::string & /* key */) { throw_type_error("is not an object value"); } + virtual value & at(int64_t /* idx */, value & /* default_val */) { throw_type_error("is not an array value"); } + virtual value & at(int64_t /* idx */) { throw_type_error("is not an array value"); } virtual bool is_numeric() const { return false; } virtual bool is_hashable() const { return false; } @@ -163,6 +161,11 @@ struct value_t { // Note: only for debugging purposes virtual std::string as_repr() const { return as_string().str(); } +private: + [[noreturn]] void throw_type_error(const char* expected) const { + throw std::runtime_error(type() + " " + expected); + } + protected: virtual bool equivalent(const value_t &) const = 0; virtual bool nonequal(const value_t & other) const { return !equivalent(other); } diff --git a/common/reasoning-budget.cpp b/common/reasoning-budget.cpp index 8f814d9e..d0a9f83c 100644 --- a/common/reasoning-budget.cpp +++ b/common/reasoning-budget.cpp @@ -122,6 +122,20 @@ static void common_reasoning_budget_accept(common_reasoning_budget_ctx * smpl, l } break; case REASONING_BUDGET_DONE: + // Re-arm on a new start tag: some models emit multiple blocks +// per response, and each should get a fresh budget window. + if (ctx->start_matcher.advance(token)) { + ctx->state = REASONING_BUDGET_COUNTING; + ctx->remaining = ctx->budget; + ctx->end_matcher.reset(); + LOG_INF("reasoning-budget: re-activated on new start tag, budget=%d tokens\n", ctx->budget); + + if (ctx->remaining <= 0) { + ctx->state = REASONING_BUDGET_FORCING; + ctx->force_pos = 0; + LOG_INF("reasoning-budget: budget=0, forcing immediately\n"); + } + } break; } } @@ -167,13 +181,7 @@ static struct common_reasoning_budget_ctx * common_reasoning_budget_init_state( static struct common_reasoning_budget_ctx * common_reasoning_budget_clone(const struct common_reasoning_budget_ctx * smpl) { const auto * ctx = (const common_reasoning_budget_ctx *)smpl; - return common_reasoning_budget_init_state( - ctx->vocab, - ctx->start_matcher.tokens, - ctx->end_matcher.tokens, - ctx->forced_tokens, - ctx->budget, - ctx->state); + return new common_reasoning_budget_ctx(*ctx); } static void common_reasoning_budget_free(struct common_reasoning_budget_ctx * smpl) { @@ -220,34 +228,6 @@ static common_reasoning_budget_ctx * common_reasoning_budget_init_state( } struct common_reasoning_budget_ctx * common_reasoning_budget_init( - const struct llama_vocab * vocab, - const std::vector & start_tokens, - const std::vector & end_tokens, - const std::vector & forced_tokens, - int32_t budget, - const std::vector & prefill_tokens) { - // Determine initial state from prefill: COUNTING if the prefill begins with - // the start sequence but does not also contain the end sequence after it. - common_reasoning_budget_state initial_state = REASONING_BUDGET_IDLE; - if (!prefill_tokens.empty() && !start_tokens.empty() && - prefill_tokens.size() >= start_tokens.size() && - std::equal(start_tokens.begin(), start_tokens.end(), prefill_tokens.begin())) { - initial_state = REASONING_BUDGET_COUNTING; - // If the end sequence also follows the start in the prefill, reasoning - // was opened and immediately closed — stay IDLE. - if (!end_tokens.empty() && - prefill_tokens.size() >= start_tokens.size() + end_tokens.size()) { - auto end_start = prefill_tokens.end() - (ptrdiff_t)end_tokens.size(); - if (end_start >= prefill_tokens.begin() + (ptrdiff_t)start_tokens.size() && - std::equal(end_tokens.begin(), end_tokens.end(), end_start)) { - initial_state = REASONING_BUDGET_IDLE; - } - } - } - return common_reasoning_budget_init_state(vocab, start_tokens, end_tokens, forced_tokens, budget, initial_state); -} - -common_reasoning_budget_ctx * common_reasoning_budget_init( const struct llama_vocab * vocab, const std::vector & start_tokens, const std::vector & end_tokens, diff --git a/common/reasoning-budget.h b/common/reasoning-budget.h index 17778ecf..d363e053 100644 --- a/common/reasoning-budget.h +++ b/common/reasoning-budget.h @@ -29,28 +29,15 @@ enum common_reasoning_budget_state { // end_tokens - token sequence for natural deactivation // forced_tokens - token sequence forced when budget expires // budget - max tokens allowed in the reasoning block -// prefill_tokens - tokens already present in the prompt (generation prompt); -// used to determine the initial state: COUNTING if they begin -// with start_tokens (but don't also end with end_tokens), -// IDLE otherwise. COUNTING with budget <= 0 is promoted to FORCING. +// initial_state - initial state // -struct common_reasoning_budget_ctx * common_reasoning_budget_init( - const struct llama_vocab * vocab, - const std::vector & start_tokens, - const std::vector & end_tokens, - const std::vector & forced_tokens, - int32_t budget, - const std::vector & prefill_tokens = {}); - -// Variant that takes an explicit initial state (used by tests and clone). -// COUNTING with budget <= 0 is promoted to FORCING. struct common_reasoning_budget_ctx * common_reasoning_budget_init( const struct llama_vocab * vocab, const std::vector & start_tokens, const std::vector & end_tokens, const std::vector & forced_tokens, int32_t budget, - common_reasoning_budget_state initial_state); + common_reasoning_budget_state initial_state = REASONING_BUDGET_IDLE); common_reasoning_budget_state common_reasoning_budget_get_state(const common_reasoning_budget_ctx * smpl); diff --git a/common/sampling.cpp b/common/sampling.cpp index 5c0a860d..ad5b7fae 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -93,33 +93,36 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co result->grammar_root = "root"; } + // Compute prefill tokens from the generation prompt + std::vector prefill_tokens; + if (!params.generation_prompt.empty()) { + GGML_ASSERT(vocab != nullptr); + auto tokens = common_tokenize(vocab, params.generation_prompt, false, true); + for (size_t i = 0; i < tokens.size(); i++) { + std::string piece = common_token_to_piece(vocab, tokens[i], true); + if (i == 0 && std::isspace(piece[0]) && !std::isspace(params.generation_prompt[0])) { + // Some tokenizers will add a space before the first special token, need to exclude + continue; + } + LOG_DBG("%s: prefill token: %d = %s\n", __func__, tokens[i], piece.c_str()); + prefill_tokens.push_back(tokens[i]); + } + } + // Feed generation prompt tokens to the grammar sampler so it advances past // tokens the template already placed in the prompt. // Only applies to output-format and tool-call grammars; user-supplied grammars must not be prefilled. - std::vector prefill_tokens; - if (!params.generation_prompt.empty() && common_grammar_needs_prefill(params.grammar)) { - GGML_ASSERT(vocab != nullptr); - prefill_tokens = common_tokenize(vocab, params.generation_prompt, false, true); - if (!prefill_tokens.empty()) { - std::string first_token = common_token_to_piece(vocab, prefill_tokens[0], true); - if (std::isspace(first_token[0]) && !std::isspace(params.generation_prompt[0])) { - // Some tokenizers will add a space before the first special token, need to remove - prefill_tokens = std::vector(prefill_tokens.begin() + 1, prefill_tokens.end()); + if (grmr && !params.grammar_lazy && common_grammar_needs_prefill(params.grammar)) { + try { + for (const auto & token : prefill_tokens) { + llama_grammar_accept_impl(*grmr, vocab, nullptr, token); + LOG_DBG("%s: grammar accepted prefill token (%d)\n", __func__, token); } } - - if (grmr && !params.grammar_lazy) { - try { - for (const auto & token : prefill_tokens) { - llama_grammar_accept_impl(*grmr, vocab, nullptr, token); - LOG_DBG("%s: accepted prefill token (%d)\n", __func__, token); - } - } - catch (std::exception & e) { - LOG_ERR("%s: error initializing grammar sampler for grammar:\n%s\n\nGeneration prompt:\n'%s'\n", __func__, - common_grammar_value(params.grammar).c_str(), params.generation_prompt.c_str()); - throw e; - } + catch (std::exception & e) { + LOG_ERR("%s: error initializing grammar sampler for grammar:\n%s\n\nGeneration prompt:\n'%s'\n", __func__, + common_grammar_value(params.grammar).c_str(), params.generation_prompt.c_str()); + throw e; } } @@ -130,8 +133,12 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co params.reasoning_budget_start, params.reasoning_budget_end, params.reasoning_budget_forced, - params.reasoning_budget_tokens < 0 ? INT_MAX : params.reasoning_budget_tokens, - prefill_tokens); + params.reasoning_budget_tokens < 0 ? INT_MAX : params.reasoning_budget_tokens); + + for (const auto & token : prefill_tokens) { + common_reasoning_budget_accept(result->rbudget, token); + LOG_DBG("%s: reasoning-budget accepted prefill token (%d)\n", __func__, token); + } } llama_sampling_set_rng_seed(result, params.seed); @@ -689,19 +696,19 @@ void common_sampler_accept( struct common_sampler * ctx_sampling, struct llama_context * ctx_main, llama_token token, - bool accept_grammar) { + bool is_generated) { if (ctx_sampling->prev.size() > 0) { ctx_sampling->prev.erase(ctx_sampling->prev.begin()); } ctx_sampling->prev.push_back(token); // grammar_should_apply() checks the reasoning budget state, so calculate this before we accept - accept_grammar = accept_grammar && grammar_should_apply(ctx_sampling); - if (ctx_sampling->rbudget) { + const auto accept_grammar = is_generated && grammar_should_apply(ctx_sampling); + if (ctx_sampling->rbudget && is_generated) { common_reasoning_budget_accept(ctx_sampling->rbudget, token); } - if (ctx_sampling->grammar != NULL && accept_grammar) { + if (ctx_sampling->grammar && accept_grammar) { llama_grammar_accept_token(ctx_sampling->grammar, ctx_main, token); } if (ctx_sampling->smpl) { diff --git a/common/sampling.h b/common/sampling.h index c38e00fc..bdaf5ee5 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -316,11 +316,12 @@ llama_token_data_array llama_sampling_prepare( bool apply_grammar = true, std::vector * original_logits = nullptr); +// if is_generated is true, the token is accepted by the sampling chain, the reasoning budget sampler, and the grammar sampler void common_sampler_accept( struct common_sampler * ctx_sampling, struct llama_context * ctx_main, llama_token id, - bool apply_grammar); + bool is_generated); // returns at least 1 token, up to draft.size() // access the internal list of current candidate tokens diff --git a/examples/server/CMakeLists.txt b/examples/server/CMakeLists.txt index d3e4ecd5..91858cb3 100644 --- a/examples/server/CMakeLists.txt +++ b/examples/server/CMakeLists.txt @@ -12,7 +12,8 @@ endif() set(TARGET_SRCS server.cpp - # httplib.h + server-chat.cpp + server-chat.h server-task.cpp server-task.h server-queue.cpp diff --git a/examples/server/server-chat.cpp b/examples/server/server-chat.cpp new file mode 100644 index 00000000..9afb9499 --- /dev/null +++ b/examples/server/server-chat.cpp @@ -0,0 +1,558 @@ +#include "server-chat.h" +#include "server-common.h" + +#include + +json server_chat_convert_responses_to_chatcmpl(const json& response_body) { + if (!response_body.contains("input")) { + throw std::runtime_error("'input' is required"); + } + if (!json_value(response_body, "previous_response_id", std::string{}).empty()) { + throw std::runtime_error("ik_llama.cpp does not support 'previous_response_id'."); + } + + const json input_value = response_body.at("input"); + json chatcmpl_body = response_body; + chatcmpl_body.erase("input"); + std::vector chatcmpl_messages; + + if (response_body.contains("instructions")) { + chatcmpl_messages.push_back({ + {"role", "system"}, + {"content", json_value(response_body, "instructions", std::string())}, + }); + chatcmpl_body.erase("instructions"); + } + + if (input_value.is_string()) { + chatcmpl_messages.push_back({ + {"role", "user"}, + {"content", input_value}, + }); + } + else if (input_value.is_array()) { + static auto exists_and_is_array = [](const json& j, const char* key) -> bool { + return j.contains(key) && j.at(key).is_array(); + }; + static auto exists_and_is_string = [](const json& j, const char* key) -> bool { + return j.contains(key) && j.at(key).is_string(); + }; + + for (json item : input_value) { + if (exists_and_is_string(item, "content")) { + item["content"] = json::array({ + json{ + {"text", item.at("content")}, + {"type", "input_text"}, + } + }); + } + + if (exists_and_is_array(item, "content") && + exists_and_is_string(item, "role") && + (item.at("role") == "user" || item.at("role") == "system" || item.at("role") == "developer") + ) { + std::vector chatcmpl_content; + + for (const json& input_item : item.at("content")) { + const std::string type = json_value(input_item, "type", std::string()); + + if (type == "input_text") { + if (!input_item.contains("text")) { + throw std::runtime_error("'Input text' requires 'text'"); + } + chatcmpl_content.push_back({ + {"text", input_item.at("text")}, + {"type", "text"}, + }); + } + else if (type == "input_image") { + if (!input_item.contains("image_url")) { + throw std::runtime_error("'image_url' is required"); + } + chatcmpl_content.push_back({ + {"image_url", json{ + {"url", input_item.at("image_url")}, + }}, + {"type", "image_url"}, + }); + } + else if (type == "input_file") { + throw std::runtime_error("'input_file' is not supported by ik_llama.cpp at this moment"); + } + else { + throw std::runtime_error("'type' must be one of 'input_text', 'input_image', or 'input_file'"); + } + } + + if (item.contains("type")) { + item.erase("type"); + } + if (item.contains("status")) { + item.erase("status"); + } + item["content"] = chatcmpl_content; + + chatcmpl_messages.push_back(item); + } + else if (exists_and_is_array(item, "content") && + exists_and_is_string(item, "role") && + item.at("role") == "assistant" && + exists_and_is_string(item, "type") && + item.at("type") == "message" + ) { + std::vector chatcmpl_content; + + for (const auto& output_text : item.at("content")) { + const std::string type = json_value(output_text, "type", std::string()); + if (type != "output_text") { + throw std::runtime_error("'type' must be 'output_text'"); + } + if (!exists_and_is_string(output_text, "text")) { + throw std::runtime_error("'Output text' requires 'text'"); + } + chatcmpl_content.push_back({ + {"text", output_text.at("text")}, + {"type", "text"}, + }); + } + + item.erase("status"); + item.erase("type"); + item["content"] = chatcmpl_content; + chatcmpl_messages.push_back(item); + } + else if (exists_and_is_string(item, "arguments") && + exists_and_is_string(item, "call_id") && + exists_and_is_string(item, "name") && + exists_and_is_string(item, "type") && + item.at("type") == "function_call" + ) { + json msg = json{ + {"role", "assistant"}, + {"tool_calls", json::array({json{ + {"function", json{ + {"arguments", item.at("arguments")}, + {"name", item.at("name")}, + }}, + {"id", item.at("call_id")}, + {"type", "function"}, + }})}, + }; + + if (!chatcmpl_messages.empty() && chatcmpl_messages.back().contains("reasoning_content")) { + msg["reasoning_content"] = chatcmpl_messages.back().at("reasoning_content"); + chatcmpl_messages.pop_back(); + } + chatcmpl_messages.push_back(msg); + } + else if (exists_and_is_string(item, "call_id") && + (exists_and_is_string(item, "output") || exists_and_is_array(item, "output")) && + exists_and_is_string(item, "type") && + item.at("type") == "function_call_output" + ) { + if (item.at("output").is_string()) { + chatcmpl_messages.push_back(json{ + {"content", item.at("output")}, + {"role", "tool"}, + {"tool_call_id", item.at("call_id")}, + }); + } + else { + json chatcmpl_outputs = item.at("output"); + for (json& chatcmpl_output : chatcmpl_outputs) { + if (!chatcmpl_output.contains("type") || chatcmpl_output.at("type") != "input_text") { + throw std::runtime_error("Output of tool call should be 'Input text'"); + } + chatcmpl_output["type"] = "text"; + } + chatcmpl_messages.push_back(json{ + {"content", chatcmpl_outputs}, + {"role", "tool"}, + {"tool_call_id", item.at("call_id")}, + }); + } + } + else if (exists_and_is_array(item, "summary") && + exists_and_is_string(item, "type") && + item.at("type") == "reasoning") { + if (!exists_and_is_array(item, "content")) { + throw std::runtime_error("item['content'] is not an array"); + } + if (item.at("content").empty()) { + throw std::runtime_error("item['content'] is empty"); + } + if (!exists_and_is_string(item.at("content")[0], "text")) { + throw std::runtime_error("item['content']['text'] is not a string"); + } + + chatcmpl_messages.push_back(json{ + {"role", "assistant"}, + {"content", json::array()}, + {"reasoning_content", item.at("content")[0].at("text")}, + }); + } + else { + throw std::runtime_error("Cannot determine type of 'item'"); + } + } + } + else { + throw std::runtime_error("'input' must be a string or array of objects"); + } + + chatcmpl_messages.erase(std::remove_if( + chatcmpl_messages.begin(), + chatcmpl_messages.end(), + [](const json& x) { + return x.contains("role") && + x.at("role") == "assistant" && + x.contains("content") && + x.at("content") == json::array() && + x.contains("reasoning_content"); + }), + chatcmpl_messages.end()); + + chatcmpl_body["messages"] = chatcmpl_messages; + + if (response_body.contains("tools")) { + if (!response_body.at("tools").is_array()) { + throw std::runtime_error("'tools' must be an array of objects"); + } + std::vector chatcmpl_tools; + for (json resp_tool : response_body.at("tools")) { + json chatcmpl_tool; + + if (json_value(resp_tool, "type", std::string()) != "function") { + throw std::runtime_error("'type' of tool must be 'function'"); + } + resp_tool.erase("type"); + chatcmpl_tool["type"] = "function"; + + if (!resp_tool.contains("strict")) { + resp_tool["strict"] = true; + } + chatcmpl_tool["function"] = resp_tool; + chatcmpl_tools.push_back(chatcmpl_tool); + } + chatcmpl_body.erase("tools"); + chatcmpl_body["tools"] = chatcmpl_tools; + } + + if (response_body.contains("max_output_tokens")) { + chatcmpl_body.erase("max_output_tokens"); + chatcmpl_body["max_tokens"] = response_body["max_output_tokens"]; + } + + return chatcmpl_body; +} + +json server_chat_convert_anthropic_to_oai(const json & body) { + json oai_body; + + // Convert system prompt + json oai_messages = json::array(); + auto system_param = json_value(body, "system", json()); + if (!system_param.is_null()) { + std::string system_content; + + if (system_param.is_string()) { + system_content = system_param.get(); + } else if (system_param.is_array()) { + for (const auto & block : system_param) { + if (json_value(block, "type", std::string()) == "text") { + std::string content_block = json_value(block, "text", std::string()); + if (!string_starts_with(content_block, "x-anthropic-")) { + system_content += content_block; + } + } + } + } + + oai_messages.push_back({ + {"role", "system"}, + {"content", system_content} + }); + } + + // Convert messages + if (!body.contains("messages")) { + throw std::runtime_error("'messages' is required"); + } + const json & messages = body.at("messages"); + if (messages.is_array()) { + for (const auto & msg : messages) { + std::string role = json_value(msg, "role", std::string()); + + if (!msg.contains("content")) { + if (role == "assistant") { + continue; + } + oai_messages.push_back(msg); + continue; + } + + const json & content = msg.at("content"); + + if (content.is_string()) { + oai_messages.push_back(msg); + continue; + } + + if (!content.is_array()) { + oai_messages.push_back(msg); + continue; + } + + json tool_calls = json::array(); + json converted_content = json::array(); + json tool_results = json::array(); + std::string reasoning_content; + bool has_tool_calls = false; + + for (const auto & block : content) { + std::string type = json_value(block, "type", std::string()); + + if (type == "text") { + converted_content.push_back(block); + } else if (type == "thinking") { + reasoning_content += json_value(block, "thinking", std::string()); + } else if (type == "image") { + json source = json_value(block, "source", json::object()); + std::string source_type = json_value(source, "type", std::string()); + + if (source_type == "base64") { + std::string media_type = json_value(source, "media_type", std::string("image/jpeg")); + std::string data = json_value(source, "data", std::string()); + std::ostringstream ss; + ss << "data:" << media_type << ";base64," << data; + + converted_content.push_back({ + {"type", "image_url"}, + {"image_url", { + {"url", ss.str()} + }} + }); + } else if (source_type == "url") { + std::string url = json_value(source, "url", std::string()); + converted_content.push_back({ + {"type", "image_url"}, + {"image_url", { + {"url", url} + }} + }); + } + } else if (type == "tool_use") { + tool_calls.push_back({ + {"id", json_value(block, "id", std::string())}, + {"type", "function"}, + {"function", { + {"name", json_value(block, "name", std::string())}, + {"arguments", json_value(block, "input", json::object()).dump()} + }} + }); + has_tool_calls = true; + } else if (type == "tool_result") { + std::string tool_use_id = json_value(block, "tool_use_id", std::string()); + + auto result_content = json_value(block, "content", json()); + std::string result_text; + if (result_content.is_string()) { + result_text = result_content.get(); + } else if (result_content.is_array()) { + for (const auto & c : result_content) { + if (json_value(c, "type", std::string()) == "text") { + result_text += json_value(c, "text", std::string()); + } + } + } + + tool_results.push_back({ + {"role", "tool"}, + {"tool_call_id", tool_use_id}, + {"content", result_text} + }); + } + } + + if (!converted_content.empty() || has_tool_calls || !reasoning_content.empty()) { + json new_msg = { {"role", role} }; + if (!converted_content.empty()) { + new_msg["content"] = converted_content; + } else if (has_tool_calls || !reasoning_content.empty()) { + new_msg["content"] = ""; + } + if (!tool_calls.empty()) { + new_msg["tool_calls"] = tool_calls; + } + if (!reasoning_content.empty()) { + new_msg["reasoning_content"] = reasoning_content; + } + oai_messages.push_back(new_msg); + } + + for (const auto & tool_msg : tool_results) { + oai_messages.push_back(tool_msg); + } + } + } + + oai_body["messages"] = oai_messages; + + // Convert tools + if (body.contains("tools")) { + const json & tools = body.at("tools"); + if (tools.is_array()) { + json oai_tools = json::array(); + for (const auto & tool : tools) { + oai_tools.push_back({ + {"type", "function"}, + {"function", { + {"name", json_value(tool, "name", std::string())}, + {"description", json_value(tool, "description", std::string())}, + {"parameters", tool.contains("input_schema") ? tool.at("input_schema") : json::object()} + }} + }); + } + oai_body["tools"] = oai_tools; + } + } + + // Convert tool_choice + if (body.contains("tool_choice")) { + const json & tc = body.at("tool_choice"); + if (tc.is_object()) { + std::string type = json_value(tc, "type", std::string()); + if (type == "auto") { + oai_body["tool_choice"] = "auto"; + } else if (type == "any" || type == "tool") { + oai_body["tool_choice"] = "required"; + } + } + } + + // Convert stop_sequences to stop + if (body.contains("stop_sequences")) { + oai_body["stop"] = body.at("stop_sequences"); + } + + // Handle max_tokens (required in Anthropic, but we're permissive) + if (body.contains("max_tokens")) { + oai_body["max_tokens"] = body.at("max_tokens"); + } else { + oai_body["max_tokens"] = 4096; + } + + // Pass through common params + for (const auto & key : { "temperature", "top_p", "top_k", "stream" }) { + if (body.contains(key)) { + oai_body[key] = body.at(key); + } + } + + // Handle Anthropic-specific thinking param + if (body.contains("thinking")) { + json thinking = json_value(body, "thinking", json::object()); + std::string thinking_type = json_value(thinking, "type", std::string()); + if (thinking_type == "enabled") { + int budget_tokens = json_value(thinking, "budget_tokens", 10000); + oai_body["thinking_budget_tokens"] = budget_tokens; + } + } + + // Handle Anthropic-specific metadata param + if (body.contains("metadata")) { + json metadata = json_value(body, "metadata", json::object()); + std::string user_id = json_value(metadata, "user_id", std::string()); + if (!user_id.empty()) { + oai_body["__metadata_user_id"] = user_id; + } + } + + return oai_body; +} + + +json server_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff) { + json delta = json::object(); + if (!diff.reasoning_content_delta.empty()) { + delta["reasoning_content"] = diff.reasoning_content_delta; + } + if (!diff.content_delta.empty()) { + delta["content"] = diff.content_delta; + } + if (diff.tool_call_index != std::string::npos) { + json tool_call; + tool_call["index"] = diff.tool_call_index; + if (!diff.tool_call_delta.id.empty()) { + tool_call["id"] = diff.tool_call_delta.id; + tool_call["type"] = "function"; + } + if (!diff.tool_call_delta.name.empty() || !diff.tool_call_delta.arguments.empty()) { + json function = json::object(); + if (!diff.tool_call_delta.name.empty()) { + function["name"] = diff.tool_call_delta.name; + } + if (!diff.tool_call_delta.arguments.empty()) { + function["arguments"] = diff.tool_call_delta.arguments; + } + tool_call["function"] = function; + } + delta["tool_calls"] = json::array({ tool_call }); + } + return delta; +} + +json convert_transcriptions_to_chatcmpl( + const json & inp_body, + const std::map & in_files, + std::vector & out_files) { + // TODO @ngxson : this function may need to be improved in the future + // handle input files + out_files.clear(); + auto it = in_files.find("file"); + if (it != in_files.end()) { + out_files.push_back(it->second); + } else { + throw std::invalid_argument("No input file found for transcription"); + } + + // handle input data + std::string prompt = json_value(inp_body, "prompt", std::string()); + std::string language = json_value(inp_body, "language", std::string()); + std::string response_format = json_value(inp_body, "response_format", std::string("json")); + if (response_format != "json") { + throw std::invalid_argument("Only 'json' response_format is supported for transcription"); + } + if (prompt.empty()) { + prompt = "Transcribe audio to text"; + } + if (!language.empty()) { + prompt += string_format(" (language: %s)", language.c_str()); + } + prompt += get_media_marker(); + + json chatcmpl_body = inp_body; // copy all fields + chatcmpl_body["messages"] = json::array({ + { + {"role", "user"}, + {"content", prompt}, + }, + }); + + // because input from form-data, everything is string, we need to correct the types here + std::string stream = json_value(inp_body, "stream", std::string("false")); + chatcmpl_body["stream"] = stream == "true"; + + if (inp_body.contains("max_tokens")) { + std::string inp = inp_body["max_tokens"].get(); + chatcmpl_body["max_tokens"] = std::stoul(inp); + } + + if (inp_body.contains("temperature")) { + std::string inp = inp_body["temperature"].get(); + chatcmpl_body["temperature"] = std::stof(inp); + } + + return chatcmpl_body; +} diff --git a/examples/server/server-chat.h b/examples/server/server-chat.h new file mode 100644 index 00000000..ecb8907c --- /dev/null +++ b/examples/server/server-chat.h @@ -0,0 +1,24 @@ +// Chat conversion functions for server (Responses API, Anthropic API, OAI streaming diffs) + +#pragma once + +#include "chat.h" +#include "server-common.h" + +#include + +using json = nlohmann::ordered_json; + +// Convert OpenAI Responses API format to OpenAI Chat Completions API format +json server_chat_convert_responses_to_chatcmpl(const json & body); + +// Convert Anthropic Messages API format to OpenAI Chat Completions API format +json server_chat_convert_anthropic_to_oai(const json & body); + +// convert OpenAI transcriptions API format to OpenAI Chat Completions API format +json convert_transcriptions_to_chatcmpl( + const json & body, + const std::map & in_files, + std::vector & out_files); + +json server_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff); diff --git a/examples/server/server-common.cpp b/examples/server/server-common.cpp index 0aa26e96..3ec73450 100644 --- a/examples/server/server-common.cpp +++ b/examples/server/server-common.cpp @@ -159,6 +159,18 @@ std::string gen_tool_call_id() { return random_string(); } +const char * get_media_marker() { + static const std::string marker = []() { + // allow user to pin a reproducible marker via env var + const char * env = getenv("LLAMA_MEDIA_MARKER"); + if (env && env[0] != '\0') { + return std::string(env); + } + return std::string("<__media_") + random_string() + "__>"; + }(); + return marker.c_str(); +} + // // other common utils // @@ -920,474 +932,7 @@ json oaicompat_chat_params_parse( return llama_params; } -json convert_responses_to_chatcmpl(const json& response_body) { - if (!response_body.contains("input")) { - throw std::runtime_error("'input' is required"); - } - if (!json_value(response_body, "previous_response_id", std::string{}).empty()) { - throw std::runtime_error("ik_llama.cpp does not support 'previous_response_id'."); - } - const json input_value = response_body.at("input"); - json chatcmpl_body = response_body; - chatcmpl_body.erase("input"); - std::vector chatcmpl_messages; - - if (response_body.contains("instructions")) { - chatcmpl_messages.push_back({ - {"role", "system"}, - {"content", json_value(response_body, "instructions", std::string())}, - }); - chatcmpl_body.erase("instructions"); - } - - if (input_value.is_string()) { - chatcmpl_messages.push_back({ - {"role", "user"}, - {"content", input_value}, - }); - } - else if (input_value.is_array()) { - static auto exists_and_is_array = [](const json& j, const char* key) -> bool { - return j.contains(key) && j.at(key).is_array(); - }; - static auto exists_and_is_string = [](const json& j, const char* key) -> bool { - return j.contains(key) && j.at(key).is_string(); - }; - - for (json item : input_value) { - if (exists_and_is_string(item, "content")) { - item["content"] = json::array({ - json{ - {"text", item.at("content")}, - {"type", "input_text"}, - } - }); - } - - if (exists_and_is_array(item, "content") && - exists_and_is_string(item, "role") && - (item.at("role") == "user" || item.at("role") == "system" || item.at("role") == "developer") - ) { - std::vector chatcmpl_content; - - for (const json& input_item : item.at("content")) { - const std::string type = json_value(input_item, "type", std::string()); - - if (type == "input_text") { - if (!input_item.contains("text")) { - throw std::runtime_error("'Input text' requires 'text'"); - } - chatcmpl_content.push_back({ - {"text", input_item.at("text")}, - {"type", "text"}, - }); - } - else if (type == "input_image") { - if (!input_item.contains("image_url")) { - throw std::runtime_error("'image_url' is required"); - } - chatcmpl_content.push_back({ - {"image_url", json{ - {"url", input_item.at("image_url")}, - }}, - {"type", "image_url"}, - }); - } - else if (type == "input_file") { - throw std::runtime_error("'input_file' is not supported by ik_llama.cpp at this moment"); - } - else { - throw std::runtime_error("'type' must be one of 'input_text', 'input_image', or 'input_file'"); - } - } - - if (item.contains("type")) { - item.erase("type"); - } - if (item.contains("status")) { - item.erase("status"); - } - item["content"] = chatcmpl_content; - - chatcmpl_messages.push_back(item); - } - else if (exists_and_is_array(item, "content") && - exists_and_is_string(item, "role") && - item.at("role") == "assistant" && - exists_and_is_string(item, "type") && - item.at("type") == "message" - ) { - std::vector chatcmpl_content; - - for (const auto& output_text : item.at("content")) { - const std::string type = json_value(output_text, "type", std::string()); - if (type != "output_text") { - throw std::runtime_error("'type' must be 'output_text'"); - } - if (!exists_and_is_string(output_text, "text")) { - throw std::runtime_error("'Output text' requires 'text'"); - } - chatcmpl_content.push_back({ - {"text", output_text.at("text")}, - {"type", "text"}, - }); - } - - item.erase("status"); - item.erase("type"); - item["content"] = chatcmpl_content; - chatcmpl_messages.push_back(item); - } - else if (exists_and_is_string(item, "arguments") && - exists_and_is_string(item, "call_id") && - exists_and_is_string(item, "name") && - exists_and_is_string(item, "type") && - item.at("type") == "function_call" - ) { - json msg = json{ - {"role", "assistant"}, - {"tool_calls", json::array({json{ - {"function", json{ - {"arguments", item.at("arguments")}, - {"name", item.at("name")}, - }}, - {"id", item.at("call_id")}, - {"type", "function"}, - }})}, - }; - - if (!chatcmpl_messages.empty() && chatcmpl_messages.back().contains("reasoning_content")) { - msg["reasoning_content"] = chatcmpl_messages.back().at("reasoning_content"); - chatcmpl_messages.pop_back(); - } - chatcmpl_messages.push_back(msg); - } - else if (exists_and_is_string(item, "call_id") && - (exists_and_is_string(item, "output") || exists_and_is_array(item, "output")) && - exists_and_is_string(item, "type") && - item.at("type") == "function_call_output" - ) { - if (item.at("output").is_string()) { - chatcmpl_messages.push_back(json{ - {"content", item.at("output")}, - {"role", "tool"}, - {"tool_call_id", item.at("call_id")}, - }); - } - else { - json chatcmpl_outputs = item.at("output"); - for (json& chatcmpl_output : chatcmpl_outputs) { - if (!chatcmpl_output.contains("type") || chatcmpl_output.at("type") != "input_text") { - throw std::runtime_error("Output of tool call should be 'Input text'"); - } - chatcmpl_output["type"] = "text"; - } - chatcmpl_messages.push_back(json{ - {"content", chatcmpl_outputs}, - {"role", "tool"}, - {"tool_call_id", item.at("call_id")}, - }); - } - } - else if (exists_and_is_array(item, "summary") && - exists_and_is_string(item, "type") && - item.at("type") == "reasoning") { - if (!exists_and_is_array(item, "content")) { - throw std::runtime_error("item['content'] is not an array"); - } - if (item.at("content").empty()) { - throw std::runtime_error("item['content'] is empty"); - } - if (!exists_and_is_string(item.at("content")[0], "text")) { - throw std::runtime_error("item['content']['text'] is not a string"); - } - - chatcmpl_messages.push_back(json{ - {"role", "assistant"}, - {"content", json::array()}, - {"reasoning_content", item.at("content")[0].at("text")}, - }); - } - else { - throw std::runtime_error("Cannot determine type of 'item'"); - } - } - } - else { - throw std::runtime_error("'input' must be a string or array of objects"); - } - - chatcmpl_messages.erase(std::remove_if( - chatcmpl_messages.begin(), - chatcmpl_messages.end(), - [](const json& x) { - return x.contains("role") && - x.at("role") == "assistant" && - x.contains("content") && - x.at("content") == json::array() && - x.contains("reasoning_content"); - }), - chatcmpl_messages.end()); - - chatcmpl_body["messages"] = chatcmpl_messages; - - if (response_body.contains("tools")) { - if (!response_body.at("tools").is_array()) { - throw std::runtime_error("'tools' must be an array of objects"); - } - std::vector chatcmpl_tools; - for (json resp_tool : response_body.at("tools")) { - json chatcmpl_tool; - - if (json_value(resp_tool, "type", std::string()) != "function") { - throw std::runtime_error("'type' of tool must be 'function'"); - } - resp_tool.erase("type"); - chatcmpl_tool["type"] = "function"; - - if (!resp_tool.contains("strict")) { - resp_tool["strict"] = true; - } - chatcmpl_tool["function"] = resp_tool; - chatcmpl_tools.push_back(chatcmpl_tool); - } - chatcmpl_body.erase("tools"); - chatcmpl_body["tools"] = chatcmpl_tools; - } - - if (response_body.contains("max_output_tokens")) { - chatcmpl_body.erase("max_output_tokens"); - chatcmpl_body["max_tokens"] = response_body["max_output_tokens"]; - } - - return chatcmpl_body; -} - -json convert_anthropic_to_oai(const json & body) { - json oai_body; - - // Convert system prompt - json oai_messages = json::array(); - auto system_param = json_value(body, "system", json()); - if (!system_param.is_null()) { - std::string system_content; - - if (system_param.is_string()) { - system_content = system_param.get(); - } else if (system_param.is_array()) { - for (const auto & block : system_param) { - if (json_value(block, "type", std::string()) == "text") { - std::string content_block = json_value(block, "text", std::string()); - if (!string_starts_with(content_block, "x-anthropic-")) { - system_content += content_block; - } - } - } - } - - oai_messages.push_back({ - {"role", "system"}, - {"content", system_content} - }); - } - - // Convert messages - if (!body.contains("messages")) { - throw std::runtime_error("'messages' is required"); - } - const json & messages = body.at("messages"); - if (messages.is_array()) { - for (const auto & msg : messages) { - std::string role = json_value(msg, "role", std::string()); - - if (!msg.contains("content")) { - if (role == "assistant") { - continue; - } - oai_messages.push_back(msg); - continue; - } - - const json & content = msg.at("content"); - - if (content.is_string()) { - oai_messages.push_back(msg); - continue; - } - - if (!content.is_array()) { - oai_messages.push_back(msg); - continue; - } - - json tool_calls = json::array(); - json converted_content = json::array(); - json tool_results = json::array(); - std::string reasoning_content; - bool has_tool_calls = false; - - for (const auto & block : content) { - std::string type = json_value(block, "type", std::string()); - - if (type == "text") { - converted_content.push_back(block); - } else if (type == "thinking") { - reasoning_content += json_value(block, "thinking", std::string()); - } else if (type == "image") { - json source = json_value(block, "source", json::object()); - std::string source_type = json_value(source, "type", std::string()); - - if (source_type == "base64") { - std::string media_type = json_value(source, "media_type", std::string("image/jpeg")); - std::string data = json_value(source, "data", std::string()); - std::ostringstream ss; - ss << "data:" << media_type << ";base64," << data; - - converted_content.push_back({ - {"type", "image_url"}, - {"image_url", { - {"url", ss.str()} - }} - }); - } else if (source_type == "url") { - std::string url = json_value(source, "url", std::string()); - converted_content.push_back({ - {"type", "image_url"}, - {"image_url", { - {"url", url} - }} - }); - } - } else if (type == "tool_use") { - tool_calls.push_back({ - {"id", json_value(block, "id", std::string())}, - {"type", "function"}, - {"function", { - {"name", json_value(block, "name", std::string())}, - {"arguments", json_value(block, "input", json::object()).dump()} - }} - }); - has_tool_calls = true; - } else if (type == "tool_result") { - std::string tool_use_id = json_value(block, "tool_use_id", std::string()); - - auto result_content = json_value(block, "content", json()); - std::string result_text; - if (result_content.is_string()) { - result_text = result_content.get(); - } else if (result_content.is_array()) { - for (const auto & c : result_content) { - if (json_value(c, "type", std::string()) == "text") { - result_text += json_value(c, "text", std::string()); - } - } - } - - tool_results.push_back({ - {"role", "tool"}, - {"tool_call_id", tool_use_id}, - {"content", result_text} - }); - } - } - - if (!converted_content.empty() || has_tool_calls || !reasoning_content.empty()) { - json new_msg = { {"role", role} }; - if (!converted_content.empty()) { - new_msg["content"] = converted_content; - } else if (has_tool_calls || !reasoning_content.empty()) { - new_msg["content"] = ""; - } - if (!tool_calls.empty()) { - new_msg["tool_calls"] = tool_calls; - } - if (!reasoning_content.empty()) { - new_msg["reasoning_content"] = reasoning_content; - } - oai_messages.push_back(new_msg); - } - - for (const auto & tool_msg : tool_results) { - oai_messages.push_back(tool_msg); - } - } - } - - oai_body["messages"] = oai_messages; - - // Convert tools - if (body.contains("tools")) { - const json & tools = body.at("tools"); - if (tools.is_array()) { - json oai_tools = json::array(); - for (const auto & tool : tools) { - oai_tools.push_back({ - {"type", "function"}, - {"function", { - {"name", json_value(tool, "name", std::string())}, - {"description", json_value(tool, "description", std::string())}, - {"parameters", tool.contains("input_schema") ? tool.at("input_schema") : json::object()} - }} - }); - } - oai_body["tools"] = oai_tools; - } - } - - // Convert tool_choice - if (body.contains("tool_choice")) { - const json & tc = body.at("tool_choice"); - if (tc.is_object()) { - std::string type = json_value(tc, "type", std::string()); - if (type == "auto") { - oai_body["tool_choice"] = "auto"; - } else if (type == "any" || type == "tool") { - oai_body["tool_choice"] = "required"; - } - } - } - - // Convert stop_sequences to stop - if (body.contains("stop_sequences")) { - oai_body["stop"] = body.at("stop_sequences"); - } - - // Handle max_tokens (required in Anthropic, but we're permissive) - if (body.contains("max_tokens")) { - oai_body["max_tokens"] = body.at("max_tokens"); - } else { - oai_body["max_tokens"] = 4096; - } - - // Pass through common params - for (const auto & key : { "temperature", "top_p", "top_k", "stream" }) { - if (body.contains(key)) { - oai_body[key] = body.at(key); - } - } - - // Handle Anthropic-specific thinking param - if (body.contains("thinking")) { - json thinking = json_value(body, "thinking", json::object()); - std::string thinking_type = json_value(thinking, "type", std::string()); - if (thinking_type == "enabled") { - int budget_tokens = json_value(thinking, "budget_tokens", 10000); - oai_body["thinking_budget_tokens"] = budget_tokens; - } - } - - // Handle Anthropic-specific metadata param - if (body.contains("metadata")) { - json metadata = json_value(body, "metadata", json::object()); - std::string user_id = json_value(metadata, "user_id", std::string()); - if (!user_id.empty()) { - oai_body["__metadata_user_id"] = user_id; - } - } - - return oai_body; -} // // tokenizer and input processing utils diff --git a/examples/server/server-common.h b/examples/server/server-common.h index 9036e1f2..d6c54904 100644 --- a/examples/server/server-common.h +++ b/examples/server/server-common.h @@ -194,6 +194,9 @@ std::string gen_chatcmplid(); std::string gen_tool_call_id(); +// get a random marker; note: each time the server restarts, the marker will be different +const char * get_media_marker(); + // // other common utils // @@ -297,12 +300,6 @@ json oaicompat_chat_params_parse( const server_chat_params& opt, std::vector& out_files); -// convert OpenAI Responses API format to OpenAI Chat Completions API format -json convert_responses_to_chatcmpl(const json& body); - -// convert Anthropic Messages API format to OpenAI Chat Completions API format -json convert_anthropic_to_oai(const json & body); - // // tokenizer and input processing utils diff --git a/examples/server/server-context.cpp b/examples/server/server-context.cpp index b91149ce..ca5d426f 100644 --- a/examples/server/server-context.cpp +++ b/examples/server/server-context.cpp @@ -1,4 +1,5 @@ #include "server-context.h" +#include "server-chat.h" #include "server-common.h" #include "server-task.h" #include "server-queue.h" diff --git a/examples/server/server-task.cpp b/examples/server/server-task.cpp index f1e8e3cb..95287820 100644 --- a/examples/server/server-task.cpp +++ b/examples/server/server-task.cpp @@ -1,4 +1,5 @@ #include "server-task.h" +#include "server-chat.h" json result_timings::to_json() const { json base = { @@ -203,8 +204,8 @@ json server_task_result_cmpl_partial::to_json_oaicompat_chat_partial() { }); } - for (const auto& diff : oaicompat_msg_diffs) { - add_delta(common_chat_msg_diff_to_json_oaicompat(diff)); + for (const auto& diff : oaicompat_msg_diffs) { + add_delta(server_chat_msg_diff_to_json_oaicompat(diff)); } if (!deltas.empty()) { @@ -415,7 +416,7 @@ json server_task_result_cmpl_final::to_json_oaicompat_chat_stream() { json { {"finish_reason", nullptr}, {"index", 0}, - {"delta", common_chat_msg_diff_to_json_oaicompat(diff)}, + {"delta", server_chat_msg_diff_to_json_oaicompat(diff)}, }, })}, {"created", t}, diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 24eafcb4..e7e55634 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1,6 +1,7 @@ #pragma warning(disable : 4996) #include "server-context.h" #include "server-common.h" +#include "server-chat.h" #include "chat.h" #include "common.h" @@ -1297,7 +1298,7 @@ int main(int argc, char ** argv) { log_prompt(ctx_server.params_base, json::parse(req.body)); auto body = json::parse(req.body); std::vector files; - json body_parsed = convert_responses_to_chatcmpl(body); + json body_parsed = server_chat_convert_responses_to_chatcmpl(body); json data = oaicompat_chat_params_parse(body_parsed, ctx_server.chat_params, files); handle_completions_impl( SERVER_TASK_TYPE_COMPLETION, @@ -1311,7 +1312,7 @@ int main(int argc, char ** argv) { const auto handle_anthropic_messages = [&ctx_server, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) { std::vector files; log_prompt(ctx_server.params_base, json::parse(req.body)); - json body = convert_anthropic_to_oai(json::parse(req.body)); + json body = server_chat_convert_anthropic_to_oai(json::parse(req.body)); SRV_DBG("%s\n", "Request converted: Anthropic -> OpenAI Chat Completions"); SRV_DBG("converted request: %s\n", body.dump().c_str()); json body_parsed = oaicompat_chat_params_parse( @@ -1330,7 +1331,7 @@ int main(int argc, char ** argv) { const auto handle_anthropic_count_tokens = [&ctx_server, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) { std::vector files; log_prompt(ctx_server.params_base, json::parse(req.body)); - json body = convert_anthropic_to_oai(json::parse(req.body)); + json body = server_chat_convert_anthropic_to_oai(json::parse(req.body)); SRV_DBG("%s\n", "Request converted: Anthropic -> OpenAI Chat Completions"); SRV_DBG("converted request: %s\n", body.dump().c_str()); json body_parsed = oaicompat_chat_params_parse( diff --git a/tests/test-chat-auto-parser.cpp b/tests/test-chat-auto-parser.cpp index bb23b7f2..1d96de71 100644 --- a/tests/test-chat-auto-parser.cpp +++ b/tests/test-chat-auto-parser.cpp @@ -1331,7 +1331,7 @@ static void test_nemotron_reasoning_detection(testing & t) { // Check reasoning markers t.assert_equal("reasoning_start should be '\\n'", "\n", analysis.reasoning.start); - t.assert_equal("reasoning_end should be ''", "", analysis.reasoning.end); + t.assert_equal("reasoning_end should be '\\n\\n'", "\n\n", analysis.reasoning.end); // Check reasoning mode detection // Nemotron uses tag-based reasoning; prefill handles the template's forced markers diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index 51b5d9ea..b8b58eee 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -7,6 +7,7 @@ // #include "../src/llama-grammar.h" #include "../src/unicode.h" +#include "../example/server/server-chat.h" #include "chat-auto-parser.h" #include "chat.h" #include "common.h" @@ -541,6 +542,36 @@ static common_chat_tool edit_tool{ })", }; +static common_chat_tool manage_todo_list_tool{ + /* .name = */ "manage_todo_list", + /* .description = */ "Create or update the todo list", + /* .parameters = */ R"({ + "type": "object", + "properties": { + "todos": { + "type": "array", + "description": "List of TODO list items" + } + }, + "required": ["todos"] + })", +}; + +static common_chat_tool run_in_terminal_tool{ + /* .name = */ "run_in_terminal", + /* .description = */ "Run a shell command.", + /* .parameters = */ R"({ + "type": "object", + "properties": { + "command": { + "type": "string", + "description": "Shell command to run" + } + }, + "required": ["command"] + })", +}; + static common_chat_tool magic_tool{ /* .name = */ "magic", /* .description = */ "Magic tool that takes a hash", @@ -1378,6 +1409,16 @@ class peg_test_builder { return *this; } + peg_test_builder & tool_choice(common_chat_tool_choice choice) { + tc_.params.tool_choice = choice; + return *this; + } + + peg_test_builder & messages(std::vector messages) { + tc_.params.messages = std::move(messages); + return *this; + } + // Execute the test void run() { // Check template filter @@ -1514,6 +1555,117 @@ static void test_tools_oaicompat_json_conversion() { common_chat_tools_to_json_oaicompat({ special_function_tool }).dump(2)); } +static void test_convert_responses_to_chatcmpl() { + LOG_DBG("%s\n", __func__); + + // Test basic conversion with input messages (user/assistant alternating) + { + json input = json::parse(R"({ + "input": [ + { + "type": "message", + "role": "user", + "content": "hi wassup" + }, + { + "type": "message", + "role": "assistant", + "content": "Hey! 👋 Not much, just here ready to chat. What's up with you? Anything I can help you with today?" + }, + { + "type": "message", + "role": "user", + "content": "hi" + } + ], + "model": "gpt-5-mini", + "stream": false, + "text": {}, + "reasoning": { + "effort": "medium" + } + })"); + + json result = server_chat_convert_responses_to_chatcmpl(input); + + // Verify messages were converted correctly + assert_equals(true, result.contains("messages")); + assert_equals(true, result.at("messages").is_array()); + assert_equals((size_t)3, result.at("messages").size()); + + // Check first message (user) + const auto & msg0 = result.at("messages")[0]; + assert_equals(std::string("user"), msg0.at("role").get()); + assert_equals(true, msg0.at("content").is_array()); + assert_equals(std::string("text"), msg0.at("content")[0].at("type").get()); + assert_equals(std::string("hi wassup"), msg0.at("content")[0].at("text").get()); + + // Check second message (assistant) + const auto & msg1 = result.at("messages")[1]; + assert_equals(std::string("assistant"), msg1.at("role").get()); + assert_equals(true, msg1.at("content").is_array()); + assert_equals(std::string("text"), msg1.at("content")[0].at("type").get()); + assert_equals(std::string("Hey! 👋 Not much, just here ready to chat. What's up with you? Anything I can help you with today?"), msg1.at("content")[0].at("text").get()); + + // Check third message (user) + const auto & msg2 = result.at("messages")[2]; + assert_equals(std::string("user"), msg2.at("role").get()); + assert_equals(true, msg2.at("content").is_array()); + assert_equals(std::string("text"), msg2.at("content")[0].at("type").get()); + assert_equals(std::string("hi"), msg2.at("content")[0].at("text").get()); + + // Verify other fields preserved + assert_equals(std::string("gpt-5-mini"), result.at("model").get()); + assert_equals(false, result.at("stream").get()); + } + + // Test string input + { + json input = json::parse(R"({ + "input": "Hello, world!", + "model": "test-model" + })"); + + json result = server_chat_convert_responses_to_chatcmpl(input); + + assert_equals((size_t)1, result.at("messages").size()); + const auto & msg = result.at("messages")[0]; + assert_equals(std::string("user"), msg.at("role").get()); + assert_equals(std::string("Hello, world!"), msg.at("content").get()); + } + + // Test with instructions (system message) + { + json input = json::parse(R"({ + "input": "Hello", + "instructions": "You are a helpful assistant.", + "model": "test-model" + })"); + + json result = server_chat_convert_responses_to_chatcmpl(input); + + assert_equals((size_t)2, result.at("messages").size()); + const auto & sys_msg = result.at("messages")[0]; + assert_equals(std::string("system"), sys_msg.at("role").get()); + assert_equals(std::string("You are a helpful assistant."), sys_msg.at("content").get()); + } + + // Test with max_output_tokens conversion + { + json input = json::parse(R"({ + "input": "Hello", + "model": "test-model", + "max_output_tokens": 100 + })"); + + json result = server_chat_convert_responses_to_chatcmpl(input); + + assert_equals(true, result.contains("max_tokens")); + assert_equals(false, result.contains("max_output_tokens")); + assert_equals(100, result.at("max_tokens").get()); + } +} + static void test_template_output_peg_parsers(bool detailed_debug) { LOG_DBG("%s\n", __func__); @@ -1530,22 +1682,16 @@ static void test_template_output_peg_parsers(bool detailed_debug) { // Qwen3.5 (basically same as Nemotron, but keeping separate tests just in case) auto tst = peg_tester("models/templates/Qwen3.5-4B.jinja", detailed_debug); - tst.test("I'm\nthinkingHello, world!\nWhat's up?") + tst.test("I'm\nthinking\n\n\nHello, world!\nWhat's up?") .reasoning_format(COMMON_REASONING_FORMAT_AUTO) .enable_thinking(true) .expect(message_assist_thoughts) .run(); - tst.test("I'm\nthinking\n\nHello, world!\nWhat's up?") + tst.test("I'm\nthinking\n\n\nHello, world!\nWhat's up?") .enable_thinking(true) .reasoning_format(COMMON_REASONING_FORMAT_NONE) - .expect_content("\nI'm\nthinking\n\nHello, world!\nWhat's up?") - .run(); - - tst.test("I'm\nthinking\n\nHello, world!\nWhat's up?") - .enable_thinking(true) - .reasoning_format(COMMON_REASONING_FORMAT_AUTO) - .expect(message_assist_thoughts) + .expect_content("\nI'm\nthinking\n\n\nHello, world!\nWhat's up?") .run(); tst.test( @@ -1561,7 +1707,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .run(); tst.test( - "I'm\nthinking\n\n" + "I'm\nthinking\n\n\n" "\n" "\n" "\n1\n\n" @@ -1619,7 +1765,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) { tst.test( "I need to output the invoice details in JSON\n" - "\n" + "\n\n" R"({"amount": 123.45, "date": "2025-12-03"})") .reasoning_format(COMMON_REASONING_FORMAT_AUTO) .enable_thinking(true) @@ -1639,7 +1785,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) { "hello()\n" "\n" "\n" - "\n" + "\n\n\n" "\n" "\n" "\n" @@ -1649,23 +1795,23 @@ static void test_template_output_peg_parsers(bool detailed_debug) { "hello()\n" "\n" "\n" - "" - ) + "") .enable_thinking(true) .reasoning_format(COMMON_REASONING_FORMAT_AUTO) .tools({ python_tool }) - .expect_reasoning("Let's call a tool: \n" - "\n" - "\n" - "def hello():\n" - " print(\"Not the real call!\")\n" - "\n" - "hello()\n" - "\n" - "\n" - "") + .expect_reasoning( + "Let's call a tool: \n" + "\n" + "\n" + "def hello():\n" + " print(\"Not the real call!\")\n" + "\n" + "hello()\n" + "\n" + "\n" + "") .expect_tool_calls({ { "python", "{\"code\": \"def hello():\\n print(\\\"Hello, world!\\\")\\n\\nhello()\"}", {} }, }) @@ -1694,6 +1840,219 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .tools({ empty_args_tool_no_properties }) .expect(message_with_tool_calls("empty_args_no_props", "{}")) .run(); + + // Edge cases when reasoning traces are not sent + tst.test( + "\n\n\n\n" + "\n" + "\n" + "\n1\n\n" + "\n" + "") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .tools({ + special_function_tool + }) + .expect_reasoning("\n\n") + .expect_tool_calls({ { "special_function", "{\"arg1\": 1}", "" } }) + .run(); + + tst.test( + "\n\n" + "\n" + "\n" + "\n1\n\n" + "\n" + "") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .tools({ + special_function_tool + }) + .expect_reasoning("") + .expect_tool_calls({ { "special_function", "{\"arg1\": 1}", "" } }) + .run(); + + tst.test( + "\n\n" + "\n" + "\n" + "\n" + "pwd\n" + "\n" + "\n" + "") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .enable_thinking(true) + .tools({ + run_in_terminal_tool + }) + .expect_tool_calls({ + { "run_in_terminal", R"({"command": "pwd"})", {} }, + }) + .run(); + + tst.test( + "\n\n" + "Let me inspect the current directory.\n" + "\n" + "\n" + "\n" + "pwd\n" + "\n" + "\n" + "") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .enable_thinking(true) + .tools({ + run_in_terminal_tool + }) + .expect_content("Let me inspect the current directory.\n") + .expect_tool_calls({ + { "run_in_terminal", R"({"command": "pwd"})", {} }, + }) + .run(); + + tst.test( + "\n\n" + "Let me inspect the current directory.\n" + "\n" + "\n" + "\n" + "pwd\n" + "\n" + "\n" + "") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .enable_thinking(true) + .tools({ + run_in_terminal_tool + }) + .tool_choice(COMMON_CHAT_TOOL_CHOICE_REQUIRED) + .expect_content("Let me inspect the current directory.\n") + .expect_tool_calls({ + { "run_in_terminal", R"({"command": "pwd"})", {} }, + }) + .run(); + + tst.test( + "I should inspect the directory.\n" + "\n\n" + "Let me inspect it now.\n" + "\n" + "\n" + "\n" + "pwd\n" + "\n" + "\n" + "") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .enable_thinking(true) + .tools({ + run_in_terminal_tool + }) + .expect_reasoning("I should inspect the directory.") + .expect_content("Let me inspect it now.\n") + .expect_tool_calls({ + { "run_in_terminal", R"({"command": "pwd"})", {} }, + }) + .run(); + + tst.test( + "I might call later, but I am still thinking.\n" + "\n\n" + "Final answer without tools.") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .enable_thinking(true) + .tools({ run_in_terminal_tool }) + .expect_reasoning("I might call later, but I am still thinking.") + .expect_content("Final answer without tools.") + .run(); + + { + common_chat_msg user_start; + user_start.role = "user"; + user_start.content = "Create a todo list, then inspect the repository."; + + common_chat_msg assistant_todos = + simple_assist_msg("", "", "manage_todo_list", + R"({"todos":[{"item":"Inspect repository","selected":false}]})", "call_todos"); + + common_chat_msg tool_result; + tool_result.role = "tool"; + tool_result.content = "Successfully wrote todo list"; + tool_result.tool_call_id = "call_todos"; + + common_chat_msg user_continue; + user_continue.role = "user"; + user_continue.content = "Proceed."; + + tst.test( + "I need to run a terminal command.\n" + "\n\n" + "\n" + "\n" + "\n" + "pwd\n" + "\n" + "\n" + "") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .enable_thinking(true) + .tools({ + manage_todo_list_tool, run_in_terminal_tool + }) + .messages({ user_start, assistant_todos, tool_result, user_continue }) + .expect_reasoning("I need to run a terminal command.") + .expect_tool_calls({ + { "run_in_terminal", R"({"command": "pwd"})", {} }, + }) + .run(); + + tst.test( + "I need to run a terminal command.\n" + "\n\n" + "Let me inspect the current directory.\n" + "\n" + "\n" + "\n" + "pwd\n" + "\n" + "\n" + "") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .enable_thinking(true) + .tools({ + manage_todo_list_tool, run_in_terminal_tool + }) + .tool_choice(COMMON_CHAT_TOOL_CHOICE_REQUIRED) + .messages({ user_start, assistant_todos, tool_result, user_continue }) + .expect_reasoning("I need to run a terminal command.") + .expect_content("Let me inspect the current directory.\n") + .expect_tool_calls({ + { "run_in_terminal", R"({"command": "pwd"})", {} }, + }) + .run(); + + tst.test( + "\n\n" + "\n" + "\n" + "\n" + "pwd\n" + "\n" + "\n" + "") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .enable_thinking(true) + .tools({ + manage_todo_list_tool, run_in_terminal_tool + }) + .messages({ user_start, assistant_todos, tool_result, user_continue }) + .expect_tool_calls({ + { "run_in_terminal", R"({"command": "pwd"})", {} }, + }) + .run(); + } } { @@ -1882,7 +2241,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) { "hello()\n" "\n" "\n" - "\n" + "\n\n" "\n" "\n" "\n" @@ -3335,7 +3694,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .run(); // Tool call with reasoning (enable_thinking=true) - tst.test("I'm\nthinking\n{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}") + tst.test("I'm\nthinking\n\n\n\n{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}") .enable_thinking(true) .reasoning_format(COMMON_REASONING_FORMAT_AUTO) .tools({ special_function_tool }) @@ -3359,7 +3718,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .run(); // Tool call with reasoning and content - tst.test("I need to call a function" + tst.test("I need to call a function\n\n\n" "Let me check the time.\n{\"name\": \"get_time\", \"arguments\": {\"city\": \"XYZCITY\"}}") .enable_thinking(true) .reasoning_format(COMMON_REASONING_FORMAT_AUTO) @@ -3386,7 +3745,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) { // fake tool call marker in reasoning tst.test( - "Let me think about \n{\"name\": \"special_function\", \"arguments\": {\"arg1\": 2}} hmm" + "Let me think about \n{\"name\": \"special_function\", \"arguments\": {\"arg1\": 2}} hmm\n\n\n" "\n{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}") .enable_thinking(true) .reasoning_format(COMMON_REASONING_FORMAT_AUTO) @@ -3414,11 +3773,11 @@ static void test_template_output_peg_parsers(bool detailed_debug) { // Format: value { auto tst = peg_tester("models/templates/MiniMax-M2.jinja", detailed_debug); - tst.test("Hello, world!\nWhat's up?").enable_thinking(true).reasoning_format(COMMON_REASONING_FORMAT_AUTO).expect(message_assist).run(); + tst.test("\n\n\nHello, world!\nWhat's up?").enable_thinking(true).reasoning_format(COMMON_REASONING_FORMAT_AUTO).expect(message_assist).run(); - tst.test("I'm\nthinkingHello, world!\nWhat's up?").enable_thinking(true).reasoning_format(COMMON_REASONING_FORMAT_AUTO).expect(message_assist_thoughts).run(); + tst.test("I'm\nthinking\n\n\nHello, world!\nWhat's up?").enable_thinking(true).reasoning_format(COMMON_REASONING_FORMAT_AUTO).expect(message_assist_thoughts).run(); - tst.test("Let's call a tool:\n\n\n"). + tst.test("Let's call a tool:\n\n\n\n\n\n"). enable_thinking(true). reasoning_format(COMMON_REASONING_FORMAT_AUTO). tools({ empty_args_tool }). @@ -3426,7 +3785,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) { run(); tst.test( - "\n\n\n\n\n\n1\n\n") .tools({ special_function_tool }) .expect(message_assist_call) @@ -3579,6 +3938,51 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .run(); } + // Reka Edge + { + auto tst = peg_tester("models/templates/Reka-Edge.jinja", detailed_debug); + tst.test("Hello, world!\nWhat's up?") + .enable_thinking(false) + .expect(message_assist) + .run(); + tst.test("I'm\nthinking\n\n\nHello, world!\nWhat's up?") + .enable_thinking(true) + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .expect(message_assist_thoughts) + .run(); + tst.test("\n{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n") + .enable_thinking(false) + .tools({ special_function_tool }) + .expect(message_assist_call) + .run(); + tst.test("Hello, world!\nWhat's up?\n\n{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n") + .enable_thinking(false) + .tools({ special_function_tool }) + .expect(message_assist_call_content) + .run(); + tst.test("I'm\nthinking\n\n\n\n{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n") + .enable_thinking(true) + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .tools({ special_function_tool }) + .expect(message_assist_call_thoughts) + .run(); + tst.test("\n{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n\n\n{\"name\": \"special_function_with_opt\", \"arguments\": {\"arg1\": 1, \"arg2\": 2}}\n") + .enable_thinking(false) + .parallel_tool_calls(true) + .tools({ special_function_tool, special_function_tool_with_optional_param }) + .expect_tool_calls({ + { "special_function", R"({"arg1": 1})", {} }, + { "special_function_with_opt", R"({"arg1": 1, "arg2": 2})", {} }, + }) + .run(); + tst.test("\n{\"name\": \"special_function\", \"arguments\": {\"arg") + .enable_thinking(false) + .tools({ special_function_tool }) + .is_partial(true) + .expect(message_assist_call_cutoff_args) + .run(); + } + // Apriel 1.5 { auto tst = peg_tester("models/templates/unsloth-Apriel-1.5.jinja", detailed_debug); @@ -3833,7 +4237,8 @@ static void test_template_output_peg_parsers(bool detailed_debug) { { auto tst = peg_tester("models/templates/StepFun3.5-Flash.jinja", detailed_debug); - tst.test("I was thinking\nNow I'm not."). + + tst.test("I was thinking\n\nNow I'm not."). enable_thinking(true). reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK). expect_reasoning("I was thinking"). @@ -4181,7 +4586,7 @@ int main(int argc, char ** argv) { bool detailed_debug = false; bool only_run_filtered = false; - // Check for --template flag + // Check for --template and --detailed flags for (int i = 1; i < argc; i++) { std::string arg = argv[i]; if (arg == "--template" && i + 1 < argc) { @@ -4206,7 +4611,20 @@ int main(int argc, char ** argv) { } #ifndef _WIN32 - if (argc > 1) { + // Check if any argument is a .jinja file (for template format detection mode) + bool has_jinja_files = false; + for (int i = 1; i < argc; i++) { + std::string arg = argv[i]; + if (arg == "--detailed") { + continue; + } + if (arg.size() >= 6 && arg.rfind(".jinja") == arg.size() - 6) { + has_jinja_files = true; + break; + } + } + + if (has_jinja_files) { common_chat_templates_inputs inputs; common_chat_msg msg; msg.role = "user"; @@ -4239,6 +4657,7 @@ int main(int argc, char ** argv) { test_msg_diffs_compute(); test_msgs_oaicompat_json_conversion(); test_tools_oaicompat_json_conversion(); + test_convert_responses_to_chatcmpl(); test_developer_role_to_system_workaround(); test_template_output_peg_parsers(detailed_debug); std::cout << "\n[chat] All tests passed!" << '\n';