mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-06-28 04:30:15 -05:00
Fix Qwen35 mtp warmup (#1987)
* Use hidden state from prev token from qwen mtp * Fix Qwen35 MTP warmup * Cleanup + remove unnecessary crippling performance by not using accept to sample draft token * Provide API to gtet the model arch string --------- Co-authored-by: SamuelOliveirads <samueloliveira32df@gmail.com>
This commit is contained in:
parent
71af16a6b7
commit
f5e5753c32
@ -2347,6 +2347,8 @@ void common_speculative_checkpoint_restore(
|
||||
common_speculative_checkpoint_discard(ckpt, ctx);
|
||||
}
|
||||
|
||||
static bool mtp_model_uses_recurrent_conditioning(const common_speculative_state_mtp & state);
|
||||
|
||||
void common_speculative_commit(
|
||||
common_speculative * spec,
|
||||
llama_context * ctx,
|
||||
@ -2364,6 +2366,7 @@ void common_speculative_commit(
|
||||
const common_speculative_type spec_type_used = spec->curr_impl != nullptr
|
||||
? spec->curr_impl->type
|
||||
: COMMON_SPECULATIVE_TYPE_NONE;
|
||||
|
||||
const bool any_rejected = (int) ids.size() - 1 < n_draft;
|
||||
std::vector<float> mtp_hidden_state_pre;
|
||||
|
||||
@ -2552,6 +2555,20 @@ static void mtp_store_target_hidden(
|
||||
stored.assign(hidden, hidden + width);
|
||||
}
|
||||
|
||||
static bool mtp_model_uses_recurrent_conditioning(const common_speculative_state_mtp & state) {
|
||||
if (state.ctx_mtp == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const llama_model * model = llama_get_model(state.ctx_mtp);
|
||||
if (!llama_model_has_recurrent(model)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string arch{llama_model_arch_string(model)};
|
||||
return arch == "qwen35" || arch == "qwen35moe";
|
||||
}
|
||||
|
||||
static void mtp_clear_target_hidden(common_speculative_state_mtp & state, llama_seq_id seq_id) {
|
||||
state.target_hidden_by_seq.erase(seq_id);
|
||||
state.draft_cache_by_seq.erase(seq_id);
|
||||
@ -2726,6 +2743,12 @@ int32_t common_speculative_on_target_batch(
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (features.width != mtp_state->n_embd) {
|
||||
LOG_ERR("%s: MTP feature width mismatch: got %d expected %d\n",
|
||||
__func__, features.width, mtp_state->n_embd);
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (batch.n_seq_id == nullptr || batch.seq_id == nullptr || batch.n_seq_id[0] <= 0 || batch.seq_id[0] == nullptr) {
|
||||
return -1;
|
||||
}
|
||||
@ -2742,7 +2765,6 @@ int32_t common_speculative_on_target_batch(
|
||||
return -1;
|
||||
}
|
||||
|
||||
const float * first_hidden = hidden_rows_storage.data();
|
||||
const float * last_hidden = hidden_rows_storage.data() + (size_t) (batch.n_tokens - 1) * features.width;
|
||||
mtp_store_target_hidden(*mtp_state, seq_id, last_hidden, features.width);
|
||||
|
||||
@ -2751,16 +2773,41 @@ int32_t common_speculative_on_target_batch(
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (is_prompt_warmup) {
|
||||
if (!llama_set_draft_input_hidden_state_copy(mtp_state->ctx_mtp, hidden_rows_storage.data(), hidden_rows_storage.size())) {
|
||||
return -1;
|
||||
}
|
||||
const int32_t ret = mtp_update_kv_cache(mtp_state->ctx_mtp, batch, true);
|
||||
mtp_invalidate_cached_draft(*mtp_state, seq_id);
|
||||
return ret;
|
||||
if (!is_prompt_warmup) {
|
||||
return mtp_accept_batch(*mtp_state, batch, seq_id, hidden_rows_storage.data());
|
||||
}
|
||||
|
||||
return mtp_accept_batch(*mtp_state, batch, seq_id, first_hidden);
|
||||
const bool uses_shifted_hidden_rows = mtp_model_uses_recurrent_conditioning(*mtp_state);
|
||||
std::vector<float> previous_hidden_storage;
|
||||
if (uses_shifted_hidden_rows) {
|
||||
const auto hidden_it = mtp_state->target_hidden_by_seq.find(seq_id);
|
||||
if (hidden_it != mtp_state->target_hidden_by_seq.end() && (int32_t) hidden_it->second.size() == features.width) {
|
||||
previous_hidden_storage = hidden_it->second;
|
||||
} else {
|
||||
previous_hidden_storage.assign(features.width, 0.0f);
|
||||
}
|
||||
}
|
||||
|
||||
const float * conditioned_hidden_rows = hidden_rows_storage.data();
|
||||
std::vector<float> conditioned_hidden_storage;
|
||||
if (uses_shifted_hidden_rows) {
|
||||
conditioned_hidden_storage.resize(hidden_rows_storage.size());
|
||||
std::copy(previous_hidden_storage.begin(), previous_hidden_storage.end(), conditioned_hidden_storage.begin());
|
||||
if (batch.n_tokens > 1) {
|
||||
std::copy(
|
||||
hidden_rows_storage.begin(),
|
||||
hidden_rows_storage.begin() + (size_t) (batch.n_tokens - 1) * features.width,
|
||||
conditioned_hidden_storage.begin() + features.width);
|
||||
}
|
||||
conditioned_hidden_rows = conditioned_hidden_storage.data();
|
||||
}
|
||||
|
||||
if (!llama_set_draft_input_hidden_state_copy(mtp_state->ctx_mtp, conditioned_hidden_rows, hidden_rows_storage.size())) {
|
||||
return -1;
|
||||
}
|
||||
const int32_t ret = mtp_update_kv_cache(mtp_state->ctx_mtp, batch, true);
|
||||
mtp_invalidate_cached_draft(*mtp_state, seq_id);
|
||||
return ret;
|
||||
}
|
||||
|
||||
common_speculative_type common_speculative_current_type(const common_speculative * spec) {
|
||||
|
||||
@ -698,6 +698,8 @@ extern "C" {
|
||||
|
||||
LLAMA_API bool llama_model_is_split_mode_graph(const struct llama_model * model);
|
||||
|
||||
LLAMA_API const char * llama_model_arch_string(const struct llama_model * model);
|
||||
|
||||
// Returns 0 on success
|
||||
LLAMA_API uint32_t llama_model_quantize(
|
||||
const char * fname_inp,
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
#include "llama-arch.h"
|
||||
#include "llama-impl.h"
|
||||
#include "llama-model.h"
|
||||
|
||||
#include <map>
|
||||
|
||||
@ -101,6 +102,15 @@ llm_arch llm_arch_from_string(const std::string & name) {
|
||||
return LLM_ARCH_UNKNOWN;
|
||||
}
|
||||
|
||||
const char * llama_model_arch_string(const struct llama_model * model) {
|
||||
static const char * unknown = "unknown";
|
||||
if (!model) return unknown;
|
||||
if (auto it = LLM_ARCH_NAMES.find(model->arch); it != LLM_ARCH_NAMES.end()) {
|
||||
return it->second;
|
||||
}
|
||||
return unknown;
|
||||
}
|
||||
|
||||
static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||
{ LLM_KV_GENERAL_TYPE, "general.type" },
|
||||
{ LLM_KV_GENERAL_ARCHITECTURE, "general.architecture" },
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user