mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-27 23:50:20 -05:00
metal: add GDN partial rollback
Extend the gated delta net kernel to store intermediate states for
partial rollback support on the Metal backend.
- Add K (snapshot slot count) as a function constant
- Read input state from slot 0 of the 3D state tensor
- Write intermediate states to different slots during token loop
- For K=1, maintain backward-compatible single-slot behavior
Ref: 8c05923630
Assisted-by: llama.cpp:local pi
This commit is contained in:
parent
8c05923630
commit
6eb6d84e46
@ -590,6 +590,8 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net(
|
||||
const int ne20 = op->src[2]->ne[0]; // S_v
|
||||
const int ne21 = op->src[2]->ne[1]; // H
|
||||
const int ne30 = op->src[3]->ne[0]; // G
|
||||
// state is src[5], 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count.
|
||||
const int K = op->src[5]->ne[1];
|
||||
|
||||
const int nsg = op->src[2]->ne[0]/32;
|
||||
|
||||
@ -598,7 +600,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net(
|
||||
GGML_ASSERT(ne20 % 32 == 0);
|
||||
|
||||
snprintf(base, 256, "kernel_gated_delta_net_%s_%d", ggml_type_name(op->src[0]->type), nsg);
|
||||
snprintf(name, 256, "%s_ne20=%d_ne30=%d", base, ne20, ne30);
|
||||
snprintf(name, 256, "%s_ne20=%d_ne30=%d_K=%d", base, ne20, ne30, K);
|
||||
|
||||
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (!res.pipeline) {
|
||||
@ -606,6 +608,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net(
|
||||
|
||||
ggml_metal_cv_set_int16(cv, ne20, FC_GATED_DELTA_NET + 0);
|
||||
ggml_metal_cv_set_int16(cv, ne30, FC_GATED_DELTA_NET + 1);
|
||||
ggml_metal_cv_set_int16(cv, K, FC_GATED_DELTA_NET + 2);
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
||||
|
||||
|
||||
@ -2531,6 +2531,7 @@ kernel void kernel_rwkv_wkv7_f32(
|
||||
|
||||
constant short FC_gated_delta_net_ne20 [[function_constant(FC_GATED_DELTA_NET + 0)]];
|
||||
constant short FC_gated_delta_net_ne30 [[function_constant(FC_GATED_DELTA_NET + 1)]];
|
||||
constant short FC_gated_delta_net_K [[function_constant(FC_GATED_DELTA_NET + 2)]];
|
||||
|
||||
#if 1
|
||||
template<short NSG>
|
||||
@ -2552,17 +2553,21 @@ kernel void kernel_gated_delta_net_impl(
|
||||
const uint tx = tpitg.x;
|
||||
const uint ty = tpitg.y;
|
||||
|
||||
const uint i23 = tgpig.z; // B
|
||||
const uint i21 = tgpig.y; // H
|
||||
const uint i20 = tgpig.x*NSG + ty;
|
||||
const uint i23 = tgpig.z; // B (n_seqs)
|
||||
const uint i21 = tgpig.y; // H (head)
|
||||
const uint i20 = tgpig.x*NSG + ty; // row within S_v
|
||||
|
||||
const uint i01 = i21 % args.ne01;
|
||||
const uint i11 = i21 % args.ne11;
|
||||
|
||||
const float scale = 1.0f / sqrt((float)S_v);
|
||||
|
||||
const uint K = FC_gated_delta_net_K;
|
||||
|
||||
// input state layout (D, K, n_seqs): per-seq stride is K*H*D; we read slot 0.
|
||||
// state is stored transposed: M[i20][is] = S[is][i20], so row i20 is contiguous
|
||||
device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20*S_v;
|
||||
const uint state_in_base = (i23*K*args.ne21 + i21)*S_v*S_v + i20*S_v;
|
||||
device const float * s_ptr = (device const float *) (s) + state_in_base;
|
||||
|
||||
float ls[NSG];
|
||||
|
||||
@ -2580,6 +2585,17 @@ kernel void kernel_gated_delta_net_impl(
|
||||
device const float * b_ptr = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21);
|
||||
device const float * g_ptr = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G;
|
||||
|
||||
// snapshot slot mapping: target_slot = t - shift. When n_tokens < K, only the last
|
||||
// n_tokens slots are written; earlier slots are left untouched (caller-owned).
|
||||
const int shift = (int)args.ne22 - (int)K;
|
||||
|
||||
// output state base offset: after attention scores
|
||||
const uint attn_size = args.ne22 * args.ne21 * S_v * args.ne23;
|
||||
// output state per-slot size: S_v * S_v * H * n_seqs
|
||||
const uint state_size_per_snap = S_v * S_v * args.ne21 * args.ne23;
|
||||
// per-(seq,head) offset within a slot
|
||||
const uint state_out_base = (i23*args.ne21 + i21)*S_v*S_v + i20*S_v;
|
||||
|
||||
for (short t = 0; t < args.ne22; t++) {
|
||||
float s_k = 0.0f;
|
||||
|
||||
@ -2627,13 +2643,25 @@ kernel void kernel_gated_delta_net_impl(
|
||||
|
||||
b_ptr += args.ne21;
|
||||
g_ptr += args.ne21*G;
|
||||
|
||||
if (K > 1u) {
|
||||
const int target_slot = (int)t - shift;
|
||||
if (target_slot >= 0 && target_slot < (int)K) {
|
||||
device float * dst_state = (device float *) (dst) + attn_size + (uint)target_slot * state_size_per_snap + state_out_base;
|
||||
FOR_UNROLL (short j = 0; j < NSG; j++) {
|
||||
const short is = tx*NSG + j;
|
||||
dst_state[is] = ls[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20*S_v;
|
||||
|
||||
FOR_UNROLL (short j = 0; j < NSG; j++) {
|
||||
const short is = tx*NSG + j;
|
||||
dst_state[is] = ls[j];
|
||||
if (K == 1u) {
|
||||
device float * dst_state = (device float *) (dst) + attn_size + state_out_base;
|
||||
FOR_UNROLL (short j = 0; j < NSG; j++) {
|
||||
const short is = tx*NSG + j;
|
||||
dst_state[is] = ls[j];
|
||||
}
|
||||
}
|
||||
|
||||
#undef S_v
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user