diff --git a/ggml/src/ggml-cuda/binbcast.cu b/ggml/src/ggml-cuda/binbcast.cu index 5953746b..898640dd 100644 --- a/ggml/src/ggml-cuda/binbcast.cu +++ b/ggml/src/ggml-cuda/binbcast.cu @@ -476,6 +476,18 @@ void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { else if (dst->src[0]->type == GGML_TYPE_F32 && dst->src[1]->type == GGML_TYPE_F32) { k_fast_add_2<<>>(dst->ne[0], nelem, (const float *)dst->src[0]->data, (const float *)dst->src[1]->data, (float *)dst->data); + } + else if (dst->src[0]->type == GGML_TYPE_BF16 && dst->src[1]->type == GGML_TYPE_BF16) { + k_fast_add_2<<>>(dst->ne[0], nelem, + (const nv_bfloat16 *)dst->src[0]->data, (const nv_bfloat16 *)dst->src[1]->data, (float *)dst->data); + } + else if (dst->src[0]->type == GGML_TYPE_BF16 && dst->src[1]->type == GGML_TYPE_F32) { + k_fast_add_2<<>>(dst->ne[0], nelem, + (const nv_bfloat16 *)dst->src[0]->data, (const float *)dst->src[1]->data, (float *)dst->data); + } + else if (dst->src[0]->type == GGML_TYPE_F32 && dst->src[1]->type == GGML_TYPE_BF16) { + k_fast_add_2<<>>(dst->ne[0], nelem, + (const float *)dst->src[0]->data, (const nv_bfloat16 *)dst->src[1]->data, (float *)dst->data); } else { k_fast_add_2<<>>(dst->ne[0], nelem, (const float *)dst->src[0]->data, (const half *)dst->src[1]->data, (float *)dst->data);