opencl: support non-contig rows in norm (#24965)

This commit is contained in:
lhez 2026-06-24 19:21:25 -07:00 committed by GitHub
parent 09cedfd699
commit fdb2c11c70
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 13 additions and 15 deletions

View File

@ -10152,14 +10152,8 @@ static void ggml_cl_norm(ggml_backend_t backend, const ggml_tensor * src0, const
float eps; float eps;
memcpy(&eps, dst->op_params, sizeof(float)); memcpy(&eps, dst->op_params, sizeof(float));
const int ne00 = src0 ? src0->ne[0] : 0; GGML_TENSOR_LOCALS(int, ne0, src0, ne);
const int ne01 = src0 ? src0->ne[1] : 0; GGML_TENSOR_LOCALS(cl_ulong, nb0, src0, nb);
const int ne02 = src0 ? src0->ne[2] : 0;
const int ne03 = src0 ? src0->ne[3] : 0;
const cl_ulong nb01 = src0 ? src0->nb[1] : 0;
const cl_ulong nb02 = src0 ? src0->nb[2] : 0;
const cl_ulong nb03 = src0 ? src0->nb[3] : 0;
const int nth = MIN(64, ne00); const int nth = MIN(64, ne00);
@ -10173,11 +10167,12 @@ static void ggml_cl_norm(ggml_backend_t backend, const ggml_tensor * src0, const
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01)); CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02)); CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03)); CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01)); CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb00));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02)); CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb03)); CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb02));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(float), &eps)); CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb03));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float)*nth, NULL)); CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float), &eps));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(float)*nth, NULL));
size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
size_t local_work_size[] = {(size_t)nth, 1, 1}; size_t local_work_size[] = {(size_t)nth, 1, 1};

View File

@ -24,6 +24,7 @@ kernel void kernel_norm(
int ne01, int ne01,
int ne02, int ne02,
int ne03, int ne03,
ulong nb00,
ulong nb01, ulong nb01,
ulong nb02, ulong nb02,
ulong nb03, ulong nb03,
@ -43,7 +44,8 @@ kernel void kernel_norm(
// parallel sum // parallel sum
sum[get_local_id(0)] = 0.0f; sum[get_local_id(0)] = 0.0f;
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
sum[get_local_id(0)] += x[i00]; // this kernel handles float, nb00/4 translates byte offset to element offset
sum[get_local_id(0)] += x[i00*nb00/4];
} }
// reduce // reduce
barrier(CLK_LOCAL_MEM_FENCE); barrier(CLK_LOCAL_MEM_FENCE);
@ -60,7 +62,8 @@ kernel void kernel_norm(
global float * y = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; global float * y = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
sum[get_local_id(0)] = 0.0f; sum[get_local_id(0)] = 0.0f;
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
y[i00] = x[i00] - mean; // this kernel handles float, nb00/4 translates byte offset to element offset
y[i00] = x[i00*nb00/4] - mean;
sum[get_local_id(0)] += y[i00] * y[i00]; sum[get_local_id(0)] += y[i00] * y[i00];
} }