diff --git a/common/common.cpp b/common/common.cpp index bb8ed772..9785dcda 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -3531,8 +3531,8 @@ static std::pair get_batch_ubatch(const gpt_params & params) { if (params.n_ctx > 0) { n_batch = std::min(n_batch, params.n_ctx); } - if (!params.mmproj.path.empty()) { - // temporary fix for qwen mtmd + if (!params.mmproj.path.empty() && params.mmproj_use_gpu) { + // temporary fix for qwen mtmd (only when mmproj is on GPU) n_batch = std::max(n_batch, n_ubatch); n_ubatch = n_batch; fprintf(stdout, "Adjust batch size for mtmd: u_batch = %d, batch = %d\n", n_ubatch, n_batch);