Much faster IQ2_KS quantization (#1672)

* Much faster iq2_ks quantization

* Slightly better

* Make the iq2_ks slow quantization path a compile time option
This commit is contained in:
Kawrakow 2026-04-22 11:00:07 +02:00 committed by GitHub
parent e0596bf614
commit 286ce324ba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 301 additions and 82 deletions

View File

@ -185,6 +185,10 @@ static void test_roundtrip_on_chunk(
for (int i = 0; i < chunk_size; i++) {
input_scratch[i] = ggml_get_f32_1d(layer, i + offset);
}
} else if (layer->type == GGML_TYPE_BF16) {
for (int i = 0; i < chunk_size; i++) {
input_scratch[i] = ggml_get_f32_1d(layer, i + offset);
}
} else {
input_scratch = ggml_get_data_f32(layer) + offset;
}
@ -211,7 +215,7 @@ static void test_roundtrip_on_layer(
uint64_t nelements = ggml_nelements(layer);
float* input_scratch_ptr = nullptr;
if (layer->type == GGML_TYPE_F16) {
if (layer->type == GGML_TYPE_F16 || layer->type == GGML_TYPE_BF16) {
if (input_scratch.size() < nelements) input_scratch.resize(nelements);
input_scratch_ptr = input_scratch.data();
}
@ -1587,6 +1591,7 @@ int main(int argc, char ** argv) {
int included_layers = 0;
int64_t max_nelements = 0;
bool is_f16 = false;
bool is_bf16 = false;
for (const auto& kv_tensor : tensors) {
if (!layer_included(params, kv_tensor.first)) {
continue;
@ -1600,6 +1605,8 @@ int main(int argc, char ** argv) {
}
if (kv_tensor.second->type == GGML_TYPE_F16) {
is_f16 = true;
} else if (kv_tensor.second->type == GGML_TYPE_BF16) {
is_bf16 = true;
} else if (kv_tensor.second->type != GGML_TYPE_F32) {
fprintf(stderr, "%s: error: Quantization should be tested with a float model, "
"this model contains already quantized layers (%s is type %d)\n", __func__, kv_tensor.first.c_str(), kv_tensor.second->type);
@ -1614,6 +1621,9 @@ int main(int argc, char ** argv) {
if (is_f16) {
printf("note: source model is f16\n");
}
if (is_bf16) {
printf("note: source model is bf16\n");
}
printf("testing %d layers with max size %" PRId64 "\n", included_layers, max_nelements);
// allocate scratch space
std::vector<float> input_scratch;

View File

@ -300,6 +300,9 @@ if (GGML_IQK_MUL_MAT)
message(STATUS "Disabling IQK Flash Attention kernels")
endif()
endif()
if (IQK_SLOW_IQ2KS_QUANTIZE)
set_source_files_properties(iqk/iqk_quantize.cpp PROPERTIES COMPILE_DEFINITIONS IQK_SLOW_IQ2KS_QUANTIZE)
endif()
if (GGML_CUDA)
cmake_minimum_required(VERSION 3.18) # for CMAKE_CUDA_ARCHITECTURES

View File

@ -1316,13 +1316,297 @@ void vec_dot_iq2_k_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx,
}
namespace {
#if defined(__AVX2__) && !defined(IQK_SLOW_IQ2KS_QUANTIZE)
inline void to_values_i32(__m256i idx, __m256i ivalues, __m256i * iv) {
auto ival = _mm256_shuffle_epi8(ivalues, idx);
auto ival_1 = _mm256_srli_si256(ival, 8);
iv[0] = _mm256_cvtepi8_epi32(_mm256_castsi256_si128(ival));
iv[1] = _mm256_cvtepi8_epi32(_mm256_castsi256_si128(ival_1));
iv[2] = _mm256_cvtepi8_epi32(_mm256_extracti128_si256(ival, 1));
iv[3] = _mm256_cvtepi8_epi32(_mm256_extracti128_si256(ival_1, 1));
}
inline __m256i to_int8(const __m256i * ibest) {
auto i0 = _mm256_packs_epi32(ibest[0], ibest[1]); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
auto i1 = _mm256_packs_epi32(ibest[2], ibest[3]); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
auto idx = _mm256_packs_epi16(i0, i1); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
auto idx_l = _mm256_castsi256_si128(idx);
auto idx_h = _mm256_extracti128_si256(idx, 1);
auto idx1 = _mm_unpacklo_epi32(idx_l, idx_h); // 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
auto idx2 = _mm_unpackhi_epi32(idx_l, idx_h);
return MM256_SET_M128I(idx2, idx1);
}
bool compute_1block_iq2ks(float d, const __m256 * vx, const __m256 * vw, const int8_t * values, __m256i & this_idx, float & best_d, float & score) {
constexpr int kBlockSize = 32;
uint32_t aux32;
std::memcpy(&aux32, values, sizeof(aux32));
auto ivalues = _mm256_set1_epi32(aux32);
__m256 vbest[8];
__m256i ibest[8];
auto val = _mm256_set1_ps(d*values[0]);
auto ival = _mm256_set1_epi32(0);
for (int k = 0; k < kBlockSize/8; ++k) {
auto diff = _mm256_sub_ps(vx[k], val);
vbest[k] = _mm256_mul_ps(diff, diff);
ibest[k] = ival;
diff = _mm256_add_ps(vx[k], val);
vbest[k+4] = _mm256_mul_ps(diff, diff);
ibest[k+4] = ival;
}
for (int j = 1; j < 4; ++j) {
val = _mm256_set1_ps(d*values[j]);
ival = _mm256_set1_epi32(j);
for (int k = 0; k < kBlockSize/8; ++k) {
auto diff = _mm256_sub_ps(vx[k], val);
diff = _mm256_mul_ps(diff, diff);
auto mask = _mm256_cmp_ps(diff, vbest[k], _CMP_LT_OQ);
vbest[k] = _mm256_or_ps(_mm256_and_ps(mask, diff), _mm256_andnot_ps(mask, vbest[k]));
auto imask = _mm256_castps_si256(mask);
ibest[k] = _mm256_or_si256(_mm256_and_si256(imask, ival), _mm256_andnot_si256(imask, ibest[k]));
diff = _mm256_add_ps(vx[k], val);
diff = _mm256_mul_ps(diff, diff);
mask = _mm256_cmp_ps(diff, vbest[k+4], _CMP_LT_OQ);
vbest[k+4] = _mm256_or_ps(_mm256_and_ps(mask, diff), _mm256_andnot_ps(mask, vbest[k+4]));
imask = _mm256_castps_si256(mask);
ibest[k+4] = _mm256_or_si256(_mm256_and_si256(imask, ival), _mm256_andnot_si256(imask, ibest[k+4]));
}
}
bool result = false;
auto idx1 = to_int8(ibest+0);
auto idx2 = to_int8(ibest+4);
to_values_i32(idx1, ivalues, ibest+0);
to_values_i32(idx2, ivalues, ibest+4);
auto vsqx_1 = _mm256_setzero_ps();
auto vsq2_1 = _mm256_setzero_ps();
auto vsqx_2 = _mm256_setzero_ps();
auto vsq2_2 = _mm256_setzero_ps();
for (int k = 0; k < 4; ++k) {
auto vq1 = _mm256_cvtepi32_ps(ibest[k+0]);
auto vwq1 = _mm256_mul_ps(vw[k], vq1);
auto vq2 = _mm256_cvtepi32_ps(ibest[k+4]);
auto vwq2 = _mm256_mul_ps(vw[k], vq2);
vsqx_1 = _mm256_fmadd_ps(vwq1, vx[k], vsqx_1);
vsq2_1 = _mm256_fmadd_ps(vwq1, vq1, vsq2_1);
vsqx_2 = _mm256_fmadd_ps(vwq2, vx[k], vsqx_2);
vsq2_2 = _mm256_fmadd_ps(vwq2, vq2, vsq2_2);
}
auto sumqx_1 = hsum_float_8(vsqx_1);
auto sumq2_1 = hsum_float_8(vsq2_1);
auto sumqx_2 = hsum_float_8(vsqx_2);
auto sumq2_2 = hsum_float_8(vsq2_2);
if (sumq2_1 > 0) {
best_d = sumqx_1/sumq2_1;
score = sumqx_1 * best_d;
this_idx = idx1;
result = true;
}
if (sumq2_2 > 0 && (!result || sumqx_2*sumqx_2 > score*sumq2_2)) {
best_d = sumqx_2/sumq2_2;
score = sumqx_2 * best_d;
this_idx = idx2;
result = true;
}
return result;
}
float compute_1block_iq2ks_rmse(float d, const __m256 * vx, const __m256 * vw, const int8_t * values, __m256i & this_idx) {
constexpr int kBlockSize = 32;
uint32_t aux32;
std::memcpy(&aux32, values, sizeof(aux32));
auto ivalues = _mm256_set1_epi32(aux32);
__m256 vbest[4];
__m256i ibest[4];
auto val = _mm256_set1_ps(d*values[0]);
auto ival = _mm256_set1_epi32(0);
for (int k = 0; k < kBlockSize/8; ++k) {
auto diff = _mm256_sub_ps(vx[k], val);
vbest[k] = _mm256_mul_ps(diff, diff);
ibest[k] = ival;
}
for (int j = 1; j < 4; ++j) {
val = _mm256_set1_ps(d*values[j]);
ival = _mm256_set1_epi32(j);
for (int k = 0; k < kBlockSize/8; ++k) {
auto diff = _mm256_sub_ps(vx[k], val);
diff = _mm256_mul_ps(diff, diff);
auto mask = _mm256_cmp_ps(diff, vbest[k], _CMP_LT_OQ);
vbest[k] = _mm256_or_ps(_mm256_and_ps(mask, diff), _mm256_andnot_ps(mask, vbest[k]));
auto imask = _mm256_castps_si256(mask);
ibest[k] = _mm256_or_si256(_mm256_and_si256(imask, ival), _mm256_andnot_si256(imask, ibest[k]));
}
}
auto idx = to_int8(ibest);
to_values_i32(idx, ivalues, ibest);
auto vd = _mm256_set1_ps(-d);
auto vrmse = _mm256_setzero_ps();
for (int k = 0; k < 4; ++k) {
auto vq = _mm256_cvtepi32_ps(ibest[k]);
auto diff = _mm256_fmadd_ps(vd, vq, vx[k]);
auto wdiff = _mm256_mul_ps(vw[k], diff);
vrmse = _mm256_fmadd_ps(wdiff, diff, vrmse);
}
this_idx = idx;
return hsum_float_8(vrmse);
}
void quantize_row_iq2_ks_impl(const float * x, void * vy, int n_per_row, const float * quant_weights, float * all_scales, float * all_sw, int8_t * all_Ls) {
constexpr int kBlockSize = 32;
ggml_half * dptr = (ggml_half *)vy;
*dptr = GGML_FP32_TO_FP16(0.f);
block_iq2_ks * y = (block_iq2_ks *)(dptr + 1);
float weight[kBlockSize];
const int8_t * shifted_values = iq2nl_values + 4;
const int nblock = n_per_row/QK_K;
__m256 vx[4], vw[4];
for (int ibl = 0; ibl < nblock; ++ibl) {
memset(&y[ibl], 0, sizeof(block_iq2_ks));
auto scales = all_scales + ibl*(QK_K/kBlockSize);
auto sw = all_sw + ibl*(QK_K/kBlockSize);
const float * xbl = x + ibl*QK_K;
float sumx2 = 0;
for (int j = 0; j < QK_K; ++j) sumx2 += xbl[j]*xbl[j];
const float sigma2 = 1.5f*sumx2/QK_K;
uint16_t extra = 0;
for (int ib = 0; ib < QK_K/kBlockSize; ++ib) {
const float * xb = xbl + kBlockSize*ib;
if (quant_weights) {
const float * qw = quant_weights + ibl*QK_K + ib*kBlockSize;
for (int j = 0; j < kBlockSize; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
} else {
for (int j = 0; j < kBlockSize; ++j) weight[j] = 0.25f*sigma2 + xb[j]*xb[j];
}
float amax = 0, max = 0, sumw = 0;
for (int j = 0; j < kBlockSize; ++j) {
float ax = fabsf(xb[j]);
if (ax > amax) {
amax = ax; max = xb[j];
}
sumw += weight[j];
}
sw[ib] = sumw;
if (amax < 1e-14f) {
scales[ib] = 0;
continue;
}
for (int k = 0; k < 4; ++k) {
vx[k] = _mm256_loadu_ps(xb + 8*k);
vw[k] = _mm256_loadu_ps(weight + 8*k);
}
float d = max/iq2nl_values[7];
float best = 0;
__m256i this_idx;
float this_d, this_score;
if (compute_1block_iq2ks(d, vx, vw, iq2nl_values, this_idx, this_d, this_score)) {
best = this_score; d = this_d;
}
for (int itry = -13; itry <= 13; ++itry) {
if (compute_1block_iq2ks(max/(iq2nl_values[0] + 0.5f*itry), vx, vw, iq2nl_values, this_idx, this_d, this_score)) {
if (this_score > best) {
best = this_score; d = this_d;
}
}
}
bool is_shifted = false;
for (int itry = -13; itry <= 13; ++itry) {
if (compute_1block_iq2ks(max/(iq2nl_values[4] + 0.5f*itry), vx, vw, iq2nl_values + 4, this_idx, this_d, this_score)) {
if (this_score > best) {
best = this_score; d = this_d; is_shifted = true;
}
}
}
scales[ib] = d;
if (is_shifted) extra |= (1 << ib);
}
y[ibl].extra = extra;
}
float d = make_qx_quants(nblock*(QK_K/kBlockSize), 16, all_scales, all_Ls, all_sw);
if (!d) return;
auto vsumqx = _mm256_setzero_ps();
auto vsumq2 = _mm256_setzero_ps();
for (int ibl = 0; ibl < nblock; ++ibl) {
auto scales = all_scales + ibl*(QK_K/kBlockSize);
auto xbl = x + ibl*QK_K;
float sumx2 = 0;
for (int j = 0; j < QK_K; ++j) sumx2 += xbl[j]*xbl[j];
const float sigma2 = 1.5f*sumx2/QK_K;
auto Ls = all_Ls + ibl*(QK_K/kBlockSize);
__m256i idx[4];
for (int ib = 0; ib < QK_K/kBlockSize; ++ib) {
const int8_t * block_values = y[ibl].extra & (1 << ib) ? shifted_values : iq2nl_values;
uint32_t aux32;
std::memcpy(&aux32, block_values, sizeof(aux32));
auto ivalues = _mm256_set1_epi32(aux32);
const float * xb = xbl + kBlockSize*ib;
if (quant_weights) {
const float * qw = quant_weights + ibl*QK_K + ib*kBlockSize;
for (int j = 0; j < kBlockSize; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
} else {
for (int j = 0; j < kBlockSize; ++j) weight[j] = 0.25f*sigma2 + xb[j]*xb[j];
}
for (int k = 0; k < 4; ++k) {
vx[k] = _mm256_loadu_ps(xb + 8*k);
vw[k] = _mm256_loadu_ps(weight + 8*k);
}
int ls = Ls[ib] - 16;
float dl = d*ls;
__m256i idx1, idx2;
auto rmse1 = compute_1block_iq2ks_rmse(dl, vx, vw, block_values, idx1);
if (Ls[ib] > 0 && dl > scales[ib]) {
auto rmse2 = compute_1block_iq2ks_rmse(d*(Ls[ib] - 17), vx, vw, block_values, idx2);
if (rmse2 < rmse1) {
--Ls[ib]; idx1 = idx2;
}
}
else if (Ls[ib] < 15 && dl < scales[ib]) {
auto rmse2 = compute_1block_iq2ks_rmse(d*(Ls[ib] - 15), vx, vw, block_values, idx2);
if (rmse2 < rmse1) {
++Ls[ib]; idx1 = idx2;
}
}
__m256i iv[4];
to_values_i32(idx1, ivalues, iv);
auto vd = _mm256_set1_ps(Ls[ib] - 16);
for (int k = 0; k < 4; ++k) {
auto vq = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(iv[k]));
auto wvq = _mm256_mul_ps(vw[k], vq);
vsumqx = _mm256_fmadd_ps(wvq, vx[k], vsumqx);
vsumq2 = _mm256_fmadd_ps(wvq, vq, vsumq2);
}
ls = Ls[ib];
y[ibl].scales[ib/2] |= ((ls & 0xf) << 4*(ib%2));
y[ibl].extra |= ((ls >> 4) << (8 + ib));
idx[ib % 4] = idx1;
if ((ib % 4) == 3) {
auto vqs1 = _mm256_or_si256(idx[0], _mm256_slli_epi16(idx[1], 2));
auto vqs2 = _mm256_or_si256(_mm256_slli_epi16(idx[2], 4), _mm256_slli_epi16(idx[3], 6));
auto vqs = _mm256_or_si256(vqs1, vqs2);
_mm256_storeu_si256((__m256i *)y[ibl].qs + ib/4, vqs);
}
}
}
float sumqx = hsum_float_8(vsumqx);
float sumq2 = hsum_float_8(vsumq2);
*dptr = GGML_FP32_TO_FP16(1.000f*(sumq2 > 0 ? sumqx/sumq2 : d));
}
#else
void quantize_row_iq2_ks_impl(const float * x, void * vy, int n_per_row, const float * quant_weights, float * all_scales, float * all_sw, int8_t * all_Ls) {
constexpr int kBlockSize = 32;
constexpr int kMax_i1 = 3*kBlockSize/4;
constexpr int kMin_i3 = kBlockSize/4;
//constexpr int kNtry = 5;
//constexpr float kStep = 1.f;
ggml_half * dptr = (ggml_half *)vy;
*dptr = GGML_FP32_TO_FP16(0.f);
@ -1375,83 +1659,6 @@ void quantize_row_iq2_ks_impl(const float * x, void * vy, int n_per_row, const f
scales[ib] = 0;
continue;
}
//float amax = 0, max = 0;
//for (int j = 0; j < kBlockSize; ++j) {
// float ax = fabsf(xb[j]);
// if (ax > amax) {
// amax = ax; max = xb[j];
// }
//}
//if (!amax) {
// scales[ib] = 0;
// continue;
//}
//float d = kNtry > 0 ? -max/iq2nl_values[0] : max/iq2nl_values[0];
//float id = 1/d;
//float sumqx_p = 0, sumq2_p = 0;
//float sumqx_m = 0, sumq2_m = 0;
//for (int j = 0; j < kBlockSize; ++j) {
// float w = weight[j];
// float al = id*xb[j];
// int l = best_index_iq2nl(iq2nl_values, al);
// float q = iq2nl_values[l];
// sumqx_p += w*q*xb[j];
// sumq2_p += w*q*q;
// l = best_index_iq2nl(iq2nl_values, -al);
// q = iq2nl_values[l];
// sumqx_m += w*q*xb[j];
// sumq2_m += w*q*q;
//}
//d = sumqx_p/sumq2_p;
//float best = d*sumqx_p;
//if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) {
// d = sumqx_m/sumq2_m; best = d*sumqx_m;
//}
//bool is_shifted = false;
//for (int itry = -kNtry; itry <= kNtry; ++itry) {
// id = (kStep*itry + iq2nl_values[0])/max;
// sumqx_p = sumq2_p = 0;
// sumqx_m = sumq2_m = 0;
// for (int j = 0; j < kBlockSize; ++j) {
// float w = weight[j];
// float al = id*xb[j];
// int l = best_index_iq2nl(iq2nl_values, al);
// float q = iq2nl_values[l];
// sumqx_p += w*q*xb[j];
// sumq2_p += w*q*q;
// l = best_index_iq2nl(iq2nl_values, -al);
// q = iq2nl_values[l];
// sumqx_m += w*q*xb[j];
// sumq2_m += w*q*q;
// }
// if (sumq2_p > 0 && sumqx_p*sumqx_p > best*sumq2_p) {
// d = sumqx_p/sumq2_p; best = d * sumqx_p; is_shifted = false;
// }
// if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) {
// d = sumqx_m/sumq2_m; best = d * sumqx_m; is_shifted = false;
// }
// id = (kStep*itry + shifted_values[0])/max;
// sumqx_p = sumq2_p = 0;
// sumqx_m = sumq2_m = 0;
// for (int j = 0; j < kBlockSize; ++j) {
// float w = weight[j];
// float al = id*xb[j];
// int l = best_index_iq2nl(shifted_values, al);
// float q = shifted_values[l];
// sumqx_p += w*q*xb[j];
// sumq2_p += w*q*q;
// l = best_index_iq2nl(shifted_values, -al);
// q = shifted_values[l];
// sumqx_m += w*q*xb[j];
// sumq2_m += w*q*q;
// }
// if (sumq2_p > 0 && sumqx_p*sumqx_p > best*sumq2_p) {
// d = sumqx_p/sumq2_p; best = d * sumqx_p; is_shifted = true;
// }
// if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) {
// d = sumqx_m/sumq2_m; best = d * sumqx_m; is_shifted = true;
// }
//}
std::sort(pairs.begin(), pairs.end());
sumx[0] = sumw[0] = 0;
for (int j = 0; j < kBlockSize; ++j) {
@ -1498,10 +1705,8 @@ void quantize_row_iq2_ks_impl(const float * x, void * vy, int n_per_row, const f
}
scales[ib] = d;
if (is_shifted) extra |= (1 << ib);
}
y[ibl].extra = extra;
}
float d = make_qx_quants(nblock*(QK_K/kBlockSize), 16, all_scales, all_Ls, all_sw);
@ -1546,6 +1751,7 @@ void quantize_row_iq2_ks_impl(const float * x, void * vy, int n_per_row, const f
}
*dptr = GGML_FP32_TO_FP16(1.030f*(sumq2 > 0 ? sumqx/sumq2 : d));
}
#endif
}
void quantize_row_iq2_ks_ref(const float * x, block_iq2_ks * y, int64_t k) {