CPU Flash Attention improvements (#172)

* Slightly faster FA for bf16 KV cache

~2-3% sort of thing. Sadly, when we go beyond 8k tokens, the
advantage kind of goes away.

* Slightly faster FA for Q8_0 KV cache

* FA: allow bf16 for V-cache with any supported K-cache

E.g., -ctk q8_0 -ctv bf16 is slightly faster than
-ctk q8_0 -ctv q8_0 on Zen4 for not too long context lengths
(say, <= 4096).

* FA: much better bf16 kv-cache speed for large contexts

We now hit 122 t/s for LLaMA-3.1-8B (quantized as iq4_xs and
run-time-repacked) with a context of 32768. IIRC, the previous
best for such large context was ~90 t/s.
Non-negligible improvement at 16384 and 8192 as well:
173.4 and 214 t/s.

* FA: slightly better quantized kv-cache speed for large contexts

E.g., for q8_0 and context of 32768, we are now at 113 t/s
for LLaMA-3.1-8B.

Also simplified the quantized K*Q multiplication.

* Fix q8_0 KV cache when not using FA - WIP (AVX2)

1. We add new types GGML_TYPE_Q8_0_X4 and GGML_TYPE_Q8_1_X4, and use
   those to quantize activations for quants that use Q8_0 or Q8_1
   as their vec_dot type.
2. We revert the changes to quantize_row_q8_0 and quantize_row_q8_1
3. We use GGML_TYPE_Q8_0_X4 and GGML_TYPE_Q8_1_X4 as the vec_dot type
4. We change the FA implementation to use GGML_TYPE_Q8_0 rather than
   GGML_TYPE_Q8_0_X4 as the K and V types
5. We change the expected type to GGML_TYPE_Q8_0_X4/GGML_TYPE_Q8_1_X4
   in iqk_mul_mat

Also added an optimization in ggml_compute_forward_mul_mat when
ne12*ne13 > 1 (K*Q and V*softmax(K*Q)) to process
n12*ne13/GCD(n12*ne13, nthread) threads simultaneously using
nthread/GCD(n12*ne13, nthread) threads per head. This results in
a non-negligible performance gain for large contexts.

Question: why is it not allowed to use quantized V-cache when
not using FA?

* Fix q8_0 KV cache when not using FA - NEON

* Fix AVX2

Again the issue with _mm256_maddubs_epi16 overflowing that I
keep forgetting.

* FA: don't use large Q steps on AVX2 for fp16 K-cache

* On Zen4 it is also better to not use large Q steps for fp16 K-cache

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow 2025-01-15 18:19:22 +02:00 committed by GitHub
parent c6503556b7
commit c606c19101
6 changed files with 737 additions and 605 deletions

View File

@ -396,6 +396,8 @@ extern "C" {
//
GGML_TYPE_I2_S = 36,
//
GGML_TYPE_Q8_0_X4 = 98,
GGML_TYPE_Q8_1_X4 = 99,
GGML_TYPE_Q6_0 = 133,
GGML_TYPE_IQ1_BN = 134,
GGML_TYPE_IQ2_BN = 135,

View File

@ -934,13 +934,7 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k)
block_q8_0 * restrict y = vy;
#if GGML_USE_IQK_MULMAT
const int nb4 = 4*(nb/4);
#else
const int nb4 = -1;
#endif
#if defined(__ARM_NEON)
block_q8_0_x4 * y4 = (block_q8_0_x4 *)vy;
for (int i = 0; i < nb; i++) {
int i4 = i/4, ir = i%4;
float32x4_t srcv [8];
@ -959,27 +953,16 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k)
const float d = amax / ((1 << 7) - 1);
const float id = d ? 1.0f/d : 0.0f;
if (i < nb4) {
y4[i4].d[ir] = GGML_FP32_TO_FP16(d);
} else {
y[i].d = GGML_FP32_TO_FP16(d);
}
y[i].d = GGML_FP32_TO_FP16(d);
for (int j = 0; j < 8; j++) {
const float32x4_t v = vmulq_n_f32(srcv[j], id);
const int32x4_t vi = vcvtnq_s32_f32(v);
if (i < nb4) {
y4[i4].qs[32*ir + 4*j + 0] = vgetq_lane_s32(vi, 0);
y4[i4].qs[32*ir + 4*j + 1] = vgetq_lane_s32(vi, 1);
y4[i4].qs[32*ir + 4*j + 2] = vgetq_lane_s32(vi, 2);
y4[i4].qs[32*ir + 4*j + 3] = vgetq_lane_s32(vi, 3);
} else {
y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0);
y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1);
y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2);
y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
}
y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0);
y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1);
y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2);
y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
}
}
#elif defined(__wasm_simd128__)
@ -1016,14 +999,7 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k)
}
}
#elif defined(__AVX2__) || defined(__AVX__)
block_q8_0_x4 * y4 = (block_q8_0_x4 *)vy;
#ifdef __AVX2__
const bool pack = true;
#else
const bool pack = false;
#endif
for (int i = 0; i < nb; i++) {
int i4 = i/4, ir = i%4;
// Load elements into 4 AVX vectors
__m256 v0 = _mm256_loadu_ps( x );
__m256 v1 = _mm256_loadu_ps( x + 8 );
@ -1045,11 +1021,7 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k)
// Quantize these floats
const float d = maxScalar / 127.f;
if (pack && i < nb4) {
y4[i4].d[ir] = GGML_FP32_TO_FP16(d);
} else {
y[i].d = GGML_FP32_TO_FP16(d);
}
y[i].d = GGML_FP32_TO_FP16(d);
const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
const __m256 mul = _mm256_set1_ps( id );
@ -1084,11 +1056,7 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k)
const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
i0 = _mm256_permutevar8x32_epi32( i0, perm );
if (i < nb4) {
_mm256_storeu_si256((__m256i *)y4[i4].qs + ir, i0);
} else {
_mm256_storeu_si256((__m256i *)y[i].qs, i0);
}
_mm256_storeu_si256((__m256i *)y[i].qs, i0);
#else
// Since we don't have in AVX some necessary functions,
// we split the registers in half and call AVX2 analogs from SSE
@ -1287,15 +1255,8 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
block_q8_1 * restrict y = vy;
#if GGML_USE_IQK_MULMAT
const int nb4 = 4*(nb/4);
#else
const int nb4 = -1;
#endif
#if defined(__ARM_NEON)
block_q8_1_x4 * restrict y4 = vy;
for (int i = 0; i < nb; i++) {
int i4 = i/4, ir = i%4;
float32x4_t srcv [8];
float32x4_t asrcv[8];
float32x4_t amaxv[8];
@ -1312,11 +1273,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
const float d = amax / ((1 << 7) - 1);
const float id = d ? 1.0f/d : 0.0f;
if (i < nb4) {
y4[i4].d[ir] = GGML_FP32_TO_FP16(d);
} else {
y[i].d = GGML_FP32_TO_FP16(d);
}
y[i].d = GGML_FP32_TO_FP16(d);
int32x4_t accv = vdupq_n_s32(0);
@ -1324,26 +1281,15 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
const float32x4_t v = vmulq_n_f32(srcv[j], id);
const int32x4_t vi = vcvtnq_s32_f32(v);
if (i < nb4) {
y4[i4].qs[QK8_1*ir + 4*j + 0] = vgetq_lane_s32(vi, 0);
y4[i4].qs[QK8_1*ir + 4*j + 1] = vgetq_lane_s32(vi, 1);
y4[i4].qs[QK8_1*ir + 4*j + 2] = vgetq_lane_s32(vi, 2);
y4[i4].qs[QK8_1*ir + 4*j + 3] = vgetq_lane_s32(vi, 3);
} else {
y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0);
y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1);
y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2);
y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
}
y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0);
y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1);
y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2);
y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
accv = vaddq_s32(accv, vi);
}
if (i < nb4) {
y4[i4].d[ir+4] = GGML_FP32_TO_FP16(d * vaddvq_s32(accv));
} else {
y[i].s = GGML_FP32_TO_FP16(d * vaddvq_s32(accv));
}
y[i].s = GGML_FP32_TO_FP16(d * vaddvq_s32(accv));
}
#elif defined(__wasm_simd128__)
for (int i = 0; i < nb; i++) {
@ -1389,14 +1335,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
wasm_i32x4_extract_lane(accv, 3)));
}
#elif defined(__AVX2__) || defined(__AVX__)
block_q8_1_x4 * restrict y4 = vy;
#ifdef __AVX2__
const bool pack = true;
#else
const bool pack = false;
#endif
for (int i = 0; i < nb; i++) {
int i4 = i/4, ir = i%4;
// Load elements into 4 AVX vectors
__m256 v0 = _mm256_loadu_ps( x );
__m256 v1 = _mm256_loadu_ps( x + 8 );
@ -1418,11 +1357,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
// Quantize these floats
const float d = max_scalar / 127.f;
if (pack && i < nb4) {
y4[i4].d[ir] = GGML_FP32_TO_FP16(d);
} else {
y[i].d = GGML_FP32_TO_FP16(d);
}
y[i].d = GGML_FP32_TO_FP16(d);
const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f;
const __m256 mul = _mm256_set1_ps( id );
@ -1446,11 +1381,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
#if defined(__AVX2__)
// Compute the sum of the quants and set y[i].s
if (i < nb4) {
y4[i4].d[ir+4] = GGML_FP32_TO_FP16(d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))));
} else {
y[i].s = GGML_FP32_TO_FP16(d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))));
}
y[i].s = GGML_FP32_TO_FP16(d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))));
// Convert int32 to int16
i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
@ -1464,11 +1395,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
i0 = _mm256_permutevar8x32_epi32( i0, perm );
if (i < nb4) {
_mm256_storeu_si256((__m256i *)y4[i4].qs + ir, i0);
} else {
_mm256_storeu_si256((__m256i *)y[i].qs, i0);
}
_mm256_storeu_si256((__m256i *)y[i].qs, i0);
#else
// Since we don't have in AVX some necessary functions,
// we split the registers in half and call AVX2 analogs from SSE

View File

@ -714,8 +714,12 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_q4_0,
.from_float_ref = (ggml_from_float_t) quantize_row_q4_0_ref,
.vec_dot = ggml_vec_dot_q4_0_q8_0,
#if GGML_USE_IQK_MULMAT && defined __AVX2__
.vec_dot_type = GGML_TYPE_Q8_1,
#if GGML_USE_IQK_MULMAT
#if defined __AVX2__
.vec_dot_type = GGML_TYPE_Q8_1_X4,
#else
.vec_dot_type = GGML_TYPE_Q8_0_X4,
#endif
#else
.vec_dot_type = GGML_TYPE_Q8_0,
#endif
@ -735,7 +739,11 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_q4_1,
.from_float_ref = (ggml_from_float_t) quantize_row_q4_1_ref,
.vec_dot = ggml_vec_dot_q4_1_q8_1,
#if GGML_USE_IQK_MULMAT
.vec_dot_type = GGML_TYPE_Q8_1_X4,
#else
.vec_dot_type = GGML_TYPE_Q8_1,
#endif
#if defined (__ARM_FEATURE_MATMUL_INT8)
.nrows = 2,
#else
@ -778,8 +786,12 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_q5_0,
.from_float_ref = (ggml_from_float_t) quantize_row_q5_0_ref,
.vec_dot = ggml_vec_dot_q5_0_q8_0,
#if GGML_USE_IQK_MULMAT && defined __AVX2__
.vec_dot_type = GGML_TYPE_Q8_1,
#if GGML_USE_IQK_MULMAT
#if defined __AVX2__
.vec_dot_type = GGML_TYPE_Q8_1_X4,
#else
.vec_dot_type = GGML_TYPE_Q8_0_X4,
#endif
#else
.vec_dot_type = GGML_TYPE_Q8_0,
#endif
@ -795,7 +807,11 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_q5_1,
.from_float_ref = (ggml_from_float_t) quantize_row_q5_1_ref,
.vec_dot = ggml_vec_dot_q5_1_q8_1,
#if GGML_USE_IQK_MULMAT
.vec_dot_type = GGML_TYPE_Q8_1_X4,
#else
.vec_dot_type = GGML_TYPE_Q8_1,
#endif
.nrows = 1,
.row_meta_size = 0,
},
@ -808,8 +824,12 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_q6_0,
.from_float_ref = (ggml_from_float_t) quantize_row_q6_0_ref,
.vec_dot = ggml_vec_dot_q6_0_q8_0,
#if GGML_USE_IQK_MULMAT && defined __AVX2__
.vec_dot_type = GGML_TYPE_Q8_1,
#if GGML_USE_IQK_MULMAT
#if defined __AVX2__
.vec_dot_type = GGML_TYPE_Q8_1_X4,
#else
.vec_dot_type = GGML_TYPE_Q8_0_X4,
#endif
#else
.vec_dot_type = GGML_TYPE_Q8_0,
#endif
@ -826,8 +846,16 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float_ref = (ggml_from_float_t) quantize_row_q8_0_ref,
.from_float_to_mat = quantize_mat_q8_0,
.vec_dot = ggml_vec_dot_q8_0_q8_0,
#if GGML_USE_IQK_MULMAT && defined __AVX2__
.vec_dot_type = GGML_TYPE_Q8_1,
#if GGML_USE_IQK_MULMAT
#if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__) && defined(__AVX512BW__) && defined(__AVX512DQ__)
// Remember: we cannot add 128 to the Q8 quants and use iblock sum in Q8_1 to subtract as we do on Zen4 for pure AVX2
// because there the result of the _mm256_maddubs_epi16() instruction may overflow the int16_t range
// (and it gets satured if it does), leading to wrong results.
// TODO: expose HAVE_FANCY_SIMD from iqk_mul_mat.cpp and use #ifdef HAVE_FANCY_SIMD instead of the above.
.vec_dot_type = GGML_TYPE_Q8_1_X4,
#else
.vec_dot_type = GGML_TYPE_Q8_0_X4,
#endif
#else
.vec_dot_type = GGML_TYPE_Q8_0,
#endif
@ -849,6 +877,26 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.nrows = 1,
.row_meta_size = 0,
},
[GGML_TYPE_Q8_0_X4] = {
.type_name = "q8_0_x4",
.blck_size = QK8_0,
.type_size = sizeof(block_q8_0),
.is_quantized = true,
.from_float = quantize_row_q8_0_x4,
.from_float_ref = quantize_row_q8_0_x4,
.nrows = 1,
.row_meta_size = 0,
},
[GGML_TYPE_Q8_1_X4] = {
.type_name = "q8_1_x4",
.blck_size = QK8_1,
.type_size = sizeof(block_q8_1),
.is_quantized = true,
.from_float = quantize_row_q8_1_x4,
.from_float_ref = quantize_row_q8_1_x4,
.nrows = 1,
.row_meta_size = 0,
},
[GGML_TYPE_Q2_K] = {
.type_name = "q2_K",
.blck_size = QK_K,
@ -1196,8 +1244,12 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_iq4_nl,
.from_float_ref = (ggml_from_float_t)quantize_row_iq4_nl_ref,
.vec_dot = ggml_vec_dot_iq4_nl_q8_0,
#if GGML_USE_IQK_MULMAT && defined __AVX2__
.vec_dot_type = GGML_TYPE_Q8_1,
#if GGML_USE_IQK_MULMAT
#if defined __AVX2__
.vec_dot_type = GGML_TYPE_Q8_1_X4,
#else
.vec_dot_type = GGML_TYPE_Q8_0_X4,
#endif
#else
.vec_dot_type = GGML_TYPE_Q8_0,
#endif
@ -1516,8 +1568,12 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_iq4_nl_r4,
.from_float_ref = (ggml_from_float_t)quantize_row_iq4_nl_r4_ref,
.vec_dot = vec_dot_iq4_nl_r4_q8_0,
#if GGML_USE_IQK_MULMAT && defined __AVX2__
.vec_dot_type = GGML_TYPE_Q8_1,
#if GGML_USE_IQK_MULMAT
#if defined __AVX2__
.vec_dot_type = GGML_TYPE_Q8_1_X4,
#else
.vec_dot_type = GGML_TYPE_Q8_0_X4,
#endif
#else
.vec_dot_type = GGML_TYPE_Q8_0,
#endif
@ -1546,8 +1602,12 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_q4_0_r4,
.from_float_ref = (ggml_from_float_t)quantize_row_q4_0_r4_ref,
.vec_dot = vec_dot_q4_0_r4_q8_0,
#if GGML_USE_IQK_MULMAT && defined __AVX2__
.vec_dot_type = GGML_TYPE_Q8_1,
#if GGML_USE_IQK_MULMAT
#if defined __AVX2__
.vec_dot_type = GGML_TYPE_Q8_1_X4,
#else
.vec_dot_type = GGML_TYPE_Q8_0_X4,
#endif
#else
.vec_dot_type = GGML_TYPE_Q8_0,
#endif
@ -1563,8 +1623,12 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_q8_0_r4,
.from_float_ref = (ggml_from_float_t)quantize_row_q8_0_r4_ref,
.vec_dot = vec_dot_q8_0_r4_q8_0,
#if GGML_USE_IQK_MULMAT && defined __AVX2__
.vec_dot_type = GGML_TYPE_Q8_1,
#if GGML_USE_IQK_MULMAT
#if defined __AVX2__
.vec_dot_type = GGML_TYPE_Q8_1_X4,
#else
.vec_dot_type = GGML_TYPE_Q8_0_X4,
#endif
#else
.vec_dot_type = GGML_TYPE_Q8_0,
#endif
@ -1580,8 +1644,12 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_q5_0_r4,
.from_float_ref = (ggml_from_float_t)quantize_row_q5_0_r4_ref,
.vec_dot = vec_dot_q5_0_r4_q8_0,
#if GGML_USE_IQK_MULMAT && defined __AVX2__
.vec_dot_type = GGML_TYPE_Q8_1,
#if GGML_USE_IQK_MULMAT
#if defined __AVX2__
.vec_dot_type = GGML_TYPE_Q8_1_X4,
#else
.vec_dot_type = GGML_TYPE_Q8_0_X4,
#endif
#else
.vec_dot_type = GGML_TYPE_Q8_0,
#endif
@ -1597,8 +1665,12 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_q6_0_r4,
.from_float_ref = (ggml_from_float_t)quantize_row_q6_0_r4_ref,
.vec_dot = vec_dot_q6_0_r4_q8_0,
#if GGML_USE_IQK_MULMAT && defined __AVX2__
.vec_dot_type = GGML_TYPE_Q8_1,
#if GGML_USE_IQK_MULMAT
#if defined __AVX2__
.vec_dot_type = GGML_TYPE_Q8_1_X4,
#else
.vec_dot_type = GGML_TYPE_Q8_0_X4,
#endif
#else
.vec_dot_type = GGML_TYPE_Q8_0,
#endif
@ -11280,6 +11352,8 @@ static void ggml_compute_forward_add1(
case GGML_TYPE_Q6_0:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
case GGML_TYPE_Q8_0_X4:
case GGML_TYPE_Q8_1_X4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q2_K_R4:
case GGML_TYPE_Q3_K:
@ -11443,6 +11517,8 @@ static void ggml_compute_forward_acc(
case GGML_TYPE_Q6_0:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
case GGML_TYPE_Q8_0_X4:
case GGML_TYPE_Q8_1_X4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q2_K_R4:
case GGML_TYPE_Q3_K:
@ -13889,6 +13965,14 @@ static void ggml_compute_forward_mul_mat_one_chunk(
}
}
static inline uint32_t simple_gcd(uint32_t a, uint32_t b) {
while (a != b) {
if (a > b) a -= b;
else b -= a;
}
return a;
}
static void ggml_compute_forward_mul_mat(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
@ -13905,10 +13989,12 @@ static void ggml_compute_forward_mul_mat(
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
ggml_from_float_t const from_float = type_traits[vec_dot_type].from_float;
ggml_from_float_to_mat_t const from_float_to_mat = type_traits[vec_dot_type].from_float_to_mat;
int64_t const vec_dot_num_rows = type_traits[type].nrows;
int64_t const matmul_num_cols = type_traits[type].ncols;
#if !GGML_USE_IQK_MULMAT
ggml_from_float_to_mat_t const from_float_to_mat = type_traits[vec_dot_type].from_float_to_mat;
int64_t const blck_size_interleave = type_traits[type].blck_size_interleave;
#endif
ggml_gemv_t const gemv = type_traits[type].gemv;
ggml_gemm_t const gemm = type_traits[type].gemm;
@ -14011,6 +14097,7 @@ UseGgmlGemm1:;
for (int64_t i13 = 0; i13 < ne13; ++i13) {
for (int64_t i12 = 0; i12 < ne12; ++i12) {
int64_t i11_processed = 0;
#if !GGML_USE_IQK_MULMAT
if ((ggml_n_dims(src1) == 2) && from_float_to_mat && gemm) {
for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
from_float_to_mat((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
@ -14019,6 +14106,7 @@ UseGgmlGemm1:;
}
i11_processed = ne11 - ne11 % 4;
}
#endif
for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
(void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
@ -14049,14 +14137,31 @@ AlreadyQuantized:;
#if GGML_USE_IQK_MULMAT
if (src1->type != vec_dot_type && dst->type == GGML_TYPE_F32) {
// When K*Q and V*softmax(K*Q) (so ne12*ne13 > 1), it is better (faster) to have fewer threads processing
// one matrix multiplication, but work on several heads at once.
// Hence, we find the GCD(n12*ne13, nth) and have nth/GCD(n12*ne13, nth) threads per head.
// Leaving the previous version commented out for now just in case.
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
for (int64_t i13 = 0; i13 < ne13; i13++)
for (int64_t i12 = 0; i12 < ne12; i12++)
if (!iqk_mul_mat(ne01, ne11, ne00,
src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01, ///ggml_type_size(src0->type),
vec_dot_type, (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size, row_size, ///ggml_type_size(vec_dot_type),
(float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type),
ith, nth)) goto IQK_MulMat_Not_Available2;
int ntg = simple_gcd(ne12*ne13, nth);
int counter = 0;
for (int64_t i13 = 0; i13 < ne13; i13++) {
for (int64_t i12 = 0; i12 < ne12; i12++) {
if (counter++ % ntg == ith%ntg) {
if (!iqk_mul_mat(ne01, ne11, ne00,
src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01, ///ggml_type_size(src0->type),
vec_dot_type, (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size, row_size, ///ggml_type_size(vec_dot_type),
(float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type),
ith/ntg, nth/ntg)) goto IQK_MulMat_Not_Available2;
}
}
}
//for (int64_t i13 = 0; i13 < ne13; i13++)
// for (int64_t i12 = 0; i12 < ne12; i12++)
// if (!iqk_mul_mat(ne01, ne11, ne00,
// src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01, ///ggml_type_size(src0->type),
// vec_dot_type, (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size, row_size, ///ggml_type_size(vec_dot_type),
// (float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type),
// ith, nth)) goto IQK_MulMat_Not_Available2;
return;
}
IQK_MulMat_Not_Available2:;
@ -15055,6 +15160,8 @@ static void ggml_compute_forward_set(
case GGML_TYPE_Q6_0:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
case GGML_TYPE_Q8_0_X4:
case GGML_TYPE_Q8_1_X4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q2_K_R4:
case GGML_TYPE_Q3_K:
@ -15352,6 +15459,8 @@ static void ggml_compute_forward_get_rows(
case GGML_TYPE_Q6_0:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
case GGML_TYPE_Q8_0_X4:
case GGML_TYPE_Q8_1_X4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q2_K_R4:
case GGML_TYPE_Q3_K:
@ -15977,6 +16086,8 @@ static void ggml_compute_forward_clamp(
case GGML_TYPE_Q6_0:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
case GGML_TYPE_Q8_0_X4:
case GGML_TYPE_Q8_1_X4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q2_K_R4:
case GGML_TYPE_Q3_K:

File diff suppressed because it is too large Load Diff

View File

@ -654,6 +654,257 @@ void quantize_row_q8_K16(const float * x, void * vy, int64_t nk) {
#endif
}
void quantize_row_q8_0_x4(const float * x, void * vy, int64_t k) {
const int nb = k / QK8_0;
const int nb4 = 4*(nb/4);
block_q8_0 * y = (block_q8_0 *)vy;
block_q8_0_x4 * y4 = (block_q8_0_x4 *)vy;
#if defined(__aarch64__)
for (int i = 0; i < nb; i++) {
int i4 = i/4, ir = i%4;
float32x4_t srcv [8];
float32x4_t asrcv[8];
float32x4_t amaxv[8];
for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j);
for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]);
for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]);
for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]);
for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]);
const float amax = vmaxvq_f32(amaxv[0]);
const float d = amax / ((1 << 7) - 1);
const float id = d ? 1.0f/d : 0.0f;
if (i < nb4) {
y4[i4].d[ir] = GGML_FP32_TO_FP16(d);
} else {
y[i].d = GGML_FP32_TO_FP16(d);
}
for (int j = 0; j < 8; j++) {
const float32x4_t v = vmulq_n_f32(srcv[j], id);
const int32x4_t vi = vcvtnq_s32_f32(v);
if (i < nb4) {
y4[i4].qs[32*ir + 4*j + 0] = vgetq_lane_s32(vi, 0);
y4[i4].qs[32*ir + 4*j + 1] = vgetq_lane_s32(vi, 1);
y4[i4].qs[32*ir + 4*j + 2] = vgetq_lane_s32(vi, 2);
y4[i4].qs[32*ir + 4*j + 3] = vgetq_lane_s32(vi, 3);
} else {
y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0);
y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1);
y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2);
y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
}
}
}
#else
for (int i = 0; i < nb; i++) {
int i4 = i/4, ir = i%4;
// Load elements into 4 AVX vectors
__m256 v0 = _mm256_loadu_ps( x );
__m256 v1 = _mm256_loadu_ps( x + 8 );
__m256 v2 = _mm256_loadu_ps( x + 16 );
__m256 v3 = _mm256_loadu_ps( x + 24 );
x += 32;
const __m256 signBit = _mm256_set1_ps( -0.0f );
__m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
const float maxScalar = _mm_cvtss_f32( max4 );
const float d = maxScalar / 127.f;
if (i < nb4) {
y4[i4].d[ir] = GGML_FP32_TO_FP16(d);
} else {
y[i].d = GGML_FP32_TO_FP16(d);
}
const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
const __m256 mul = _mm256_set1_ps( id );
v0 = _mm256_mul_ps( v0, mul );
v1 = _mm256_mul_ps( v1, mul );
v2 = _mm256_mul_ps( v2, mul );
v3 = _mm256_mul_ps( v3, mul );
v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
__m256i i0 = _mm256_cvtps_epi32( v0 );
__m256i i1 = _mm256_cvtps_epi32( v1 );
__m256i i2 = _mm256_cvtps_epi32( v2 );
__m256i i3 = _mm256_cvtps_epi32( v3 );
// Convert int32 to int16
i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
// Convert int16 to int8
i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
// We got our precious signed bytes, but the order is now wrong
// These AVX2 pack instructions process 16-byte pieces independently
// The following instruction is fixing the order
const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
i0 = _mm256_permutevar8x32_epi32( i0, perm );
if (i < nb4) {
_mm256_storeu_si256((__m256i *)y4[i4].qs + ir, i0);
} else {
_mm256_storeu_si256((__m256i *)y[i].qs, i0);
}
}
#endif
}
void quantize_row_q8_1_x4(const float * x, void * vy, int64_t k) {
assert(k % QK8_1 == 0);
const int nb = k / QK8_1;
const int nb4 = 4*(nb/4);
block_q8_1 * y = (block_q8_1 *)vy;
block_q8_1_x4 * y4 = (block_q8_1_x4 *)vy;
#if defined(__aarch64__)
for (int i = 0; i < nb; i++) {
int i4 = i/4, ir = i%4;
float32x4_t srcv [8];
float32x4_t asrcv[8];
float32x4_t amaxv[8];
for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j);
for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]);
for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]);
for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]);
for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]);
const float amax = vmaxvq_f32(amaxv[0]);
const float d = amax / ((1 << 7) - 1);
const float id = d ? 1.0f/d : 0.0f;
if (i < nb4) {
y4[i4].d[ir] = GGML_FP32_TO_FP16(d);
} else {
y[i].d = GGML_FP32_TO_FP16(d);
}
int32x4_t accv = vdupq_n_s32(0);
for (int j = 0; j < 8; j++) {
const float32x4_t v = vmulq_n_f32(srcv[j], id);
const int32x4_t vi = vcvtnq_s32_f32(v);
if (i < nb4) {
y4[i4].qs[QK8_1*ir + 4*j + 0] = vgetq_lane_s32(vi, 0);
y4[i4].qs[QK8_1*ir + 4*j + 1] = vgetq_lane_s32(vi, 1);
y4[i4].qs[QK8_1*ir + 4*j + 2] = vgetq_lane_s32(vi, 2);
y4[i4].qs[QK8_1*ir + 4*j + 3] = vgetq_lane_s32(vi, 3);
} else {
y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0);
y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1);
y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2);
y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
}
accv = vaddq_s32(accv, vi);
}
if (i < nb4) {
y4[i4].d[ir+4] = GGML_FP32_TO_FP16(d * vaddvq_s32(accv));
} else {
y[i].s = GGML_FP32_TO_FP16(d * vaddvq_s32(accv));
}
}
#else
for (int i = 0; i < nb; i++) {
int i4 = i/4, ir = i%4;
// Load elements into 4 AVX vectors
__m256 v0 = _mm256_loadu_ps( x );
__m256 v1 = _mm256_loadu_ps( x + 8 );
__m256 v2 = _mm256_loadu_ps( x + 16 );
__m256 v3 = _mm256_loadu_ps( x + 24 );
x += 32;
// Compute max(abs(e)) for the block
const __m256 signBit = _mm256_set1_ps( -0.0f );
__m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
const float max_scalar = _mm_cvtss_f32( max4 );
// Quantize these floats
const float d = max_scalar / 127.f;
if (i < nb4) {
y4[i4].d[ir] = GGML_FP32_TO_FP16(d);
} else {
y[i].d = GGML_FP32_TO_FP16(d);
}
const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f;
const __m256 mul = _mm256_set1_ps( id );
// Apply the multiplier
v0 = _mm256_mul_ps( v0, mul );
v1 = _mm256_mul_ps( v1, mul );
v2 = _mm256_mul_ps( v2, mul );
v3 = _mm256_mul_ps( v3, mul );
// Round to nearest integer
v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
// Convert floats to integers
__m256i i0 = _mm256_cvtps_epi32( v0 );
__m256i i1 = _mm256_cvtps_epi32( v1 );
__m256i i2 = _mm256_cvtps_epi32( v2 );
__m256i i3 = _mm256_cvtps_epi32( v3 );
// Compute the sum of the quants and set y[i].s
if (i < nb4) {
y4[i4].d[ir+4] = GGML_FP32_TO_FP16(d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))));
} else {
y[i].s = GGML_FP32_TO_FP16(d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))));
}
// Convert int32 to int16
i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
// Convert int16 to int8
i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
// We got our precious signed bytes, but the order is now wrong
// These AVX2 pack instructions process 16-byte pieces independently
// The following instruction is fixing the order
const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
i0 = _mm256_permutevar8x32_epi32( i0, perm );
if (i < nb4) {
_mm256_storeu_si256((__m256i *)y4[i4].qs + ir, i0);
} else {
_mm256_storeu_si256((__m256i *)y[i].qs, i0);
}
}
#endif
}
//
// ============================================== iq2_K
//

View File

@ -211,6 +211,8 @@ void quantize_row_q8_K64(const float * GGML_RESTRICT x, void * GGML_RESTRICT y,
void quantize_row_q8_K16(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q8_K32(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q8_KR8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q8_0_x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q8_1_x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void repack_f32_bf16_r16 (const void * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row);
void repack_bf16_bf16_r16(const void * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row);