diff --git a/src/graphs/build_dflash.cpp b/src/graphs/build_dflash.cpp index adb583ef..4fcc43fb 100644 --- a/src/graphs/build_dflash.cpp +++ b/src/graphs/build_dflash.cpp @@ -366,6 +366,13 @@ ggml_cgraph * llm_build_context::build_dflash() { cb(Vcur, "dflash_main_v_pad", il); } + if (Kcur->type == GGML_TYPE_F32) { + Kcur = ggml_cast(ctx0, Kcur, GGML_TYPE_F16); + } + if (Vcur->type == GGML_TYPE_F32) { + Vcur = ggml_cast(ctx0, Vcur, GGML_TYPE_F16); + } + cb(Qcur, "Qcur", il); ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);