Nexes the Elder 2d3ecd5e19
Fix minor CUDA discrepancies (part 2) (#2015)
* fix: wrong tensor index in BF16 fused RMS norm add path (norm.cu:1039)

The BF16 branch of ggml_cuda_op_fused_rms_rms_add used dst->src[2]->data
for the second weight pointer, but should have used dst->src[3]->data.
This caused reading float weights from the wrong bf16 input tensor.

The F32 and F16 branches both correctly reference src[3], and the
assertions at lines 1013-1015 confirm src[3] is the F32 weight tensor.

* fix: off-by-one bounds check in 7 dmmv kernels (row > nrows -> row >= nrows)

Seven K-quant dequantize_mul_mat_vec kernels used row > nrows for bounds
checking instead of row >= nrows. Since rows are 0-indexed (0..nrows-1),
the check missed the row == nrows case, allowing a potential out-of-bounds
memory write when grid dimensions produce exactly nrows.

The templated dequantize_mul_mat_vec<type> kernel at line 667 already used
the correct row >= nrows pattern.

* fix: typo in function name iqk_mul_mat_vec_q_kerne -> iqk_mul_mat_vec_q_kernel

Truncated function name in iqk_mmvq_templates.cuh was missing trailing 'l'.

* fix: print actual split_dim value in set_tensor error message (ggml-cuda.cu)

fprintf used extra->split_dim == 0 which evaluates to boolean 0 or 1
instead of the actual split dimension value. When this fatal error is
hit for an unsupported split_dim, the user could not diagnose which
value caused the problem.

* fix: wrong src index in gate bias stride for fused up-gate MoE path

ggml_cuda_add_id for the gate bias used dst->src[4]->nb[1] as the stride
argument instead of dst->src[5]->nb[1]. This was a copy-paste error from
the up-bias code (lines 3220-3224) where src[4] is correct. If src[4]
and src[5] have different strides, the bias addition produces incorrect
results.

* fix: wrong row count for gate projection MMQ in fused up-gate MoE path

ggml_cuda_op_mul_mat_q for the gate projection (src0_2) used
src0_1->ne[1] as row_high instead of src0_2->ne[1]. This copy-paste
error causes processing the wrong number of rows if the up and gate
projections have different row counts. The gemv path (line ~3563)
correctly used src0_2->ne[1].
2026-06-23 14:03:22 +02:00

1056 lines
40 KiB
Plaintext

#include "norm.cuh"
template <int block_size, typename T>
static __global__ void norm_f32(const T * x, float * dst, const int ncols, const float eps) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;
float2 mean_var = make_float2(0.f, 0.f);
for (int col = tid; col < ncols; col += block_size) {
const float xi = (float)x[row*ncols + col];
mean_var.x += xi;
mean_var.y += xi * xi;
}
// sum up partial sums
mean_var = warp_reduce_sum(mean_var);
if (block_size > WARP_SIZE) {
__shared__ float2 s_sum[32];
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = mean_var;
}
__syncthreads();
mean_var = s_sum[lane_id];
mean_var = warp_reduce_sum(mean_var);
}
const float mean = mean_var.x / ncols;
const float var = mean_var.y / ncols - mean * mean;
const float inv_std = rsqrtf(var + eps);
for (int col = tid; col < ncols; col += block_size) {
dst[row*ncols + col] = (T)(((float)x[row*ncols + col] - mean) * inv_std);
}
}
template <int block_size, typename T>
static __global__ void fused_norm_f32(const T * x, const float * c, float * dst, const int ncols, const float eps) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;
float2 mean_var = make_float2(0.f, 0.f);
if constexpr (std::is_same_v<T, block_q8_0>) {
static_assert(block_size % QK8_0 == 0);
auto xr = x + (row*ncols)/QK8_0;
for (int col = tid; col < ncols; col += block_size) {
const float xi = (float)xr[col / QK8_0].d * xr[col / QK8_0].qs[col % QK8_0];
mean_var.x += xi;
mean_var.y += xi * xi;
}
} else {
for (int col = tid; col < ncols; col += block_size) {
const float xi = (float)x[row*ncols + col];
mean_var.x += xi;
mean_var.y += xi * xi;
}
}
// sum up partial sums
mean_var = warp_reduce_sum(mean_var);
if (block_size > WARP_SIZE) {
__shared__ float2 s_sum[32];
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = mean_var;
}
__syncthreads();
mean_var = s_sum[lane_id];
mean_var = warp_reduce_sum(mean_var);
}
const float mean = mean_var.x / ncols;
const float var = mean_var.y / ncols - mean * mean;
const float inv_std = rsqrtf(var + eps);
if constexpr (std::is_same_v<T, block_q8_0>) {
static_assert(block_size % QK8_0 == 0);
auto xr = x + (row*ncols)/QK8_0;
for (int col = tid; col < ncols; col += block_size) {
dst[row*ncols + col] = ((float)xr[col/QK8_0].d*xr[col/QK8_0].qs[col%QK8_0] - mean) * inv_std * c[col];
}
} else {
for (int col = tid; col < ncols; col += block_size) {
dst[row*ncols + col] = ((float)x[row*ncols + col] - mean) * inv_std * c[col];
}
}
}
template <int block_size>
static __global__ void group_norm_f32(const float * x, float * dst, const int group_size, const int ne_elements, const float eps) {
// blockIdx.x: num_groups idx
// threadIdx.x: block_size idx
int start = blockIdx.x * group_size;
int end = start + group_size;
start += threadIdx.x;
if (end >= ne_elements) {
end = ne_elements;
}
float tmp = 0.0f; // partial sum for thread in warp
for (int j = start; j < end; j += block_size) {
tmp += x[j];
}
tmp = warp_reduce_sum(tmp);
if (block_size > WARP_SIZE) {
__shared__ float s_sum[32];
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = tmp;
}
__syncthreads();
tmp = s_sum[lane_id];
tmp = warp_reduce_sum(tmp);
}
float mean = tmp / group_size;
tmp = 0.0f;
for (int j = start; j < end; j += block_size) {
float xi = x[j] - mean;
dst[j] = xi;
tmp += xi * xi;
}
tmp = warp_reduce_sum(tmp);
if (block_size > WARP_SIZE) {
__shared__ float s_sum[32];
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = tmp;
}
__syncthreads();
tmp = s_sum[lane_id];
tmp = warp_reduce_sum(tmp);
}
float variance = tmp / group_size;
float scale = rsqrtf(variance + eps);
for (int j = start; j < end; j += block_size) {
dst[j] *= scale;
}
}
template <int block_size>
static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;
float tmp = 0.0f; // partial sum for thread in warp
for (int col = tid; col < ncols; col += block_size) {
const float xi = x[row*ncols + col];
tmp += xi * xi;
}
// sum up partial sums
tmp = warp_reduce_sum(tmp);
if (block_size > WARP_SIZE) {
__shared__ float s_sum[32];
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = tmp;
}
__syncthreads();
tmp = lane_id < block_size/WARP_SIZE ? s_sum[lane_id] : 0.0f;
tmp = warp_reduce_sum(tmp);
}
const float mean = tmp / ncols;
const float scale = rsqrtf(mean + eps);
for (int col = tid; col < ncols; col += block_size) {
dst[row*ncols + col] = scale * x[row*ncols + col];
}
}
template <int block_size>
static __global__ void l2_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
const int row = blockIdx.x * blockDim.y + threadIdx.y;
const int tid = threadIdx.x;
float tmp = 0.0f;
for (int col = tid; col < ncols; col += block_size) {
const float xi = x[row * ncols + col];
tmp += xi * xi;
}
tmp = warp_reduce_sum(tmp);
if (block_size > WARP_SIZE) {
__shared__ float s_sum[32];
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = tmp;
}
__syncthreads();
tmp = lane_id < block_size / WARP_SIZE ? s_sum[lane_id] : 0.0f;
tmp = warp_reduce_sum(tmp);
}
const float scale = rsqrtf(fmaxf(tmp, eps * eps));
for (int col = tid; col < ncols; col += block_size) {
dst[row * ncols + col] = scale * x[row * ncols + col];
}
}
template <int block_size>
static __global__ void rms_norm_f32_nc(
const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
const int64_t stride_sample, const float eps) {
const int nrows = gridDim.x;
const int nchannels = gridDim.y;
const int row = blockIdx.x;
const int channel = blockIdx.y;
const int sample = blockIdx.z;
const int tid = threadIdx.x;
x += sample*stride_sample + channel*stride_channel + row*stride_row;
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
float tmp = 0.0f; // partial sum for thread in warp
for (int col = tid; col < ncols; col += block_size) {
const float xi = x[col];
tmp += xi * xi;
}
// sum up partial sums
tmp = warp_reduce_sum(tmp);
if constexpr (block_size > WARP_SIZE) {
static_assert(block_size == 1024, "unexpected block_size");
__shared__ float s_sum[32];
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = tmp;
}
__syncthreads();
tmp = s_sum[lane_id];
tmp = warp_reduce_sum(tmp);
}
const float mean = tmp / ncols;
const float scale = rsqrtf(mean + eps);
for (int col = tid; col < ncols; col += block_size) {
dst[col] = scale * x[col];
}
}
template <int block_size>
static __global__ void l2_norm_f32_nc(
const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
const int64_t stride_sample, const float eps) {
const int nrows = gridDim.x;
const int nchannels = gridDim.y;
const int row = blockIdx.x;
const int channel = blockIdx.y;
const int sample = blockIdx.z;
const int tid = threadIdx.x;
x += sample * stride_sample + channel * stride_channel + row * stride_row;
dst += ((sample * nchannels + channel) * nrows + row) * ncols;
float tmp = 0.0f;
for (int col = tid; col < ncols; col += block_size) {
const float xi = x[col];
tmp += xi * xi;
}
tmp = warp_reduce_sum(tmp);
if constexpr (block_size > WARP_SIZE) {
static_assert(block_size == 1024, "unexpected block_size");
__shared__ float s_sum[32];
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = tmp;
}
__syncthreads();
tmp = s_sum[lane_id];
tmp = warp_reduce_sum(tmp);
}
const float scale = rsqrtf(fmaxf(tmp, eps * eps));
for (int col = tid; col < ncols; col += block_size) {
dst[col] = scale * x[col];
}
}
template <int block_size, typename src_t>
static __global__ void fused_rms_norm_f32(const src_t * x, const float * y, float * dst, const int ncols, const float eps) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;
float tmp = 0.0f; // partial sum for thread in warp
if constexpr (std::is_same_v<src_t, block_q8_0>) {
static_assert(block_size % QK8_0 == 0);
auto xr = x + (row*ncols)/QK8_0;
for (int col = tid; col < ncols; col += block_size) {
const float xi = (float)xr[col / QK8_0].d * xr[col / QK8_0].qs[col % QK8_0];
tmp += xi * xi;
}
} else if constexpr (std::is_same_v<src_t, nv_bfloat16>) {
for (int col = tid; col < ncols; col += block_size) {
const float xi = __bfloat162float(x[row*ncols + col]);
tmp += xi * xi;
}
} else {
for (int col = tid; col < ncols; col += block_size) {
const float xi = (float)x[row*ncols + col];
tmp += xi * xi;
}
}
// sum up partial sums
tmp = warp_reduce_sum(tmp);
if (block_size > WARP_SIZE) {
__shared__ float s_sum[32];
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = tmp;
}
__syncthreads();
tmp = lane_id < block_size/WARP_SIZE ? s_sum[lane_id] : 0.0f;
tmp = warp_reduce_sum(tmp);
}
const float mean = tmp / ncols;
const float scale = rsqrtf(mean + eps);
if constexpr (std::is_same_v<src_t, block_q8_0>) {
auto xr = x + (row*ncols)/QK8_0;
for (int col = tid; col < ncols; col += block_size) {
dst[row*ncols + col] = scale * y[col] * (float)xr[col / QK8_0].d * xr[col / QK8_0].qs[col % QK8_0];
}
} else if constexpr (std::is_same_v<src_t, nv_bfloat16>) {
for (int col = tid; col < ncols; col += block_size) {
dst[row*ncols + col] = scale * y[col] * __bfloat162float(x[row*ncols + col]);
}
} else {
for (int col = tid; col < ncols; col += block_size) {
dst[row*ncols + col] = scale * y[col] * (float)x[row*ncols + col];
}
}
}
template <int block_size, typename src_t>
static __global__ void fused_rms_norm_f32_nc(
const src_t * x, const float * y, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
const int64_t stride_sample, const float eps) {
const int nrows = gridDim.x;
const int nchannels = gridDim.y;
const int row = blockIdx.x;
const int channel = blockIdx.y;
//const int channel = blockIdx.y * blockDim.y + threadIdx.y;
const int sample = blockIdx.z;
const int tid = threadIdx.x;
x += sample*stride_sample + channel*stride_channel + row*stride_row;
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
float tmp = 0.0f; // partial sum for thread in warp
for (int col = tid; col < ncols; col += block_size) {
const float xi = (float)x[col];
tmp += xi * xi;
}
// sum up partial sums
tmp = warp_reduce_sum(tmp);
if constexpr (block_size > WARP_SIZE) {
static_assert(block_size == 1024, "unexpected block_size");
__shared__ float s_sum[32];
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = tmp;
}
__syncthreads();
tmp = s_sum[lane_id];
//if constexpr (block_size == 1024) {
// tmp = s_sum[lane_id];
//} else {
// tmp = lane_id < block_size/WARP_SIZE ? s_sum[lane_id] : 0.0f;
//}
tmp = warp_reduce_sum(tmp);
}
const float mean = tmp / ncols;
const float scale = rsqrtf(mean + eps);
for (int col = tid; col < ncols; col += block_size) {
dst[col] = scale * y[col] * (float)x[col];
}
}
template <typename T>
static void norm_f32_cuda(const T * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
GGML_ASSERT(ncols % WARP_SIZE == 0);
if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
norm_f32<WARP_SIZE, T><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
} else {
const dim3 block_dims(1024, 1, 1);
norm_f32<1024, T><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
}
}
static void group_norm_f32_cuda(const float * x, float * dst, const int num_groups, const float eps, const int group_size, const int ne_elements, cudaStream_t stream) {
if (group_size < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
group_norm_f32<WARP_SIZE><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);
} else {
const dim3 block_dims(1024, 1, 1);
group_norm_f32<1024><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);
}
}
static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
// Why did we have this assert?
//GGML_ASSERT(ncols % WARP_SIZE == 0);
constexpr int kBlockSize = 256;
if (ncols < 1024) {
const dim3 block_dims(kBlockSize, 1, 1);
rms_norm_f32<kBlockSize><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
} else {
const dim3 block_dims(1024, 1, 1);
rms_norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
}
}
static void rms_norm_f32_nc_cuda(
const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
const dim3 blocks_num(nrows, nchannels, nsamples);
if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
rms_norm_f32_nc<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
} else {
const dim3 block_dims(1024, 1, 1);
rms_norm_f32_nc<1024><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
}
}
static void l2_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
GGML_ASSERT(ncols % WARP_SIZE == 0);
constexpr int kBlockSize = 256;
if (ncols < 1024) {
const dim3 block_dims(kBlockSize, 1, 1);
l2_norm_f32<kBlockSize><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
} else {
const dim3 block_dims(1024, 1, 1);
l2_norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
}
}
static void l2_norm_f32_nc_cuda(
const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
const dim3 blocks_num(nrows, nchannels, nsamples);
if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
l2_norm_f32_nc<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
} else {
const dim3 block_dims(1024, 1, 1);
l2_norm_f32_nc<1024><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
}
}
template <typename src_t>
static void fused_rms_norm_f32_cuda(const src_t * x, const float * y, float * dst,
const int ncols, const int nrows, const float eps, bool is_norm, cudaStream_t stream) {
constexpr int kBlockSize = 256;
GGML_ASSERT(ncols % WARP_SIZE == 0);
if (is_norm) {
if (ncols < kBlockSize) {
switch (ncols) {
case 32: fused_norm_f32< 32><<<nrows, 32, 0, stream>>>(x, y, dst, ncols, eps); break;
case 64: fused_norm_f32< 64><<<nrows, 64, 0, stream>>>(x, y, dst, ncols, eps); break;
case 96: fused_norm_f32< 96><<<nrows, 96, 0, stream>>>(x, y, dst, ncols, eps); break;
case 128: fused_norm_f32<128><<<nrows, 128, 0, stream>>>(x, y, dst, ncols, eps); break;
case 160: fused_norm_f32<160><<<nrows, 160, 0, stream>>>(x, y, dst, ncols, eps); break;
case 192: fused_norm_f32<192><<<nrows, 192, 0, stream>>>(x, y, dst, ncols, eps); break;
default : fused_norm_f32<224><<<nrows, 224, 0, stream>>>(x, y, dst, ncols, eps); break;
}
}
else if (ncols < 1024) {
const dim3 block_dims(kBlockSize, 1, 1);
fused_norm_f32<kBlockSize><<<nrows, block_dims, 0, stream>>>(x, y, dst, ncols, eps);
} else {
const dim3 block_dims(1024, 1, 1);
fused_norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, y, dst, ncols, eps);
}
} else {
if (ncols < kBlockSize) {
switch (ncols) {
case 32: fused_rms_norm_f32< 32><<<nrows, 32, 0, stream>>>(x, y, dst, ncols, eps); break;
case 64: fused_rms_norm_f32< 64><<<nrows, 64, 0, stream>>>(x, y, dst, ncols, eps); break;
case 96: fused_rms_norm_f32< 96><<<nrows, 96, 0, stream>>>(x, y, dst, ncols, eps); break;
case 128: fused_rms_norm_f32<128><<<nrows, 128, 0, stream>>>(x, y, dst, ncols, eps); break;
case 160: fused_rms_norm_f32<160><<<nrows, 160, 0, stream>>>(x, y, dst, ncols, eps); break;
case 192: fused_rms_norm_f32<192><<<nrows, 192, 0, stream>>>(x, y, dst, ncols, eps); break;
default : fused_rms_norm_f32<224><<<nrows, 224, 0, stream>>>(x, y, dst, ncols, eps); break;
}
}
else if (ncols < 1024) {
const dim3 block_dims(kBlockSize, 1, 1);
fused_rms_norm_f32<kBlockSize><<<nrows, block_dims, 0, stream>>>(x, y, dst, ncols, eps);
} else {
const dim3 block_dims(1024, 1, 1);
fused_rms_norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, y, dst, ncols, eps);
}
}
}
template <typename src_t>
static void fused_rms_norm_f32_nc_cuda(
const src_t * x, const float * y, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
const dim3 blocks_num(nrows, nchannels, nsamples);
if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
fused_rms_norm_f32_nc<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, y, dst, ncols, stride_row, stride_channel, stride_sample, eps);
//constexpr int kBlockSize = 256;
//if (nchannels%4 == 0) {
// const dim3 blocks_num(nrows, nchannels/4, nsamples);
// const dim3 block_dims(kBlockSize, 4, 1);
// fused_rms_norm_f32_nc<kBlockSize><<<blocks_num, block_dims, 0, stream>>>(x, y, dst, ncols, stride_row, stride_channel, stride_sample, eps);
//} else {
// const dim3 block_dims(kBlockSize, 1, 1);
// fused_rms_norm_f32_nc<kBlockSize><<<blocks_num, block_dims, 0, stream>>>(x, y, dst, ncols, stride_row, stride_channel, stride_sample, eps);
//}
} else {
const dim3 block_dims(1024, 1, 1);
fused_rms_norm_f32_nc<1024><<<blocks_num, block_dims, 0, stream>>>(x, y, dst, ncols, stride_row, stride_channel, stride_sample, eps);
}
}
void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
const int64_t ne00 = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
if (src0->type == GGML_TYPE_F32) {
norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
} else {
norm_f32_cuda((const half *)src0_d, dst_d, ne00, nrows, eps, stream);
}
}
void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
int num_groups = dst->op_params[0];
float eps;
memcpy(&eps, dst->op_params + 1, sizeof(float));
int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
group_norm_f32_cuda(src0_d, dst_d, num_groups * src0->ne[3], eps, group_size, ggml_nelements(src0), stream);
}
void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
const int64_t ne00 = src0->ne[0];
if (ggml_is_contiguous(src0)) {
const int64_t nrows = ggml_nrows(src0);
rms_norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
} else {
auto ts0 = ggml_type_size(src0->type);
GGML_ASSERT(src0->nb[0] == ts0);
auto s01 = src0->nb[1] / ts0;
auto s02 = src0->nb[2] / ts0;
auto s03 = src0->nb[3] / ts0;
rms_norm_f32_nc_cuda(src0_d, dst_d, ne00, src0->ne[1], src0->ne[2], src0->ne[3], s01, s02, s03, eps, stream);
}
}
void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *) src0->data;
float * dst_d = (float *) dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
float eps = 0.0f;
memcpy(&eps, dst->op_params, sizeof(float));
const int64_t ne00 = src0->ne[0];
if (ggml_is_contiguous(src0)) {
const int64_t nrows = ggml_nrows(src0);
l2_norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
} else {
const size_t ts0 = ggml_type_size(src0->type);
GGML_ASSERT(src0->nb[0] == ts0);
const int64_t s01 = src0->nb[1] / ts0;
const int64_t s02 = src0->nb[2] / ts0;
const int64_t s03 = src0->nb[3] / ts0;
l2_norm_f32_nc_cuda(src0_d, dst_d, ne00, src0->ne[1], src0->ne[2], src0->ne[3], s01, s02, s03, eps, stream);
}
}
void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst, bool is_norm) {
if (!dst->src[1]) {
ggml_cuda_op_rms_norm(ctx, dst);
return;
}
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const float * src0_d = (const float *)src0->data;
const float * src1_d = (const float *)src1->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 ||
(ggml_is_contiguous(src0) && src0->type == GGML_TYPE_Q8_0));
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(src0->ne[0] == src1->ne[0]);
GGML_ASSERT(ggml_nrows(src1) == 1);
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
const int64_t ne00 = src0->ne[0];
if (ggml_is_contiguous(src0)) {
const int64_t nrows = ggml_nrows(src0);
if (src0->type == GGML_TYPE_F32) {
fused_rms_norm_f32_cuda(src0_d, src1_d, dst_d, ne00, nrows, eps, is_norm, stream);
} else if (src0->type == GGML_TYPE_F16) {
fused_rms_norm_f32_cuda((const half *)src0_d, src1_d, dst_d, ne00, nrows, eps, is_norm, stream);
} else if (src0->type == GGML_TYPE_Q8_0) {
fused_rms_norm_f32_cuda((const block_q8_0 *)src0_d, src1_d, dst_d, ne00, nrows, eps, is_norm, stream);
} else {
fused_rms_norm_f32_cuda((const nv_bfloat16 *)src0_d, src1_d, dst_d, ne00, nrows, eps, is_norm, stream);
}
} else {
if (is_norm) {
GGML_ABORT("Non-contiguous norm is not implemented");
}
auto ts0 = ggml_type_size(src0->type);
GGML_ASSERT(src0->nb[0] == ts0);
auto s01 = src0->nb[1] / ts0;
auto s02 = src0->nb[2] / ts0;
auto s03 = src0->nb[3] / ts0;
if (src0->type == GGML_TYPE_F32) {
fused_rms_norm_f32_nc_cuda(src0_d, src1_d, dst_d, ne00, src0->ne[1], src0->ne[2], src0->ne[3], s01, s02, s03, eps, stream);
} else if (src0->type == GGML_TYPE_BF16) {
fused_rms_norm_f32_nc_cuda((const nv_bfloat16 *)src0_d, src1_d, dst_d, ne00, src0->ne[1], src0->ne[2], src0->ne[3], s01, s02, s03, eps, stream);
} else {
fused_rms_norm_f32_nc_cuda((const half *)src0_d, src1_d, dst_d, ne00, src0->ne[1], src0->ne[2], src0->ne[3], s01, s02, s03, eps, stream);
}
}
}
template <int block_size>
static __global__ void fused_add_rms_norm_f32(const float * a, const float * b, const float * c,
float * dst_add, float * dst, const int ncols, const float eps) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;
float tmp = 0.0f; // partial sum for thread in warp
for (int col = tid; col < ncols; col += block_size) {
const float xi = a[row*ncols + col] + b[row*ncols + col];
tmp += xi * xi;
dst_add[row*ncols + col] = xi;
}
// sum up partial sums
tmp = warp_reduce_sum(tmp);
if (block_size > WARP_SIZE) {
__shared__ float s_sum[32];
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = tmp;
}
__syncthreads();
tmp = lane_id < block_size/WARP_SIZE ? s_sum[lane_id] : 0.0f;
tmp = warp_reduce_sum(tmp);
}
const float mean = tmp / ncols;
const float scale = rsqrtf(mean + eps);
for (int col = tid; col < ncols; col += block_size) {
dst[row*ncols + col] = scale * c[col] * dst_add[row*ncols + col];
}
}
template <int block_size>
static __global__ void fused_add_add_rms_norm_f32(const float * a1, const float * a2, const float * b, const float * c,
float * dst_add, float * dst, const int ncols, const float eps) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;
float tmp = 0.0f; // partial sum for thread in warp
for (int col = tid; col < ncols; col += block_size) {
const float xi = a1[row*ncols + col] + a2[row*ncols + col] + b[row*ncols + col];
tmp += xi * xi;
dst_add[row*ncols + col] = xi;
}
// sum up partial sums
tmp = warp_reduce_sum(tmp);
if (block_size > WARP_SIZE) {
__shared__ float s_sum[32];
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = tmp;
}
__syncthreads();
tmp = lane_id < block_size/WARP_SIZE ? s_sum[lane_id] : 0.0f;
tmp = warp_reduce_sum(tmp);
}
const float mean = tmp / ncols;
const float scale = rsqrtf(mean + eps);
for (int col = tid; col < ncols; col += block_size) {
dst[row*ncols + col] = scale * c[col] * dst_add[row*ncols + col];
}
}
static void fused_add_rms_norm_f32_cuda(const float * a, const float * b, const float * c, float * dst_add, float * dst,
const int ncols, const int nrows, const float eps, cudaStream_t stream) {
GGML_ASSERT(ncols % WARP_SIZE == 0);
if (ncols < 1024) {
const dim3 block_dims(256, 1, 1);
fused_add_rms_norm_f32<256><<<nrows, block_dims, 0, stream>>>(a, b, c, dst_add, dst, ncols, eps);
} else {
const dim3 block_dims(1024, 1, 1);
fused_add_rms_norm_f32<1024><<<nrows, block_dims, 0, stream>>>(a, b, c, dst_add, dst, ncols, eps);
}
}
void ggml_cuda_op_fused_add_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * add, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
//const float * src0_d = (const float *)src0->data;
const float * src1_d = (const float *)src1->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(add->data == src0->data);
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(add->src[0]));
GGML_ASSERT(ggml_is_contiguous(add->src[1]));
GGML_ASSERT(ggml_are_same_shape(add->src[0], add->src[1]));
GGML_ASSERT(ggml_are_same_shape(add->src[0], src0));
GGML_ASSERT(add->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(add->src[1]->type == GGML_TYPE_F32);
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(src0->ne[0] == src1->ne[0]);
GGML_ASSERT(ggml_nrows(src1) == 1);
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
const int64_t ne00 = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);
fused_add_rms_norm_f32_cuda((const float *)add->src[0]->data, (const float *)add->src[1]->data,
src1_d, (float *)add->data, dst_d, ne00, nrows, eps, stream);
}
static void fused_add_add_rms_norm_f32_cuda(const float * a1, const float * a2, const float * b, const float * c, float * dst_add, float * dst,
const int ncols, const int nrows, const float eps, cudaStream_t stream) {
GGML_ASSERT(ncols % WARP_SIZE == 0);
if (ncols < 1024) {
const dim3 block_dims(256, 1, 1);
fused_add_add_rms_norm_f32<256><<<nrows, block_dims, 0, stream>>>(a1, a2, b, c, dst_add, dst, ncols, eps);
} else {
const dim3 block_dims(1024, 1, 1);
fused_add_add_rms_norm_f32<1024><<<nrows, block_dims, 0, stream>>>(a1, a2, b, c, dst_add, dst, ncols, eps);
}
}
void ggml_cuda_op_fused_add_add_rms_norm(ggml_backend_cuda_context & ctx,
ggml_tensor * add1, ggml_tensor * add2, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
//const float * src0_d = (const float *)src0->data;
const float * src1_d = (const float *)src1->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(add1->data == add2->src[0]->data);
GGML_ASSERT(add2->data == src0->data);
GGML_ASSERT(ggml_is_contiguous(src0));
//GGML_ASSERT(ggml_is_contiguous(add->src[0]));
//GGML_ASSERT(ggml_is_contiguous(add->src[1]));
//GGML_ASSERT(ggml_are_same_shape(add->src[0], add->src[1]));
//GGML_ASSERT(ggml_are_same_shape(add->src[0], src0));
//GGML_ASSERT(add->src[0]->type == GGML_TYPE_F32);
//GGML_ASSERT(add->src[1]->type == GGML_TYPE_F32);
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(src0->ne[0] == src1->ne[0]);
GGML_ASSERT(ggml_nrows(src1) == 1);
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
const int64_t ne00 = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);
fused_add_add_rms_norm_f32_cuda((const float *)add1->src[0]->data, (const float *)add1->src[1]->data, (const float *)add2->src[1]->data,
src1_d, (float *)add2->data, dst_d, ne00, nrows, eps, stream);
}
template <int block_size>
static __global__ void fused_rms_rms_norm_f32(int ncols, int nrows1, int nrows2, size_t nb1, size_t nb2, float eps,
const char *x1, const char * x2, const float * c1, const float * c2, float * y1, float * y2) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;
auto x_row = (const float *)(row < nrows1 ? x1 + row*nb1 : x2 + (row - nrows1)*nb2);
float tmp = 0.0f; // partial sum for thread in warp
for (int col = tid; col < ncols; col += block_size) {
const float xi = x_row[col];
tmp += xi * xi;
}
// sum up partial sums
tmp = warp_reduce_sum(tmp);
if (block_size > WARP_SIZE) {
__shared__ float s_sum[32];
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = tmp;
}
__syncthreads();
tmp = lane_id < block_size/WARP_SIZE ? s_sum[lane_id] : 0.0f;
tmp = warp_reduce_sum(tmp);
}
const float mean = tmp / ncols;
const float scale = rsqrtf(mean + eps);
auto dst = row < nrows1 ? y1 + row*ncols : y2 + (row - nrows1)*ncols;
auto c = row < nrows1 ? c1 : c2;
for (int col = tid; col < ncols; col += block_size) {
dst[col] = scale * c[col] * x_row[col];
}
}
static void fused_rms_rms_norm_f32_cuda(int ncols, int nrows1, int nrows2, size_t nb1, size_t nb2, float eps,
const char * x1, const char * x2, const float * c1, const float * c2, float * y1, float * y2, cudaStream_t stream) {
GGML_ASSERT(ncols % WARP_SIZE == 0);
int nrows = nrows1 + nrows2;
if (ncols < 1024) {
const dim3 block_dims(256, 1, 1);
fused_rms_rms_norm_f32<256><<<nrows, block_dims, 0, stream>>>(ncols, nrows1, nrows2, nb1, nb2, eps, x1, x2, c1, c2, y1, y2);
} else {
const dim3 block_dims(1024, 1, 1);
fused_rms_rms_norm_f32<1024><<<nrows, block_dims, 0, stream>>>(ncols, nrows1, nrows2, nb1, nb2, eps, x1, x2, c1, c2, y1, y2);
}
}
void ggml_cuda_op_fused_rms_rms_norm([[maybe_unused]] ggml_backend_cuda_context & ctx, [[maybe_unused]] ggml_tensor * rms1, [[maybe_unused]] ggml_tensor * rms2) {
GGML_ASSERT(rms1->ne[2] == 1 && rms1->ne[3] == 1);
GGML_ASSERT(rms2->ne[2] == 1 && rms2->ne[3] == 1);
GGML_ASSERT(rms1->ne[0] == rms2->ne[0]);
GGML_ASSERT(rms1->type == GGML_TYPE_F32);
GGML_ASSERT(rms2->type == GGML_TYPE_F32);
GGML_ASSERT(rms1->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(rms2->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(rms1->src[0]->ne[0] == rms1->src[1]->ne[0]);
GGML_ASSERT(rms2->src[0]->ne[0] == rms2->src[1]->ne[0]);
GGML_ASSERT(ggml_nrows(rms1->src[1]) == 1);
GGML_ASSERT(ggml_nrows(rms2->src[1]) == 1);
GGML_ASSERT(rms1->src[1]->type == GGML_TYPE_F32);
GGML_ASSERT(rms2->src[1]->type == GGML_TYPE_F32);
float eps1, eps2;
memcpy(&eps1, rms1->op_params, sizeof(float));
memcpy(&eps2, rms2->op_params, sizeof(float));
GGML_ASSERT(eps1 == eps2);
fused_rms_rms_norm_f32_cuda(rms1->ne[0], rms1->ne[1], rms2->ne[1], rms1->nb[1], rms2->nb[1], eps1,
(const char *)rms1->src[0]->data, (const char *)rms2->src[0]->data,
(const float *)rms1->src[1]->data, (const float *)rms2->src[1]->data,
(float *)rms1->data, (float *)rms2->data, ctx.stream());
}
template <int block_size, typename src_t>
static __global__ void fused_rms_rms_add_f32(int ncols, int nrows, float * dst,
const src_t * x1, const float * c1, const src_t * x2, const float * c2, float eps) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;
auto x1_row = x1 + row*ncols;
auto x2_row = x2 + row*ncols;
float tmp1 = 0.0f, tmp2 = 0.0f;
for (int col = tid; col < ncols; col += block_size) {
const float xi1 = (float)x1_row[col];
const float xi2 = (float)x2_row[col];
tmp1 += xi1 * xi1;
tmp2 += xi2 * xi2;
}
tmp1 = warp_reduce_sum(tmp1);
tmp2 = warp_reduce_sum(tmp2);
if (block_size > WARP_SIZE) {
__shared__ float s_sum[2*WARP_SIZE];
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[2*warp_id+0] = tmp1;
s_sum[2*warp_id+1] = tmp2;
}
__syncthreads();
tmp1 = lane_id < block_size/WARP_SIZE ? s_sum[2*lane_id+0] : 0.0f;
tmp2 = lane_id < block_size/WARP_SIZE ? s_sum[2*lane_id+1] : 0.0f;
tmp1 = warp_reduce_sum(tmp1);
tmp2 = warp_reduce_sum(tmp2);
}
const float mean1 = tmp1 / ncols;
const float mean2 = tmp2 / ncols;
const float scale1 = rsqrtf(mean1 + eps);
const float scale2 = rsqrtf(mean2 + eps);
dst += row*ncols;
for (int col = tid; col < ncols; col += block_size) {
dst[col] = scale1 * c1[col] * (float)x1_row[col] + scale2 * c2[col] * (float)x2_row[col];
}
}
template <typename src_t>
static void fused_rms_rms_add_f32_cuda(int ncols, int nrows, float * dst,
const src_t * x1, const float * c1, const src_t * x2, const float * c2,
float eps, cudaStream_t stream) {
if (ncols < 1024) {
const dim3 block_dims(256, 1, 1);
fused_rms_rms_add_f32<256><<<nrows, block_dims, 0, stream>>>(ncols, nrows, dst, x1, c1, x2, c2, eps);
} else {
const dim3 block_dims(1024, 1, 1);
fused_rms_rms_add_f32<1024><<<nrows, block_dims, 0, stream>>>(ncols, nrows, dst, x1, c1, x2, c2, eps);
}
}
void ggml_cuda_op_fused_rms_rms_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(ggml_are_same_shape(dst->src[0], dst->src[2]));
GGML_ASSERT(ggml_are_same_shape(dst->src[0], dst));
GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
GGML_ASSERT(ggml_is_contiguous(dst->src[2]));
GGML_ASSERT(ggml_is_contiguous(dst));
GGML_ASSERT(ggml_nrows(dst->src[1]) == 1 && dst->src[1]->ne[0] == dst->src[0]->ne[0]);
GGML_ASSERT(ggml_nrows(dst->src[3]) == 1 && dst->src[3]->ne[0] == dst->src[2]->ne[0]);
GGML_ASSERT(dst->src[0]->type == dst->src[2]->type);
GGML_ASSERT(dst->src[1]->type == GGML_TYPE_F32 && dst->src[3]->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
int nrows = ggml_nrows(dst);
int ncols = dst->ne[0];
if (dst->src[0]->type == GGML_TYPE_F32) {
fused_rms_rms_add_f32_cuda(ncols, nrows, (float *)dst->data,
(const float *)dst->src[0]->data, (const float *)dst->src[1]->data,
(const float *)dst->src[2]->data, (const float *)dst->src[3]->data,
eps, ctx.stream());
}
else if (dst->src[0]->type == GGML_TYPE_F16) {
fused_rms_rms_add_f32_cuda(ncols, nrows, (float *)dst->data,
(const half *)dst->src[0]->data, (const float *)dst->src[1]->data,
(const half *)dst->src[2]->data, (const float *)dst->src[3]->data,
eps, ctx.stream());
}
else if (dst->src[0]->type == GGML_TYPE_BF16) {
fused_rms_rms_add_f32_cuda(ncols, nrows, (float *)dst->data,
(const nv_bfloat16 *)dst->src[0]->data, (const float *)dst->src[1]->data,
(const nv_bfloat16 *)dst->src[2]->data, (const float *)dst->src[3]->data,
eps, ctx.stream());
}
else {
GGML_ABORT("Not implemented");
}
}