mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-06-28 04:30:15 -05:00
Do not repeat yourself (#1373)
* DRY - part 1 * DRY - part 2 * DRY - part 3 * Fix NEON --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
parent
082addead2
commit
277fc1d26f
@ -37,6 +37,8 @@ __global__ void delta_net_recurrent_f32(
|
||||
const float * __restrict__ state_in, // [HEAD_DIM, HEAD_DIM*n_heads, 1, n_seqs]
|
||||
float * __restrict__ dst, // output + new_state concatenated
|
||||
const int64_t n_heads,
|
||||
const int64_t gqa_ratio,
|
||||
const int repeat_type,
|
||||
const int64_t n_tokens,
|
||||
const int64_t n_seqs,
|
||||
const int64_t output_offset, // offset where state starts in output
|
||||
@ -46,6 +48,7 @@ __global__ void delta_net_recurrent_f32(
|
||||
const int sub_head_idx = blockIdx.x % (warps_per_head*n_heads);
|
||||
const int head_idx = sub_head_idx / warps_per_head;
|
||||
const int sub_idx = sub_head_idx % warps_per_head;
|
||||
const int head_idx_kq = repeat_type == 0 ? head_idx / gqa_ratio : head_idx % (n_heads/gqa_ratio);
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
// Strides for input tensors (column-major)
|
||||
@ -53,6 +56,7 @@ __global__ void delta_net_recurrent_f32(
|
||||
const int64_t qkv_stride_token = HEAD_DIM;
|
||||
const int64_t qkv_stride_head = HEAD_DIM * n_tokens;
|
||||
const int64_t qkv_stride_batch = HEAD_DIM * n_tokens * n_heads;
|
||||
const int64_t qkv_stride_batch_kq = qkv_stride_batch / gqa_ratio;
|
||||
|
||||
// G/Beta: [n_tokens, 1, n_heads, n_seqs] / [1, n_tokens, n_heads, n_seqs]
|
||||
const int64_t g_stride_head = n_tokens;
|
||||
@ -66,8 +70,8 @@ __global__ void delta_net_recurrent_f32(
|
||||
const int64_t state_batch_stride = HEAD_DIM * HEAD_DIM * n_heads;
|
||||
|
||||
// Pointers for this batch/head
|
||||
const float * q_ptr = q + batch_idx * qkv_stride_batch + head_idx * qkv_stride_head;
|
||||
const float * k_ptr = k + batch_idx * qkv_stride_batch + head_idx * qkv_stride_head;
|
||||
const float * q_ptr = q + batch_idx * qkv_stride_batch_kq + head_idx_kq * qkv_stride_head;
|
||||
const float * k_ptr = k + batch_idx * qkv_stride_batch_kq + head_idx_kq * qkv_stride_head;
|
||||
const float * v_ptr = v + batch_idx * qkv_stride_batch + head_idx * qkv_stride_head;
|
||||
const float * g_ptr = g + batch_idx * g_stride_batch + head_idx * g_stride_head;
|
||||
const float * beta_ptr = beta_in + batch_idx * g_stride_batch + head_idx * g_stride_head;
|
||||
@ -171,6 +175,8 @@ static void delta_net_f32_cuda(
|
||||
const int64_t head_dim,
|
||||
const int64_t n_tokens,
|
||||
const int64_t n_heads,
|
||||
const int64_t gqa_ratio,
|
||||
const int repeat_type,
|
||||
const int64_t n_seqs,
|
||||
const float eps,
|
||||
const int device_id,
|
||||
@ -193,19 +199,19 @@ static void delta_net_f32_cuda(
|
||||
constexpr int threads_per_block = 256;
|
||||
if (head_dim == 64) {
|
||||
delta_net_recurrent_f32<64, threads_per_block><<<num_blocks, threads_per_block, smem_size, stream>>>(
|
||||
q, k, v, g, beta, state_in, dst, n_heads, n_tokens, n_seqs, output_offset, eps);
|
||||
q, k, v, g, beta, state_in, dst, n_heads, gqa_ratio, repeat_type, n_tokens, n_seqs, output_offset, eps);
|
||||
} else {
|
||||
delta_net_recurrent_f32<128, threads_per_block><<<num_blocks, threads_per_block, smem_size, stream>>>(
|
||||
q, k, v, g, beta, state_in, dst, n_heads, n_tokens, n_seqs, output_offset, eps);
|
||||
q, k, v, g, beta, state_in, dst, n_heads, gqa_ratio, repeat_type, n_tokens, n_seqs, output_offset, eps);
|
||||
}
|
||||
} else {
|
||||
constexpr int threads_per_block = 128;
|
||||
if (head_dim == 64) {
|
||||
delta_net_recurrent_f32<64, threads_per_block><<<num_blocks, threads_per_block, smem_size, stream>>>(
|
||||
q, k, v, g, beta, state_in, dst, n_heads, n_tokens, n_seqs, output_offset, eps);
|
||||
q, k, v, g, beta, state_in, dst, n_heads, gqa_ratio, repeat_type, n_tokens, n_seqs, output_offset, eps);
|
||||
} else {
|
||||
delta_net_recurrent_f32<128, threads_per_block><<<num_blocks, threads_per_block, smem_size, stream>>>(
|
||||
q, k, v, g, beta, state_in, dst, n_heads, n_tokens, n_seqs, output_offset, eps);
|
||||
q, k, v, g, beta, state_in, dst, n_heads, gqa_ratio, repeat_type, n_tokens, n_seqs, output_offset, eps);
|
||||
}
|
||||
}
|
||||
|
||||
@ -226,12 +232,15 @@ void ggml_cuda_op_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
|
||||
|
||||
const int64_t head_dim = src0->ne[0];
|
||||
const int64_t n_tokens = src0->ne[1];
|
||||
const int64_t n_heads = src0->ne[2];
|
||||
const int64_t n_heads = src2->ne[2];
|
||||
const int64_t n_heads_kq = src0->ne[2];
|
||||
const int64_t n_seqs = src0->ne[3];
|
||||
GGML_ASSERT(n_heads % n_heads_kq == 0);
|
||||
const int64_t gqa_ratio = n_heads / n_heads_kq;
|
||||
|
||||
// Dimension validation
|
||||
// Q/K: [head_dim, n_tokens, n_heads, n_seqs]
|
||||
GGML_ASSERT(src1->ne[0] == head_dim && src1->ne[1] == n_tokens && src1->ne[2] == n_heads && src1->ne[3] == n_seqs);
|
||||
GGML_ASSERT(src1->ne[0] == head_dim && src1->ne[1] == n_tokens && src1->ne[2] == n_heads_kq && src1->ne[3] == n_seqs);
|
||||
// V: [head_dim, n_tokens, n_heads, n_seqs]
|
||||
GGML_ASSERT(src2->ne[0] == head_dim && src2->ne[1] == n_tokens && src2->ne[2] == n_heads && src2->ne[3] == n_seqs);
|
||||
// G: [n_tokens, 1, n_heads, n_seqs]
|
||||
@ -247,6 +256,7 @@ void ggml_cuda_op_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
|
||||
GGML_ASSERT(ggml_nelements(dst) == output_size + state_size);
|
||||
|
||||
const float eps = 1e-6f;
|
||||
int repeat_type = dst->op_params[0];
|
||||
|
||||
GGML_ASSERT(head_dim <= 256); // Reasonable limit for shared memory
|
||||
|
||||
@ -262,7 +272,7 @@ void ggml_cuda_op_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
|
||||
(const float *)src4->data,
|
||||
(const float *)src5->data,
|
||||
(float *)dst->data,
|
||||
head_dim, n_tokens, n_heads, n_seqs, eps,
|
||||
head_dim, n_tokens, n_heads, gqa_ratio, repeat_type, n_seqs, eps,
|
||||
device_id, cc,
|
||||
ctx.stream());
|
||||
|
||||
|
||||
@ -9903,10 +9903,11 @@ struct ggml_tensor * ggml_delta_net(
|
||||
|
||||
GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == n_tokens && k->ne[2] == H_k && k->ne[3] == n_seqs);
|
||||
GGML_ASSERT(v->ne[1] == n_tokens && v->ne[3] == n_seqs);
|
||||
GGML_ASSERT(g->ne[0] == n_tokens && g->ne[1] == 1 && g->ne[2] == H_k && g->ne[3] == n_seqs);
|
||||
GGML_ASSERT(beta->ne[0] == 1 && beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[3] == n_seqs);
|
||||
GGML_ASSERT(g->ne[0] == n_tokens && g->ne[1] == 1 && g->ne[2] == H_v && g->ne[3] == n_seqs);
|
||||
GGML_ASSERT(beta->ne[0] == 1 && beta->ne[1] == n_tokens && beta->ne[2] == H_v && beta->ne[3] == n_seqs);
|
||||
GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs);
|
||||
GGML_ASSERT(H_k == H_v);
|
||||
//GGML_ASSERT(H_k == H_v);
|
||||
GGML_ASSERT(H_v % H_k == 0);
|
||||
|
||||
const int64_t output_size = S_v * H_v * n_tokens * n_seqs;
|
||||
const int64_t state_size = S_v * S_v * H_v * n_seqs;
|
||||
@ -22548,8 +22549,10 @@ static void ggml_compute_forward_delta_net_f32(
|
||||
|
||||
const int64_t head_dim = src0->ne[0];
|
||||
const int64_t n_tokens = src0->ne[1];
|
||||
const int64_t n_heads = src0->ne[2];
|
||||
const int64_t n_heads = src2->ne[2];
|
||||
const int64_t n_seqs = src0->ne[3];
|
||||
GGML_ASSERT(src2->ne[2] % src0->ne[2] == 0);
|
||||
const int gqa_ratio = src2->ne[2]/src0->ne[2];
|
||||
|
||||
const int64_t output_size = head_dim * n_tokens * n_heads * n_seqs;
|
||||
|
||||
@ -22565,7 +22568,9 @@ static void ggml_compute_forward_delta_net_f32(
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
if (iqk_fused_delta_net(head_dim, n_heads, n_tokens, n_seqs, q_data, k_data, v_data, g_data, beta_data, state_in,
|
||||
int repeat_type = dst->op_params[0];
|
||||
|
||||
if (iqk_fused_delta_net(head_dim, n_heads, gqa_ratio, repeat_type, n_tokens, n_seqs, q_data, k_data, v_data, g_data, beta_data, state_in,
|
||||
out_data, state_out, ith, nth)) {
|
||||
return;
|
||||
}
|
||||
@ -22584,8 +22589,10 @@ static void ggml_compute_forward_delta_net_f32(
|
||||
for (int64_t h_idx = h_start; h_idx < h_end; ++h_idx) {
|
||||
const int64_t batch_idx = h_idx / n_heads;
|
||||
const int64_t head_idx = h_idx % n_heads;
|
||||
const int64_t head_idx_kq = repeat_type == 0 ? head_idx / gqa_ratio : head_idx % (n_heads/gqa_ratio);
|
||||
|
||||
const int64_t qkv_head_offset = batch_idx * (head_dim * n_tokens * n_heads) + head_idx * (head_dim * n_tokens);
|
||||
const int64_t qkv_head_offset_kq = batch_idx * (head_dim * n_tokens * n_heads/gqa_ratio) + head_idx_kq * (head_dim * n_tokens);
|
||||
const int64_t qkv_token_stride = head_dim;
|
||||
const int64_t g_head_offset = batch_idx * (n_tokens * n_heads) + head_idx * n_tokens;
|
||||
const int64_t state_head_offset = batch_idx * (head_dim * head_dim * n_heads) + head_idx * (head_dim * head_dim);
|
||||
@ -22599,8 +22606,8 @@ static void ggml_compute_forward_delta_net_f32(
|
||||
float * state = state_out + state_head_offset;
|
||||
|
||||
for (int64_t t = 0; t < n_tokens; ++t) {
|
||||
const float * q_t = q_data + qkv_head_offset + t * qkv_token_stride;
|
||||
const float * k_t = k_data + qkv_head_offset + t * qkv_token_stride;
|
||||
const float * q_t = q_data + qkv_head_offset_kq + t * qkv_token_stride;
|
||||
const float * k_t = k_data + qkv_head_offset_kq + t * qkv_token_stride;
|
||||
const float * v_t = v_data + qkv_head_offset + t * qkv_token_stride;
|
||||
|
||||
const float g_val = g_data[g_head_offset + t];
|
||||
|
||||
@ -1386,7 +1386,7 @@ bool iqk_flash_attn_impl(int int_type_k, // type of k
|
||||
namespace {
|
||||
#ifdef __ARM_NEON
|
||||
template <int head_dim>
|
||||
void iqk_fused_delta_net_neon_impl(int n_heads, int n_tokens, int n_seqs,
|
||||
void iqk_fused_delta_net_neon_impl(int n_heads, int gqa_ratio, int repeat_type, int n_tokens, int n_seqs,
|
||||
const float * q_data, const float * k_data, const float * v_data, const float * g_data, const float * beta_data,
|
||||
const float * state_in, float * out_data, float * state_out, int ith, int nth) {
|
||||
const int total_heads = n_heads * n_seqs;
|
||||
@ -1406,8 +1406,10 @@ void iqk_fused_delta_net_neon_impl(int n_heads, int n_tokens, int n_seqs,
|
||||
for (int h_idx = h_start; h_idx < h_end; ++h_idx) {
|
||||
const int batch_idx = h_idx / n_heads;
|
||||
const int head_idx = h_idx % n_heads;
|
||||
const int head_idx_kq = repeat_type == 0 ? head_idx / gqa_ratio : head_idx % (n_heads/gqa_ratio);
|
||||
|
||||
const int qkv_head_offset = batch_idx * (head_dim * n_tokens * n_heads) + head_idx * (head_dim * n_tokens);
|
||||
const int qkv_head_offset = batch_idx * (head_dim * n_tokens * n_heads) + head_idx * (head_dim * n_tokens);
|
||||
const int qkv_head_offset_kq = batch_idx * (head_dim * n_tokens * n_heads/gqa_ratio) + head_idx_kq * (head_dim * n_tokens);
|
||||
const int qkv_token_stride = head_dim;
|
||||
const int g_head_offset = batch_idx * (n_tokens * n_heads) + head_idx * n_tokens;
|
||||
const int state_head_offset = batch_idx * (head_dim * head_dim * n_heads) + head_idx * (head_dim * head_dim);
|
||||
@ -1422,8 +1424,8 @@ void iqk_fused_delta_net_neon_impl(int n_heads, int n_tokens, int n_seqs,
|
||||
|
||||
|
||||
for (int t = 0; t < n_tokens; ++t) {
|
||||
const float * q_t = q_data + qkv_head_offset + t * qkv_token_stride;
|
||||
const float * k_t = k_data + qkv_head_offset + t * qkv_token_stride;
|
||||
const float * q_t = q_data + qkv_head_offset_kq + t * qkv_token_stride;
|
||||
const float * k_t = k_data + qkv_head_offset_kq + t * qkv_token_stride;
|
||||
const float * v_t = v_data + qkv_head_offset + t * qkv_token_stride;
|
||||
|
||||
const float g_val = g_data[g_head_offset + t];
|
||||
@ -1492,11 +1494,11 @@ void iqk_fused_delta_net_neon_impl(int n_heads, int n_tokens, int n_seqs,
|
||||
}
|
||||
#endif
|
||||
template <int head_dim>
|
||||
void iqk_fused_delta_net_impl(int n_heads, int n_tokens, int n_seqs,
|
||||
void iqk_fused_delta_net_impl(int n_heads, int gqa_ratio, int repeat_type, int n_tokens, int n_seqs,
|
||||
const float * q_data, const float * k_data, const float * v_data, const float * g_data, const float * beta_data,
|
||||
const float * state_in, float * out_data, float * state_out, int ith, int nth) {
|
||||
#ifdef __ARM_NEON
|
||||
iqk_fused_delta_net_neon_impl<head_dim>(n_heads, n_tokens, n_seqs, q_data, k_data, v_data, g_data, beta_data, state_in, out_data, state_out, ith, nth);
|
||||
iqk_fused_delta_net_neon_impl<head_dim>(n_heads, gqa_ratio, repeat_type, n_tokens, n_seqs, q_data, k_data, v_data, g_data, beta_data, state_in, out_data, state_out, ith, nth);
|
||||
return;
|
||||
#endif
|
||||
const int total_heads = n_heads * n_seqs;
|
||||
@ -1520,8 +1522,10 @@ void iqk_fused_delta_net_impl(int n_heads, int n_tokens, int n_seqs,
|
||||
for (int h_idx = h_start; h_idx < h_end; ++h_idx) {
|
||||
const int batch_idx = h_idx / n_heads;
|
||||
const int head_idx = h_idx % n_heads;
|
||||
const int head_idx_kq = repeat_type == 0 ? head_idx / gqa_ratio : head_idx % (n_heads/gqa_ratio);
|
||||
|
||||
const int qkv_head_offset = batch_idx * (head_dim * n_tokens * n_heads) + head_idx * (head_dim * n_tokens);
|
||||
const int qkv_head_offset = batch_idx * (head_dim * n_tokens * n_heads) + head_idx * (head_dim * n_tokens);
|
||||
const int qkv_head_offset_kq = batch_idx * (head_dim * n_tokens * n_heads/gqa_ratio) + head_idx_kq * (head_dim * n_tokens);
|
||||
const int qkv_token_stride = head_dim;
|
||||
const int g_head_offset = batch_idx * (n_tokens * n_heads) + head_idx * n_tokens;
|
||||
const int state_head_offset = batch_idx * (head_dim * head_dim * n_heads) + head_idx * (head_dim * head_dim);
|
||||
@ -1535,8 +1539,8 @@ void iqk_fused_delta_net_impl(int n_heads, int n_tokens, int n_seqs,
|
||||
float * state = state_out + state_head_offset;
|
||||
|
||||
for (int t = 0; t < n_tokens; ++t) {
|
||||
const float * q_t = q_data + qkv_head_offset + t * qkv_token_stride;
|
||||
const float * k_t = k_data + qkv_head_offset + t * qkv_token_stride;
|
||||
const float * q_t = q_data + qkv_head_offset_kq + t * qkv_token_stride;
|
||||
const float * k_t = k_data + qkv_head_offset_kq + t * qkv_token_stride;
|
||||
const float * v_t = v_data + qkv_head_offset + t * qkv_token_stride;
|
||||
|
||||
const float g_val = g_data[g_head_offset + t];
|
||||
@ -1658,17 +1662,17 @@ void iqk_fused_delta_net_impl(int n_heads, int n_tokens, int n_seqs,
|
||||
}
|
||||
}
|
||||
|
||||
bool iqk_fused_delta_net(int head_dim, int n_heads, int n_tokens, int n_seqs,
|
||||
bool iqk_fused_delta_net(int head_dim, int n_heads, int gqa_ratio, int repeat_type, int n_tokens, int n_seqs,
|
||||
const float * q_data, const float * k_data, const float * v_data, const float * g_data, const float * beta_data,
|
||||
const float * state_in, float * out_data, float * state_out, int ith, int nth) {
|
||||
if (head_dim != 64 && head_dim != 128) {
|
||||
return false;
|
||||
}
|
||||
if (head_dim == 64) {
|
||||
iqk_fused_delta_net_impl<64>(n_heads, n_tokens, n_seqs, q_data, k_data, v_data, g_data, beta_data, state_in,
|
||||
iqk_fused_delta_net_impl<64>(n_heads, gqa_ratio, repeat_type, n_tokens, n_seqs, q_data, k_data, v_data, g_data, beta_data, state_in,
|
||||
out_data, state_out, ith, nth);
|
||||
} else {
|
||||
iqk_fused_delta_net_impl<128>(n_heads, n_tokens, n_seqs, q_data, k_data, v_data, g_data, beta_data, state_in,
|
||||
iqk_fused_delta_net_impl<128>(n_heads, gqa_ratio, repeat_type, n_tokens, n_seqs, q_data, k_data, v_data, g_data, beta_data, state_in,
|
||||
out_data, state_out, ith, nth);
|
||||
}
|
||||
return true;
|
||||
@ -1707,7 +1711,7 @@ extern "C" IQK_API bool iqk_moe_fused_up_gate(long /*Nx*/, long /*Ny*/, long /*n
|
||||
return false;
|
||||
}
|
||||
|
||||
bool iqk_fused_delta_net(int, int, int, int,
|
||||
bool iqk_fused_delta_net(int, int, int, int, int, int,
|
||||
const float *, const float *, const float *, const float *, const float *,
|
||||
const float *, float *, float *, int, int) {
|
||||
return false;
|
||||
|
||||
@ -73,7 +73,7 @@ IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias,
|
||||
IQK_API void iqk_topk_moe(int n_experts, int n_experts_used, int nrows, const float * logits,
|
||||
float * weights, int32_t * ids, int ith, int nth);
|
||||
|
||||
IQK_API bool iqk_fused_delta_net(int head_dim, int n_heads, int n_tokens, int n_seqs,
|
||||
IQK_API bool iqk_fused_delta_net(int head_dim, int n_heads, int gqa_ratio, int repeat_type, int n_tokens, int n_seqs,
|
||||
const float * q_data, const float * k_data, const float * v_data, const float * g_data, const float * beta_data,
|
||||
const float * state_in, float * out_data, float * state_out, int ith, int nth);
|
||||
|
||||
|
||||
@ -77,7 +77,7 @@ delta_net::~delta_net() = default;
|
||||
std::pair<ggml_tensor *, ggml_tensor *> delta_net::build_fused_delta_net(ggml_context * ctx0,
|
||||
ggml_tensor * q, ggml_tensor * k, ggml_tensor * v,
|
||||
ggml_tensor * g, ggml_tensor * beta, ggml_tensor * state,
|
||||
int il, const llm_build_cb & cb) {
|
||||
int il, const llm_build_cb & cb, int repeat_type) {
|
||||
|
||||
const int64_t S_k = q->ne[0];
|
||||
const int64_t H_k = q->ne[1];
|
||||
@ -94,7 +94,8 @@ std::pair<ggml_tensor *, ggml_tensor *> delta_net::build_fused_delta_net(ggml_co
|
||||
GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
|
||||
GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
|
||||
GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v && state->ne[2] == H_v && state->ne[3] == n_seqs);
|
||||
GGML_ASSERT(H_k == H_v);
|
||||
//GGML_ASSERT(H_k == H_v);
|
||||
GGML_ASSERT(H_v % H_k == 0);
|
||||
|
||||
cb(q, "q_in", il);
|
||||
cb(k, "k_in", il);
|
||||
@ -112,8 +113,8 @@ std::pair<ggml_tensor *, ggml_tensor *> delta_net::build_fused_delta_net(ggml_co
|
||||
q = ggml_cont_4d(ctx0, q, S_k, n_tokens, H_k, n_seqs);
|
||||
k = ggml_cont_4d(ctx0, k, S_k, n_tokens, H_k, n_seqs);
|
||||
v = ggml_cont_4d(ctx0, v, S_v, n_tokens, H_v, n_seqs);
|
||||
g = ggml_cont_4d(ctx0, g, n_tokens, 1, H_k, n_seqs);
|
||||
beta = ggml_cont_4d(ctx0, beta, 1, n_tokens, H_k, n_seqs);
|
||||
g = ggml_cont_4d(ctx0, g, n_tokens, 1, H_v, n_seqs);
|
||||
beta = ggml_cont_4d(ctx0, beta, 1, n_tokens, H_v, n_seqs);
|
||||
}
|
||||
|
||||
ggml_tensor * state_flat = ggml_reshape_4d(ctx0, state, S_v, S_v * H_v, 1, n_seqs);
|
||||
@ -130,6 +131,7 @@ std::pair<ggml_tensor *, ggml_tensor *> delta_net::build_fused_delta_net(ggml_co
|
||||
|
||||
ggml_tensor * fused_result = ggml_delta_net(ctx0, q, k, v, g, beta, state_flat);
|
||||
cb(fused_result, "delta_net_fused_raw", il);
|
||||
fused_result->op_params[0] = repeat_type;
|
||||
|
||||
const int64_t output_size = S_v * H_v * n_tokens * n_seqs;
|
||||
const int64_t state_size = S_v * S_v * H_v * n_seqs;
|
||||
@ -359,33 +361,10 @@ ggml_tensor * delta_net::build_layer_attn_linear_core(ggml_context * ctx0, ggml_
|
||||
|
||||
q_conv = ggml_l2_norm(ctx0, q_conv, eps_norm);
|
||||
k_conv = ggml_l2_norm(ctx0, k_conv, eps_norm);
|
||||
cb(q_conv, "q_conv_normed", il);
|
||||
cb(k_conv, "k_conv_normed", il);
|
||||
|
||||
if (num_k_heads != num_v_heads) {
|
||||
GGML_ASSERT(num_v_heads % num_k_heads == 0);
|
||||
if (model.layers[il].ssm_beta_alpha) {
|
||||
const int64_t repeat_factor = num_v_heads / num_k_heads;
|
||||
|
||||
ggml_tensor * q_reshaped = ggml_reshape_3d(ctx0, q_conv, head_k_dim, 1, num_k_heads * n_tok);
|
||||
ggml_tensor * k_reshaped = ggml_reshape_3d(ctx0, k_conv, head_k_dim, 1, num_k_heads * n_tok);
|
||||
|
||||
ggml_tensor * q_repeated = ggml_repeat_4d(ctx0, q_reshaped, head_k_dim, repeat_factor, num_k_heads * n_tok, 1);
|
||||
ggml_tensor * k_repeated = ggml_repeat_4d(ctx0, k_reshaped, head_k_dim, repeat_factor, num_k_heads * n_tok, 1);
|
||||
cb(q_repeated, "q_repeated", il);
|
||||
cb(k_repeated, "k_repeated", il);
|
||||
|
||||
q_conv = ggml_reshape_4d(ctx0, q_repeated, head_k_dim, num_k_heads * repeat_factor, n_tok, 1);
|
||||
k_conv = ggml_reshape_4d(ctx0, k_repeated, head_k_dim, num_k_heads * repeat_factor, n_tok, 1);
|
||||
} else {
|
||||
q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs);
|
||||
k_conv = ggml_repeat_4d(ctx0, k_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs);
|
||||
}
|
||||
}
|
||||
|
||||
cb(q_conv, "q_conv_predelta", il);
|
||||
cb(k_conv, "k_conv_predelta", il);
|
||||
cb(v_conv, "v_conv_predelta", il);
|
||||
|
||||
auto [output, new_state] = build_fused_delta_net(ctx0, q_conv, k_conv, v_conv, gate, beta, state, il, cb);
|
||||
auto [output, new_state] = build_fused_delta_net(ctx0, q_conv, k_conv, v_conv, gate, beta, state, il, cb, model.layers[il].ssm_beta_alpha ? 0 : 1);
|
||||
|
||||
cb(output, "attn_output", il);
|
||||
cb(new_state, "new_state", il);
|
||||
|
||||
@ -11,7 +11,7 @@ struct delta_net {
|
||||
static std::pair<ggml_tensor *, ggml_tensor *> build_fused_delta_net(ggml_context * ctx0,
|
||||
ggml_tensor * q, ggml_tensor * k, ggml_tensor * v,
|
||||
ggml_tensor * g, ggml_tensor * beta, ggml_tensor * state,
|
||||
int il, const llm_build_cb & cb);
|
||||
int il, const llm_build_cb & cb, int repeat_type);
|
||||
|
||||
std::pair<ggml_tensor *, ggml_tensor *> build_qkvz(ggml_context * ctx0, ggml_tensor * input, int il, const llm_build_cb & cb, ggml_cgraph * gf) const;
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user