CUDA: batch out_prod broadcast (dps2>1) path with cublasSgemmBatched (#24426)

This commit is contained in:
leonardHONG 2026-06-26 13:51:25 +08:00 committed by GitHub
parent 960d628f46
commit f818065d75
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 61 additions and 12 deletions

View File

@ -2,6 +2,28 @@
#include <cstdint>
static __global__ void k_compute_out_prod_ptrs(
const float * src0_d, const float * src1_d, float * dst_d,
const float ** ptrs_a, const float ** ptrs_b, float ** ptrs_c,
const int64_t ne2, const int64_t ne3,
const int64_t dps2, const int64_t dps3,
const size_t s02, const size_t s03,
const size_t s12, const size_t s13,
const size_t s2, const size_t s3) {
const int64_t i2 = blockIdx.x*blockDim.x + threadIdx.x;
const int64_t i3 = blockIdx.y*blockDim.y + threadIdx.y;
if (i2 >= ne2 || i3 >= ne3) {
return;
}
const int64_t idx = i3*ne2 + i2;
ptrs_a[idx] = src0_d + (i3/dps3)*s03 + (i2/dps2)*s02;
ptrs_b[idx] = src1_d + i3 *s13 + i2 *s12;
ptrs_c[idx] = dst_d + i3 *s3 + i2 *s2;
}
void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
@ -67,18 +89,39 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
&beta, dst_d + i3 *s3, ldc, s2,
batch_count));
}
} else if (ne2 > 1 || ne3 > 1) {
// dps2 > 1 (src0 broadcast along dim 2 with non-uniform stride) or multiple GEMMs
// along dim 3: compute per-GEMM pointers on the device and use a single batched GEMM.
GGML_ASSERT(ne3 > 0);
GGML_ASSERT(ne2 <= (int64_t) std::numeric_limits<int>::max() / ne3);
const int batch_count = (int) (ne2 * ne3);
ggml_cuda_pool_alloc<const float *> ptrs_a(ctx.pool(), batch_count);
ggml_cuda_pool_alloc<const float *> ptrs_b(ctx.pool(), batch_count);
ggml_cuda_pool_alloc< float *> ptrs_c(ctx.pool(), batch_count);
const dim3 block_dims(16, 16);
const dim3 grid_dims((ne2 + block_dims.x - 1)/block_dims.x, (ne3 + block_dims.y - 1)/block_dims.y);
k_compute_out_prod_ptrs<<<grid_dims, block_dims, 0, stream>>>(
src0_d, src1_d, dst_d,
ptrs_a.get(), ptrs_b.get(), ptrs_c.get(),
ne2, ne3, dps2, dps3, s02, s03, s12, s13, s2, s3);
CUDA_CHECK(cudaGetLastError());
CUBLAS_CHECK(
cublasSgemmBatched(handle, CUBLAS_OP_N, src1_cublas_op,
ne0, ne1, ne01,
&alpha, ptrs_a.get(), lda,
ptrs_b.get(), ldb,
&beta, ptrs_c.get(), ldc,
batch_count));
} else {
// Fallback: ne2 == 1 (no batching benefit) or dps2 > 1 (src0 broadcast along dim 2
// with non-uniform stride; would need cublasSgemmBatched with pointer arrays).
for (int64_t i3 = 0; i3 < ne3; ++i3) {
for (int64_t i2 = 0; i2 < ne2; ++i2) {
CUBLAS_CHECK(
cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op,
ne0, ne1, ne01,
&alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, lda,
src1_d + i3 *s13 + i2 *s12, ldb,
&beta, dst_d + i3 *s3 + i2 *s2, ldc));
}
}
// ne2 == 1 && ne3 == 1: single GEMM
CUBLAS_CHECK(
cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op,
ne0, ne1, ne01,
&alpha, src0_d, lda,
src1_d, ldb,
&beta, dst_d, ldc));
}
}

View File

@ -8672,6 +8672,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
256, 16, 16, {ne2, 1}, {1, 1}));
}
// nr2 sweep to cover the cublasSgemmBatched pointer-array path (dps2 > 1)
for (int64_t nr2 : {8, 16, 32}) {
test_cases.emplace_back(new test_out_prod(GGML_TYPE_F32, GGML_TYPE_F32,
256, 16, 16, {1, 1}, {nr2, 1}));
}
// add_id
for (ggml_type type_a : {GGML_TYPE_F32}) {
for (ggml_type type_b : {GGML_TYPE_F32}) {