expand np guardrail for all mtp types (#1901)

This commit is contained in:
Samuel Oliveira Alves 2026-05-30 10:19:53 -03:00 committed by GitHub
parent 8960c5ba5e
commit 3f40e73c36
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -130,6 +130,10 @@ static void server_remove_speculative_stage(common_params_speculative & spec, co
}
}
static bool server_speculative_needs_draft_model(const common_params_speculative & spec) {
return spec.has_stage_type(COMMON_SPECULATIVE_TYPE_DRAFT);
}
static bool server_speculative_has_mtp(const common_params_speculative & spec) {
return spec.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP);
}
@ -273,17 +277,21 @@ bool server_context::load_model(const gpt_params& params_) {
add_bos_token = llama_should_add_bos_token(model);
has_eos_token = llama_add_eos_token(model) != 1;
if (params_base.has_mtp && params_base.n_parallel > 1) {
LOG_WARNING("MTP is not supported with parallel slots yet, disabling MTP to avoid cross-slot corruption.\n", {
if (params_base.n_parallel > 1 && server_speculative_has_mtp(params_base.speculative)) {
LOG_WARNING("MTP is not supported with parallel slots yet, removing the MTP stage to avoid cross-slot corruption.\n", {
{"n_parallel", params_base.n_parallel},
{"stage_chain", common_speculative_stage_chain_to_str(params_base.speculative)},
});
params_base.has_mtp = false;
if (params_base.speculative.type == COMMON_SPECULATIVE_TYPE_MTP) {
params_base.speculative.type = COMMON_SPECULATIVE_TYPE_NONE;
server_remove_speculative_stage(params_base.speculative, COMMON_SPECULATIVE_TYPE_MTP);
if (!server_speculative_needs_draft_model(params_base.speculative)) {
params_base.speculative.model.clear();
params_base.speculative.params.clear();
params_base.speculative.model_dft = nullptr;
}
params_base.speculative.model.clear();
params_base.speculative.params.clear();
params_base.speculative.model_dft = nullptr;
}
bool has_draft_model = !params_base.speculative.model.empty() || !params_base.speculative.params.empty();
@ -470,7 +478,7 @@ void server_context::init() {
bool can_spec = true;
if (!params_base.dry_run) {
can_spec = common_speculative_is_compat(ctx);
}
}
if (!can_spec) {
SRV_WRN("%s", "speculative decoding not supported by this context\n");
}
@ -1656,7 +1664,7 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
int32_t banbuffer_size = json_value(data, "banbuffer_size", 0);
slot.n_buffer = 0; // Ensure buffer calculation starts fresh for this slot
slot.rewind_count_max = json_value(data, "rewind_count_max", -1);
const auto& banned_strings = data.find("banned_strings");
if (banned_strings != data.end() && banned_strings->is_array()) {
slot.ban_phrases.clear();
@ -2805,7 +2813,7 @@ static size_t load_server_tokens_from_file(const std::string & filename, server
size_t pos = 0;
json token_json;
if (file.is_open()) {
file >> token_json;
file >> token_json;
pos = file.tellg();
file.close();
}
@ -3727,7 +3735,7 @@ bool server_context::create_checkpoint(server_slot & slot) {
slot.server_cached_prompt.checkpoints.erase(slot.server_cached_prompt.checkpoints.begin());
}
auto & cur = slot.server_cached_prompt.checkpoints.emplace_back();
server_prompt_checkpoint_update(cur, ctx, slot.id, slot.cache_tokens.n_tokens(), pos_min, pos_max, slot.n_past_offset);
@ -4060,7 +4068,7 @@ void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t
slot.do_checkpoint = true;
break;
}
}
LOG_VERBOSE("prompt processing progress", {
{"id_slot", slot.id},
@ -4395,15 +4403,15 @@ void server_context::release_slot_after_final_response(server_slot & slot) {
void server_context::send_token_results(completion_token_outputs& results, server_slot& slot, int32_t n) {
int count = 0;
bool released = false;
int32_t start_pos = slot.n_past - (int32_t)slot.token_buffer.size() + 1;
for (auto& it : results) {
bool has_next = process_token(it, slot);
// Clean up positional bans for the token we just confirmed/sent
slot.positional_bans.erase(start_pos + count);
count++;
if (!has_next) {
if (slot.stopped_limit && !slot.stopped_eos && !slot.stopped_word) {
@ -4436,7 +4444,7 @@ inline int32_t check_ban_phrase(server_slot& slot) {
std::string string_buffer;
std::vector<size_t> token_offsets;
for (const auto& it : slot.token_buffer) {
token_offsets.push_back(string_buffer.size());
string_buffer += it.text_to_send;
@ -4488,10 +4496,10 @@ inline int32_t check_ban_phrase(server_slot& slot) {
if (found) {
int32_t token_idx = -1;
for (size_t i = 0; i < token_offsets.size(); ++i) {
size_t len = (i == token_offsets.size() - 1)
? string_buffer.size() - token_offsets[i]
size_t len = (i == token_offsets.size() - 1)
? string_buffer.size() - token_offsets[i]
: token_offsets[i+1] - token_offsets[i];
if (best_start >= token_offsets[i] && best_start < token_offsets[i] + len) {
token_idx = (int32_t)i;
break;
@ -4509,7 +4517,7 @@ inline int32_t check_ban_phrase(server_slot& slot) {
inline void rewind_context(server_slot& slot, int32_t ban_pos) {
slot.rewind_count++;
int32_t buffer_start_pos = slot.n_past - (int32_t)slot.token_buffer.size() + 1;
int32_t n_keep_buffer = ban_pos - buffer_start_pos;
if (n_keep_buffer < 0) n_keep_buffer = 0;
@ -4518,9 +4526,9 @@ inline void rewind_context(server_slot& slot, int32_t ban_pos) {
int32_t n = 0;
for (auto result = slot.token_buffer.begin() + n_keep_buffer; result != slot.token_buffer.end(); result++) {
llama_token banned_tok = result->tok;
if (n == 0) {
LLAMA_LOG_DEBUG("Banned pattern detected at pos %d. Banning token %d ('%s') and rewinding.\n",
LLAMA_LOG_DEBUG("Banned pattern detected at pos %d. Banning token %d ('%s') and rewinding.\n",
ban_pos, banned_tok, result->text_to_send.c_str());
}
@ -4533,7 +4541,7 @@ inline void rewind_context(server_slot& slot, int32_t ban_pos) {
}
int32_t n_rewind_total = (slot.n_past + 1) - ban_pos;
size_t n_keep_cache = 0;
if (ban_pos > 0) {
n_keep_cache = (size_t)(ban_pos - 1);
@ -4546,13 +4554,13 @@ inline void rewind_context(server_slot& slot, int32_t ban_pos) {
if (n_keep_cache < slot.cache_tokens.size()) {
slot.sampled = slot.cache_tokens[n_keep_cache];
} else {
slot.sampled = 0;
slot.sampled = 0;
}
// Truncate cache
slot.cache_tokens.keep_first(n_keep_cache);
slot.n_past = slot.cache_tokens.n_tokens();
// Remove from KV cache
llama_kv_cache_seq_rm(slot.ctx, slot.id, slot.cache_tokens.pos_next(slot.n_past), -1);
@ -4590,13 +4598,13 @@ void server_context::buffer_and_check_string_ban(server_slot & slot, completion_
// Automatic / Heuristic logic
// Account for strings + regex + regex_ci
size_t total_bans = slot.ban_phrases.size() + slot.ban_regex.size() + slot.ban_regex_ci.size();
// Heuristic: Allow if under 20 OR under 2 * total_bans
// Conversely: Stop if >= 20 AND > 2 * total_bans
if (slot.rewind_count >= 20 && slot.rewind_count > 2 * total_bans) {
allow_rewind = false;
}
}
}
else if (slot.rewind_count_max > 0) {
// Strict limit logic
if (slot.rewind_count >= slot.rewind_count_max) {
@ -4613,7 +4621,7 @@ void server_context::buffer_and_check_string_ban(server_slot & slot, completion_
else if (buffer_full || !next_token) {
slot.rewind_status = false;
slot.rewind_count = 0;
if (!next_token) {
// send all remaining tokens
send_token_results(slot.token_buffer, slot);
@ -4625,7 +4633,7 @@ void server_context::buffer_and_check_string_ban(server_slot & slot, completion_
}
else {
// buffer the result, wait for more tokens to validate string
slot.sampled = result.tok;
slot.sampled = result.tok;
}
}