metal: add GDN partial rollback

Extend the gated delta net kernel to store intermediate states for
partial rollback support on the Metal backend.

- Add K (snapshot slot count) as a function constant
- Read input state from slot 0 of the 3D state tensor
- Write intermediate states to different slots during token loop
- For K=1, maintain backward-compatible single-slot behavior

Ref: 8c05923630

Assisted-by: llama.cpp:local pi
This commit is contained in:
Georgi Gerganov 2026-05-14 10:24:09 +03:00
parent 8c05923630
commit 6eb6d84e46
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
2 changed files with 41 additions and 10 deletions

View File

@ -590,6 +590,8 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net(
const int ne20 = op->src[2]->ne[0]; // S_v
const int ne21 = op->src[2]->ne[1]; // H
const int ne30 = op->src[3]->ne[0]; // G
// state is src[5], 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count.
const int K = op->src[5]->ne[1];
const int nsg = op->src[2]->ne[0]/32;
@ -598,7 +600,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net(
GGML_ASSERT(ne20 % 32 == 0);
snprintf(base, 256, "kernel_gated_delta_net_%s_%d", ggml_type_name(op->src[0]->type), nsg);
snprintf(name, 256, "%s_ne20=%d_ne30=%d", base, ne20, ne30);
snprintf(name, 256, "%s_ne20=%d_ne30=%d_K=%d", base, ne20, ne30, K);
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
@ -606,6 +608,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net(
ggml_metal_cv_set_int16(cv, ne20, FC_GATED_DELTA_NET + 0);
ggml_metal_cv_set_int16(cv, ne30, FC_GATED_DELTA_NET + 1);
ggml_metal_cv_set_int16(cv, K, FC_GATED_DELTA_NET + 2);
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);

View File

@ -2531,6 +2531,7 @@ kernel void kernel_rwkv_wkv7_f32(
constant short FC_gated_delta_net_ne20 [[function_constant(FC_GATED_DELTA_NET + 0)]];
constant short FC_gated_delta_net_ne30 [[function_constant(FC_GATED_DELTA_NET + 1)]];
constant short FC_gated_delta_net_K [[function_constant(FC_GATED_DELTA_NET + 2)]];
#if 1
template<short NSG>
@ -2552,17 +2553,21 @@ kernel void kernel_gated_delta_net_impl(
const uint tx = tpitg.x;
const uint ty = tpitg.y;
const uint i23 = tgpig.z; // B
const uint i21 = tgpig.y; // H
const uint i20 = tgpig.x*NSG + ty;
const uint i23 = tgpig.z; // B (n_seqs)
const uint i21 = tgpig.y; // H (head)
const uint i20 = tgpig.x*NSG + ty; // row within S_v
const uint i01 = i21 % args.ne01;
const uint i11 = i21 % args.ne11;
const float scale = 1.0f / sqrt((float)S_v);
const uint K = FC_gated_delta_net_K;
// input state layout (D, K, n_seqs): per-seq stride is K*H*D; we read slot 0.
// state is stored transposed: M[i20][is] = S[is][i20], so row i20 is contiguous
device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20*S_v;
const uint state_in_base = (i23*K*args.ne21 + i21)*S_v*S_v + i20*S_v;
device const float * s_ptr = (device const float *) (s) + state_in_base;
float ls[NSG];
@ -2580,6 +2585,17 @@ kernel void kernel_gated_delta_net_impl(
device const float * b_ptr = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21);
device const float * g_ptr = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G;
// snapshot slot mapping: target_slot = t - shift. When n_tokens < K, only the last
// n_tokens slots are written; earlier slots are left untouched (caller-owned).
const int shift = (int)args.ne22 - (int)K;
// output state base offset: after attention scores
const uint attn_size = args.ne22 * args.ne21 * S_v * args.ne23;
// output state per-slot size: S_v * S_v * H * n_seqs
const uint state_size_per_snap = S_v * S_v * args.ne21 * args.ne23;
// per-(seq,head) offset within a slot
const uint state_out_base = (i23*args.ne21 + i21)*S_v*S_v + i20*S_v;
for (short t = 0; t < args.ne22; t++) {
float s_k = 0.0f;
@ -2627,13 +2643,25 @@ kernel void kernel_gated_delta_net_impl(
b_ptr += args.ne21;
g_ptr += args.ne21*G;
if (K > 1u) {
const int target_slot = (int)t - shift;
if (target_slot >= 0 && target_slot < (int)K) {
device float * dst_state = (device float *) (dst) + attn_size + (uint)target_slot * state_size_per_snap + state_out_base;
FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
dst_state[is] = ls[j];
}
}
}
}
device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20*S_v;
FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
dst_state[is] = ls[j];
if (K == 1u) {
device float * dst_state = (device float *) (dst) + attn_size + state_out_base;
FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
dst_state[is] = ls[j];
}
}
#undef S_v