From 8b0cd0357ad156ef65a555f359ffdaa24acf916e Mon Sep 17 00:00:00 2001 From: Jun Yamog Date: Tue, 12 May 2026 16:38:42 +1200 Subject: [PATCH] fix: keep sm70 cublas f32 outputs in f32 (#1776) --- ggml/src/ggml-cuda.cu | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 4604db95..10959122 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -1603,6 +1603,34 @@ static void ggml_cuda_op_mul_mat_cublas( to_fp16_cuda(src1_ddf_i, src1_as_f16.get(), src1_ncols, ne10, stream); } const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16.get(); + + // On Volta, avoid storing f32 graph outputs in a temporary f16 buffer; + // finite matmul results outside fp16 range would become +/-inf there. + const bool sm70_f32_output = + compute_capability <= CC_VOLTA && + dst->type == GGML_TYPE_F32; + if (sm70_f32_output) { + const float alpha_f32 = 1.0f; + const float beta_f32 = 0.0f; + + static std::atomic sm70_f32_output_logs{0}; + if (sm70_f32_output_logs.fetch_add(1) < 8) { + GGML_CUDA_LOG_WARN( + "%s: using f32 cublas output for %s on cc=%d to avoid fp16 output saturation\n", + __func__, dst->name, compute_capability); + } + + CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream)); + CUBLAS_CHECK( + cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N, + row_diff, src1_ncols, ne10, + &alpha_f32, src0_ptr, CUDA_R_16F, ne00, + src1_ptr, CUDA_R_16F, ne10, + &beta_f32, dst_dd_i, CUDA_R_32F, ldc, + CUBLAS_COMPUTE_32F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + return; + } ggml_cuda_pool_alloc dst_f16(ctx.pool(id), row_diff*src1_ncols); const half alpha_f16 = 1.0f;