mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-06-28 04:30:15 -05:00
mtmd: be able to use alternative types for the K*Q multiplication (#1567)
* mtmd: allow using types other than f32 for K*Q * Do not cast q if kq_type is quantized * Fix formatting * More formatting
This commit is contained in:
parent
6ea7f321e8
commit
73742c5db9
@ -1181,6 +1181,11 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
||||
params.mmproj_use_gpu = false;
|
||||
return true;
|
||||
}
|
||||
if (arg == "--mtmd-kq-type") {
|
||||
CHECK_ARG
|
||||
params.mtmd_kq_type = argv[i];
|
||||
return true;
|
||||
}
|
||||
if (arg == "--image" || arg == "--audio") {
|
||||
CHECK_ARG
|
||||
params.image.emplace_back(argv[i]);
|
||||
@ -2489,9 +2494,10 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
|
||||
options.push_back({ "multi-modality" });
|
||||
options.push_back({ "*", " --mmproj FILE", "path to a multimodal projector file for LLaVA. see examples/llava/README.md" });
|
||||
options.push_back({ "*", " --image FILE", "path to an image file. use with multimodal models. Specify multiple times for batching" });
|
||||
options.push_back({ "*", " --image-min-tokens N", "minimum number of tokens each image can take, only used by vision models with dynamic resolution (default: read from model)"});
|
||||
options.push_back({ "*", " --image-max-tokens N", "maximum number of tokens each image can take, only used by vision models with dynamic resolution (default: read from model)" });
|
||||
options.push_back({ "*", " --no-context-shift", "disable context-shift." });
|
||||
options.push_back({ "*", " --image-min-tokens N", "minimum number of tokens each image can take, only used by vision models with dynamic resolution (default: read from model)"});
|
||||
options.push_back({ "*", " --image-max-tokens N", "maximum number of tokens each image can take, only used by vision models with dynamic resolution (default: read from model)" });
|
||||
options.push_back({ "*", " --mtmd-kq-type TYPE", "data type for multimodality K*Q (default: %s)", params.mtmd_kq_type.c_str() });
|
||||
options.push_back({ "*", " --no-context-shift", "disable context-shift." });
|
||||
options.push_back({ "*", "--context-shift (auto|on|off|0|1)", "set context-shift (default: %s)", params.ctx_shift ? "on" : "off" });
|
||||
options.push_back({ "backend" });
|
||||
options.push_back({ "*", " --rpc SERVERS", "comma separated list of RPC servers" });
|
||||
|
||||
@ -378,6 +378,7 @@ struct gpt_params {
|
||||
std::vector<std::string> image; // path to image file(s)
|
||||
int image_min_tokens = -1;
|
||||
int image_max_tokens = -1;
|
||||
std::string mtmd_kq_type = "f32";
|
||||
|
||||
// embedding
|
||||
bool embedding = false; // get only sentence embedding
|
||||
|
||||
@ -443,6 +443,7 @@ struct clip_ctx {
|
||||
int max_nodes = 8192;
|
||||
ggml_backend_sched_ptr sched;
|
||||
clip_flash_attn_type flash_attn_type = CLIP_FLASH_ATTN_TYPE_AUTO;
|
||||
ggml_type kq_type = GGML_TYPE_F32;
|
||||
|
||||
// for debugging
|
||||
bool debug_graph = false;
|
||||
@ -450,6 +451,7 @@ struct clip_ctx {
|
||||
|
||||
clip_ctx(clip_context_params & ctx_params) {
|
||||
flash_attn_type = ctx_params.flash_attn_type;
|
||||
kq_type = ctx_params.kq_type;
|
||||
debug_graph = std::getenv("MTMD_DEBUG_GRAPH") != nullptr;
|
||||
//backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
|
||||
backend_cpu = ggml_backend_cpu_init();
|
||||
@ -1011,7 +1013,9 @@ struct clip_graph {
|
||||
// self-attention
|
||||
{
|
||||
cur = ggml_mul_mat(ctx0, layer.qkv_w, cur);
|
||||
cb(cur, "qkv_w", il);
|
||||
cur = ggml_add(ctx0, cur, layer.qkv_b);
|
||||
cb(cur, "qkv_b", il);
|
||||
|
||||
ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, d_head*sizeof(float),
|
||||
cur->nb[1], 0);
|
||||
@ -1072,12 +1076,14 @@ struct clip_graph {
|
||||
nullptr, nullptr,
|
||||
layer.deepstack_fc2_w, layer.deepstack_fc2_b,
|
||||
ffn_op_type::FFN_GELU, il);
|
||||
cb(feat, "ffn_feat", il);
|
||||
|
||||
if(!deepstack_features) {
|
||||
deepstack_features = feat;
|
||||
} else {
|
||||
// concat along the feature dimension
|
||||
deepstack_features = ggml_concat(ctx0, deepstack_features, feat, 0);
|
||||
cb(deepstack_features, "feat_concat", il);
|
||||
}
|
||||
}
|
||||
|
||||
@ -1098,9 +1104,11 @@ struct clip_graph {
|
||||
nullptr, nullptr,
|
||||
model.mm_1_w, model.mm_1_b,
|
||||
ffn_op_type::FFN_GELU, -1);
|
||||
cb(embeddings, "ffn_postl", -1);
|
||||
|
||||
if (deepstack_features) {
|
||||
embeddings = ggml_concat(ctx0, embeddings, deepstack_features, 0); // concat along the feature dimension
|
||||
cb(embeddings, "ffn_postl_concat", -1);
|
||||
}
|
||||
|
||||
// build the graph
|
||||
@ -2425,6 +2433,22 @@ private:
|
||||
ggml_tensor * v = ggml_permute(ctx0, v_cur, 1, 2, 0, 3);
|
||||
v = ggml_cont(ctx0, v);
|
||||
|
||||
if (ctx->kq_type != k->type) {
|
||||
auto bs = ggml_blck_size(ctx->kq_type);
|
||||
if (k->ne[0] % bs != 0) {
|
||||
int nbs = bs*((k->ne[0] + bs - 1)/bs);
|
||||
k = ggml_pad(ctx0, k, nbs - k->ne[0], 0, 0, 0);
|
||||
}
|
||||
if (q->ne[0] % bs != 0) {
|
||||
int nbs = bs*((q->ne[0] + bs - 1)/bs);
|
||||
q = ggml_pad(ctx0, q, nbs - q->ne[0], 0, 0, 0);
|
||||
}
|
||||
k = ggml_cast(ctx0, k, ctx->kq_type);
|
||||
if (!ggml_is_quantized(ctx->kq_type)) {
|
||||
q = ggml_cast(ctx0, q, ctx->kq_type);
|
||||
}
|
||||
}
|
||||
|
||||
if (q->ne[3] == 1 && q->ne[2] > 1 && q->ne[2] == k->ne[2] && q->ne[2] == v->ne[2] && q->ne[1]*k->ne[1]*q->ne[2]/1024./1024. >= 256.) {
|
||||
cur = nullptr;
|
||||
for (int64_t i2 = 0; i2 < q->ne[2]; ++i2) {
|
||||
@ -2432,10 +2456,14 @@ private:
|
||||
auto ki = ggml_view_2d(ctx0, k, k->ne[0], k->ne[1], k->nb[1], k->nb[2]*i2);
|
||||
auto vi = ggml_view_2d(ctx0, v, v->ne[0], v->ne[1], v->nb[1], v->nb[2]*i2);
|
||||
auto kq = ggml_mul_mat(ctx0, ki, qi);
|
||||
cb(kq, "kq_i", il);
|
||||
kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, 0.0f);
|
||||
cb(kq, "softmax(kq_i)", il);
|
||||
auto kqv = ggml_mul_mat(ctx0, vi, kq);
|
||||
cb(kqv, "kqv_i", il);
|
||||
if (cur) {
|
||||
cur = ggml_concat(ctx0, cur, kqv, 0);
|
||||
cb(cur, "kqv_i_concat", il);
|
||||
} else {
|
||||
cur = kqv;
|
||||
}
|
||||
|
||||
@ -35,6 +35,7 @@ struct clip_context_params {
|
||||
enum clip_flash_attn_type flash_attn_type;
|
||||
int image_min_tokens;
|
||||
int image_max_tokens;
|
||||
ggml_type kq_type;
|
||||
};
|
||||
|
||||
struct clip_init_result {
|
||||
|
||||
@ -99,6 +99,22 @@ void common_init() {
|
||||
#endif
|
||||
// ======================= end compat ================================
|
||||
|
||||
static ggml_type ggml_type_from_str(const std::string & s) {
|
||||
if (s == "f32") {
|
||||
return GGML_TYPE_F32;
|
||||
}
|
||||
if (s == "f16") {
|
||||
return GGML_TYPE_F16;
|
||||
}
|
||||
if (s == "bf16") {
|
||||
return GGML_TYPE_BF16;
|
||||
}
|
||||
if (s == "q8_0") {
|
||||
return GGML_TYPE_Q8_0;
|
||||
}
|
||||
throw std::runtime_error("Invalid cache type: " + s);
|
||||
}
|
||||
|
||||
struct mtmd_cli_context {
|
||||
mtmd::context_ptr ctx_vision;
|
||||
common_init_result llama_init;
|
||||
@ -171,6 +187,7 @@ struct mtmd_cli_context {
|
||||
mparams.flash_attn_type = params.flash_attn ? LLAMA_FLASH_ATTN_TYPE_ENABLED : LLAMA_FLASH_ATTN_TYPE_DISABLED;
|
||||
mparams.image_min_tokens = params.image_min_tokens;
|
||||
mparams.image_max_tokens = params.image_max_tokens;
|
||||
mparams.kq_type = ggml_type_from_str(params.mtmd_kq_type);
|
||||
ctx_vision.reset(mtmd_init_from_file(clip_path, model, mparams));
|
||||
if (!ctx_vision.get()) {
|
||||
LOG_ERR("Failed to load vision model from %s\n", clip_path);
|
||||
|
||||
@ -102,6 +102,7 @@ mtmd_context_params mtmd_context_params_default() {
|
||||
/* flash_attn_type */ LLAMA_FLASH_ATTN_TYPE_AUTO,
|
||||
/* image_min_tokens */ -1,
|
||||
/* image_max_tokens */ -1,
|
||||
/* kq_type */ GGML_TYPE_F32,
|
||||
};
|
||||
return params;
|
||||
}
|
||||
@ -170,6 +171,7 @@ struct mtmd_context {
|
||||
/* flash_attn_type */ CLIP_FLASH_ATTN_TYPE_DISABLED,
|
||||
/* image_min_tokens */ ctx_params.image_min_tokens,
|
||||
/* image_max_tokens */ ctx_params.image_max_tokens,
|
||||
/* kq_type */ ctx_params.kq_type,
|
||||
};
|
||||
|
||||
auto res = clip_init(mmproj_fname, ctx_clip_params);
|
||||
|
||||
@ -87,6 +87,7 @@ struct mtmd_context_params {
|
||||
// limit number of image tokens, only for vision models with dynamic resolution
|
||||
int image_min_tokens; // minimum number of tokens for image input (default: read from metadata)
|
||||
int image_max_tokens; // maximum number of tokens for image input (default: read from metadata)
|
||||
ggml_type kq_type;
|
||||
};
|
||||
|
||||
MTMD_API const char * mtmd_default_marker(void);
|
||||
|
||||
@ -3971,12 +3971,15 @@ static ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x,
|
||||
int i = 0;
|
||||
ggml_float sum = 0;
|
||||
#if defined(__AVX512F__) && defined(__AVX512DQ__)
|
||||
__m512 vsum = _mm512_setzero_ps();
|
||||
for (; i + 15 < n; i += 16) {
|
||||
__m512 val = ggml_v_expf(_mm512_sub_ps(_mm512_loadu_ps(x + i),
|
||||
_mm512_set1_ps(max)));
|
||||
_mm512_storeu_ps(y + i, val);
|
||||
sum += (ggml_float)_mm512_reduce_add_ps(val);
|
||||
vsum = _mm512_add_ps(vsum, val);
|
||||
//sum += (ggml_float)_mm512_reduce_add_ps(val);
|
||||
}
|
||||
sum = (ggml_float)_mm512_reduce_add_ps(vsum);
|
||||
#elif defined(__AVX2__) && defined(__FMA__)
|
||||
for (; i + 7 < n; i += 8) {
|
||||
__m256 val = ggml_v_expf(_mm256_sub_ps(_mm256_loadu_ps(x + i),
|
||||
@ -21667,16 +21670,20 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
|
||||
#if GGML_USE_IQK_MULMAT
|
||||
// For now we do not implement sinks in the iqk FA implementation
|
||||
if (iqk_flash_attn_noalibi(q->type, mask->type, max_bias,
|
||||
if (iqk_flash_attn_noalibi(q->type, mask ? mask->type : GGML_TYPE_F16, max_bias,
|
||||
q->ne[3], q->ne[2], q->nb[3], q->nb[2],
|
||||
k->ne[3], k->ne[2], k->nb[3], k->nb[2],
|
||||
v->ne[3], v->ne[2], v->nb[3], v->nb[2],
|
||||
dst->ne[2], dst->ne[1], dst->nb[1],
|
||||
k->type, v->type,
|
||||
Dk, Dv, neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1],
|
||||
q->data, k->data, v->data, mask->data, sinks ? sinks->data : NULL,
|
||||
Dk, Dv, neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask ? mask->nb[1] : 0,
|
||||
q->data, k->data, v->data, mask ? mask->data : NULL, sinks ? sinks->data : NULL,
|
||||
scale, softcap, (float *)dst->data,
|
||||
params->wdata, (barrier_t)ggml_barrier, (void *)params->shared, ith, nth, dst->op_params[4])) return;
|
||||
printf("iqk_flash_attn_noalibi returned false for Dk = %ld, Dv = %ld, mask = %p:\n", Dk, Dv, (const void *)mask);
|
||||
printf(" q(%s): %ld x %ld x %ld x %ld\n", ggml_type_name(q->type), q->ne[0], q->ne[1], q->ne[2], q->ne[3]);
|
||||
printf(" k(%s): %ld x %ld x %ld x %ld\n", ggml_type_name(k->type), k->ne[0], k->ne[1], k->ne[2], k->ne[3]);
|
||||
printf(" v(%s): %ld x %ld x %ld x %ld\n", ggml_type_name(v->type), v->ne[0], v->ne[1], v->ne[2], v->ne[3]);
|
||||
|
||||
// if (max_bias <= 0.0f && q->type == GGML_TYPE_F32 && mask && mask->type == GGML_TYPE_F16) {
|
||||
// //if (ith == 0) printf("k: %ld x %ld x %ld, q: %ld x %ld x %ld, v: %ld x %ld x %ld mask: %ld x %ld x %ld\n",
|
||||
|
||||
@ -184,7 +184,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
|
||||
GGML_ABORT("Fatal error");
|
||||
}
|
||||
|
||||
if (n_swa > 0) {
|
||||
if (n_swa > 0 && mask) {
|
||||
constexpr int kMinBatch = 256;
|
||||
int ntokens = std::max(kMinBatch, neq1);
|
||||
int nblock = (ntokens + n_swa + kMinBatch - 1)/kMinBatch;
|
||||
@ -203,7 +203,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
|
||||
int rv3 = neq3/nev3;
|
||||
|
||||
int first_k = 0, last_k = nek1;
|
||||
if (neq3 == 1 && rk2 > 1 && neq1 == 1 && nek1 > 256) {
|
||||
if (neq3 == 1 && rk2 > 1 && neq1 == 1 && nek1 > 256 && mask) {
|
||||
// This is a quick hack for SWA models.
|
||||
// Given that the mask is the same for all layers, ideally we should determine the
|
||||
// cache bounds once, and reuse for the whole graph. But even with this simple hack
|
||||
@ -271,7 +271,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
|
||||
auto kth = (const char *)k + kv_offset*stride_k;
|
||||
auto vth = (const char *)v + kv_offset*stride_v;
|
||||
auto qth = (const char *)q;
|
||||
auto mth = (const char *)mask + kv_offset*sizeof(uint16_t); // we don't have ggml_half available here
|
||||
auto mth = mask ? (const char *)mask + kv_offset*sizeof(uint16_t) : nullptr; // we don't have ggml_half available here
|
||||
|
||||
auto work = (char *)work_buffer;
|
||||
auto size_thread = (Dv + 16)*rk2*sizeof(float);
|
||||
@ -322,7 +322,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
|
||||
auto vth = (const char *)v + ith_k*(nek1/gcd_k)*stride_v;
|
||||
auto q_offset = ith_q < ith_mid ? ith_q*nq_per_thread*nbq2 : (ith_mid*nq_per_thread + (ith_q - ith_mid)*nq_this_thread)*nbq2;
|
||||
auto qth = (const char *)q + q_offset;
|
||||
auto mth = (const char *)mask + ith_k*(nek1/gcd_k)*sizeof(uint16_t); // we don't have ggml_half available here
|
||||
auto mth = mask ? (const char *)mask + ith_k*(nek1/gcd_k)*sizeof(uint16_t) : nullptr; // we don't have ggml_half available here
|
||||
|
||||
// Each thread will produce a result of size Dv*nq_this_thread*sizeof(float)
|
||||
// In addition, we need M, S for the nq_this_thread rows the thread is processing
|
||||
@ -403,7 +403,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
|
||||
auto this_q = (const float *)((const char *)q + ik02*rk2*nbq2);
|
||||
auto this_k = (const char *)k + ik01*stride_k + ik02*nbk2;
|
||||
auto this_v = (const char *)v + ik01*stride_v + ik02*nbv2;
|
||||
auto this_m = (const char *)mask + ik01*sizeof(uint16_t); // we don't have ggml_half available here
|
||||
auto this_m = mask ? (const char *)mask + ik01*sizeof(uint16_t) : nullptr; // we don't have ggml_half available here
|
||||
if (!iqk_flash_attn_impl(int_type_k, int_type_v,
|
||||
Dk, Dv, rk2, this_nk, nbq2, stride_k, stride_v, 0, Dv,
|
||||
this_q, (const void *)this_k, (const void *)this_v, (const void *)this_m, nullptr, 0,
|
||||
@ -473,7 +473,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
|
||||
(const float *)((const char *)q + iq2*nbq2 + iq3*nbq3 + iq1*stride_q),
|
||||
(const void *)((const char *)k + iq2/rk2*nbk2 + iq3/rk3*nbk3),
|
||||
(const void *)((const char *)v + iq2/rv2*nbv2 + iq3/rv3*nbv3),
|
||||
(const void *)((const char *)mask + iq1*stride_m), sinksf, 1,
|
||||
mask ? (const void *)((const char *)mask + iq1*stride_m) : nullptr, sinksf, 1,
|
||||
scale, softcap,
|
||||
(float *)((char *)qkv + (iq3*ne2*ne1 + iq2 + iq1*ne1)*nb1), nullptr, nullptr)) return false;
|
||||
}
|
||||
|
||||
@ -420,6 +420,32 @@ template <int nrc_in> struct QFTBF16 final : public QFBaseBF16 {
|
||||
IQK_ALWAYS_INLINE Data load1(int iy, int i) const { return load(y[iy] + k_step*i); }
|
||||
const ggml_bf16_t * y[nrc];
|
||||
};
|
||||
struct QFBaseBF16x8 {
|
||||
constexpr static int k_step = 16;
|
||||
using Data = __m256bh;
|
||||
using Acc = __m256;
|
||||
static inline Data load(const ggml_bf16_t * x) { return __m256bh(_mm256_loadu_si256((const __m256i *)x)); }
|
||||
static inline Acc acc(Acc prev, Data y, Data x) {
|
||||
return _mm256_dpbf16_ps(prev, y, x);
|
||||
}
|
||||
static inline Acc acc_first(const Data& y, const Data& x) {
|
||||
return _mm256_dpbf16_ps(_mm256_setzero_ps(), y, x);
|
||||
}
|
||||
static inline float hsum(Acc acc) {
|
||||
return hsum_float_8(acc);
|
||||
}
|
||||
};
|
||||
template <int nrc_in> struct QFTBF16x8 final : public QFBaseBF16x8 {
|
||||
constexpr static int nrc = nrc_in;
|
||||
QFTBF16x8(const DataInfo& info) {
|
||||
for (int iy = 0; iy < nrc; ++iy) y[iy] = (const ggml_bf16_t *)info.src1_row(iy);
|
||||
}
|
||||
QFTBF16x8(const char * cx, size_t bx) {
|
||||
for (int iy = 0; iy < nrc; ++iy) y[iy] = (const ggml_bf16_t *)(cx + iy*bx);
|
||||
}
|
||||
IQK_ALWAYS_INLINE Data load1(int iy, int i) const { return load(y[iy] + k_step*i); }
|
||||
const ggml_bf16_t * y[nrc];
|
||||
};
|
||||
|
||||
template <int nrc_y, int nrc_x>
|
||||
IQK_NOINLINE void mul_mat_Qx_Qy_MxN(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) {
|
||||
@ -476,6 +502,61 @@ void mul_mat_fX_fY_T(int n, const void * vx, size_t bx, const DataInfo& info, in
|
||||
case 4: mul_mat_Qx_Qy_MxN<nrc_y, 4>(n, cx, bx, last_x, info); break;
|
||||
}
|
||||
}
|
||||
template <int nrc_y, int nrc_x>
|
||||
IQK_NOINLINE void mul_mat_Qx_Qy_MxNx8(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) {
|
||||
int nb = n/QFBaseBF16x8::k_step;
|
||||
QFTBF16x8<nrc_y> y(info);
|
||||
QFTBF16x8<nrc_x> x(cx + ix0*bx, bx);
|
||||
QFBaseBF16x8::Data xv[nrc_x];
|
||||
QFBaseBF16x8::Acc acc[nrc_x*nrc_y];
|
||||
auto yv = y.load1(0, 0);
|
||||
for (int ix = 0; ix < nrc_x; ++ix) {
|
||||
xv[ix] = x.load1(ix, 0);
|
||||
acc[ix] = QFBaseBF16x8::acc_first(yv, xv[ix]);
|
||||
}
|
||||
for (int iy = 1; iy < nrc_y; ++iy) {
|
||||
yv = y.load1(iy, 0);
|
||||
for (int ix = 0; ix < nrc_x; ++ix) acc[nrc_x*iy + ix] = QFBaseBF16x8::acc_first(yv, xv[ix]);
|
||||
}
|
||||
for (int i = 1; i < nb; ++i) {
|
||||
yv = y.load1(0, i);
|
||||
for (int ix = 0; ix < nrc_x; ++ix) {
|
||||
xv[ix] = x.load1(ix, i);
|
||||
acc[ix] = QFBaseBF16x8::acc(acc[ix], yv, xv[ix]);
|
||||
}
|
||||
for (int iy = 1; iy < nrc_y; ++iy) {
|
||||
yv = y.load1(iy, i);
|
||||
for (int ix = 0; ix < nrc_x; ++ix) acc[nrc_x*iy + ix] = QFBaseBF16x8::acc(acc[nrc_x*iy + ix], yv, xv[ix]);
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < nrc_x; ++ix) info.store(ix0+ix, iy, QFBaseBF16x8::hsum(acc[nrc_x*iy+ix]));
|
||||
}
|
||||
|
||||
template <int nrc_y>
|
||||
void mul_mat_fX_fY_Tx8(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
constexpr int k_nx = nrc_y <= 2 ? 8 : 5;
|
||||
const char * cx = (const char *)vx;
|
||||
for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
|
||||
mul_mat_Qx_Qy_MxNx8<nrc_y, k_nx>(n, cx, bx, ix*k_nx, info);
|
||||
}
|
||||
int last_x = k_nx*(nrc_x/k_nx);
|
||||
if (last_x == nrc_x) return;
|
||||
int nx = nrc_x - last_x;
|
||||
if constexpr (nrc_y <= 2) {
|
||||
if (nx >= 4) {
|
||||
mul_mat_Qx_Qy_MxNx8<nrc_y, 4>(n, cx, bx, last_x, info);
|
||||
last_x += 4;
|
||||
if (last_x == nrc_x) return;
|
||||
nx = nrc_x - last_x;
|
||||
}
|
||||
}
|
||||
switch (nx) {
|
||||
case 1: mul_mat_Qx_Qy_MxNx8<nrc_y, 1>(n, cx, bx, last_x, info); break;
|
||||
case 2: mul_mat_Qx_Qy_MxNx8<nrc_y, 2>(n, cx, bx, last_x, info); break;
|
||||
case 3: mul_mat_Qx_Qy_MxNx8<nrc_y, 3>(n, cx, bx, last_x, info); break;
|
||||
case 4: mul_mat_Qx_Qy_MxNx8<nrc_y, 4>(n, cx, bx, last_x, info); break;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
@ -501,6 +582,14 @@ void set_mul_mat_bf16(std::array<mul_mat_t, IQK_MAX_NY>& funcs) {
|
||||
funcs[3] = mul_mat_fX_fY_T<4>;
|
||||
funcs[4] = mul_mat_fX_fY_T<5>;
|
||||
}
|
||||
void set_mul_mat_bf16x8(std::array<mul_mat_t, IQK_MAX_NY>& funcs) {
|
||||
for (auto& f : funcs) f = nullptr;
|
||||
funcs[0] = mul_mat_fX_fY_Tx8<1>;
|
||||
funcs[1] = mul_mat_fX_fY_Tx8<2>;
|
||||
funcs[2] = mul_mat_fX_fY_Tx8<3>;
|
||||
funcs[3] = mul_mat_fX_fY_Tx8<4>;
|
||||
funcs[4] = mul_mat_fX_fY_Tx8<5>;
|
||||
}
|
||||
void set_mul_mat_bf16_r16(std::array<mul_mat_t, IQK_MAX_NY>& funcs) {
|
||||
for (auto& f : funcs) f = nullptr;
|
||||
funcs[0] = mul_mat_bf16_r16_bf16<1>;
|
||||
@ -519,10 +608,16 @@ void set_mul_mat_bf16_r16(std::array<mul_mat_t, IQK_MAX_NY>& funcs) {
|
||||
bool iqk_set_kernels_float(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels) {
|
||||
|
||||
if (typeA == GGML_TYPE_BF16) {
|
||||
if (ne00 % 16) return false;
|
||||
if (ne00 % 8) return false;
|
||||
switch (typeB) {
|
||||
#ifdef __AVX512BF16__
|
||||
case GGML_TYPE_BF16: set_mul_mat_bf16(kernels); break;
|
||||
case GGML_TYPE_BF16: {
|
||||
if (ne00 % 16 == 0) {
|
||||
set_mul_mat_bf16(kernels);
|
||||
} else {
|
||||
set_mul_mat_bf16x8(kernels);
|
||||
}
|
||||
} break;
|
||||
#else
|
||||
case GGML_TYPE_BF16: set_mul_mat_f<ggml_bf16_t, ggml_bf16_t>(kernels); break;
|
||||
case GGML_TYPE_F32: set_mul_mat_f<ggml_bf16_t, float>(kernels); break;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user