mtmd: build_vit batching (#24352)

This commit is contained in:
Saba Fallah 2026-06-09 16:32:08 +02:00 committed by GitHub
parent d6d0ce8215
commit 49f3542190
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -314,11 +314,17 @@ ggml_tensor * clip_graph::build_vit(
std::function<ggml_tensor *(ggml_tensor *, const clip_layer &)> add_pos,
const build_vit_opts & opts
) {
// batch dim: inp is [n_embd, n_pos] (B==1) or [n_embd, n_pos, B] (multi-tile encode)
const int64_t B = inp->ne[2];
if (learned_pos_embd) {
inp = ggml_add(ctx0, inp, learned_pos_embd);
cb(inp, "pos_embed", -1);
}
// flatten batch; unflatten again in attention
inp = ggml_reshape_2d(ctx0, inp, n_embd, n_pos * B);
ggml_tensor * inpL = inp;
// pre-layernorm
@ -348,20 +354,24 @@ ggml_tensor * clip_graph::build_vit(
cur = ggml_add(ctx0, cur, layer.qkv_b);
}
Qcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos,
/* nb1 */ ggml_row_size(cur->type, d_head),
/* nb2 */ cur->nb[1],
/* offset */ 0);
// Q/K/V as [d_head, n_head, n_pos, B], the batch stride is cur->nb[1]*n_pos.
Qcur = ggml_view_4d(ctx0, cur, d_head, n_head, n_pos, B,
/* nb1 */ ggml_row_size(cur->type, d_head),
/* nb2 */ cur->nb[1],
/* nb3 */ cur->nb[1] * n_pos,
/* offset */ 0);
Kcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos,
/* nb1 */ ggml_row_size(cur->type, d_head),
/* nb2 */ cur->nb[1],
/* offset */ ggml_row_size(cur->type, n_embd));
Kcur = ggml_view_4d(ctx0, cur, d_head, n_head, n_pos, B,
/* nb1 */ ggml_row_size(cur->type, d_head),
/* nb2 */ cur->nb[1],
/* nb3 */ cur->nb[1] * n_pos,
/* offset */ ggml_row_size(cur->type, n_embd));
Vcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos,
/* nb1 */ ggml_row_size(cur->type, d_head),
/* nb2 */ cur->nb[1],
/* offset */ ggml_row_size(cur->type, 2 * n_embd));
Vcur = ggml_view_4d(ctx0, cur, d_head, n_head, n_pos, B,
/* nb1 */ ggml_row_size(cur->type, d_head),
/* nb2 */ cur->nb[1],
/* nb3 */ cur->nb[1] * n_pos,
/* offset */ ggml_row_size(cur->type, 2 * n_embd));
if (layer.q_norm) {
GGML_ASSERT(layer.q_norm->ne[0] == Qcur->ne[0]);
@ -406,9 +416,9 @@ ggml_tensor * clip_graph::build_vit(
}
}
Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_pos);
Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head_kv, n_pos);
Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head_kv, n_pos);
Qcur = ggml_reshape_4d(ctx0, Qcur, d_head, n_head, n_pos, B);
Kcur = ggml_reshape_4d(ctx0, Kcur, d_head, n_head_kv, n_pos, B);
Vcur = ggml_reshape_4d(ctx0, Vcur, d_head, n_head_kv, n_pos, B);
if (norm_per_head) {
if (layer.q_norm) {
@ -438,6 +448,7 @@ ggml_tensor * clip_graph::build_vit(
cb(Vcur, "Vcur_normed", il);
}
// build_attn returns a flat 2D [n_embd, n_pos*B]
cur = build_attn(layer.o_w, layer.o_b,
Qcur, Kcur, Vcur, opts.attn_mask, kq_scale, il);
cb(cur, "attn_out", il);
@ -509,6 +520,10 @@ ggml_tensor * clip_graph::build_vit(
if (model.post_ln_w) {
inpL = build_norm(inpL, model.post_ln_w, model.post_ln_b, norm_t, eps, -1);
}
// restore the batch dim
GGML_ASSERT(inpL->ne[1] % B == 0);
inpL = ggml_reshape_3d(ctx0, inpL, n_embd, inpL->ne[1] / B, B);
return inpL;
}