mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-27 23:50:20 -05:00
opencl: support non-contig rows in norm (#24965)
This commit is contained in:
parent
09cedfd699
commit
fdb2c11c70
@ -10152,14 +10152,8 @@ static void ggml_cl_norm(ggml_backend_t backend, const ggml_tensor * src0, const
|
|||||||
float eps;
|
float eps;
|
||||||
memcpy(&eps, dst->op_params, sizeof(float));
|
memcpy(&eps, dst->op_params, sizeof(float));
|
||||||
|
|
||||||
const int ne00 = src0 ? src0->ne[0] : 0;
|
GGML_TENSOR_LOCALS(int, ne0, src0, ne);
|
||||||
const int ne01 = src0 ? src0->ne[1] : 0;
|
GGML_TENSOR_LOCALS(cl_ulong, nb0, src0, nb);
|
||||||
const int ne02 = src0 ? src0->ne[2] : 0;
|
|
||||||
const int ne03 = src0 ? src0->ne[3] : 0;
|
|
||||||
|
|
||||||
const cl_ulong nb01 = src0 ? src0->nb[1] : 0;
|
|
||||||
const cl_ulong nb02 = src0 ? src0->nb[2] : 0;
|
|
||||||
const cl_ulong nb03 = src0 ? src0->nb[3] : 0;
|
|
||||||
|
|
||||||
const int nth = MIN(64, ne00);
|
const int nth = MIN(64, ne00);
|
||||||
|
|
||||||
@ -10173,11 +10167,12 @@ static void ggml_cl_norm(ggml_backend_t backend, const ggml_tensor * src0, const
|
|||||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01));
|
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01));
|
||||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02));
|
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02));
|
||||||
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03));
|
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03));
|
||||||
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01));
|
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb00));
|
||||||
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02));
|
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01));
|
||||||
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb03));
|
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb02));
|
||||||
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(float), &eps));
|
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb03));
|
||||||
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float)*nth, NULL));
|
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float), &eps));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(float)*nth, NULL));
|
||||||
|
|
||||||
size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
|
size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
|
||||||
size_t local_work_size[] = {(size_t)nth, 1, 1};
|
size_t local_work_size[] = {(size_t)nth, 1, 1};
|
||||||
|
|||||||
@ -24,6 +24,7 @@ kernel void kernel_norm(
|
|||||||
int ne01,
|
int ne01,
|
||||||
int ne02,
|
int ne02,
|
||||||
int ne03,
|
int ne03,
|
||||||
|
ulong nb00,
|
||||||
ulong nb01,
|
ulong nb01,
|
||||||
ulong nb02,
|
ulong nb02,
|
||||||
ulong nb03,
|
ulong nb03,
|
||||||
@ -43,7 +44,8 @@ kernel void kernel_norm(
|
|||||||
// parallel sum
|
// parallel sum
|
||||||
sum[get_local_id(0)] = 0.0f;
|
sum[get_local_id(0)] = 0.0f;
|
||||||
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
|
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
|
||||||
sum[get_local_id(0)] += x[i00];
|
// this kernel handles float, nb00/4 translates byte offset to element offset
|
||||||
|
sum[get_local_id(0)] += x[i00*nb00/4];
|
||||||
}
|
}
|
||||||
// reduce
|
// reduce
|
||||||
barrier(CLK_LOCAL_MEM_FENCE);
|
barrier(CLK_LOCAL_MEM_FENCE);
|
||||||
@ -60,7 +62,8 @@ kernel void kernel_norm(
|
|||||||
global float * y = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
global float * y = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||||
sum[get_local_id(0)] = 0.0f;
|
sum[get_local_id(0)] = 0.0f;
|
||||||
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
|
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
|
||||||
y[i00] = x[i00] - mean;
|
// this kernel handles float, nb00/4 translates byte offset to element offset
|
||||||
|
y[i00] = x[i00*nb00/4] - mean;
|
||||||
sum[get_local_id(0)] += y[i00] * y[i00];
|
sum[get_local_id(0)] += y[i00] * y[i00];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user