diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 9a40c4366a..ca4bf0d1be 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -3515,6 +3515,10 @@ const llama_model * llama_get_model(const llama_context * ctx) { return &ctx->get_model(); } +ggml_backend_sched_t llama_get_sched(const llama_context * ctx) { + return ctx->get_sched(); +} + enum llama_pooling_type llama_pooling_type(const llama_context * ctx) { return ctx->pooling_type(); } diff --git a/src/llama-ext.h b/src/llama-ext.h index bd74544129..806a36bdc2 100644 --- a/src/llama-ext.h +++ b/src/llama-ext.h @@ -102,3 +102,5 @@ LLAMA_API float * llama_get_embeddings_nextn(struct llama_context * ctx); LLAMA_API float * llama_get_embeddings_nextn_ith(struct llama_context * ctx, int32_t i); LLAMA_API llama_context * llama_get_ctx_other(struct llama_context * ctx); + +LLAMA_API ggml_backend_sched_t llama_get_sched(const struct llama_context * ctx); diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index bd33f43062..1a2a006a82 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -160,10 +160,15 @@ struct clip_ctx { int max_nodes = 8192; - ggml_backend_sched_ptr sched; + ggml_backend_sched_ptr sched; // owned scheduler (when no external sched is provided) + ggml_backend_sched_t ext_sched = nullptr; // borrowed scheduler (not freed by this context) clip_flash_attn_type flash_attn_type = CLIP_FLASH_ATTN_TYPE_AUTO; bool is_allocated = false; + ggml_backend_sched_t get_sched() const { + return ext_sched ? ext_sched : sched.get(); + } + bool debug_output_embeddings = false; // for measuring memory usage @@ -211,12 +216,16 @@ struct clip_ctx { backend_ptrs.push_back(backend_cpu); backend_buft.push_back(ggml_backend_get_default_buffer_type(backend_cpu)); - sched.reset( - ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), 8192, false, true) - ); + if (ctx_params.sched) { + ext_sched = ctx_params.sched; + } else { + sched.reset( + ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), 8192, false, true) + ); + } if (ctx_params.cb_eval != nullptr) { - ggml_backend_sched_set_eval_callback(sched.get(), ctx_params.cb_eval, ctx_params.cb_eval_user_data); + ggml_backend_sched_set_eval_callback(get_sched(), ctx_params.cb_eval, ctx_params.cb_eval_user_data); } debug_output_embeddings = std::getenv("MTMD_DEBUG_EMBEDDINGS") != nullptr; @@ -2893,22 +2902,31 @@ struct clip_model_loader { ctx_clip.buf_compute_meta.resize(ctx_clip.max_nodes * ggml_tensor_overhead() + ggml_graph_overhead()); ggml_cgraph * gf = clip_image_build_graph(&ctx_clip, batch); - ggml_backend_sched_reserve(ctx_clip.sched.get(), gf); - - ctx_clip.mem_compute.clear(); - for (size_t i = 0; i < ctx_clip.backend_ptrs.size(); ++i) { - ggml_backend_t backend = ctx_clip.backend_ptrs[i]; - ggml_backend_buffer_type_t buft = ctx_clip.backend_buft[i]; - size_t size = ggml_backend_sched_get_buffer_size(ctx_clip.sched.get(), backend); - if (size > 1) { - LOG_INF("%s: %10s compute buffer size = %8.2f MiB\n", __func__, - ggml_backend_buft_name(buft), - size / 1024.0 / 1024.0); - } - ctx_clip.mem_compute[ggml_backend_get_device(backend)] += size; + if (ctx_clip.sched) { + ggml_backend_sched_reserve(ctx_clip.get_sched(), gf); } - const int n_splits = ggml_backend_sched_get_n_splits(ctx_clip.sched.get()); + // TODO @ngxson : what to do here when we have ctx_clip.ext_sched? + // running ggml_backend_sched_reserve in ext_sched breaks pre-alloc tensors + + ctx_clip.mem_compute.clear(); + { + ggml_backend_sched_t sched = ctx_clip.get_sched(); + int n_backends = ggml_backend_sched_get_n_backends(sched); + for (int i = 0; i < n_backends; ++i) { + ggml_backend_t backend = ggml_backend_sched_get_backend(sched, i); + ggml_backend_buffer_type_t buft = ggml_backend_sched_get_buffer_type(sched, backend); + size_t size = ggml_backend_sched_get_buffer_size(sched, backend); + if (size > 1) { + LOG_INF("%s: %10s compute buffer size = %8.2f MiB\n", __func__, + ggml_backend_buft_name(buft), + size / 1024.0 / 1024.0); + } + ctx_clip.mem_compute[ggml_backend_get_device(backend)] += size; + } + } + + const int n_splits = ggml_backend_sched_get_n_splits(ctx_clip.get_sched()); const int n_nodes = ggml_graph_n_nodes(gf); LOG_INF("%s: graph splits = %d, nodes = %d\n", __func__, n_splits, n_nodes); @@ -3497,9 +3515,9 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } // build the inference graph - ggml_backend_sched_reset(ctx->sched.get()); + ggml_backend_sched_reset(ctx->get_sched()); ggml_cgraph * gf = clip_image_build_graph(ctx, imgs); - ggml_backend_sched_alloc_graph(ctx->sched.get(), gf); + ggml_backend_sched_alloc_graph(ctx->get_sched(), gf); // set inputs const auto & model = ctx->model; @@ -4392,7 +4410,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } } - auto status = ggml_backend_sched_graph_compute(ctx->sched.get(), gf); + auto status = ggml_backend_sched_graph_compute(ctx->get_sched(), gf); if (status != GGML_STATUS_SUCCESS) { LOG_ERR("%s: ggml_backend_sched_graph_compute failed with error %d\n", __func__, status); return false; diff --git a/tools/mtmd/clip.h b/tools/mtmd/clip.h index 18c7a1d1a7..0067f936ca 100644 --- a/tools/mtmd/clip.h +++ b/tools/mtmd/clip.h @@ -52,6 +52,9 @@ struct clip_context_params { ggml_backend_sched_eval_callback cb_eval; void * cb_eval_user_data; bool no_alloc; + // optional: share an existing scheduler instead of creating a new one + // caller must ensure the scheduler outlives the clip context + ggml_backend_sched_t sched; }; struct clip_init_result { diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index 4140a3c4aa..cd9b185670 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -197,6 +197,7 @@ mtmd_context_params mtmd_context_params_default() { /* image_max_tokens */ -1, /* cb_eval */ nullptr, /* cb_eval_user_data */ nullptr, + /* sched */ nullptr, }; return params; } @@ -287,6 +288,7 @@ struct mtmd_context { /* cb_eval */ ctx_params.cb_eval, /* cb_eval_user_data */ ctx_params.cb_eval_user_data, /* no_alloc */ no_alloc, + /* sched */ ctx_params.sched, }; auto res = clip_init(mmproj_fname, ctx_clip_params); diff --git a/tools/mtmd/mtmd.h b/tools/mtmd/mtmd.h index a76a6ec2b8..89837852ca 100644 --- a/tools/mtmd/mtmd.h +++ b/tools/mtmd/mtmd.h @@ -97,6 +97,11 @@ struct mtmd_context_params { // callback function passed over to mtmd proper ggml_backend_sched_eval_callback cb_eval; void * cb_eval_user_data; + + // optional: share an existing scheduler (e.g. from llama_context via llama_get_sched()) + // instead of creating a separate one, saving compute buffer memory + // caller must ensure the scheduler outlives the mtmd context + ggml_backend_sched_t sched; }; MTMD_API const char * mtmd_default_marker(void); diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index bdfa517180..6366a59566 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -982,6 +982,7 @@ private: mtmd_helper_log_set(common_log_default_callback, nullptr); } + mparams.sched = llama_get_sched(ctx_tgt); mctx = mtmd_init_from_file(mmproj_path.c_str(), model_tgt, mparams); if (mctx == nullptr) { SRV_ERR("failed to load multimodal model, '%s'\n", mmproj_path.c_str());