Disable K Hadamard transform if K-head size is not a power of 2

This commit is contained in:
Kawrakow 2026-05-26 07:19:08 +00:00
parent b4e1d916c5
commit f3e929c25e
2 changed files with 34 additions and 10 deletions

View File

@ -515,6 +515,10 @@ struct llama_model {
return tensor_overrides;
}
bool is_mla_model() const {
return arch == LLM_ARCH_DEEPSEEK2 || arch == LLM_ARCH_GLM_DSA || arch == LLM_ARCH_MISTRAL4;
}
size_t cache_size(int il, ggml_type type_k, ggml_type type_v, uint32_t kv_size, int mla_attn, int n_seq_max, bool flash_attn) const;
void set_tensor_overrides(const llama_model_params& params);

View File

@ -798,7 +798,7 @@ static bool llama_kv_cache_init(
}
}
bool is_mla_attn = model.arch == LLM_ARCH_DEEPSEEK2 || model.arch == LLM_ARCH_GLM_DSA || model.arch == LLM_ARCH_MISTRAL4;
bool is_mla_attn = model.is_mla_model();
bool split_cache = false;
bool replicate_mla = false;
@ -2111,7 +2111,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
// general kv
LLAMA_LOG_INFO("%s: general.name = %s\n", __func__, model.name.c_str());
if (model.arch == LLM_ARCH_DEEPSEEK2 || model.arch == LLM_ARCH_GLM_DSA || model.arch == LLM_ARCH_MISTRAL4) {
if (model.is_mla_model()) {
LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead);
LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q);
LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv);
@ -2241,7 +2241,7 @@ static void llm_requantize_output_tensor(llama_model & model, ggml_type new_type
}
static void llm_prepare_mla(llama_model & model, int mla) {
if (model.arch != LLM_ARCH_DEEPSEEK2 && model.arch != LLM_ARCH_GLM_DSA && model.arch != LLM_ARCH_MISTRAL4) return;
if (!model.is_mla_model()) return;
const auto& hparams = model.hparams;
const int n_layer = model.layers.size();
int n_to_compute = 0;
@ -2815,7 +2815,7 @@ static void llm_prepare_mla(llama_model & model, int mla) {
// skips the runtime cache_nope un-Hadamard. Math identity by H^T H = I.
static void llm_apply_khad_pretransform(llama_model & model) {
if (model.khad_pretransformed) return;
if (model.arch != LLM_ARCH_DEEPSEEK2 && model.arch != LLM_ARCH_GLM_DSA && model.arch != LLM_ARCH_MISTRAL4) return;
if (!model.is_mla_model()) return;
// High-enough bpw to survive one quant->F32->H->quant roundtrip within PPL noise.
// Cliff is ~2.7 bpw: IQ3_XXS (3.06) sits at +0.05 noise edge; IQ2_XS (2.31) drifts +0.20.
@ -3066,7 +3066,7 @@ static std::pair<std::vector<double>, double> get_layer_sizes(const llama_model_
ggml_tensor * wkv_b = nullptr;
};
std::vector<mla_tensors> mla_tensors;
bool has_mla = model.arch == LLM_ARCH_DEEPSEEK2 || model.arch == LLM_ARCH_GLM_DSA || model.arch == LLM_ARCH_MISTRAL4;
bool has_mla = model.is_mla_model();
if (has_mla) {
mla_tensors.resize(n_layer);
}
@ -3443,7 +3443,7 @@ static bool llm_load_tensors(
model.main_gpu = device_count - 1;
}
if (model.arch == LLM_ARCH_DEEPSEEK2 || model.arch == LLM_ARCH_GLM_DSA || model.arch == LLM_ARCH_MISTRAL4) {
if (model.is_mla_model()) {
if (model.n_gpu_layers > 0 && model.n_gpu_layers < model.hparams.n_layer && mla_attn != 3) {
LLAMA_LOG_WARN("=============================================================================\n");
LLAMA_LOG_WARN("MLA models with ngl < n_layer and split mode graph do not work with mla = %d\n", mla_attn);
@ -3918,7 +3918,7 @@ static bool llm_load_tensors(
}
}
if ((model.arch == LLM_ARCH_DEEPSEEK2 || model.arch == LLM_ARCH_GLM_DSA || model.arch == LLM_ARCH_MISTRAL4)) {
if (model.is_mla_model()) {
// -sm graph/attn needs wk_b->extra populated; run prepare even under dry-run.
const bool graph_mode = (model.split_mode == LLAMA_SPLIT_MODE_GRAPH ||
model.split_mode == LLAMA_SPLIT_MODE_ATTN);
@ -5889,7 +5889,7 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
}
static bool get_can_shift(struct llama_context & lctx) {
bool no_shift = lctx.model.arch == LLM_ARCH_DEEPSEEK2 || lctx.model.arch == LLM_ARCH_GLM_DSA; // not supported due to MLA
bool no_shift = lctx.model.is_mla_model();
no_shift = no_shift || lctx.model.hparams.rope_type == LLAMA_ROPE_TYPE_IMROPE;
return !no_shift;
}
@ -6648,11 +6648,31 @@ struct llama_context * llama_init_from_model(
LLAMA_LOG_WARN("%s: there is no point in Hadamard transforms with not quantized K-cache. Turning K-cache Hadamard off\n", __func__);
params.k_cache_hadamard = false;
}
if (params.k_cache_hadamard) {
for (int il = 0; il < model->hparams.n_layer; ++il) {
int n_head_k = model->hparams.n_embd_head_k(il);
if ((n_head_k & ~(n_head_k - 1)) != n_head_k) {
LLAMA_LOG_WARN("%s: K-head size %d in layer %d is not a power of 2. Turning K-cache Hadamard off\n",
__func__, n_head_k, il);
params.k_cache_hadamard = false;
}
}
}
if (params.v_cache_hadamard && !ggml_is_quantized(params.type_v)) {
LLAMA_LOG_WARN("%s: there is no point in Hadamard transforms with not quantized V-cache. Turning V-cache Hadamard off\n", __func__);
params.v_cache_hadamard = false;
}
if (params.v_cache_hadamard) {
for (int il = 0; il < model->hparams.n_layer; ++il) {
int n_head_v = model->hparams.n_embd_head_v(il);
if ((n_head_v & ~(n_head_v - 1)) != n_head_v) {
LLAMA_LOG_WARN("%s: V-head size %d in layer %d is not a power of 2. Turning V-cache Hadamard off\n",
__func__, n_head_v, il);
params.v_cache_hadamard = false;
}
}
}
llama_context * ctx = new llama_context(*model);
@ -6777,7 +6797,7 @@ struct llama_context * llama_init_from_model(
params.seed = time(NULL);
}
if (model->arch != LLM_ARCH_DEEPSEEK2 && model->arch != LLM_ARCH_GLM_DSA && model->arch != LLM_ARCH_MISTRAL4 && cparams.mla_attn != 0) {
if (!model->is_mla_model() && cparams.mla_attn != 0) {
cparams.mla_attn = 0;
} else {
if (model->n_gpu_layers > 0 && model->n_gpu_layers < model->hparams.n_layer && cparams.mla_attn != 3) {
@ -6810,7 +6830,7 @@ struct llama_context * llama_init_from_model(
LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn);
if (model->arch == LLM_ARCH_DEEPSEEK2 || model->arch == LLM_ARCH_GLM_DSA || model->arch == LLM_ARCH_MISTRAL4) {
if (model->is_mla_model()) {
LLAMA_LOG_INFO("%s: mla_attn = %d\n", __func__, cparams.mla_attn);
}
LLAMA_LOG_INFO("%s: attn_max_b = %d\n", __func__, cparams.attn_max_batch);