From f818065d753700341b5355ce36ca676ecfb5d139 Mon Sep 17 00:00:00 2001 From: leonardHONG <2695316095@qq.com> Date: Fri, 26 Jun 2026 13:51:25 +0800 Subject: [PATCH] CUDA: batch out_prod broadcast (dps2>1) path with cublasSgemmBatched (#24426) --- ggml/src/ggml-cuda/out-prod.cu | 67 ++++++++++++++++++++++++++++------ tests/test-backend-ops.cpp | 6 +++ 2 files changed, 61 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-cuda/out-prod.cu b/ggml/src/ggml-cuda/out-prod.cu index 499903d09b..46b9f3a67e 100644 --- a/ggml/src/ggml-cuda/out-prod.cu +++ b/ggml/src/ggml-cuda/out-prod.cu @@ -2,6 +2,28 @@ #include +static __global__ void k_compute_out_prod_ptrs( + const float * src0_d, const float * src1_d, float * dst_d, + const float ** ptrs_a, const float ** ptrs_b, float ** ptrs_c, + const int64_t ne2, const int64_t ne3, + const int64_t dps2, const int64_t dps3, + const size_t s02, const size_t s03, + const size_t s12, const size_t s13, + const size_t s2, const size_t s3) { + const int64_t i2 = blockIdx.x*blockDim.x + threadIdx.x; + const int64_t i3 = blockIdx.y*blockDim.y + threadIdx.y; + + if (i2 >= ne2 || i3 >= ne3) { + return; + } + + const int64_t idx = i3*ne2 + i2; + + ptrs_a[idx] = src0_d + (i3/dps3)*s03 + (i2/dps2)*s02; + ptrs_b[idx] = src1_d + i3 *s13 + i2 *s12; + ptrs_c[idx] = dst_d + i3 *s3 + i2 *s2; +} + void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; @@ -67,18 +89,39 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { &beta, dst_d + i3 *s3, ldc, s2, batch_count)); } + } else if (ne2 > 1 || ne3 > 1) { + // dps2 > 1 (src0 broadcast along dim 2 with non-uniform stride) or multiple GEMMs + // along dim 3: compute per-GEMM pointers on the device and use a single batched GEMM. + GGML_ASSERT(ne3 > 0); + GGML_ASSERT(ne2 <= (int64_t) std::numeric_limits::max() / ne3); + const int batch_count = (int) (ne2 * ne3); + + ggml_cuda_pool_alloc ptrs_a(ctx.pool(), batch_count); + ggml_cuda_pool_alloc ptrs_b(ctx.pool(), batch_count); + ggml_cuda_pool_alloc< float *> ptrs_c(ctx.pool(), batch_count); + + const dim3 block_dims(16, 16); + const dim3 grid_dims((ne2 + block_dims.x - 1)/block_dims.x, (ne3 + block_dims.y - 1)/block_dims.y); + k_compute_out_prod_ptrs<<>>( + src0_d, src1_d, dst_d, + ptrs_a.get(), ptrs_b.get(), ptrs_c.get(), + ne2, ne3, dps2, dps3, s02, s03, s12, s13, s2, s3); + CUDA_CHECK(cudaGetLastError()); + + CUBLAS_CHECK( + cublasSgemmBatched(handle, CUBLAS_OP_N, src1_cublas_op, + ne0, ne1, ne01, + &alpha, ptrs_a.get(), lda, + ptrs_b.get(), ldb, + &beta, ptrs_c.get(), ldc, + batch_count)); } else { - // Fallback: ne2 == 1 (no batching benefit) or dps2 > 1 (src0 broadcast along dim 2 - // with non-uniform stride; would need cublasSgemmBatched with pointer arrays). - for (int64_t i3 = 0; i3 < ne3; ++i3) { - for (int64_t i2 = 0; i2 < ne2; ++i2) { - CUBLAS_CHECK( - cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op, - ne0, ne1, ne01, - &alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, lda, - src1_d + i3 *s13 + i2 *s12, ldb, - &beta, dst_d + i3 *s3 + i2 *s2, ldc)); - } - } + // ne2 == 1 && ne3 == 1: single GEMM + CUBLAS_CHECK( + cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op, + ne0, ne1, ne01, + &alpha, src0_d, lda, + src1_d, ldb, + &beta, dst_d, ldc)); } } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index c83e91fbdb..0a017d57e7 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -8672,6 +8672,12 @@ static std::vector> make_test_cases_eval() { 256, 16, 16, {ne2, 1}, {1, 1})); } + // nr2 sweep to cover the cublasSgemmBatched pointer-array path (dps2 > 1) + for (int64_t nr2 : {8, 16, 32}) { + test_cases.emplace_back(new test_out_prod(GGML_TYPE_F32, GGML_TYPE_F32, + 256, 16, 16, {1, 1}, {nr2, 1})); + } + // add_id for (ggml_type type_a : {GGML_TYPE_F32}) { for (ggml_type type_b : {GGML_TYPE_F32}) {