mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-27 23:50:20 -05:00
mtmd: build_vit batching (#24352)
This commit is contained in:
parent
d6d0ce8215
commit
49f3542190
@ -314,11 +314,17 @@ ggml_tensor * clip_graph::build_vit(
|
||||
std::function<ggml_tensor *(ggml_tensor *, const clip_layer &)> add_pos,
|
||||
const build_vit_opts & opts
|
||||
) {
|
||||
// batch dim: inp is [n_embd, n_pos] (B==1) or [n_embd, n_pos, B] (multi-tile encode)
|
||||
const int64_t B = inp->ne[2];
|
||||
|
||||
if (learned_pos_embd) {
|
||||
inp = ggml_add(ctx0, inp, learned_pos_embd);
|
||||
cb(inp, "pos_embed", -1);
|
||||
}
|
||||
|
||||
// flatten batch; unflatten again in attention
|
||||
inp = ggml_reshape_2d(ctx0, inp, n_embd, n_pos * B);
|
||||
|
||||
ggml_tensor * inpL = inp;
|
||||
|
||||
// pre-layernorm
|
||||
@ -348,20 +354,24 @@ ggml_tensor * clip_graph::build_vit(
|
||||
cur = ggml_add(ctx0, cur, layer.qkv_b);
|
||||
}
|
||||
|
||||
Qcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos,
|
||||
/* nb1 */ ggml_row_size(cur->type, d_head),
|
||||
/* nb2 */ cur->nb[1],
|
||||
/* offset */ 0);
|
||||
// Q/K/V as [d_head, n_head, n_pos, B], the batch stride is cur->nb[1]*n_pos.
|
||||
Qcur = ggml_view_4d(ctx0, cur, d_head, n_head, n_pos, B,
|
||||
/* nb1 */ ggml_row_size(cur->type, d_head),
|
||||
/* nb2 */ cur->nb[1],
|
||||
/* nb3 */ cur->nb[1] * n_pos,
|
||||
/* offset */ 0);
|
||||
|
||||
Kcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos,
|
||||
/* nb1 */ ggml_row_size(cur->type, d_head),
|
||||
/* nb2 */ cur->nb[1],
|
||||
/* offset */ ggml_row_size(cur->type, n_embd));
|
||||
Kcur = ggml_view_4d(ctx0, cur, d_head, n_head, n_pos, B,
|
||||
/* nb1 */ ggml_row_size(cur->type, d_head),
|
||||
/* nb2 */ cur->nb[1],
|
||||
/* nb3 */ cur->nb[1] * n_pos,
|
||||
/* offset */ ggml_row_size(cur->type, n_embd));
|
||||
|
||||
Vcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos,
|
||||
/* nb1 */ ggml_row_size(cur->type, d_head),
|
||||
/* nb2 */ cur->nb[1],
|
||||
/* offset */ ggml_row_size(cur->type, 2 * n_embd));
|
||||
Vcur = ggml_view_4d(ctx0, cur, d_head, n_head, n_pos, B,
|
||||
/* nb1 */ ggml_row_size(cur->type, d_head),
|
||||
/* nb2 */ cur->nb[1],
|
||||
/* nb3 */ cur->nb[1] * n_pos,
|
||||
/* offset */ ggml_row_size(cur->type, 2 * n_embd));
|
||||
|
||||
if (layer.q_norm) {
|
||||
GGML_ASSERT(layer.q_norm->ne[0] == Qcur->ne[0]);
|
||||
@ -406,9 +416,9 @@ ggml_tensor * clip_graph::build_vit(
|
||||
}
|
||||
}
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_pos);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head_kv, n_pos);
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head_kv, n_pos);
|
||||
Qcur = ggml_reshape_4d(ctx0, Qcur, d_head, n_head, n_pos, B);
|
||||
Kcur = ggml_reshape_4d(ctx0, Kcur, d_head, n_head_kv, n_pos, B);
|
||||
Vcur = ggml_reshape_4d(ctx0, Vcur, d_head, n_head_kv, n_pos, B);
|
||||
|
||||
if (norm_per_head) {
|
||||
if (layer.q_norm) {
|
||||
@ -438,6 +448,7 @@ ggml_tensor * clip_graph::build_vit(
|
||||
cb(Vcur, "Vcur_normed", il);
|
||||
}
|
||||
|
||||
// build_attn returns a flat 2D [n_embd, n_pos*B]
|
||||
cur = build_attn(layer.o_w, layer.o_b,
|
||||
Qcur, Kcur, Vcur, opts.attn_mask, kq_scale, il);
|
||||
cb(cur, "attn_out", il);
|
||||
@ -509,6 +520,10 @@ ggml_tensor * clip_graph::build_vit(
|
||||
if (model.post_ln_w) {
|
||||
inpL = build_norm(inpL, model.post_ln_w, model.post_ln_b, norm_t, eps, -1);
|
||||
}
|
||||
|
||||
// restore the batch dim
|
||||
GGML_ASSERT(inpL->ne[1] % B == 0);
|
||||
inpL = ggml_reshape_3d(ctx0, inpL, n_embd, inpL->ne[1] / B, B);
|
||||
return inpL;
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user