mtmd: be able to use alternative types for the K*Q multiplication (#1567)

* mtmd: allow using types other than f32 for K*Q

* Do not cast q if kq_type is quantized

* Fix formatting

* More formatting
This commit is contained in:
Kawrakow 2026-04-02 08:04:05 +02:00 committed by GitHub
parent 6ea7f321e8
commit 73742c5db9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 173 additions and 15 deletions

View File

@ -1181,6 +1181,11 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.mmproj_use_gpu = false;
return true;
}
if (arg == "--mtmd-kq-type") {
CHECK_ARG
params.mtmd_kq_type = argv[i];
return true;
}
if (arg == "--image" || arg == "--audio") {
CHECK_ARG
params.image.emplace_back(argv[i]);
@ -2489,9 +2494,10 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
options.push_back({ "multi-modality" });
options.push_back({ "*", " --mmproj FILE", "path to a multimodal projector file for LLaVA. see examples/llava/README.md" });
options.push_back({ "*", " --image FILE", "path to an image file. use with multimodal models. Specify multiple times for batching" });
options.push_back({ "*", " --image-min-tokens N", "minimum number of tokens each image can take, only used by vision models with dynamic resolution (default: read from model)"});
options.push_back({ "*", " --image-max-tokens N", "maximum number of tokens each image can take, only used by vision models with dynamic resolution (default: read from model)" });
options.push_back({ "*", " --no-context-shift", "disable context-shift." });
options.push_back({ "*", " --image-min-tokens N", "minimum number of tokens each image can take, only used by vision models with dynamic resolution (default: read from model)"});
options.push_back({ "*", " --image-max-tokens N", "maximum number of tokens each image can take, only used by vision models with dynamic resolution (default: read from model)" });
options.push_back({ "*", " --mtmd-kq-type TYPE", "data type for multimodality K*Q (default: %s)", params.mtmd_kq_type.c_str() });
options.push_back({ "*", " --no-context-shift", "disable context-shift." });
options.push_back({ "*", "--context-shift (auto|on|off|0|1)", "set context-shift (default: %s)", params.ctx_shift ? "on" : "off" });
options.push_back({ "backend" });
options.push_back({ "*", " --rpc SERVERS", "comma separated list of RPC servers" });

View File

@ -378,6 +378,7 @@ struct gpt_params {
std::vector<std::string> image; // path to image file(s)
int image_min_tokens = -1;
int image_max_tokens = -1;
std::string mtmd_kq_type = "f32";
// embedding
bool embedding = false; // get only sentence embedding

View File

@ -443,6 +443,7 @@ struct clip_ctx {
int max_nodes = 8192;
ggml_backend_sched_ptr sched;
clip_flash_attn_type flash_attn_type = CLIP_FLASH_ATTN_TYPE_AUTO;
ggml_type kq_type = GGML_TYPE_F32;
// for debugging
bool debug_graph = false;
@ -450,6 +451,7 @@ struct clip_ctx {
clip_ctx(clip_context_params & ctx_params) {
flash_attn_type = ctx_params.flash_attn_type;
kq_type = ctx_params.kq_type;
debug_graph = std::getenv("MTMD_DEBUG_GRAPH") != nullptr;
//backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
backend_cpu = ggml_backend_cpu_init();
@ -1011,7 +1013,9 @@ struct clip_graph {
// self-attention
{
cur = ggml_mul_mat(ctx0, layer.qkv_w, cur);
cb(cur, "qkv_w", il);
cur = ggml_add(ctx0, cur, layer.qkv_b);
cb(cur, "qkv_b", il);
ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, d_head*sizeof(float),
cur->nb[1], 0);
@ -1072,12 +1076,14 @@ struct clip_graph {
nullptr, nullptr,
layer.deepstack_fc2_w, layer.deepstack_fc2_b,
ffn_op_type::FFN_GELU, il);
cb(feat, "ffn_feat", il);
if(!deepstack_features) {
deepstack_features = feat;
} else {
// concat along the feature dimension
deepstack_features = ggml_concat(ctx0, deepstack_features, feat, 0);
cb(deepstack_features, "feat_concat", il);
}
}
@ -1098,9 +1104,11 @@ struct clip_graph {
nullptr, nullptr,
model.mm_1_w, model.mm_1_b,
ffn_op_type::FFN_GELU, -1);
cb(embeddings, "ffn_postl", -1);
if (deepstack_features) {
embeddings = ggml_concat(ctx0, embeddings, deepstack_features, 0); // concat along the feature dimension
cb(embeddings, "ffn_postl_concat", -1);
}
// build the graph
@ -2425,6 +2433,22 @@ private:
ggml_tensor * v = ggml_permute(ctx0, v_cur, 1, 2, 0, 3);
v = ggml_cont(ctx0, v);
if (ctx->kq_type != k->type) {
auto bs = ggml_blck_size(ctx->kq_type);
if (k->ne[0] % bs != 0) {
int nbs = bs*((k->ne[0] + bs - 1)/bs);
k = ggml_pad(ctx0, k, nbs - k->ne[0], 0, 0, 0);
}
if (q->ne[0] % bs != 0) {
int nbs = bs*((q->ne[0] + bs - 1)/bs);
q = ggml_pad(ctx0, q, nbs - q->ne[0], 0, 0, 0);
}
k = ggml_cast(ctx0, k, ctx->kq_type);
if (!ggml_is_quantized(ctx->kq_type)) {
q = ggml_cast(ctx0, q, ctx->kq_type);
}
}
if (q->ne[3] == 1 && q->ne[2] > 1 && q->ne[2] == k->ne[2] && q->ne[2] == v->ne[2] && q->ne[1]*k->ne[1]*q->ne[2]/1024./1024. >= 256.) {
cur = nullptr;
for (int64_t i2 = 0; i2 < q->ne[2]; ++i2) {
@ -2432,10 +2456,14 @@ private:
auto ki = ggml_view_2d(ctx0, k, k->ne[0], k->ne[1], k->nb[1], k->nb[2]*i2);
auto vi = ggml_view_2d(ctx0, v, v->ne[0], v->ne[1], v->nb[1], v->nb[2]*i2);
auto kq = ggml_mul_mat(ctx0, ki, qi);
cb(kq, "kq_i", il);
kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, 0.0f);
cb(kq, "softmax(kq_i)", il);
auto kqv = ggml_mul_mat(ctx0, vi, kq);
cb(kqv, "kqv_i", il);
if (cur) {
cur = ggml_concat(ctx0, cur, kqv, 0);
cb(cur, "kqv_i_concat", il);
} else {
cur = kqv;
}

View File

@ -35,6 +35,7 @@ struct clip_context_params {
enum clip_flash_attn_type flash_attn_type;
int image_min_tokens;
int image_max_tokens;
ggml_type kq_type;
};
struct clip_init_result {

View File

@ -99,6 +99,22 @@ void common_init() {
#endif
// ======================= end compat ================================
static ggml_type ggml_type_from_str(const std::string & s) {
if (s == "f32") {
return GGML_TYPE_F32;
}
if (s == "f16") {
return GGML_TYPE_F16;
}
if (s == "bf16") {
return GGML_TYPE_BF16;
}
if (s == "q8_0") {
return GGML_TYPE_Q8_0;
}
throw std::runtime_error("Invalid cache type: " + s);
}
struct mtmd_cli_context {
mtmd::context_ptr ctx_vision;
common_init_result llama_init;
@ -171,6 +187,7 @@ struct mtmd_cli_context {
mparams.flash_attn_type = params.flash_attn ? LLAMA_FLASH_ATTN_TYPE_ENABLED : LLAMA_FLASH_ATTN_TYPE_DISABLED;
mparams.image_min_tokens = params.image_min_tokens;
mparams.image_max_tokens = params.image_max_tokens;
mparams.kq_type = ggml_type_from_str(params.mtmd_kq_type);
ctx_vision.reset(mtmd_init_from_file(clip_path, model, mparams));
if (!ctx_vision.get()) {
LOG_ERR("Failed to load vision model from %s\n", clip_path);

View File

@ -102,6 +102,7 @@ mtmd_context_params mtmd_context_params_default() {
/* flash_attn_type */ LLAMA_FLASH_ATTN_TYPE_AUTO,
/* image_min_tokens */ -1,
/* image_max_tokens */ -1,
/* kq_type */ GGML_TYPE_F32,
};
return params;
}
@ -170,6 +171,7 @@ struct mtmd_context {
/* flash_attn_type */ CLIP_FLASH_ATTN_TYPE_DISABLED,
/* image_min_tokens */ ctx_params.image_min_tokens,
/* image_max_tokens */ ctx_params.image_max_tokens,
/* kq_type */ ctx_params.kq_type,
};
auto res = clip_init(mmproj_fname, ctx_clip_params);

View File

@ -87,6 +87,7 @@ struct mtmd_context_params {
// limit number of image tokens, only for vision models with dynamic resolution
int image_min_tokens; // minimum number of tokens for image input (default: read from metadata)
int image_max_tokens; // maximum number of tokens for image input (default: read from metadata)
ggml_type kq_type;
};
MTMD_API const char * mtmd_default_marker(void);

View File

@ -3971,12 +3971,15 @@ static ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x,
int i = 0;
ggml_float sum = 0;
#if defined(__AVX512F__) && defined(__AVX512DQ__)
__m512 vsum = _mm512_setzero_ps();
for (; i + 15 < n; i += 16) {
__m512 val = ggml_v_expf(_mm512_sub_ps(_mm512_loadu_ps(x + i),
_mm512_set1_ps(max)));
_mm512_storeu_ps(y + i, val);
sum += (ggml_float)_mm512_reduce_add_ps(val);
vsum = _mm512_add_ps(vsum, val);
//sum += (ggml_float)_mm512_reduce_add_ps(val);
}
sum = (ggml_float)_mm512_reduce_add_ps(vsum);
#elif defined(__AVX2__) && defined(__FMA__)
for (; i + 7 < n; i += 8) {
__m256 val = ggml_v_expf(_mm256_sub_ps(_mm256_loadu_ps(x + i),
@ -21667,16 +21670,20 @@ static void ggml_compute_forward_flash_attn_ext_f16(
#if GGML_USE_IQK_MULMAT
// For now we do not implement sinks in the iqk FA implementation
if (iqk_flash_attn_noalibi(q->type, mask->type, max_bias,
if (iqk_flash_attn_noalibi(q->type, mask ? mask->type : GGML_TYPE_F16, max_bias,
q->ne[3], q->ne[2], q->nb[3], q->nb[2],
k->ne[3], k->ne[2], k->nb[3], k->nb[2],
v->ne[3], v->ne[2], v->nb[3], v->nb[2],
dst->ne[2], dst->ne[1], dst->nb[1],
k->type, v->type,
Dk, Dv, neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1],
q->data, k->data, v->data, mask->data, sinks ? sinks->data : NULL,
Dk, Dv, neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask ? mask->nb[1] : 0,
q->data, k->data, v->data, mask ? mask->data : NULL, sinks ? sinks->data : NULL,
scale, softcap, (float *)dst->data,
params->wdata, (barrier_t)ggml_barrier, (void *)params->shared, ith, nth, dst->op_params[4])) return;
printf("iqk_flash_attn_noalibi returned false for Dk = %ld, Dv = %ld, mask = %p:\n", Dk, Dv, (const void *)mask);
printf(" q(%s): %ld x %ld x %ld x %ld\n", ggml_type_name(q->type), q->ne[0], q->ne[1], q->ne[2], q->ne[3]);
printf(" k(%s): %ld x %ld x %ld x %ld\n", ggml_type_name(k->type), k->ne[0], k->ne[1], k->ne[2], k->ne[3]);
printf(" v(%s): %ld x %ld x %ld x %ld\n", ggml_type_name(v->type), v->ne[0], v->ne[1], v->ne[2], v->ne[3]);
// if (max_bias <= 0.0f && q->type == GGML_TYPE_F32 && mask && mask->type == GGML_TYPE_F16) {
// //if (ith == 0) printf("k: %ld x %ld x %ld, q: %ld x %ld x %ld, v: %ld x %ld x %ld mask: %ld x %ld x %ld\n",

View File

@ -184,7 +184,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
GGML_ABORT("Fatal error");
}
if (n_swa > 0) {
if (n_swa > 0 && mask) {
constexpr int kMinBatch = 256;
int ntokens = std::max(kMinBatch, neq1);
int nblock = (ntokens + n_swa + kMinBatch - 1)/kMinBatch;
@ -203,7 +203,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
int rv3 = neq3/nev3;
int first_k = 0, last_k = nek1;
if (neq3 == 1 && rk2 > 1 && neq1 == 1 && nek1 > 256) {
if (neq3 == 1 && rk2 > 1 && neq1 == 1 && nek1 > 256 && mask) {
// This is a quick hack for SWA models.
// Given that the mask is the same for all layers, ideally we should determine the
// cache bounds once, and reuse for the whole graph. But even with this simple hack
@ -271,7 +271,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
auto kth = (const char *)k + kv_offset*stride_k;
auto vth = (const char *)v + kv_offset*stride_v;
auto qth = (const char *)q;
auto mth = (const char *)mask + kv_offset*sizeof(uint16_t); // we don't have ggml_half available here
auto mth = mask ? (const char *)mask + kv_offset*sizeof(uint16_t) : nullptr; // we don't have ggml_half available here
auto work = (char *)work_buffer;
auto size_thread = (Dv + 16)*rk2*sizeof(float);
@ -322,7 +322,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
auto vth = (const char *)v + ith_k*(nek1/gcd_k)*stride_v;
auto q_offset = ith_q < ith_mid ? ith_q*nq_per_thread*nbq2 : (ith_mid*nq_per_thread + (ith_q - ith_mid)*nq_this_thread)*nbq2;
auto qth = (const char *)q + q_offset;
auto mth = (const char *)mask + ith_k*(nek1/gcd_k)*sizeof(uint16_t); // we don't have ggml_half available here
auto mth = mask ? (const char *)mask + ith_k*(nek1/gcd_k)*sizeof(uint16_t) : nullptr; // we don't have ggml_half available here
// Each thread will produce a result of size Dv*nq_this_thread*sizeof(float)
// In addition, we need M, S for the nq_this_thread rows the thread is processing
@ -403,7 +403,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
auto this_q = (const float *)((const char *)q + ik02*rk2*nbq2);
auto this_k = (const char *)k + ik01*stride_k + ik02*nbk2;
auto this_v = (const char *)v + ik01*stride_v + ik02*nbv2;
auto this_m = (const char *)mask + ik01*sizeof(uint16_t); // we don't have ggml_half available here
auto this_m = mask ? (const char *)mask + ik01*sizeof(uint16_t) : nullptr; // we don't have ggml_half available here
if (!iqk_flash_attn_impl(int_type_k, int_type_v,
Dk, Dv, rk2, this_nk, nbq2, stride_k, stride_v, 0, Dv,
this_q, (const void *)this_k, (const void *)this_v, (const void *)this_m, nullptr, 0,
@ -473,7 +473,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
(const float *)((const char *)q + iq2*nbq2 + iq3*nbq3 + iq1*stride_q),
(const void *)((const char *)k + iq2/rk2*nbk2 + iq3/rk3*nbk3),
(const void *)((const char *)v + iq2/rv2*nbv2 + iq3/rv3*nbv3),
(const void *)((const char *)mask + iq1*stride_m), sinksf, 1,
mask ? (const void *)((const char *)mask + iq1*stride_m) : nullptr, sinksf, 1,
scale, softcap,
(float *)((char *)qkv + (iq3*ne2*ne1 + iq2 + iq1*ne1)*nb1), nullptr, nullptr)) return false;
}

View File

@ -420,6 +420,32 @@ template <int nrc_in> struct QFTBF16 final : public QFBaseBF16 {
IQK_ALWAYS_INLINE Data load1(int iy, int i) const { return load(y[iy] + k_step*i); }
const ggml_bf16_t * y[nrc];
};
struct QFBaseBF16x8 {
constexpr static int k_step = 16;
using Data = __m256bh;
using Acc = __m256;
static inline Data load(const ggml_bf16_t * x) { return __m256bh(_mm256_loadu_si256((const __m256i *)x)); }
static inline Acc acc(Acc prev, Data y, Data x) {
return _mm256_dpbf16_ps(prev, y, x);
}
static inline Acc acc_first(const Data& y, const Data& x) {
return _mm256_dpbf16_ps(_mm256_setzero_ps(), y, x);
}
static inline float hsum(Acc acc) {
return hsum_float_8(acc);
}
};
template <int nrc_in> struct QFTBF16x8 final : public QFBaseBF16x8 {
constexpr static int nrc = nrc_in;
QFTBF16x8(const DataInfo& info) {
for (int iy = 0; iy < nrc; ++iy) y[iy] = (const ggml_bf16_t *)info.src1_row(iy);
}
QFTBF16x8(const char * cx, size_t bx) {
for (int iy = 0; iy < nrc; ++iy) y[iy] = (const ggml_bf16_t *)(cx + iy*bx);
}
IQK_ALWAYS_INLINE Data load1(int iy, int i) const { return load(y[iy] + k_step*i); }
const ggml_bf16_t * y[nrc];
};
template <int nrc_y, int nrc_x>
IQK_NOINLINE void mul_mat_Qx_Qy_MxN(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) {
@ -476,6 +502,61 @@ void mul_mat_fX_fY_T(int n, const void * vx, size_t bx, const DataInfo& info, in
case 4: mul_mat_Qx_Qy_MxN<nrc_y, 4>(n, cx, bx, last_x, info); break;
}
}
template <int nrc_y, int nrc_x>
IQK_NOINLINE void mul_mat_Qx_Qy_MxNx8(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) {
int nb = n/QFBaseBF16x8::k_step;
QFTBF16x8<nrc_y> y(info);
QFTBF16x8<nrc_x> x(cx + ix0*bx, bx);
QFBaseBF16x8::Data xv[nrc_x];
QFBaseBF16x8::Acc acc[nrc_x*nrc_y];
auto yv = y.load1(0, 0);
for (int ix = 0; ix < nrc_x; ++ix) {
xv[ix] = x.load1(ix, 0);
acc[ix] = QFBaseBF16x8::acc_first(yv, xv[ix]);
}
for (int iy = 1; iy < nrc_y; ++iy) {
yv = y.load1(iy, 0);
for (int ix = 0; ix < nrc_x; ++ix) acc[nrc_x*iy + ix] = QFBaseBF16x8::acc_first(yv, xv[ix]);
}
for (int i = 1; i < nb; ++i) {
yv = y.load1(0, i);
for (int ix = 0; ix < nrc_x; ++ix) {
xv[ix] = x.load1(ix, i);
acc[ix] = QFBaseBF16x8::acc(acc[ix], yv, xv[ix]);
}
for (int iy = 1; iy < nrc_y; ++iy) {
yv = y.load1(iy, i);
for (int ix = 0; ix < nrc_x; ++ix) acc[nrc_x*iy + ix] = QFBaseBF16x8::acc(acc[nrc_x*iy + ix], yv, xv[ix]);
}
}
for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < nrc_x; ++ix) info.store(ix0+ix, iy, QFBaseBF16x8::hsum(acc[nrc_x*iy+ix]));
}
template <int nrc_y>
void mul_mat_fX_fY_Tx8(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
constexpr int k_nx = nrc_y <= 2 ? 8 : 5;
const char * cx = (const char *)vx;
for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
mul_mat_Qx_Qy_MxNx8<nrc_y, k_nx>(n, cx, bx, ix*k_nx, info);
}
int last_x = k_nx*(nrc_x/k_nx);
if (last_x == nrc_x) return;
int nx = nrc_x - last_x;
if constexpr (nrc_y <= 2) {
if (nx >= 4) {
mul_mat_Qx_Qy_MxNx8<nrc_y, 4>(n, cx, bx, last_x, info);
last_x += 4;
if (last_x == nrc_x) return;
nx = nrc_x - last_x;
}
}
switch (nx) {
case 1: mul_mat_Qx_Qy_MxNx8<nrc_y, 1>(n, cx, bx, last_x, info); break;
case 2: mul_mat_Qx_Qy_MxNx8<nrc_y, 2>(n, cx, bx, last_x, info); break;
case 3: mul_mat_Qx_Qy_MxNx8<nrc_y, 3>(n, cx, bx, last_x, info); break;
case 4: mul_mat_Qx_Qy_MxNx8<nrc_y, 4>(n, cx, bx, last_x, info); break;
}
}
#endif
@ -501,6 +582,14 @@ void set_mul_mat_bf16(std::array<mul_mat_t, IQK_MAX_NY>& funcs) {
funcs[3] = mul_mat_fX_fY_T<4>;
funcs[4] = mul_mat_fX_fY_T<5>;
}
void set_mul_mat_bf16x8(std::array<mul_mat_t, IQK_MAX_NY>& funcs) {
for (auto& f : funcs) f = nullptr;
funcs[0] = mul_mat_fX_fY_Tx8<1>;
funcs[1] = mul_mat_fX_fY_Tx8<2>;
funcs[2] = mul_mat_fX_fY_Tx8<3>;
funcs[3] = mul_mat_fX_fY_Tx8<4>;
funcs[4] = mul_mat_fX_fY_Tx8<5>;
}
void set_mul_mat_bf16_r16(std::array<mul_mat_t, IQK_MAX_NY>& funcs) {
for (auto& f : funcs) f = nullptr;
funcs[0] = mul_mat_bf16_r16_bf16<1>;
@ -519,10 +608,16 @@ void set_mul_mat_bf16_r16(std::array<mul_mat_t, IQK_MAX_NY>& funcs) {
bool iqk_set_kernels_float(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels) {
if (typeA == GGML_TYPE_BF16) {
if (ne00 % 16) return false;
if (ne00 % 8) return false;
switch (typeB) {
#ifdef __AVX512BF16__
case GGML_TYPE_BF16: set_mul_mat_bf16(kernels); break;
case GGML_TYPE_BF16: {
if (ne00 % 16 == 0) {
set_mul_mat_bf16(kernels);
} else {
set_mul_mat_bf16x8(kernels);
}
} break;
#else
case GGML_TYPE_BF16: set_mul_mat_f<ggml_bf16_t, ggml_bf16_t>(kernels); break;
case GGML_TYPE_F32: set_mul_mat_f<ggml_bf16_t, float>(kernels); break;