diff --git a/common/chat.cpp b/common/chat.cpp index ded8440e66..cee6ad650a 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -90,41 +90,93 @@ std::string common_chat_msg::render_content(const std::string & delimiter) const return text; } -std::vector common_chat_split_by_role(const std::string & prompt, const std::vector & delims) { - if (delims.empty() || prompt.empty()) { - return {}; +common_chat_role common_chat_role_from_string(const std::string & role) { + if (role == "system") { return COMMON_CHAT_ROLE_SYSTEM; } + if (role == "assistant") { return COMMON_CHAT_ROLE_ASSISTANT; } + if (role == "user") { return COMMON_CHAT_ROLE_USER; } + if (role == "tool") { return COMMON_CHAT_ROLE_TOOL; } + return COMMON_CHAT_ROLE_UNKNOWN; +} + +const char * common_chat_role_to_string(common_chat_role role) { + switch (role) { + case COMMON_CHAT_ROLE_SYSTEM: return "system"; + case COMMON_CHAT_ROLE_ASSISTANT: return "assistant"; + case COMMON_CHAT_ROLE_USER: return "user"; + case COMMON_CHAT_ROLE_TOOL: return "tool"; + case COMMON_CHAT_ROLE_UNKNOWN: return ""; + } + return ""; +} + +json common_chat_msg_delimiters::to_json() const { + json result = json::array(); + for (const auto & d : delimiters) { + result.push_back({ + { "role", common_chat_role_to_string(d.role) }, + { "delimiter", d.delimiter }, + }); + } + return result; +} + +common_chat_msg_delimiters common_chat_msg_delimiters_parse(const json & delimiters) { + common_chat_msg_delimiters result; + + if (!delimiters.is_array()) { + return result; } - auto parser = build_peg_parser([&](common_peg_parser_builder & p) { - std::vector all_delims; - std::vector tagged_messages; - - all_delims.reserve(delims.size()); - tagged_messages.reserve(delims.size()); - for (const auto & d : delims) { - all_delims.push_back(d.delimiter); + result.delimiters.reserve(delimiters.size()); + for (const auto & d : delimiters) { + if (!d.is_object()) { + continue; } - - auto any_delim = p.until_one_of(all_delims); - for (const auto & d : delims) { - tagged_messages.push_back(p.tag(d.role, p.literal(d.delimiter) + any_delim)); - } - - return any_delim + p.zero_or_more(p.choice(tagged_messages)) + p.end(); - }); - - common_peg_parse_context ctx(prompt); - const auto result = parser.parse(ctx); - if (!result.success()) { - return {}; + result.delimiters.push_back({ + common_chat_role_from_string(d.value("role", std::string())), + d.value("delimiter", std::string()), + }); } - std::vector spans; - ctx.ast.visit(result, [&](const common_peg_ast_node & node) { - if (!node.tag.empty()) { - spans.push_back({ node.tag, node.start, node.end - node.start }); + return result; +} + +void common_chat_msg_delimiters::tokenize(const llama_vocab * vocab) { + for (auto & d : delimiters) { + d.tokens = common_tokenize(vocab, d.delimiter, false, true); + } +} + +common_chat_msg_spans common_chat_msg_delimiters::split(const llama_tokens & tokens, const std::map & skips) const { + std::vector> matches; + + auto skip = skips.begin(); + for (size_t i = 0; i < tokens.size();) { + if (skip != skips.end() && i == skip->first) { + i += skip->second; + ++skip; + continue; } - }); + for (const auto & d : delimiters) { + if (i + d.tokens.size() > tokens.size()) { + continue; + } + if (std::equal(d.tokens.begin(), d.tokens.end(), tokens.begin() + i)) { + matches.emplace_back(d.role, i); + break; + } + } + i++; + } + + matches.emplace_back(COMMON_CHAT_ROLE_UNKNOWN, tokens.size()); + + common_chat_msg_spans spans; + for (size_t i = 0; i + 1 < matches.size(); i++) { + const auto & curr = matches[i]; + const auto & next = matches[i + 1]; + spans.add(curr.first, curr.second, next.second - curr.second); + } return spans; } @@ -1081,13 +1133,13 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp data.prompt = prompt; data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs, /* messages_override= */ adjusted_messages); - data.message_spans = common_chat_split_by_role(prompt, { - { "assistant", "<|start|>assistant" }, - { "user", "<|start|>user" }, - { "system", "<|start|>developer" }, - { "system", "<|start|>system" }, - { "tool", "<|start|>functions" }, - }); + data.message_delimiters = { + { COMMON_CHAT_ROLE_ASSISTANT, "<|start|>assistant" }, + { COMMON_CHAT_ROLE_USER, "<|start|>user" }, + { COMMON_CHAT_ROLE_SYSTEM, "<|start|>developer" }, + { COMMON_CHAT_ROLE_SYSTEM, "<|start|>system" }, + { COMMON_CHAT_ROLE_TOOL, "<|start|>functions" }, + }; data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; data.supports_thinking = true; @@ -1228,10 +1280,10 @@ static common_chat_params common_chat_params_init_gemma4(const common_chat_templ data.prompt += data.generation_prompt; } - data.message_spans = common_chat_split_by_role(data.prompt, { - { "user", "<|turn>user\n" }, - { "assistant", "<|turn>model\n" }, - }); + data.message_delimiters = { + { COMMON_CHAT_ROLE_USER, "<|turn>user" }, + { COMMON_CHAT_ROLE_ASSISTANT, "<|turn>model" }, + }; data.format = COMMON_CHAT_FORMAT_PEG_GEMMA4; data.supports_thinking = true; @@ -2030,15 +2082,15 @@ static common_chat_params common_chat_params_init_cohere2moe(const common_chat_t RESULT_START, RESULT_END, }; - // Split the rendered prompt into per-role message spans. Tool results are rendered with the + // Declare per-role message delimiters. Tool results are rendered with the // system token followed by <|START_TOOL_RESULT|>, so the "tool" delimiter must be listed before // the plain "system" one (it is a strict superset, and the role split tries delimiters in order). - data.message_spans = common_chat_split_by_role(data.prompt, { - { "assistant", GEN_PREFIX }, - { "user", TURN_START + USER }, - { "tool", TURN_START + SYSTEM + RESULT_START }, - { "system", TURN_START + SYSTEM }, - }); + data.message_delimiters = { + { COMMON_CHAT_ROLE_ASSISTANT, GEN_PREFIX }, + { COMMON_CHAT_ROLE_USER, TURN_START + USER }, + { COMMON_CHAT_ROLE_TOOL, TURN_START + SYSTEM + RESULT_START }, + { COMMON_CHAT_ROLE_SYSTEM, TURN_START + SYSTEM }, + }; auto has_tools = inputs.tools.is_array() && !inputs.tools.empty(); auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE; @@ -2526,17 +2578,15 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_ autoparser.analyze_template(tmpl); auto auto_params = autoparser::peg_generator::generate_parser(tmpl, params, autoparser); - std::vector delimiters; + common_chat_msg_delimiters delimiters; if (!autoparser.assistant_start.empty()) { - delimiters.push_back({ "assistant", autoparser.assistant_start }); + delimiters.add(COMMON_CHAT_ROLE_ASSISTANT, autoparser.assistant_start); } if (!autoparser.user_start.empty()) { - delimiters.push_back({ "user", autoparser.user_start }); + delimiters.add(COMMON_CHAT_ROLE_USER, autoparser.user_start); } - if (!delimiters.empty()) { - auto_params.message_spans = common_chat_split_by_role(auto_params.prompt, delimiters); - } + auto_params.message_delimiters = std::move(delimiters); auto_params.supports_thinking = autoparser.reasoning.mode != autoparser::reasoning_mode::NONE; if (auto_params.supports_thinking) { diff --git a/common/chat.h b/common/chat.h index 5659cd42a0..7898f1623f 100644 --- a/common/chat.h +++ b/common/chat.h @@ -143,15 +143,75 @@ struct common_chat_msg_diff { } }; +enum common_chat_role { + COMMON_CHAT_ROLE_UNKNOWN, + COMMON_CHAT_ROLE_SYSTEM, + COMMON_CHAT_ROLE_ASSISTANT, + COMMON_CHAT_ROLE_USER, + COMMON_CHAT_ROLE_TOOL +}; + +common_chat_role common_chat_role_from_string(const std::string & role); +const char * common_chat_role_to_string(common_chat_role role); + struct common_chat_msg_span { - std::string role; + common_chat_role role = COMMON_CHAT_ROLE_UNKNOWN; std::size_t pos = 0; std::size_t len = 0; + + bool valid() const { + return role != COMMON_CHAT_ROLE_UNKNOWN; + } +}; + +struct common_chat_msg_spans { + std::vector spans; + + void add(common_chat_role role, size_t pos, size_t len) { + spans.push_back({ role, pos, len }); + } + + bool is_user_start(int32_t pos) const { + for (auto it = spans.begin(); it != spans.end(); ++it) { + if (it->role == COMMON_CHAT_ROLE_USER && pos == (int32_t) it->pos) { + return true; + } + } + return false; + } + + int32_t last_user_message_pos() const { + for (auto it = spans.rbegin(); it != spans.rend(); ++it) { + if (it->role == COMMON_CHAT_ROLE_USER) { + return (int32_t) it->pos; + } + } + return -1; + } }; struct common_chat_msg_delimiter { - std::string role; - std::string delimiter; + common_chat_role role = COMMON_CHAT_ROLE_UNKNOWN; + std::string delimiter; + llama_tokens tokens = {}; +}; + +struct common_chat_msg_delimiters { + std::vector delimiters; + + common_chat_msg_delimiters() = default; + common_chat_msg_delimiters(std::initializer_list delims) : delimiters(delims) {} + + void add(common_chat_role role, const std::string & delimiter) { + delimiters.push_back({ role, delimiter }); + } + + void tokenize(const llama_vocab * vocab); + + // split tokens into message spans. skips maps a start index to a length of a region to jump over without matching + common_chat_msg_spans split(const llama_tokens & tokens, const std::map & skips = {}) const; + + nlohmann::ordered_json to_json() const; }; struct common_chat_tool { @@ -219,7 +279,7 @@ struct common_chat_params { std::vector preserved_tokens; std::vector additional_stops; std::string parser; - std::vector message_spans; + common_chat_msg_delimiters message_delimiters; }; // per-message parsing syntax @@ -325,5 +385,4 @@ struct common_chat_prompt_preset { common_chat_prompt_preset common_chat_get_asr_prompt(const common_chat_templates * chat_templates); -std::vector common_chat_split_by_role(const std::string & prompt, const std::vector & delims); - +common_chat_msg_delimiters common_chat_msg_delimiters_parse(const nlohmann::ordered_json & delimiters); diff --git a/common/common.h b/common/common.h index f2f2202ec2..75a6036a0f 100644 --- a/common/common.h +++ b/common/common.h @@ -609,7 +609,7 @@ struct common_params { bool cache_prompt = true; // whether to enable prompt caching bool cache_idle_slots = true; // save and clear idle slots upon starting a new task int32_t n_ctx_checkpoints = 32; // max number of context checkpoints per slot - int32_t checkpoint_min_step = 256; // minimum spacing between context checkpoints + int32_t checkpoint_min_step = 8192; // minimum spacing between context checkpoints int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc. std::string hostname = "127.0.0.1"; diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index 30aa35e137..c38aed8cfe 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -1562,37 +1562,112 @@ static void test_msgs_oaicompat_json_conversion() { } } -static void test_split_by_role() { +static void test_msg_token_delimiters_split() { LOG_DBG("%s\n", __func__); + // Delimiters that share a leading token, distinguished by the second token, + // to exercise the per-position token matching. + const common_chat_msg_delimiters delims = { + { { COMMON_CHAT_ROLE_USER, "", { 10, 11 } }, + { COMMON_CHAT_ROLE_ASSISTANT, "", { 10, 12 } } } + }; + // Empty inputs - assert_equals(0, common_chat_split_by_role("", {}).size()); - assert_equals(0, common_chat_split_by_role("hello", {}).size()); - assert_equals(0, common_chat_split_by_role("", { { "user", "<|user|>" } }).size()); + assert_equals(0, common_chat_msg_delimiters{}.split({}).spans.size()); + assert_equals(0, common_chat_msg_delimiters{}.split({ 10, 11 }).spans.size()); + assert_equals(0, delims.split({}).spans.size()); - // Multi-role conversation, no leading/trailing content + // No delimiters match -> no spans + assert_equals(0, delims.split({ 100, 101, 102 }).spans.size()); + + // Multi-role conversation: HiHelloBye { - const std::string prompt = "<|user|>Hi<|assistant|>Hello<|user|>Bye"; - const auto splits = common_chat_split_by_role(prompt, { - { "user", "<|user|>" }, - { "assistant", "<|assistant|>" }, - }); - assert_equals(3, splits.size()); + const llama_tokens tokens = { + 10, 11, // + 100, 101, // Hi + 10, 12, // + 200, 201, 202, // Hello + 10, 11, // + 300, 301, // Bye + }; - assert_equals("user", splits[0].role); - assert_equals(0, splits[0].pos); - assert_equals(10, splits[0].len); - assert_equals("<|user|>Hi", prompt.substr(splits[0].pos, splits[0].len)); + const auto result = delims.split(tokens); + const auto & spans = result.spans; + assert_equals(3, spans.size()); - assert_equals("assistant", splits[1].role); - assert_equals(10, splits[1].pos); - assert_equals(18, splits[1].len); - assert_equals("<|assistant|>Hello", prompt.substr(splits[1].pos, splits[1].len)); + assert_equals(COMMON_CHAT_ROLE_USER, spans[0].role); + assert_equals(0, spans[0].pos); + assert_equals(4, spans[0].len); - assert_equals("user", splits[2].role); - assert_equals(28, splits[2].pos); - assert_equals(11, splits[2].len); - assert_equals("<|user|>Bye", prompt.substr(splits[2].pos, splits[2].len)); + assert_equals(COMMON_CHAT_ROLE_ASSISTANT, spans[1].role); + assert_equals(4, spans[1].pos); + assert_equals(5, spans[1].len); + + assert_equals(COMMON_CHAT_ROLE_USER, spans[2].role); + assert_equals(9, spans[2].pos); + assert_equals(4, spans[2].len); + + // is_user_start() is true at the token position where a user span begins + assert_equals(true, result.is_user_start(0)); + assert_equals(false, result.is_user_start(4)); // assistant span + assert_equals(true, result.is_user_start(9)); + } + + // Content before the first delimiter is not captured as a span + { + const llama_tokens tokens = { + 500, 501, // leading content (dropped) + 10, 11, // + 100, // Hi + }; + + const auto spans = delims.split(tokens).spans; + assert_equals(1, spans.size()); + assert_equals(COMMON_CHAT_ROLE_USER, spans[0].role); + assert_equals(2, spans[0].pos); + assert_equals(3, spans[0].len); + } + + // Skipped regions (media chunks) are jumped over but still count as span content + { + const llama_tokens tokens = { + 10, 11, // + LLAMA_TOKEN_NULL, // media chunk (3 tokens) + LLAMA_TOKEN_NULL, + LLAMA_TOKEN_NULL, + 100, // Hi + 10, 12, // + }; + + const std::map skips = { { 2, 3 } }; + + const auto spans = delims.split(tokens, skips).spans; + assert_equals(2, spans.size()); + + assert_equals(COMMON_CHAT_ROLE_USER, spans[0].role); + assert_equals(0, spans[0].pos); + assert_equals(6, spans[0].len); + + assert_equals(COMMON_CHAT_ROLE_ASSISTANT, spans[1].role); + assert_equals(6, spans[1].pos); + assert_equals(2, spans[1].len); + } + + // A delimiter sequence inside a skipped region is not matched + { + const llama_tokens tokens = { + 10, 11, // + 10, 12, // skipped region that happens to contain delimiter tokens + 100, // Hi + }; + + const std::map skips = { { 2, 2 } }; + + const auto spans = delims.split(tokens, skips).spans; + assert_equals(1, spans.size()); + assert_equals(COMMON_CHAT_ROLE_USER, spans[0].role); + assert_equals(0, spans[0].pos); + assert_equals(5, spans[0].len); } } @@ -5857,7 +5932,7 @@ int main(int argc, char ** argv) { { test_msg_diffs_compute(); test_msgs_oaicompat_json_conversion(); - test_split_by_role(); + test_msg_token_delimiters_split(); test_tools_oaicompat_json_conversion(); test_convert_responses_to_chatcmpl(); test_developer_role_to_system_workaround(); diff --git a/tools/server/server-common.cpp b/tools/server/server-common.cpp index e412b94c5c..ac291d359a 100644 --- a/tools/server/server-common.cpp +++ b/tools/server/server-common.cpp @@ -518,6 +518,14 @@ size_t server_tokens::get_common_prefix(const server_tokens & b) const { return max_idx; // all tokens are equal } +common_chat_msg_spans server_tokens::find_message_spans(const common_chat_msg_delimiters & delims) const { + std::map skips; + for (const auto & it : map_idx_to_media) { + skips[it.first] = mtmd_input_chunk_get_n_tokens(it.second.get()); + } + return delims.split(tokens, skips); +} + bool server_tokens::validate(const struct llama_context * ctx) const { const llama_model * model = llama_get_model(ctx); const llama_vocab * vocab = llama_model_get_vocab(model); @@ -1104,15 +1112,7 @@ json oaicompat_chat_params_parse( llama_params["chat_parser"] = chat_params.parser; } - llama_params["message_spans"] = json::array(); - - for (const auto & span : chat_params.message_spans) { - llama_params["message_spans"].push_back({ - { "role", span.role }, - { "pos", span.pos }, - { "len", span.len }, - }); - } + llama_params["message_delimiters"] = chat_params.message_delimiters.to_json(); // Reasoning budget: pass parameters through to sampling layer { diff --git a/tools/server/server-common.h b/tools/server/server-common.h index efd31733b0..c0eaec6b02 100644 --- a/tools/server/server-common.h +++ b/tools/server/server-common.h @@ -218,6 +218,9 @@ public: size_t get_common_prefix(const server_tokens & b) const; + // split the tokens into message spans, skipping over media chunks + common_chat_msg_spans find_message_spans(const common_chat_msg_delimiters & delims) const; + // make sure all text tokens are within the vocab range bool validate(const struct llama_context * ctx) const; diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 0a25b414ed..ca91449d26 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -3436,8 +3436,8 @@ private: has_mtmd = true; } - const int32_t n_before_user = slot.task->params.n_before_user; - const bool n_before_user_known = n_before_user > 0; + const auto & spans = slot.task->params.message_spans; + const auto last_user_pos = spans.last_user_message_pos(); // add prompt tokens for processing in the current batch while (slot.prompt.n_tokens() < slot.task->n_tokens() && batch.size() < n_batch) { @@ -3466,10 +3466,8 @@ private: slot.n_prompt_tokens_processed++; - // stop the prompt batch exactly before the latest user input, so a checkpoint - // can be created after the previous messages - if (n_before_user_known && - slot.prompt.n_tokens() == n_before_user) { + // stop the prompt batch exactly before a user message + if (spans.is_user_start(slot.prompt.n_tokens())) { break; } @@ -3498,8 +3496,13 @@ private: // the number of tokens added to the batch for the current slot const auto n_tokens_cur = batch.size() - n_tokens_prev; + const auto n_tokens_start = slot.prompt.n_tokens() - n_tokens_cur; + const bool near_prompt_end = slot.task->n_tokens() < slot.prompt.n_tokens() + n_ubatch; + const bool is_user_start = spans.is_user_start(n_tokens_start); + const bool is_last_user_message = n_tokens_start == last_user_pos; + // entire prompt has been processed if (slot.prompt.n_tokens() == slot.task->n_tokens()) { slot.state = SLOT_STATE_DONE_PROMPT; @@ -3514,8 +3517,9 @@ private: slot.init_sampler(); } else { - // skip ordinary mid-prompt checkpoints - if (!n_before_user_known && !near_prompt_end) { + // skip ordinary mid-prompt checkpoints, unless the batch starts a user + // message or we are near the end of the prompt + if (!is_user_start && !near_prompt_end) { do_checkpoint = false; } } @@ -3523,29 +3527,6 @@ private: const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), slot.id); const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), slot.id); - // checkpoints are created before the current batch is decoded, so - // their token position is the batch start rather than the prompt end - const int32_t n_tokens_start = slot.prompt.n_tokens() - n_tokens_cur; - - { - const bool is_on_user = - n_before_user_known && - n_tokens_start == n_before_user; - - const bool is_after_user = - n_before_user_known && - n_tokens_start > n_before_user; - - const bool is_allowed = - !n_before_user_known || - is_on_user || - (is_after_user && near_prompt_end); - - if (do_checkpoint && !is_allowed) { - do_checkpoint = false; - } - } - // nothing to checkpoint yet // TODO: is this check needed? if (do_checkpoint && pos_min < 0) { @@ -3555,8 +3536,8 @@ private: // do not checkpoint after mtmd chunks do_checkpoint = do_checkpoint && !has_mtmd; - // no need to create checkpoints that are too close together - do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || n_tokens_start > slot.prompt.checkpoints.back().n_tokens + params_base.checkpoint_min_step); + // no need to create checkpoints that are too close together, unless it's the last user message + do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || is_last_user_message || n_tokens_start > slot.prompt.checkpoints.back().n_tokens + params_base.checkpoint_min_step); SLT_DBG(slot, "main/do_checkpoint = %s, pos_min = %d, pos_max = %d\n", do_checkpoint ? "yes" : "no", pos_min, pos_max); // note: we create the checkpoint before calling llama_decode(), so the current batch is not @@ -4055,54 +4036,6 @@ void server_context::set_state_callback(server_state_callback_t callback) { }); } -// compute the number of tokens before the last user message in the prompt -static int32_t prompt_get_n_before_user( - const json & message_spans, - const std::string & prompt, - const std::vector & files, - const llama_vocab * vocab, - mtmd_context * mctx) { - int32_t result = -1; - int32_t byte_pos = -1; - - for (const auto & span : message_spans) { - const std::string role = json_value(span, "role", std::string()); - - if (role == "user") { - byte_pos = json_value(span, "pos", -1); - } - } - - if (byte_pos >= 0) { - GGML_ASSERT((size_t) byte_pos <= prompt.size()); - - const std::string prefix = prompt.substr(0, (size_t) byte_pos); - - const std::string marker = get_media_marker(); - size_t n_prefix_media = 0; - for (size_t pos = 0; (pos = prefix.find(marker, pos)) != std::string::npos; pos += marker.size()) { - n_prefix_media++; - } - - GGML_ASSERT(n_prefix_media <= files.size()); - - if (mctx != nullptr && n_prefix_media > 0) { - // TODO: this makes a copy - avoid it - std::vector prefix_files(files.begin(), files.begin() + n_prefix_media); - - result = (int32_t) process_mtmd_prompt(mctx, prefix, prefix_files).size(); - } else { - result = (int32_t) tokenize_input_prompts(vocab, nullptr, prefix, true, true)[0].size(); - } - - SRV_TRC("message_spans: last user message: byte_pos=%d, media=%zu, n_before_user=%d\n", - byte_pos, n_prefix_media, result); - } - - return result; -} - - // // server_routes // @@ -4150,6 +4083,10 @@ std::unique_ptr server_routes::handle_completions_impl( // tasks.reserve(inputs.size()); // TODO: this is inaccurate due to child tasks + // message delimiters for checkpointing + auto delimiters = common_chat_msg_delimiters_parse(json_value(data, "message_delimiters", json::array())); + delimiters.tokenize(ctx_server.vocab); + for (size_t i = 0; i < inputs.size(); i++) { server_task task = server_task(type); @@ -4163,16 +4100,7 @@ std::unique_ptr server_routes::handle_completions_impl( meta->logit_bias_eog, data); - const auto message_spans = json_value(data, "message_spans", json::array()); - if (prompt.is_string() && message_spans.is_array()) { - task.params.n_before_user = - prompt_get_n_before_user( - message_spans, - prompt.get(), - files, - ctx_server.vocab, - ctx_server.mctx); - } + task.params.message_spans = task.tokens.find_message_spans(delimiters); task.id_slot = json_value(data, "id_slot", -1); diff --git a/tools/server/server-task.h b/tools/server/server-task.h index 299c279d7d..293bdf053a 100644 --- a/tools/server/server-task.h +++ b/tools/server/server-task.h @@ -62,9 +62,6 @@ struct task_params { int32_t n_cache_reuse = 0; // min chunk size to attempt reusing from the cache via KV shifting (0 = disabled) - // number of prompt tokens before the latest user message - int32_t n_before_user = -1; - int64_t t_max_prompt_ms = -1; // TODO: implement int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit @@ -92,6 +89,9 @@ struct task_params { // per-request parameters for chat parsing common_chat_parser_params chat_parser_params; + // message spans for checkpointing + common_chat_msg_spans message_spans; + // Embeddings int32_t embd_normalize = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm)