mtmd, llama: shared backend sched

This commit is contained in:
Xuan Son Nguyen 2026-06-09 15:34:17 +02:00
parent d6d0ce8215
commit b6cf9cd8fe
7 changed files with 57 additions and 22 deletions

View File

@ -3515,6 +3515,10 @@ const llama_model * llama_get_model(const llama_context * ctx) {
return &ctx->get_model();
}
ggml_backend_sched_t llama_get_sched(const llama_context * ctx) {
return ctx->get_sched();
}
enum llama_pooling_type llama_pooling_type(const llama_context * ctx) {
return ctx->pooling_type();
}

View File

@ -102,3 +102,5 @@ LLAMA_API float * llama_get_embeddings_nextn(struct llama_context * ctx);
LLAMA_API float * llama_get_embeddings_nextn_ith(struct llama_context * ctx, int32_t i);
LLAMA_API llama_context * llama_get_ctx_other(struct llama_context * ctx);
LLAMA_API ggml_backend_sched_t llama_get_sched(const struct llama_context * ctx);

View File

@ -160,10 +160,15 @@ struct clip_ctx {
int max_nodes = 8192;
ggml_backend_sched_ptr sched;
ggml_backend_sched_ptr sched; // owned scheduler (when no external sched is provided)
ggml_backend_sched_t ext_sched = nullptr; // borrowed scheduler (not freed by this context)
clip_flash_attn_type flash_attn_type = CLIP_FLASH_ATTN_TYPE_AUTO;
bool is_allocated = false;
ggml_backend_sched_t get_sched() const {
return ext_sched ? ext_sched : sched.get();
}
bool debug_output_embeddings = false;
// for measuring memory usage
@ -211,12 +216,16 @@ struct clip_ctx {
backend_ptrs.push_back(backend_cpu);
backend_buft.push_back(ggml_backend_get_default_buffer_type(backend_cpu));
sched.reset(
ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), 8192, false, true)
);
if (ctx_params.sched) {
ext_sched = ctx_params.sched;
} else {
sched.reset(
ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), 8192, false, true)
);
}
if (ctx_params.cb_eval != nullptr) {
ggml_backend_sched_set_eval_callback(sched.get(), ctx_params.cb_eval, ctx_params.cb_eval_user_data);
ggml_backend_sched_set_eval_callback(get_sched(), ctx_params.cb_eval, ctx_params.cb_eval_user_data);
}
debug_output_embeddings = std::getenv("MTMD_DEBUG_EMBEDDINGS") != nullptr;
@ -2893,22 +2902,31 @@ struct clip_model_loader {
ctx_clip.buf_compute_meta.resize(ctx_clip.max_nodes * ggml_tensor_overhead() + ggml_graph_overhead());
ggml_cgraph * gf = clip_image_build_graph(&ctx_clip, batch);
ggml_backend_sched_reserve(ctx_clip.sched.get(), gf);
ctx_clip.mem_compute.clear();
for (size_t i = 0; i < ctx_clip.backend_ptrs.size(); ++i) {
ggml_backend_t backend = ctx_clip.backend_ptrs[i];
ggml_backend_buffer_type_t buft = ctx_clip.backend_buft[i];
size_t size = ggml_backend_sched_get_buffer_size(ctx_clip.sched.get(), backend);
if (size > 1) {
LOG_INF("%s: %10s compute buffer size = %8.2f MiB\n", __func__,
ggml_backend_buft_name(buft),
size / 1024.0 / 1024.0);
}
ctx_clip.mem_compute[ggml_backend_get_device(backend)] += size;
if (ctx_clip.sched) {
ggml_backend_sched_reserve(ctx_clip.get_sched(), gf);
}
const int n_splits = ggml_backend_sched_get_n_splits(ctx_clip.sched.get());
// TODO @ngxson : what to do here when we have ctx_clip.ext_sched?
// running ggml_backend_sched_reserve in ext_sched breaks pre-alloc tensors
ctx_clip.mem_compute.clear();
{
ggml_backend_sched_t sched = ctx_clip.get_sched();
int n_backends = ggml_backend_sched_get_n_backends(sched);
for (int i = 0; i < n_backends; ++i) {
ggml_backend_t backend = ggml_backend_sched_get_backend(sched, i);
ggml_backend_buffer_type_t buft = ggml_backend_sched_get_buffer_type(sched, backend);
size_t size = ggml_backend_sched_get_buffer_size(sched, backend);
if (size > 1) {
LOG_INF("%s: %10s compute buffer size = %8.2f MiB\n", __func__,
ggml_backend_buft_name(buft),
size / 1024.0 / 1024.0);
}
ctx_clip.mem_compute[ggml_backend_get_device(backend)] += size;
}
}
const int n_splits = ggml_backend_sched_get_n_splits(ctx_clip.get_sched());
const int n_nodes = ggml_graph_n_nodes(gf);
LOG_INF("%s: graph splits = %d, nodes = %d\n", __func__, n_splits, n_nodes);
@ -3497,9 +3515,9 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
}
// build the inference graph
ggml_backend_sched_reset(ctx->sched.get());
ggml_backend_sched_reset(ctx->get_sched());
ggml_cgraph * gf = clip_image_build_graph(ctx, imgs);
ggml_backend_sched_alloc_graph(ctx->sched.get(), gf);
ggml_backend_sched_alloc_graph(ctx->get_sched(), gf);
// set inputs
const auto & model = ctx->model;
@ -4392,7 +4410,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
}
}
auto status = ggml_backend_sched_graph_compute(ctx->sched.get(), gf);
auto status = ggml_backend_sched_graph_compute(ctx->get_sched(), gf);
if (status != GGML_STATUS_SUCCESS) {
LOG_ERR("%s: ggml_backend_sched_graph_compute failed with error %d\n", __func__, status);
return false;

View File

@ -52,6 +52,9 @@ struct clip_context_params {
ggml_backend_sched_eval_callback cb_eval;
void * cb_eval_user_data;
bool no_alloc;
// optional: share an existing scheduler instead of creating a new one
// caller must ensure the scheduler outlives the clip context
ggml_backend_sched_t sched;
};
struct clip_init_result {

View File

@ -197,6 +197,7 @@ mtmd_context_params mtmd_context_params_default() {
/* image_max_tokens */ -1,
/* cb_eval */ nullptr,
/* cb_eval_user_data */ nullptr,
/* sched */ nullptr,
};
return params;
}
@ -287,6 +288,7 @@ struct mtmd_context {
/* cb_eval */ ctx_params.cb_eval,
/* cb_eval_user_data */ ctx_params.cb_eval_user_data,
/* no_alloc */ no_alloc,
/* sched */ ctx_params.sched,
};
auto res = clip_init(mmproj_fname, ctx_clip_params);

View File

@ -97,6 +97,11 @@ struct mtmd_context_params {
// callback function passed over to mtmd proper
ggml_backend_sched_eval_callback cb_eval;
void * cb_eval_user_data;
// optional: share an existing scheduler (e.g. from llama_context via llama_get_sched())
// instead of creating a separate one, saving compute buffer memory
// caller must ensure the scheduler outlives the mtmd context
ggml_backend_sched_t sched;
};
MTMD_API const char * mtmd_default_marker(void);

View File

@ -982,6 +982,7 @@ private:
mtmd_helper_log_set(common_log_default_callback, nullptr);
}
mparams.sched = llama_get_sched(ctx_tgt);
mctx = mtmd_init_from_file(mmproj_path.c_str(), model_tgt, mparams);
if (mctx == nullptr) {
SRV_ERR("failed to load multimodal model, '%s'\n", mmproj_path.c_str());