From 9bebfcb4bc8b12a316e96ae03f33671eac1e72fd Mon Sep 17 00:00:00 2001 From: Neo Zhang Date: Sat, 27 Jun 2026 17:13:43 +0800 Subject: [PATCH] sycl : fix failed ut cases of norm (#25044) --- ggml/src/ggml-sycl/norm.cpp | 150 ++++++++++++++++++++++++------------ 1 file changed, 102 insertions(+), 48 deletions(-) diff --git a/ggml/src/ggml-sycl/norm.cpp b/ggml/src/ggml-sycl/norm.cpp index 09fce1280a..c4472e4bd6 100644 --- a/ggml/src/ggml-sycl/norm.cpp +++ b/ggml/src/ggml-sycl/norm.cpp @@ -2,8 +2,10 @@ #include "ggml-sycl/common.hpp" #include "ggml-sycl/presets.hpp" -static void norm_f32(const float* x, float* dst, const int ncols, const int64_t stride_row, const int64_t stride_channel, - const int64_t stride_sample, const float eps, const sycl::nd_item<3>& item_ct1, sycl::float2* s_sum, int block_size) { +static void norm_f32(const float* x, float* dst, const int ncols, + const int64_t src_stride_col, const int64_t src_stride_row, const int64_t src_stride_channel, const int64_t src_stride_sample, + const int64_t dst_stride_col, const int64_t dst_stride_row, const int64_t dst_stride_channel, const int64_t dst_stride_sample, + const float eps, const sycl::nd_item<3>& item_ct1, sycl::float2* s_sum, int block_size) { const int nrows = item_ct1.get_group_range(2); const int nchannels = item_ct1.get_group_range(1); @@ -16,16 +18,16 @@ static void norm_f32(const float* x, float* dst, const int ncols, const int64_t const int tid = item_ct1.get_local_id(2); const int nwarps = nthreads / WARP_SIZE; - const auto strided_offset = calculate_offset<3>({stride_sample, stride_channel, stride_row}, {sample, channel, row}); - const auto packed_offset = calculate_offset<3>({nchannels * nrows * ncols, nrows * ncols, ncols}, {sample, channel, row}); + const auto src_offset = calculate_offset<3>({src_stride_sample, src_stride_channel, src_stride_row}, {sample, channel, row}); + const auto dst_offset = calculate_offset<3>({dst_stride_sample, dst_stride_channel, dst_stride_row}, {sample, channel, row}); - x += strided_offset; - dst += packed_offset; + x += src_offset; + dst += dst_offset; sycl::float2 mean_var = sycl::float2(0.f, 0.f); for (int col = tid; col < ncols; col += block_size) { - const float xi = x[col]; + const float xi = x[col * src_stride_col]; mean_var.x() += xi; mean_var.y() += xi * xi; } @@ -54,7 +56,7 @@ static void norm_f32(const float* x, float* dst, const int ncols, const int64_t const float inv_std = sycl::rsqrt(var + eps); for (int col = tid; col < ncols; col += block_size) { - dst[col] = (x[col] - mean) * inv_std; + dst[col * dst_stride_col] = (x[col * src_stride_col] - mean) * inv_std; } } @@ -145,8 +147,10 @@ static void group_norm_f32(const float* x, float* dst, const int group_size, con } } -static void rms_norm_f32(const float* x, float* dst, const int ncols, const int64_t stride_row, const int64_t stride_channel, - const int64_t stride_sample, const float eps, const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) { +static void rms_norm_f32(const float* x, float* dst, const int ncols, + const int64_t src_stride_col, const int64_t src_stride_row, const int64_t src_stride_channel, const int64_t src_stride_sample, + const int64_t dst_stride_col, const int64_t dst_stride_row, const int64_t dst_stride_channel, const int64_t dst_stride_sample, + const float eps, const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) { const int nrows = item_ct1.get_group_range(2); const int nchannels = item_ct1.get_group_range(1); @@ -160,17 +164,17 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const int6 const int tid = item_ct1.get_local_id(2); const int nwarps = nthreads / WARP_SIZE; - const auto strided_offset = calculate_offset<3>({stride_sample, stride_channel, stride_row}, {sample, channel, row}); - const auto packed_offset = calculate_offset<3>({nchannels * nrows * ncols, nrows * ncols, ncols}, {sample, channel, row}); + const auto src_offset = calculate_offset<3>({src_stride_sample, src_stride_channel, src_stride_row}, {sample, channel, row}); + const auto dst_offset = calculate_offset<3>({dst_stride_sample, dst_stride_channel, dst_stride_row}, {sample, channel, row}); - x += strided_offset; - dst += packed_offset; + x += src_offset; + dst += dst_offset; float tmp = 0.0f; // partial sum for thread in warp for (int col = tid; col < ncols; col += block_size) { - const float xi = x[col]; + const float xi = x[col * src_stride_col]; tmp += xi * xi; } @@ -198,14 +202,15 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const int6 const float scale = sycl::rsqrt(mean + eps); for (int col = tid; col < ncols; col += block_size) { - dst[col] = scale * x[col]; + dst[col * dst_stride_col] = scale * x[col * src_stride_col]; } } template static void l2_norm_f32(const float * x, float * dst, const int ncols, - const int64_t stride_row, const int64_t stride_channel, - const int64_t stride_sample, const float eps, + const int64_t src_stride_col, const int64_t src_stride_row, const int64_t src_stride_channel, + const int64_t src_stride_sample, const int64_t dst_stride_col, const int64_t dst_stride_row, + const int64_t dst_stride_channel, const int64_t dst_stride_sample, const float eps, const sycl::nd_item<3>& item_ct1, float* s_sum, const int block_size) { const int nrows = item_ct1.get_group_range(2); const int nchannels = item_ct1.get_group_range(1); @@ -215,13 +220,13 @@ static void l2_norm_f32(const float * x, float * dst, const int ncols, const int sample = item_ct1.get_group(0); const int tid = item_ct1.get_local_id(2); - x += sample*stride_sample + channel*stride_channel + row*stride_row; - dst += ((sample*nchannels + channel)*nrows + row)*ncols; + x += sample*src_stride_sample + channel*src_stride_channel + row*src_stride_row; + dst += sample*dst_stride_sample + channel*dst_stride_channel + row*dst_stride_row; float tmp = 0.0f; // partial sum for thread in warp for (int col = tid; col < ncols; col += block_size) { - const float xi = x[col]; + const float xi = x[col * src_stride_col]; tmp += xi * xi; } @@ -229,12 +234,13 @@ static void l2_norm_f32(const float * x, float * dst, const int ncols, const float scale = sycl::rsqrt(sycl::fmax(tmp, eps * eps)); for (int col = tid; col < ncols; col += block_size) { - dst[col] = scale * x[col]; + dst[col * dst_stride_col] = scale * x[col * src_stride_col]; } } static void norm_f32_sycl(const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples, - const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, + const int64_t src_stride_col, const int64_t src_stride_row, const int64_t src_stride_channel, const int64_t src_stride_sample, + const int64_t dst_stride_col, const int64_t dst_stride_row, const int64_t dst_stride_channel, const int64_t dst_stride_sample, const float eps, queue_ptr stream, int device) { const sycl::range<3> global_dims(nsamples, nchannels, nrows); @@ -245,7 +251,10 @@ static void norm_f32_sycl(const float * x, float * dst, const int ncols, const i sycl::nd_range<3>(global_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, nullptr, WARP_SIZE); + norm_f32(x, dst, ncols, + src_stride_col, src_stride_row, src_stride_channel, src_stride_sample, + dst_stride_col, dst_stride_row, dst_stride_channel, dst_stride_sample, + eps, item_ct1, nullptr, WARP_SIZE); }); }); } @@ -265,7 +274,10 @@ static void norm_f32_sycl(const float * x, float * dst, const int ncols, const i sycl::nd_range<3>(global_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size); + norm_f32(x, dst, ncols, + src_stride_col, src_stride_row, src_stride_channel, src_stride_sample, + dst_stride_col, dst_stride_row, dst_stride_channel, dst_stride_sample, + eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size); }); }); } @@ -319,7 +331,9 @@ static void group_norm_f32_sycl(const float* x, float* dst, } static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const int nrows, const int nchannels, const int nsamples, - const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, queue_ptr stream, int device) { + const int64_t src_stride_col, const int64_t src_stride_row, const int64_t src_stride_channel, const int64_t src_stride_sample, + const int64_t dst_stride_col, const int64_t dst_stride_row, const int64_t dst_stride_channel, const int64_t dst_stride_sample, + const float eps, queue_ptr stream, int device) { // printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE); const sycl::range<3> global_dims(nsamples, nchannels, nrows); @@ -330,7 +344,10 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const sycl::nd_range<3>(global_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, nullptr, WARP_SIZE); + rms_norm_f32(x, dst, ncols, + src_stride_col, src_stride_row, src_stride_channel, src_stride_sample, + dst_stride_col, dst_stride_row, dst_stride_channel, dst_stride_sample, + eps, item_ct1, nullptr, WARP_SIZE); }); }); } @@ -350,7 +367,10 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const sycl::nd_range<3>(global_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size); + rms_norm_f32(x, dst, ncols, + src_stride_col, src_stride_row, src_stride_channel, src_stride_sample, + dst_stride_col, dst_stride_row, dst_stride_channel, dst_stride_sample, + eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size); }); }); } @@ -363,9 +383,14 @@ static void l2_norm_f32_sycl(const float * x, const int nrows, const int nchannels, const int nsamples, - const int64_t stride_row, - const int64_t stride_channel, - const int64_t stride_sample, + const int64_t src_stride_col, + const int64_t src_stride_row, + const int64_t src_stride_channel, + const int64_t src_stride_sample, + const int64_t dst_stride_col, + const int64_t dst_stride_row, + const int64_t dst_stride_channel, + const int64_t dst_stride_sample, const float eps, queue_ptr stream, int device) { @@ -379,7 +404,10 @@ static void l2_norm_f32_sycl(const float * x, block_dims), [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(warp_size)]] { - l2_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, + l2_norm_f32(x, dst, ncols, + src_stride_col, src_stride_row, src_stride_channel, src_stride_sample, + dst_stride_col, dst_stride_row, dst_stride_channel, dst_stride_sample, + eps, item_ct1, nullptr, warp_size); }); }); @@ -398,7 +426,9 @@ static void l2_norm_f32_sycl(const float * x, block_dims), [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(warp_size)]] { - l2_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, + l2_norm_f32(x, dst, ncols, + src_stride_col, src_stride_row, src_stride_channel, src_stride_sample, + dst_stride_col, dst_stride_row, dst_stride_channel, dst_stride_sample, eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size); }); }); @@ -421,12 +451,20 @@ void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { memcpy(&eps, dst->op_params, sizeof(float)); GGML_ASSERT(eps >= 0.0f); const size_t ts0 = ggml_type_size(src0->type); - GGML_ASSERT(nb00 == ts0); - const int64_t s01 = nb01 / ts0; - const int64_t s02 = nb02 / ts0; - const int64_t s03 = nb03 / ts0; + const size_t tdst = ggml_type_size(dst->type); + GGML_ASSERT(nb00 % ts0 == 0 && nb01 % ts0 == 0 && nb02 % ts0 == 0 && nb03 % ts0 == 0); + GGML_ASSERT(nb0 % tdst == 0 && nb1 % tdst == 0 && nb2 % tdst == 0 && nb3 % tdst == 0); + const int64_t ss0 = nb00 / ts0; + const int64_t ss1 = nb01 / ts0; + const int64_t ss2 = nb02 / ts0; + const int64_t ss3 = nb03 / ts0; + const int64_t ds0 = nb0 / tdst; + const int64_t ds1 = nb1 / tdst; + const int64_t ds2 = nb2 / tdst; + const int64_t ds3 = nb3 / tdst; - norm_f32_sycl(src0_dd, dst_dd, ne00, ne01, ne02, ne03, s01, s02, s03, eps, main_stream, ctx.device); + norm_f32_sycl(src0_dd, dst_dd, ne00, ne01, ne02, ne03, + ss0, ss1, ss2, ss3, ds0, ds1, ds2, ds3, eps, main_stream, ctx.device); } void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { @@ -465,11 +503,19 @@ void ggml_sycl_op_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_TENSOR_UNARY_OP_LOCALS const size_t ts0 = ggml_type_size(src0->type); - GGML_ASSERT(nb00 == ts0); - const int64_t s01 = nb01 / ts0; - const int64_t s02 = nb02 / ts0; - const int64_t s03 = nb03 / ts0; - rms_norm_f32_sycl(src0_dd, dst_dd, ne00, ne01, ne02, ne03, s01, s02, s03, eps, main_stream, ctx.device); + const size_t tdst = ggml_type_size(dst->type); + GGML_ASSERT(nb00 % ts0 == 0 && nb01 % ts0 == 0 && nb02 % ts0 == 0 && nb03 % ts0 == 0); + GGML_ASSERT(nb0 % tdst == 0 && nb1 % tdst == 0 && nb2 % tdst == 0 && nb3 % tdst == 0); + const int64_t ss0 = nb00 / ts0; + const int64_t ss1 = nb01 / ts0; + const int64_t ss2 = nb02 / ts0; + const int64_t ss3 = nb03 / ts0; + const int64_t ds0 = nb0 / tdst; + const int64_t ds1 = nb1 / tdst; + const int64_t ds2 = nb2 / tdst; + const int64_t ds3 = nb3 / tdst; + rms_norm_f32_sycl(src0_dd, dst_dd, ne00, ne01, ne02, ne03, + ss0, ss1, ss2, ss3, ds0, ds1, ds2, ds3, eps, main_stream, ctx.device); } void ggml_sycl_op_rms_norm_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { @@ -644,13 +690,21 @@ void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { GGML_ASSERT(eps >= 0.0f); const size_t ts0 = ggml_type_size(src0->type); - GGML_ASSERT(nb00 == ts0); - const int64_t s01 = nb01 / ts0; - const int64_t s02 = nb02 / ts0; - const int64_t s03 = nb03 / ts0; + const size_t tdst = ggml_type_size(dst->type); + GGML_ASSERT(nb00 % ts0 == 0 && nb01 % ts0 == 0 && nb02 % ts0 == 0 && nb03 % ts0 == 0); + GGML_ASSERT(nb0 % tdst == 0 && nb1 % tdst == 0 && nb2 % tdst == 0 && nb3 % tdst == 0); + const int64_t ss0 = nb00 / ts0; + const int64_t ss1 = nb01 / ts0; + const int64_t ss2 = nb02 / ts0; + const int64_t ss3 = nb03 / ts0; + const int64_t ds0 = nb0 / tdst; + const int64_t ds1 = nb1 / tdst; + const int64_t ds2 = nb2 / tdst; + const int64_t ds3 = nb3 / tdst; /*support both WARP_SIZE or WARP_32_SIZE in code choose by hardware for better performance */ - l2_norm_f32_sycl(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream, ctx.device); + l2_norm_f32_sycl(src0_d, dst_d, ne00, ne01, ne02, ne03, + ss0, ss1, ss2, ss3, ds0, ds1, ds2, ds3, eps, stream, ctx.device); }