diff --git a/tools/mtmd/mtmd-cli.cpp b/tools/mtmd/mtmd-cli.cpp index a3cad7cd06..0ad000ef01 100644 --- a/tools/mtmd/mtmd-cli.cpp +++ b/tools/mtmd/mtmd-cli.cpp @@ -32,9 +32,9 @@ static volatile bool g_is_generating = false; static volatile bool g_is_interrupted = false; /** - * Please note that this is NOT a production-ready stuff. + * Please note that this is NOT a production-ready binary. * It is a playground for trying multimodal support in llama.cpp. - * For contributors: please keep this code simple and easy to understand. + * For contributors: please keep this code simple and easy to understand. Do not add unnecessary complexity. The goal is to have a simple CLI for testing multimodal support. */ static void show_additional_info(int /*argc*/, char ** argv) { @@ -65,6 +65,14 @@ static void sigint_handler(int signo) { } #endif +// this is only used by tests.sh to capture the response ; it's not meant to be used in production +static void inject_test_response_marker() { + const char * env = std::getenv("MTMD_TEST_RESPONSE_MARKER"); + if (env) { + LOG("%s\n", env); + } +} + struct mtmd_cli_context { mtmd::context_ptr ctx_vision; common_init_result_ptr llama_init; @@ -79,6 +87,8 @@ struct mtmd_cli_context { mtmd::bitmaps bitmaps; std::vector videos; + mtmd::batch_ptr mbatch; + // chat template common_chat_templates_ptr tmpls; std::vector chat_history; @@ -233,6 +243,8 @@ static std::string chat_add_and_format(mtmd_cli_context & ctx, common_chat_msg & } static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg) { + inject_test_response_marker(); + bool add_bos = ctx.chat_history.empty(); auto formatted_chat = chat_add_and_format(ctx, msg); LOG_DBG("formatted_chat.prompt: %s\n", formatted_chat.c_str()); @@ -259,20 +271,95 @@ static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg) { ctx.bitmaps.entries.clear(); ctx.videos.clear(); - llama_pos new_n_past; - if (mtmd_helper_eval_chunks(ctx.ctx_vision.get(), - ctx.lctx, // lctx - chunks.ptr.get(), // chunks - ctx.n_past, // n_past - 0, // seq_id - ctx.n_batch, // n_batch - true, // logits_last - &new_n_past)) { - LOG_ERR("Unable to eval prompt\n"); - return 1; - } + // batch encode all media chunks, then decode each + size_t n_chunks = mtmd_input_chunks_size(chunks.ptr.get()); + for (size_t i = 0; i < n_chunks; i++) { + auto chunk = mtmd_input_chunks_get(chunks.ptr.get(), i); + auto chunk_type = mtmd_input_chunk_get_type(chunk); - ctx.n_past = new_n_past; + if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) { + // decode text chunk + llama_pos new_n_past = ctx.n_past; + res = mtmd_helper_eval_chunk_single(ctx.ctx_vision.get(), + ctx.lctx, + chunk, + ctx.n_past, + 0, // seq_id + ctx.n_batch, + i == n_chunks - 1, // logits_last + &new_n_past); + if (res != 0) { + LOG_ERR("Unable to eval text chunk %zu\n", i); + return 1; + } + ctx.n_past = new_n_past; + } else { + // media chunk: try to get embd from existing batch, or create a new batch + float * embd = nullptr; + if (ctx.mbatch) { + embd = mtmd_batch_get_output_embd(ctx.mbatch.get(), chunk); + + if (embd) { + LOG_DBG("found embd for media chunk %zu in existing batch\n", i); + } else { + LOG_DBG("media chunk %zu not found in existing batch, creating new batch\n", i); + } + } + + if (!embd) { + // create and encode a new batch with as many media chunks as possible + ctx.mbatch.reset(mtmd_batch_init(ctx.ctx_vision.get())); + res = mtmd_batch_add_chunk(ctx.mbatch.get(), chunk); + GGML_ASSERT(res == 0); // first chunk must always succeed + + int n_added = 1; + // add as many subsequent media chunks as possible + for (size_t j = i + 1; j < n_chunks; j++) { + auto next_chunk = mtmd_input_chunks_get(chunks.ptr.get(), j); + auto next_type = mtmd_input_chunk_get_type(next_chunk); + if (next_type == MTMD_INPUT_CHUNK_TYPE_TEXT) { + break; // text chunk splits the batch + } + res = mtmd_batch_add_chunk(ctx.mbatch.get(), next_chunk); + if (res != 0) { + break; // batch full or incompatible + } + n_added++; + } + + int64_t time_start = ggml_time_ms(); + LOG_INF("encoding mtmd batch, n_chunks = %d (done = %zu, total = %zu)\n", n_added, i, n_chunks); + res = mtmd_batch_encode(ctx.mbatch.get()); + if (res != 0) { + LOG_ERR("Failed to encode mtmd batch, res = %d\n", res); + return 1; + } + LOG_INF("mtmd batch encoding done in %d ms\n", (int)(ggml_time_ms() - time_start)); + + embd = mtmd_batch_get_output_embd(ctx.mbatch.get(), chunk); + } + + GGML_ASSERT(embd != nullptr); + + llama_pos new_n_past = ctx.n_past; + res = mtmd_helper_decode_image_chunk(ctx.ctx_vision.get(), + ctx.lctx, + chunk, + embd, + ctx.n_past, + 0, // seq_id + ctx.n_batch, + &new_n_past, + nullptr, // callback + nullptr // user_data + ); + if (res != 0) { + LOG_ERR("Unable to decode media chunk %zu\n", i); + return 1; + } + ctx.n_past = new_n_past; + } + } LOG("\n"); diff --git a/tools/mtmd/tests.sh b/tools/mtmd/tests.sh index 5da48d61bf..6fe26478ab 100755 --- a/tools/mtmd/tests.sh +++ b/tools/mtmd/tests.sh @@ -13,6 +13,8 @@ mkdir -p $SCRIPT_DIR/output PROJ_ROOT="$SCRIPT_DIR/../.." cd $PROJ_ROOT +export MTMD_TEST_RESPONSE_MARKER="" + # Check if the first argument is "big", then run test with big models # This is useful if we're running the script on a larger machine, so we can test the big models RUN_BIG_TESTS=false @@ -28,6 +30,15 @@ if [ "${1:-}" = "huge" ]; then echo "Include BIG and HUGE models..." fi +USE_VIDEO=false +if [ "${1:-}" = "video" ]; then + USE_VIDEO=true + echo "Using video as input..." + # behavior of USE_VIDEO: + # do NOT check if the output contains "new york", only verify if the exit code is 0 + # when printing the result, print the OK/FAIL line then print the generated text +fi + # Check if the second argument is "flash", then enable flash attention # This is useful to test if flash attention off works correctly FLASH_ATTN="on" @@ -50,13 +61,20 @@ add_test_vision() { if [ $# -gt 0 ]; then extra_args=$(printf " %q" "$@") fi + if [ "$USE_VIDEO" = true ]; then + arr_file+=("test-3.mp4") + else + arr_file+=("test-1.jpeg") + fi arr_prefix+=("[vision]") arr_hf+=("$hf") arr_extra_args+=("$extra_args") - arr_file+=("test-1.jpeg") } add_test_audio() { + if [ "$USE_VIDEO" = true ]; then + return 0 + fi local hf=$1 shift local extra_args="" @@ -166,19 +184,35 @@ for i in "${!arr_hf[@]}"; do cmd+=" -p \"what is the publisher name of the newspaper?\"" fi - output=$(eval "$cmd" 2>&1 | tee /dev/tty) + exit_code=0 + output=$(eval "$cmd" 2>&1 | tee /dev/tty) || exit_code=$? echo "$output" > $SCRIPT_DIR/output/$bin-$(echo "$hf" | tr '/' '-').log - # either contains "new york" or both "men" and "walk" - if echo "$output" | grep -iq "new york" \ - || (echo "$output" | grep -iq "men" && echo "$output" | grep -iq "walk") - then - result="$prefix \033[32mOK\033[0m: $hf" + if [ "$USE_VIDEO" = true ]; then + # for video, only check exit code; do not grep for "new york" + if [ $exit_code -eq 0 ]; then + result="$prefix \033[32mOK\033[0m: $hf" + else + result="$prefix \033[31mFAIL\033[0m: $hf" + fi + # append generated text (after the response marker) + generated_text=$(echo "$output" | sed "1,/${MTMD_TEST_RESPONSE_MARKER}/d" | tail -10) + if [ -n "$generated_text" ]; then + result+="\n$generated_text" + fi + echo -e "$result" else - result="$prefix \033[31mFAIL\033[0m: $hf" + # either contains "new york" or both "men" and "walk" + if echo "$output" | grep -iq "new york" \ + || (echo "$output" | grep -iq "men" && echo "$output" | grep -iq "walk") + then + result="$prefix \033[32mOK\033[0m: $hf" + else + result="$prefix \033[31mFAIL\033[0m: $hf" + fi + echo -e "$result" fi - echo -e "$result" arr_res+=("$result") echo ""