diff --git a/examples/quantize-stats/quantize-stats.cpp b/examples/quantize-stats/quantize-stats.cpp index 7e2c8b3a..7e680ef3 100644 --- a/examples/quantize-stats/quantize-stats.cpp +++ b/examples/quantize-stats/quantize-stats.cpp @@ -185,6 +185,10 @@ static void test_roundtrip_on_chunk( for (int i = 0; i < chunk_size; i++) { input_scratch[i] = ggml_get_f32_1d(layer, i + offset); } + } else if (layer->type == GGML_TYPE_BF16) { + for (int i = 0; i < chunk_size; i++) { + input_scratch[i] = ggml_get_f32_1d(layer, i + offset); + } } else { input_scratch = ggml_get_data_f32(layer) + offset; } @@ -211,7 +215,7 @@ static void test_roundtrip_on_layer( uint64_t nelements = ggml_nelements(layer); float* input_scratch_ptr = nullptr; - if (layer->type == GGML_TYPE_F16) { + if (layer->type == GGML_TYPE_F16 || layer->type == GGML_TYPE_BF16) { if (input_scratch.size() < nelements) input_scratch.resize(nelements); input_scratch_ptr = input_scratch.data(); } @@ -1587,6 +1591,7 @@ int main(int argc, char ** argv) { int included_layers = 0; int64_t max_nelements = 0; bool is_f16 = false; + bool is_bf16 = false; for (const auto& kv_tensor : tensors) { if (!layer_included(params, kv_tensor.first)) { continue; @@ -1600,6 +1605,8 @@ int main(int argc, char ** argv) { } if (kv_tensor.second->type == GGML_TYPE_F16) { is_f16 = true; + } else if (kv_tensor.second->type == GGML_TYPE_BF16) { + is_bf16 = true; } else if (kv_tensor.second->type != GGML_TYPE_F32) { fprintf(stderr, "%s: error: Quantization should be tested with a float model, " "this model contains already quantized layers (%s is type %d)\n", __func__, kv_tensor.first.c_str(), kv_tensor.second->type); @@ -1614,6 +1621,9 @@ int main(int argc, char ** argv) { if (is_f16) { printf("note: source model is f16\n"); } + if (is_bf16) { + printf("note: source model is bf16\n"); + } printf("testing %d layers with max size %" PRId64 "\n", included_layers, max_nelements); // allocate scratch space std::vector input_scratch; diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index 078876dd..0d4f4023 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -300,6 +300,9 @@ if (GGML_IQK_MUL_MAT) message(STATUS "Disabling IQK Flash Attention kernels") endif() endif() +if (IQK_SLOW_IQ2KS_QUANTIZE) + set_source_files_properties(iqk/iqk_quantize.cpp PROPERTIES COMPILE_DEFINITIONS IQK_SLOW_IQ2KS_QUANTIZE) +endif() if (GGML_CUDA) cmake_minimum_required(VERSION 3.18) # for CMAKE_CUDA_ARCHITECTURES diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index 07edaf03..1d718049 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -1316,13 +1316,297 @@ void vec_dot_iq2_k_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, } namespace { +#if defined(__AVX2__) && !defined(IQK_SLOW_IQ2KS_QUANTIZE) +inline void to_values_i32(__m256i idx, __m256i ivalues, __m256i * iv) { + auto ival = _mm256_shuffle_epi8(ivalues, idx); + auto ival_1 = _mm256_srli_si256(ival, 8); + iv[0] = _mm256_cvtepi8_epi32(_mm256_castsi256_si128(ival)); + iv[1] = _mm256_cvtepi8_epi32(_mm256_castsi256_si128(ival_1)); + iv[2] = _mm256_cvtepi8_epi32(_mm256_extracti128_si256(ival, 1)); + iv[3] = _mm256_cvtepi8_epi32(_mm256_extracti128_si256(ival_1, 1)); +} +inline __m256i to_int8(const __m256i * ibest) { + auto i0 = _mm256_packs_epi32(ibest[0], ibest[1]); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 + auto i1 = _mm256_packs_epi32(ibest[2], ibest[3]); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 + auto idx = _mm256_packs_epi16(i0, i1); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 + auto idx_l = _mm256_castsi256_si128(idx); + auto idx_h = _mm256_extracti128_si256(idx, 1); + auto idx1 = _mm_unpacklo_epi32(idx_l, idx_h); // 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 + auto idx2 = _mm_unpackhi_epi32(idx_l, idx_h); + return MM256_SET_M128I(idx2, idx1); +} +bool compute_1block_iq2ks(float d, const __m256 * vx, const __m256 * vw, const int8_t * values, __m256i & this_idx, float & best_d, float & score) { + constexpr int kBlockSize = 32; + uint32_t aux32; + std::memcpy(&aux32, values, sizeof(aux32)); + auto ivalues = _mm256_set1_epi32(aux32); + __m256 vbest[8]; + __m256i ibest[8]; + auto val = _mm256_set1_ps(d*values[0]); + auto ival = _mm256_set1_epi32(0); + for (int k = 0; k < kBlockSize/8; ++k) { + auto diff = _mm256_sub_ps(vx[k], val); + vbest[k] = _mm256_mul_ps(diff, diff); + ibest[k] = ival; + diff = _mm256_add_ps(vx[k], val); + vbest[k+4] = _mm256_mul_ps(diff, diff); + ibest[k+4] = ival; + } + for (int j = 1; j < 4; ++j) { + val = _mm256_set1_ps(d*values[j]); + ival = _mm256_set1_epi32(j); + for (int k = 0; k < kBlockSize/8; ++k) { + auto diff = _mm256_sub_ps(vx[k], val); + diff = _mm256_mul_ps(diff, diff); + auto mask = _mm256_cmp_ps(diff, vbest[k], _CMP_LT_OQ); + vbest[k] = _mm256_or_ps(_mm256_and_ps(mask, diff), _mm256_andnot_ps(mask, vbest[k])); + auto imask = _mm256_castps_si256(mask); + ibest[k] = _mm256_or_si256(_mm256_and_si256(imask, ival), _mm256_andnot_si256(imask, ibest[k])); + diff = _mm256_add_ps(vx[k], val); + diff = _mm256_mul_ps(diff, diff); + mask = _mm256_cmp_ps(diff, vbest[k+4], _CMP_LT_OQ); + vbest[k+4] = _mm256_or_ps(_mm256_and_ps(mask, diff), _mm256_andnot_ps(mask, vbest[k+4])); + imask = _mm256_castps_si256(mask); + ibest[k+4] = _mm256_or_si256(_mm256_and_si256(imask, ival), _mm256_andnot_si256(imask, ibest[k+4])); + } + } + bool result = false; + auto idx1 = to_int8(ibest+0); + auto idx2 = to_int8(ibest+4); + to_values_i32(idx1, ivalues, ibest+0); + to_values_i32(idx2, ivalues, ibest+4); + auto vsqx_1 = _mm256_setzero_ps(); + auto vsq2_1 = _mm256_setzero_ps(); + auto vsqx_2 = _mm256_setzero_ps(); + auto vsq2_2 = _mm256_setzero_ps(); + for (int k = 0; k < 4; ++k) { + auto vq1 = _mm256_cvtepi32_ps(ibest[k+0]); + auto vwq1 = _mm256_mul_ps(vw[k], vq1); + auto vq2 = _mm256_cvtepi32_ps(ibest[k+4]); + auto vwq2 = _mm256_mul_ps(vw[k], vq2); + vsqx_1 = _mm256_fmadd_ps(vwq1, vx[k], vsqx_1); + vsq2_1 = _mm256_fmadd_ps(vwq1, vq1, vsq2_1); + vsqx_2 = _mm256_fmadd_ps(vwq2, vx[k], vsqx_2); + vsq2_2 = _mm256_fmadd_ps(vwq2, vq2, vsq2_2); + } + auto sumqx_1 = hsum_float_8(vsqx_1); + auto sumq2_1 = hsum_float_8(vsq2_1); + auto sumqx_2 = hsum_float_8(vsqx_2); + auto sumq2_2 = hsum_float_8(vsq2_2); + if (sumq2_1 > 0) { + best_d = sumqx_1/sumq2_1; + score = sumqx_1 * best_d; + this_idx = idx1; + result = true; + } + if (sumq2_2 > 0 && (!result || sumqx_2*sumqx_2 > score*sumq2_2)) { + best_d = sumqx_2/sumq2_2; + score = sumqx_2 * best_d; + this_idx = idx2; + result = true; + } + return result; +} +float compute_1block_iq2ks_rmse(float d, const __m256 * vx, const __m256 * vw, const int8_t * values, __m256i & this_idx) { + constexpr int kBlockSize = 32; + uint32_t aux32; + std::memcpy(&aux32, values, sizeof(aux32)); + auto ivalues = _mm256_set1_epi32(aux32); + __m256 vbest[4]; + __m256i ibest[4]; + auto val = _mm256_set1_ps(d*values[0]); + auto ival = _mm256_set1_epi32(0); + for (int k = 0; k < kBlockSize/8; ++k) { + auto diff = _mm256_sub_ps(vx[k], val); + vbest[k] = _mm256_mul_ps(diff, diff); + ibest[k] = ival; + } + for (int j = 1; j < 4; ++j) { + val = _mm256_set1_ps(d*values[j]); + ival = _mm256_set1_epi32(j); + for (int k = 0; k < kBlockSize/8; ++k) { + auto diff = _mm256_sub_ps(vx[k], val); + diff = _mm256_mul_ps(diff, diff); + auto mask = _mm256_cmp_ps(diff, vbest[k], _CMP_LT_OQ); + vbest[k] = _mm256_or_ps(_mm256_and_ps(mask, diff), _mm256_andnot_ps(mask, vbest[k])); + auto imask = _mm256_castps_si256(mask); + ibest[k] = _mm256_or_si256(_mm256_and_si256(imask, ival), _mm256_andnot_si256(imask, ibest[k])); + } + } + auto idx = to_int8(ibest); + to_values_i32(idx, ivalues, ibest); + auto vd = _mm256_set1_ps(-d); + auto vrmse = _mm256_setzero_ps(); + for (int k = 0; k < 4; ++k) { + auto vq = _mm256_cvtepi32_ps(ibest[k]); + auto diff = _mm256_fmadd_ps(vd, vq, vx[k]); + auto wdiff = _mm256_mul_ps(vw[k], diff); + vrmse = _mm256_fmadd_ps(wdiff, diff, vrmse); + } + this_idx = idx; + return hsum_float_8(vrmse); +} +void quantize_row_iq2_ks_impl(const float * x, void * vy, int n_per_row, const float * quant_weights, float * all_scales, float * all_sw, int8_t * all_Ls) { + + constexpr int kBlockSize = 32; + + ggml_half * dptr = (ggml_half *)vy; + *dptr = GGML_FP32_TO_FP16(0.f); + + block_iq2_ks * y = (block_iq2_ks *)(dptr + 1); + + float weight[kBlockSize]; + + const int8_t * shifted_values = iq2nl_values + 4; + + const int nblock = n_per_row/QK_K; + + __m256 vx[4], vw[4]; + + for (int ibl = 0; ibl < nblock; ++ibl) { + + memset(&y[ibl], 0, sizeof(block_iq2_ks)); + + auto scales = all_scales + ibl*(QK_K/kBlockSize); + auto sw = all_sw + ibl*(QK_K/kBlockSize); + + const float * xbl = x + ibl*QK_K; + float sumx2 = 0; + for (int j = 0; j < QK_K; ++j) sumx2 += xbl[j]*xbl[j]; + const float sigma2 = 1.5f*sumx2/QK_K; + + uint16_t extra = 0; + + for (int ib = 0; ib < QK_K/kBlockSize; ++ib) { + const float * xb = xbl + kBlockSize*ib; + if (quant_weights) { + const float * qw = quant_weights + ibl*QK_K + ib*kBlockSize; + for (int j = 0; j < kBlockSize; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); + } else { + for (int j = 0; j < kBlockSize; ++j) weight[j] = 0.25f*sigma2 + xb[j]*xb[j]; + } + float amax = 0, max = 0, sumw = 0; + for (int j = 0; j < kBlockSize; ++j) { + float ax = fabsf(xb[j]); + if (ax > amax) { + amax = ax; max = xb[j]; + } + sumw += weight[j]; + } + sw[ib] = sumw; + if (amax < 1e-14f) { + scales[ib] = 0; + continue; + } + for (int k = 0; k < 4; ++k) { + vx[k] = _mm256_loadu_ps(xb + 8*k); + vw[k] = _mm256_loadu_ps(weight + 8*k); + } + float d = max/iq2nl_values[7]; + float best = 0; + __m256i this_idx; + float this_d, this_score; + if (compute_1block_iq2ks(d, vx, vw, iq2nl_values, this_idx, this_d, this_score)) { + best = this_score; d = this_d; + } + for (int itry = -13; itry <= 13; ++itry) { + if (compute_1block_iq2ks(max/(iq2nl_values[0] + 0.5f*itry), vx, vw, iq2nl_values, this_idx, this_d, this_score)) { + if (this_score > best) { + best = this_score; d = this_d; + } + } + } + bool is_shifted = false; + for (int itry = -13; itry <= 13; ++itry) { + if (compute_1block_iq2ks(max/(iq2nl_values[4] + 0.5f*itry), vx, vw, iq2nl_values + 4, this_idx, this_d, this_score)) { + if (this_score > best) { + best = this_score; d = this_d; is_shifted = true; + } + } + } + scales[ib] = d; + if (is_shifted) extra |= (1 << ib); + } + y[ibl].extra = extra; + } + + float d = make_qx_quants(nblock*(QK_K/kBlockSize), 16, all_scales, all_Ls, all_sw); + + if (!d) return; + + auto vsumqx = _mm256_setzero_ps(); + auto vsumq2 = _mm256_setzero_ps(); + for (int ibl = 0; ibl < nblock; ++ibl) { + auto scales = all_scales + ibl*(QK_K/kBlockSize); + auto xbl = x + ibl*QK_K; + float sumx2 = 0; + for (int j = 0; j < QK_K; ++j) sumx2 += xbl[j]*xbl[j]; + const float sigma2 = 1.5f*sumx2/QK_K; + auto Ls = all_Ls + ibl*(QK_K/kBlockSize); + __m256i idx[4]; + for (int ib = 0; ib < QK_K/kBlockSize; ++ib) { + const int8_t * block_values = y[ibl].extra & (1 << ib) ? shifted_values : iq2nl_values; + uint32_t aux32; + std::memcpy(&aux32, block_values, sizeof(aux32)); + auto ivalues = _mm256_set1_epi32(aux32); + const float * xb = xbl + kBlockSize*ib; + if (quant_weights) { + const float * qw = quant_weights + ibl*QK_K + ib*kBlockSize; + for (int j = 0; j < kBlockSize; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); + } else { + for (int j = 0; j < kBlockSize; ++j) weight[j] = 0.25f*sigma2 + xb[j]*xb[j]; + } + for (int k = 0; k < 4; ++k) { + vx[k] = _mm256_loadu_ps(xb + 8*k); + vw[k] = _mm256_loadu_ps(weight + 8*k); + } + int ls = Ls[ib] - 16; + float dl = d*ls; + __m256i idx1, idx2; + auto rmse1 = compute_1block_iq2ks_rmse(dl, vx, vw, block_values, idx1); + if (Ls[ib] > 0 && dl > scales[ib]) { + auto rmse2 = compute_1block_iq2ks_rmse(d*(Ls[ib] - 17), vx, vw, block_values, idx2); + if (rmse2 < rmse1) { + --Ls[ib]; idx1 = idx2; + } + } + else if (Ls[ib] < 15 && dl < scales[ib]) { + auto rmse2 = compute_1block_iq2ks_rmse(d*(Ls[ib] - 15), vx, vw, block_values, idx2); + if (rmse2 < rmse1) { + ++Ls[ib]; idx1 = idx2; + } + } + __m256i iv[4]; + to_values_i32(idx1, ivalues, iv); + auto vd = _mm256_set1_ps(Ls[ib] - 16); + for (int k = 0; k < 4; ++k) { + auto vq = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(iv[k])); + auto wvq = _mm256_mul_ps(vw[k], vq); + vsumqx = _mm256_fmadd_ps(wvq, vx[k], vsumqx); + vsumq2 = _mm256_fmadd_ps(wvq, vq, vsumq2); + } + ls = Ls[ib]; + y[ibl].scales[ib/2] |= ((ls & 0xf) << 4*(ib%2)); + y[ibl].extra |= ((ls >> 4) << (8 + ib)); + idx[ib % 4] = idx1; + if ((ib % 4) == 3) { + auto vqs1 = _mm256_or_si256(idx[0], _mm256_slli_epi16(idx[1], 2)); + auto vqs2 = _mm256_or_si256(_mm256_slli_epi16(idx[2], 4), _mm256_slli_epi16(idx[3], 6)); + auto vqs = _mm256_or_si256(vqs1, vqs2); + _mm256_storeu_si256((__m256i *)y[ibl].qs + ib/4, vqs); + } + } + } + float sumqx = hsum_float_8(vsumqx); + float sumq2 = hsum_float_8(vsumq2); + *dptr = GGML_FP32_TO_FP16(1.000f*(sumq2 > 0 ? sumqx/sumq2 : d)); +} +#else void quantize_row_iq2_ks_impl(const float * x, void * vy, int n_per_row, const float * quant_weights, float * all_scales, float * all_sw, int8_t * all_Ls) { constexpr int kBlockSize = 32; constexpr int kMax_i1 = 3*kBlockSize/4; constexpr int kMin_i3 = kBlockSize/4; - //constexpr int kNtry = 5; - //constexpr float kStep = 1.f; ggml_half * dptr = (ggml_half *)vy; *dptr = GGML_FP32_TO_FP16(0.f); @@ -1375,83 +1659,6 @@ void quantize_row_iq2_ks_impl(const float * x, void * vy, int n_per_row, const f scales[ib] = 0; continue; } - //float amax = 0, max = 0; - //for (int j = 0; j < kBlockSize; ++j) { - // float ax = fabsf(xb[j]); - // if (ax > amax) { - // amax = ax; max = xb[j]; - // } - //} - //if (!amax) { - // scales[ib] = 0; - // continue; - //} - //float d = kNtry > 0 ? -max/iq2nl_values[0] : max/iq2nl_values[0]; - //float id = 1/d; - //float sumqx_p = 0, sumq2_p = 0; - //float sumqx_m = 0, sumq2_m = 0; - //for (int j = 0; j < kBlockSize; ++j) { - // float w = weight[j]; - // float al = id*xb[j]; - // int l = best_index_iq2nl(iq2nl_values, al); - // float q = iq2nl_values[l]; - // sumqx_p += w*q*xb[j]; - // sumq2_p += w*q*q; - // l = best_index_iq2nl(iq2nl_values, -al); - // q = iq2nl_values[l]; - // sumqx_m += w*q*xb[j]; - // sumq2_m += w*q*q; - //} - //d = sumqx_p/sumq2_p; - //float best = d*sumqx_p; - //if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) { - // d = sumqx_m/sumq2_m; best = d*sumqx_m; - //} - //bool is_shifted = false; - //for (int itry = -kNtry; itry <= kNtry; ++itry) { - // id = (kStep*itry + iq2nl_values[0])/max; - // sumqx_p = sumq2_p = 0; - // sumqx_m = sumq2_m = 0; - // for (int j = 0; j < kBlockSize; ++j) { - // float w = weight[j]; - // float al = id*xb[j]; - // int l = best_index_iq2nl(iq2nl_values, al); - // float q = iq2nl_values[l]; - // sumqx_p += w*q*xb[j]; - // sumq2_p += w*q*q; - // l = best_index_iq2nl(iq2nl_values, -al); - // q = iq2nl_values[l]; - // sumqx_m += w*q*xb[j]; - // sumq2_m += w*q*q; - // } - // if (sumq2_p > 0 && sumqx_p*sumqx_p > best*sumq2_p) { - // d = sumqx_p/sumq2_p; best = d * sumqx_p; is_shifted = false; - // } - // if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) { - // d = sumqx_m/sumq2_m; best = d * sumqx_m; is_shifted = false; - // } - // id = (kStep*itry + shifted_values[0])/max; - // sumqx_p = sumq2_p = 0; - // sumqx_m = sumq2_m = 0; - // for (int j = 0; j < kBlockSize; ++j) { - // float w = weight[j]; - // float al = id*xb[j]; - // int l = best_index_iq2nl(shifted_values, al); - // float q = shifted_values[l]; - // sumqx_p += w*q*xb[j]; - // sumq2_p += w*q*q; - // l = best_index_iq2nl(shifted_values, -al); - // q = shifted_values[l]; - // sumqx_m += w*q*xb[j]; - // sumq2_m += w*q*q; - // } - // if (sumq2_p > 0 && sumqx_p*sumqx_p > best*sumq2_p) { - // d = sumqx_p/sumq2_p; best = d * sumqx_p; is_shifted = true; - // } - // if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) { - // d = sumqx_m/sumq2_m; best = d * sumqx_m; is_shifted = true; - // } - //} std::sort(pairs.begin(), pairs.end()); sumx[0] = sumw[0] = 0; for (int j = 0; j < kBlockSize; ++j) { @@ -1498,10 +1705,8 @@ void quantize_row_iq2_ks_impl(const float * x, void * vy, int n_per_row, const f } scales[ib] = d; if (is_shifted) extra |= (1 << ib); - } y[ibl].extra = extra; - } float d = make_qx_quants(nblock*(QK_K/kBlockSize), 16, all_scales, all_Ls, all_sw); @@ -1546,6 +1751,7 @@ void quantize_row_iq2_ks_impl(const float * x, void * vy, int n_per_row, const f } *dptr = GGML_FP32_TO_FP16(1.030f*(sumq2 > 0 ? sumqx/sumq2 : d)); } +#endif } void quantize_row_iq2_ks_ref(const float * x, block_iq2_ks * y, int64_t k) {