diff --git a/tools/mtmd/mtmd-helper.cpp b/tools/mtmd/mtmd-helper.cpp index 2d11a33804..b5c4089232 100644 --- a/tools/mtmd/mtmd-helper.cpp +++ b/tools/mtmd/mtmd-helper.cpp @@ -247,7 +247,9 @@ int32_t mtmd_helper_decode_image_chunk( llama_pos n_past, llama_seq_id seq_id, int32_t n_batch, - llama_pos * new_n_past) { + llama_pos * new_n_past, + mtmd_helper_post_decode_callback callback, + void * user_data) { GGML_ASSERT(n_batch > 0); auto chunk_type = mtmd_input_chunk_get_type(chunk); const char * name = chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE ? "image" : "audio"; @@ -302,10 +304,23 @@ int32_t mtmd_helper_decode_image_chunk( int32_t ret = llama_decode(lctx, batch_embd_view); if (ret != 0) { LOG_ERR("failed to decode %s\n", name); - llama_set_causal_attn(lctx, true); // restore causal attn + if (use_non_causal) { + llama_set_causal_attn(lctx, true); + } return ret; } + if (callback != nullptr) { + ret = callback(batch_embd_view, user_data); + if (ret != 0) { + LOG_ERR("post-decode callback failed\n"); + if (use_non_causal) { + llama_set_causal_attn(lctx, true); + } + return ret; + } + } + LOG_INF("%s decoded (batch %d/%d) in %" PRId64 " ms\n", name, i_batch+1, n_img_batches, ggml_time_ms() - t1); i_batch++; @@ -379,7 +394,7 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx, LOG_INF("%s slice encoded in %" PRId64 " ms\n", name, ggml_time_ms() - t0); float * embd = mtmd_get_output_embd(ctx); - ret = mtmd_helper_decode_image_chunk(ctx, lctx, chunk, embd, n_past, seq_id, n_batch, new_n_past); + ret = mtmd_helper_decode_image_chunk(ctx, lctx, chunk, embd, n_past, seq_id, n_batch, new_n_past, nullptr, nullptr); if (ret != 0) { LOG_ERR("failed to decode %s\n", name); llama_batch_free(text_batch); diff --git a/tools/mtmd/mtmd-helper.h b/tools/mtmd/mtmd-helper.h index 719aae9885..680a2317df 100644 --- a/tools/mtmd/mtmd-helper.h +++ b/tools/mtmd/mtmd-helper.h @@ -91,6 +91,8 @@ MTMD_API int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx, bool logits_last, llama_pos * new_n_past); +typedef int32_t (*mtmd_helper_post_decode_callback)(struct llama_batch batch, void * user_data); + // helper function to decode an image whose embeddings have already been calculated // this helper will handle batching and pre/post decoding setup (for ex. gemma 3 requires non-causal attention) // ret 0 on success, -1 on chunk not being a valid image chunk, 1 on decode failure @@ -101,7 +103,9 @@ MTMD_API int32_t mtmd_helper_decode_image_chunk(mtmd_context * ctx, llama_pos n_past, llama_seq_id seq_id, int32_t n_batch, - llama_pos * new_n_past); + llama_pos * new_n_past, + mtmd_helper_post_decode_callback callback, + void * user_data); // // video input helpers (requires ffmpeg/ffprobe installed on the system) diff --git a/tools/server/server-common.cpp b/tools/server/server-common.cpp index 2b89a8bc5a..75729e62dd 100644 --- a/tools/server/server-common.cpp +++ b/tools/server/server-common.cpp @@ -539,37 +539,6 @@ bool server_tokens::validate(const struct llama_context * ctx) const { return true; } -int32_t server_tokens::process_chunk( - llama_context * ctx, - mtmd_context * mctx, - size_t idx, - llama_pos pos, - int32_t seq_id, - size_t & n_tokens_out) const { - const auto & chunk = find_chunk(idx); - const char * name = mtmd_input_chunk_get_type(chunk.get()) == MTMD_INPUT_CHUNK_TYPE_IMAGE - ? "image" : "audio"; - SRV_INF("processing %s...\n", name); - int32_t n_batch = llama_n_batch(ctx); - int64_t t0 = ggml_time_ms(); - llama_pos new_n_past; // unused for now - int32_t result = mtmd_helper_eval_chunk_single(mctx, ctx, - chunk.get(), - pos, - seq_id, - n_batch, - true, // logits last - &new_n_past); - SRV_INF("%s processed in %" PRId64 " ms\n", name, ggml_time_ms() - t0); - if (result != 0) { - LOG_ERR("mtmd_helper_eval failed with status %d", result); - n_tokens_out = 0; - return result; - } - n_tokens_out = mtmd_input_chunk_get_n_tokens(chunk.get()); - return 0; -} - server_tokens server_tokens::clone() const { server_tokens res; res.has_mtmd = has_mtmd; diff --git a/tools/server/server-common.h b/tools/server/server-common.h index 857ffe1479..f286b3d156 100644 --- a/tools/server/server-common.h +++ b/tools/server/server-common.h @@ -221,15 +221,6 @@ public: // make sure all text tokens are within the vocab range bool validate(const struct llama_context * ctx) const; - // encode and decode the image chunk - int32_t process_chunk( - llama_context * ctx, - mtmd_context * mctx, - size_t idx, - llama_pos pos, - int32_t seq_id, - size_t & n_tokens_out) const; - server_tokens clone() const; }; diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 986b2f15d5..bcae39a109 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -15,11 +15,6 @@ #include "mtmd.h" #include "mtmd-helper.h" -#include "ggml-cpp.h" - -// TODO: tmp until the mtmd draft processing is refactored [TAG_MTMD_DRAFT_PROCESSING] -#include "../../src/llama-ext.h" - #include #include #include @@ -81,7 +76,6 @@ struct server_slot { // multimodal mtmd_context * mctx = nullptr; mtmd::batch_ptr mbatch = nullptr; - std::array mtgt = {nullptr, nullptr}; // [0] for main context, [1] for optional draft context // speculative decoding common_speculative * spec; @@ -244,15 +238,6 @@ struct server_slot { // clear multimodal state mbatch.reset(); - mtgt[0] = ctx_tgt; - mtgt[1] = nullptr; - if (ctx_dft && llama_get_ctx_other(ctx_dft) != ctx_tgt) { - // TODO: in the future, figure out how to infuse target embeddings to the images - // for now, we re-decode the same chunk in both ctx_tgt and ctx_dft - // maybe we simply need to call `common_speculative_process()` ? - // [TAG_MTMD_DRAFT_PROCESSING] - mtgt[1] = ctx_dft; - } } void init_sampler() const { @@ -598,32 +583,38 @@ struct server_slot { int process_mtmd_chunk(size_t idx, size_t & n_tokens_out) { GGML_ASSERT(mctx); const auto & input_tokens = task->tokens; - auto & chunk = input_tokens.find_chunk(idx); + const auto & chunk = input_tokens.find_chunk(idx); int32_t res = 0; auto try_decode = [&]() -> int32_t { if (mbatch) { float * embd = mtmd_batch_get_output_embd(mbatch.get(), chunk.get()); if (embd) { - for (auto * lctx : mtgt) { - if (lctx == nullptr) { - continue; - } - llama_pos new_n_past; // unused for now - res = mtmd_helper_decode_image_chunk( - mctx, - lctx, - chunk.get(), - embd, - prompt.tokens.pos_next(), - id, - llama_n_batch(lctx), - &new_n_past - ); - if (res != 0) { - SLT_ERR(*this, "failed to decode mtmd chunk, idx = %zu, res = %d\n", idx, res); - return -1; + void * cb_data = spec; + static auto cb = [](llama_batch batch, void * user_data) { + common_speculative * spec = static_cast(user_data); + if (!common_speculative_process(spec, batch)) { + return 1; } + return 0; + }; + + llama_pos new_n_past; // unused for now + res = mtmd_helper_decode_image_chunk( + mctx, + ctx_tgt, + chunk.get(), + embd, + prompt.tokens.pos_next(), + id, + llama_n_batch(ctx_tgt), + &new_n_past, + cb, + cb_data + ); + if (res != 0) { + SLT_ERR(*this, "failed to decode mtmd chunk, idx = %zu, res = %d\n", idx, res); + return -1; } n_tokens_out = mtmd_input_chunk_get_n_tokens(chunk.get()); return 0; // success @@ -636,7 +627,8 @@ struct server_slot { res = try_decode(); if (res == 0) { return 0; - } else if (res < 0) { + } + if (res < 0) { // fatal error return res; } @@ -3350,48 +3342,6 @@ private: // TODO: avoid restoring the draft context and re-evaluating the drafted tokens when not needed [TAG_SPEC_AVOID_DRAFT_REEVAL] // for now, always re-evaluate for simplicity // ref: https://github.com/ggml-org/llama.cpp/pull/22728#issuecomment-4400925384 - // - // | spec type | need re-eval | - // | --- | --- | - // | draft model | no | because the draft model does not use embeddings from the target - // | MTP (std) | yes | - // | MTP Gemma4 | no | because the KV cache is shared - // | Eagle3 | yes | - // | DFlash | yes | https://github.com/ggml-org/llama.cpp/pull/22728#issuecomment-4405406982 - // - // note: this logic is now moved in `common_speculative_process()` - // keeping the sketch here until for a bit, until the logic is finalized - // - //if (ctx_dft) { - // // TODO: update as needed for MTP, Eagle3, etc. - // const bool need_tgt_embd = false; - - // if (need_tgt_embd) { - // llama_synchronize(ctx_tgt); - // } - - // // the logic here varies depending on the speculative decoding method - // // - some draft contexts require embeddings from the target context, others don't - // // - some draft contexts involve an encoder step to transform the target embeddings to draft embeddings - // // TODO: extract this in a function ? - // { - // // TODO: hook the embeddings from the last target batch here - // if (llama_model_has_encoder(model_dft.get())) { - // //llama_encode(ctx_dft, ...); - - // GGML_ABORT("not implemented yet\n"); - // } - - // const int ret = llama_decode(ctx_dft.get(), batch_view); - - // if (ret != 0) { - // SRV_ERR("failed to decode draft batch, ret = %d\n", ret); - - // // TODO: handle error - // break; - // } - // } - //} if (!common_speculative_process(spec.get(), batch_view)) { SRV_ERR("%s", "failed to process speculative batch\n");