mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-27 23:50:20 -05:00
sycl : fix the failed UT cases of conv_3d (#24900)
This commit is contained in:
parent
fdb2c11c70
commit
9c10954865
@ -103,8 +103,8 @@ void ggml_sycl_op_conv_3d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
// allocate packed arrays: A_packed (k x m), B_packed (k x n)
|
||||
ggml_sycl_pool_alloc<float> A_packed_alloc(ctx.pool());
|
||||
ggml_sycl_pool_alloc<float> B_packed_alloc(ctx.pool());
|
||||
A_packed_alloc.alloc((size_t) knl_n_total * patch_total * sizeof(float));
|
||||
B_packed_alloc.alloc((size_t) knl_n_total * oc * sizeof(float));
|
||||
A_packed_alloc.alloc((size_t) knl_n_total * patch_total);
|
||||
B_packed_alloc.alloc((size_t) knl_n_total * oc);
|
||||
|
||||
float * A_packed = A_packed_alloc.get();
|
||||
float * B_packed = B_packed_alloc.get();
|
||||
@ -115,10 +115,16 @@ void ggml_sycl_op_conv_3d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
|
||||
// Combined kernel: im2col -> pack A, and pack B simultaneously
|
||||
const char * src1_base = (const char *) src1->data;
|
||||
const char * src0_base = (const char *) src0->data;
|
||||
const int64_t src1_nb0 = src1->nb[0];
|
||||
const int64_t src1_nb1 = src1->nb[1];
|
||||
const int64_t src1_nb2 = src1->nb[2];
|
||||
const int64_t src1_nb3 = src1->nb[3];
|
||||
const int64_t src1_w = src1->ne[0];
|
||||
const int64_t src1_h = src1->ne[1];
|
||||
const int64_t src1_d = src1->ne[2];
|
||||
|
||||
const bool src0_is_f32 = (src0->type == GGML_TYPE_F32);
|
||||
|
||||
// Compute correct strides for src0 as (knl_n_total, oc) matrix
|
||||
const int64_t src0_packed_nb0 = kernel_type_size;
|
||||
@ -165,7 +171,7 @@ void ggml_sycl_op_conv_3d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
const int64_t sz = dst_z * s2 + kz * d2 - p2;
|
||||
|
||||
float val = 0.0f;
|
||||
if (sx >= 0 && sx < src1->ne[0] && sy >= 0 && sy < src1->ne[1] && sz >= 0 && sz < src1->ne[2]) {
|
||||
if (sx >= 0 && sx < src1_w && sy >= 0 && sy < src1_h && sz >= 0 && sz < src1_d) {
|
||||
const int64_t channel_idx = batch_idx * c + ic;
|
||||
const char * ptr = src1_base + sx * src1_nb0 + sy * src1_nb1 + sz * src1_nb2 + channel_idx * src1_nb3;
|
||||
val = *(const float *) ptr;
|
||||
@ -184,9 +190,9 @@ void ggml_sycl_op_conv_3d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
|
||||
const int64_t row = t % k;
|
||||
const int64_t col = t / k;
|
||||
const char * src_ptr = (const char *) src0->data + row * src0_packed_nb0 + col * src0_packed_nb1;
|
||||
const char * src_ptr = src0_base + row * src0_packed_nb0 + col * src0_packed_nb1;
|
||||
float v;
|
||||
if (src0->type == GGML_TYPE_F32) {
|
||||
if (src0_is_f32) {
|
||||
v = *(const float *) src_ptr;
|
||||
} else {
|
||||
v = sycl::vec<sycl::half, 1>(*(const sycl::half *) src_ptr).convert<float, sycl::rounding_mode::automatic>()[0];
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user