mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-06-28 04:30:15 -05:00
fix: add missing __syncthreads in delta net CUDA kernel (#1649)
The token loop reads sK[] in the state update (bottom of loop) but has no barrier before the next iteration overwrites sK[] (top of loop). Without an explicit memory fence, hardware/compiler reordering can cause non-deterministic reads from shared memory. Per review: with this barrier in place, the prior __syncthreads() after the cross-warp reduction and the one immediately after loop exit are both redundant. The new barrier is a full block-level fence that also orders all_sum1/all_sum2 reads vs. the next iteration's writes, and every thread reaches it before leaving the loop. Both redundant barriers removed. No performance impact — GPU utilization is 31-33% during inference, bottlenecked by CPU MoE expert computation, not the CUDA kernel. Co-authored-by: Mark Alonzo <mark.alonzo@outlook.com>
This commit is contained in:
parent
64234e3c4e
commit
52efa12fda
@ -141,8 +141,6 @@ __global__ void delta_net_recurrent_f32(
|
||||
sum1 += all_sum1[i*WARP_SIZE_S + row];
|
||||
sum2 += all_sum2[i*WARP_SIZE_S + row];
|
||||
}
|
||||
// To be honest, I don't understand why we need this sync. But without it I observe results varying from run to run
|
||||
__syncthreads();
|
||||
|
||||
//float sv_new = beta_val * (v_ptr[t * qkv_stride_token + row_out] - sum1 * decay);
|
||||
float sv_new = beta_val * (v_ptr[t * vnb1 + row_out] - sum1 * decay);
|
||||
@ -157,8 +155,13 @@ __global__ void delta_net_recurrent_f32(
|
||||
state_local[i] = new_state_val;
|
||||
}
|
||||
|
||||
// Barrier required: (a) sK reads in the state update above must complete
|
||||
// before next iteration overwrites sK at the top of the loop, and (b) this
|
||||
// single barrier also orders all_sum1/all_sum2 reads above vs. the next
|
||||
// iteration's writes — subsuming the prior barriers after the cross-warp
|
||||
// reduction and after the loop exit.
|
||||
__syncthreads();
|
||||
}
|
||||
__syncthreads();
|
||||
// Copy the final state to its destination
|
||||
for (int i = 0; i < HEAD_DIM/num_warps; ++i) {
|
||||
int col = num_warps*i + col_idx_0;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user