From 49f354219059fc22316ae3efa54e54ba37f77860 Mon Sep 17 00:00:00 2001 From: Saba Fallah <10401143+sfallah@users.noreply.github.com> Date: Tue, 9 Jun 2026 16:32:08 +0200 Subject: [PATCH] mtmd: build_vit batching (#24352) --- tools/mtmd/clip.cpp | 45 ++++++++++++++++++++++++++++++--------------- 1 file changed, 30 insertions(+), 15 deletions(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index bd33f43062..02e7a3a8f6 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -314,11 +314,17 @@ ggml_tensor * clip_graph::build_vit( std::function add_pos, const build_vit_opts & opts ) { + // batch dim: inp is [n_embd, n_pos] (B==1) or [n_embd, n_pos, B] (multi-tile encode) + const int64_t B = inp->ne[2]; + if (learned_pos_embd) { inp = ggml_add(ctx0, inp, learned_pos_embd); cb(inp, "pos_embed", -1); } + // flatten batch; unflatten again in attention + inp = ggml_reshape_2d(ctx0, inp, n_embd, n_pos * B); + ggml_tensor * inpL = inp; // pre-layernorm @@ -348,20 +354,24 @@ ggml_tensor * clip_graph::build_vit( cur = ggml_add(ctx0, cur, layer.qkv_b); } - Qcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, - /* nb1 */ ggml_row_size(cur->type, d_head), - /* nb2 */ cur->nb[1], - /* offset */ 0); + // Q/K/V as [d_head, n_head, n_pos, B], the batch stride is cur->nb[1]*n_pos. + Qcur = ggml_view_4d(ctx0, cur, d_head, n_head, n_pos, B, + /* nb1 */ ggml_row_size(cur->type, d_head), + /* nb2 */ cur->nb[1], + /* nb3 */ cur->nb[1] * n_pos, + /* offset */ 0); - Kcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, - /* nb1 */ ggml_row_size(cur->type, d_head), - /* nb2 */ cur->nb[1], - /* offset */ ggml_row_size(cur->type, n_embd)); + Kcur = ggml_view_4d(ctx0, cur, d_head, n_head, n_pos, B, + /* nb1 */ ggml_row_size(cur->type, d_head), + /* nb2 */ cur->nb[1], + /* nb3 */ cur->nb[1] * n_pos, + /* offset */ ggml_row_size(cur->type, n_embd)); - Vcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, - /* nb1 */ ggml_row_size(cur->type, d_head), - /* nb2 */ cur->nb[1], - /* offset */ ggml_row_size(cur->type, 2 * n_embd)); + Vcur = ggml_view_4d(ctx0, cur, d_head, n_head, n_pos, B, + /* nb1 */ ggml_row_size(cur->type, d_head), + /* nb2 */ cur->nb[1], + /* nb3 */ cur->nb[1] * n_pos, + /* offset */ ggml_row_size(cur->type, 2 * n_embd)); if (layer.q_norm) { GGML_ASSERT(layer.q_norm->ne[0] == Qcur->ne[0]); @@ -406,9 +416,9 @@ ggml_tensor * clip_graph::build_vit( } } - Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_pos); - Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head_kv, n_pos); - Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head_kv, n_pos); + Qcur = ggml_reshape_4d(ctx0, Qcur, d_head, n_head, n_pos, B); + Kcur = ggml_reshape_4d(ctx0, Kcur, d_head, n_head_kv, n_pos, B); + Vcur = ggml_reshape_4d(ctx0, Vcur, d_head, n_head_kv, n_pos, B); if (norm_per_head) { if (layer.q_norm) { @@ -438,6 +448,7 @@ ggml_tensor * clip_graph::build_vit( cb(Vcur, "Vcur_normed", il); } + // build_attn returns a flat 2D [n_embd, n_pos*B] cur = build_attn(layer.o_w, layer.o_b, Qcur, Kcur, Vcur, opts.attn_mask, kq_scale, il); cb(cur, "attn_out", il); @@ -509,6 +520,10 @@ ggml_tensor * clip_graph::build_vit( if (model.post_ln_w) { inpL = build_norm(inpL, model.post_ln_w, model.post_ln_b, norm_t, eps, -1); } + + // restore the batch dim + GGML_ASSERT(inpL->ne[1] % B == 0); + inpL = ggml_reshape_3d(ctx0, inpL, n_embd, inpL->ne[1] / B, B); return inpL; }