Q8_KV: 8-bit quantization type targeting the KV cache (#208)

* Adding q8_KV - Basics + AVX2 gemm/gemv

* q8_KV: Better AVX2 gemm

* q8_KV: Better Zen4 gemm

We get 225.7 t/s for L3-8B. In comparison q8_0 without
run-tinme-repacking is at 169 t/s.

* q8_KV: AVX2 gemm/gemv

We get 254 t/s for L3-8B vs 194 t/s for q8_0 without rtr.

* q8_KV: be able to use it for K cache

This required quite a few fixes in ggml and llama.cpp:
* ggml: do not calculate row size as n/block_size*type_size. I had
  removed most of it when implementing the quants with per row scale,
  bit it was stull lurking in ggml_copy. Not sure if these were the last
  remnants of ggmil-style row sizes, or if there are still places left
* llama.cpp: get rid of the the 1d K cache assumption. Create and manage
  the K-cache as a 2D tensor so we can have per row meta data as needed
  by q8_KV.

Using q8_KV for K-cache results in non-negligible performance gains.
More details to follow, but for DeepSeek-Lite with MLA, we get
18% speedup for PP-8192 compared to q8_0 K-cache.

* q8_KV: be able to use it for K cache in FA

* q8_KV: repack it for K*Q in FA

* q8_KV: slightly faster gemv on Zen4

* q8_KV: slightly faster gemv on Zen4

* q8_KV: ARM_NEON

We get PP-512 = 167 t/s for L3-8B without interleaving!
We do the interleaving on the fly, so I wonder if this
could be done for other quants as well.

* q8_KV: use it in FA on NEON

* q8_KV_r8 - repacked q8_KV

On Zen4 it is slower than q8_k_r8 (292 vs 370 t/s)
This makes no sense whatsoever as the q8_KV_r8 GEMM is
basically the q8_k_r8 GEMM with the unnecessary block stuff
removed (so, one would think that it would be faster).

* q8_KV_r8: don't use nrc_y = 16 on Zen4

This is faster - 350 t/s. Why?
Much better than the 290 t/s we had before, but still slower
than the 370 t/s for q8_k_r8.

* q8_KV: nrc_y = 16 also doesn't pay off in FA

* Minor

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow 2025-02-19 11:47:07 +02:00 committed by GitHub
parent 9c74d3ef12
commit 1140b4568d
11 changed files with 983 additions and 34 deletions

View File

@ -2259,6 +2259,9 @@ static ggml_type kv_cache_type_from_str(const std::string & s) {
if (s == "q6_0") {
return GGML_TYPE_Q6_0;
}
if (s == "q8_KV") {
return GGML_TYPE_Q8_KV;
}
throw std::runtime_error("Invalid cache type: " + s);
}

View File

@ -339,6 +339,9 @@ static ggml_type ggml_type_from_name(const std::string & s) {
if (s == "q6_0") {
return GGML_TYPE_Q6_0;
}
if (s == "q8_KV") {
return GGML_TYPE_Q8_KV;
}
return GGML_TYPE_COUNT;
}

View File

@ -56,6 +56,7 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
{ "Q5_0_R4", LLAMA_FTYPE_MOSTLY_Q5_0_R4, " 5.50 bpw quantization", },
{ "Q6_0_R4", LLAMA_FTYPE_MOSTLY_Q6_0_R4, " 6.50 bpw quantization", },
{ "Q8_0_R8", LLAMA_FTYPE_MOSTLY_Q8_0_R8, " 8.50 bpw quantization", },
{ "Q8_KV", LLAMA_FTYPE_MOSTLY_Q8_KV, " 8.00 bpw quantization", },
{ "IQ4_XS", LLAMA_FTYPE_MOSTLY_IQ4_XS, " 4.25 bpw non-linear quantization", },
{ "IQ4_KS", LLAMA_FTYPE_MOSTLY_IQ4_KS, " 4.25 bpw non-linear quantization", },
{ "IQ4_KS_R4",LLAMA_FTYPE_MOSTLY_IQ4_KS_R4,"IQ4_KS repacked", },
@ -82,6 +83,7 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
{ "Q6_K", LLAMA_FTYPE_MOSTLY_Q6_K, " 5.15G, +0.0008 ppl @ LLaMA-v1-7B", },
{ "Q6_K_R4", LLAMA_FTYPE_MOSTLY_Q6_K_R4, "Q6_K repacked", },
{ "Q8_K_R8", LLAMA_FTYPE_MOSTLY_Q8_K_R8, "Q8_K repacked", },
{ "Q8_KV_R8", LLAMA_FTYPE_MOSTLY_Q8_KV_R8, "Q8_KV repacked", },
{ "Q8_0", LLAMA_FTYPE_MOSTLY_Q8_0, " 6.70G, +0.0004 ppl @ LLaMA-v1-7B", },
{ "Q4_0_4_4", LLAMA_FTYPE_MOSTLY_Q4_0_4_4, " 4.34G, +0.4685 ppl @ Llama-3-8B", },
{ "Q4_0_4_8", LLAMA_FTYPE_MOSTLY_Q4_0_4_8, " 4.34G, +0.4685 ppl @ Llama-3-8B", },

View File

@ -416,6 +416,7 @@ extern "C" {
GGML_TYPE_Q8_K32 = 148,
GGML_TYPE_Q8_KR8 = 149,
GGML_TYPE_Q8_K128 = 150,
GGML_TYPE_Q8_KV = 151,
GGML_TYPE_Q4_0_R8 = 202,
GGML_TYPE_Q5_0_R4 = 206,
@ -442,6 +443,7 @@ extern "C" {
GGML_TYPE_IQ4_K_R4 = 339,
GGML_TYPE_IQ5_K_R4 = 340,
GGML_TYPE_IQ4_KS_R4 = 344,
GGML_TYPE_Q8_KV_R8 = 398,
GGML_TYPE_Q8_K_R8 = 399,
GGML_TYPE_COUNT,
};
@ -501,6 +503,7 @@ extern "C" {
GGML_FTYPE_MOSTLY_IQ4_KS = 137, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ2_KS = 138, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ4_KSS = 139, // except 1d tensors
GGML_FTYPE_MOSTLY_Q8_KV = 140, // except 1d tensors
//
GGML_FTYPE_MOSTLY_Q4_0_R8 = 202, // except 1d tensors
GGML_FTYPE_MOSTLY_Q8_0_R8 = 207, // except 1d tensors
@ -527,6 +530,7 @@ extern "C" {
GGML_FTYPE_MOSTLY_IQ4_K_R4 = 332, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ5_K_R4 = 333, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ4_KS_R4 = 337, // except 1d tensors
GGML_FTYPE_MOSTLY_Q8_KV_R8 = 398, // except 1d tensors
GGML_FTYPE_MOSTLY_Q8_K_R8 = 399, // except 1d tensors
};

View File

@ -15214,8 +15214,10 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
case GGML_TYPE_IQ3_K_R4: break;
case GGML_TYPE_IQ4_K_R4: break;
case GGML_TYPE_IQ5_K_R4: break;
case GGML_TYPE_IQ4_KS_R4: break;
case GGML_TYPE_Q8_K_R8: break;
case GGML_TYPE_IQ4_KS_R4:break;
case GGML_TYPE_Q8_KV_R8: break;
case GGML_TYPE_Q8_K_R8: break;
case GGML_TYPE_Q8_KV: break;
case GGML_TYPE_BF16_R16: break;
case GGML_TYPE_Q4_0_4_4:
case GGML_TYPE_Q4_0_4_8:

View File

@ -1362,6 +1362,30 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_q8_K128,
.row_meta_size = 0,
},
[GGML_TYPE_Q8_KV] = {
.type_name = "q8_KV",
.blck_size = 32,
.type_size = 32,
.is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_q8_KV,
.from_float = quantize_row_q8_KV,
.from_float_ref = (ggml_from_float_t)quantize_row_q8_KV_ref,
.vec_dot = vec_dot_q8_KV_q8_KV,
.vec_dot_type = GGML_TYPE_Q8_KV,
.row_meta_size = 8,
},
[GGML_TYPE_Q8_KV_R8] = {
.type_name = "q8_KV_r8",
.blck_size = 32,
.type_size = 32,
.is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_q8_KV_r8,
.from_float = quantize_row_q8_KV_r8,
.from_float_ref = (ggml_from_float_t)quantize_row_q8_KV_r8_ref,
.vec_dot = vec_dot_q8_KV_r8_q8_KV,
.vec_dot_type = GGML_TYPE_Q8_KV,
.row_meta_size = 4,
},
[GGML_TYPE_Q8_K16] = {
.type_name = "q8_K16",
.blck_size = 64,
@ -4373,6 +4397,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
case GGML_FTYPE_MOSTLY_Q5_1: wtype = GGML_TYPE_Q5_1; break;
case GGML_FTYPE_MOSTLY_Q6_0: wtype = GGML_TYPE_Q6_0; break;
case GGML_FTYPE_MOSTLY_Q8_0: wtype = GGML_TYPE_Q8_0; break;
case GGML_FTYPE_MOSTLY_Q8_KV: wtype = GGML_TYPE_Q8_KV; break;
case GGML_FTYPE_MOSTLY_Q2_K: wtype = GGML_TYPE_Q2_K; break;
case GGML_FTYPE_MOSTLY_Q2_K_R4: wtype = GGML_TYPE_Q2_K_R4; break;
case GGML_FTYPE_MOSTLY_Q3_K: wtype = GGML_TYPE_Q3_K; break;
@ -4384,6 +4409,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
case GGML_FTYPE_MOSTLY_Q6_K: wtype = GGML_TYPE_Q6_K; break;
case GGML_FTYPE_MOSTLY_Q6_K_R4: wtype = GGML_TYPE_Q6_K_R4; break;
case GGML_FTYPE_MOSTLY_Q8_K_R8: wtype = GGML_TYPE_Q8_K_R8; break;
case GGML_FTYPE_MOSTLY_Q8_KV_R8: wtype = GGML_TYPE_Q8_KV_R8; break;
case GGML_FTYPE_MOSTLY_IQ2_XXS: wtype = GGML_TYPE_IQ2_XXS; break;
case GGML_FTYPE_MOSTLY_IQ2_XXS_R4: wtype = GGML_TYPE_IQ2_XXS_R4;break;
case GGML_FTYPE_MOSTLY_IQ2_XS: wtype = GGML_TYPE_IQ2_XS; break;
@ -9436,7 +9462,7 @@ static void ggml_compute_forward_dup_f16(
float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
size_t id = 0;
size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
size_t rs = ggml_row_size(dst->type, ne00); //nb0 * (ne00 / ggml_blck_size(dst->type));
char * dst_ptr = (char *) dst->data;
for (int i03 = 0; i03 < ne03; i03++) {
@ -9722,7 +9748,7 @@ static void ggml_compute_forward_dup_bf16(
float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
size_t id = 0;
size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
size_t rs = ggml_row_size(dst->type, ne00); //nb0 * (ne00 / ggml_blck_size(dst->type));
char * dst_ptr = (char *) dst->data;
for (int i03 = 0; i03 < ne03; i03++) {
@ -10042,7 +10068,7 @@ static void ggml_compute_forward_dup_f32(
ggml_from_float_t const quantize_row_q = type_traits[dst->type].from_float;
size_t id = 0;
size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
size_t rs = ggml_row_size(dst->type, ne00); //nb0 * (ne00 / ggml_blck_size(dst->type));
char * dst_ptr = (char *) dst->data;
for (int i03 = 0; i03 < ne03; i03++) {
@ -10936,6 +10962,7 @@ static void ggml_compute_forward_add(
case GGML_TYPE_Q6_K:
case GGML_TYPE_Q6_K_R4:
case GGML_TYPE_Q8_K_R8:
case GGML_TYPE_Q8_KV_R8:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XXS_R4:
case GGML_TYPE_IQ2_XS:
@ -11406,6 +11433,7 @@ static void ggml_compute_forward_add1(
case GGML_TYPE_Q6_K:
case GGML_TYPE_Q6_K_R4:
case GGML_TYPE_Q8_K_R8:
case GGML_TYPE_Q8_KV_R8:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XXS_R4:
case GGML_TYPE_IQ2_XS:
@ -11573,6 +11601,7 @@ static void ggml_compute_forward_acc(
case GGML_TYPE_Q6_K:
case GGML_TYPE_Q6_K_R4:
case GGML_TYPE_Q8_K_R8:
case GGML_TYPE_Q8_KV_R8:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XXS_R4:
case GGML_TYPE_IQ2_XS:
@ -14061,7 +14090,7 @@ static void ggml_compute_forward_mul_mat(
// nb01 >= nb00 - src0 is not transposed
// compute by src0 rows
#if GGML_USE_IQK_MULMAT || GGML_USE_LLAMAFILE
#if GGML_USE_LLAMAFILE
// broadcast factors
const int64_t r2 = ne12 / ne02;
const int64_t r3 = ne13 / ne03;
@ -14344,7 +14373,7 @@ static void ggml_compute_forward_mul_mat_id(
char * wdata_src1_end = (src1->type == vec_dot_type) ?
(char *) params->wdata :
(char *) params->wdata + GGML_PAD(ggml_row_size(vec_dot_type, ggml_nelements(src1)), sizeof(int64_t));
(char *) params->wdata + GGML_PAD(ggml_row_size(vec_dot_type, src1->ne[0])*ggml_nrows(src1), sizeof(int64_t));
struct mmid_row_mapping {
int32_t i1;
@ -14768,6 +14797,7 @@ static void ggml_compute_forward_out_prod(
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q6_0:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_KV:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q2_K_R4:
case GGML_TYPE_Q3_K:
@ -14779,6 +14809,7 @@ static void ggml_compute_forward_out_prod(
case GGML_TYPE_Q6_K:
case GGML_TYPE_Q6_K_R4:
case GGML_TYPE_Q8_K_R8:
case GGML_TYPE_Q8_KV_R8:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XXS_R4:
case GGML_TYPE_IQ2_XS:
@ -15186,6 +15217,7 @@ static void ggml_compute_forward_set(
case GGML_TYPE_Q6_K:
case GGML_TYPE_Q6_K_R4:
case GGML_TYPE_Q8_K_R8:
case GGML_TYPE_Q8_KV_R8:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XXS_R4:
case GGML_TYPE_IQ2_XS:
@ -15473,6 +15505,7 @@ static void ggml_compute_forward_get_rows(
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q6_0:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_KV:
case GGML_TYPE_Q8_1:
case GGML_TYPE_Q8_0_X4:
case GGML_TYPE_Q8_1_X4:
@ -15487,6 +15520,7 @@ static void ggml_compute_forward_get_rows(
case GGML_TYPE_Q6_K:
case GGML_TYPE_Q6_K_R4:
case GGML_TYPE_Q8_K_R8:
case GGML_TYPE_Q8_KV_R8:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XXS_R4:
case GGML_TYPE_IQ2_XS:
@ -16116,6 +16150,7 @@ static void ggml_compute_forward_clamp(
case GGML_TYPE_Q6_K:
case GGML_TYPE_Q6_K_R4:
case GGML_TYPE_Q8_K_R8:
case GGML_TYPE_Q8_KV_R8:
case GGML_TYPE_Q8_KR8:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XXS_R4:
@ -16159,6 +16194,7 @@ static void ggml_compute_forward_clamp(
case GGML_TYPE_Q8_K:
case GGML_TYPE_Q8_K64:
case GGML_TYPE_Q8_K128:
case GGML_TYPE_Q8_KV:
case GGML_TYPE_Q8_K16:
case GGML_TYPE_Q8_K32:
case GGML_TYPE_Q4_0_4_4:
@ -22970,6 +23006,7 @@ size_t ggml_quantize_chunk(
case GGML_TYPE_Q5_1: result = quantize_q5_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q6_0: result = quantize_q6_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q8_0: result = quantize_q8_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q8_KV: result = quantize_q8_KV(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q2_K: result = quantize_q2_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q2_K_R4: result = quantize_q2_k_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q3_K: result = quantize_q3_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
@ -22981,6 +23018,7 @@ size_t ggml_quantize_chunk(
case GGML_TYPE_Q6_K: result = quantize_q6_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q6_K_R4: result = quantize_q6_k_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q8_K_R8: result = quantize_q8_k_r8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q8_KV_R8:result = quantize_q8_KV_r8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ2_XXS: result = quantize_iq2_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ2_XXS_R4:result = quantize_iq2_xxs_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ2_XS: result = quantize_iq2_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;

View File

@ -269,6 +269,8 @@ struct MulMat {
case GGML_TYPE_IQ4_XS_R8:
case GGML_TYPE_Q4_K_R4:
case GGML_TYPE_Q5_K_R4:
case GGML_TYPE_Q8_KV:
case GGML_TYPE_Q8_KV_R8:
case GGML_TYPE_Q8_K_R8: return 8;
case GGML_TYPE_Q4_0_R8:
case GGML_TYPE_Q8_0_R8:
@ -301,6 +303,8 @@ struct MulMat {
case GGML_TYPE_IQ4_XS_R8:
case GGML_TYPE_Q4_0_R8:
case GGML_TYPE_Q8_0_R8:
case GGML_TYPE_Q8_KV:
case GGML_TYPE_Q8_KV_R8:
case GGML_TYPE_Q8_K_R8: return 8;
case GGML_TYPE_BF16_R16: return 16;
default: return 1;
@ -6107,7 +6111,7 @@ static void mul_mat_q6_k_r4_q8_k(int n, const void * vx, size_t bx, const DataIn
// The HAVE_FANCY_SIMD should only be #if defined(__AVX512_VNNI__ && defined(__AVX512VL__)
template <int nrc_y>
static void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
GGML_ASSERT(nrc_x%8 == 0);
Q8<nrc_y, block_q8_K> q8(info);
#ifndef HAVE_FANCY_SIMD
auto m1 = _mm256_set1_epi16(1);
@ -6169,6 +6173,230 @@ static void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataIn
}
}
// The HAVE_FANCY_SIMD should only be #if defined(__AVX512_VNNI__ && defined(__AVX512VL__)
template <int nrc_y>
static void mul_mat_q8_KV_r8_q8_KV(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%8 == 0);
GGML_ASSERT(n%32 == 0);
#ifndef HAVE_FANCY_SIMD
auto m1 = _mm256_set1_epi16(1);
#endif
int nb = n / 16;
__m256i acc[nrc_y] = {};
__m256i qx[4];
float dy[nrc_y];
#ifdef HAVE_FANCY_SIMD
float sy[nrc_y];
#endif
const int8_t * q8y[nrc_y];
for (int iy = 0; iy < nrc_y; ++iy) {
auto dptr = (const float *)info.src1_row(iy);
dy[iy] = dptr[0];
#ifdef HAVE_FANCY_SIMD
auto iptr = (const int32_t *)(dptr + 1);
sy[iy] = -127*iptr[0];
#endif
q8y[iy] = (const int8_t *)(dptr + 2);
}
for (int ix = 0; ix < nrc_x; ix += 8) {
auto dptr = (const float *)((const char *)vx + ix*bx);
auto dx = _mm256_loadu_ps(dptr);
auto q8x = (const int8_t *)(dptr + 8);
for (int ib = 0; ib < nb; ++ib) { // Blocks of 16 for 8 interleaved rows
qx[0] = _mm256_loadu_si256((const __m256i *)q8x+4*ib+0);
qx[1] = _mm256_loadu_si256((const __m256i *)q8x+4*ib+1);
qx[2] = _mm256_loadu_si256((const __m256i *)q8x+4*ib+2);
qx[3] = _mm256_loadu_si256((const __m256i *)q8x+4*ib+3);
#ifndef HAVE_FANCY_SIMD
auto s0 = _mm256_sign_epi8(qx[0], qx[0]);
auto s1 = _mm256_sign_epi8(qx[1], qx[1]);
auto s2 = _mm256_sign_epi8(qx[2], qx[2]);
auto s3 = _mm256_sign_epi8(qx[3], qx[3]);
#endif
for (int iy = 0; iy < nrc_y; ++iy) {
auto y128 = _mm_loadu_si128((const __m128i*)q8y[iy]+ib);
auto y = MM256_SET_M128I(y128, y128);
#ifdef HAVE_FANCY_SIMD
acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[0], _mm256_shuffle_epi32(y, 0x00));
acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[1], _mm256_shuffle_epi32(y, 0x55));
acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[2], _mm256_shuffle_epi32(y, 0xaa));
acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[3], _mm256_shuffle_epi32(y, 0xff));
#else
auto sumi1 = _mm256_maddubs_epi16(s0, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]));
auto sumi2 = _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
auto sumi3 = _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
auto sumi4 = _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
auto sumi12 = _mm256_add_epi32(_mm256_madd_epi16(m1, sumi1), _mm256_madd_epi16(m1, sumi2));
auto sumi34 = _mm256_add_epi32(_mm256_madd_epi16(m1, sumi3), _mm256_madd_epi16(m1, sumi4));
acc[iy] = _mm256_add_epi32(acc[iy], _mm256_add_epi32(sumi12, sumi34));
#endif
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto scale = _mm256_mul_ps(dx, _mm256_set1_ps(dy[iy]));
#ifdef HAVE_FANCY_SIMD
acc[iy] = _mm256_add_epi32(acc[iy], _mm256_set1_epi32(sy[iy]));
#endif
info.store(ix, iy, _mm256_mul_ps(scale, _mm256_cvtepi32_ps(acc[iy])));
acc[iy] = _mm256_setzero_si256();
}
}
}
template <int nrc_y>
static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%8 == 0);
GGML_ASSERT(n%32 == 0);
__m256i qx[2];
__m256i acc[2*nrc_y] = {};
float dy[nrc_y];
#ifdef HAVE_FANCY_SIMD
int32_t sy[nrc_y];
#else
__m256i sx[2];
auto m1 = _mm256_set1_epi16(1);
#endif
const int8_t * q8y[nrc_y];
for (int iy = 0; iy < nrc_y; ++iy) {
auto dptr = (const float *)info.src1_row(iy);
dy[iy] = dptr[0];
#ifdef HAVE_FANCY_SIMD
auto iptr = (const int32_t *)(dptr+1);
sy[iy] = -127*iptr[0];
#endif
q8y[iy] = (const int8_t *)(dptr + 2);
}
for (int ix = 0; ix < nrc_x; ++ix) {
auto dx = (const float *)((const char *)vx + ix*bx);
auto q8x = (const int8_t *)(dx + 2);
for (int i = 0; i < n/64; ++i) {
for (int j = 0; j < 2; ++j) {
#ifdef HAVE_FANCY_SIMD
qx[j] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)q8x + 2*i + j), _mm256_set1_epi8(127));
#else
qx[j] = _mm256_loadu_si256((const __m256i *)q8x + 2*i + j);
sx[j] = _mm256_sign_epi8(qx[j], qx[j]);
#endif
}
for (int iy = 0; iy < nrc_y; ++iy) {
for (int j = 0; j < 2; ++j) {
#ifdef HAVE_FANCY_SIMD
acc[2*iy+j] = _mm256_dpbusd_epi32(acc[2*iy+j], qx[j], _mm256_loadu_si256((const __m256i *)q8y[iy] + 2*i + j));
#else
auto dot = _mm256_maddubs_epi16(sx[j], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + 2*i + j), qx[j]));
acc[2*iy+j] = _mm256_add_epi32(acc[2*iy+j], _mm256_madd_epi16(m1, dot));
#endif
}
}
}
if (int i = 2*(n/64); i < n/32) {
#ifdef HAVE_FANCY_SIMD
qx[0] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)q8x + i), _mm256_set1_epi8(127));
#else
qx[0] = _mm256_loadu_si256((const __m256i *)q8x + i);
sx[0] = _mm256_sign_epi8(qx[0], qx[0]);
#endif
for (int iy = 0; iy < nrc_y; ++iy) {
#ifdef HAVE_FANCY_SIMD
acc[2*iy] = _mm256_dpbusd_epi32(acc[2*iy], qx[0], _mm256_loadu_si256((const __m256i *)q8y[iy] + i));
#else
auto dot = _mm256_maddubs_epi16(sx[0], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + i), qx[0]));
acc[2*iy] = _mm256_add_epi32(acc[2*iy], _mm256_madd_epi16(m1, dot));
#endif
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto sumi = hsum_i32_8(_mm256_add_epi32(acc[2*iy], acc[2*iy+1]));
#ifdef HAVE_FANCY_SIMD
info.store(ix, iy, dx[0]*dy[iy]*(sumi+sy[iy]));
#else
info.store(ix, iy, dx[0]*dy[iy]*sumi);
#endif
acc[2*iy] = acc[2*iy+1] = _mm256_setzero_si256();
}
}
}
template <int nrc_y>
static void mul_mat_q8_KV_q8_KV(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%8 == 0);
GGML_ASSERT(n%32 == 0);
__m256i qx[4];
#ifndef HAVE_FANCY_SIMD
__m256i sx[4];
auto m1 = _mm256_set1_epi16(1);
#endif
__m256i acc[nrc_y] = {};
float dy[nrc_y];
#ifdef HAVE_FANCY_SIMD
int32_t sy[nrc_y];
#endif
const int8_t * q8y[nrc_y];
for (int iy = 0; iy < nrc_y; ++iy) {
auto dptr = (const float *)info.src1_row(iy);
dy[iy] = dptr[0];
#ifdef HAVE_FANCY_SIMD
auto iptr = (const int32_t *)(dptr + 1);
sy[iy] = -127*iptr[0];
#endif
q8y[iy] = (const int8_t *)(dptr + 2);
}
const int8_t * q8x[4];
float dx[4];
for (int ix = 0; ix < nrc_x; ix += 4) {
for (int kx = 0; kx < 4; ++kx) {
auto dptr = (const float *)((const char *)vx + (ix+kx)*bx);
dx[kx] = dptr[0];
q8x[kx] = (const int8_t *)(dptr + 2);
}
for (int i = 0; i < n/32; ++i) {
for (int kx = 0; kx < 4; ++kx) qx[kx] = _mm256_loadu_si256((const __m256i *)q8x[kx] + i);
auto t0 = _mm256_unpacklo_epi32(qx[0], qx[1]);
auto t1 = _mm256_unpacklo_epi32(qx[2], qx[3]);
auto t2 = _mm256_unpackhi_epi32(qx[0], qx[1]);
auto t3 = _mm256_unpackhi_epi32(qx[2], qx[3]);
#ifdef HAVE_FANCY_SIMD
qx[0] = _mm256_add_epi8(_mm256_unpacklo_epi64(t0, t1), _mm256_set1_epi8(127));
qx[1] = _mm256_add_epi8(_mm256_unpackhi_epi64(t0, t1), _mm256_set1_epi8(127));
qx[2] = _mm256_add_epi8(_mm256_unpacklo_epi64(t2, t3), _mm256_set1_epi8(127));
qx[3] = _mm256_add_epi8(_mm256_unpackhi_epi64(t2, t3), _mm256_set1_epi8(127));
#else
qx[0] = _mm256_unpacklo_epi64(t0, t1); sx[0] = _mm256_sign_epi8(qx[0], qx[0]);
qx[1] = _mm256_unpackhi_epi64(t0, t1); sx[1] = _mm256_sign_epi8(qx[1], qx[1]);
qx[2] = _mm256_unpacklo_epi64(t2, t3); sx[2] = _mm256_sign_epi8(qx[2], qx[2]);
qx[3] = _mm256_unpackhi_epi64(t2, t3); sx[3] = _mm256_sign_epi8(qx[3], qx[3]);
#endif
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = _mm256_loadu_si256((const __m256i *)q8y[iy] + i);
#ifdef HAVE_FANCY_SIMD
acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[0], _mm256_shuffle_epi32(y, 0x00));
acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[1], _mm256_shuffle_epi32(y, 0x55));
acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[2], _mm256_shuffle_epi32(y, 0xaa));
acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[3], _mm256_shuffle_epi32(y, 0xff));
#else
auto dot1 = _mm256_maddubs_epi16(sx[0], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]));
auto dot2 = _mm256_maddubs_epi16(sx[1], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
auto dot3 = _mm256_maddubs_epi16(sx[2], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
auto dot4 = _mm256_maddubs_epi16(sx[3], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
auto dot12 = _mm256_add_epi32(_mm256_madd_epi16(m1, dot1), _mm256_madd_epi16(m1, dot2));
auto dot34 = _mm256_add_epi32(_mm256_madd_epi16(m1, dot3), _mm256_madd_epi16(m1, dot4));
acc[iy] = _mm256_add_epi32(acc[iy], _mm256_add_epi32(dot12, dot34));
#endif
}
}
auto scales_x = _mm_loadu_ps(dx);
for (int iy = 0; iy < nrc_y; ++iy) {
auto sumi = _mm_add_epi32(_mm256_castsi256_si128(acc[iy]), _mm256_extracti128_si256(acc[iy], 1));
#ifdef HAVE_FANCY_SIMD
sumi = _mm_add_epi32(sumi, _mm_set1_epi32(sy[iy]));
#endif
auto scale = _mm_mul_ps(scales_x, _mm_set1_ps(dy[iy]));
info.store(ix, iy, _mm_mul_ps(scale, _mm_cvtepi32_ps(sumi)));
acc[iy] = _mm256_setzero_si256();
}
}
}
#ifdef __AVX512BF16__
template <int nrc_y>
static void mul_mat_bf16_r16_bf16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
@ -9114,6 +9342,33 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
#endif
expected_typeB = GGML_TYPE_Q8_KR8;
break;
case GGML_TYPE_Q8_KV:
assert (ne00 % 32 == 0);
mm.funcs[0] = mul_mat_q8_KV_q8_KV_1<1>;
mm.funcs[1] = mul_mat_q8_KV_q8_KV<2>;
mm.funcs[2] = mul_mat_q8_KV_q8_KV<3>;
mm.funcs[3] = mul_mat_q8_KV_q8_KV<4>;
mm.funcs[4] = mul_mat_q8_KV_q8_KV<5>;
mm.funcs[5] = mul_mat_q8_KV_q8_KV<6>;
mm.funcs[6] = mul_mat_q8_KV_q8_KV<7>;
mm.funcs[7] = mul_mat_q8_KV_q8_KV<8>;
#ifdef HAVE_FANCY_SIMD
mm.func16 = mul_mat_q8_KV_q8_KV<16>;
#endif
expected_typeB = GGML_TYPE_Q8_KV;
break;
case GGML_TYPE_Q8_KV_R8:
assert (ne00 % 32 == 0);
mm.funcs[0] = mul_mat_q8_KV_r8_q8_KV<1>;
mm.funcs[1] = mul_mat_q8_KV_r8_q8_KV<2>;
mm.funcs[2] = mul_mat_q8_KV_r8_q8_KV<3>;
mm.funcs[3] = mul_mat_q8_KV_r8_q8_KV<4>;
mm.funcs[4] = mul_mat_q8_KV_r8_q8_KV<5>;
mm.funcs[5] = mul_mat_q8_KV_r8_q8_KV<6>;
mm.funcs[6] = mul_mat_q8_KV_r8_q8_KV<7>;
mm.funcs[7] = mul_mat_q8_KV_r8_q8_KV<8>;
expected_typeB = GGML_TYPE_Q8_KV;
break;
case GGML_TYPE_IQ4_K_R4:
assert (ne00 % QK_K == 0);
mm.funcs[0] = mul_mat_iq4_k_r4_q8_k<1>;
@ -13424,6 +13679,123 @@ void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataInfo& inf
}
}
static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(n%32 == 0);
int32x4_t acc[4] = {};
auto dptr = (const float *)info.src1_row(0);
const float dy = dptr[0];
auto q8y = (const int8_t *)(dptr + 2);
for (int ix = 0; ix < nrc_x; ++ix) {
auto dx = (const float *)((const char *)vx + ix*bx);
auto q8x = (const int8_t *)(dx + 2);
for (int i = 0; i < n/64; ++i) {
auto qx = vld1q_s8_x4(q8x + 64*i);
for (int j = 0; j < 4; ++j) {
acc[j] = ggml_vdotq_s32(acc[j], qx.val[j], vld1q_s8(q8y + 64*i + 16*j));
}
}
if (int i = 2*(n/64); i < n/32) {
auto qx = vld1q_s8_x2(q8x + 32*i);
for (int j = 0; j < 2; ++j) {
acc[j] = ggml_vdotq_s32(acc[j], qx.val[j], vld1q_s8(q8y + 32*i + 16*j));
}
}
acc[0] = vaddq_s32(acc[0], acc[1]);
acc[2] = vaddq_s32(acc[2], acc[3]);
acc[0] = vaddq_s32(acc[0], acc[2]);
info.store(ix, 0, dx[0]*dy*vaddvq_s32(acc[0]));
acc[0] = acc[1] = acc[2] = acc[3] = vdupq_n_s32(0);
}
}
template <int nrc_y>
static void mul_mat_q8_KV_q8_KV(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
GGML_ASSERT(n%16 == 0);
int8x16_t qx[4];
int32x4_t acc[nrc_y] = {};
float dy[nrc_y];
const int8_t * q8y[nrc_y];
for (int iy = 0; iy < nrc_y; ++iy) {
auto dptr = (const float *)info.src1_row(iy);
dy[iy] = dptr[0];
q8y[iy] = (const int8_t *)(dptr + 2);
}
const int8_t * q8x[4];
float dx[4];
for (int ix = 0; ix < nrc_x; ix += 4) {
for (int kx = 0; kx < 4; ++kx) {
auto dptr = (const float *)((const char *)vx + (ix+kx)*bx);
dx[kx] = dptr[0];
q8x[kx] = (const int8_t *)(dptr + 2);
}
for (int i = 0; i < n/16; ++i) {
for (int kx = 0; kx < 4; ++kx) qx[kx] = vld1q_s8(q8x[kx] + 16*i);
auto row01 = vtrnq_s32(qx[0], qx[1]);
auto row23 = vtrnq_s32(qx[2], qx[3]);
qx[0] = vtrn1q_s64(row01.val[0], row23.val[0]);
qx[1] = vtrn1q_s64(row01.val[1], row23.val[1]);
qx[2] = vtrn2q_s64(row01.val[0], row23.val[0]);
qx[3] = vtrn2q_s64(row01.val[1], row23.val[1]);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8(q8y[iy] + 16*i);
acc[iy] = vdotq_laneq_s32(acc[iy], qx[0], y, 0);
acc[iy] = vdotq_laneq_s32(acc[iy], qx[1], y, 1);
acc[iy] = vdotq_laneq_s32(acc[iy], qx[2], y, 2);
acc[iy] = vdotq_laneq_s32(acc[iy], qx[3], y, 3);
}
}
auto scales_x = vld1q_f32(dx);
for (int iy = 0; iy < nrc_y; ++iy) {
auto scale = vmulq_f32(scales_x, vdupq_n_f32(dy[iy]));
info.store(ix, iy, vmulq_f32(scale, vcvtq_f32_s32(acc[iy])));
acc[iy] = vdupq_n_s32(0);
}
}
}
template <int nrc_y>
void mul_mat_q8_KV_r8_q8_KV(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%8 == 0);
int32x4_t acc[2*nrc_y] = {};
float dy[nrc_y];
const int8_t * q8y[nrc_y];
for (int iy = 0; iy < nrc_y; ++iy) {
auto dptr = (const float *)info.src1_row(iy);
dy[iy] = dptr[0];
q8y[iy] = (const int8_t *)(dptr + 2);
}
for (int ix = 0; ix < nrc_x; ix += 8) {
const float * dptr = (const float *)((const char *)vx + ix*bx);
auto q8x = (const int8_t *)(dptr + 8);
for (int ib = 0; ib < n/16; ++ib) {
auto q1 = vld1q_s8_x4(q8x + 128*ib + 0);
auto q2 = vld1q_s8_x4(q8x + 128*ib + 64);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8(q8y[iy]+16*ib);
acc[2*iy+0] = vdotq_laneq_s32(acc[2*iy+0], q1.val[0], y, 0);
acc[2*iy+1] = vdotq_laneq_s32(acc[2*iy+1], q1.val[1], y, 0);
acc[2*iy+0] = vdotq_laneq_s32(acc[2*iy+0], q1.val[2], y, 1);
acc[2*iy+1] = vdotq_laneq_s32(acc[2*iy+1], q1.val[3], y, 1);
acc[2*iy+0] = vdotq_laneq_s32(acc[2*iy+0], q2.val[0], y, 2);
acc[2*iy+1] = vdotq_laneq_s32(acc[2*iy+1], q2.val[1], y, 2);
acc[2*iy+0] = vdotq_laneq_s32(acc[2*iy+0], q2.val[2], y, 3);
acc[2*iy+1] = vdotq_laneq_s32(acc[2*iy+1], q2.val[3], y, 3);
}
}
auto scale1_x = vld1q_f32(dptr+0);
auto scale2_x = vld1q_f32(dptr+4);
for (int iy = 0; iy < nrc_y; ++iy) {
auto scale_y = vdupq_n_f32(dy[iy]);
auto scale1 = vmulq_f32(scale1_x, scale_y);
auto scale2 = vmulq_f32(scale2_x, scale_y);
info.store(ix+0, iy, vmulq_f32(scale1, vcvtq_f32_s32(acc[2*iy+0])));
info.store(ix+4, iy, vmulq_f32(scale2, vcvtq_f32_s32(acc[2*iy+1])));
acc[2*iy+0] = acc[2*iy+1] = vdupq_n_f32(0.f);
}
}
}
void mul_mat_iq4_nl_r4_q8_0_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<1, block_q8_0_x4> q8(info);
@ -14000,6 +14372,16 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
SET_MUL_MAT_FUNCTIONS(m, mul_mat_q8_k_r8_q8_k);
expected_Btype = GGML_TYPE_Q8_KR8;
break;
case GGML_TYPE_Q8_KV:
SET_MUL_MAT_FUNCTIONS(m, mul_mat_q8_KV_q8_KV);
m.funcs[0] = mul_mat_q8_KV_q8_KV_1;
m.func16 = mul_mat_q8_KV_q8_KV<16>;
expected_Btype = GGML_TYPE_Q8_KV;
break;
case GGML_TYPE_Q8_KV_R8:
SET_MUL_MAT_FUNCTIONS(m, mul_mat_q8_KV_r8_q8_KV);
expected_Btype = GGML_TYPE_Q8_KV;
break;
case GGML_TYPE_IQ2_K_R4:
SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq2_k_r4_q8_k);
expected_Btype = GGML_TYPE_Q8_K;
@ -14347,13 +14729,49 @@ struct HelperF16 final : public BaseHelper<step> {
}
};
template <int D> struct block_q8_KV {
float d;
int s;
int8_t qs[D];
};
template <int D, int step>
struct HelperQ8KV final : public BaseHelper<step> {
using Base = BaseHelper<step>;
using block_q8 = block_q8_KV<D>;
constexpr static int block_size_q = D;
HelperQ8KV(const char * data, int stride) : Base(data, stride) {}
// Needed for v * softmax(k * q)
inline void load(int l1, int i, F16::Data& v1, F16::Data& v2) const {
auto q8 = (const block_q8_KV<D> *)Base::lblock(l1);
#ifdef __aarch64__
auto vd = F16::set1(q8->d);
auto qs = vld1_s8_x2(q8->qs + 8*i);
v1 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(qs.val[0])));
v2 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(qs.val[1])));
#else
auto vd = F16::set1(q8->d);
#ifdef HAVE_FANCY_SIMD
v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)q8->qs+i+0))));
v2 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)q8->qs+i+1))));
#else
v1 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(q8->qs+8*i+0)))));
v2 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(q8->qs+8*i+8)))));
#endif
#endif
}
};
template <int D, int step>
struct HelperQ80 final : public BaseHelper<step> {
using Base = BaseHelper<step>;
#ifdef HAVE_FANCY_SIMD
using block_q8 = block_q8_1;
constexpr static int block_size_q = QK8_1;
#else
using block_q8 = block_q8_0;
constexpr static int block_size_q = QK8_0;
#endif
HelperQ80(const char * data, int stride) : Base(data, stride) {}
@ -14397,23 +14815,33 @@ struct HelperQ80 final : public BaseHelper<step> {
y += D/QK8_1;
}
}
static inline void convert(int nq, int stride_q, const float * q, block_q8_KV<D> * y) {
for (int i = 0; i < nq; ++i) {
quantize_row_q8_KV(q, y, D);
q += stride_q;
++y;
}
}
};
template <int D, int step>
struct HelperQ80R4 : public BaseHelper<step> {
struct HelperQ80R8 : public BaseHelper<step> {
using Base = BaseHelper<step>;
#ifdef __AVX2__
constexpr static int block_size_q = QK8_1;
using block_q8 = block_q8_1;
#else
constexpr static int block_size_q = QK8_0;
using block_q8 = block_q8_0;
#endif
HelperQ80R4(int nk, const HelperQ80<D, step>& q8) : Base(q8.data, q8.stride) {
HelperQ80R8(int nk, const HelperQ80<D, step>& q8) : Base(q8.data, q8.stride) {
r4 = repack(nk, q8);
Base::data = (const char *)r4.data();
Base::stride = (D/QK8_0)*sizeof(block_q8_0);
}
static std::vector<block_q8_0_r8> repack(int nk, const HelperQ80<D, step> q8) {
static std::vector<block_q8_0_r8> repack(int nk, const HelperQ80<D, step>& q8) {
static_assert(D%QK8_0 == 0);
GGML_ASSERT(nk%8 == 0);
constexpr int nblock = D/QK8_0;
@ -14512,10 +14940,107 @@ struct HelperQ80R4 : public BaseHelper<step> {
std::vector<block_q8_0_r8> r4;
};
// TODO: unite this with the above
template <int D, int step>
struct HelperQ8KVR8 : public BaseHelper<step> {
using Base = BaseHelper<step>;
constexpr static int block_size_q = D;
using block_q8 = block_q8_KV<D>;
struct block_q8_KV_r8 {
float d[8];
int8_t qs[8*D];
};
HelperQ8KVR8(int nk, const HelperQ8KV<D, step>& q8) : Base(q8.data, q8.stride) {
r4 = repack(nk, q8);
Base::data = (const char *)r4.data();
Base::stride = sizeof(block_q8_KV_r8)/8;
}
static std::vector<block_q8_KV_r8> repack(int nk, const HelperQ8KV<D, step>& q8) {
static_assert(D%32 == 0);
GGML_ASSERT(nk%8 == 0);
std::vector<block_q8_KV_r8> result(nk/8);
auto y = result.data();
#ifdef __ARM_NEON
int8x16x2_t m0, m1, m2, m3;
#endif
const int8_t * x8[8];
for (int ix = 0; ix < nk/8; ++ix) {
for (int k = 0; k < 8; ++k) {
auto dptr = (const float *)(q8.data + (8*ix + k)*q8.stride);
y[ix].d[k] = dptr[0];
x8[k] = (const int8_t *)(dptr + 2);
}
for (int ib = 0; ib < D/16; ++ib) {
#ifdef __AVX2__
auto m0 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[4]+ib), _mm_loadu_si128((const __m128i *)x8[0]+ib));
auto m1 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[5]+ib), _mm_loadu_si128((const __m128i *)x8[1]+ib));
auto m2 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[6]+ib), _mm_loadu_si128((const __m128i *)x8[2]+ib));
auto m3 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[7]+ib), _mm_loadu_si128((const __m128i *)x8[3]+ib));
auto t0 = _mm256_unpacklo_epi32(m0, m1);
auto t1 = _mm256_unpacklo_epi32(m2, m3);
auto t2 = _mm256_unpackhi_epi32(m0, m1);
auto t3 = _mm256_unpackhi_epi32(m2, m3);
m0 = _mm256_unpacklo_epi64(t0, t1);
m1 = _mm256_unpackhi_epi64(t0, t1);
m2 = _mm256_unpacklo_epi64(t2, t3);
m3 = _mm256_unpackhi_epi64(t2, t3);
#ifdef HAVE_FANCY_SIMD
m0 = _mm256_add_epi8(m0, _mm256_set1_epi8(127));
m1 = _mm256_add_epi8(m1, _mm256_set1_epi8(127));
m2 = _mm256_add_epi8(m2, _mm256_set1_epi8(127));
m3 = _mm256_add_epi8(m3, _mm256_set1_epi8(127));
#endif
_mm256_storeu_si256((__m256i *)y[ix].qs + 4*ib+0, m0);
_mm256_storeu_si256((__m256i *)y[ix].qs + 4*ib+1, m1);
_mm256_storeu_si256((__m256i *)y[ix].qs + 4*ib+2, m2);
_mm256_storeu_si256((__m256i *)y[ix].qs + 4*ib+3, m3);
#elif defined __ARM_NEON
// TODO
m0.val[0] = vld1q_s8(x8[0]+16*ib); m0.val[1] = vld1q_s8(x8[4]+16*ib);
m1.val[0] = vld1q_s8(x8[1]+16*ib); m1.val[1] = vld1q_s8(x8[5]+16*ib);
m2.val[0] = vld1q_s8(x8[2]+16*ib); m2.val[1] = vld1q_s8(x8[6]+16*ib);
m3.val[0] = vld1q_s8(x8[3]+16*ib); m3.val[1] = vld1q_s8(x8[7]+16*ib);
auto row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[0]), vreinterpretq_s32_s8(m1.val[0]));
auto row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[0]), vreinterpretq_s32_s8(m3.val[0]));
m0.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
m1.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
m2.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
m3.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[1]), vreinterpretq_s32_s8(m1.val[1]));
row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[1]), vreinterpretq_s32_s8(m3.val[1]));
m0.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
m1.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
m2.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
m3.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
vst1q_s8_x2(y[ix].qs + 0 + 128*ib, m0);
vst1q_s8_x2(y[ix].qs + 32 + 128*ib, m1);
vst1q_s8_x2(y[ix].qs + 64 + 128*ib, m2);
vst1q_s8_x2(y[ix].qs + 96 + 128*ib, m3);
#else
// TODO
for (int l = 0; l < 4; ++l) {
for (int k = 0; k < 8; ++k) for (int i = 0; i < 4; ++i) {
y[ib].qs[32*l+4*k+i+ 0] = x8[k][ib].qs[i+4*l+ 0];
y[ib].qs[32*l+4*k+i+128] = x8[k][ib].qs[i+4*l+16];
}
}
#endif
}
}
return result;
}
std::vector<block_q8_KV_r8> r4;
};
template <int D, int step>
struct HelperQ40 final : public BaseHelper<step> {
using Base = BaseHelper<step>;
using block_q8 = block_q8_0;
constexpr static int block_size_q = QK8_0;
HelperQ40(const char * data, int stride) : Base(data, stride) {}
// Needed for v * softmax(k * q)
@ -14559,6 +15084,7 @@ template <int D, int step>
struct HelperQ41 final : public BaseHelper<step> {
using Base = BaseHelper<step>;
using block_q8 = block_q8_1;
constexpr static int block_size_q = QK8_1;
HelperQ41(const char * data, int stride) : Base(data, stride) {}
// Needed for v * softmax(k * q)
@ -14649,8 +15175,10 @@ template <int D, int step>
struct HelperQ60 final : public BaseHelper<step> {
#ifdef __aarch64__
using block_q8 = block_q8_0;
constexpr static int block_size_q = QK8_0;
#else
using block_q8 = block_q8_1;
constexpr static int block_size_q = QK8_1;
#endif
using Base = BaseHelper<step>;
HelperQ60(const char * data, int stride) : Base(data, stride) {}
@ -15071,9 +15599,9 @@ struct FlashQKV {
}
inline void normalize_and_store(const FlashMS<q_step, k_step>& fms, int j, const qkv_cache_t * R, float * qkv) const {
GGML_ASSERT(fms.S[j] > 0);
auto norm = F16::set1(1/fms.S[j]);
//auto norm = F16::set1(fms.S[j] > 0 ? 1/fms.S[j] : 0.f);
//GGML_ASSERT(fms.S[j] > 0);
//auto norm = F16::set1(1/fms.S[j]);
auto norm = F16::set1(fms.S[j] > 0 ? 1/fms.S[j] : 0.f);
for (int i = 0; i < D/F16::block_size; ++i) {
auto r = F16::load(R + F16::block_size*i);
F16::store(qkv + F16::block_size*i, F16::mul(norm, r));
@ -15357,13 +15885,29 @@ struct FlashQKfp32 {
#endif
#endif
}
else if constexpr (std::is_same_v<KHelper, HelperQ80R4<D, k_step>>) {
else if constexpr (std::is_same_v<KHelper, HelperQ8KV<D, k_step>>) {
#ifdef __aarch64__
if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_q8_KV<16>, 16);
if (nq == 1) return std::make_pair(mul_mat_q8_KV_q8_KV_1, 1);
MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_q8_KV, nq);
#else
#ifdef HAVE_FANCY_SIMD
if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_q8_KV<16>, 16);
#endif
if (nq == 1) return std::make_pair(mul_mat_q8_KV_q8_KV_1<1>, 1);
MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_q8_KV, nq);
#endif
}
else if constexpr (std::is_same_v<KHelper, HelperQ80R8<D, k_step>>) {
#ifdef __aarch64__
MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r8_q8_0, nq);
#else
MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r8_q8_1, nq);
#endif
}
else if constexpr (std::is_same_v<KHelper, HelperQ8KVR8<D, k_step>>) {
MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_r8_q8_KV, nq);
}
else if constexpr (std::is_same_v<KHelper, HelperQ60<D, k_step>>) {
#ifdef __aarch64__
MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ60, nq);
@ -15406,7 +15950,7 @@ struct FlashQKfp32 {
constexpr int kMaxQ = 8;
static_assert(q_step < kMaxQ || q_step%kMaxQ == 0);
auto [mul_mat, nrc_q] = mul_mat_kernel<KHelper>(q_step);
DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr};
DataInfo info{fms.cache, (const char *)q, k_step, (D/KHelper::block_size_q)*sizeof(block_q8), 0, 1, nullptr};
for (int iq = 0; iq < q_step/nrc_q; ++iq) {
mul_mat(D, kh.block, kh.stride, info, k_step);
info.cur_y += nrc_q;
@ -15428,7 +15972,7 @@ struct FlashQKfp32 {
static inline void mul_mask_kq(int nq, const KHelper& kh, int stride_m,
const block_q8 * q, const char * mask, FlashMS<q_step, k_step>& fms) {
auto [mul_mat, nrc_q] = mul_mat_kernel<KHelper>(nq);
DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr};
DataInfo info{fms.cache, (const char *)q, k_step, (D/KHelper::block_size_q)*sizeof(block_q8), 0, 1, nullptr};
for (int iq = 0; iq < nq/nrc_q; ++iq) {
mul_mat(D, kh.block, kh.stride, info, k_step);
info.cur_y += nrc_q;
@ -15516,7 +16060,7 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q,
FlashMS<q_step, k_step>& fms,
FlashQKV<Dv, q_step, k_step>& fqkv,
const float * q, const char * mask, float * qkv) {
typename KHelper::block_q8 q8[q_step*(Dk/QK8_0)];
typename KHelper::block_q8 q8[q_step*(Dk/KHelper::block_size_q)];
#if FA_TIMING
Perf perf(false);
#endif
@ -15613,12 +16157,28 @@ struct FlashAttn {
if (nq1 >= 8) {
#if FA_TIMING
auto t1 = Perf::cur_time();
HelperQ80R4<Dk, k_step> khr4(nk1, kh);
HelperQ80R8<Dk, k_step> khr4(nk1, kh);
Perf::instance().accum(4, t1);
#else
HelperQ80R4<Dk, k_step> khr4(nk1, kh);
HelperQ80R8<Dk, k_step> khr4(nk1, kh);
#endif
compute_helper_q<Dk, Dv, q_step, k_step, HelperQ80R4<Dk, k_step>, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
compute_helper_q<Dk, Dv, q_step, k_step, HelperQ80R8<Dk, k_step>, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv);
} else{
compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv);
}
}
else if constexpr (std::is_same_v<KHelper, HelperQ8KV<Dk, k_step>>) {
if (nq1 >= 8) {
#if FA_TIMING
auto t1 = Perf::cur_time();
HelperQ8KVR8<Dk, k_step> khr4(nk1, kh);
Perf::instance().accum(4, t1);
#else
HelperQ8KVR8<Dk, k_step> khr4(nk1, kh);
#endif
compute_helper_q<Dk, Dv, q_step, k_step, HelperQ8KVR8<Dk, k_step>, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv);
} else{
compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
@ -16142,6 +16702,10 @@ inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v,
HelperQ80<Dv, k_step> vh(v, stride_v);
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
} break;
case GGML_TYPE_Q8_KV: {
HelperQ8KV<Dv, k_step> vh(v, stride_v);
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
} break;
case GGML_TYPE_Q6_0: {
HelperQ60<Dv, k_step> vh(v, stride_v);
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
@ -16179,6 +16743,10 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v,
HelperQ80<Dk, k_step> kh(k, stride_k);
iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
} break;
case GGML_TYPE_Q8_KV: {
HelperQ8KV<Dk, k_step> kh(k, stride_k);
iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
} break;
case GGML_TYPE_Q6_0: {
HelperQ60<Dk, k_step> kh(k, stride_k);
iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
@ -16210,7 +16778,7 @@ inline bool flash_attn_is_supported(ggml_type type) {
if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 ||
type == GGML_TYPE_Q6_0 || type == GGML_TYPE_IQ4_NL) return true;
#else
if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q6_0) return true;
if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q6_0 || type == GGML_TYPE_Q8_KV) return true;
#endif
return false;
}

View File

@ -2967,6 +2967,103 @@ void iqk_quantize_row_q8_K128(const float * x, void * vy, int64_t k) {
}
#endif
}
// TODO: merge this with the above template
void iqk_quantize_row_q8_KV(const float * x, void * vy, int64_t k) {
assert(k % 32 == 0);
auto dptr = (float *)vy;
auto q8 = (int8_t *)(dptr + 2);
#ifdef __AVX2__
const __m256 signBit = _mm256_set1_ps(-0.0f);
const __m256i perm = _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7);
__m256 maxAbs = _mm256_setzero_ps();
for (int ib = 0; ib < k/8; ++ib) {
const __m256 v = _mm256_loadu_ps(x + 8*ib);
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps(signBit, v));
}
const float maxScalar = hmax_f32_8(maxAbs);
if (!maxScalar) {
dptr[0] = dptr[1] = 0;
std::memset(q8, 0, k*sizeof(int8_t));
return;
}
dptr[0] = maxScalar / 127.f;
auto mul = _mm256_set1_ps(1/dptr[0]);
auto isum = _mm256_setzero_si256();
for (int i = 0; i < k/32; i++) {
__m256 v0 = _mm256_mul_ps(mul, _mm256_loadu_ps(x + 32*i + 0));
__m256 v1 = _mm256_mul_ps(mul, _mm256_loadu_ps(x + 32*i + 8));
__m256 v2 = _mm256_mul_ps(mul, _mm256_loadu_ps(x + 32*i + 16));
__m256 v3 = _mm256_mul_ps(mul, _mm256_loadu_ps(x + 32*i + 24));
v0 = _mm256_round_ps(v0, _MM_ROUND_NEAREST);
v1 = _mm256_round_ps(v1, _MM_ROUND_NEAREST);
v2 = _mm256_round_ps(v2, _MM_ROUND_NEAREST);
v3 = _mm256_round_ps(v3, _MM_ROUND_NEAREST);
__m256i i0 = _mm256_cvtps_epi32(v0);
__m256i i1 = _mm256_cvtps_epi32(v1);
__m256i i2 = _mm256_cvtps_epi32(v2);
__m256i i3 = _mm256_cvtps_epi32(v3);
isum = _mm256_add_epi32(isum, _mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)));
i0 = _mm256_packs_epi32( i0, i1 );
i2 = _mm256_packs_epi32( i2, i3 );
i0 = _mm256_packs_epi16( i0, i2 );
i0 = _mm256_permutevar8x32_epi32( i0, perm );
_mm256_storeu_si256((__m256i *)q8, i0);
q8 += 32;
}
auto iptr = (int32_t *)(dptr + 1);
iptr[0] = hsum_i32_8(isum);
#elif defined __ARM_NEON
int32x4_t ival[8];
auto vmax = vdupq_n_f32(0.f);
for (int j = 0; j < k; j += 4) {
vmax = vmaxq_f32(vmax, vabsq_f32(vld1q_f32(x + j)));
}
auto smax = vmaxvq_f32(vmax);
if (!smax) {
dptr[0] = dptr[1] = 0;
std::memset(q8, 0, k*sizeof(int8_t));
return;
}
dptr[0] = smax/127;
auto vid = vdupq_n_f32(1/dptr[0]);
auto isum = vdupq_n_s32(0);
for (int ib = 0; ib < k/32; ++ib) {
auto xb = x + 32*ib;
for (int k = 0; k < 8; ++k) {
auto val = vld1q_f32(xb + 4*k);
ival[k] = vcvtnq_s32_f32(vmulq_f32(val, vid));
isum = vaddq_s32(isum, ival[k]);
}
for (int k = 0; k < 4; ++k) {
auto i16 = vcombine_s16(vmovn_s32(ival[2*k+0]), vmovn_s32(ival[2*k+1]));
vst1_s8(q8, vmovn_s16(i16));
q8 += 8;
}
}
auto iptr = (int32_t *)(dptr + 1);
iptr[0] = vaddvq_s32(isum);
#else
float amax = 0;
for (int j = 0; j < k; ++j) {
float ax = std::abs(x[j]);
amax = std::max(amax, ax);
}
if (!amax) {
dptr[0] = dptr[1] = 0;
std::memset(q8, 0, k*sizeof(int8_t));
return;
}
dptr[0] = amax/127;
float id = 1/dptr[0];
int isum = 0;
for (int i = 0; i < k; i++) {
q8[i] = nearest_int(id*x[i]);
isum += q8[i];
}
auto iptr = (int32_t *)(dptr + 1);
iptr[0] = isum;
#endif
}
}
void quantize_row_q8_K128(const float * x, void * vy, int64_t k) {
@ -3886,7 +3983,7 @@ static void repack_q8_0(int nrows, int n_per_row, const block_q8_0 * x, block_q8
#ifdef HAVE_FANCY_SIMD
static void modify_q8_0_r8(int64_t k, char * cy) {
auto y = (block_iq4_nl_r8 *)cy;
auto y = (block_q8_0_r8 *)cy;
int nb = k/(32*8);
for (int ib = 0; ib < nb; ++ib) {
for (int l = 0; l < 4; ++l) {
@ -5412,6 +5509,150 @@ void vec_dot_q8_k_r8_q8_k(int n, float * s, size_t bs, const void * vx, size_t b
GGML_UNUSED(by);
}
//
// ========================================= q8_KV_r8
//
void quantize_row_q8_KV_r8_ref(const float * x, void * y, int64_t k) {
quantize_q8_KV_r8(x, y, 8, k/8, nullptr);
}
void quantize_row_q8_KV_r8(const float * x, void * y, int64_t k) {
quantize_q8_KV_r8(x, y, 8, k/8, nullptr);
}
static void repack_q8_KV(int nrows, int n_per_row, const char * cx, char * cy, [[maybe_unused]] bool online) {
GGML_ASSERT(nrows%8 == 0);
GGML_ASSERT(n_per_row%16 == 0);
auto row_size_x = ggml_row_size(GGML_TYPE_Q8_KV, n_per_row);
auto row_size_y = ggml_row_size(GGML_TYPE_Q8_KV_R8, n_per_row);
const int8_t * x8[8];
#ifdef __ARM_NEON
int8x16x2_t m0, m1, m2, m3;
#endif
for (int row = 0; row < nrows; row += 8) {
auto dy = (float *)cy;
auto qy = (int8_t *)(dy + 8);
for (int k = 0; k < 8; ++k) {
auto dx = (const float *)(cx + k*row_size_x);
dy[k] = dx[0];
x8[k] = (const int8_t *)(dx + 2);
}
for (int ib = 0; ib < n_per_row/16; ++ib) {
#ifdef __AVX2__
#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
auto m0 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[4]+ib), _mm_loadu_si128((const __m128i *)x8[0]+ib));
auto m1 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[5]+ib), _mm_loadu_si128((const __m128i *)x8[1]+ib));
auto m2 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[6]+ib), _mm_loadu_si128((const __m128i *)x8[2]+ib));
auto m3 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[7]+ib), _mm_loadu_si128((const __m128i *)x8[3]+ib));
auto t0 = _mm256_unpacklo_epi32(m0, m1);
auto t1 = _mm256_unpacklo_epi32(m2, m3);
auto t2 = _mm256_unpackhi_epi32(m0, m1);
auto t3 = _mm256_unpackhi_epi32(m2, m3);
m0 = _mm256_unpacklo_epi64(t0, t1);
m1 = _mm256_unpackhi_epi64(t0, t1);
m2 = _mm256_unpacklo_epi64(t2, t3);
m3 = _mm256_unpackhi_epi64(t2, t3);
#ifdef HAVE_FANCY_SIMD
if (online) {
m0 = _mm256_add_epi8(m0, _mm256_set1_epi8(127));
m1 = _mm256_add_epi8(m1, _mm256_set1_epi8(127));
m2 = _mm256_add_epi8(m2, _mm256_set1_epi8(127));
m3 = _mm256_add_epi8(m3, _mm256_set1_epi8(127));
}
#endif
_mm256_storeu_si256((__m256i *)qy + 4*ib+0, m0);
_mm256_storeu_si256((__m256i *)qy + 4*ib+1, m1);
_mm256_storeu_si256((__m256i *)qy + 4*ib+2, m2);
_mm256_storeu_si256((__m256i *)qy + 4*ib+3, m3);
#elif defined __ARM_NEON
m0.val[0] = vld1q_s8(x8[0]+16*ib); m0.val[1] = vld1q_s8(x8[4]+16*ib);
m1.val[0] = vld1q_s8(x8[1]+16*ib); m1.val[1] = vld1q_s8(x8[5]+16*ib);
m2.val[0] = vld1q_s8(x8[2]+16*ib); m2.val[1] = vld1q_s8(x8[6]+16*ib);
m3.val[0] = vld1q_s8(x8[3]+16*ib); m3.val[1] = vld1q_s8(x8[7]+16*ib);
auto row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[0]), vreinterpretq_s32_s8(m1.val[0]));
auto row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[0]), vreinterpretq_s32_s8(m3.val[0]));
m0.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
m1.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
m2.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
m3.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[1]), vreinterpretq_s32_s8(m1.val[1]));
row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[1]), vreinterpretq_s32_s8(m3.val[1]));
m0.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
m1.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
m2.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
m3.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
vst1q_s8_x2(qy + 0 + 128*ib, m0);
vst1q_s8_x2(qy + 32 + 128*ib, m1);
vst1q_s8_x2(qy + 64 + 128*ib, m2);
vst1q_s8_x2(qy + 96 + 128*ib, m3);
#else
// TODO
for (int l = 0; l < 4; ++l) {
for (int k = 0; k < 8; ++k) for (int i = 0; i < 4; ++i) {
y[ib].qs[32*l+4*k+i+ 0] = x8[k][ib].qs[i+4*l+ 0];
y[ib].qs[32*l+4*k+i+128] = x8[k][ib].qs[i+4*l+16];
}
}
#endif
}
cx += 8*row_size_x;
cy += online ? 8*row_size_x : 8*row_size_y;
//So, if we are run-time-repacking (online = true) we don't want to change the stride, so we just leave some unused space at the end of each row
}
}
#ifdef HAVE_FANCY_SIMD
static void modify_q8_KV_r8(int64_t k, char * cy) {
int8_t * q8 = (int8_t *)(cy + 8*sizeof(float));
for (int j = 0; j < k; ++j) q8[j] += 127;
}
#endif
size_t quantize_q8_KV_r8(const float * src, void * dst, int64_t nrows, int64_t n_per_row, [[maybe_unused]] const float * imatrix) {
GGML_ASSERT(nrows%8 == 0);
GGML_ASSERT(n_per_row%16 == 0);
char * qcur = (char *)dst;
auto row_size_0 = ggml_row_size(GGML_TYPE_Q8_KV, n_per_row);
auto row_size_1 = ggml_row_size(GGML_TYPE_Q8_KV_R8, n_per_row);
std::vector<char> qtmp(8*row_size_0);
for (int row = 0; row < nrows; row += 8) {
quantize_q8_KV(src, (void *)qtmp.data(), 8, n_per_row, imatrix);
repack_q8_KV(8, n_per_row, qtmp.data(), qcur, false);
qcur += 8*row_size_1;
src += 8*n_per_row;
}
return nrows*row_size_1;
}
void dequantize_row_q8_KV_r8(const void * vx, float * y, int64_t k) {
auto n_per_row = k/8;
float * y8[8];
for (int k = 0; k < 8; ++k) y8[k] = y + n_per_row*k;
auto dptr = (const float *)vx;
auto q8 = (const int8_t *)(dptr + 8);
for (int ib = 0; ib < n_per_row/16; ++ib) {
for (int k = 0; k < 8; ++k) {
for (int l = 0; l < 4; ++l) {
for (int i = 0; i < 4; ++i) y8[k][16*ib + 4*l + i] = dptr[k] * q8[128*ib + 32*l + 4*k + i];
}
}
}
}
void vec_dot_q8_KV_r8_q8_KV(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
#if GGML_USE_IQK_MULMAT
if (iqk_mul_mat(1, 1, n, GGML_TYPE_Q8_KV_R8, vx, 0, GGML_TYPE_Q8_KV, vy, 0, s, 0, 0, 1)) {
return;
}
#endif
GGML_ASSERT(n%QK4_NL == 0);
GGML_ASSERT(nrc == 1);
GGML_UNUSED(bs);
GGML_UNUSED(bx);
GGML_UNUSED(by);
}
//
// ========================================= bf16_r4
//
@ -6450,6 +6691,47 @@ void vec_dot_iq1_m_r4_q8_k(int n, float * s, size_t bs, const void * vx, size_t
GGML_UNUSED(by);
}
void quantize_row_q8_KV(const float * x, void * vy, int64_t k) {
iqk_quantize_row_q8_KV(x, vy, k);
}
void quantize_row_q8_KV_ref(const float * x, void * y, int64_t k) {
quantize_row_q8_KV(x, y, k);
}
size_t quantize_q8_KV(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) {
(void)imatrix;
auto row_size = ggml_row_size(GGML_TYPE_Q8_KV, n_per_row);
auto q = (char *)dst;
for (int row = 0; row < nrows; ++row) {
quantize_row_q8_KV(src, q, n_per_row);
src += n_per_row;
q += row_size;
}
return row_size*nrows;
}
void dequantize_row_q8_KV(const void * x, float * y, int64_t k) {
auto dptr = (const float *)x;
float d = dptr[0];
auto q8 = (const int8_t *)(dptr + 2);
for (int j = 0; j < k; ++j) y[j] = d * q8[j];
}
void vec_dot_q8_KV_q8_KV(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
#if GGML_USE_IQK_MULMAT
if (iqk_mul_mat(1, 1, n, GGML_TYPE_Q8_KV, vx, 0, GGML_TYPE_Q8_KV, vy, 0, s, 0, 0, 1)) {
return;
}
#endif
GGML_ASSERT(n%QK4_NL == 0);
GGML_ASSERT(nrc == 1);
GGML_UNUSED(bs);
GGML_UNUSED(bx);
GGML_UNUSED(by);
}
//================================================
namespace {
@ -6472,8 +6754,9 @@ bool iqk_modify_tensor(struct ggml_tensor * tensor) {
{ GGML_TYPE_Q4_0_R8, {modify_q4_0_r8, 8} },
#endif
#ifdef HAVE_FANCY_SIMD
{ GGML_TYPE_Q8_0_R8, {modify_q8_0_r8, 8} },
{ GGML_TYPE_Q8_K_R8, {modify_q8_k_r8, 8} },
{ GGML_TYPE_Q8_0_R8, {modify_q8_0_r8, 8} },
{ GGML_TYPE_Q8_K_R8, {modify_q8_k_r8, 8} },
{ GGML_TYPE_Q8_KV_R8, {modify_q8_KV_r8, 8} },
#endif
};
auto it = k_mod_map.find(tensor->type);
@ -6532,6 +6815,7 @@ void iqk_repack_tensor(struct ggml_tensor * tensor) {
{ GGML_TYPE_Q6_0, { GGML_TYPE_Q6_0_R4, 4, (Repack::repack_func)repack_q6_0} },
{ GGML_TYPE_Q8_0, { GGML_TYPE_Q8_0_R8, 8, (Repack::repack_func)repack_q8_0} },
{ GGML_TYPE_Q8_K, { GGML_TYPE_Q8_K_R8, 8, (Repack::repack_func)repack_q8_k} },
{ GGML_TYPE_Q8_KV, { GGML_TYPE_Q8_KV_R8, 8, (Repack::repack_func)repack_q8_KV} },
#ifdef __AVX512BF16__
{ GGML_TYPE_BF16, { GGML_TYPE_BF16_R16, 16, (Repack::repack_func)repack_bf16<ggml_bf16_t>}},
{ GGML_TYPE_F16, { GGML_TYPE_BF16_R16, 16, (Repack::repack_func)repack_bf16<ggml_half>} },

View File

@ -217,6 +217,18 @@ size_t quantize_q8_k_r8(const float * GGML_RESTRICT src, void * GGML_RESTRICT ds
void dequantize_row_q8_k_r8(const block_q8_k_r8 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void vec_dot_q8_k_r8_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void quantize_row_q8_KV_ref(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q8_KV(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
size_t quantize_q8_KV(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
void dequantize_row_q8_KV(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void vec_dot_q8_KV_q8_KV(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void quantize_row_q8_KV_r8_ref(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q8_KV_r8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
size_t quantize_q8_KV_r8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
void dequantize_row_q8_KV_r8(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void vec_dot_q8_KV_r8_q8_KV(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void iqk_quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
void quantize_row_q8_K64_ref(const float * GGML_RESTRICT x, block_q8_K64 * GGML_RESTRICT y, int64_t k);
void quantize_row_q8_K64(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);

View File

@ -180,6 +180,7 @@ extern "C" {
LLAMA_FTYPE_MOSTLY_IQ3_KL = 146, // except 1d tensors
LLAMA_FTYPE_MOSTLY_IQ2_KS = 147, // except 1d tensors
LLAMA_FTYPE_MOSTLY_IQ4_KSS = 148, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q8_KV = 149, // except 1d tensors
//
LLAMA_FTYPE_MOSTLY_Q4_0_R8 = 202, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q8_0_R8 = 207, // except 1d tensors
@ -206,6 +207,7 @@ extern "C" {
LLAMA_FTYPE_MOSTLY_IQ4_K_R4 = 340, // except 1d tensors
LLAMA_FTYPE_MOSTLY_IQ5_K_R4 = 341, // except 1d tensors
LLAMA_FTYPE_MOSTLY_IQ4_KS_R4 = 345, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q8_KV_R8 = 398, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q8_K_R8 = 399, // except 1d tensors
LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file

View File

@ -3180,6 +3180,10 @@ static bool llama_kv_cache_init(
for (int i = 0; i < (int) n_layer; i++) {
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
const uint32_t n_head = hparams.n_head(i);
const uint32_t n_head_kv = hparams.n_head_kv(i);
const uint32_t n_embd_head_k= hparams.n_embd_head_k;
struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front();
ggml_tensor * k;
@ -3201,7 +3205,8 @@ static bool llama_kv_cache_init(
const uint32_t kv_lora_rank = hparams.n_lora_kv;
LLAMA_LOG_INFO("%s: layer %d: n_embd_head_qk_rope = %d, kv_lora_rank = %d\n", __func__, i, n_embd_head_qk_rope, kv_lora_rank);
#if MLA_USE_TRANSPOSED_CACHE
ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_k, (kv_lora_rank + n_embd_head_qk_rope)*kv_size);
ggml_tensor * kv = ggml_new_tensor_2d(ctx, cache.type_k, kv_lora_rank + n_embd_head_qk_rope, kv_size);
//ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_k, (kv_lora_rank + n_embd_head_qk_rope)*kv_size);
#else
ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_v, (kv_lora_rank + n_embd_head_qk_rope)*kv_size);
#endif
@ -3215,7 +3220,10 @@ static bool llama_kv_cache_init(
n_mla++;
}
else {
k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
//printf("Creating cache tensors:\n");
//printf("n_embd_k_gqa = %d, kv_size = %d, n_head = %d, n_head_kv = %d, n_embd_head_k = %d\n", (int)n_embd_k_gqa, (int)kv_size, (int)n_head, (int)n_head_kv, (int)n_embd_head_k);
//k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
k = ggml_new_tensor_2d(ctx, type_k, n_embd_head_k, n_head_kv*kv_size);
v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
ggml_format_name(k, "cache_k_l%d", i);
ggml_format_name(v, "cache_v_l%d", i);
@ -4002,6 +4010,7 @@ struct llama_model_loader {
case GGML_TYPE_Q5_1: ftype = LLAMA_FTYPE_MOSTLY_Q5_1; break;
case GGML_TYPE_Q6_0: ftype = LLAMA_FTYPE_MOSTLY_Q6_0; break;
case GGML_TYPE_Q8_0: ftype = LLAMA_FTYPE_MOSTLY_Q8_0; break;
case GGML_TYPE_Q8_KV: ftype = LLAMA_FTYPE_MOSTLY_Q8_KV; break;
case GGML_TYPE_Q2_K: ftype = LLAMA_FTYPE_MOSTLY_Q2_K; break;
case GGML_TYPE_Q3_K: ftype = LLAMA_FTYPE_MOSTLY_Q3_K_M; break;
case GGML_TYPE_Q3_K_R4: ftype = LLAMA_FTYPE_MOSTLY_Q3_K_R4; break;
@ -4012,6 +4021,7 @@ struct llama_model_loader {
case GGML_TYPE_Q6_K: ftype = LLAMA_FTYPE_MOSTLY_Q6_K; break;
case GGML_TYPE_Q6_K_R4: ftype = LLAMA_FTYPE_MOSTLY_Q6_K_R4; break;
case GGML_TYPE_Q8_K_R8: ftype = LLAMA_FTYPE_MOSTLY_Q8_K_R8; break;
case GGML_TYPE_Q8_KV_R8: ftype = LLAMA_FTYPE_MOSTLY_Q8_KV_R8; break;
case GGML_TYPE_IQ2_XXS: ftype = LLAMA_FTYPE_MOSTLY_IQ2_XXS; break;
case GGML_TYPE_IQ2_XXS_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ2_XXS_R4; break;
case GGML_TYPE_IQ2_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ2_XS; break;
@ -4730,6 +4740,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
case LLAMA_FTYPE_MOSTLY_Q5_1: return "Q5_1";
case LLAMA_FTYPE_MOSTLY_Q6_0: return "Q6_0";
case LLAMA_FTYPE_MOSTLY_Q8_0: return "Q8_0";
case LLAMA_FTYPE_MOSTLY_Q8_KV: return "Q8_KV";
case LLAMA_FTYPE_MOSTLY_Q2_K: return "Q2_K - Medium";
case LLAMA_FTYPE_MOSTLY_Q2_K_R4: return "Q2_K_R4";
case LLAMA_FTYPE_MOSTLY_Q2_K_S: return "Q2_K - Small";
@ -4746,6 +4757,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
case LLAMA_FTYPE_MOSTLY_Q6_K: return "Q6_K";
case LLAMA_FTYPE_MOSTLY_Q6_K_R4: return "Q6_K_R4";
case LLAMA_FTYPE_MOSTLY_Q8_K_R8: return "Q8_K_R8";
case LLAMA_FTYPE_MOSTLY_Q8_KV_R8: return "Q8_KV_R8";
case LLAMA_FTYPE_MOSTLY_IQ2_XXS: return "IQ2_XXS - 2.0625 bpw";
case LLAMA_FTYPE_MOSTLY_IQ2_XXS_R4:return "IQ2_XXS_R4 - 2.0625 bpw";
case LLAMA_FTYPE_MOSTLY_IQ2_XS: return "IQ2_XS - 2.3125 bpw";
@ -8283,11 +8295,20 @@ static void llm_build_kv_store(
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
const int64_t n_head = hparams.n_head(il);
const int64_t n_head_kv = hparams.n_head_kv(il);
const int64_t n_embd_head_k = hparams.n_embd_head_k;
const int64_t n_embd_head_v = hparams.n_embd_head_v;
GGML_ASSERT(kv.size == n_ctx);
struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_k_gqa,
(ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa))*kv_head);
cb(k_cache_view, "k_cache_view", il);
//struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_k_gqa,
// (ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa))*kv_head);
//cb(k_cache_view, "k_cache_view", il);
auto k_row_size = ggml_row_size(kv.k_l[il]->type, n_embd_head_k);
ggml_tensor * k_cache_view = ggml_view_2d(ctx, kv.k_l[il], n_embd_head_k, n_tokens*n_head_kv,
k_row_size, k_row_size*n_head_kv*kv_head);
// note: storing RoPE-ed version of K in the KV cache
ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view));
@ -8706,7 +8727,7 @@ static struct ggml_tensor * llm_build_kqv(
struct ggml_tensor * k =
ggml_view_3d(ctx, kv.k_l[il],
n_embd_head_k, n_kv, n_head_kv,
ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa),
ggml_row_size(kv.k_l[il]->type, n_embd_head_k)*n_head_kv, //n_embd_k_gqa),
ggml_row_size(kv.k_l[il]->type, n_embd_head_k),
0);
cb(k, "k", il);
@ -13507,8 +13528,9 @@ struct llm_build_context {
ggml_tensor * kvr = ggml_concat(ctx0, kv_compressed, ggml_permute(ctx0, k_rope, 0, 2, 1, 3), 0);
cb(kvr, "kvr", il);
ggml_tensor * kv_cache_view = ggml_view_1d(ctx0, kv_self.kv_l[il], n_tokens*(kv_lora_rank + n_embd_head_qk_rope),
ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope)*kv_head);
auto row_size = ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope);
ggml_tensor * kv_cache_view = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_self.kv_l[il]->ne[0], n_tokens,
row_size, row_size*kv_head);
ggml_build_forward_expand(gf, ggml_cpy(ctx0, kvr, kv_cache_view));
ggml_tensor * kv_cache = ggml_view_2d(ctx0, kv_self.kv_l[il],
kv_lora_rank + n_embd_head_qk_rope, n_kv,
@ -16164,7 +16186,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
new_type = GGML_TYPE_IQ5_K;
}
else if (new_type != GGML_TYPE_Q8_0 && new_type != GGML_TYPE_Q8_0_R8 && new_type != GGML_TYPE_IQ6_K && new_type != GGML_TYPE_Q6_K_R4 &&
new_type != GGML_TYPE_Q8_K_R8) {
new_type != GGML_TYPE_Q8_K_R8 && new_type != GGML_TYPE_Q8_KV && new_type != GGML_TYPE_Q8_KV_R8) {
new_type = GGML_TYPE_Q6_K;
}
}
@ -16218,6 +16240,9 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
else if (new_type == GGML_TYPE_Q8_K_R8) {
new_type = GGML_TYPE_Q8_0;
}
else if (new_type == GGML_TYPE_Q8_KV_R8) {
new_type = GGML_TYPE_Q8_0;
}
else if (new_type == GGML_TYPE_IQ2_K_R4) {
new_type = GGML_TYPE_IQ2_K;
}
@ -16728,6 +16753,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
case LLAMA_FTYPE_MOSTLY_Q5_1: default_type = GGML_TYPE_Q5_1; break;
case LLAMA_FTYPE_MOSTLY_Q6_0: default_type = GGML_TYPE_Q6_0; break;
case LLAMA_FTYPE_MOSTLY_Q8_0: default_type = GGML_TYPE_Q8_0; break;
case LLAMA_FTYPE_MOSTLY_Q8_KV:default_type = GGML_TYPE_Q8_KV;break;
case LLAMA_FTYPE_MOSTLY_F16: default_type = GGML_TYPE_F16; break;
case LLAMA_FTYPE_MOSTLY_BF16: default_type = GGML_TYPE_BF16; break;
case LLAMA_FTYPE_MOSTLY_BF16_R16: default_type = GGML_TYPE_BF16_R16; break;
@ -16751,6 +16777,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
case LLAMA_FTYPE_MOSTLY_Q6_K: default_type = GGML_TYPE_Q6_K; break;
case LLAMA_FTYPE_MOSTLY_Q6_K_R4: default_type = GGML_TYPE_Q6_K_R4; break;
case LLAMA_FTYPE_MOSTLY_Q8_K_R8: default_type = GGML_TYPE_Q8_K_R8; break;
case LLAMA_FTYPE_MOSTLY_Q8_KV_R8: default_type = GGML_TYPE_Q8_KV_R8; break;
case LLAMA_FTYPE_MOSTLY_IQ2_XXS: default_type = GGML_TYPE_IQ2_XXS; break;
case LLAMA_FTYPE_MOSTLY_IQ2_XXS_R4:default_type = GGML_TYPE_IQ2_XXS_R4; break;
case LLAMA_FTYPE_MOSTLY_IQ2_XS: default_type = GGML_TYPE_IQ2_XS; break;
@ -17194,6 +17221,10 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
if (tensor->ne[1] % 8 != 0) new_type = GGML_TYPE_Q8_0;
else chunk_size_multiplier = 8;
}
else if (new_type == GGML_TYPE_Q8_KV_R8) {
if (tensor->ne[1] % 8 != 0) new_type = GGML_TYPE_Q8_0;
else chunk_size_multiplier = 8;
}
else if (new_type == GGML_TYPE_IQ2_BN_R4) {
if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ2_BN;
else chunk_size_multiplier = 4;