diff --git a/tools/server/CMakeLists.txt b/tools/server/CMakeLists.txt index 7d427431db..47bb582c30 100644 --- a/tools/server/CMakeLists.txt +++ b/tools/server/CMakeLists.txt @@ -17,6 +17,8 @@ add_library(${TARGET} STATIC server-context.h server-tools.cpp server-tools.h + server-schema.cpp + server-schema.h ) if (BUILD_SHARED_LIBS) diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 31280d63c4..aebca306a8 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -4,6 +4,7 @@ #include "server-http.h" #include "server-task.h" #include "server-queue.h" +#include "server-schema.h" #include "build-info.h" #include "common.h" @@ -3825,7 +3826,7 @@ std::unique_ptr server_routes::handle_completions_impl( task.id = rd.get_new_id(); task.tokens = std::move(inputs[i]); - task.params = server_task::params_from_json_cmpl( + task.params = server_schema::eval_llama_cmpl_schema( ctx_server.vocab, params, meta->slot_n_ctx, diff --git a/tools/server/server-schema.cpp b/tools/server/server-schema.cpp new file mode 100644 index 0000000000..d5d747a654 --- /dev/null +++ b/tools/server/server-schema.cpp @@ -0,0 +1,635 @@ +#include "server-schema.h" + +#include "json-schema-to-grammar.h" + +namespace server_schema { + +// +// llama.cpp-specific completion schema +// + +std::vector> make_llama_cmpl_schema(const common_params & params_base, task_params & params) { + std::vector> fields; + auto add = [&](field * f) { + fields.emplace_back(f); + }; + + add((new field_bool("timings_per_token", params.timings_per_token)) + ->set_desc("Include prompt processing and text generation speed information in each response")); + + add((new field_bool("stream", params.stream)) + ->set_desc("Allows receiving each predicted token in real-time instead of waiting for the completion to finish")); + + add((new field_nested("stream_options")) + ->add_subfield((new field_bool("include_usage", params.include_usage)) + ->set_desc("Whether to include usage information in the stream")) + ->set_desc("Additional options for streaming responses")); + + add((new field_bool("cache_prompt", params.cache_prompt)) + ->set_desc("Re-use KV cache from a previous request if possible. This way the common prefix does not have to be re-processed, only the suffix that differs between the requests")); + + add((new field_bool("return_tokens", params.return_tokens)) + ->set_desc("Return the raw generated token ids in the `tokens` field")); + + add((new field_bool("return_progress", params.return_progress)) + ->set_desc("Include prompt processing progress events in stream mode")); + + add((new field_num("n_predict", params.n_predict)) + ->set_hard_limits(-1, INT32_MAX) + ->add_alias("max_completion_tokens") + ->add_alias("max_tokens") + ->set_desc("Set the maximum number of tokens to predict. When 0, no tokens will be generated but the prompt is evaluated into the cache")); + + add((new field_num("n_indent", params.n_indent)) + ->set_hard_limits(0, INT32_MAX) + ->set_desc("Specify the minimum line indentation for the generated text in number of whitespace characters. Useful for code completion tasks")); + + add((new field_num("n_keep", params.n_keep)) + ->set_hard_limits(-1, INT32_MAX) + ->set_desc("Specify the number of tokens from the initial prompt to retain when context size is exceeded. Use -1 to retain all tokens from the prompt")); + + add((new field_num("n_discard", params.n_discard)) + ->set_hard_limits(0, INT32_MAX) + ->set_desc("Number of tokens after n_keep that may be discarded when shifting context (0 = half context)")); + + add((new field_num("n_cmpl", params.n_cmpl)) + ->set_hard_limits(1, params_base.n_parallel) + ->add_alias("n") // alias "n" as fallback (OpenAI completions API) + ->set_desc("Number of completions to generate. If the input has multiple prompts, total outputs will be N prompts times n_cmpl")); + + add((new field_num("n_cache_reuse", params.n_cache_reuse)) + ->set_hard_limits(0, INT32_MAX) + ->set_desc("Min chunk size to attempt reusing from the cache via KV shifting. See --cache-reuse arg")); + + // TODO: implement t_max_prompt_ms + // add((new field_num("t_max_prompt_ms", params.t_max_prompt_ms)) + + add((new field_num("t_max_predict_ms", params.t_max_predict_ms)) + ->set_hard_limits(-1, std::numeric_limits::max()) + ->set_desc("Set a time limit in milliseconds for the prediction phase. The timeout triggers if generation exceeds this time (measured since the first token) and a newline has been generated. Useful for FIM applications")); + + add((new field_json("response_fields")) + ->set_desc("A list of response fields to return. Missing fields are omitted without error. Fields with a slash are unnested (e.g. generation_settings/n_predict moves n_predict to the root)") + ->set_handler([&](field_eval_context & ctx, const json & data) { + ctx.params.response_fields = json_value(data, "response_fields", std::vector()); + })); + + + // + // Sampling params + // + + add((new field_num("top_k", params.sampling.top_k)) + ->set_limits(0, INT32_MAX) + ->set_desc("Limit the next token selection to the K most probable tokens (0 = disabled)")); + + add((new field_num("top_p", params.sampling.top_p)) + ->set_limits(0.0f, 1.0f) + ->set_desc("Limit the next token selection to a subset of tokens with cumulative probability above threshold P (1.0 = disabled)")); + + add((new field_num("min_p", params.sampling.min_p)) + ->set_limits(0.0f, 1.0f) + ->set_desc("The minimum probability for a token to be considered, relative to the probability of the most likely token (0 = disabled)")); + + add((new field_num("top_n_sigma", params.sampling.top_n_sigma)) + ->set_desc("Keep tokens within n standard deviations of the top token logit (< 0 = disabled)")); + + add((new field_num("xtc_probability", params.sampling.xtc_probability)) + ->set_limits(0.0f, 1.0f) + ->set_desc("Set the chance for token removal via XTC sampler (0 = disabled)")); + + add((new field_num("xtc_threshold", params.sampling.xtc_threshold)) + ->set_limits(0.0f, 1.0f) + ->set_desc("Set a minimum probability threshold for tokens to be removed via XTC sampler (> 0.5 disables XTC)")); + + add((new field_num("typical_p", params.sampling.typ_p)) + // ->set_limits(0.0f, 1.0f) // what's the valid range? + ->set_desc("Enable locally typical sampling with parameter p (1.0 = disabled)")); + + add((new field_num("temperature", params.sampling.temp)) + ->set_limits(0.0f, std::numeric_limits::infinity()) + ->set_desc("Adjust the randomness of the generated text (0 = greedy)")); + + add((new field_num("dynatemp_range", params.sampling.dynatemp_range)) + ->set_desc("Dynamic temperature range. The final temperature will be in [temperature - range, temperature + range] (0 = disabled)")); + + add((new field_num("dynatemp_exponent", params.sampling.dynatemp_exponent)) + ->set_desc("Dynamic temperature exponent, controls how entropy maps to temperature")); + + add((new field_num("repeat_last_n", params.sampling.penalty_last_n)) + ->set_hard_limits(-1, INT32_MAX) + ->set_desc("Last n tokens to consider for penalizing repetition (0 = disabled, -1 = ctx-size)")); + + add((new field_num("repeat_penalty", params.sampling.penalty_repeat)) + ->set_desc("Control the repetition of token sequences in the generated text (1.0 = disabled)")); + + add((new field_num("frequency_penalty", params.sampling.penalty_freq)) + ->set_desc("Repeat alpha frequency penalty (0 = disabled)")); + + add((new field_num("presence_penalty", params.sampling.penalty_present)) + ->set_desc("Repeat alpha presence penalty (0 = disabled)")); + + add((new field_num("dry_multiplier", params.sampling.dry_multiplier)) + ->set_desc("Set the DRY (Don't Repeat Yourself) repetition penalty multiplier (0 = disabled)")); + + add((new field_num("dry_base", params.sampling.dry_base)) + ->set_desc("Set the DRY repetition penalty base value (must be >= 1.0, any values < 1.0 will be replaced with the default value)") + ->set_handler([&](field_eval_context & ctx, const json & data) { + float v = data.at("dry_base").get(); + ctx.params.sampling.dry_base = (v < 1.0f) ? params_base.sampling.dry_base : v; + })); + + add((new field_num("dry_allowed_length", params.sampling.dry_allowed_length)) + ->set_hard_limits(0, INT32_MAX) + ->set_desc("Tokens that extend repetition beyond this length receive exponentially increasing penalty: multiplier * base ^ (sequence_length - allowed_length)")); + + add((new field_num("dry_penalty_last_n", params.sampling.dry_penalty_last_n)) + ->set_hard_limits(-1, INT32_MAX) + ->set_desc("How many tokens to scan for repetitions (0 = disabled, -1 = context size)")); + + add((new field_num("mirostat", params.sampling.mirostat)) + ->set_limits(0, 2) + ->set_desc("Enable Mirostat sampling, controlling perplexity during text generation (0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)")); + + add((new field_num("mirostat_tau", params.sampling.mirostat_tau)) + ->set_desc("Set the Mirostat target entropy, parameter tau")); + + add((new field_num("mirostat_eta", params.sampling.mirostat_eta)) + ->set_desc("Set the Mirostat learning rate, parameter eta")); + + add((new field_num("adaptive_target", params.sampling.adaptive_target)) + ->set_limits(-std::numeric_limits::max(), 1.0f) + ->set_desc("Adaptive sampling target entropy (valid range 0.0 to 1.0; negative = disabled)")); + + add((new field_num("adaptive_decay", params.sampling.adaptive_decay)) + ->set_hard_limits(0.0f, 0.99f) + ->set_desc("EMA decay for adaptive sampling; history approximates 1/(1-decay) tokens")); + + // seed is uint32_t; field_num uses int32_t so use a handler + add((new field_num("seed", params.sampling.seed)) + ->set_desc("Set the random number generator (RNG) seed (-1 = random)")); + + add((new field_num("n_probs", params.sampling.n_probs)) + ->add_alias("logprobs") // use "logprobs" if "n_probs" wasn't provided + ->set_desc("If greater than 0, output the probabilities of top N tokens for each generated token")); + + add((new field_num("min_keep", params.sampling.min_keep)) + ->set_hard_limits(0, INT32_MAX) + ->set_desc("If greater than 0, force samplers to return at least N possible tokens")); + + add((new field_bool("backend_sampling", params.sampling.backend_sampling)) + ->set_desc("Use backend sampling instead of llama.cpp sampling")); + + add((new field_bool("post_sampling_probs", params.post_sampling_probs)) + ->set_desc("Return probabilities of top n_probs tokens after applying the sampling chain")); + + // + // Speculative decoding params + // + + // TODO: to keep things simple, we disable speculative parameter adjustments for now +#if 0 + // TODO: for now, be able to adjust only the draft-model based speculative parameters + add((new field_num("speculative.n_max", params.speculative.draft.n_max)) + ->set_hard_limits(0, INT32_MAX) + ->set_desc("Maximum number of tokens to draft during speculative decoding")); + + add((new field_num("speculative.n_min", params.speculative.draft.n_min)) + ->set_hard_limits(0, INT32_MAX) + ->set_desc("Minimum number of draft tokens to use for speculative decoding"); + + add((new field_num("speculative.p_min", params.speculative.draft.p_min)) + ->set_hard_limits(0.0f, 1.0f) + ->set_desc("Minimum speculative decoding probability for draft tokens (0 = greedy)")); + + add((new field_str("speculative.type")) + ->set_desc("Speculative decoding method (for debugging and research purposes)") + ->set_handler([&](field_eval_context & ctx, const json & data) { + ctx.params.speculative.types = { common_speculative_type_from_name(data.at("speculative.type").get()) }; + })); + + add((new field_num("speculative.ngram_size_n", params.speculative.ngram_simple.size_n)) + ->set_desc("Ngram size for lookup in ngram-based speculative decoding")); + + add((new field_num("speculative.ngram_size_m", params.speculative.ngram_simple.size_m)) + ->set_desc("Mgram size for speculative tokens in ngram-based speculative decoding")); + + add((new field_num("speculative.ngram_min_hits", params.speculative.ngram_simple.min_hits)) + ->set_desc("Minimum hits at ngram lookup for mgram to be proposed")); +#endif + + add((new field_json("lora")) + ->set_desc("A list of LoRA adapters to apply to this request. Each entry must have `id` and `scale` fields. Adapters not listed default to scale 0.0") + ->set_handler([&](field_eval_context & ctx, const json & data) { + const auto & lora = data.at("lora"); + if (!lora.is_array()) { + throw std::runtime_error("Error: 'lora' must be an array of objects with 'id' and 'scale' fields"); + } + ctx.params.lora = parse_lora_request(lora); + })); + + // sequence breakers for DRY + // Currently, this is not compatible with TextGen WebUI, Koboldcpp and SillyTavern format + // Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39 + add((new field_json("dry_sequence_breakers")) + ->set_desc("Specify an array of sequence breakers for DRY sampling. Only a JSON array of strings is accepted") + ->set_handler([&](field_eval_context & ctx, const json & data) { + ctx.params.sampling.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector()); + if (ctx.params.sampling.dry_sequence_breakers.empty()) { + throw std::runtime_error("Error: dry_sequence_breakers must be a non-empty array of strings"); + } + })); + + // handle both "json_schema" and "grammar" + add((new field_json("json_schema")) + ->add_alias("grammar") + ->set_desc("Set a JSON schema (json_schema) or GBNF grammar string (grammar) for constrained generation. json_schema takes precedence if both are provided") + ->set_handler([&](field_eval_context & ctx, const json & data) { + auto & params = ctx.params; + if (data.contains("json_schema") && !data.contains("grammar")) { + try { + auto schema = json_value(data, "json_schema", json::object()); + SRV_DBG("JSON schema: %s\n", schema.dump(2).c_str()); + std::string grammar_str = json_schema_to_grammar(schema); + SRV_DBG("Converted grammar: %s\n", grammar_str.c_str()); + params.sampling.grammar = {COMMON_GRAMMAR_TYPE_OUTPUT_FORMAT, std::move(grammar_str)}; + } catch (const std::exception & e) { + throw std::runtime_error(std::string("\"json_schema\": ") + e.what()); + } + } else { + std::string grammar_str = json_value(data, "grammar", std::string()); + if (!grammar_str.empty()) { + // grammar_type key is set by the server when converting chat template grammars + std::string grammar_type = json_value(data, "grammar_type", std::string()); + if (grammar_type == "tool_calls") { + params.sampling.grammar = {COMMON_GRAMMAR_TYPE_TOOL_CALLS, std::move(grammar_str)}; + } else { + // explicit grammar from the user (API field "grammar") + params.sampling.grammar = {COMMON_GRAMMAR_TYPE_USER, std::move(grammar_str)}; + } + SRV_DBG("Grammar (%s): %s\n", grammar_type.c_str(), common_grammar_value(params.sampling.grammar).c_str()); + } + } + })); + + add((new field_bool("grammar_lazy", params.sampling.grammar_lazy)) + ->set_desc("Whether to apply grammar constraints lazily, only when triggered (instead of at every step)")); + + // + // Chat parser params + // + + // TODO: change this to string field instead + add((new field_json("chat_format")) + ->set_desc("Chat format used internally by the server") + ->set_handler([&](field_eval_context & ctx, const json & data) { + ctx.params.chat_parser_params.format = static_cast(data.at("chat_format").get()); + SRV_INF("Chat format: %s\n", common_chat_format_name(ctx.params.chat_parser_params.format)); + })); + + add((new field_str("reasoning_format")) + ->set_desc("Reasoning format for chain-of-thought models") + ->set_handler([&](field_eval_context & ctx, const json & data) { + auto reasoning_format = common_reasoning_format_from_name(data.at("reasoning_format").get()); + ctx.params.chat_parser_params.reasoning_format = reasoning_format; + ctx.params.chat_parser_params.reasoning_in_content = ctx.params.stream && (reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY); + })); + + add((new field_str("generation_prompt")) + ->set_desc("Generation prompt appended to the chat template output") + ->set_handler([&](field_eval_context & ctx, const json & data) { + std::string s = data.at("generation_prompt").get(); + ctx.params.chat_parser_params.generation_prompt = s; + ctx.params.sampling.generation_prompt = s; + })); + + add((new field_bool("parse_tool_calls", params.chat_parser_params.parse_tool_calls)) + ->set_desc("Whether to parse tool calls from the generated output")); + + add((new field_str("chat_parser")) + ->set_desc("Chat parser configuration string") + ->set_handler([&](field_eval_context & ctx, const json & data) { + ctx.params.chat_parser_params.parser.load(data.at("chat_parser").get()); + })); + + add((new field_json("continue_final_message")) + ->set_desc("Whether to continue the final message of the chat template") + ->set_handler([&](field_eval_context & ctx, const json & data) { + auto continuation = common_chat_continuation_parse(data.at("continue_final_message")); + ctx.params.chat_parser_params.is_continuation = continuation != COMMON_CHAT_CONTINUATION_NONE; + })); + + add((new field_bool("echo", params.chat_parser_params.echo)) + ->set_desc("Whether to echo the input tokens in the output")); + + // + // Token-level fields (require vocab) + // + + add((new field_json("preserved_tokens")) + ->set_desc("List of token strings that must not be split during tokenization") + ->set_handler([&](field_eval_context & ctx, const json & data) { + GGML_ASSERT(ctx.vocab != nullptr); + for (const auto & t : data.at("preserved_tokens")) { + auto ids = common_tokenize(ctx.vocab, t.get(), false, true); + if (ids.size() == 1) { + ctx.params.sampling.preserved_tokens.insert(ids[0]); + } + } + })); + + add((new field_json("grammar_triggers")) + ->set_desc("List of strings or patterns that trigger grammar-constrained generation") + ->set_handler([&](field_eval_context & ctx, const json & data) { + GGML_ASSERT(ctx.vocab != nullptr); + for (const auto & t : data.at("grammar_triggers")) { + server_grammar_trigger ct(t); + if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) { + const auto & word = ct.value.value; + auto ids = common_tokenize(ctx.vocab, word, false, true); + if (ids.size() == 1) { + auto token = ids[0]; + if (std::find(ctx.params.sampling.preserved_tokens.begin(), ctx.params.sampling.preserved_tokens.end(), (llama_token) token) == ctx.params.sampling.preserved_tokens.end()) { + throw std::runtime_error("Grammar trigger word should be marked as preserved token: " + word); + } + common_grammar_trigger trigger; + trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN; + trigger.value = word; + trigger.token = token; + ctx.params.sampling.grammar_triggers.push_back(std::move(trigger)); + } else { + ctx.params.sampling.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word}); + } + } else { + ctx.params.sampling.grammar_triggers.emplace_back(std::move(ct.value)); + } + } + if (ctx.params.sampling.grammar_lazy && ctx.params.sampling.grammar_triggers.empty()) { + throw std::runtime_error("Error: no triggers set for lazy grammar!"); + } + })); + + add((new field_bool("reasoning_control", params.sampling.reasoning_control)) + ->set_desc("Create the budget sampler on demand so reasoning can be ended at runtime")); + + add((new field_num("reasoning_budget_tokens", params.sampling.reasoning_budget_tokens)) + ->set_hard_limits(-1, INT32_MAX) + ->set_desc("Number of tokens in the reasoning budget (-1 = disabled)")); + + add((new field_str("reasoning_budget_start_tag")) + ->set_desc("Token string marking the start of the reasoning budget section") + ->set_handler([&](field_eval_context & ctx, const json & data) { + GGML_ASSERT(ctx.vocab != nullptr); + ctx.params.sampling.reasoning_budget_start = common_tokenize(ctx.vocab, data.at("reasoning_budget_start_tag").get(), false, true); + })); + + add((new field_str("reasoning_budget_end_tag")) + ->set_desc("Token string marking the end of the reasoning budget section") + ->set_handler([&](field_eval_context & ctx, const json & data) { + GGML_ASSERT(ctx.vocab != nullptr); + std::string end_tag = data.at("reasoning_budget_end_tag").get(); + ctx.params.sampling.reasoning_budget_end = common_tokenize(ctx.vocab, end_tag, false, true); + })); + + add((new field_str("reasoning_budget_message")) + ->set_desc("Message to prepend to the reasoning budget end tag when forcing it") + ->set_handler([&](field_eval_context & ctx, const json & data) { + GGML_ASSERT(ctx.vocab != nullptr); + std::string end_tag = json_value(data, "reasoning_budget_end_tag", std::string()); + std::string message = data.at("reasoning_budget_message").get(); + ctx.params.sampling.reasoning_budget_forced = common_tokenize(ctx.vocab, message + end_tag, false, true); + })); + + add((new field_json("logit_bias")) + ->set_desc("Modify the likelihood of specific tokens. Accepts an array of [token, bias] pairs or an object mapping token to bias. Use false as bias to ban a token") + ->set_handler([&](field_eval_context & ctx, const json & data) { + GGML_ASSERT(ctx.vocab != nullptr); + ctx.params.sampling.logit_bias.clear(); + const auto & logit_bias = data.at("logit_bias"); + const int n_vocab = llama_vocab_n_tokens(ctx.vocab); + auto parse_bias = [](const json & v, float & bias) -> bool { + if (v.is_number()) { bias = v.get(); return true; } + if (v.is_boolean() && !v.get()) { bias = -INFINITY; return true; } + return false; + }; + if (logit_bias.is_array()) { + for (const auto & el : logit_bias) { + if (!el.is_array() || el.size() != 2) continue; + float bias; + if (!parse_bias(el[1], bias)) continue; + if (el[0].is_number_integer()) { + llama_token tok = el[0].get(); + if (tok >= 0 && tok < n_vocab) ctx.params.sampling.logit_bias.push_back({tok, bias}); + } else if (el[0].is_string()) { + for (auto tok : common_tokenize(ctx.vocab, el[0].get(), false)) + ctx.params.sampling.logit_bias.push_back({tok, bias}); + } + } + } else if (logit_bias.is_object()) { + for (const auto & el : logit_bias.items()) { + float bias; + if (!parse_bias(el.value(), bias)) continue; + char * end; + llama_token tok = strtol(el.key().c_str(), &end, 10); + if (*end == 0) { + if (tok >= 0 && tok < n_vocab) ctx.params.sampling.logit_bias.push_back({tok, bias}); + } else { + for (auto t : common_tokenize(ctx.vocab, el.key(), false)) + ctx.params.sampling.logit_bias.push_back({t, bias}); + } + } + } + })); + + add((new field_bool("ignore_eos", params.sampling.ignore_eos)) + ->set_desc("Ignore the end-of-sequence token and continue generating") + ->set_handler([&](field_eval_context & ctx, const json & data) { + GGML_ASSERT(ctx.logit_bias_eog != nullptr); + ctx.params.sampling.ignore_eos = data.at("ignore_eos").get(); + if (ctx.params.sampling.ignore_eos && ctx.logit_bias_eog) { + ctx.params.sampling.logit_bias.insert( + ctx.params.sampling.logit_bias.end(), + ctx.logit_bias_eog->begin(), ctx.logit_bias_eog->end()); + } + })); + + add((new field_json("stop")) + ->set_desc("Specify stopping strings. Generation stops when one is produced, and the string is not included in the output") + ->set_handler([&](field_eval_context & ctx, const json & data) { + ctx.params.antiprompt.clear(); + const auto & stop = data.at("stop"); + if (stop.is_array()) { + for (const auto & word : stop) { + if (!word.empty()) ctx.params.antiprompt.push_back(word); + } + } else if (stop.is_string()) { + ctx.params.antiprompt.push_back(stop.get()); + } + // fall back to CLI defaults if the request provided no effective stop strings + if (ctx.params.antiprompt.empty()) { + ctx.params.antiprompt = params_base.antiprompt; + } + })); + + add((new field_json("samplers")) + ->set_desc("The order in which samplers are applied. An array of sampler type names, or a single string of sampler chars") + ->set_handler([&](field_eval_context & ctx, const json & data) { + const auto & samplers = data.at("samplers"); + if (samplers.is_array()) { + ctx.params.sampling.samplers = common_sampler_types_from_names(samplers); + } else if (samplers.is_string()) { + ctx.params.sampling.samplers = common_sampler_types_from_chars(samplers.get()); + } + })); + + return fields; +} + +task_params eval_llama_cmpl_schema( + const llama_vocab * vocab, + const common_params & params_base, + const int n_ctx_slot, + const std::vector & logit_bias_eog, + const json & data) { + task_params params; + + // Sampling parameter defaults are loaded from the global server context (but individual requests can still them) + params.sampling = params_base.sampling; + params.speculative = params_base.speculative; + params.n_keep = params_base.n_keep; + params.n_predict = params_base.n_predict; + params.n_cache_reuse = params_base.n_cache_reuse; + params.cache_prompt = params_base.cache_prompt; + params.antiprompt = params_base.antiprompt; + + // enabling this will output extra debug information in the HTTP responses from the server + params.verbose = params_base.verbosity > 9; + + params.chat_parser_params.reasoning_format = params_base.reasoning_format; + + // create context and schema + field_eval_context ctx(params); + ctx.vocab = vocab; + ctx.logit_bias_eog = &logit_bias_eog; + + auto schema = make_llama_cmpl_schema(params_base, params); + + // eval all fields in the schema + for (const auto & f : schema) { + f->eval(ctx, data); + } + + // post-processing + { + if (params.sampling.penalty_last_n == -1) { + // note: should be the slot's context and not the full context, but it's ok + params.sampling.penalty_last_n = n_ctx_slot; + } + + if (params.sampling.dry_penalty_last_n == -1) { + params.sampling.dry_penalty_last_n = n_ctx_slot; + } + + // if "reasoning_format" is not provided, its handler will not be called, we will need to handle it here + auto reasoning_format = params.chat_parser_params.reasoning_format; + params.chat_parser_params.reasoning_in_content = params.stream && (reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY); + } + + // debugging + { + auto budget = params.sampling.reasoning_budget_tokens; + SRV_DBG("reasoning budget: tokens=%d, generation_prompt='%s', start=%zu toks, end=%zu toks, forced=%zu toks\n", + budget, params.sampling.generation_prompt.c_str(), + params.sampling.reasoning_budget_start.size(), + params.sampling.reasoning_budget_end.size(), + params.sampling.reasoning_budget_forced.size()); + } + + return params; +} + +// +// eval() implementations +// + +static void handle_with_catch(const char * name, std::function func) { + try { + func(); + } catch (const std::exception & e) { + throw std::invalid_argument(string_format("Field '%s': %s", name, e.what())); + } +} + +template +void field_num::eval(field_eval_context & ctx, const json & data) { + for (const auto & n : name) { + if (data.contains(n)) { + handle_with_catch(n, [&]() { + if (custom_handler) { + custom_handler(ctx, data); + } else if (!is_hard_limit) { + val = std::max(min, std::min(max, data.at(n).template get())); + } else { + T tmp = data.at(n).template get(); + if (tmp < min || tmp > max) { + throw std::invalid_argument(std::string("Value must be between ") + std::to_string(min) + " <= value <= " + std::to_string(max) + ", but got " + std::to_string(tmp)); + } + val = tmp; + } + }); + return; + } + } +} + +void field_str::eval(field_eval_context & ctx, const json & data) { + GGML_ASSERT(custom_handler); + for (const auto & n : name) { + if (data.contains(n)) { + handle_with_catch(n, [&]() { + custom_handler(ctx, data); + }); + return; + } + } +} + +void field_bool::eval(field_eval_context & ctx, const json & data) { + for (const auto & n : name) { + if (data.contains(n)) { + handle_with_catch(n, [&]() { + if (custom_handler) { + custom_handler(ctx, data); + } else { + val = data.at(n).get(); + } + }); + return; + } + } +} + +void field_json::eval(field_eval_context & ctx, const json & data) { + GGML_ASSERT(custom_handler); + for (const auto & n : name) { + if (data.contains(n)) { + handle_with_catch(n, [&]() { + custom_handler(ctx, data); + }); + return; + } + } +} + +void field_nested::eval(field_eval_context & ctx, const json & data) { + for (const auto & n : name) { + if (data.contains(n) && data.at(n).is_object()) { + for (auto & f : subfields) { + f->eval(ctx, data.at(n)); + } + return; + } + } +} + +} // namespace server_schema diff --git a/tools/server/server-schema.h b/tools/server/server-schema.h new file mode 100644 index 0000000000..08cf427dc9 --- /dev/null +++ b/tools/server/server-schema.h @@ -0,0 +1,105 @@ +#pragma once + +#include "server-common.h" +#include "server-task.h" + +#include "sampling.h" +#include "speculative.h" + +#include +#include +#include +#include +#include +#include + +namespace server_schema { + +struct field_eval_context { + task_params & params; + const llama_vocab * vocab = nullptr; + const std::vector * logit_bias_eog = nullptr; + field_eval_context(task_params & params) : params(params) {} +}; + +using field_handler = std::function; + +struct field { + std::vector name; + const char * desc = ""; + field_handler custom_handler; + field() = default; + field(const char * n) : name({n}) {} + virtual ~field() = default; + field * set_desc(const char * s) { + desc = s; + return this; + } + // if 'name' is present, use it, otherwise look for aliases following the order they were added + field * add_alias(const char * n) { + name.push_back(n); + return this; + } + field * set_handler(field_handler h) { this->custom_handler = h; return this; } + virtual void eval(field_eval_context & ctx, const json & data) = 0; +}; + +template +struct field_num : public field { + T & val; + T min = std::numeric_limits::lowest(); + T max = std::numeric_limits::max(); + bool is_hard_limit = false; // if true, throw error if the value is invalid + field_num(const char * n, T & val) : field(n), val(val) {} + // limits are inclusive, min <= value <= max + field_num * set_limits(T min, T max) { + this->min = min; + this->max = max; + return this; + } + field_num * set_hard_limits(T min, T max) { + set_limits(min, max); + is_hard_limit = true; + return this; + } + virtual void eval(field_eval_context & ctx, const json & data) override; +}; + +struct field_str : public field { + field_str(const char * n) : field(n) {} + virtual void eval(field_eval_context & ctx, const json & data) override; +}; + +struct field_bool : public field { + bool & val; + field_bool(const char * n, bool & val) : field(n), val(val) {} + virtual void eval(field_eval_context & ctx, const json & data) override; +}; + +struct field_json : public field { + field_json(const char * n) : field(n) {} + virtual void eval(field_eval_context & ctx, const json & data) override; +}; + +struct field_nested : public field { + std::vector> subfields; + field_nested(const char * n) : field(n) {} + field_nested * add_subfield(field * f) { + subfields.emplace_back(std::unique_ptr(f)); + return this; + } + virtual void eval(field_eval_context & ctx, const json & data) override; +}; + +std::vector> make_llama_cmpl_schema( + const common_params & params_base, + task_params & params); + +task_params eval_llama_cmpl_schema( + const llama_vocab * vocab, + const common_params & params_base, + const int n_ctx_slot, + const std::vector & logit_bias_eog, + const json & data); + +} // namespace server_schema diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index 72a4bd076a..9ba039c8b8 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -232,396 +232,8 @@ common_chat_msg task_result_state::update_chat_msg( return chat_msg; } -// -// server_task // -task_params server_task::params_from_json_cmpl( - const llama_vocab * vocab, - const common_params & params_base, - const int n_ctx_slot, - const std::vector & logit_bias_eog, - const json & data) { - task_params params; - - // Sampling parameter defaults are loaded from the global server context (but individual requests can still them) - task_params defaults; - defaults.sampling = params_base.sampling; - defaults.speculative = params_base.speculative; - defaults.n_keep = params_base.n_keep; - defaults.n_predict = params_base.n_predict; - defaults.n_cache_reuse = params_base.n_cache_reuse; - defaults.cache_prompt = params_base.cache_prompt; - defaults.antiprompt = params_base.antiprompt; - - // enabling this will output extra debug information in the HTTP responses from the server - params.verbose = params_base.verbosity > 9; - params.timings_per_token = json_value(data, "timings_per_token", false); - - params.stream = json_value(data, "stream", false); - auto stream_opt = json_value(data, "stream_options", json::object()); - params.include_usage = json_value(stream_opt, "include_usage", false); - params.cache_prompt = json_value(data, "cache_prompt", defaults.cache_prompt); - params.return_tokens = json_value(data, "return_tokens", false); - params.return_progress = json_value(data, "return_progress", false); - auto max_tokens = json_value(data, "max_tokens", defaults.n_predict); - params.n_predict = json_value(data, "n_predict", json_value(data, "max_completion_tokens", max_tokens)); - params.n_indent = json_value(data, "n_indent", defaults.n_indent); - params.n_keep = json_value(data, "n_keep", defaults.n_keep); - params.n_discard = json_value(data, "n_discard", defaults.n_discard); - params.n_discard = std::max(0, params.n_discard); - params.n_cmpl = json_value(data, "n_cmpl", json_value(data, "n", 1)); - params.n_cache_reuse = json_value(data, "n_cache_reuse", defaults.n_cache_reuse); - //params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement - params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms); - params.response_fields = json_value(data, "response_fields", std::vector()); - - params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k); - params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p); - params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p); - params.sampling.top_n_sigma = json_value(data, "top_n_sigma", defaults.sampling.top_n_sigma); - params.sampling.xtc_probability = json_value(data, "xtc_probability", defaults.sampling.xtc_probability); - params.sampling.xtc_threshold = json_value(data, "xtc_threshold", defaults.sampling.xtc_threshold); - params.sampling.typ_p = json_value(data, "typical_p", defaults.sampling.typ_p); - params.sampling.temp = json_value(data, "temperature", defaults.sampling.temp); - params.sampling.dynatemp_range = json_value(data, "dynatemp_range", defaults.sampling.dynatemp_range); - params.sampling.dynatemp_exponent = json_value(data, "dynatemp_exponent", defaults.sampling.dynatemp_exponent); - params.sampling.penalty_last_n = json_value(data, "repeat_last_n", defaults.sampling.penalty_last_n); - params.sampling.penalty_repeat = json_value(data, "repeat_penalty", defaults.sampling.penalty_repeat); - params.sampling.penalty_freq = json_value(data, "frequency_penalty", defaults.sampling.penalty_freq); - params.sampling.penalty_present = json_value(data, "presence_penalty", defaults.sampling.penalty_present); - params.sampling.dry_multiplier = json_value(data, "dry_multiplier", defaults.sampling.dry_multiplier); - params.sampling.dry_base = json_value(data, "dry_base", defaults.sampling.dry_base); - params.sampling.dry_allowed_length = json_value(data, "dry_allowed_length", defaults.sampling.dry_allowed_length); - params.sampling.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", defaults.sampling.dry_penalty_last_n); - params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat); - params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau); - params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta); - params.sampling.adaptive_target = json_value(data, "adaptive_target", defaults.sampling.adaptive_target); - params.sampling.adaptive_decay = json_value(data, "adaptive_decay", defaults.sampling.adaptive_decay); - params.sampling.seed = json_value(data, "seed", defaults.sampling.seed); - params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs); - params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep); - params.sampling.backend_sampling = json_value(data, "backend_sampling", defaults.sampling.backend_sampling); - params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs); - - params.speculative = defaults.speculative; - - // TODO: to keep things simple, we disable speculative parameter adjustments for now -#if 0 - // TODO: for now, be able to adjust only the draft-model based speculative parameters - params.speculative.draft.n_min = json_value(data, "speculative.n_min", defaults.speculative.draft.n_min); - params.speculative.draft.n_max = json_value(data, "speculative.n_max", defaults.speculative.draft.n_max); - params.speculative.draft.p_min = json_value(data, "speculative.p_min", defaults.speculative.draft.p_min); - - params.speculative.draft.n_min = std::min(params.speculative.draft.n_max, params.speculative.draft.n_min); - params.speculative.draft.n_min = std::max(params.speculative.draft.n_min, 0); - params.speculative.draft.n_max = std::max(params.speculative.draft.n_max, 0); - - // for debugging and research purposes - params.speculative.type = common_speculative_type_from_name(json_value(data, "speculative.type", common_speculative_type_to_str(defaults.speculative.type))); - - params.speculative.ngram_size_n = json_value(data, "speculative.ngram_size_n", defaults.speculative.ngram_size_n); - params.speculative.ngram_size_m = json_value(data, "speculative.ngram_size_m", defaults.speculative.ngram_size_m); - params.speculative.ngram_min_hits = json_value(data, "speculative.ngram_m_hits", defaults.speculative.ngram_min_hits); - - params.speculative.ngram_size_n = std::max(std::min(1, (int) params.speculative.ngram_size_n), 1024); - params.speculative.ngram_size_m = std::max(std::min(1, (int) params.speculative.ngram_size_m), 1024); - params.speculative.ngram_min_hits = std::max(std::min(1, (int) params.speculative.ngram_min_hits), 1024); -#endif - - // Use OpenAI API logprobs only if n_probs wasn't provided - if (data.contains("logprobs") && params.sampling.n_probs == defaults.sampling.n_probs){ - params.sampling.n_probs = json_value(data, "logprobs", defaults.sampling.n_probs); - } - - if (data.contains("lora")) { - if (data.at("lora").is_array()) { - params.lora = parse_lora_request(data.at("lora")); - } else { - throw std::runtime_error("Error: 'lora' must be an array of objects with 'id' and 'scale' fields"); - } - } else { - params.lora = {}; - } - - // TODO: add more sanity checks for the input parameters - - if (params.sampling.penalty_last_n < -1) { - throw std::runtime_error("Error: repeat_last_n must be >= -1"); - } - - if (params.sampling.dry_penalty_last_n < -1) { - throw std::runtime_error("Error: dry_penalty_last_n must be >= -1"); - } - - if (params.sampling.penalty_last_n == -1) { - // note: should be the slot's context and not the full context, but it's ok - params.sampling.penalty_last_n = n_ctx_slot; - } - - if (params.sampling.dry_penalty_last_n == -1) { - params.sampling.dry_penalty_last_n = n_ctx_slot; - } - - if (params.sampling.dry_base < 1.0f) { - params.sampling.dry_base = defaults.sampling.dry_base; - } - - // sequence breakers for DRY - { - // Currently, this is not compatible with TextGen WebUI, Koboldcpp and SillyTavern format - // Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39 - - if (data.contains("dry_sequence_breakers")) { - params.sampling.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector()); - if (params.sampling.dry_sequence_breakers.empty()) { - throw std::runtime_error("Error: dry_sequence_breakers must be a non-empty array of strings"); - } - } - } - - // process "json_schema" and "grammar" - if (data.contains("json_schema") && !data.contains("grammar")) { - try { - auto schema = json_value(data, "json_schema", json::object()); - SRV_DBG("JSON schema: %s\n", schema.dump(2).c_str()); - std::string grammar_str = json_schema_to_grammar(schema); - SRV_DBG("Converted grammar: %s\n", grammar_str.c_str()); - params.sampling.grammar = {COMMON_GRAMMAR_TYPE_OUTPUT_FORMAT, std::move(grammar_str)}; - } catch (const std::exception & e) { - throw std::runtime_error(std::string("\"json_schema\": ") + e.what()); - } - } else { - params.sampling.grammar = defaults.sampling.grammar; - - std::string grammar_str = json_value(data, "grammar", std::string()); - if (!grammar_str.empty()) { - // grammar_type key is set by the server when converting chat template grammars - std::string grammar_type = json_value(data, "grammar_type", std::string()); - if (grammar_type == "tool_calls") { - params.sampling.grammar = {COMMON_GRAMMAR_TYPE_TOOL_CALLS, std::move(grammar_str)}; - } else { - // explicit grammar from the user (API field "grammar") - params.sampling.grammar = {COMMON_GRAMMAR_TYPE_USER, std::move(grammar_str)}; - } - SRV_DBG("Grammar (%s): %s\n", grammar_type.c_str(), common_grammar_value(params.sampling.grammar).c_str()); - } - params.sampling.grammar_lazy = json_value(data, "grammar_lazy", defaults.sampling.grammar_lazy); - SRV_DBG("Grammar lazy: %s\n", params.sampling.grammar_lazy ? "true" : "false"); - } - - { - auto it = data.find("chat_format"); - if (it != data.end()) { - params.chat_parser_params.format = static_cast(it->get()); - SRV_INF("Chat format: %s\n", common_chat_format_name(params.chat_parser_params.format)); - } else { - params.chat_parser_params.format = defaults.chat_parser_params.format; - } - common_reasoning_format reasoning_format = params_base.reasoning_format; - if (data.contains("reasoning_format")) { - reasoning_format = common_reasoning_format_from_name(data.at("reasoning_format").get()); - } - params.chat_parser_params.reasoning_format = reasoning_format; - params.chat_parser_params.reasoning_in_content = params.stream && (reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY); - params.chat_parser_params.generation_prompt = json_value(data, "generation_prompt", std::string()); - params.sampling.generation_prompt = params.chat_parser_params.generation_prompt; - SRV_DBG("Generation prompt: '%s'\n", params.chat_parser_params.generation_prompt.c_str()); - params.chat_parser_params.parse_tool_calls = json_value(data, "parse_tool_calls", false); - if (data.contains("chat_parser")) { - params.chat_parser_params.parser.load(data.at("chat_parser").get()); - } - if (data.contains("continue_final_message")) { - auto continuation = common_chat_continuation_parse(data.at("continue_final_message")); - params.chat_parser_params.is_continuation = continuation != COMMON_CHAT_CONTINUATION_NONE; - } - params.chat_parser_params.echo = json_value(data, "echo", false); - } - - { - const auto preserved_tokens = data.find("preserved_tokens"); - if (preserved_tokens != data.end()) { - for (const auto & t : *preserved_tokens) { - auto ids = common_tokenize(vocab, t.get(), /* add_special= */ false, /* parse_special= */ true); - if (ids.size() == 1) { - SRV_DBG("Preserved token: %d\n", ids[0]); - params.sampling.preserved_tokens.insert(ids[0]); - } else { - // This may happen when using a tool call style meant for a model with special tokens to preserve on a model without said tokens. - SRV_DBG("Not preserved because more than 1 token: %s\n", t.get().c_str()); - } - } - } - const auto grammar_triggers = data.find("grammar_triggers"); - if (grammar_triggers != data.end()) { - for (const auto & t : *grammar_triggers) { - server_grammar_trigger ct(t); - if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) { - const auto & word = ct.value.value; - auto ids = common_tokenize(vocab, word, /* add_special= */ false, /* parse_special= */ true); - if (ids.size() == 1) { - auto token = ids[0]; - if (std::find(params.sampling.preserved_tokens.begin(), params.sampling.preserved_tokens.end(), (llama_token) token) == params.sampling.preserved_tokens.end()) { - throw std::runtime_error("Grammar trigger word should be marked as preserved token: " + word); - } - SRV_DBG("Grammar trigger token: %d (`%s`)\n", token, word.c_str()); - common_grammar_trigger trigger; - trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN; - trigger.value = word; - trigger.token = token; - params.sampling.grammar_triggers.push_back(std::move(trigger)); - } else { - SRV_DBG("Grammar trigger word: `%s`\n", word.c_str()); - params.sampling.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word}); - } - } else { - if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN) { - SRV_DBG("Grammar trigger pattern: `%s`\n", ct.value.value.c_str()); - } else if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL) { - SRV_DBG("Grammar trigger pattern full: `%s`\n", ct.value.value.c_str()); - } else { - throw std::runtime_error("Unknown grammar trigger type"); - } - params.sampling.grammar_triggers.emplace_back(std::move(ct.value)); - } - } - } - if (params.sampling.grammar_lazy && params.sampling.grammar_triggers.empty()) { - throw std::runtime_error("Error: no triggers set for lazy grammar!"); - } - } - - // Parse reasoning budget sampler parameters - { - const int32_t budget = json_value(data, "reasoning_budget_tokens", (int32_t) -1); - const auto start_tag = json_value(data, "reasoning_budget_start_tag", std::string()); - const auto end_tag = json_value(data, "reasoning_budget_end_tag", std::string()); - const auto message = json_value(data, "reasoning_budget_message", std::string()); - params.sampling.reasoning_budget_tokens = budget; - params.sampling.reasoning_control = json_value(data, "reasoning_control", false); - - if (!start_tag.empty()) { - params.sampling.reasoning_budget_start = common_tokenize(vocab, start_tag, false, true); - } - if (!end_tag.empty()) { - params.sampling.reasoning_budget_end = common_tokenize(vocab, end_tag, false, true); - params.sampling.reasoning_budget_forced = common_tokenize(vocab, message + end_tag, false, true); - - SRV_DBG("reasoning budget: tokens=%d, generation_prompt='%s', start=%zu toks, end=%zu toks, forced=%zu toks\n", - budget, params.sampling.generation_prompt.c_str(), - params.sampling.reasoning_budget_start.size(), - params.sampling.reasoning_budget_end.size(), - params.sampling.reasoning_budget_forced.size()); - } - } - - { - params.sampling.logit_bias.clear(); - - const auto & logit_bias = data.find("logit_bias"); - if (logit_bias != data.end() && logit_bias->is_array()) { - const int n_vocab = llama_vocab_n_tokens(vocab); - for (const auto & el : *logit_bias) { - // TODO: we may want to throw errors here, in case "el" is incorrect - if (el.is_array() && el.size() == 2) { - float bias; - if (el[1].is_number()) { - bias = el[1].get(); - } else if (el[1].is_boolean() && !el[1].get()) { - bias = -INFINITY; - } else { - continue; - } - - if (el[0].is_number_integer()) { - llama_token tok = el[0].get(); - if (tok >= 0 && tok < n_vocab) { - params.sampling.logit_bias.push_back({tok, bias}); - } - } else if (el[0].is_string()) { - auto toks = common_tokenize(vocab, el[0].get(), false); - for (auto tok : toks) { - params.sampling.logit_bias.push_back({tok, bias}); - } - } - } - } - } else if (logit_bias != data.end() && logit_bias->is_object()) { - const int n_vocab = llama_vocab_n_tokens(vocab); - for (const auto & el : logit_bias->items()) { - float bias; - const auto & key = el.key(); - const auto & value = el.value(); - if (value.is_number()) { - bias = value.get(); - } else if (value.is_boolean() && !value.get()) { - bias = -INFINITY; - } else { - continue; - } - - char *end; - llama_token tok = strtol(key.c_str(), &end, 10); - if (*end == 0) { - if (tok >= 0 && tok < n_vocab) { - params.sampling.logit_bias.push_back({tok, bias}); - } - } else { - auto toks = common_tokenize(vocab, key, false); - for (auto tok : toks) { - params.sampling.logit_bias.push_back({tok, bias}); - } - } - } - } - - params.sampling.ignore_eos = json_value(data, "ignore_eos", params_base.sampling.ignore_eos); - if (params.sampling.ignore_eos) { - params.sampling.logit_bias.insert( - params.sampling.logit_bias.end(), - logit_bias_eog.begin(), logit_bias_eog.end()); - } - } - - { - params.antiprompt.clear(); - - const auto & stop = data.find("stop"); - if (stop != data.end() && stop->is_array()) { - for (const auto & word : *stop) { - if (!word.empty()) { - params.antiprompt.push_back(word); - } - } - } - // set reverse prompt from cli args if not set in the request - if (params.antiprompt.empty()) { - params.antiprompt = defaults.antiprompt; - } - } - - { - const auto samplers = data.find("samplers"); - if (samplers != data.end()) { - if (samplers->is_array()) { - params.sampling.samplers = common_sampler_types_from_names(*samplers); - } else if (samplers->is_string()){ - params.sampling.samplers = common_sampler_types_from_chars(samplers->get()); - } - } else { - params.sampling.samplers = defaults.sampling.samplers; - } - } - - if (params.n_cmpl > params_base.n_parallel) { - throw std::runtime_error("n_cmpl cannot be greater than the number of slots, please increase -np"); - } - - return params; -} - -// // result_timings // diff --git a/tools/server/server-task.h b/tools/server/server-task.h index 1a03d5f266..299c279d7d 100644 --- a/tools/server/server-task.h +++ b/tools/server/server-task.h @@ -210,13 +210,6 @@ struct server_task { } } - static task_params params_from_json_cmpl( - const llama_vocab * vocab, - const common_params & params_base, - const int n_ctx_slot, - const std::vector & logit_bias_eog, - const json & data); - // utility function static std::unordered_set get_list_id(const std::vector & tasks) { std::unordered_set ids(tasks.size());