Fix bf16 graph reduce type

This commit is contained in:
Kawrakow 2026-06-08 13:35:20 +00:00
parent 1660459db5
commit 17f05fc6ec

View File

@ -476,6 +476,18 @@ void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
else if (dst->src[0]->type == GGML_TYPE_F32 && dst->src[1]->type == GGML_TYPE_F32) {
k_fast_add_2<<<nblocks, kBlockSize, 0, ctx.stream()>>>(dst->ne[0], nelem,
(const float *)dst->src[0]->data, (const float *)dst->src[1]->data, (float *)dst->data);
}
else if (dst->src[0]->type == GGML_TYPE_BF16 && dst->src[1]->type == GGML_TYPE_BF16) {
k_fast_add_2<<<nblocks, kBlockSize, 0, ctx.stream()>>>(dst->ne[0], nelem,
(const nv_bfloat16 *)dst->src[0]->data, (const nv_bfloat16 *)dst->src[1]->data, (float *)dst->data);
}
else if (dst->src[0]->type == GGML_TYPE_BF16 && dst->src[1]->type == GGML_TYPE_F32) {
k_fast_add_2<<<nblocks, kBlockSize, 0, ctx.stream()>>>(dst->ne[0], nelem,
(const nv_bfloat16 *)dst->src[0]->data, (const float *)dst->src[1]->data, (float *)dst->data);
}
else if (dst->src[0]->type == GGML_TYPE_F32 && dst->src[1]->type == GGML_TYPE_BF16) {
k_fast_add_2<<<nblocks, kBlockSize, 0, ctx.stream()>>>(dst->ne[0], nelem,
(const float *)dst->src[0]->data, (const nv_bfloat16 *)dst->src[1]->data, (float *)dst->data);
} else {
k_fast_add_2<<<nblocks, kBlockSize, 0, ctx.stream()>>>(dst->ne[0], nelem,
(const float *)dst->src[0]->data, (const half *)dst->src[1]->data, (float *)dst->data);