Fix Qwen35 mtp warmup (#1987)

* Use hidden state from prev token from qwen mtp

* Fix Qwen35 MTP warmup

* Cleanup + remove unnecessary crippling performance by not using accept to sample draft token

* Provide API to gtet the model arch string

---------

Co-authored-by: SamuelOliveirads <samueloliveira32df@gmail.com>
This commit is contained in:
Kawrakow 2026-06-18 09:03:40 +02:00 committed by GitHub
parent 71af16a6b7
commit f5e5753c32
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 68 additions and 9 deletions

View File

@ -2347,6 +2347,8 @@ void common_speculative_checkpoint_restore(
common_speculative_checkpoint_discard(ckpt, ctx);
}
static bool mtp_model_uses_recurrent_conditioning(const common_speculative_state_mtp & state);
void common_speculative_commit(
common_speculative * spec,
llama_context * ctx,
@ -2364,6 +2366,7 @@ void common_speculative_commit(
const common_speculative_type spec_type_used = spec->curr_impl != nullptr
? spec->curr_impl->type
: COMMON_SPECULATIVE_TYPE_NONE;
const bool any_rejected = (int) ids.size() - 1 < n_draft;
std::vector<float> mtp_hidden_state_pre;
@ -2552,6 +2555,20 @@ static void mtp_store_target_hidden(
stored.assign(hidden, hidden + width);
}
static bool mtp_model_uses_recurrent_conditioning(const common_speculative_state_mtp & state) {
if (state.ctx_mtp == nullptr) {
return false;
}
const llama_model * model = llama_get_model(state.ctx_mtp);
if (!llama_model_has_recurrent(model)) {
return false;
}
std::string arch{llama_model_arch_string(model)};
return arch == "qwen35" || arch == "qwen35moe";
}
static void mtp_clear_target_hidden(common_speculative_state_mtp & state, llama_seq_id seq_id) {
state.target_hidden_by_seq.erase(seq_id);
state.draft_cache_by_seq.erase(seq_id);
@ -2726,6 +2743,12 @@ int32_t common_speculative_on_target_batch(
return 0;
}
if (features.width != mtp_state->n_embd) {
LOG_ERR("%s: MTP feature width mismatch: got %d expected %d\n",
__func__, features.width, mtp_state->n_embd);
return -1;
}
if (batch.n_seq_id == nullptr || batch.seq_id == nullptr || batch.n_seq_id[0] <= 0 || batch.seq_id[0] == nullptr) {
return -1;
}
@ -2742,7 +2765,6 @@ int32_t common_speculative_on_target_batch(
return -1;
}
const float * first_hidden = hidden_rows_storage.data();
const float * last_hidden = hidden_rows_storage.data() + (size_t) (batch.n_tokens - 1) * features.width;
mtp_store_target_hidden(*mtp_state, seq_id, last_hidden, features.width);
@ -2751,16 +2773,41 @@ int32_t common_speculative_on_target_batch(
return 0;
}
if (is_prompt_warmup) {
if (!llama_set_draft_input_hidden_state_copy(mtp_state->ctx_mtp, hidden_rows_storage.data(), hidden_rows_storage.size())) {
if (!is_prompt_warmup) {
return mtp_accept_batch(*mtp_state, batch, seq_id, hidden_rows_storage.data());
}
const bool uses_shifted_hidden_rows = mtp_model_uses_recurrent_conditioning(*mtp_state);
std::vector<float> previous_hidden_storage;
if (uses_shifted_hidden_rows) {
const auto hidden_it = mtp_state->target_hidden_by_seq.find(seq_id);
if (hidden_it != mtp_state->target_hidden_by_seq.end() && (int32_t) hidden_it->second.size() == features.width) {
previous_hidden_storage = hidden_it->second;
} else {
previous_hidden_storage.assign(features.width, 0.0f);
}
}
const float * conditioned_hidden_rows = hidden_rows_storage.data();
std::vector<float> conditioned_hidden_storage;
if (uses_shifted_hidden_rows) {
conditioned_hidden_storage.resize(hidden_rows_storage.size());
std::copy(previous_hidden_storage.begin(), previous_hidden_storage.end(), conditioned_hidden_storage.begin());
if (batch.n_tokens > 1) {
std::copy(
hidden_rows_storage.begin(),
hidden_rows_storage.begin() + (size_t) (batch.n_tokens - 1) * features.width,
conditioned_hidden_storage.begin() + features.width);
}
conditioned_hidden_rows = conditioned_hidden_storage.data();
}
if (!llama_set_draft_input_hidden_state_copy(mtp_state->ctx_mtp, conditioned_hidden_rows, hidden_rows_storage.size())) {
return -1;
}
const int32_t ret = mtp_update_kv_cache(mtp_state->ctx_mtp, batch, true);
mtp_invalidate_cached_draft(*mtp_state, seq_id);
return ret;
}
return mtp_accept_batch(*mtp_state, batch, seq_id, first_hidden);
}
common_speculative_type common_speculative_current_type(const common_speculative * spec) {

View File

@ -698,6 +698,8 @@ extern "C" {
LLAMA_API bool llama_model_is_split_mode_graph(const struct llama_model * model);
LLAMA_API const char * llama_model_arch_string(const struct llama_model * model);
// Returns 0 on success
LLAMA_API uint32_t llama_model_quantize(
const char * fname_inp,

View File

@ -1,5 +1,6 @@
#include "llama-arch.h"
#include "llama-impl.h"
#include "llama-model.h"
#include <map>
@ -101,6 +102,15 @@ llm_arch llm_arch_from_string(const std::string & name) {
return LLM_ARCH_UNKNOWN;
}
const char * llama_model_arch_string(const struct llama_model * model) {
static const char * unknown = "unknown";
if (!model) return unknown;
if (auto it = LLM_ARCH_NAMES.find(model->arch); it != LLM_ARCH_NAMES.end()) {
return it->second;
}
return unknown;
}
static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_GENERAL_TYPE, "general.type" },
{ LLM_KV_GENERAL_ARCHITECTURE, "general.architecture" },