Do not repeat yourself (#1373)

* DRY - part 1

* DRY - part 2

* DRY - part 3

* Fix NEON

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow 2026-03-06 16:06:51 +01:00 committed by GitHub
parent 082addead2
commit 277fc1d26f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 61 additions and 61 deletions

View File

@ -37,6 +37,8 @@ __global__ void delta_net_recurrent_f32(
const float * __restrict__ state_in, // [HEAD_DIM, HEAD_DIM*n_heads, 1, n_seqs]
float * __restrict__ dst, // output + new_state concatenated
const int64_t n_heads,
const int64_t gqa_ratio,
const int repeat_type,
const int64_t n_tokens,
const int64_t n_seqs,
const int64_t output_offset, // offset where state starts in output
@ -46,6 +48,7 @@ __global__ void delta_net_recurrent_f32(
const int sub_head_idx = blockIdx.x % (warps_per_head*n_heads);
const int head_idx = sub_head_idx / warps_per_head;
const int sub_idx = sub_head_idx % warps_per_head;
const int head_idx_kq = repeat_type == 0 ? head_idx / gqa_ratio : head_idx % (n_heads/gqa_ratio);
const int tid = threadIdx.x;
// Strides for input tensors (column-major)
@ -53,6 +56,7 @@ __global__ void delta_net_recurrent_f32(
const int64_t qkv_stride_token = HEAD_DIM;
const int64_t qkv_stride_head = HEAD_DIM * n_tokens;
const int64_t qkv_stride_batch = HEAD_DIM * n_tokens * n_heads;
const int64_t qkv_stride_batch_kq = qkv_stride_batch / gqa_ratio;
// G/Beta: [n_tokens, 1, n_heads, n_seqs] / [1, n_tokens, n_heads, n_seqs]
const int64_t g_stride_head = n_tokens;
@ -66,8 +70,8 @@ __global__ void delta_net_recurrent_f32(
const int64_t state_batch_stride = HEAD_DIM * HEAD_DIM * n_heads;
// Pointers for this batch/head
const float * q_ptr = q + batch_idx * qkv_stride_batch + head_idx * qkv_stride_head;
const float * k_ptr = k + batch_idx * qkv_stride_batch + head_idx * qkv_stride_head;
const float * q_ptr = q + batch_idx * qkv_stride_batch_kq + head_idx_kq * qkv_stride_head;
const float * k_ptr = k + batch_idx * qkv_stride_batch_kq + head_idx_kq * qkv_stride_head;
const float * v_ptr = v + batch_idx * qkv_stride_batch + head_idx * qkv_stride_head;
const float * g_ptr = g + batch_idx * g_stride_batch + head_idx * g_stride_head;
const float * beta_ptr = beta_in + batch_idx * g_stride_batch + head_idx * g_stride_head;
@ -171,6 +175,8 @@ static void delta_net_f32_cuda(
const int64_t head_dim,
const int64_t n_tokens,
const int64_t n_heads,
const int64_t gqa_ratio,
const int repeat_type,
const int64_t n_seqs,
const float eps,
const int device_id,
@ -193,19 +199,19 @@ static void delta_net_f32_cuda(
constexpr int threads_per_block = 256;
if (head_dim == 64) {
delta_net_recurrent_f32<64, threads_per_block><<<num_blocks, threads_per_block, smem_size, stream>>>(
q, k, v, g, beta, state_in, dst, n_heads, n_tokens, n_seqs, output_offset, eps);
q, k, v, g, beta, state_in, dst, n_heads, gqa_ratio, repeat_type, n_tokens, n_seqs, output_offset, eps);
} else {
delta_net_recurrent_f32<128, threads_per_block><<<num_blocks, threads_per_block, smem_size, stream>>>(
q, k, v, g, beta, state_in, dst, n_heads, n_tokens, n_seqs, output_offset, eps);
q, k, v, g, beta, state_in, dst, n_heads, gqa_ratio, repeat_type, n_tokens, n_seqs, output_offset, eps);
}
} else {
constexpr int threads_per_block = 128;
if (head_dim == 64) {
delta_net_recurrent_f32<64, threads_per_block><<<num_blocks, threads_per_block, smem_size, stream>>>(
q, k, v, g, beta, state_in, dst, n_heads, n_tokens, n_seqs, output_offset, eps);
q, k, v, g, beta, state_in, dst, n_heads, gqa_ratio, repeat_type, n_tokens, n_seqs, output_offset, eps);
} else {
delta_net_recurrent_f32<128, threads_per_block><<<num_blocks, threads_per_block, smem_size, stream>>>(
q, k, v, g, beta, state_in, dst, n_heads, n_tokens, n_seqs, output_offset, eps);
q, k, v, g, beta, state_in, dst, n_heads, gqa_ratio, repeat_type, n_tokens, n_seqs, output_offset, eps);
}
}
@ -226,12 +232,15 @@ void ggml_cuda_op_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
const int64_t head_dim = src0->ne[0];
const int64_t n_tokens = src0->ne[1];
const int64_t n_heads = src0->ne[2];
const int64_t n_heads = src2->ne[2];
const int64_t n_heads_kq = src0->ne[2];
const int64_t n_seqs = src0->ne[3];
GGML_ASSERT(n_heads % n_heads_kq == 0);
const int64_t gqa_ratio = n_heads / n_heads_kq;
// Dimension validation
// Q/K: [head_dim, n_tokens, n_heads, n_seqs]
GGML_ASSERT(src1->ne[0] == head_dim && src1->ne[1] == n_tokens && src1->ne[2] == n_heads && src1->ne[3] == n_seqs);
GGML_ASSERT(src1->ne[0] == head_dim && src1->ne[1] == n_tokens && src1->ne[2] == n_heads_kq && src1->ne[3] == n_seqs);
// V: [head_dim, n_tokens, n_heads, n_seqs]
GGML_ASSERT(src2->ne[0] == head_dim && src2->ne[1] == n_tokens && src2->ne[2] == n_heads && src2->ne[3] == n_seqs);
// G: [n_tokens, 1, n_heads, n_seqs]
@ -247,6 +256,7 @@ void ggml_cuda_op_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
GGML_ASSERT(ggml_nelements(dst) == output_size + state_size);
const float eps = 1e-6f;
int repeat_type = dst->op_params[0];
GGML_ASSERT(head_dim <= 256); // Reasonable limit for shared memory
@ -262,7 +272,7 @@ void ggml_cuda_op_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
(const float *)src4->data,
(const float *)src5->data,
(float *)dst->data,
head_dim, n_tokens, n_heads, n_seqs, eps,
head_dim, n_tokens, n_heads, gqa_ratio, repeat_type, n_seqs, eps,
device_id, cc,
ctx.stream());

View File

@ -9903,10 +9903,11 @@ struct ggml_tensor * ggml_delta_net(
GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == n_tokens && k->ne[2] == H_k && k->ne[3] == n_seqs);
GGML_ASSERT(v->ne[1] == n_tokens && v->ne[3] == n_seqs);
GGML_ASSERT(g->ne[0] == n_tokens && g->ne[1] == 1 && g->ne[2] == H_k && g->ne[3] == n_seqs);
GGML_ASSERT(beta->ne[0] == 1 && beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[3] == n_seqs);
GGML_ASSERT(g->ne[0] == n_tokens && g->ne[1] == 1 && g->ne[2] == H_v && g->ne[3] == n_seqs);
GGML_ASSERT(beta->ne[0] == 1 && beta->ne[1] == n_tokens && beta->ne[2] == H_v && beta->ne[3] == n_seqs);
GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs);
GGML_ASSERT(H_k == H_v);
//GGML_ASSERT(H_k == H_v);
GGML_ASSERT(H_v % H_k == 0);
const int64_t output_size = S_v * H_v * n_tokens * n_seqs;
const int64_t state_size = S_v * S_v * H_v * n_seqs;
@ -22548,8 +22549,10 @@ static void ggml_compute_forward_delta_net_f32(
const int64_t head_dim = src0->ne[0];
const int64_t n_tokens = src0->ne[1];
const int64_t n_heads = src0->ne[2];
const int64_t n_heads = src2->ne[2];
const int64_t n_seqs = src0->ne[3];
GGML_ASSERT(src2->ne[2] % src0->ne[2] == 0);
const int gqa_ratio = src2->ne[2]/src0->ne[2];
const int64_t output_size = head_dim * n_tokens * n_heads * n_seqs;
@ -22565,7 +22568,9 @@ static void ggml_compute_forward_delta_net_f32(
const int ith = params->ith;
const int nth = params->nth;
if (iqk_fused_delta_net(head_dim, n_heads, n_tokens, n_seqs, q_data, k_data, v_data, g_data, beta_data, state_in,
int repeat_type = dst->op_params[0];
if (iqk_fused_delta_net(head_dim, n_heads, gqa_ratio, repeat_type, n_tokens, n_seqs, q_data, k_data, v_data, g_data, beta_data, state_in,
out_data, state_out, ith, nth)) {
return;
}
@ -22584,8 +22589,10 @@ static void ggml_compute_forward_delta_net_f32(
for (int64_t h_idx = h_start; h_idx < h_end; ++h_idx) {
const int64_t batch_idx = h_idx / n_heads;
const int64_t head_idx = h_idx % n_heads;
const int64_t head_idx_kq = repeat_type == 0 ? head_idx / gqa_ratio : head_idx % (n_heads/gqa_ratio);
const int64_t qkv_head_offset = batch_idx * (head_dim * n_tokens * n_heads) + head_idx * (head_dim * n_tokens);
const int64_t qkv_head_offset_kq = batch_idx * (head_dim * n_tokens * n_heads/gqa_ratio) + head_idx_kq * (head_dim * n_tokens);
const int64_t qkv_token_stride = head_dim;
const int64_t g_head_offset = batch_idx * (n_tokens * n_heads) + head_idx * n_tokens;
const int64_t state_head_offset = batch_idx * (head_dim * head_dim * n_heads) + head_idx * (head_dim * head_dim);
@ -22599,8 +22606,8 @@ static void ggml_compute_forward_delta_net_f32(
float * state = state_out + state_head_offset;
for (int64_t t = 0; t < n_tokens; ++t) {
const float * q_t = q_data + qkv_head_offset + t * qkv_token_stride;
const float * k_t = k_data + qkv_head_offset + t * qkv_token_stride;
const float * q_t = q_data + qkv_head_offset_kq + t * qkv_token_stride;
const float * k_t = k_data + qkv_head_offset_kq + t * qkv_token_stride;
const float * v_t = v_data + qkv_head_offset + t * qkv_token_stride;
const float g_val = g_data[g_head_offset + t];

View File

@ -1386,7 +1386,7 @@ bool iqk_flash_attn_impl(int int_type_k, // type of k
namespace {
#ifdef __ARM_NEON
template <int head_dim>
void iqk_fused_delta_net_neon_impl(int n_heads, int n_tokens, int n_seqs,
void iqk_fused_delta_net_neon_impl(int n_heads, int gqa_ratio, int repeat_type, int n_tokens, int n_seqs,
const float * q_data, const float * k_data, const float * v_data, const float * g_data, const float * beta_data,
const float * state_in, float * out_data, float * state_out, int ith, int nth) {
const int total_heads = n_heads * n_seqs;
@ -1406,8 +1406,10 @@ void iqk_fused_delta_net_neon_impl(int n_heads, int n_tokens, int n_seqs,
for (int h_idx = h_start; h_idx < h_end; ++h_idx) {
const int batch_idx = h_idx / n_heads;
const int head_idx = h_idx % n_heads;
const int head_idx_kq = repeat_type == 0 ? head_idx / gqa_ratio : head_idx % (n_heads/gqa_ratio);
const int qkv_head_offset = batch_idx * (head_dim * n_tokens * n_heads) + head_idx * (head_dim * n_tokens);
const int qkv_head_offset = batch_idx * (head_dim * n_tokens * n_heads) + head_idx * (head_dim * n_tokens);
const int qkv_head_offset_kq = batch_idx * (head_dim * n_tokens * n_heads/gqa_ratio) + head_idx_kq * (head_dim * n_tokens);
const int qkv_token_stride = head_dim;
const int g_head_offset = batch_idx * (n_tokens * n_heads) + head_idx * n_tokens;
const int state_head_offset = batch_idx * (head_dim * head_dim * n_heads) + head_idx * (head_dim * head_dim);
@ -1422,8 +1424,8 @@ void iqk_fused_delta_net_neon_impl(int n_heads, int n_tokens, int n_seqs,
for (int t = 0; t < n_tokens; ++t) {
const float * q_t = q_data + qkv_head_offset + t * qkv_token_stride;
const float * k_t = k_data + qkv_head_offset + t * qkv_token_stride;
const float * q_t = q_data + qkv_head_offset_kq + t * qkv_token_stride;
const float * k_t = k_data + qkv_head_offset_kq + t * qkv_token_stride;
const float * v_t = v_data + qkv_head_offset + t * qkv_token_stride;
const float g_val = g_data[g_head_offset + t];
@ -1492,11 +1494,11 @@ void iqk_fused_delta_net_neon_impl(int n_heads, int n_tokens, int n_seqs,
}
#endif
template <int head_dim>
void iqk_fused_delta_net_impl(int n_heads, int n_tokens, int n_seqs,
void iqk_fused_delta_net_impl(int n_heads, int gqa_ratio, int repeat_type, int n_tokens, int n_seqs,
const float * q_data, const float * k_data, const float * v_data, const float * g_data, const float * beta_data,
const float * state_in, float * out_data, float * state_out, int ith, int nth) {
#ifdef __ARM_NEON
iqk_fused_delta_net_neon_impl<head_dim>(n_heads, n_tokens, n_seqs, q_data, k_data, v_data, g_data, beta_data, state_in, out_data, state_out, ith, nth);
iqk_fused_delta_net_neon_impl<head_dim>(n_heads, gqa_ratio, repeat_type, n_tokens, n_seqs, q_data, k_data, v_data, g_data, beta_data, state_in, out_data, state_out, ith, nth);
return;
#endif
const int total_heads = n_heads * n_seqs;
@ -1520,8 +1522,10 @@ void iqk_fused_delta_net_impl(int n_heads, int n_tokens, int n_seqs,
for (int h_idx = h_start; h_idx < h_end; ++h_idx) {
const int batch_idx = h_idx / n_heads;
const int head_idx = h_idx % n_heads;
const int head_idx_kq = repeat_type == 0 ? head_idx / gqa_ratio : head_idx % (n_heads/gqa_ratio);
const int qkv_head_offset = batch_idx * (head_dim * n_tokens * n_heads) + head_idx * (head_dim * n_tokens);
const int qkv_head_offset = batch_idx * (head_dim * n_tokens * n_heads) + head_idx * (head_dim * n_tokens);
const int qkv_head_offset_kq = batch_idx * (head_dim * n_tokens * n_heads/gqa_ratio) + head_idx_kq * (head_dim * n_tokens);
const int qkv_token_stride = head_dim;
const int g_head_offset = batch_idx * (n_tokens * n_heads) + head_idx * n_tokens;
const int state_head_offset = batch_idx * (head_dim * head_dim * n_heads) + head_idx * (head_dim * head_dim);
@ -1535,8 +1539,8 @@ void iqk_fused_delta_net_impl(int n_heads, int n_tokens, int n_seqs,
float * state = state_out + state_head_offset;
for (int t = 0; t < n_tokens; ++t) {
const float * q_t = q_data + qkv_head_offset + t * qkv_token_stride;
const float * k_t = k_data + qkv_head_offset + t * qkv_token_stride;
const float * q_t = q_data + qkv_head_offset_kq + t * qkv_token_stride;
const float * k_t = k_data + qkv_head_offset_kq + t * qkv_token_stride;
const float * v_t = v_data + qkv_head_offset + t * qkv_token_stride;
const float g_val = g_data[g_head_offset + t];
@ -1658,17 +1662,17 @@ void iqk_fused_delta_net_impl(int n_heads, int n_tokens, int n_seqs,
}
}
bool iqk_fused_delta_net(int head_dim, int n_heads, int n_tokens, int n_seqs,
bool iqk_fused_delta_net(int head_dim, int n_heads, int gqa_ratio, int repeat_type, int n_tokens, int n_seqs,
const float * q_data, const float * k_data, const float * v_data, const float * g_data, const float * beta_data,
const float * state_in, float * out_data, float * state_out, int ith, int nth) {
if (head_dim != 64 && head_dim != 128) {
return false;
}
if (head_dim == 64) {
iqk_fused_delta_net_impl<64>(n_heads, n_tokens, n_seqs, q_data, k_data, v_data, g_data, beta_data, state_in,
iqk_fused_delta_net_impl<64>(n_heads, gqa_ratio, repeat_type, n_tokens, n_seqs, q_data, k_data, v_data, g_data, beta_data, state_in,
out_data, state_out, ith, nth);
} else {
iqk_fused_delta_net_impl<128>(n_heads, n_tokens, n_seqs, q_data, k_data, v_data, g_data, beta_data, state_in,
iqk_fused_delta_net_impl<128>(n_heads, gqa_ratio, repeat_type, n_tokens, n_seqs, q_data, k_data, v_data, g_data, beta_data, state_in,
out_data, state_out, ith, nth);
}
return true;
@ -1707,7 +1711,7 @@ extern "C" IQK_API bool iqk_moe_fused_up_gate(long /*Nx*/, long /*Ny*/, long /*n
return false;
}
bool iqk_fused_delta_net(int, int, int, int,
bool iqk_fused_delta_net(int, int, int, int, int, int,
const float *, const float *, const float *, const float *, const float *,
const float *, float *, float *, int, int) {
return false;

View File

@ -73,7 +73,7 @@ IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias,
IQK_API void iqk_topk_moe(int n_experts, int n_experts_used, int nrows, const float * logits,
float * weights, int32_t * ids, int ith, int nth);
IQK_API bool iqk_fused_delta_net(int head_dim, int n_heads, int n_tokens, int n_seqs,
IQK_API bool iqk_fused_delta_net(int head_dim, int n_heads, int gqa_ratio, int repeat_type, int n_tokens, int n_seqs,
const float * q_data, const float * k_data, const float * v_data, const float * g_data, const float * beta_data,
const float * state_in, float * out_data, float * state_out, int ith, int nth);

View File

@ -77,7 +77,7 @@ delta_net::~delta_net() = default;
std::pair<ggml_tensor *, ggml_tensor *> delta_net::build_fused_delta_net(ggml_context * ctx0,
ggml_tensor * q, ggml_tensor * k, ggml_tensor * v,
ggml_tensor * g, ggml_tensor * beta, ggml_tensor * state,
int il, const llm_build_cb & cb) {
int il, const llm_build_cb & cb, int repeat_type) {
const int64_t S_k = q->ne[0];
const int64_t H_k = q->ne[1];
@ -94,7 +94,8 @@ std::pair<ggml_tensor *, ggml_tensor *> delta_net::build_fused_delta_net(ggml_co
GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v && state->ne[2] == H_v && state->ne[3] == n_seqs);
GGML_ASSERT(H_k == H_v);
//GGML_ASSERT(H_k == H_v);
GGML_ASSERT(H_v % H_k == 0);
cb(q, "q_in", il);
cb(k, "k_in", il);
@ -112,8 +113,8 @@ std::pair<ggml_tensor *, ggml_tensor *> delta_net::build_fused_delta_net(ggml_co
q = ggml_cont_4d(ctx0, q, S_k, n_tokens, H_k, n_seqs);
k = ggml_cont_4d(ctx0, k, S_k, n_tokens, H_k, n_seqs);
v = ggml_cont_4d(ctx0, v, S_v, n_tokens, H_v, n_seqs);
g = ggml_cont_4d(ctx0, g, n_tokens, 1, H_k, n_seqs);
beta = ggml_cont_4d(ctx0, beta, 1, n_tokens, H_k, n_seqs);
g = ggml_cont_4d(ctx0, g, n_tokens, 1, H_v, n_seqs);
beta = ggml_cont_4d(ctx0, beta, 1, n_tokens, H_v, n_seqs);
}
ggml_tensor * state_flat = ggml_reshape_4d(ctx0, state, S_v, S_v * H_v, 1, n_seqs);
@ -130,6 +131,7 @@ std::pair<ggml_tensor *, ggml_tensor *> delta_net::build_fused_delta_net(ggml_co
ggml_tensor * fused_result = ggml_delta_net(ctx0, q, k, v, g, beta, state_flat);
cb(fused_result, "delta_net_fused_raw", il);
fused_result->op_params[0] = repeat_type;
const int64_t output_size = S_v * H_v * n_tokens * n_seqs;
const int64_t state_size = S_v * S_v * H_v * n_seqs;
@ -359,33 +361,10 @@ ggml_tensor * delta_net::build_layer_attn_linear_core(ggml_context * ctx0, ggml_
q_conv = ggml_l2_norm(ctx0, q_conv, eps_norm);
k_conv = ggml_l2_norm(ctx0, k_conv, eps_norm);
cb(q_conv, "q_conv_normed", il);
cb(k_conv, "k_conv_normed", il);
if (num_k_heads != num_v_heads) {
GGML_ASSERT(num_v_heads % num_k_heads == 0);
if (model.layers[il].ssm_beta_alpha) {
const int64_t repeat_factor = num_v_heads / num_k_heads;
ggml_tensor * q_reshaped = ggml_reshape_3d(ctx0, q_conv, head_k_dim, 1, num_k_heads * n_tok);
ggml_tensor * k_reshaped = ggml_reshape_3d(ctx0, k_conv, head_k_dim, 1, num_k_heads * n_tok);
ggml_tensor * q_repeated = ggml_repeat_4d(ctx0, q_reshaped, head_k_dim, repeat_factor, num_k_heads * n_tok, 1);
ggml_tensor * k_repeated = ggml_repeat_4d(ctx0, k_reshaped, head_k_dim, repeat_factor, num_k_heads * n_tok, 1);
cb(q_repeated, "q_repeated", il);
cb(k_repeated, "k_repeated", il);
q_conv = ggml_reshape_4d(ctx0, q_repeated, head_k_dim, num_k_heads * repeat_factor, n_tok, 1);
k_conv = ggml_reshape_4d(ctx0, k_repeated, head_k_dim, num_k_heads * repeat_factor, n_tok, 1);
} else {
q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs);
k_conv = ggml_repeat_4d(ctx0, k_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs);
}
}
cb(q_conv, "q_conv_predelta", il);
cb(k_conv, "k_conv_predelta", il);
cb(v_conv, "v_conv_predelta", il);
auto [output, new_state] = build_fused_delta_net(ctx0, q_conv, k_conv, v_conv, gate, beta, state, il, cb);
auto [output, new_state] = build_fused_delta_net(ctx0, q_conv, k_conv, v_conv, gate, beta, state, il, cb, model.layers[il].ssm_beta_alpha ? 0 : 1);
cb(output, "attn_output", il);
cb(new_state, "new_state", il);

View File

@ -11,7 +11,7 @@ struct delta_net {
static std::pair<ggml_tensor *, ggml_tensor *> build_fused_delta_net(ggml_context * ctx0,
ggml_tensor * q, ggml_tensor * k, ggml_tensor * v,
ggml_tensor * g, ggml_tensor * beta, ggml_tensor * state,
int il, const llm_build_cb & cb);
int il, const llm_build_cb & cb, int repeat_type);
std::pair<ggml_tensor *, ggml_tensor *> build_qkvz(ggml_context * ctx0, ggml_tensor * input, int il, const llm_build_cb & cb, ggml_cgraph * gf) const;