mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-27 23:50:20 -05:00
sycl : fix failed ut cases of norm (#25044)
This commit is contained in:
parent
0b6529d818
commit
9bebfcb4bc
@ -2,8 +2,10 @@
|
||||
#include "ggml-sycl/common.hpp"
|
||||
#include "ggml-sycl/presets.hpp"
|
||||
|
||||
static void norm_f32(const float* x, float* dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
|
||||
const int64_t stride_sample, const float eps, const sycl::nd_item<3>& item_ct1, sycl::float2* s_sum, int block_size) {
|
||||
static void norm_f32(const float* x, float* dst, const int ncols,
|
||||
const int64_t src_stride_col, const int64_t src_stride_row, const int64_t src_stride_channel, const int64_t src_stride_sample,
|
||||
const int64_t dst_stride_col, const int64_t dst_stride_row, const int64_t dst_stride_channel, const int64_t dst_stride_sample,
|
||||
const float eps, const sycl::nd_item<3>& item_ct1, sycl::float2* s_sum, int block_size) {
|
||||
|
||||
const int nrows = item_ct1.get_group_range(2);
|
||||
const int nchannels = item_ct1.get_group_range(1);
|
||||
@ -16,16 +18,16 @@ static void norm_f32(const float* x, float* dst, const int ncols, const int64_t
|
||||
const int tid = item_ct1.get_local_id(2);
|
||||
const int nwarps = nthreads / WARP_SIZE;
|
||||
|
||||
const auto strided_offset = calculate_offset<3>({stride_sample, stride_channel, stride_row}, {sample, channel, row});
|
||||
const auto packed_offset = calculate_offset<3>({nchannels * nrows * ncols, nrows * ncols, ncols}, {sample, channel, row});
|
||||
const auto src_offset = calculate_offset<3>({src_stride_sample, src_stride_channel, src_stride_row}, {sample, channel, row});
|
||||
const auto dst_offset = calculate_offset<3>({dst_stride_sample, dst_stride_channel, dst_stride_row}, {sample, channel, row});
|
||||
|
||||
x += strided_offset;
|
||||
dst += packed_offset;
|
||||
x += src_offset;
|
||||
dst += dst_offset;
|
||||
|
||||
sycl::float2 mean_var = sycl::float2(0.f, 0.f);
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
const float xi = x[col];
|
||||
const float xi = x[col * src_stride_col];
|
||||
mean_var.x() += xi;
|
||||
mean_var.y() += xi * xi;
|
||||
}
|
||||
@ -54,7 +56,7 @@ static void norm_f32(const float* x, float* dst, const int ncols, const int64_t
|
||||
const float inv_std = sycl::rsqrt(var + eps);
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
dst[col] = (x[col] - mean) * inv_std;
|
||||
dst[col * dst_stride_col] = (x[col * src_stride_col] - mean) * inv_std;
|
||||
}
|
||||
}
|
||||
|
||||
@ -145,8 +147,10 @@ static void group_norm_f32(const float* x, float* dst, const int group_size, con
|
||||
}
|
||||
}
|
||||
|
||||
static void rms_norm_f32(const float* x, float* dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
|
||||
const int64_t stride_sample, const float eps, const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) {
|
||||
static void rms_norm_f32(const float* x, float* dst, const int ncols,
|
||||
const int64_t src_stride_col, const int64_t src_stride_row, const int64_t src_stride_channel, const int64_t src_stride_sample,
|
||||
const int64_t dst_stride_col, const int64_t dst_stride_row, const int64_t dst_stride_channel, const int64_t dst_stride_sample,
|
||||
const float eps, const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) {
|
||||
|
||||
const int nrows = item_ct1.get_group_range(2);
|
||||
const int nchannels = item_ct1.get_group_range(1);
|
||||
@ -160,17 +164,17 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const int6
|
||||
const int tid = item_ct1.get_local_id(2);
|
||||
const int nwarps = nthreads / WARP_SIZE;
|
||||
|
||||
const auto strided_offset = calculate_offset<3>({stride_sample, stride_channel, stride_row}, {sample, channel, row});
|
||||
const auto packed_offset = calculate_offset<3>({nchannels * nrows * ncols, nrows * ncols, ncols}, {sample, channel, row});
|
||||
const auto src_offset = calculate_offset<3>({src_stride_sample, src_stride_channel, src_stride_row}, {sample, channel, row});
|
||||
const auto dst_offset = calculate_offset<3>({dst_stride_sample, dst_stride_channel, dst_stride_row}, {sample, channel, row});
|
||||
|
||||
x += strided_offset;
|
||||
dst += packed_offset;
|
||||
x += src_offset;
|
||||
dst += dst_offset;
|
||||
|
||||
|
||||
float tmp = 0.0f; // partial sum for thread in warp
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
const float xi = x[col];
|
||||
const float xi = x[col * src_stride_col];
|
||||
tmp += xi * xi;
|
||||
}
|
||||
|
||||
@ -198,14 +202,15 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const int6
|
||||
const float scale = sycl::rsqrt(mean + eps);
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
dst[col] = scale * x[col];
|
||||
dst[col * dst_stride_col] = scale * x[col * src_stride_col];
|
||||
}
|
||||
}
|
||||
|
||||
template<int warp_size>
|
||||
static void l2_norm_f32(const float * x, float * dst, const int ncols,
|
||||
const int64_t stride_row, const int64_t stride_channel,
|
||||
const int64_t stride_sample, const float eps,
|
||||
const int64_t src_stride_col, const int64_t src_stride_row, const int64_t src_stride_channel,
|
||||
const int64_t src_stride_sample, const int64_t dst_stride_col, const int64_t dst_stride_row,
|
||||
const int64_t dst_stride_channel, const int64_t dst_stride_sample, const float eps,
|
||||
const sycl::nd_item<3>& item_ct1, float* s_sum, const int block_size) {
|
||||
const int nrows = item_ct1.get_group_range(2);
|
||||
const int nchannels = item_ct1.get_group_range(1);
|
||||
@ -215,13 +220,13 @@ static void l2_norm_f32(const float * x, float * dst, const int ncols,
|
||||
const int sample = item_ct1.get_group(0);
|
||||
const int tid = item_ct1.get_local_id(2);
|
||||
|
||||
x += sample*stride_sample + channel*stride_channel + row*stride_row;
|
||||
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
|
||||
x += sample*src_stride_sample + channel*src_stride_channel + row*src_stride_row;
|
||||
dst += sample*dst_stride_sample + channel*dst_stride_channel + row*dst_stride_row;
|
||||
|
||||
float tmp = 0.0f; // partial sum for thread in warp
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
const float xi = x[col];
|
||||
const float xi = x[col * src_stride_col];
|
||||
tmp += xi * xi;
|
||||
}
|
||||
|
||||
@ -229,12 +234,13 @@ static void l2_norm_f32(const float * x, float * dst, const int ncols,
|
||||
const float scale = sycl::rsqrt(sycl::fmax(tmp, eps * eps));
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
dst[col] = scale * x[col];
|
||||
dst[col * dst_stride_col] = scale * x[col * src_stride_col];
|
||||
}
|
||||
}
|
||||
|
||||
static void norm_f32_sycl(const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
|
||||
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample,
|
||||
const int64_t src_stride_col, const int64_t src_stride_row, const int64_t src_stride_channel, const int64_t src_stride_sample,
|
||||
const int64_t dst_stride_col, const int64_t dst_stride_row, const int64_t dst_stride_channel, const int64_t dst_stride_sample,
|
||||
const float eps, queue_ptr stream, int device) {
|
||||
|
||||
const sycl::range<3> global_dims(nsamples, nchannels, nrows);
|
||||
@ -245,7 +251,10 @@ static void norm_f32_sycl(const float * x, float * dst, const int ncols, const i
|
||||
sycl::nd_range<3>(global_dims * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, nullptr, WARP_SIZE);
|
||||
norm_f32(x, dst, ncols,
|
||||
src_stride_col, src_stride_row, src_stride_channel, src_stride_sample,
|
||||
dst_stride_col, dst_stride_row, dst_stride_channel, dst_stride_sample,
|
||||
eps, item_ct1, nullptr, WARP_SIZE);
|
||||
});
|
||||
});
|
||||
}
|
||||
@ -265,7 +274,10 @@ static void norm_f32_sycl(const float * x, float * dst, const int ncols, const i
|
||||
sycl::nd_range<3>(global_dims * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size);
|
||||
norm_f32(x, dst, ncols,
|
||||
src_stride_col, src_stride_row, src_stride_channel, src_stride_sample,
|
||||
dst_stride_col, dst_stride_row, dst_stride_channel, dst_stride_sample,
|
||||
eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size);
|
||||
});
|
||||
});
|
||||
}
|
||||
@ -319,7 +331,9 @@ static void group_norm_f32_sycl(const float* x, float* dst,
|
||||
}
|
||||
|
||||
static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
|
||||
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, queue_ptr stream, int device) {
|
||||
const int64_t src_stride_col, const int64_t src_stride_row, const int64_t src_stride_channel, const int64_t src_stride_sample,
|
||||
const int64_t dst_stride_col, const int64_t dst_stride_row, const int64_t dst_stride_channel, const int64_t dst_stride_sample,
|
||||
const float eps, queue_ptr stream, int device) {
|
||||
// printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
|
||||
|
||||
const sycl::range<3> global_dims(nsamples, nchannels, nrows);
|
||||
@ -330,7 +344,10 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const
|
||||
sycl::nd_range<3>(global_dims * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, nullptr, WARP_SIZE);
|
||||
rms_norm_f32(x, dst, ncols,
|
||||
src_stride_col, src_stride_row, src_stride_channel, src_stride_sample,
|
||||
dst_stride_col, dst_stride_row, dst_stride_channel, dst_stride_sample,
|
||||
eps, item_ct1, nullptr, WARP_SIZE);
|
||||
});
|
||||
});
|
||||
}
|
||||
@ -350,7 +367,10 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const
|
||||
sycl::nd_range<3>(global_dims * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size);
|
||||
rms_norm_f32(x, dst, ncols,
|
||||
src_stride_col, src_stride_row, src_stride_channel, src_stride_sample,
|
||||
dst_stride_col, dst_stride_row, dst_stride_channel, dst_stride_sample,
|
||||
eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size);
|
||||
});
|
||||
});
|
||||
}
|
||||
@ -363,9 +383,14 @@ static void l2_norm_f32_sycl(const float * x,
|
||||
const int nrows,
|
||||
const int nchannels,
|
||||
const int nsamples,
|
||||
const int64_t stride_row,
|
||||
const int64_t stride_channel,
|
||||
const int64_t stride_sample,
|
||||
const int64_t src_stride_col,
|
||||
const int64_t src_stride_row,
|
||||
const int64_t src_stride_channel,
|
||||
const int64_t src_stride_sample,
|
||||
const int64_t dst_stride_col,
|
||||
const int64_t dst_stride_row,
|
||||
const int64_t dst_stride_channel,
|
||||
const int64_t dst_stride_sample,
|
||||
const float eps,
|
||||
queue_ptr stream,
|
||||
int device) {
|
||||
@ -379,7 +404,10 @@ static void l2_norm_f32_sycl(const float * x,
|
||||
block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[sycl::reqd_sub_group_size(warp_size)]] {
|
||||
l2_norm_f32<warp_size>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1,
|
||||
l2_norm_f32<warp_size>(x, dst, ncols,
|
||||
src_stride_col, src_stride_row, src_stride_channel, src_stride_sample,
|
||||
dst_stride_col, dst_stride_row, dst_stride_channel, dst_stride_sample,
|
||||
eps, item_ct1,
|
||||
nullptr, warp_size);
|
||||
});
|
||||
});
|
||||
@ -398,7 +426,9 @@ static void l2_norm_f32_sycl(const float * x,
|
||||
block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[sycl::reqd_sub_group_size(warp_size)]] {
|
||||
l2_norm_f32<warp_size>(x, dst, ncols, stride_row, stride_channel, stride_sample,
|
||||
l2_norm_f32<warp_size>(x, dst, ncols,
|
||||
src_stride_col, src_stride_row, src_stride_channel, src_stride_sample,
|
||||
dst_stride_col, dst_stride_row, dst_stride_channel, dst_stride_sample,
|
||||
eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size);
|
||||
});
|
||||
});
|
||||
@ -421,12 +451,20 @@ void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
||||
memcpy(&eps, dst->op_params, sizeof(float));
|
||||
GGML_ASSERT(eps >= 0.0f);
|
||||
const size_t ts0 = ggml_type_size(src0->type);
|
||||
GGML_ASSERT(nb00 == ts0);
|
||||
const int64_t s01 = nb01 / ts0;
|
||||
const int64_t s02 = nb02 / ts0;
|
||||
const int64_t s03 = nb03 / ts0;
|
||||
const size_t tdst = ggml_type_size(dst->type);
|
||||
GGML_ASSERT(nb00 % ts0 == 0 && nb01 % ts0 == 0 && nb02 % ts0 == 0 && nb03 % ts0 == 0);
|
||||
GGML_ASSERT(nb0 % tdst == 0 && nb1 % tdst == 0 && nb2 % tdst == 0 && nb3 % tdst == 0);
|
||||
const int64_t ss0 = nb00 / ts0;
|
||||
const int64_t ss1 = nb01 / ts0;
|
||||
const int64_t ss2 = nb02 / ts0;
|
||||
const int64_t ss3 = nb03 / ts0;
|
||||
const int64_t ds0 = nb0 / tdst;
|
||||
const int64_t ds1 = nb1 / tdst;
|
||||
const int64_t ds2 = nb2 / tdst;
|
||||
const int64_t ds3 = nb3 / tdst;
|
||||
|
||||
norm_f32_sycl(src0_dd, dst_dd, ne00, ne01, ne02, ne03, s01, s02, s03, eps, main_stream, ctx.device);
|
||||
norm_f32_sycl(src0_dd, dst_dd, ne00, ne01, ne02, ne03,
|
||||
ss0, ss1, ss2, ss3, ds0, ds1, ds2, ds3, eps, main_stream, ctx.device);
|
||||
}
|
||||
|
||||
void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
||||
@ -465,11 +503,19 @@ void ggml_sycl_op_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
|
||||
GGML_TENSOR_UNARY_OP_LOCALS
|
||||
const size_t ts0 = ggml_type_size(src0->type);
|
||||
GGML_ASSERT(nb00 == ts0);
|
||||
const int64_t s01 = nb01 / ts0;
|
||||
const int64_t s02 = nb02 / ts0;
|
||||
const int64_t s03 = nb03 / ts0;
|
||||
rms_norm_f32_sycl(src0_dd, dst_dd, ne00, ne01, ne02, ne03, s01, s02, s03, eps, main_stream, ctx.device);
|
||||
const size_t tdst = ggml_type_size(dst->type);
|
||||
GGML_ASSERT(nb00 % ts0 == 0 && nb01 % ts0 == 0 && nb02 % ts0 == 0 && nb03 % ts0 == 0);
|
||||
GGML_ASSERT(nb0 % tdst == 0 && nb1 % tdst == 0 && nb2 % tdst == 0 && nb3 % tdst == 0);
|
||||
const int64_t ss0 = nb00 / ts0;
|
||||
const int64_t ss1 = nb01 / ts0;
|
||||
const int64_t ss2 = nb02 / ts0;
|
||||
const int64_t ss3 = nb03 / ts0;
|
||||
const int64_t ds0 = nb0 / tdst;
|
||||
const int64_t ds1 = nb1 / tdst;
|
||||
const int64_t ds2 = nb2 / tdst;
|
||||
const int64_t ds3 = nb3 / tdst;
|
||||
rms_norm_f32_sycl(src0_dd, dst_dd, ne00, ne01, ne02, ne03,
|
||||
ss0, ss1, ss2, ss3, ds0, ds1, ds2, ds3, eps, main_stream, ctx.device);
|
||||
}
|
||||
|
||||
void ggml_sycl_op_rms_norm_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
@ -644,13 +690,21 @@ void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
||||
GGML_ASSERT(eps >= 0.0f);
|
||||
|
||||
const size_t ts0 = ggml_type_size(src0->type);
|
||||
GGML_ASSERT(nb00 == ts0);
|
||||
const int64_t s01 = nb01 / ts0;
|
||||
const int64_t s02 = nb02 / ts0;
|
||||
const int64_t s03 = nb03 / ts0;
|
||||
const size_t tdst = ggml_type_size(dst->type);
|
||||
GGML_ASSERT(nb00 % ts0 == 0 && nb01 % ts0 == 0 && nb02 % ts0 == 0 && nb03 % ts0 == 0);
|
||||
GGML_ASSERT(nb0 % tdst == 0 && nb1 % tdst == 0 && nb2 % tdst == 0 && nb3 % tdst == 0);
|
||||
const int64_t ss0 = nb00 / ts0;
|
||||
const int64_t ss1 = nb01 / ts0;
|
||||
const int64_t ss2 = nb02 / ts0;
|
||||
const int64_t ss3 = nb03 / ts0;
|
||||
const int64_t ds0 = nb0 / tdst;
|
||||
const int64_t ds1 = nb1 / tdst;
|
||||
const int64_t ds2 = nb2 / tdst;
|
||||
const int64_t ds3 = nb3 / tdst;
|
||||
|
||||
/*support both WARP_SIZE or WARP_32_SIZE in code
|
||||
choose by hardware for better performance
|
||||
*/
|
||||
l2_norm_f32_sycl<WARP_SIZE>(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream, ctx.device);
|
||||
l2_norm_f32_sycl<WARP_SIZE>(src0_d, dst_d, ne00, ne01, ne02, ne03,
|
||||
ss0, ss1, ss2, ss3, ds0, ds1, ds2, ds3, eps, stream, ctx.device);
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user