fix: add missing __syncthreads in delta net CUDA kernel (#1649)

The token loop reads sK[] in the state update (bottom of loop) but has
no barrier before the next iteration overwrites sK[] (top of loop).
Without an explicit memory fence, hardware/compiler reordering can
cause non-deterministic reads from shared memory.

Per review: with this barrier in place, the prior __syncthreads() after
the cross-warp reduction and the one immediately after loop exit are
both redundant. The new barrier is a full block-level fence that also
orders all_sum1/all_sum2 reads vs. the next iteration's writes, and
every thread reaches it before leaving the loop. Both redundant
barriers removed.

No performance impact — GPU utilization is 31-33% during inference,
bottlenecked by CPU MoE expert computation, not the CUDA kernel.

Co-authored-by: Mark Alonzo <mark.alonzo@outlook.com>
This commit is contained in:
markaalonzo 2026-04-17 15:45:46 -04:00 committed by GitHub
parent 64234e3c4e
commit 52efa12fda
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -141,8 +141,6 @@ __global__ void delta_net_recurrent_f32(
sum1 += all_sum1[i*WARP_SIZE_S + row];
sum2 += all_sum2[i*WARP_SIZE_S + row];
}
// To be honest, I don't understand why we need this sync. But without it I observe results varying from run to run
__syncthreads();
//float sv_new = beta_val * (v_ptr[t * qkv_stride_token + row_out] - sum1 * decay);
float sv_new = beta_val * (v_ptr[t * vnb1 + row_out] - sum1 * decay);
@ -157,8 +155,13 @@ __global__ void delta_net_recurrent_f32(
state_local[i] = new_state_val;
}
// Barrier required: (a) sK reads in the state update above must complete
// before next iteration overwrites sK at the top of the loop, and (b) this
// single barrier also orders all_sum1/all_sum2 reads above vs. the next
// iteration's writes — subsuming the prior barriers after the cross-warp
// reduction and after the loop exit.
__syncthreads();
}
__syncthreads();
// Copy the final state to its destination
for (int i = 0; i < HEAD_DIM/num_warps; ++i) {
int col = num_warps*i + col_idx_0;