From aefb8bdd997a6e88538004b5ffdb8678fbab0d5e Mon Sep 17 00:00:00 2001 From: David Young <1213472+davidsyoung@users.noreply.github.com> Date: Thu, 21 May 2026 05:29:15 +0100 Subject: [PATCH] MLA TP -khad: ggml_dequant_hadamard fused op + wv_b/wk_b_pp Hadamard fold (#1852) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ggml: ggml_dequant_hadamard fused op for MLA -khad path Adds a new ggml op that fuses (ggml_cast -> F32) + (ggml_hadamard) into a single kernel. Reads a quantized (or F16/F32) source and produces a per- Hadamard-block F32 chunk with the inverse transform applied, without materializing a full-size F32 intermediate buffer. Motivation: the MLA pp_opt path in build_deepseek2.cpp un-encodes the H-applied cache_nope view at every PP call. Today that runs as a cast (quant -> F32) followed by a separate ggml_hadamard kernel, costing two full-size F32 passes per layer per rank per call. Fusing them halves the bandwidth on the un-encode and removes one kernel launch. CUDA kernels in dequant_hadamard.cu lift the Walsh-Hadamard butterfly from hadamard.cu and dequant helpers from dequantize.cuh: * qr=1 layout (q8_0): consecutive dequant pair, stage 1 fused with load * qr=2 layout (q4_0 / q4_1 / q5_0 / q5_1 / q6_0 / iq4_nl): dequant pair at stride qk/2, explicit stage 1 after sync * F16 has a dedicated kernel * F32 source falls back to the standalone Hadamard op CPU impl in iqk_cpu_ops.cpp composes the existing type_traits.to_float dequant with fast_ht for graph completeness. nh in {64, 128, 256, 512}. * MLA-TP: Hadamard pretransform of wv_b/wk_b_pp for -khad Fold the 64-block orthonormal Hadamard into wv_b and wk_b_pp once at context init so the pp_opt mul_mats consume the K cache in its on-disk encoded basis. The per-PP-call cache_nope un-Hadamard is then skipped (rope half still un-applied — it goes to FA via concat, no wk_b multiply). Math is identity by H^T H = I: mul_mat(H@wv_b, H@cache) = wv_b^T @ cache. For mla=2/3 absorb, composes correctly with the existing post-FA ggml_hadamard(kqv_compressed, 64). All-or-nothing across layers under a castable type-allowlist (excludes 1-3 bpw IQ types whose requant blows up beyond PPL noise). Models with ineligible weights fall back to the runtime un-Hadamard path unchanged. Composes with the fused ggml_dequant_hadamard op (prior commit): with the fold active only the rope half still runs the runtime transform, via the fused kernel. * MLA-TP: fix TG with -khad after wv_b/wk_b_pp fold The absorb branch of build_deepseek2_tp_attention applies ggml_hadamard to kqv_compressed after FA, then multiplies by wv_b. Pre-fold this was needed because wv_b was un-encoded; with the wv_b fold (prior commit) the mul_mat already expects H-encoded kqv_compressed: mul_mat(H @ wv_b, kqv_encoded) = wv_b^T @ H @ H @ kqv_unencoded = wv_b^T @ kqv_unencoded (H @ H = I) Skip the post-FA hadamard when model.khad_pretransformed is set so the two H applications cancel instead of double-applying. Affects the absorb branch: TG (n_tokens=1), short-context PP (n_kv < 1024), and models without wk_b_pp. Long-context PP goes through the pp_opt branch and is unrelated/unchanged. Reported by @ikawrakow on PR 1852. Verified across mla={1,2,3} x khad={on,off} x -ctk={q8_0,q4_0} on GLM-4.7-Flash IQ5_K and the unsloth IQ4_XS variant ik used to reproduce. * ggml_hadamard: accept F16 and quant sources; drop GGML_OP_DEQUANT_HADAMARD Per @ikawrakow review on PR 1852: subsume the per-source-type dispatch into the existing GGML_OP_HADAMARD instead of carrying a separate enum entry, op constructor, and standalone files. ggml_hadamard's API is unchanged from the call-site perspective. The constructor's F32-only assertion is dropped; ggml_cuda_op_hadamard and iqk_hadamard now dispatch internally: - F32 source: existing F32 butterfly (unchanged) - F16 source: dedicated kernel - q8_0 / q4_0 / q4_1 / q5_0 / q5_1 / q6_0 / iq4_nl: fused dequant + butterfly kernel (lifted from the deleted dequant_hadamard.cu) - CPU side composes traits.to_float with fast_ht Net diff: -80 lines. Removes dequant_hadamard.{cu,cuh}, the enum entry, op table rows, ggml_dequant_hadamard constructor, dispatch cases, and the DEQUANT_HADAMARD supports_op block. Verified clean build + TG smoke (mla=3 +khad q8 on GLM-4.7-Flash-IQ4_XS, same coherent output as prior commit on feat/dequant-hadamard). --- ggml/include/ggml.h | 1 + ggml/src/ggml-cuda.cu | 22 +++- ggml/src/ggml-cuda/hadamard.cu | 203 +++++++++++++++++++++++++++------ ggml/src/ggml.c | 1 - ggml/src/iqk/iqk_cpu_ops.cpp | 30 ++++- src/graphs/build_deepseek2.cpp | 26 ++--- src/llama-model.h | 3 + src/llama.cpp | 128 +++++++++++++++++++++ 8 files changed, 356 insertions(+), 58 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index b5d60241..43b4ee5f 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1115,6 +1115,7 @@ extern "C" { struct ggml_tensor * a, struct ggml_tensor * b); + // Source may be F32, F16, or a supported quantized type; output is always F32. GGML_API struct ggml_tensor * ggml_hadamard( struct ggml_context * ctx, struct ggml_tensor * a, diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index bc69b10c..efe5db79 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -4786,9 +4786,25 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_OP_FAKE_CPY: case GGML_OP_ARGMAX: return true; - case GGML_OP_HADAMARD: - return (op->op_params[0] == 64 || op->op_params[0] == 128 || op->op_params[0] == 256 || op->op_params[0] == 512) - && op->ne[0] % op->op_params[0] == 0 && op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_HADAMARD: { + if (!(op->op_params[0] == 64 || op->op_params[0] == 128 || op->op_params[0] == 256 || op->op_params[0] == 512)) return false; + if (op->ne[0] % op->op_params[0] != 0) return false; + if (op->type != GGML_TYPE_F32) return false; + switch (op->src[0]->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q6_0: + case GGML_TYPE_IQ4_NL: + return true; + default: + return false; + } + } case GGML_OP_DUP: case GGML_OP_REPEAT: case GGML_OP_CONCAT: diff --git a/ggml/src/ggml-cuda/hadamard.cu b/ggml/src/ggml-cuda/hadamard.cu index e1d8dc13..14b3a543 100644 --- a/ggml/src/ggml-cuda/hadamard.cu +++ b/ggml/src/ggml-cuda/hadamard.cu @@ -1,46 +1,142 @@ #include "hadamard.cuh" +#include "dequantize.cuh" + +#if defined(_MSC_VER) +#pragma warning(disable: 4244 4267) // possible loss of data +#include +#include +#include +#include +#include +static inline int popcount(uint32_t x) { return __popcnt(x); } +#else +static inline int popcount(uint32_t x) { return __builtin_popcount(x); } +#endif + +template +static __device__ __forceinline__ void hadamard_butterfly(float * ys, int tid, float & scale) { + constexpr float ksqrt2 = 0.707106781f; + #pragma unroll + for (int h = 2; h < nh; h <<= 1) { + __syncthreads(); + const int ii = tid/h, jj = tid%h; + const int j = 2*h*ii + jj; + const float u = ys[j], v = ys[j+h]; + ys[j+0] = u + v; + ys[j+h] = u - v; + scale *= ksqrt2; + } +} template static __global__ void hadamard_f32(const char * src, char * dst, int ne0, size_t nb01, size_t nb02, size_t nb03, size_t nb1, size_t nb2, size_t nb3) { constexpr float ksqrt2 = 0.707106781f; - - int nc = ne0/nh; - int ii1 = blockIdx.x; - int i1 = ii1 / nc; - int ic = ii1 % nc; - int i2 = blockIdx.y; - int i3 = blockIdx.z; - - int tid = threadIdx.x; + const int nc = ne0/nh; + const int ii1 = blockIdx.x; + const int i1 = ii1 / nc; + const int ic = ii1 % nc; + const int i2 = blockIdx.y; + const int i3 = blockIdx.z; + const int tid = threadIdx.x; const float * x = (const float *)((const char *)src + i1*nb01 + i2*nb02 + i3*nb03) + ic*nh; float * y = ( float *)((const char *)dst + i1*nb1 + i2*nb2 + i3*nb3) + ic*nh; __shared__ float ys[nh]; - ys[2*tid+0] = x[2*tid+0] + x[2*tid+1]; ys[2*tid+1] = x[2*tid+0] - x[2*tid+1]; - float scale = ksqrt2; -#pragma unroll - for (int h = 2; h < nh; h <<= 1) { - __syncthreads(); - int ii = tid/h, jj = tid%h; - int j = 2*h*ii+jj; - float u = ys[j], v = ys[j+h]; - ys[j+0] = u + v; - ys[j+h] = u - v; - scale *= ksqrt2; - } + hadamard_butterfly(ys, tid, scale); __syncthreads(); y[2*tid+0] = ys[2*tid+0] * scale; y[2*tid+1] = ys[2*tid+1] * scale; } +template +static __global__ void hadamard_f16(const char * src, char * dst, int ne0, + size_t nb01, size_t nb02, size_t nb03, size_t nb1, size_t nb2, size_t nb3) { + + constexpr float ksqrt2 = 0.707106781f; + const int nc = ne0/nh; + const int ii1 = blockIdx.x; + const int i1 = ii1 / nc; + const int ic = ii1 % nc; + const int i2 = blockIdx.y; + const int i3 = blockIdx.z; + const int tid = threadIdx.x; + + const half * x = (const half *)((const char *)src + i1*nb01 + i2*nb02 + i3*nb03) + ic*nh; + float * y = ( float *)((const char *)dst + i1*nb1 + i2*nb2 + i3*nb3) + ic*nh; + + __shared__ float ys[nh]; + const float a = __half2float(x[2*tid + 0]); + const float b = __half2float(x[2*tid + 1]); + ys[2*tid + 0] = a + b; + ys[2*tid + 1] = a - b; + float scale = ksqrt2; + + hadamard_butterfly(ys, tid, scale); + + __syncthreads(); + y[2*tid + 0] = ys[2*tid + 0] * scale; + y[2*tid + 1] = ys[2*tid + 1] * scale; +} + +template +static __global__ void hadamard_quant(const char * src, char * dst, int ne0, + size_t nb01, size_t nb02, size_t nb03, size_t nb1, size_t nb2, size_t nb3) { + + constexpr float ksqrt2 = 0.707106781f; + const int nc = ne0/nh; + const int ii1 = blockIdx.x; + const int i1 = ii1 / nc; + const int ic = ii1 % nc; + const int i2 = blockIdx.y; + const int i3 = blockIdx.z; + const int tid = threadIdx.x; + + const void * row_src = (const char *)src + i1*nb01 + i2*nb02 + i3*nb03; + float * y = (float *)((const char *)dst + i1*nb1 + i2*nb2 + i3*nb3) + ic*nh; + + __shared__ float ys[nh]; + float scale = ksqrt2; + + if (!qr2) { + const int abs_off = ic*nh + 2*tid; + const int ib = abs_off / qk; + const int iqs = abs_off % qk; + dfloat2 v; + dequant(row_src, ib, iqs, v); + ys[2*tid + 0] = (float)v.x + (float)v.y; + ys[2*tid + 1] = (float)v.x - (float)v.y; + } else { + constexpr int qk_half = qk/2; + const int b = tid / qk_half; + const int iqs = tid % qk_half; + const int ib = ic*(nh/qk) + b; + dfloat2 v; + dequant(row_src, ib, iqs, v); + ys[b*qk + iqs + 0 ] = (float)v.x; + ys[b*qk + iqs + qk_half] = (float)v.y; + __syncthreads(); + const float a = ys[2*tid + 0]; + const float c = ys[2*tid + 1]; + __syncthreads(); + ys[2*tid + 0] = a + c; + ys[2*tid + 1] = a - c; + } + + hadamard_butterfly(ys, tid, scale); + + __syncthreads(); + y[2*tid + 0] = ys[2*tid + 0] * scale; + y[2*tid + 1] = ys[2*tid + 1] * scale; +} + static void hadamard_f32_cuda(int nh, const char * x, char * y, int ne0, int ne1, int ne2, int ne3, size_t nb01, size_t nb02, size_t nb03, size_t nb1, size_t nb2, size_t nb3, cudaStream_t stream) { int nc = ne0/nh; @@ -55,29 +151,64 @@ static void hadamard_f32_cuda(int nh, const char * x, char * y, int ne0, int ne1 } } -#if defined(_MSC_VER) -#pragma warning(disable: 4244 4267) // possible loss of data -#include -#include -#include -#include -#include -static inline int popcount(uint32_t x) { return __popcnt(x); } -#else -static inline int popcount(uint32_t x) { return __builtin_popcount(x); } -#endif +#define LAUNCH_HADAMARD_F16(NH) \ + hadamard_f16<<>>( \ + (const char *)src->data, (char *)dst->data, src->ne[0], \ + src->nb[1], src->nb[2], src->nb[3], dst->nb[1], dst->nb[2], dst->nb[3]) +#define DISPATCH_HADAMARD_F16_NH \ + switch (nh) { \ + case 64: LAUNCH_HADAMARD_F16( 64); break; \ + case 128: LAUNCH_HADAMARD_F16(128); break; \ + case 256: LAUNCH_HADAMARD_F16(256); break; \ + case 512: LAUNCH_HADAMARD_F16(512); break; \ + default: GGML_ABORT("Unsupported Hadamard block size"); \ + } + +#define LAUNCH_HADAMARD_QUANT(NH, DEQUANT, QK, QR2) \ + hadamard_quant<<>>( \ + (const char *)src->data, (char *)dst->data, src->ne[0], \ + src->nb[1], src->nb[2], src->nb[3], dst->nb[1], dst->nb[2], dst->nb[3]) + +#define DISPATCH_HADAMARD_QUANT_NH(DEQUANT, QK, QR2) \ + switch (nh) { \ + case 64: LAUNCH_HADAMARD_QUANT( 64, DEQUANT, QK, QR2); break; \ + case 128: LAUNCH_HADAMARD_QUANT(128, DEQUANT, QK, QR2); break; \ + case 256: LAUNCH_HADAMARD_QUANT(256, DEQUANT, QK, QR2); break; \ + case 512: LAUNCH_HADAMARD_QUANT(512, DEQUANT, QK, QR2); break; \ + default: GGML_ABORT("Unsupported Hadamard block size"); \ + } void ggml_cuda_op_hadamard(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src = dst->src[0]; - GGML_ASSERT(src->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); GGML_ASSERT(ggml_are_same_shape(src, dst)); - int nh = dst->op_params[0]; + const int nh = dst->op_params[0]; GGML_ASSERT(dst->ne[0]%nh == 0); GGML_ASSERT(nh > 1 && popcount(nh) == 1); - hadamard_f32_cuda(nh, (const char *)src->data, (char *)dst->data, src->ne[0], src->ne[1], src->ne[2], src->ne[3], - src->nb[1], src->nb[2], src->nb[3], dst->nb[1], dst->nb[2], dst->nb[3], ctx.stream()); + cudaStream_t stream = ctx.stream(); + + if (src->type == GGML_TYPE_F32) { + hadamard_f32_cuda(nh, (const char *)src->data, (char *)dst->data, + src->ne[0], src->ne[1], src->ne[2], src->ne[3], + src->nb[1], src->nb[2], src->nb[3], dst->nb[1], dst->nb[2], dst->nb[3], stream); + return; + } + + dim3 num_blocks((src->ne[0]/nh) * src->ne[1], src->ne[2], src->ne[3]); + + switch (src->type) { + case GGML_TYPE_F16: DISPATCH_HADAMARD_F16_NH; break; + case GGML_TYPE_Q8_0: DISPATCH_HADAMARD_QUANT_NH(dequantize_q8_0, QK8_0, false); break; + case GGML_TYPE_Q4_0: DISPATCH_HADAMARD_QUANT_NH(dequantize_q4_0, QK4_0, true); break; + case GGML_TYPE_Q4_1: DISPATCH_HADAMARD_QUANT_NH(dequantize_q4_1, QK4_1, true); break; + case GGML_TYPE_Q5_0: DISPATCH_HADAMARD_QUANT_NH(dequantize_q5_0, QK5_0, true); break; + case GGML_TYPE_Q5_1: DISPATCH_HADAMARD_QUANT_NH(dequantize_q5_1, QK5_1, true); break; + case GGML_TYPE_Q6_0: DISPATCH_HADAMARD_QUANT_NH(dequantize_q6_0, QK6_0, true); break; + case GGML_TYPE_IQ4_NL: DISPATCH_HADAMARD_QUANT_NH(dequantize_iq4_nl, QK4_NL, true); break; + default: + GGML_ABORT("hadamard: unsupported source type"); + } } diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 91ab4c90..d6097649 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -6255,7 +6255,6 @@ struct ggml_tensor * ggml_hadamard( struct ggml_tensor * a, int n) { - GGML_ASSERT(a->type == GGML_TYPE_F32); // will not bother implementing for other data types GGML_ASSERT(n > 1); // no point in Hadamard transforms with less than 2 elements GGML_ASSERT(a->ne[0] % n == 0); GGML_ASSERT(popcount(n) == 1); // must be a power of 2 diff --git a/ggml/src/iqk/iqk_cpu_ops.cpp b/ggml/src/iqk/iqk_cpu_ops.cpp index 7b87e88d..56dedc42 100644 --- a/ggml/src/iqk/iqk_cpu_ops.cpp +++ b/ggml/src/iqk/iqk_cpu_ops.cpp @@ -521,7 +521,6 @@ void fast_ht(int n, T * values) { void iqk_hadamard(struct ggml_tensor * dst, int ith, int nth) { auto src = dst->src[0]; - GGML_ASSERT(src->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); GGML_ASSERT(ggml_are_same_shape(src, dst)); int nh = dst->op_params[0]; @@ -530,20 +529,41 @@ void iqk_hadamard(struct ggml_tensor * dst, int ith, int nth) { int nc = dst->ne[0]/nh; int nr = ggml_nrows(dst) * nc; - int npt = (nr + nth - 1)/nth; int first = npt*ith; int last = std::min(first + npt, nr); + if (src->type == GGML_TYPE_F32) { + for (int ir = first; ir < last; ++ir) { + int i3 = ir / (dst->ne[1] * dst->ne[2] * nc); + int i2 = (ir - i3*dst->ne[1] * dst->ne[2] * nc)/(dst->ne[1] * nc); + int i1 = (ir - i3*dst->ne[1] * dst->ne[2] * nc - i2*dst->ne[1]*nc)/nc; + int ic = (ir - i3*dst->ne[1] * dst->ne[2] * nc - i2*dst->ne[1]*nc - i1*nc); + + auto x = (const float *)((const char *)src->data + i3*src->nb[3] + i2*src->nb[2] + i1*src->nb[1]) + ic*nh; + auto y = ( float *)(( char *)dst->data + i3*dst->nb[3] + i2*dst->nb[2] + i1*dst->nb[1]) + ic*nh; + std::memcpy(y, x, nh*sizeof(float)); + fast_ht(nh, y); + } + return; + } + + auto traits = ggml_internal_get_type_traits(src->type); + GGML_ASSERT(traits.to_float != nullptr); + const size_t blck_size = traits.blck_size; + const size_t type_size = traits.type_size; + GGML_ASSERT(blck_size > 0 && (nh % blck_size == 0 || blck_size % nh == 0)); + for (int ir = first; ir < last; ++ir) { int i3 = ir / (dst->ne[1] * dst->ne[2] * nc); int i2 = (ir - i3*dst->ne[1] * dst->ne[2] * nc)/(dst->ne[1] * nc); int i1 = (ir - i3*dst->ne[1] * dst->ne[2] * nc - i2*dst->ne[1]*nc)/nc; int ic = (ir - i3*dst->ne[1] * dst->ne[2] * nc - i2*dst->ne[1]*nc - i1*nc); - auto x = (const float *)((const char *)src->data + i3*src->nb[3] + i2*src->nb[2] + i1*src->nb[1]) + ic*nh; - auto y = ( float *)(( char *)dst->data + i3*dst->nb[3] + i2*dst->nb[2] + i1*dst->nb[1]) + ic*nh; - std::memcpy(y, x, nh*sizeof(float)); + const char * x_row = (const char *)src->data + i3*src->nb[3] + i2*src->nb[2] + i1*src->nb[1]; + const size_t offset = ((size_t)ic * nh / blck_size) * type_size; + float * y = (float *)((char *)dst->data + i3*dst->nb[3] + i2*dst->nb[2] + i1*dst->nb[1]) + ic*nh; + traits.to_float(x_row + offset, y, nh); fast_ht(nh, y); } } diff --git a/src/graphs/build_deepseek2.cpp b/src/graphs/build_deepseek2.cpp index c6156860..9d9bc558 100644 --- a/src/graphs/build_deepseek2.cpp +++ b/src/graphs/build_deepseek2.cpp @@ -174,19 +174,15 @@ ggml_tensor * llm_build_context::build_deepseek2_tp_attention( row_size_cache, cache_local->nb[2], 0); cb(kv_cache_rope_view, "kv_cache_rope_pp", il_id); - // Hadamard cache was applied per 64-block during write; un-Hadamard the - // read views so the materialize mul_mats see the original latents. Hadamard - // requires F32 input, so dequantize the cache views first when the cache is - // quantized. Hadamard is its own inverse (the impl handles the scale). + // Un-Hadamard the cache views via the fused dequant+hadamard kernel. + // When khad_pretransformed is set, H was folded into wv_b/wk_b_pp at init, + // so the cache_nope un-Hadamard is skipped (rope half still goes to FA via + // concat — no wk_b multiply, no H to fold into). if (cparams.k_cache_hadamard) { - ggml_tensor * kn_f32 = kv_cache_nope->type == GGML_TYPE_F32 - ? kv_cache_nope - : ggml_cast(ctx0, kv_cache_nope, GGML_TYPE_F32); - ggml_tensor * kr_f32 = kv_cache_rope_view->type == GGML_TYPE_F32 - ? kv_cache_rope_view - : ggml_cast(ctx0, kv_cache_rope_view, GGML_TYPE_F32); - kv_cache_nope = ggml_hadamard(ctx0, kn_f32, 64); - kv_cache_rope_view = ggml_hadamard(ctx0, kr_f32, 64); + kv_cache_rope_view = ggml_hadamard(ctx0, kv_cache_rope_view, 64); + if (!model.khad_pretransformed) { + kv_cache_nope = ggml_hadamard(ctx0, kv_cache_nope, 64); + } } // CUDA quantized-cache + REPEAT/CONCAT/CPY has known issues, so force F16 here. @@ -285,7 +281,11 @@ ggml_tensor * llm_build_context::build_deepseek2_tp_attention( if (use_f32_attn_precision) { ggml_flash_attn_ext_set_prec(kqv_compressed, GGML_PREC_F32); } - if (cparams.k_cache_hadamard) { + // When khad_pretransformed is set, H is folded into wv_b. FA leaves + // kqv_compressed in the H-encoded basis; the mul_mat(H@wv_b, kqv_encoded) + // below collapses to wv_b^T @ kqv_unencoded by H @ H = I. Skip the + // post-FA un-encode so the fold composes correctly. + if (cparams.k_cache_hadamard && !model.khad_pretransformed) { kqv_compressed = ggml_hadamard(ctx0, kqv_compressed, 64); } kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3); diff --git a/src/llama-model.h b/src/llama-model.h index 14e2d44a..d2972feb 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -496,6 +496,9 @@ struct llama_model { bool tensor_overrides; + // Set by llm_apply_khad_pretransform once H is folded into wv_b/wk_b_pp. + bool khad_pretransformed = false; + ~llama_model(); size_t max_nodes(int n_tokens) const { diff --git a/src/llama.cpp b/src/llama.cpp index 3d773d16..7651ec35 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2809,6 +2809,123 @@ static void llm_prepare_mla(llama_model & model, int mla) { ggml_free(ctx); } +// Fold the 64-block Hadamard into wv_b/wk_b_pp at init; build_deepseek2.cpp then +// skips the runtime cache_nope un-Hadamard. Math identity by H^T H = I. +static void llm_apply_khad_pretransform(llama_model & model) { + if (model.khad_pretransformed) return; + if (model.arch != LLM_ARCH_DEEPSEEK2 && model.arch != LLM_ARCH_GLM_DSA && model.arch != LLM_ARCH_MISTRAL4) return; + + // High-enough bpw to survive one quant->F32->H->quant roundtrip within PPL noise. + // Cliff is ~2.7 bpw: IQ3_XXS (3.06) sits at +0.05 noise edge; IQ2_XS (2.31) drifts +0.20. + // Below-cliff and unmeasured types fall back to the runtime cache_nope path. + auto castable = [](ggml_type t) { + return t == GGML_TYPE_F32 || t == GGML_TYPE_F16 || t == GGML_TYPE_BF16 || + t == GGML_TYPE_Q4_0 || t == GGML_TYPE_Q4_1 || + t == GGML_TYPE_Q5_0 || t == GGML_TYPE_Q5_1 || + t == GGML_TYPE_Q6_0 || t == GGML_TYPE_Q8_0 || + t == GGML_TYPE_IQ4_NL || + t == GGML_TYPE_Q4_K || t == GGML_TYPE_Q5_K || t == GGML_TYPE_Q6_K || + t == GGML_TYPE_IQ4_K || t == GGML_TYPE_IQ5_K || + t == GGML_TYPE_IQ4_KS|| t == GGML_TYPE_IQ4_KSS|| + t == GGML_TYPE_IQ5_KS; + }; + + const auto & hparams = model.hparams; + const int64_t kv_lora_rank = hparams.n_lora_kv; + if (kv_lora_rank <= 0 || kv_lora_rank % 64 != 0) return; + + // All-or-nothing: a partially-folded model would consume H-applied cache + // through un-folded weights and produce wrong values. + bool all_eligible = true; + int n_layers_with_pp = 0; + for (auto & l : model.layers) { + if (!l.wv_b || !l.wk_b_pp) continue; + ++n_layers_with_pp; + ggml_type tv = l.wv_b->type; + ggml_type tk = l.wk_b_pp->type; + if (!castable(tv) || !castable(tk)) { + all_eligible = false; + break; + } + } + if (!all_eligible || n_layers_with_pp == 0) { + LLAMA_LOG_INFO("============ %s: skipping (no eligible wv_b/wk_b_pp; n_pp=%d)\n", + __func__, n_layers_with_pp); + return; + } + + auto fold_tensor = [&](ggml_tensor * t) -> bool { + if (!t) return true; + if (t->ne[0] != kv_lora_rank) return false; + + const size_t nbytes = ggml_nbytes(t); + std::vector host_in(nbytes); + ggml_backend_tensor_get(t, host_in.data(), 0, nbytes); + + ggml_init_params ip{ ggml_tensor_overhead()*32 + ggml_graph_overhead(), nullptr, true }; + auto ctx = ggml_init(ip); + auto graph = ggml_new_graph(ctx); + + ggml_tensor src_host = *t; + src_host.data = host_in.data(); + src_host.op = GGML_OP_NONE; + for (int j = 0; j < GGML_MAX_SRC; ++j) src_host.src[j] = nullptr; + src_host.buffer = nullptr; + src_host.extra = nullptr; + src_host.view_src = nullptr; + src_host.view_offs = 0; + + const size_t n_f32_bytes = (size_t)ggml_nelements(t) * sizeof(float); + std::vector f32_buf_a(n_f32_bytes); + std::vector f32_buf_b(n_f32_bytes); + std::vector out_buf(nbytes); + + auto src_f32 = ggml_cast(ctx, &src_host, GGML_TYPE_F32); + src_f32->data = f32_buf_a.data(); + auto had = ggml_hadamard(ctx, src_f32, 64); + had->data = f32_buf_b.data(); + auto out_q = ggml_cast(ctx, had, t->type); + out_q->data = out_buf.data(); + + ggml_build_forward_expand(graph, out_q); + + std::vector work_data; + auto plan = ggml_graph_plan(graph, std::thread::hardware_concurrency()/2); + if (plan.work_size > work_data.size()) work_data.resize(plan.work_size); + plan.work_data = work_data.data(); + bool ok = (ggml_graph_compute(graph, &plan) == GGML_STATUS_SUCCESS); + ggml_free(ctx); + if (!ok) return false; + + ggml_backend_tensor_set(t, out_buf.data(), 0, nbytes); + return true; + }; + + auto fold_split_or_single = [&](ggml_tensor * full) -> bool { + if (!full) return true; + if (full->extra) { + auto split = (const ggml_split_tensor_t *)full->extra; + for (int id = 0; id < split->n_device; ++id) { + if (!split->splits[id]) continue; + if (!fold_tensor(split->splits[id])) return false; + } + return true; + } + return fold_tensor(full); + }; + + int n_folded = 0; + for (auto & l : model.layers) { + if (!l.wv_b || !l.wk_b_pp) continue; + if (!fold_split_or_single(l.wv_b)) { LLAMA_LOG_ERROR("%s: failed to fold wv_b\n", __func__); return; } + if (!fold_split_or_single(l.wk_b_pp)){ LLAMA_LOG_ERROR("%s: failed to fold wk_b_pp\n",__func__); return; } + ++n_folded; + } + + model.khad_pretransformed = (n_folded > 0); + LLAMA_LOG_INFO("============ %s: folded H into wv_b/wk_b_pp on %d layers\n", __func__, n_folded); +} + static void llm_scale_gate_inp_s(llama_model & model, bool uses_mmap) { auto & hparams = model.hparams; printf("%s: n_embd = %d\n", __func__, hparams.n_embd); @@ -6503,6 +6620,17 @@ struct llama_context * llama_init_from_model( cparams.graph_reuse = params.graph_reuse; cparams.k_cache_hadamard = params.k_cache_hadamard; cparams.v_cache_hadamard = params.v_cache_hadamard; + // Folding H into wv_b/wk_b_pp permanently mutates the model; a later context + // on the same model with khad=false would consume an H-applied weight against + // an un-applied cache and produce wrong values. Force khad=true to keep math + // consistent for the rest of the model's lifetime. + if (model->khad_pretransformed && !cparams.k_cache_hadamard) { + LLAMA_LOG_WARN("%s: model has Hadamard-folded wv_b/wk_b_pp; forcing k_cache_hadamard=true\n", __func__); + cparams.k_cache_hadamard = true; + } + if (cparams.k_cache_hadamard) { + llm_apply_khad_pretransform(*model); + } cparams.split_mode_graph_scheduling = params.split_mode_graph_scheduling; //cparams.split_mode_f16 = params.split_mode_f16; cparams.scheduler_async = params.scheduler_async;