mtmd : add post-decode callback (#24645)

Assisted-by: pi:llama.cpp/Qwen3.6-27B
This commit is contained in:
Georgi Gerganov 2026-06-15 16:02:05 +03:00 committed by GitHub
parent 9dbc6621ae
commit e3cab403bf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 50 additions and 121 deletions

View File

@ -247,7 +247,9 @@ int32_t mtmd_helper_decode_image_chunk(
llama_pos n_past,
llama_seq_id seq_id,
int32_t n_batch,
llama_pos * new_n_past) {
llama_pos * new_n_past,
mtmd_helper_post_decode_callback callback,
void * user_data) {
GGML_ASSERT(n_batch > 0);
auto chunk_type = mtmd_input_chunk_get_type(chunk);
const char * name = chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE ? "image" : "audio";
@ -302,10 +304,23 @@ int32_t mtmd_helper_decode_image_chunk(
int32_t ret = llama_decode(lctx, batch_embd_view);
if (ret != 0) {
LOG_ERR("failed to decode %s\n", name);
llama_set_causal_attn(lctx, true); // restore causal attn
if (use_non_causal) {
llama_set_causal_attn(lctx, true);
}
return ret;
}
if (callback != nullptr) {
ret = callback(batch_embd_view, user_data);
if (ret != 0) {
LOG_ERR("post-decode callback failed\n");
if (use_non_causal) {
llama_set_causal_attn(lctx, true);
}
return ret;
}
}
LOG_INF("%s decoded (batch %d/%d) in %" PRId64 " ms\n", name, i_batch+1, n_img_batches, ggml_time_ms() - t1);
i_batch++;
@ -379,7 +394,7 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
LOG_INF("%s slice encoded in %" PRId64 " ms\n", name, ggml_time_ms() - t0);
float * embd = mtmd_get_output_embd(ctx);
ret = mtmd_helper_decode_image_chunk(ctx, lctx, chunk, embd, n_past, seq_id, n_batch, new_n_past);
ret = mtmd_helper_decode_image_chunk(ctx, lctx, chunk, embd, n_past, seq_id, n_batch, new_n_past, nullptr, nullptr);
if (ret != 0) {
LOG_ERR("failed to decode %s\n", name);
llama_batch_free(text_batch);

View File

@ -91,6 +91,8 @@ MTMD_API int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
bool logits_last,
llama_pos * new_n_past);
typedef int32_t (*mtmd_helper_post_decode_callback)(struct llama_batch batch, void * user_data);
// helper function to decode an image whose embeddings have already been calculated
// this helper will handle batching and pre/post decoding setup (for ex. gemma 3 requires non-causal attention)
// ret 0 on success, -1 on chunk not being a valid image chunk, 1 on decode failure
@ -101,7 +103,9 @@ MTMD_API int32_t mtmd_helper_decode_image_chunk(mtmd_context * ctx,
llama_pos n_past,
llama_seq_id seq_id,
int32_t n_batch,
llama_pos * new_n_past);
llama_pos * new_n_past,
mtmd_helper_post_decode_callback callback,
void * user_data);
//
// video input helpers (requires ffmpeg/ffprobe installed on the system)

View File

@ -539,37 +539,6 @@ bool server_tokens::validate(const struct llama_context * ctx) const {
return true;
}
int32_t server_tokens::process_chunk(
llama_context * ctx,
mtmd_context * mctx,
size_t idx,
llama_pos pos,
int32_t seq_id,
size_t & n_tokens_out) const {
const auto & chunk = find_chunk(idx);
const char * name = mtmd_input_chunk_get_type(chunk.get()) == MTMD_INPUT_CHUNK_TYPE_IMAGE
? "image" : "audio";
SRV_INF("processing %s...\n", name);
int32_t n_batch = llama_n_batch(ctx);
int64_t t0 = ggml_time_ms();
llama_pos new_n_past; // unused for now
int32_t result = mtmd_helper_eval_chunk_single(mctx, ctx,
chunk.get(),
pos,
seq_id,
n_batch,
true, // logits last
&new_n_past);
SRV_INF("%s processed in %" PRId64 " ms\n", name, ggml_time_ms() - t0);
if (result != 0) {
LOG_ERR("mtmd_helper_eval failed with status %d", result);
n_tokens_out = 0;
return result;
}
n_tokens_out = mtmd_input_chunk_get_n_tokens(chunk.get());
return 0;
}
server_tokens server_tokens::clone() const {
server_tokens res;
res.has_mtmd = has_mtmd;

View File

@ -221,15 +221,6 @@ public:
// make sure all text tokens are within the vocab range
bool validate(const struct llama_context * ctx) const;
// encode and decode the image chunk
int32_t process_chunk(
llama_context * ctx,
mtmd_context * mctx,
size_t idx,
llama_pos pos,
int32_t seq_id,
size_t & n_tokens_out) const;
server_tokens clone() const;
};

View File

@ -15,11 +15,6 @@
#include "mtmd.h"
#include "mtmd-helper.h"
#include "ggml-cpp.h"
// TODO: tmp until the mtmd draft processing is refactored [TAG_MTMD_DRAFT_PROCESSING]
#include "../../src/llama-ext.h"
#include <algorithm>
#include <cstddef>
#include <cinttypes>
@ -81,7 +76,6 @@ struct server_slot {
// multimodal
mtmd_context * mctx = nullptr;
mtmd::batch_ptr mbatch = nullptr;
std::array<llama_context *, 2> mtgt = {nullptr, nullptr}; // [0] for main context, [1] for optional draft context
// speculative decoding
common_speculative * spec;
@ -244,15 +238,6 @@ struct server_slot {
// clear multimodal state
mbatch.reset();
mtgt[0] = ctx_tgt;
mtgt[1] = nullptr;
if (ctx_dft && llama_get_ctx_other(ctx_dft) != ctx_tgt) {
// TODO: in the future, figure out how to infuse target embeddings to the images
// for now, we re-decode the same chunk in both ctx_tgt and ctx_dft
// maybe we simply need to call `common_speculative_process()` ?
// [TAG_MTMD_DRAFT_PROCESSING]
mtgt[1] = ctx_dft;
}
}
void init_sampler() const {
@ -598,32 +583,38 @@ struct server_slot {
int process_mtmd_chunk(size_t idx, size_t & n_tokens_out) {
GGML_ASSERT(mctx);
const auto & input_tokens = task->tokens;
auto & chunk = input_tokens.find_chunk(idx);
const auto & chunk = input_tokens.find_chunk(idx);
int32_t res = 0;
auto try_decode = [&]() -> int32_t {
if (mbatch) {
float * embd = mtmd_batch_get_output_embd(mbatch.get(), chunk.get());
if (embd) {
for (auto * lctx : mtgt) {
if (lctx == nullptr) {
continue;
}
llama_pos new_n_past; // unused for now
res = mtmd_helper_decode_image_chunk(
mctx,
lctx,
chunk.get(),
embd,
prompt.tokens.pos_next(),
id,
llama_n_batch(lctx),
&new_n_past
);
if (res != 0) {
SLT_ERR(*this, "failed to decode mtmd chunk, idx = %zu, res = %d\n", idx, res);
return -1;
void * cb_data = spec;
static auto cb = [](llama_batch batch, void * user_data) {
common_speculative * spec = static_cast<common_speculative *>(user_data);
if (!common_speculative_process(spec, batch)) {
return 1;
}
return 0;
};
llama_pos new_n_past; // unused for now
res = mtmd_helper_decode_image_chunk(
mctx,
ctx_tgt,
chunk.get(),
embd,
prompt.tokens.pos_next(),
id,
llama_n_batch(ctx_tgt),
&new_n_past,
cb,
cb_data
);
if (res != 0) {
SLT_ERR(*this, "failed to decode mtmd chunk, idx = %zu, res = %d\n", idx, res);
return -1;
}
n_tokens_out = mtmd_input_chunk_get_n_tokens(chunk.get());
return 0; // success
@ -636,7 +627,8 @@ struct server_slot {
res = try_decode();
if (res == 0) {
return 0;
} else if (res < 0) {
}
if (res < 0) {
// fatal error
return res;
}
@ -3350,48 +3342,6 @@ private:
// TODO: avoid restoring the draft context and re-evaluating the drafted tokens when not needed [TAG_SPEC_AVOID_DRAFT_REEVAL]
// for now, always re-evaluate for simplicity
// ref: https://github.com/ggml-org/llama.cpp/pull/22728#issuecomment-4400925384
//
// | spec type | need re-eval |
// | --- | --- |
// | draft model | no | because the draft model does not use embeddings from the target
// | MTP (std) | yes |
// | MTP Gemma4 | no | because the KV cache is shared
// | Eagle3 | yes |
// | DFlash | yes | https://github.com/ggml-org/llama.cpp/pull/22728#issuecomment-4405406982
//
// note: this logic is now moved in `common_speculative_process()`
// keeping the sketch here until for a bit, until the logic is finalized
//
//if (ctx_dft) {
// // TODO: update as needed for MTP, Eagle3, etc.
// const bool need_tgt_embd = false;
// if (need_tgt_embd) {
// llama_synchronize(ctx_tgt);
// }
// // the logic here varies depending on the speculative decoding method
// // - some draft contexts require embeddings from the target context, others don't
// // - some draft contexts involve an encoder step to transform the target embeddings to draft embeddings
// // TODO: extract this in a function ?
// {
// // TODO: hook the embeddings from the last target batch here
// if (llama_model_has_encoder(model_dft.get())) {
// //llama_encode(ctx_dft, ...);
// GGML_ABORT("not implemented yet\n");
// }
// const int ret = llama_decode(ctx_dft.get(), batch_view);
// if (ret != 0) {
// SRV_ERR("failed to decode draft batch, ret = %d\n", ret);
// // TODO: handle error
// break;
// }
// }
//}
if (!common_speculative_process(spec.get(), batch_view)) {
SRV_ERR("%s", "failed to process speculative batch\n");