diff --git a/examples/server/server-context.cpp b/examples/server/server-context.cpp index 99345458..444ed2ce 100644 --- a/examples/server/server-context.cpp +++ b/examples/server/server-context.cpp @@ -130,6 +130,10 @@ static void server_remove_speculative_stage(common_params_speculative & spec, co } } +static bool server_speculative_needs_draft_model(const common_params_speculative & spec) { + return spec.has_stage_type(COMMON_SPECULATIVE_TYPE_DRAFT); +} + static bool server_speculative_has_mtp(const common_params_speculative & spec) { return spec.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP); } @@ -273,17 +277,21 @@ bool server_context::load_model(const gpt_params& params_) { add_bos_token = llama_should_add_bos_token(model); has_eos_token = llama_add_eos_token(model) != 1; - if (params_base.has_mtp && params_base.n_parallel > 1) { - LOG_WARNING("MTP is not supported with parallel slots yet, disabling MTP to avoid cross-slot corruption.\n", { + if (params_base.n_parallel > 1 && server_speculative_has_mtp(params_base.speculative)) { + LOG_WARNING("MTP is not supported with parallel slots yet, removing the MTP stage to avoid cross-slot corruption.\n", { {"n_parallel", params_base.n_parallel}, + {"stage_chain", common_speculative_stage_chain_to_str(params_base.speculative)}, }); + params_base.has_mtp = false; - if (params_base.speculative.type == COMMON_SPECULATIVE_TYPE_MTP) { - params_base.speculative.type = COMMON_SPECULATIVE_TYPE_NONE; + + server_remove_speculative_stage(params_base.speculative, COMMON_SPECULATIVE_TYPE_MTP); + + if (!server_speculative_needs_draft_model(params_base.speculative)) { + params_base.speculative.model.clear(); + params_base.speculative.params.clear(); + params_base.speculative.model_dft = nullptr; } - params_base.speculative.model.clear(); - params_base.speculative.params.clear(); - params_base.speculative.model_dft = nullptr; } bool has_draft_model = !params_base.speculative.model.empty() || !params_base.speculative.params.empty(); @@ -470,7 +478,7 @@ void server_context::init() { bool can_spec = true; if (!params_base.dry_run) { can_spec = common_speculative_is_compat(ctx); - } + } if (!can_spec) { SRV_WRN("%s", "speculative decoding not supported by this context\n"); } @@ -1656,7 +1664,7 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task) int32_t banbuffer_size = json_value(data, "banbuffer_size", 0); slot.n_buffer = 0; // Ensure buffer calculation starts fresh for this slot slot.rewind_count_max = json_value(data, "rewind_count_max", -1); - + const auto& banned_strings = data.find("banned_strings"); if (banned_strings != data.end() && banned_strings->is_array()) { slot.ban_phrases.clear(); @@ -2805,7 +2813,7 @@ static size_t load_server_tokens_from_file(const std::string & filename, server size_t pos = 0; json token_json; if (file.is_open()) { - file >> token_json; + file >> token_json; pos = file.tellg(); file.close(); } @@ -3727,7 +3735,7 @@ bool server_context::create_checkpoint(server_slot & slot) { slot.server_cached_prompt.checkpoints.erase(slot.server_cached_prompt.checkpoints.begin()); } - + auto & cur = slot.server_cached_prompt.checkpoints.emplace_back(); server_prompt_checkpoint_update(cur, ctx, slot.id, slot.cache_tokens.n_tokens(), pos_min, pos_max, slot.n_past_offset); @@ -4060,7 +4068,7 @@ void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t slot.do_checkpoint = true; break; } - + } LOG_VERBOSE("prompt processing progress", { {"id_slot", slot.id}, @@ -4395,15 +4403,15 @@ void server_context::release_slot_after_final_response(server_slot & slot) { void server_context::send_token_results(completion_token_outputs& results, server_slot& slot, int32_t n) { int count = 0; bool released = false; - + int32_t start_pos = slot.n_past - (int32_t)slot.token_buffer.size() + 1; for (auto& it : results) { bool has_next = process_token(it, slot); - + // Clean up positional bans for the token we just confirmed/sent slot.positional_bans.erase(start_pos + count); - + count++; if (!has_next) { if (slot.stopped_limit && !slot.stopped_eos && !slot.stopped_word) { @@ -4436,7 +4444,7 @@ inline int32_t check_ban_phrase(server_slot& slot) { std::string string_buffer; std::vector token_offsets; - + for (const auto& it : slot.token_buffer) { token_offsets.push_back(string_buffer.size()); string_buffer += it.text_to_send; @@ -4488,10 +4496,10 @@ inline int32_t check_ban_phrase(server_slot& slot) { if (found) { int32_t token_idx = -1; for (size_t i = 0; i < token_offsets.size(); ++i) { - size_t len = (i == token_offsets.size() - 1) - ? string_buffer.size() - token_offsets[i] + size_t len = (i == token_offsets.size() - 1) + ? string_buffer.size() - token_offsets[i] : token_offsets[i+1] - token_offsets[i]; - + if (best_start >= token_offsets[i] && best_start < token_offsets[i] + len) { token_idx = (int32_t)i; break; @@ -4509,7 +4517,7 @@ inline int32_t check_ban_phrase(server_slot& slot) { inline void rewind_context(server_slot& slot, int32_t ban_pos) { slot.rewind_count++; - + int32_t buffer_start_pos = slot.n_past - (int32_t)slot.token_buffer.size() + 1; int32_t n_keep_buffer = ban_pos - buffer_start_pos; if (n_keep_buffer < 0) n_keep_buffer = 0; @@ -4518,9 +4526,9 @@ inline void rewind_context(server_slot& slot, int32_t ban_pos) { int32_t n = 0; for (auto result = slot.token_buffer.begin() + n_keep_buffer; result != slot.token_buffer.end(); result++) { llama_token banned_tok = result->tok; - + if (n == 0) { - LLAMA_LOG_DEBUG("Banned pattern detected at pos %d. Banning token %d ('%s') and rewinding.\n", + LLAMA_LOG_DEBUG("Banned pattern detected at pos %d. Banning token %d ('%s') and rewinding.\n", ban_pos, banned_tok, result->text_to_send.c_str()); } @@ -4533,7 +4541,7 @@ inline void rewind_context(server_slot& slot, int32_t ban_pos) { } int32_t n_rewind_total = (slot.n_past + 1) - ban_pos; - + size_t n_keep_cache = 0; if (ban_pos > 0) { n_keep_cache = (size_t)(ban_pos - 1); @@ -4546,13 +4554,13 @@ inline void rewind_context(server_slot& slot, int32_t ban_pos) { if (n_keep_cache < slot.cache_tokens.size()) { slot.sampled = slot.cache_tokens[n_keep_cache]; } else { - slot.sampled = 0; + slot.sampled = 0; } // Truncate cache slot.cache_tokens.keep_first(n_keep_cache); slot.n_past = slot.cache_tokens.n_tokens(); - + // Remove from KV cache llama_kv_cache_seq_rm(slot.ctx, slot.id, slot.cache_tokens.pos_next(slot.n_past), -1); @@ -4590,13 +4598,13 @@ void server_context::buffer_and_check_string_ban(server_slot & slot, completion_ // Automatic / Heuristic logic // Account for strings + regex + regex_ci size_t total_bans = slot.ban_phrases.size() + slot.ban_regex.size() + slot.ban_regex_ci.size(); - + // Heuristic: Allow if under 20 OR under 2 * total_bans // Conversely: Stop if >= 20 AND > 2 * total_bans if (slot.rewind_count >= 20 && slot.rewind_count > 2 * total_bans) { allow_rewind = false; } - } + } else if (slot.rewind_count_max > 0) { // Strict limit logic if (slot.rewind_count >= slot.rewind_count_max) { @@ -4613,7 +4621,7 @@ void server_context::buffer_and_check_string_ban(server_slot & slot, completion_ else if (buffer_full || !next_token) { slot.rewind_status = false; slot.rewind_count = 0; - + if (!next_token) { // send all remaining tokens send_token_results(slot.token_buffer, slot); @@ -4625,7 +4633,7 @@ void server_context::buffer_and_check_string_ban(server_slot & slot, completion_ } else { // buffer the result, wait for more tokens to validate string - slot.sampled = result.tok; + slot.sampled = result.tok; } }