diff --git a/common/common.cpp b/common/common.cpp index 9785dcda..7a3c054c 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -618,6 +618,14 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa } return true; } + if (arg == "-tm" || arg == "--threads-mtmd") { + CHECK_ARG + params.n_threads_mtmd = std::stoi(argv[i]); + if (params.n_threads_mtmd <= 0) { + params.n_threads_mtmd = std::thread::hardware_concurrency(); + } + return true; + } if (arg == "-td" || arg == "--threads-draft") { CHECK_ARG params.speculative.n_threads = std::stoi(argv[i]); @@ -2461,6 +2469,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", "-s, --seed SEED", "RNG seed (default: %d, use random seed for < 0)", params.seed }); options.push_back({ "*", "-t, --threads N", "number of threads to use during generation (default: %d)", params.n_threads }); options.push_back({ "*", "-tb, --threads-batch N", "number of threads to use during batch and prompt processing (default: same as --threads)" }); + options.push_back({ "multi-modality", "-tm, --threads-mtmd N", "number of threads to use during multimodal image processing (default: same as --threads-batch)" }); options.push_back({ "speculative", "-td, --threads-draft N", "number of threads to use during generation (default: same as --threads)" }); options.push_back({ "speculative", "-tbd, --threads-batch-draft N", "number of threads to use during batch and prompt processing (default: same as --threads-draft)" }); @@ -2896,6 +2905,9 @@ std::string gpt_params_get_system_info(const gpt_params & params) { if (params.n_threads_batch != -1) { os << " (n_threads_batch = " << params.n_threads_batch << ")"; } + if (params.n_threads_mtmd != -1) { + os << " (n_threads_mtmd = " << params.n_threads_mtmd << ")"; + } os << " / " << std::thread::hardware_concurrency() << " | " << llama_print_system_info(); return os.str(); @@ -3021,7 +3033,7 @@ std::string string_lower(const std::string& str) { std::string result = str; for (char& c : result) { if (c >= 'A' && c <= 'Z') { - c = static_cast(c + ('a' - 'A')); + c = static_cast(c + ('a' - 'A')); } } return result; diff --git a/common/common.h b/common/common.h index ffaa2d5d..273c0813 100644 --- a/common/common.h +++ b/common/common.h @@ -410,6 +410,7 @@ struct gpt_params { int image_min_tokens = -1; int image_max_tokens = -1; std::string mtmd_kq_type = "f32"; + int32_t n_threads_mtmd = -1; // number of threads to use for multimodal processing (-1 = use n_threads_batch, then n_threads) // embedding bool embedding = false; // get only sentence embedding diff --git a/examples/mtmd/mtmd-cli.cpp b/examples/mtmd/mtmd-cli.cpp index 496a1e53..02d0c26a 100644 --- a/examples/mtmd/mtmd-cli.cpp +++ b/examples/mtmd/mtmd-cli.cpp @@ -182,7 +182,9 @@ struct mtmd_cli_context { mtmd_context_params mparams = mtmd_context_params_default(); mparams.use_gpu = params.mmproj_use_gpu; mparams.print_timings = true; - mparams.n_threads = params.n_threads; + mparams.n_threads = params.n_threads_mtmd != -1 ? params.n_threads_mtmd + : params.n_threads_batch != -1 ? params.n_threads_batch + : params.n_threads; mparams.verbosity = params.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO; mparams.flash_attn_type = params.flash_attn ? LLAMA_FLASH_ATTN_TYPE_ENABLED : LLAMA_FLASH_ATTN_TYPE_DISABLED; mparams.image_min_tokens = params.image_min_tokens; diff --git a/examples/server/server-context.cpp b/examples/server/server-context.cpp index b272162e..24dcf533 100644 --- a/examples/server/server-context.cpp +++ b/examples/server/server-context.cpp @@ -292,7 +292,9 @@ bool server_context::load_model(const gpt_params& params_) { mtmd_context_params mparams = mtmd_context_params_default(); mparams.use_gpu = params_base.mmproj_use_gpu; mparams.print_timings = false; - mparams.n_threads = params_base.n_threads; + mparams.n_threads = params_base.n_threads_mtmd != -1 ? params_base.n_threads_mtmd + : params_base.n_threads_batch != -1 ? params_base.n_threads_batch + : params_base.n_threads; mparams.flash_attn_type = params_base.flash_attn ? LLAMA_FLASH_ATTN_TYPE_ENABLED : LLAMA_FLASH_ATTN_TYPE_DISABLED; mparams.verbosity = params_base.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO; mparams.image_min_tokens = params_base.image_min_tokens;