Fusing a mat mul op followed by scale op on the CPU

This is useful for Bitnet here we have almost all matricx
multiplications be followed by scale operations.
As a result, we get a ~2% boost in Bitnet PP performance.

Implementation is easy when the matrix multiplication is done
by iqk_mul_mat. But if iqk_mul_mat is not implemented for the
quant type/architecture, we need to add the scaling to
llamafile sgemm and to ggml itself, which is way more
messy, so I didn't do it yet.
Given that Bitnet is just a niche thing for now, I'll just
leave it on a branch for now.
This commit is contained in:
Iwan Kawrakow 2024-07-27 10:45:56 +03:00
parent f62615b44f
commit 473e280500
5 changed files with 38 additions and 21 deletions

View File

@ -3812,7 +3812,7 @@ static inline __m128i get_scale_shuffle(int i) {
void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
#if GGML_USE_IQK_MULMAT
if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_Q4_0, vx, bx, GGML_TYPE_Q8_0, vy, by, s, bs, 0, 1)) {
if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_Q4_0, vx, bx, GGML_TYPE_Q8_0, vy, by, s, bs, 0, 1, 1.f)) {
return;
}
#endif
@ -4296,7 +4296,7 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
#if GGML_USE_IQK_MULMAT
if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_Q4_1, vx, bx, GGML_TYPE_Q8_1, vy, by, s, bs, 0, 1)) {
if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_Q4_1, vx, bx, GGML_TYPE_Q8_1, vy, by, s, bs, 0, 1, 1.f)) {
return;
}
#endif
@ -4585,7 +4585,7 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
#if GGML_USE_IQK_MULMAT
if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_Q5_0, vx, bx, GGML_TYPE_Q8_0, vy, by, s, bs, 0, 1)) {
if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_Q5_0, vx, bx, GGML_TYPE_Q8_0, vy, by, s, bs, 0, 1, 1.f)) {
return;
}
#endif
@ -4942,7 +4942,7 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r
void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
#if GGML_USE_IQK_MULMAT
if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_Q5_1, vx, bx, GGML_TYPE_Q8_1, vy, by, s, bs, 0, 1)) {
if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_Q5_1, vx, bx, GGML_TYPE_Q8_1, vy, by, s, bs, 0, 1, 1.f)) {
return;
}
#endif
@ -5318,7 +5318,7 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r
void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
#if GGML_USE_IQK_MULMAT
if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_Q8_0, vx, bx, GGML_TYPE_Q8_0, vy, by, s, bs, 0, 1)) {
if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_Q8_0, vx, bx, GGML_TYPE_Q8_0, vy, by, s, bs, 0, 1, 1.f)) {
return;
}
#endif
@ -11692,7 +11692,7 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void
void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
#if GGML_USE_IQK_MULMAT
if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_IQ4_NL, vx, bx, GGML_TYPE_Q8_0, vy, by, s, bs, 0, 1)) {
if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_IQ4_NL, vx, bx, GGML_TYPE_Q8_0, vy, by, s, bs, 0, 1, 1.f)) {
return;
}
#endif

View File

@ -12295,7 +12295,8 @@ static void ggml_compute_forward_mul_mat_one_chunk(
static void ggml_compute_forward_mul_mat(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
struct ggml_tensor * dst,
float scale) {
const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * src1 = dst->src[1];
@ -12350,7 +12351,7 @@ static void ggml_compute_forward_mul_mat(
src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01/ggml_type_size(src0->type),
src1->type, (const char *)src1->data + i12*nb12 + i13*nb13, nb11/ggml_type_size(src1->type),
(float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type),
0, 1)) goto IQK_MulMat_Not_Available1;
0, 1, scale)) goto IQK_MulMat_Not_Available1;
}
}
}
@ -12363,7 +12364,7 @@ static void ggml_compute_forward_mul_mat(
src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01/ggml_type_size(src0->type),
src1->type, (const char *)src1->data + i12*nb12 + i13*nb13, nb11/ggml_type_size(src1->type),
(float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type),
ith, nth)) goto IQK_MulMat_Not_Available1;
ith, nth, scale)) goto IQK_MulMat_Not_Available1;
return;
}
IQK_MulMat_Not_Available1:;
@ -12388,6 +12389,11 @@ IQK_MulMat_Not_Available1:;
src1->type,
dst->type))
goto UseGgmlGemm1;
//TODO: apply scale if different from 1
//if (fabsf(scale-1.f) > 1e-4f) {
// ggml_barrier(params->shared);
// ggml_compute_forward_scale_f32(params, scale);
//}
return;
}
UseGgmlGemm1:;
@ -12441,7 +12447,7 @@ UseGgmlGemm1:;
src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01/ggml_type_size(src0->type),
vec_dot_type, (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size, row_size/ggml_type_size(vec_dot_type),
(float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type),
ith, nth)) goto IQK_MulMat_Not_Available2;
ith, nth, scale)) goto IQK_MulMat_Not_Available2;
return;
}
IQK_MulMat_Not_Available2:;
@ -12554,6 +12560,7 @@ UseGgmlGemm2:;
current_chunk = atomic_fetch_add(&params->shared->current_chunk, 1);
}
//TODO: apply scale if different from 1
}
// ggml_compute_forward_mul_mat_id
@ -16811,11 +16818,11 @@ static void ggml_compute_forward_cross_entropy_loss_back(
/////////////////////////////////
static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
static bool ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor, struct ggml_tensor * next) {
GGML_ASSERT(params);
if (tensor->op == GGML_OP_NONE || ggml_is_empty(tensor)) {
return;
return false;
}
switch (tensor->op) {
@ -16909,7 +16916,13 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
} break;
case GGML_OP_MUL_MAT:
{
ggml_compute_forward_mul_mat(params, tensor);
if (next && next->op == GGML_OP_SCALE) {
float scale;
memcpy(&scale, next->op_params, sizeof(float));
ggml_compute_forward_mul_mat(params, tensor, scale);
return true;
}
ggml_compute_forward_mul_mat(params, tensor, 1.f);
} break;
case GGML_OP_MUL_MAT_ID:
{
@ -17143,6 +17156,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
GGML_ASSERT(false);
} break;
}
return false;
}
////////////////////////////////////////////////////////////////////////////////
@ -18991,7 +19005,9 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
for (int node_n = 0; node_n < cgraph->n_nodes; node_n++) {
struct ggml_tensor * node = cgraph->nodes[node_n];
ggml_compute_forward(&params, node);
if (ggml_compute_forward(&params, node, node_n < cgraph->n_nodes - 1 ? cgraph->nodes[node_n+1] : NULL)) {
++node_n;
}
if (state->ith == 0 && cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) {
state->shared->ec = GGML_STATUS_ABORTED;

View File

@ -73,6 +73,7 @@ struct DataInfo {
int ne11;
const mmid_row_mapping * row_mapping = nullptr;
size_t bs2 = 0;
float scale;
inline const char * src1_row(int iy) const {
if (!row_mapping) return cy + (cur_y + iy)*by;
@ -82,7 +83,7 @@ struct DataInfo {
}
inline void store(int ix, int iy, float result) const {
*(dst_row(iy) + ix) = result;
*(dst_row(iy) + ix) = result*scale;
}
inline float * dst_row(int iy) const {
if (!row_mapping) return s + (cur_y + iy)*bs;
@ -133,7 +134,7 @@ private:
bool iqk_mul_mat(long Nx, long Ny, long ne00,
int typeA, const void * A, long strideA,
int typeB, const void * B, long strideB,
float * C, long stride_C, int ith, int nth) {
float * C, long stride_C, int ith, int nth, float scale) {
MulMat mm;
if (!MulMat::prepare(typeA, typeB, ne00, mm, Ny)) {
@ -147,7 +148,7 @@ bool iqk_mul_mat(long Nx, long Ny, long ne00,
auto first_x = ith*nrc_x;
if (first_x + nrc_x > Nx) nrc_x = Nx - first_x;
DataInfo info{C + first_x, (const char *)B, (size_t)stride_C, row_size_qy, 0, 1, nullptr, 0};
DataInfo info{C + first_x, (const char *)B, (size_t)stride_C, row_size_qy, 0, 1, nullptr, 0, scale};
mm.mul_mat_NxM(ne00, (const char *)A + row_size_qx*first_x, row_size_qx, info, nrc_x, Ny);
@ -171,7 +172,7 @@ bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11,
int first_x = ith*nrc_x;
if (first_x + nrc_x > Nx) nrc_x = Nx - first_x;
DataInfo info{C + first_x, (const char *)B, nb1/sizeof(float),
row_size_qy, 0, ne11, row_mapping, nb2/sizeof(float)};
row_size_qy, 0, ne11, row_mapping, nb2/sizeof(float), 1.f};
mm.mul_mat_NxM(ne00, (const char *)A + row_size_qx*first_x, row_size_qx, info, nrc_x, Ny);
return true;
}

View File

@ -14,7 +14,7 @@ extern "C" {
bool iqk_mul_mat(long Nx, long Ny, long ne00,
int typeA, const void * A, long strideA,
int typeB, const void * B, long strideB,
float * C, long stride_C, int ith, int nth);
float * C, long stride_C, int ith, int nth, float scale);
bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11,
int typeA, const void * A, long strideA,

View File

@ -236,7 +236,7 @@ void ggml_vec_dot_iq1_bn_q8_K64(int n, float * s, size_t bs, const void * vx, si
static_assert(QK_IQ1BN == 64, "This dot product implementation for iq1_bn requires a block size of 64");
#if GGML_USE_IQK_MULMAT
if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ1_BN, vx, 0, GGML_TYPE_Q8_K64, vy, 0, s, 0, 0, 1)) {
if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ1_BN, vx, 0, GGML_TYPE_Q8_K64, vy, 0, s, 0, 0, 1, 1.f)) {
return;
}
#endif
@ -286,7 +286,7 @@ void ggml_vec_dot_iq2_bn_q8_K64(int n, float * s, size_t bs, const void * vx, si
static_assert(QK_IQ1BN == 64, "This dot product implementation for iq2_bn requires a block size of 64");
if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ2_BN, vx, 0, GGML_TYPE_Q8_K64, vy, 0, s, 0, 0, 1)) {
if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ2_BN, vx, 0, GGML_TYPE_Q8_K64, vy, 0, s, 0, 0, 1, 1.f)) {
return;
}