diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 5d4b10d34b..ce847dd8b6 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -1738,10 +1738,14 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_im2col(ggml_meta GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); GGML_ASSERT(op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32); + const bool is_2D = ((const int32_t *)(op->op_params))[6] == 1; + const int64_t KH = is_2D ? ne01 : 1; + const int64_t KW = ne00; + char base[256]; char name[256]; - if (ne00*ne01 <= 1024) { + if (KH*KW <= 1024) { snprintf(base, 256, "kernel_im2col_%s", ggml_type_name(op->type)); } else { snprintf(base, 256, "kernel_im2col_ext_%s", ggml_type_name(op->type)); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index ba89a94fc9..4d40ff77d3 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -7771,6 +7771,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false)); test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false)); test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false)); + test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {3000, 384, 1, 1}, {3, 384, 384, 1}, 1, 0, 1, 0, 1, 0, false)); for (int s0 : {1, 3}) { for (int p0 : {0, 3}) { for (int d0 : {1, 3}) {