From f5e5753c325d00d4869dd668a0a82a9c5bd757ab Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Thu, 18 Jun 2026 09:03:40 +0200 Subject: [PATCH] Fix Qwen35 mtp warmup (#1987) * Use hidden state from prev token from qwen mtp * Fix Qwen35 MTP warmup * Cleanup + remove unnecessary crippling performance by not using accept to sample draft token * Provide API to gtet the model arch string --------- Co-authored-by: SamuelOliveirads --- common/speculative.cpp | 65 ++++++++++++++++++++++++++++++++++++------ include/llama.h | 2 ++ src/llama-arch.cpp | 10 +++++++ 3 files changed, 68 insertions(+), 9 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index 82819c9b..9b956c10 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -2347,6 +2347,8 @@ void common_speculative_checkpoint_restore( common_speculative_checkpoint_discard(ckpt, ctx); } +static bool mtp_model_uses_recurrent_conditioning(const common_speculative_state_mtp & state); + void common_speculative_commit( common_speculative * spec, llama_context * ctx, @@ -2364,6 +2366,7 @@ void common_speculative_commit( const common_speculative_type spec_type_used = spec->curr_impl != nullptr ? spec->curr_impl->type : COMMON_SPECULATIVE_TYPE_NONE; + const bool any_rejected = (int) ids.size() - 1 < n_draft; std::vector mtp_hidden_state_pre; @@ -2552,6 +2555,20 @@ static void mtp_store_target_hidden( stored.assign(hidden, hidden + width); } +static bool mtp_model_uses_recurrent_conditioning(const common_speculative_state_mtp & state) { + if (state.ctx_mtp == nullptr) { + return false; + } + + const llama_model * model = llama_get_model(state.ctx_mtp); + if (!llama_model_has_recurrent(model)) { + return false; + } + + std::string arch{llama_model_arch_string(model)}; + return arch == "qwen35" || arch == "qwen35moe"; +} + static void mtp_clear_target_hidden(common_speculative_state_mtp & state, llama_seq_id seq_id) { state.target_hidden_by_seq.erase(seq_id); state.draft_cache_by_seq.erase(seq_id); @@ -2726,6 +2743,12 @@ int32_t common_speculative_on_target_batch( return 0; } + if (features.width != mtp_state->n_embd) { + LOG_ERR("%s: MTP feature width mismatch: got %d expected %d\n", + __func__, features.width, mtp_state->n_embd); + return -1; + } + if (batch.n_seq_id == nullptr || batch.seq_id == nullptr || batch.n_seq_id[0] <= 0 || batch.seq_id[0] == nullptr) { return -1; } @@ -2742,7 +2765,6 @@ int32_t common_speculative_on_target_batch( return -1; } - const float * first_hidden = hidden_rows_storage.data(); const float * last_hidden = hidden_rows_storage.data() + (size_t) (batch.n_tokens - 1) * features.width; mtp_store_target_hidden(*mtp_state, seq_id, last_hidden, features.width); @@ -2751,16 +2773,41 @@ int32_t common_speculative_on_target_batch( return 0; } - if (is_prompt_warmup) { - if (!llama_set_draft_input_hidden_state_copy(mtp_state->ctx_mtp, hidden_rows_storage.data(), hidden_rows_storage.size())) { - return -1; - } - const int32_t ret = mtp_update_kv_cache(mtp_state->ctx_mtp, batch, true); - mtp_invalidate_cached_draft(*mtp_state, seq_id); - return ret; + if (!is_prompt_warmup) { + return mtp_accept_batch(*mtp_state, batch, seq_id, hidden_rows_storage.data()); } - return mtp_accept_batch(*mtp_state, batch, seq_id, first_hidden); + const bool uses_shifted_hidden_rows = mtp_model_uses_recurrent_conditioning(*mtp_state); + std::vector previous_hidden_storage; + if (uses_shifted_hidden_rows) { + const auto hidden_it = mtp_state->target_hidden_by_seq.find(seq_id); + if (hidden_it != mtp_state->target_hidden_by_seq.end() && (int32_t) hidden_it->second.size() == features.width) { + previous_hidden_storage = hidden_it->second; + } else { + previous_hidden_storage.assign(features.width, 0.0f); + } + } + + const float * conditioned_hidden_rows = hidden_rows_storage.data(); + std::vector conditioned_hidden_storage; + if (uses_shifted_hidden_rows) { + conditioned_hidden_storage.resize(hidden_rows_storage.size()); + std::copy(previous_hidden_storage.begin(), previous_hidden_storage.end(), conditioned_hidden_storage.begin()); + if (batch.n_tokens > 1) { + std::copy( + hidden_rows_storage.begin(), + hidden_rows_storage.begin() + (size_t) (batch.n_tokens - 1) * features.width, + conditioned_hidden_storage.begin() + features.width); + } + conditioned_hidden_rows = conditioned_hidden_storage.data(); + } + + if (!llama_set_draft_input_hidden_state_copy(mtp_state->ctx_mtp, conditioned_hidden_rows, hidden_rows_storage.size())) { + return -1; + } + const int32_t ret = mtp_update_kv_cache(mtp_state->ctx_mtp, batch, true); + mtp_invalidate_cached_draft(*mtp_state, seq_id); + return ret; } common_speculative_type common_speculative_current_type(const common_speculative * spec) { diff --git a/include/llama.h b/include/llama.h index 754a0643..1eb40b14 100644 --- a/include/llama.h +++ b/include/llama.h @@ -698,6 +698,8 @@ extern "C" { LLAMA_API bool llama_model_is_split_mode_graph(const struct llama_model * model); + LLAMA_API const char * llama_model_arch_string(const struct llama_model * model); + // Returns 0 on success LLAMA_API uint32_t llama_model_quantize( const char * fname_inp, diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 77226d3f..8fa207ce 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -1,5 +1,6 @@ #include "llama-arch.h" #include "llama-impl.h" +#include "llama-model.h" #include @@ -101,6 +102,15 @@ llm_arch llm_arch_from_string(const std::string & name) { return LLM_ARCH_UNKNOWN; } +const char * llama_model_arch_string(const struct llama_model * model) { + static const char * unknown = "unknown"; + if (!model) return unknown; + if (auto it = LLM_ARCH_NAMES.find(model->arch); it != LLM_ARCH_NAMES.end()) { + return it->second; + } + return unknown; +} + static const std::map LLM_KV_NAMES = { { LLM_KV_GENERAL_TYPE, "general.type" }, { LLM_KV_GENERAL_ARCHITECTURE, "general.architecture" },