diff --git a/ggml/src/ggml-cpu/amx/mmq.cpp b/ggml/src/ggml-cpu/amx/mmq.cpp index d9383a04be..9f3a744b5d 100644 --- a/ggml/src/ggml-cpu/amx/mmq.cpp +++ b/ggml/src/ggml-cpu/amx/mmq.cpp @@ -2417,15 +2417,14 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te // Q4_K, Q5_K, Q6_K, IQ4_XS handles 8 TILE_K per blck_size GGML_ASSERT(TILE_K == blck_size || TILE_K * 8 == blck_size); - parallel_for_ggml(params, n_batch, [&](int begin, int end) { - for (int batch_idx = begin; batch_idx < end; ++batch_idx) { + parallel_for_ggml(params, n_batch * M, [&](int begin, int end) { + for (int idx = begin; idx < end; ++idx) { + int batch_idx = idx / M; + int m = idx % M; int64_t src1_offset = ggml_batch_offset(src1, batch_idx, ne2); const float * A_data = (const float *)((const char *)src1->data + src1_offset); char * wdata_batch = (char *)wdata + batch_idx * M * row_size_A; - - for (int m = 0; m < M; ++m) { - from_float(A_data + m * K, wdata_batch + m * row_size_A, K); - } + from_float(A_data + m * K, wdata_batch + m * row_size_A, K); } }); });