From 277fc1d26f52fb7dfa1266c76f51e7c75899469c Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Fri, 6 Mar 2026 16:06:51 +0100 Subject: [PATCH] Do not repeat yourself (#1373) * DRY - part 1 * DRY - part 2 * DRY - part 3 * Fix NEON --------- Co-authored-by: Iwan Kawrakow --- ggml/src/ggml-cuda/delta-net.cu | 28 +++++++++++++++-------- ggml/src/ggml.c | 21 ++++++++++++------ ggml/src/iqk/iqk_mul_mat.cpp | 30 ++++++++++++++----------- ggml/src/iqk/iqk_mul_mat.h | 2 +- src/llama-delta-net.cpp | 39 ++++++++------------------------- src/llama-delta-net.h | 2 +- 6 files changed, 61 insertions(+), 61 deletions(-) diff --git a/ggml/src/ggml-cuda/delta-net.cu b/ggml/src/ggml-cuda/delta-net.cu index a0193b55..f0bf36e9 100644 --- a/ggml/src/ggml-cuda/delta-net.cu +++ b/ggml/src/ggml-cuda/delta-net.cu @@ -37,6 +37,8 @@ __global__ void delta_net_recurrent_f32( const float * __restrict__ state_in, // [HEAD_DIM, HEAD_DIM*n_heads, 1, n_seqs] float * __restrict__ dst, // output + new_state concatenated const int64_t n_heads, + const int64_t gqa_ratio, + const int repeat_type, const int64_t n_tokens, const int64_t n_seqs, const int64_t output_offset, // offset where state starts in output @@ -46,6 +48,7 @@ __global__ void delta_net_recurrent_f32( const int sub_head_idx = blockIdx.x % (warps_per_head*n_heads); const int head_idx = sub_head_idx / warps_per_head; const int sub_idx = sub_head_idx % warps_per_head; + const int head_idx_kq = repeat_type == 0 ? head_idx / gqa_ratio : head_idx % (n_heads/gqa_ratio); const int tid = threadIdx.x; // Strides for input tensors (column-major) @@ -53,6 +56,7 @@ __global__ void delta_net_recurrent_f32( const int64_t qkv_stride_token = HEAD_DIM; const int64_t qkv_stride_head = HEAD_DIM * n_tokens; const int64_t qkv_stride_batch = HEAD_DIM * n_tokens * n_heads; + const int64_t qkv_stride_batch_kq = qkv_stride_batch / gqa_ratio; // G/Beta: [n_tokens, 1, n_heads, n_seqs] / [1, n_tokens, n_heads, n_seqs] const int64_t g_stride_head = n_tokens; @@ -66,8 +70,8 @@ __global__ void delta_net_recurrent_f32( const int64_t state_batch_stride = HEAD_DIM * HEAD_DIM * n_heads; // Pointers for this batch/head - const float * q_ptr = q + batch_idx * qkv_stride_batch + head_idx * qkv_stride_head; - const float * k_ptr = k + batch_idx * qkv_stride_batch + head_idx * qkv_stride_head; + const float * q_ptr = q + batch_idx * qkv_stride_batch_kq + head_idx_kq * qkv_stride_head; + const float * k_ptr = k + batch_idx * qkv_stride_batch_kq + head_idx_kq * qkv_stride_head; const float * v_ptr = v + batch_idx * qkv_stride_batch + head_idx * qkv_stride_head; const float * g_ptr = g + batch_idx * g_stride_batch + head_idx * g_stride_head; const float * beta_ptr = beta_in + batch_idx * g_stride_batch + head_idx * g_stride_head; @@ -171,6 +175,8 @@ static void delta_net_f32_cuda( const int64_t head_dim, const int64_t n_tokens, const int64_t n_heads, + const int64_t gqa_ratio, + const int repeat_type, const int64_t n_seqs, const float eps, const int device_id, @@ -193,19 +199,19 @@ static void delta_net_f32_cuda( constexpr int threads_per_block = 256; if (head_dim == 64) { delta_net_recurrent_f32<64, threads_per_block><<>>( - q, k, v, g, beta, state_in, dst, n_heads, n_tokens, n_seqs, output_offset, eps); + q, k, v, g, beta, state_in, dst, n_heads, gqa_ratio, repeat_type, n_tokens, n_seqs, output_offset, eps); } else { delta_net_recurrent_f32<128, threads_per_block><<>>( - q, k, v, g, beta, state_in, dst, n_heads, n_tokens, n_seqs, output_offset, eps); + q, k, v, g, beta, state_in, dst, n_heads, gqa_ratio, repeat_type, n_tokens, n_seqs, output_offset, eps); } } else { constexpr int threads_per_block = 128; if (head_dim == 64) { delta_net_recurrent_f32<64, threads_per_block><<>>( - q, k, v, g, beta, state_in, dst, n_heads, n_tokens, n_seqs, output_offset, eps); + q, k, v, g, beta, state_in, dst, n_heads, gqa_ratio, repeat_type, n_tokens, n_seqs, output_offset, eps); } else { delta_net_recurrent_f32<128, threads_per_block><<>>( - q, k, v, g, beta, state_in, dst, n_heads, n_tokens, n_seqs, output_offset, eps); + q, k, v, g, beta, state_in, dst, n_heads, gqa_ratio, repeat_type, n_tokens, n_seqs, output_offset, eps); } } @@ -226,12 +232,15 @@ void ggml_cuda_op_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst) const int64_t head_dim = src0->ne[0]; const int64_t n_tokens = src0->ne[1]; - const int64_t n_heads = src0->ne[2]; + const int64_t n_heads = src2->ne[2]; + const int64_t n_heads_kq = src0->ne[2]; const int64_t n_seqs = src0->ne[3]; + GGML_ASSERT(n_heads % n_heads_kq == 0); + const int64_t gqa_ratio = n_heads / n_heads_kq; // Dimension validation // Q/K: [head_dim, n_tokens, n_heads, n_seqs] - GGML_ASSERT(src1->ne[0] == head_dim && src1->ne[1] == n_tokens && src1->ne[2] == n_heads && src1->ne[3] == n_seqs); + GGML_ASSERT(src1->ne[0] == head_dim && src1->ne[1] == n_tokens && src1->ne[2] == n_heads_kq && src1->ne[3] == n_seqs); // V: [head_dim, n_tokens, n_heads, n_seqs] GGML_ASSERT(src2->ne[0] == head_dim && src2->ne[1] == n_tokens && src2->ne[2] == n_heads && src2->ne[3] == n_seqs); // G: [n_tokens, 1, n_heads, n_seqs] @@ -247,6 +256,7 @@ void ggml_cuda_op_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst) GGML_ASSERT(ggml_nelements(dst) == output_size + state_size); const float eps = 1e-6f; + int repeat_type = dst->op_params[0]; GGML_ASSERT(head_dim <= 256); // Reasonable limit for shared memory @@ -262,7 +272,7 @@ void ggml_cuda_op_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst) (const float *)src4->data, (const float *)src5->data, (float *)dst->data, - head_dim, n_tokens, n_heads, n_seqs, eps, + head_dim, n_tokens, n_heads, gqa_ratio, repeat_type, n_seqs, eps, device_id, cc, ctx.stream()); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 268a81ba..7faba6bb 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -9903,10 +9903,11 @@ struct ggml_tensor * ggml_delta_net( GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == n_tokens && k->ne[2] == H_k && k->ne[3] == n_seqs); GGML_ASSERT(v->ne[1] == n_tokens && v->ne[3] == n_seqs); - GGML_ASSERT(g->ne[0] == n_tokens && g->ne[1] == 1 && g->ne[2] == H_k && g->ne[3] == n_seqs); - GGML_ASSERT(beta->ne[0] == 1 && beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[3] == n_seqs); + GGML_ASSERT(g->ne[0] == n_tokens && g->ne[1] == 1 && g->ne[2] == H_v && g->ne[3] == n_seqs); + GGML_ASSERT(beta->ne[0] == 1 && beta->ne[1] == n_tokens && beta->ne[2] == H_v && beta->ne[3] == n_seqs); GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs); - GGML_ASSERT(H_k == H_v); + //GGML_ASSERT(H_k == H_v); + GGML_ASSERT(H_v % H_k == 0); const int64_t output_size = S_v * H_v * n_tokens * n_seqs; const int64_t state_size = S_v * S_v * H_v * n_seqs; @@ -22548,8 +22549,10 @@ static void ggml_compute_forward_delta_net_f32( const int64_t head_dim = src0->ne[0]; const int64_t n_tokens = src0->ne[1]; - const int64_t n_heads = src0->ne[2]; + const int64_t n_heads = src2->ne[2]; const int64_t n_seqs = src0->ne[3]; + GGML_ASSERT(src2->ne[2] % src0->ne[2] == 0); + const int gqa_ratio = src2->ne[2]/src0->ne[2]; const int64_t output_size = head_dim * n_tokens * n_heads * n_seqs; @@ -22565,7 +22568,9 @@ static void ggml_compute_forward_delta_net_f32( const int ith = params->ith; const int nth = params->nth; - if (iqk_fused_delta_net(head_dim, n_heads, n_tokens, n_seqs, q_data, k_data, v_data, g_data, beta_data, state_in, + int repeat_type = dst->op_params[0]; + + if (iqk_fused_delta_net(head_dim, n_heads, gqa_ratio, repeat_type, n_tokens, n_seqs, q_data, k_data, v_data, g_data, beta_data, state_in, out_data, state_out, ith, nth)) { return; } @@ -22584,8 +22589,10 @@ static void ggml_compute_forward_delta_net_f32( for (int64_t h_idx = h_start; h_idx < h_end; ++h_idx) { const int64_t batch_idx = h_idx / n_heads; const int64_t head_idx = h_idx % n_heads; + const int64_t head_idx_kq = repeat_type == 0 ? head_idx / gqa_ratio : head_idx % (n_heads/gqa_ratio); const int64_t qkv_head_offset = batch_idx * (head_dim * n_tokens * n_heads) + head_idx * (head_dim * n_tokens); + const int64_t qkv_head_offset_kq = batch_idx * (head_dim * n_tokens * n_heads/gqa_ratio) + head_idx_kq * (head_dim * n_tokens); const int64_t qkv_token_stride = head_dim; const int64_t g_head_offset = batch_idx * (n_tokens * n_heads) + head_idx * n_tokens; const int64_t state_head_offset = batch_idx * (head_dim * head_dim * n_heads) + head_idx * (head_dim * head_dim); @@ -22599,8 +22606,8 @@ static void ggml_compute_forward_delta_net_f32( float * state = state_out + state_head_offset; for (int64_t t = 0; t < n_tokens; ++t) { - const float * q_t = q_data + qkv_head_offset + t * qkv_token_stride; - const float * k_t = k_data + qkv_head_offset + t * qkv_token_stride; + const float * q_t = q_data + qkv_head_offset_kq + t * qkv_token_stride; + const float * k_t = k_data + qkv_head_offset_kq + t * qkv_token_stride; const float * v_t = v_data + qkv_head_offset + t * qkv_token_stride; const float g_val = g_data[g_head_offset + t]; diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index de302a1f..2f482524 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -1386,7 +1386,7 @@ bool iqk_flash_attn_impl(int int_type_k, // type of k namespace { #ifdef __ARM_NEON template -void iqk_fused_delta_net_neon_impl(int n_heads, int n_tokens, int n_seqs, +void iqk_fused_delta_net_neon_impl(int n_heads, int gqa_ratio, int repeat_type, int n_tokens, int n_seqs, const float * q_data, const float * k_data, const float * v_data, const float * g_data, const float * beta_data, const float * state_in, float * out_data, float * state_out, int ith, int nth) { const int total_heads = n_heads * n_seqs; @@ -1406,8 +1406,10 @@ void iqk_fused_delta_net_neon_impl(int n_heads, int n_tokens, int n_seqs, for (int h_idx = h_start; h_idx < h_end; ++h_idx) { const int batch_idx = h_idx / n_heads; const int head_idx = h_idx % n_heads; + const int head_idx_kq = repeat_type == 0 ? head_idx / gqa_ratio : head_idx % (n_heads/gqa_ratio); - const int qkv_head_offset = batch_idx * (head_dim * n_tokens * n_heads) + head_idx * (head_dim * n_tokens); + const int qkv_head_offset = batch_idx * (head_dim * n_tokens * n_heads) + head_idx * (head_dim * n_tokens); + const int qkv_head_offset_kq = batch_idx * (head_dim * n_tokens * n_heads/gqa_ratio) + head_idx_kq * (head_dim * n_tokens); const int qkv_token_stride = head_dim; const int g_head_offset = batch_idx * (n_tokens * n_heads) + head_idx * n_tokens; const int state_head_offset = batch_idx * (head_dim * head_dim * n_heads) + head_idx * (head_dim * head_dim); @@ -1422,8 +1424,8 @@ void iqk_fused_delta_net_neon_impl(int n_heads, int n_tokens, int n_seqs, for (int t = 0; t < n_tokens; ++t) { - const float * q_t = q_data + qkv_head_offset + t * qkv_token_stride; - const float * k_t = k_data + qkv_head_offset + t * qkv_token_stride; + const float * q_t = q_data + qkv_head_offset_kq + t * qkv_token_stride; + const float * k_t = k_data + qkv_head_offset_kq + t * qkv_token_stride; const float * v_t = v_data + qkv_head_offset + t * qkv_token_stride; const float g_val = g_data[g_head_offset + t]; @@ -1492,11 +1494,11 @@ void iqk_fused_delta_net_neon_impl(int n_heads, int n_tokens, int n_seqs, } #endif template -void iqk_fused_delta_net_impl(int n_heads, int n_tokens, int n_seqs, +void iqk_fused_delta_net_impl(int n_heads, int gqa_ratio, int repeat_type, int n_tokens, int n_seqs, const float * q_data, const float * k_data, const float * v_data, const float * g_data, const float * beta_data, const float * state_in, float * out_data, float * state_out, int ith, int nth) { #ifdef __ARM_NEON - iqk_fused_delta_net_neon_impl(n_heads, n_tokens, n_seqs, q_data, k_data, v_data, g_data, beta_data, state_in, out_data, state_out, ith, nth); + iqk_fused_delta_net_neon_impl(n_heads, gqa_ratio, repeat_type, n_tokens, n_seqs, q_data, k_data, v_data, g_data, beta_data, state_in, out_data, state_out, ith, nth); return; #endif const int total_heads = n_heads * n_seqs; @@ -1520,8 +1522,10 @@ void iqk_fused_delta_net_impl(int n_heads, int n_tokens, int n_seqs, for (int h_idx = h_start; h_idx < h_end; ++h_idx) { const int batch_idx = h_idx / n_heads; const int head_idx = h_idx % n_heads; + const int head_idx_kq = repeat_type == 0 ? head_idx / gqa_ratio : head_idx % (n_heads/gqa_ratio); - const int qkv_head_offset = batch_idx * (head_dim * n_tokens * n_heads) + head_idx * (head_dim * n_tokens); + const int qkv_head_offset = batch_idx * (head_dim * n_tokens * n_heads) + head_idx * (head_dim * n_tokens); + const int qkv_head_offset_kq = batch_idx * (head_dim * n_tokens * n_heads/gqa_ratio) + head_idx_kq * (head_dim * n_tokens); const int qkv_token_stride = head_dim; const int g_head_offset = batch_idx * (n_tokens * n_heads) + head_idx * n_tokens; const int state_head_offset = batch_idx * (head_dim * head_dim * n_heads) + head_idx * (head_dim * head_dim); @@ -1535,8 +1539,8 @@ void iqk_fused_delta_net_impl(int n_heads, int n_tokens, int n_seqs, float * state = state_out + state_head_offset; for (int t = 0; t < n_tokens; ++t) { - const float * q_t = q_data + qkv_head_offset + t * qkv_token_stride; - const float * k_t = k_data + qkv_head_offset + t * qkv_token_stride; + const float * q_t = q_data + qkv_head_offset_kq + t * qkv_token_stride; + const float * k_t = k_data + qkv_head_offset_kq + t * qkv_token_stride; const float * v_t = v_data + qkv_head_offset + t * qkv_token_stride; const float g_val = g_data[g_head_offset + t]; @@ -1658,17 +1662,17 @@ void iqk_fused_delta_net_impl(int n_heads, int n_tokens, int n_seqs, } } -bool iqk_fused_delta_net(int head_dim, int n_heads, int n_tokens, int n_seqs, +bool iqk_fused_delta_net(int head_dim, int n_heads, int gqa_ratio, int repeat_type, int n_tokens, int n_seqs, const float * q_data, const float * k_data, const float * v_data, const float * g_data, const float * beta_data, const float * state_in, float * out_data, float * state_out, int ith, int nth) { if (head_dim != 64 && head_dim != 128) { return false; } if (head_dim == 64) { - iqk_fused_delta_net_impl<64>(n_heads, n_tokens, n_seqs, q_data, k_data, v_data, g_data, beta_data, state_in, + iqk_fused_delta_net_impl<64>(n_heads, gqa_ratio, repeat_type, n_tokens, n_seqs, q_data, k_data, v_data, g_data, beta_data, state_in, out_data, state_out, ith, nth); } else { - iqk_fused_delta_net_impl<128>(n_heads, n_tokens, n_seqs, q_data, k_data, v_data, g_data, beta_data, state_in, + iqk_fused_delta_net_impl<128>(n_heads, gqa_ratio, repeat_type, n_tokens, n_seqs, q_data, k_data, v_data, g_data, beta_data, state_in, out_data, state_out, ith, nth); } return true; @@ -1707,7 +1711,7 @@ extern "C" IQK_API bool iqk_moe_fused_up_gate(long /*Nx*/, long /*Ny*/, long /*n return false; } -bool iqk_fused_delta_net(int, int, int, int, +bool iqk_fused_delta_net(int, int, int, int, int, int, const float *, const float *, const float *, const float *, const float *, const float *, float *, float *, int, int) { return false; diff --git a/ggml/src/iqk/iqk_mul_mat.h b/ggml/src/iqk/iqk_mul_mat.h index 440bc815..1bd999cc 100644 --- a/ggml/src/iqk/iqk_mul_mat.h +++ b/ggml/src/iqk/iqk_mul_mat.h @@ -73,7 +73,7 @@ IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias, IQK_API void iqk_topk_moe(int n_experts, int n_experts_used, int nrows, const float * logits, float * weights, int32_t * ids, int ith, int nth); -IQK_API bool iqk_fused_delta_net(int head_dim, int n_heads, int n_tokens, int n_seqs, +IQK_API bool iqk_fused_delta_net(int head_dim, int n_heads, int gqa_ratio, int repeat_type, int n_tokens, int n_seqs, const float * q_data, const float * k_data, const float * v_data, const float * g_data, const float * beta_data, const float * state_in, float * out_data, float * state_out, int ith, int nth); diff --git a/src/llama-delta-net.cpp b/src/llama-delta-net.cpp index 47a6ba0e..3d460a1f 100644 --- a/src/llama-delta-net.cpp +++ b/src/llama-delta-net.cpp @@ -77,7 +77,7 @@ delta_net::~delta_net() = default; std::pair delta_net::build_fused_delta_net(ggml_context * ctx0, ggml_tensor * q, ggml_tensor * k, ggml_tensor * v, ggml_tensor * g, ggml_tensor * beta, ggml_tensor * state, - int il, const llm_build_cb & cb) { + int il, const llm_build_cb & cb, int repeat_type) { const int64_t S_k = q->ne[0]; const int64_t H_k = q->ne[1]; @@ -94,7 +94,8 @@ std::pair delta_net::build_fused_delta_net(ggml_co GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs); GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v && state->ne[2] == H_v && state->ne[3] == n_seqs); - GGML_ASSERT(H_k == H_v); + //GGML_ASSERT(H_k == H_v); + GGML_ASSERT(H_v % H_k == 0); cb(q, "q_in", il); cb(k, "k_in", il); @@ -112,8 +113,8 @@ std::pair delta_net::build_fused_delta_net(ggml_co q = ggml_cont_4d(ctx0, q, S_k, n_tokens, H_k, n_seqs); k = ggml_cont_4d(ctx0, k, S_k, n_tokens, H_k, n_seqs); v = ggml_cont_4d(ctx0, v, S_v, n_tokens, H_v, n_seqs); - g = ggml_cont_4d(ctx0, g, n_tokens, 1, H_k, n_seqs); - beta = ggml_cont_4d(ctx0, beta, 1, n_tokens, H_k, n_seqs); + g = ggml_cont_4d(ctx0, g, n_tokens, 1, H_v, n_seqs); + beta = ggml_cont_4d(ctx0, beta, 1, n_tokens, H_v, n_seqs); } ggml_tensor * state_flat = ggml_reshape_4d(ctx0, state, S_v, S_v * H_v, 1, n_seqs); @@ -130,6 +131,7 @@ std::pair delta_net::build_fused_delta_net(ggml_co ggml_tensor * fused_result = ggml_delta_net(ctx0, q, k, v, g, beta, state_flat); cb(fused_result, "delta_net_fused_raw", il); + fused_result->op_params[0] = repeat_type; const int64_t output_size = S_v * H_v * n_tokens * n_seqs; const int64_t state_size = S_v * S_v * H_v * n_seqs; @@ -359,33 +361,10 @@ ggml_tensor * delta_net::build_layer_attn_linear_core(ggml_context * ctx0, ggml_ q_conv = ggml_l2_norm(ctx0, q_conv, eps_norm); k_conv = ggml_l2_norm(ctx0, k_conv, eps_norm); + cb(q_conv, "q_conv_normed", il); + cb(k_conv, "k_conv_normed", il); - if (num_k_heads != num_v_heads) { - GGML_ASSERT(num_v_heads % num_k_heads == 0); - if (model.layers[il].ssm_beta_alpha) { - const int64_t repeat_factor = num_v_heads / num_k_heads; - - ggml_tensor * q_reshaped = ggml_reshape_3d(ctx0, q_conv, head_k_dim, 1, num_k_heads * n_tok); - ggml_tensor * k_reshaped = ggml_reshape_3d(ctx0, k_conv, head_k_dim, 1, num_k_heads * n_tok); - - ggml_tensor * q_repeated = ggml_repeat_4d(ctx0, q_reshaped, head_k_dim, repeat_factor, num_k_heads * n_tok, 1); - ggml_tensor * k_repeated = ggml_repeat_4d(ctx0, k_reshaped, head_k_dim, repeat_factor, num_k_heads * n_tok, 1); - cb(q_repeated, "q_repeated", il); - cb(k_repeated, "k_repeated", il); - - q_conv = ggml_reshape_4d(ctx0, q_repeated, head_k_dim, num_k_heads * repeat_factor, n_tok, 1); - k_conv = ggml_reshape_4d(ctx0, k_repeated, head_k_dim, num_k_heads * repeat_factor, n_tok, 1); - } else { - q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs); - k_conv = ggml_repeat_4d(ctx0, k_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs); - } - } - - cb(q_conv, "q_conv_predelta", il); - cb(k_conv, "k_conv_predelta", il); - cb(v_conv, "v_conv_predelta", il); - - auto [output, new_state] = build_fused_delta_net(ctx0, q_conv, k_conv, v_conv, gate, beta, state, il, cb); + auto [output, new_state] = build_fused_delta_net(ctx0, q_conv, k_conv, v_conv, gate, beta, state, il, cb, model.layers[il].ssm_beta_alpha ? 0 : 1); cb(output, "attn_output", il); cb(new_state, "new_state", il); diff --git a/src/llama-delta-net.h b/src/llama-delta-net.h index 5dcf18f3..f34cb52c 100644 --- a/src/llama-delta-net.h +++ b/src/llama-delta-net.h @@ -11,7 +11,7 @@ struct delta_net { static std::pair build_fused_delta_net(ggml_context * ctx0, ggml_tensor * q, ggml_tensor * k, ggml_tensor * v, ggml_tensor * g, ggml_tensor * beta, ggml_tensor * state, - int il, const llm_build_cb & cb); + int il, const llm_build_cb & cb, int repeat_type); std::pair build_qkvz(ggml_context * ctx0, ggml_tensor * input, int il, const llm_build_cb & cb, ggml_cgraph * gf) const;