mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-06-28 04:30:15 -05:00
Fix bf16 graph reduce type
This commit is contained in:
parent
1660459db5
commit
17f05fc6ec
@ -476,6 +476,18 @@ void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
else if (dst->src[0]->type == GGML_TYPE_F32 && dst->src[1]->type == GGML_TYPE_F32) {
|
||||
k_fast_add_2<<<nblocks, kBlockSize, 0, ctx.stream()>>>(dst->ne[0], nelem,
|
||||
(const float *)dst->src[0]->data, (const float *)dst->src[1]->data, (float *)dst->data);
|
||||
}
|
||||
else if (dst->src[0]->type == GGML_TYPE_BF16 && dst->src[1]->type == GGML_TYPE_BF16) {
|
||||
k_fast_add_2<<<nblocks, kBlockSize, 0, ctx.stream()>>>(dst->ne[0], nelem,
|
||||
(const nv_bfloat16 *)dst->src[0]->data, (const nv_bfloat16 *)dst->src[1]->data, (float *)dst->data);
|
||||
}
|
||||
else if (dst->src[0]->type == GGML_TYPE_BF16 && dst->src[1]->type == GGML_TYPE_F32) {
|
||||
k_fast_add_2<<<nblocks, kBlockSize, 0, ctx.stream()>>>(dst->ne[0], nelem,
|
||||
(const nv_bfloat16 *)dst->src[0]->data, (const float *)dst->src[1]->data, (float *)dst->data);
|
||||
}
|
||||
else if (dst->src[0]->type == GGML_TYPE_F32 && dst->src[1]->type == GGML_TYPE_BF16) {
|
||||
k_fast_add_2<<<nblocks, kBlockSize, 0, ctx.stream()>>>(dst->ne[0], nelem,
|
||||
(const float *)dst->src[0]->data, (const nv_bfloat16 *)dst->src[1]->data, (float *)dst->data);
|
||||
} else {
|
||||
k_fast_add_2<<<nblocks, kBlockSize, 0, ctx.stream()>>>(dst->ne[0], nelem,
|
||||
(const float *)dst->src[0]->data, (const half *)dst->src[1]->data, (float *)dst->data);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user