Use AVX version VNNI intrinsic when AVX512VNNI not available. (#1748)

* Use AVX version VNNI intrinsic when AVX512VNNI not available.

* remove changes under HAVE_FANCY_SIMD

---------

Co-authored-by: XZiar <xziar@xziar.xziar>
This commit is contained in:
XZiar 2026-05-08 23:02:06 -07:00 committed by GitHub
parent 51331f4973
commit ab0f22b819
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 59 additions and 50 deletions

View File

@ -115,7 +115,7 @@ static inline __m256 sum_i16_pairs_float(const __m256i x) {
static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
const __m256i zero = _mm256_setzero_si256();
const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
const __m256i summed_pairs = ggml_mm256_dpbusd_epi32(zero, ax, sy);
return _mm256_cvtepi32_ps(summed_pairs);
#else
// Perform multiplication and create 16-bit values

View File

@ -52,5 +52,14 @@
#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
#define HAVE_VNNI256
#endif
#if defined(__AVX512VNNI__) && defined(__AVX512VL__)
#define ggml_mm256_dpbusd_epi32 _mm256_dpbusd_epi32
#define ggml_mm256_dpwssd_epi32 _mm256_dpwssd_epi32
#define ggml_mm_dpbusd_epi32 _mm_dpbusd_epi32
#elif defined(__AVXVNNI__)
#define ggml_mm256_dpbusd_epi32 _mm256_dpbusd_avx_epi32
#define ggml_mm256_dpwssd_epi32 _mm256_dpwssd_avx_epi32
#define ggml_mm_dpbusd_epi32 _mm_dpbusd_avx_epi32
#endif
#endif

View File

@ -1695,7 +1695,7 @@ static void mul_mat_iq3_xxs_r4_q8_k(int n, const void * vx, size_t bx, const Dat
auto sas = _mm_loadu_si128((const __m128i *)iq3[ibl].sas + ib);
auto scales = _mm_and_si128(sas, _mm_set1_epi8(1));
#ifdef HAVE_VNNI256
scales = _mm_dpbusd_epi32(_mm_set1_epi32(1), scales, _mm_set1_epi32(0x10080402));
scales = ggml_mm_dpbusd_epi32(_mm_set1_epi32(1), scales, _mm_set1_epi32(0x10080402));
#else
scales = _mm_maddubs_epi16(scales, _mm_set1_epi32(0x10080402));
scales = _mm_add_epi32(_mm_madd_epi16(_mm_set1_epi16(1), scales), _mm_set1_epi32(1));
@ -1732,10 +1732,10 @@ static void mul_mat_iq3_xxs_r4_q8_k(int n, const void * vx, size_t bx, const Dat
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ibl].qs + ib);
#ifdef HAVE_VNNI256
auto sumi1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[0], _mm256_sign_epi8(y, s1));
auto sumi2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[1], _mm256_sign_epi8(y, s2));
auto sumi3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[2], _mm256_sign_epi8(y, s3));
auto sumi4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[3], _mm256_sign_epi8(y, s4));
auto sumi1 = ggml_mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[0], _mm256_sign_epi8(y, s1));
auto sumi2 = ggml_mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[1], _mm256_sign_epi8(y, s2));
auto sumi3 = ggml_mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[2], _mm256_sign_epi8(y, s3));
auto sumi4 = ggml_mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[3], _mm256_sign_epi8(y, s4));
#else
auto sumi1 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(qx[0], _mm256_sign_epi8(y, s1)));
auto sumi2 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(qx[1], _mm256_sign_epi8(y, s2)));
@ -1832,10 +1832,10 @@ static void mul_mat_iq3_s_r4_q8_k(int n, const void * vx, size_t bx, const DataI
auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ibl].qs + ib);
auto sumi = _mm256_setzero_si256();
#ifdef HAVE_VNNI256
sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), s1));
sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), s2));
sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), s3));
sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), s4));
sumi = ggml_mm256_dpbusd_epi32(sumi, qx[0], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), s1));
sumi = ggml_mm256_dpbusd_epi32(sumi, qx[1], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), s2));
sumi = ggml_mm256_dpbusd_epi32(sumi, qx[2], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), s3));
sumi = ggml_mm256_dpbusd_epi32(sumi, qx[3], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), s4));
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(sumi, scales));
#else
sumi = _mm256_add_epi16(sumi, _mm256_maddubs_epi16(qx[0], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), s1)));

View File

@ -1416,10 +1416,10 @@ static void mul_mat_q3_k_r4_q8_k(int n, const void * vx, size_t bx, const DataIn
sumi = _mm256_dpwssd_epi32(sumi, s4, _mm256_shuffle_epi32(bsums, 0xff));
isum[iy] = sumi;
#elif defined(HAVE_VNNI256)
sumi = _mm256_dpwssd_epi32(sumi, s1, _mm256_shuffle_epi32(bsums, 0x00));
sumi = _mm256_dpwssd_epi32(sumi, s2, _mm256_shuffle_epi32(bsums, 0x55));
sumi = _mm256_dpwssd_epi32(sumi, s3, _mm256_shuffle_epi32(bsums, 0xaa));
sumi = _mm256_dpwssd_epi32(sumi, s4, _mm256_shuffle_epi32(bsums, 0xff));
sumi = ggml_mm256_dpwssd_epi32(sumi, s1, _mm256_shuffle_epi32(bsums, 0x00));
sumi = ggml_mm256_dpwssd_epi32(sumi, s2, _mm256_shuffle_epi32(bsums, 0x55));
sumi = ggml_mm256_dpwssd_epi32(sumi, s3, _mm256_shuffle_epi32(bsums, 0xaa));
sumi = ggml_mm256_dpwssd_epi32(sumi, s4, _mm256_shuffle_epi32(bsums, 0xff));
if constexpr (nrc_y == 1) {
acc[iy] = _mm256_fmadd_ps(min, _mm256_cvtepi32_ps(sumi), acc[iy]);
} else {
@ -1461,10 +1461,10 @@ static void mul_mat_q3_k_r4_q8_k(int n, const void * vx, size_t bx, const DataIn
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(iscales, sumi));
#elif defined(HAVE_VNNI256)
auto sumi = _mm256_setzero_si256();
sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
sumi = ggml_mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
sumi = ggml_mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
sumi = ggml_mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
sumi = ggml_mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
if constexpr (nrc_y == 1) {
acc[iy] = _mm256_fmadd_ps(scales, _mm256_cvtepi32_ps(sumi), acc[iy]);
} else {
@ -1574,10 +1574,10 @@ static void mul_mat_q4_k_r4_q8_k(int n, const void * vx, size_t bx, const DataIn
auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
#ifdef HAVE_VNNI256
auto sumi = _mm256_setzero_si256();
sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
sumi = ggml_mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
sumi = ggml_mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
sumi = ggml_mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
sumi = ggml_mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(scales_d, sumi));
#else
auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
@ -1645,10 +1645,10 @@ static void mul_mat_q5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataIn
auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
#ifdef HAVE_VNNI256
auto sumi = _mm256_setzero_si256();
sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
sumi = ggml_mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
sumi = ggml_mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
sumi = ggml_mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
sumi = ggml_mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(scales_d, sumi));
#else
auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
@ -1728,10 +1728,10 @@ static void mul_mat_q6_k_r4_q8_k(int n, const void * vx, size_t bx, const DataIn
sumi = _mm256_dpwssd_epi32(sumi, s4, _mm256_shuffle_epi32(bsums, 0xff));
isum[iy] = sumi;
#elif defined(HAVE_VNNI256)
sumi = _mm256_dpwssd_epi32(sumi, s1, _mm256_shuffle_epi32(bsums, 0x00));
sumi = _mm256_dpwssd_epi32(sumi, s2, _mm256_shuffle_epi32(bsums, 0x55));
sumi = _mm256_dpwssd_epi32(sumi, s3, _mm256_shuffle_epi32(bsums, 0xaa));
sumi = _mm256_dpwssd_epi32(sumi, s4, _mm256_shuffle_epi32(bsums, 0xff));
sumi = ggml_mm256_dpwssd_epi32(sumi, s1, _mm256_shuffle_epi32(bsums, 0x00));
sumi = ggml_mm256_dpwssd_epi32(sumi, s2, _mm256_shuffle_epi32(bsums, 0x55));
sumi = ggml_mm256_dpwssd_epi32(sumi, s3, _mm256_shuffle_epi32(bsums, 0xaa));
sumi = ggml_mm256_dpwssd_epi32(sumi, s4, _mm256_shuffle_epi32(bsums, 0xff));
if constexpr (nrc_y == 1) {
acc[iy] = _mm256_fmadd_ps(min, _mm256_cvtepi32_ps(sumi), acc[iy]);
} else {
@ -1774,10 +1774,10 @@ static void mul_mat_q6_k_r4_q8_k(int n, const void * vx, size_t bx, const DataIn
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(iscales, sumi));
#elif defined(HAVE_VNNI256)
auto sumi = _mm256_setzero_si256();
sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
sumi = ggml_mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
sumi = ggml_mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
sumi = ggml_mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
sumi = ggml_mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
if constexpr (nrc_y == 1) {
acc[iy] = _mm256_fmadd_ps(scales, _mm256_cvtepi32_ps(sumi), acc[iy]);
} else {
@ -1861,10 +1861,10 @@ static void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataIn
auto y128 = _mm_loadu_si128((const __m128i*)q8.y[iy][ibl].qs+ib);
auto y = MM256_SET_M128I(y128, y128);
#ifdef HAVE_VNNI256
isum[iy] = _mm256_dpbusd_epi32(isum[iy], s0, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]));
isum[iy] = _mm256_dpbusd_epi32(isum[iy], s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
isum[iy] = _mm256_dpbusd_epi32(isum[iy], s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
isum[iy] = _mm256_dpbusd_epi32(isum[iy], s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
isum[iy] = ggml_mm256_dpbusd_epi32(isum[iy], s0, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]));
isum[iy] = ggml_mm256_dpbusd_epi32(isum[iy], s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
isum[iy] = ggml_mm256_dpbusd_epi32(isum[iy], s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
isum[iy] = ggml_mm256_dpbusd_epi32(isum[iy], s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
#else
auto sumi1 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s0, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0])));
auto sumi2 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1])));

View File

@ -915,10 +915,10 @@ static void mul_mat_iq4_nl_r4_q8_2(int n, const void * vx, size_t bx, const Data
#ifdef HAVE_VNNI256
auto dot = [&qs] (__m256i y) {
auto sumi = _mm256_setzero_si256();
sumi = _mm256_dpbusd_epi32(sumi, _mm256_sign_epi8(qs[0], qs[0]), _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qs[0]));
sumi = _mm256_dpbusd_epi32(sumi, _mm256_sign_epi8(qs[1], qs[1]), _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qs[1]));
sumi = _mm256_dpbusd_epi32(sumi, _mm256_sign_epi8(qs[2], qs[2]), _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qs[2]));
sumi = _mm256_dpbusd_epi32(sumi, _mm256_sign_epi8(qs[3], qs[3]), _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qs[3]));
sumi = ggml_mm256_dpbusd_epi32(sumi, _mm256_sign_epi8(qs[0], qs[0]), _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qs[0]));
sumi = ggml_mm256_dpbusd_epi32(sumi, _mm256_sign_epi8(qs[1], qs[1]), _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qs[1]));
sumi = ggml_mm256_dpbusd_epi32(sumi, _mm256_sign_epi8(qs[2], qs[2]), _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qs[2]));
sumi = ggml_mm256_dpbusd_epi32(sumi, _mm256_sign_epi8(qs[3], qs[3]), _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qs[3]));
return sumi;
};
#else
@ -1680,10 +1680,10 @@ static void mul_mat_q8_0_r8_q8_2(int n, const void * vx, size_t bx, const DataIn
auto y = MM256_SET_M128I(y128, y128);
#ifdef HAVE_VNNI256
auto sumi = _mm256_setzero_si256();
sumi = _mm256_dpbusd_epi32(sumi, sx[0], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]));
sumi = _mm256_dpbusd_epi32(sumi, sx[1], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
sumi = _mm256_dpbusd_epi32(sumi, sx[2], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
sumi = _mm256_dpbusd_epi32(sumi, sx[3], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
sumi = ggml_mm256_dpbusd_epi32(sumi, sx[0], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]));
sumi = ggml_mm256_dpbusd_epi32(sumi, sx[1], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
sumi = ggml_mm256_dpbusd_epi32(sumi, sx[2], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
sumi = ggml_mm256_dpbusd_epi32(sumi, sx[3], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
return sumi;
#else
auto sumi1 = _mm256_add_epi32(
@ -1879,10 +1879,10 @@ static void mul_mat_q8_1_r8_q8_2(int n, const void * vx, size_t bx, const DataIn
auto y = MM256_SET_M128I(y128, y128);
#ifdef HAVE_VNNI256
auto sumi = _mm256_setzero_si256();
sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
sumi = ggml_mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
sumi = ggml_mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
sumi = ggml_mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
sumi = ggml_mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
return sumi;
#else
auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),