mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-27 23:50:20 -05:00
mtmd : add post-decode callback (#24645)
Assisted-by: pi:llama.cpp/Qwen3.6-27B
This commit is contained in:
parent
9dbc6621ae
commit
e3cab403bf
@ -247,7 +247,9 @@ int32_t mtmd_helper_decode_image_chunk(
|
||||
llama_pos n_past,
|
||||
llama_seq_id seq_id,
|
||||
int32_t n_batch,
|
||||
llama_pos * new_n_past) {
|
||||
llama_pos * new_n_past,
|
||||
mtmd_helper_post_decode_callback callback,
|
||||
void * user_data) {
|
||||
GGML_ASSERT(n_batch > 0);
|
||||
auto chunk_type = mtmd_input_chunk_get_type(chunk);
|
||||
const char * name = chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE ? "image" : "audio";
|
||||
@ -302,10 +304,23 @@ int32_t mtmd_helper_decode_image_chunk(
|
||||
int32_t ret = llama_decode(lctx, batch_embd_view);
|
||||
if (ret != 0) {
|
||||
LOG_ERR("failed to decode %s\n", name);
|
||||
llama_set_causal_attn(lctx, true); // restore causal attn
|
||||
if (use_non_causal) {
|
||||
llama_set_causal_attn(lctx, true);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
if (callback != nullptr) {
|
||||
ret = callback(batch_embd_view, user_data);
|
||||
if (ret != 0) {
|
||||
LOG_ERR("post-decode callback failed\n");
|
||||
if (use_non_causal) {
|
||||
llama_set_causal_attn(lctx, true);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
LOG_INF("%s decoded (batch %d/%d) in %" PRId64 " ms\n", name, i_batch+1, n_img_batches, ggml_time_ms() - t1);
|
||||
|
||||
i_batch++;
|
||||
@ -379,7 +394,7 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
|
||||
LOG_INF("%s slice encoded in %" PRId64 " ms\n", name, ggml_time_ms() - t0);
|
||||
|
||||
float * embd = mtmd_get_output_embd(ctx);
|
||||
ret = mtmd_helper_decode_image_chunk(ctx, lctx, chunk, embd, n_past, seq_id, n_batch, new_n_past);
|
||||
ret = mtmd_helper_decode_image_chunk(ctx, lctx, chunk, embd, n_past, seq_id, n_batch, new_n_past, nullptr, nullptr);
|
||||
if (ret != 0) {
|
||||
LOG_ERR("failed to decode %s\n", name);
|
||||
llama_batch_free(text_batch);
|
||||
|
||||
@ -91,6 +91,8 @@ MTMD_API int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
|
||||
bool logits_last,
|
||||
llama_pos * new_n_past);
|
||||
|
||||
typedef int32_t (*mtmd_helper_post_decode_callback)(struct llama_batch batch, void * user_data);
|
||||
|
||||
// helper function to decode an image whose embeddings have already been calculated
|
||||
// this helper will handle batching and pre/post decoding setup (for ex. gemma 3 requires non-causal attention)
|
||||
// ret 0 on success, -1 on chunk not being a valid image chunk, 1 on decode failure
|
||||
@ -101,7 +103,9 @@ MTMD_API int32_t mtmd_helper_decode_image_chunk(mtmd_context * ctx,
|
||||
llama_pos n_past,
|
||||
llama_seq_id seq_id,
|
||||
int32_t n_batch,
|
||||
llama_pos * new_n_past);
|
||||
llama_pos * new_n_past,
|
||||
mtmd_helper_post_decode_callback callback,
|
||||
void * user_data);
|
||||
|
||||
//
|
||||
// video input helpers (requires ffmpeg/ffprobe installed on the system)
|
||||
|
||||
@ -539,37 +539,6 @@ bool server_tokens::validate(const struct llama_context * ctx) const {
|
||||
return true;
|
||||
}
|
||||
|
||||
int32_t server_tokens::process_chunk(
|
||||
llama_context * ctx,
|
||||
mtmd_context * mctx,
|
||||
size_t idx,
|
||||
llama_pos pos,
|
||||
int32_t seq_id,
|
||||
size_t & n_tokens_out) const {
|
||||
const auto & chunk = find_chunk(idx);
|
||||
const char * name = mtmd_input_chunk_get_type(chunk.get()) == MTMD_INPUT_CHUNK_TYPE_IMAGE
|
||||
? "image" : "audio";
|
||||
SRV_INF("processing %s...\n", name);
|
||||
int32_t n_batch = llama_n_batch(ctx);
|
||||
int64_t t0 = ggml_time_ms();
|
||||
llama_pos new_n_past; // unused for now
|
||||
int32_t result = mtmd_helper_eval_chunk_single(mctx, ctx,
|
||||
chunk.get(),
|
||||
pos,
|
||||
seq_id,
|
||||
n_batch,
|
||||
true, // logits last
|
||||
&new_n_past);
|
||||
SRV_INF("%s processed in %" PRId64 " ms\n", name, ggml_time_ms() - t0);
|
||||
if (result != 0) {
|
||||
LOG_ERR("mtmd_helper_eval failed with status %d", result);
|
||||
n_tokens_out = 0;
|
||||
return result;
|
||||
}
|
||||
n_tokens_out = mtmd_input_chunk_get_n_tokens(chunk.get());
|
||||
return 0;
|
||||
}
|
||||
|
||||
server_tokens server_tokens::clone() const {
|
||||
server_tokens res;
|
||||
res.has_mtmd = has_mtmd;
|
||||
|
||||
@ -221,15 +221,6 @@ public:
|
||||
// make sure all text tokens are within the vocab range
|
||||
bool validate(const struct llama_context * ctx) const;
|
||||
|
||||
// encode and decode the image chunk
|
||||
int32_t process_chunk(
|
||||
llama_context * ctx,
|
||||
mtmd_context * mctx,
|
||||
size_t idx,
|
||||
llama_pos pos,
|
||||
int32_t seq_id,
|
||||
size_t & n_tokens_out) const;
|
||||
|
||||
server_tokens clone() const;
|
||||
};
|
||||
|
||||
|
||||
@ -15,11 +15,6 @@
|
||||
#include "mtmd.h"
|
||||
#include "mtmd-helper.h"
|
||||
|
||||
#include "ggml-cpp.h"
|
||||
|
||||
// TODO: tmp until the mtmd draft processing is refactored [TAG_MTMD_DRAFT_PROCESSING]
|
||||
#include "../../src/llama-ext.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstddef>
|
||||
#include <cinttypes>
|
||||
@ -81,7 +76,6 @@ struct server_slot {
|
||||
// multimodal
|
||||
mtmd_context * mctx = nullptr;
|
||||
mtmd::batch_ptr mbatch = nullptr;
|
||||
std::array<llama_context *, 2> mtgt = {nullptr, nullptr}; // [0] for main context, [1] for optional draft context
|
||||
|
||||
// speculative decoding
|
||||
common_speculative * spec;
|
||||
@ -244,15 +238,6 @@ struct server_slot {
|
||||
|
||||
// clear multimodal state
|
||||
mbatch.reset();
|
||||
mtgt[0] = ctx_tgt;
|
||||
mtgt[1] = nullptr;
|
||||
if (ctx_dft && llama_get_ctx_other(ctx_dft) != ctx_tgt) {
|
||||
// TODO: in the future, figure out how to infuse target embeddings to the images
|
||||
// for now, we re-decode the same chunk in both ctx_tgt and ctx_dft
|
||||
// maybe we simply need to call `common_speculative_process()` ?
|
||||
// [TAG_MTMD_DRAFT_PROCESSING]
|
||||
mtgt[1] = ctx_dft;
|
||||
}
|
||||
}
|
||||
|
||||
void init_sampler() const {
|
||||
@ -598,32 +583,38 @@ struct server_slot {
|
||||
int process_mtmd_chunk(size_t idx, size_t & n_tokens_out) {
|
||||
GGML_ASSERT(mctx);
|
||||
const auto & input_tokens = task->tokens;
|
||||
auto & chunk = input_tokens.find_chunk(idx);
|
||||
const auto & chunk = input_tokens.find_chunk(idx);
|
||||
int32_t res = 0;
|
||||
|
||||
auto try_decode = [&]() -> int32_t {
|
||||
if (mbatch) {
|
||||
float * embd = mtmd_batch_get_output_embd(mbatch.get(), chunk.get());
|
||||
if (embd) {
|
||||
for (auto * lctx : mtgt) {
|
||||
if (lctx == nullptr) {
|
||||
continue;
|
||||
}
|
||||
llama_pos new_n_past; // unused for now
|
||||
res = mtmd_helper_decode_image_chunk(
|
||||
mctx,
|
||||
lctx,
|
||||
chunk.get(),
|
||||
embd,
|
||||
prompt.tokens.pos_next(),
|
||||
id,
|
||||
llama_n_batch(lctx),
|
||||
&new_n_past
|
||||
);
|
||||
if (res != 0) {
|
||||
SLT_ERR(*this, "failed to decode mtmd chunk, idx = %zu, res = %d\n", idx, res);
|
||||
return -1;
|
||||
void * cb_data = spec;
|
||||
static auto cb = [](llama_batch batch, void * user_data) {
|
||||
common_speculative * spec = static_cast<common_speculative *>(user_data);
|
||||
if (!common_speculative_process(spec, batch)) {
|
||||
return 1;
|
||||
}
|
||||
return 0;
|
||||
};
|
||||
|
||||
llama_pos new_n_past; // unused for now
|
||||
res = mtmd_helper_decode_image_chunk(
|
||||
mctx,
|
||||
ctx_tgt,
|
||||
chunk.get(),
|
||||
embd,
|
||||
prompt.tokens.pos_next(),
|
||||
id,
|
||||
llama_n_batch(ctx_tgt),
|
||||
&new_n_past,
|
||||
cb,
|
||||
cb_data
|
||||
);
|
||||
if (res != 0) {
|
||||
SLT_ERR(*this, "failed to decode mtmd chunk, idx = %zu, res = %d\n", idx, res);
|
||||
return -1;
|
||||
}
|
||||
n_tokens_out = mtmd_input_chunk_get_n_tokens(chunk.get());
|
||||
return 0; // success
|
||||
@ -636,7 +627,8 @@ struct server_slot {
|
||||
res = try_decode();
|
||||
if (res == 0) {
|
||||
return 0;
|
||||
} else if (res < 0) {
|
||||
}
|
||||
if (res < 0) {
|
||||
// fatal error
|
||||
return res;
|
||||
}
|
||||
@ -3350,48 +3342,6 @@ private:
|
||||
// TODO: avoid restoring the draft context and re-evaluating the drafted tokens when not needed [TAG_SPEC_AVOID_DRAFT_REEVAL]
|
||||
// for now, always re-evaluate for simplicity
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/22728#issuecomment-4400925384
|
||||
//
|
||||
// | spec type | need re-eval |
|
||||
// | --- | --- |
|
||||
// | draft model | no | because the draft model does not use embeddings from the target
|
||||
// | MTP (std) | yes |
|
||||
// | MTP Gemma4 | no | because the KV cache is shared
|
||||
// | Eagle3 | yes |
|
||||
// | DFlash | yes | https://github.com/ggml-org/llama.cpp/pull/22728#issuecomment-4405406982
|
||||
//
|
||||
// note: this logic is now moved in `common_speculative_process()`
|
||||
// keeping the sketch here until for a bit, until the logic is finalized
|
||||
//
|
||||
//if (ctx_dft) {
|
||||
// // TODO: update as needed for MTP, Eagle3, etc.
|
||||
// const bool need_tgt_embd = false;
|
||||
|
||||
// if (need_tgt_embd) {
|
||||
// llama_synchronize(ctx_tgt);
|
||||
// }
|
||||
|
||||
// // the logic here varies depending on the speculative decoding method
|
||||
// // - some draft contexts require embeddings from the target context, others don't
|
||||
// // - some draft contexts involve an encoder step to transform the target embeddings to draft embeddings
|
||||
// // TODO: extract this in a function ?
|
||||
// {
|
||||
// // TODO: hook the embeddings from the last target batch here
|
||||
// if (llama_model_has_encoder(model_dft.get())) {
|
||||
// //llama_encode(ctx_dft, ...);
|
||||
|
||||
// GGML_ABORT("not implemented yet\n");
|
||||
// }
|
||||
|
||||
// const int ret = llama_decode(ctx_dft.get(), batch_view);
|
||||
|
||||
// if (ret != 0) {
|
||||
// SRV_ERR("failed to decode draft batch, ret = %d\n", ret);
|
||||
|
||||
// // TODO: handle error
|
||||
// break;
|
||||
// }
|
||||
// }
|
||||
//}
|
||||
if (!common_speculative_process(spec.get(), batch_view)) {
|
||||
SRV_ERR("%s", "failed to process speculative batch\n");
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user