llama.cpp/ggml/src/ggml-hexagon/htp/hvx-mm-kernels-tiled.h
Max Krasnyansky 8be759e6f7
hexagon: MUL_MAT and MUL_MAT_ID rework : 32x32 tiled weight repack, kernel-params, cached graphs (#24954)
* hex-mm: new weight layout and fusion updates

* hvx-mm: unroll the new tiled vec_dots to optimize hvx register util

* hex-mm: optimize dyn.quant format for q8_0 and q8_1 to reduce overhead in vec_dots.

* hvx-mm: parallel quantizer per block for large rows

* hvx-mm: simplify and futher optimize dyn.quant and vec_dots

* hvx-mm: keep intermediate per tile accumulators in fp16

* hmx-mm: optimize weight dequant by aligning the repacked tiles with the DMA

* hmx-mm: remove qweight scratch and just use vtcm_weight

* hmx-mm: remove all unused and obsolete code

* hmx-mm: the new tiled repack format is here to stay -- rename all x4x2 to _tiled

* hmx-mm: improve activation processing with dma prefetch

* hex-mm: fix hmx/hvx fallback logic and MUL_MAT_ID allocation (unbreaks OLMoE)

* hex-mm: align the weight tiles with dma just like we did in hmx-mm

* hex-mm: factor out common mm bits into htp/matmul-ops.h

* hex-mm: start moving mm kernel selection to the host

* hex-mm: move all of the matmul param compute into the host

* hmx-mm: restore pipelined mode

* hmx-mm: unroll the dequant functions to optimize register usage

* hmx-mm: further improve activation process

* hex-mm: use vtcm_seq_alloc for all vtcm allocations and define more common functions

* hex-mm: improve mm optimizer to acount for number of activation threads

* hex-mm: fix matmul-id kernel params selection (unbreaks OLMoE and LFM)

* hexagon: remove support for arch < v73 since HMX is now required for most use-cases

* hex-mm: cleanup naming for consistency

* hex-mm: make sure matmul fusion accounts for vtcm allocation

* hex-mm: minor cleanup for kernel_params definition

* hex-mm: replace hardcoded limits with proper checks for vtcm requirements

* hex-mm: add support for non-tiled mm as a fallback option and factor out hvx kernels into separate header

* hex-mm: remove unused functions

* hex-mm: add shorthand for MM_SELECT in run-tool script

* hvx-mm: factor out hvx/hmx microkernels and unify matmul entry and dispatch

* hex-mm: further cleanup matmul fallback path

* hex-mm: refactor matmul entry point and dispatch a bit further

* hexagon: update cmake build to enable hmx for everything

* hex-ops: optimize kernel_param updates and include summary in the logs

* hex-mm: add support for GGML_HEXAGON_MM_SELECT

* hex-mm: add hex-common header

* hex-mm: pass correct number of tasks to workpool

* hex-mm: add proper checks for no-work in dyn.quant tasks

* hex-mm: convert all quantizers into a macro

* hex-mm: fix hvx-flat fallback to pass all MUL_MAT tests

* hex-mm: vectorize q8_1 quantizer

* hex-mm: improve fused ffn mm stride handling

* hex-mm: consistent use of n_threads and pipeline in kernel_params

* hexagon: minor formatting

* hex-mm: update MUL_MAT_ID kernel_param handling to make sure host/npu are in sync

* hvx-mm: go back to accumulating in fp32 in tiled hvx kernels, more accurate and same perf

* hvx-mm: unroll the loops and remove masking that is not needed for tiled accums

* hmx-mm: optimize activation processing (slit loops, some unrolling, etc)

* hmx-mm: minor optimization for output processing

* hex-mm: consistent use of uint32_t and size_t in mm kernels

* hex-mm: remove legacy restrictions for rows to be multiple of 256

* hexagon: replace sprintf with snprintf

* hex-mm: relax hardcoded nrows checks and rely on VTCM size requirements

* hexagon: minor alignment fix

* hexagon: fix trailing spaces

* hex-mm: relax padding from 256 to 128 (leftovers)

* hex-mm: remove redundant checks for weight align to 128

we always use 2D dma for the weights and align them properly

* hmx-mm: MUL_MAT_ID better work distribution between hvx threads and hmx tracing

* hex-mm: specialize per-token mmid activation handling

* hex-profile: update python scripts to handle kernel-params section in the logging output

* hex-mm: move n_prefetch (aka dma_depth) into kernel params and remove unused fields

* hex-trace: use easier to parse format, simply and fix post-proc scripts

* hmx-mm: relax 32 row limit for output processing which helps utilization

* hmx-mm: use start-chunk idx for tracing info

* hmx-mm: parameterize activation dma pipeline

* hexagon: add support for simple graph caching to avoid recomputing kernel-params

* hex-mm: remove left-over repack functions

* hex-mm: tighten n_prefetch asserts

* hex-mm: remove duplicate round/align_up helper

* hexagon: cleanup common header used in host/npu

* hexagon: update early wakeup threshold

* hmx-mm: define cost constants and update solver to assume that repacked ne[1] is padded to 32

* hmx-mm: make precompute_matmul a bit more readable (split into smaller functions, etc)

* hex-mm: remove n_threads constraint

* hex-mm: minor formatting updates

* hex-mm: remove obsolete profiling logs

* hex-mm: restore hardcode gate to refuse lm-head to avoid repacking that tensor
2026-06-24 12:14:25 -07:00

1141 lines
51 KiB
C

// Dynamic quantizers that produce tiled activations
static inline void quantize_block_f32_q8_1_tiled(float * restrict x, uint8_t * restrict y_block) {
assert((unsigned long) x % 128 == 0);
assert((unsigned long) y_block % 128 == 0);
HVX_Vector * vx = (HVX_Vector *) x;
HVX_Vector zero = Q6_V_vzero();
HVX_Vector vmax0_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[0]));
HVX_Vector vmax1_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[1]));
HVX_Vector vmax2_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[2]));
HVX_Vector vmax3_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[3]));
HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero);
HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero);
HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero);
HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero);
HVX_Vector vmax0_qf = Q6_Vqf32_vsub_VsfVsf(vmax0_sf, zero);
HVX_Vector vmax1_qf = Q6_Vqf32_vsub_VsfVsf(vmax1_sf, zero);
HVX_Vector vmax2_qf = Q6_Vqf32_vsub_VsfVsf(vmax2_sf, zero);
HVX_Vector vmax3_qf = Q6_Vqf32_vsub_VsfVsf(vmax3_sf, zero);
HVX_Vector vmax01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax1_qf, vmax0_qf)));
HVX_Vector vmax23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax3_qf, vmax2_qf)));
HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf)));
HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf)));
HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
HVX_Vector vd01_hf = Q6_Vhf_equals_Vqf16(vd01_qf16);
HVX_Vector vd23_hf = Q6_Vhf_equals_Vqf16(vd23_qf16);
HVX_Vector vd01_inv_hf = hvx_vec_inverse_f16(vd01_hf);
HVX_Vector vd23_inv_hf = hvx_vec_inverse_f16(vd23_hf);
vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd01_inv_hf));
vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd23_inv_hf));
HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf);
HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf);
HVX_Vector vx_i8 = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16);
const HVX_Vector ones = Q6_Vb_vsplat_R(1);
HVX_Vector v_sums = Q6_Vw_vrmpy_VbVb(vx_i8, ones);
v_sums = Q6_Vw_vadd_VwVw(v_sums, Q6_V_vror_VR(v_sums, 4));
v_sums = Q6_Vw_vadd_VwVw(v_sums, Q6_V_vror_VR(v_sums, 8));
v_sums = Q6_Vw_vadd_VwVw(v_sums, Q6_V_vror_VR(v_sums, 16));
float vmax0[32] __attribute__((aligned(128)));
float vmax1[32] __attribute__((aligned(128)));
float vmax2[32] __attribute__((aligned(128)));
float vmax3[32] __attribute__((aligned(128)));
int32_t sums[32] __attribute__((aligned(128)));
hvx_vec_store_u(vmax0, 128, vmax0_sf);
hvx_vec_store_u(vmax1, 128, vmax1_sf);
hvx_vec_store_u(vmax2, 128, vmax2_sf);
hvx_vec_store_u(vmax3, 128, vmax3_sf);
hvx_vec_store_u(sums, 128, v_sums);
float d0 = vmax0[0] / 127.0f;
float d1 = vmax1[0] / 127.0f;
float d2 = vmax2[0] / 127.0f;
float d3 = vmax3[0] / 127.0f;
static const uint8_t __attribute__((aligned(128))) repl[128] = {
0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
};
HVX_Vector v_repl_ctrl = * (const HVX_Vector *) repl;
for (int b = 0; b < 4; b++) {
HVX_Vector v_act = Q6_V_vror_VR(vx_i8, b * 32);
HVX_Vector r0 = Q6_V_vdelta_VV(v_act, v_repl_ctrl);
HVX_Vector r1 = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act, 4), v_repl_ctrl);
HVX_Vector r2 = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act, 8), v_repl_ctrl);
HVX_Vector r3 = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act, 12), v_repl_ctrl);
HVX_Vector r4 = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act, 16), v_repl_ctrl);
HVX_Vector r5 = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act, 20), v_repl_ctrl);
HVX_Vector r6 = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act, 24), v_repl_ctrl);
HVX_Vector r7 = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act, 28), v_repl_ctrl);
__fp16 scale_h, offset_h;
if (b == 0) {
scale_h = (__fp16) d0;
offset_h = (__fp16) (sums[0] * d0);
} else if (b == 1) {
scale_h = (__fp16) d1;
offset_h = (__fp16) (sums[8] * d1);
} else if (b == 2) {
scale_h = (__fp16) d2;
offset_h = (__fp16) (sums[16] * d2);
} else {
scale_h = (__fp16) d3;
offset_h = (__fp16) (sums[24] * d3);
}
HVX_Vector r_scale = Q6_Vh_vsplat_R(*(int16_t *)&scale_h);
HVX_Vector r_offset = Q6_Vh_vsplat_R(*(int16_t *)&offset_h);
HVX_Vector * restrict dst = (HVX_Vector *) (y_block + b * 1280);
dst[0] = r0;
dst[1] = r1;
dst[2] = r2;
dst[3] = r3;
dst[4] = r4;
dst[5] = r5;
dst[6] = r6;
dst[7] = r7;
dst[8] = r_scale;
dst[9] = r_offset;
}
}
static inline void quantize_block_f32_q8_0_tiled(float * restrict x, uint8_t * restrict y_block) {
assert((unsigned long) x % 128 == 0);
assert((unsigned long) y_block % 128 == 0);
HVX_Vector * vx = (HVX_Vector *) x;
HVX_Vector zero = Q6_V_vzero();
HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero);
HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero);
HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero);
HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero);
HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf)));
HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf)));
HVX_Vector vmax_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx01_hf));
vmax_hf = hvx_vec_reduce_max2_f16(hvx_vec_abs_f16(vx23_hf), vmax_hf);
HVX_Vector vd_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax_hf, Q6_Vh_vsplat_R(0x2008));
HVX_Vector vd_hf = Q6_Vhf_equals_Vqf16(vd_qf16);
HVX_Vector vd_inv_hf = hvx_vec_inverse_f16(vd_hf);
vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd_inv_hf));
vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd_inv_hf));
HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf);
HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf);
HVX_Vector vx_i8 = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16);
HVX_Vector r_scale = hvx_vec_repl_f16(vd_hf);
static const uint8_t __attribute__((aligned(128))) repl[128] = {
0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
};
HVX_Vector v_repl_ctrl = * (const HVX_Vector *) repl;
for (int b = 0; b < 4; b++) {
HVX_Vector v_act = Q6_V_vror_VR(vx_i8, b * 32);
HVX_Vector r0 = Q6_V_vdelta_VV(v_act, v_repl_ctrl);
HVX_Vector r1 = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act, 4), v_repl_ctrl);
HVX_Vector r2 = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act, 8), v_repl_ctrl);
HVX_Vector r3 = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act, 12), v_repl_ctrl);
HVX_Vector r4 = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act, 16), v_repl_ctrl);
HVX_Vector r5 = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act, 20), v_repl_ctrl);
HVX_Vector r6 = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act, 24), v_repl_ctrl);
HVX_Vector r7 = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act, 28), v_repl_ctrl);
HVX_Vector * restrict dst = (HVX_Vector *) (y_block + b * 1152);
dst[0] = r0;
dst[1] = r1;
dst[2] = r2;
dst[3] = r3;
dst[4] = r4;
dst[5] = r5;
dst[6] = r6;
dst[7] = r7;
dst[8] = r_scale;
}
}
static void quantize_row_f32_q8_0_tiled(float * restrict x, uint8_t * restrict y, uint32_t k) {
assert(k % 32 == 0);
const uint32_t qk = QK_Q8_0_TILED;
const uint32_t nb = (k + qk - 1) / qk;
for (uint32_t i = 0; i < nb; i++) {
uint8_t * restrict y_block = y + i * 4 * 1152;
quantize_block_f32_q8_0_tiled(x + i * qk, y_block);
}
}
static void quantize_row_f32_q8_1_tiled(float * restrict x, uint8_t * restrict y, uint32_t k) {
assert(k % 32 == 0);
const uint32_t qk = QK_Q8_0_TILED;
const uint32_t nb = (k + qk - 1) / qk;
for (uint32_t i = 0; i < nb; i++) {
uint8_t * restrict y_block = y + i * 4 * 1280;
quantize_block_f32_q8_1_tiled(x + i * qk, y_block);
}
}
// Dot kernels & helpers that consume tiled activations
static inline HVX_Vector hvx_vec_mul_f16_f16_to_f32_lower32(HVX_Vector v1, HVX_Vector v2) {
#if __HVX_ARCH__ >= 79
HVX_VectorPair p = Q6_Wsf_vmpy_VhfVhf(v1, v2);
return Q6_V_lo_W(Q6_W_vshuff_VVR(Q6_V_hi_W(p), Q6_V_lo_W(p), -4));
#else
HVX_VectorPair p = Q6_Wqf32_vmpy_VhfVhf(v1, v2);
HVX_Vector hi = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(p));
HVX_Vector lo = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(p));
return Q6_V_lo_W(Q6_W_vshuff_VVR(hi, lo, -4));
#endif
}
static inline HVX_Vector unpack_and_interleave_4bit(HVX_Vector v_a, HVX_Vector v_b, HVX_Vector mask_h4) {
HVX_Vector v_W0 = Q6_V_vand_VV(v_a, mask_h4);
HVX_Vector v_W1 = Q6_Vub_vlsr_VubR(v_a, 4);
HVX_Vector v_W2 = Q6_V_vand_VV(v_b, mask_h4);
HVX_Vector v_W3 = Q6_Vub_vlsr_VubR(v_b, 4);
HVX_VectorPair v01_pair = Q6_W_vshuff_VVR(v_W1, v_W0, -1);
HVX_VectorPair v23_pair = Q6_W_vshuff_VVR(v_W3, v_W2, -1);
HVX_VectorPair v0123_pair = Q6_W_vshuff_VVR(Q6_V_lo_W(v23_pair), Q6_V_lo_W(v01_pair), -2);
return Q6_V_lo_W(v0123_pair);
}
static inline HVX_VectorPair unpack_and_interleave_4bit_x2(HVX_Vector v_src, HVX_Vector mask_h4) {
HVX_Vector v_lo = Q6_V_vand_VV(v_src, mask_h4);
HVX_Vector v_hi = Q6_Vub_vlsr_VubR(v_src, 4);
HVX_VectorPair v01_pair = Q6_W_vshuff_VVR(v_hi, v_lo, -1);
HVX_Vector v01_lo = Q6_V_lo_W(v01_pair);
HVX_Vector v01_hi = Q6_V_hi_W(v01_pair);
HVX_Vector v23_lo = Q6_V_valign_VVR(v01_hi, v01_lo, 64);
HVX_Vector v_W0 = Q6_V_lo_W(Q6_W_vshuff_VVR(v23_lo, v01_lo, -2));
HVX_Vector v67_lo = Q6_V_valign_VVR(v01_lo, v01_hi, 64);
HVX_Vector v_W1 = Q6_V_lo_W(Q6_W_vshuff_VVR(v67_lo, v01_hi, -2));
return Q6_W_vcombine_VV(v_W1, v_W0);
}
static inline HVX_Vector accum_4bit_32x1(
const HVX_Vector * restrict vptr,
const HVX_Vector * restrict v_act,
HVX_Vector i8
) {
HVX_Vector v_sum0 = Q6_V_vzero();
HVX_Vector v_sum1 = Q6_V_vzero();
HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
#pragma unroll
for (int i = 0; i < 4; i++) {
HVX_VectorPair v_W_pair = unpack_and_interleave_4bit_x2(vptr[i], mask_h4);
HVX_Vector v_W0 = Q6_Vb_vsub_VbVb(Q6_V_lo_W(v_W_pair), i8);
HVX_Vector v_W1 = Q6_Vb_vsub_VbVb(Q6_V_hi_W(v_W_pair), i8);
v_sum0 = Q6_Vw_vrmpyacc_VwVbVb(v_sum0, v_W0, v_act[i * 2 + 0]);
v_sum1 = Q6_Vw_vrmpyacc_VwVbVb(v_sum1, v_W1, v_act[i * 2 + 1]);
}
return Q6_Vw_vadd_VwVw(v_sum0, v_sum1);
}
static inline HVX_Vector accum_4bit_32x1_lut(
const HVX_Vector * restrict vptr,
const HVX_Vector * restrict v_act,
HVX_Vector mask_h4,
HVX_Vector lut
) {
HVX_Vector v_sum0 = Q6_V_vzero();
HVX_Vector v_sum1 = Q6_V_vzero();
#pragma unroll
for (int i = 0; i < 4; i++) {
HVX_VectorPair v_W_pair = unpack_and_interleave_4bit_x2(vptr[i], mask_h4);
HVX_Vector v_W0 = Q6_Vb_vlut32_VbVbI(Q6_V_lo_W(v_W_pair), lut, 0);
HVX_Vector v_W1 = Q6_Vb_vlut32_VbVbI(Q6_V_hi_W(v_W_pair), lut, 0);
v_sum0 = Q6_Vw_vrmpyacc_VwVbVb(v_sum0, v_W0, v_act[i * 2 + 0]);
v_sum1 = Q6_Vw_vrmpyacc_VwVbVb(v_sum1, v_W1, v_act[i * 2 + 1]);
}
return Q6_Vw_vadd_VwVw(v_sum0, v_sum1);
}
static inline HVX_VectorPair accum_4bit_32x2(
const HVX_Vector * restrict vptr,
const HVX_Vector * restrict v_act0,
const HVX_Vector * restrict v_act1,
HVX_Vector i8
) {
HVX_Vector v_sum0 = Q6_V_vzero();
HVX_Vector v_sum1 = Q6_V_vzero();
HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
#pragma unroll
for (int i = 0; i < 4; i++) {
HVX_VectorPair v_W_pair = unpack_and_interleave_4bit_x2(vptr[i], mask_h4);
HVX_Vector v_W0 = Q6_Vb_vsub_VbVb(Q6_V_lo_W(v_W_pair), i8);
HVX_Vector v_W1 = Q6_Vb_vsub_VbVb(Q6_V_hi_W(v_W_pair), i8);
v_sum0 = Q6_Vw_vrmpyacc_VwVbVb(v_sum0, v_W0, v_act0[i * 2 + 0]);
v_sum0 = Q6_Vw_vrmpyacc_VwVbVb(v_sum0, v_W1, v_act0[i * 2 + 1]);
v_sum1 = Q6_Vw_vrmpyacc_VwVbVb(v_sum1, v_W0, v_act1[i * 2 + 0]);
v_sum1 = Q6_Vw_vrmpyacc_VwVbVb(v_sum1, v_W1, v_act1[i * 2 + 1]);
}
return Q6_W_vcombine_VV(v_sum1, v_sum0);
}
static inline HVX_VectorPair accum_4bit_32x2_lut(
const HVX_Vector * restrict vptr,
const HVX_Vector * restrict v_act0,
const HVX_Vector * restrict v_act1,
HVX_Vector mask_h4,
HVX_Vector lut
) {
HVX_Vector v_sum0 = Q6_V_vzero();
HVX_Vector v_sum1 = Q6_V_vzero();
#pragma unroll
for (int i = 0; i < 4; i++) {
HVX_VectorPair v_W_pair = unpack_and_interleave_4bit_x2(vptr[i], mask_h4);
HVX_Vector v_W0 = Q6_Vb_vlut32_VbVbI(Q6_V_lo_W(v_W_pair), lut, 0);
HVX_Vector v_W1 = Q6_Vb_vlut32_VbVbI(Q6_V_hi_W(v_W_pair), lut, 0);
v_sum0 = Q6_Vw_vrmpyacc_VwVbVb(v_sum0, v_W0, v_act0[i * 2 + 0]);
v_sum0 = Q6_Vw_vrmpyacc_VwVbVb(v_sum0, v_W1, v_act0[i * 2 + 1]);
v_sum1 = Q6_Vw_vrmpyacc_VwVbVb(v_sum1, v_W0, v_act1[i * 2 + 0]);
v_sum1 = Q6_Vw_vrmpyacc_VwVbVb(v_sum1, v_W1, v_act1[i * 2 + 1]);
}
return Q6_W_vcombine_VV(v_sum1, v_sum0);
}
static inline HVX_Vector accum_q8_0_32x1(
const HVX_Vector * restrict vptr,
const HVX_Vector * restrict v_act
) {
HVX_Vector v_sum = Q6_V_vzero();
#pragma unroll
for (int g = 0; g < 8; g++) {
HVX_Vector v_rot = Q6_V_vror_VR(vptr[g], 64);
HVX_Vector v_W = Q6_V_lo_W(Q6_W_vshuff_VVR(v_rot, vptr[g], -2));
v_sum = Q6_Vw_vrmpyacc_VwVbVb(v_sum, v_W, v_act[g]);
}
return v_sum;
}
static inline HVX_VectorPair accum_q8_0_32x2(
const HVX_Vector * restrict vptr,
const HVX_Vector * restrict v_act0,
const HVX_Vector * restrict v_act1
) {
HVX_Vector v_sum0 = Q6_V_vzero();
HVX_Vector v_sum1 = Q6_V_vzero();
#pragma unroll
for (int g = 0; g < 8; g++) {
HVX_Vector v_rot = Q6_V_vror_VR(vptr[g], 64);
HVX_Vector v_W = Q6_V_lo_W(Q6_W_vshuff_VVR(v_rot, vptr[g], -2));
v_sum0 = Q6_Vw_vrmpyacc_VwVbVb(v_sum0, v_W, v_act0[g]);
v_sum1 = Q6_Vw_vrmpyacc_VwVbVb(v_sum1, v_W, v_act1[g]);
}
return Q6_W_vcombine_VV(v_sum1, v_sum0);
}
static void tiled_vec_dot_q4_0_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y_q = vy;
HVX_Vector v_sum_float = Q6_V_vzero();
HVX_Vector i8 = Q6_Vb_vsplat_R(8);
uint32_t n_k_tiles = n / 32;
for (uint32_t kt = 0; kt < n_k_tiles; kt++) {
const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 640);
const HVX_Vector * restrict v_act = (const HVX_Vector *) (y_q + kt * 1152);
HVX_Vector v_sum = accum_4bit_32x1(vptr, v_act, i8);
HVX_Vector v_sum_sf = Q6_Vsf_equals_Vw(v_sum);
HVX_Vector v_scale_w = vptr[4];
HVX_Vector v_scale_a = v_act[8];
HVX_Vector v_scale_comb = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a);
HVX_Vector v_sum_scaled = hvx_vec_mul_f32_f32(v_sum_sf, v_scale_comb);
v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled);
}
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
}
static void tiled_vec_dot_q4_0_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y0_q = vy0;
const uint8_t * restrict y1_q = vy1;
HVX_Vector v_sum_float_c0 = Q6_V_vzero();
HVX_Vector v_sum_float_c1 = Q6_V_vzero();
HVX_Vector i8 = Q6_Vb_vsplat_R(8);
uint32_t n_k_tiles = n / 32;
uint32_t kt = 0;
for (; kt + 1 < n_k_tiles; kt += 2) {
const HVX_Vector * restrict vptr0 = (const HVX_Vector *) (tile_ptr + (kt + 0) * 640);
const HVX_Vector * restrict v_act0_0 = (const HVX_Vector *) (y0_q + (kt + 0) * 1152);
const HVX_Vector * restrict v_act1_0 = (const HVX_Vector *) (y1_q + (kt + 0) * 1152);
const HVX_Vector * restrict vptr1 = (const HVX_Vector *) (tile_ptr + (kt + 1) * 640);
const HVX_Vector * restrict v_act0_1 = (const HVX_Vector *) (y0_q + (kt + 1) * 1152);
const HVX_Vector * restrict v_act1_1 = (const HVX_Vector *) (y1_q + (kt + 1) * 1152);
HVX_VectorPair v_sums0 = accum_4bit_32x2(vptr0, v_act0_0, v_act1_0, i8);
HVX_VectorPair v_sums1 = accum_4bit_32x2(vptr1, v_act0_1, v_act1_1, i8);
HVX_Vector v_sum_c0_0 = Q6_V_lo_W(v_sums0);
HVX_Vector v_sum_c1_0 = Q6_V_hi_W(v_sums0);
HVX_Vector v_sum_c0_1 = Q6_V_lo_W(v_sums1);
HVX_Vector v_sum_c1_1 = Q6_V_hi_W(v_sums1);
HVX_Vector v_sum_sf_c0_0 = Q6_Vsf_equals_Vw(v_sum_c0_0);
HVX_Vector v_sum_sf_c1_0 = Q6_Vsf_equals_Vw(v_sum_c1_0);
HVX_Vector v_sum_sf_c0_1 = Q6_Vsf_equals_Vw(v_sum_c0_1);
HVX_Vector v_sum_sf_c1_1 = Q6_Vsf_equals_Vw(v_sum_c1_1);
HVX_Vector v_scale_w0 = vptr0[4];
HVX_Vector v_scale_w1 = vptr1[4];
HVX_Vector v_scale_a_c0_0 = v_act0_0[8];
HVX_Vector v_scale_a_c1_0 = v_act1_0[8];
HVX_Vector v_scale_a_c0_1 = v_act0_1[8];
HVX_Vector v_scale_a_c1_1 = v_act1_1[8];
HVX_Vector v_scale_comb_c0_0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w0, v_scale_a_c0_0);
HVX_Vector v_scale_comb_c1_0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w0, v_scale_a_c1_0);
HVX_Vector v_scale_comb_c0_1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w1, v_scale_a_c0_1);
HVX_Vector v_scale_comb_c1_1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w1, v_scale_a_c1_1);
HVX_Vector v_sum_scaled_c0_0 = hvx_vec_mul_f32_f32(v_sum_sf_c0_0, v_scale_comb_c0_0);
HVX_Vector v_sum_scaled_c1_0 = hvx_vec_mul_f32_f32(v_sum_sf_c1_0, v_scale_comb_c1_0);
HVX_Vector v_sum_scaled_c0_1 = hvx_vec_mul_f32_f32(v_sum_sf_c0_1, v_scale_comb_c0_1);
HVX_Vector v_sum_scaled_c1_1 = hvx_vec_mul_f32_f32(v_sum_sf_c1_1, v_scale_comb_c1_1);
v_sum_float_c0 = hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vec_add_f32_f32(v_sum_scaled_c0_0, v_sum_scaled_c0_1));
v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vec_add_f32_f32(v_sum_scaled_c1_0, v_sum_scaled_c1_1));
}
for (; kt < n_k_tiles; kt++) {
const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 640);
const HVX_Vector * restrict v_act0 = (const HVX_Vector *) (y0_q + kt * 1152);
const HVX_Vector * restrict v_act1 = (const HVX_Vector *) (y1_q + kt * 1152);
HVX_VectorPair v_sums = accum_4bit_32x2(vptr, v_act0, v_act1, i8);
HVX_Vector v_sum_c0 = Q6_V_lo_W(v_sums);
HVX_Vector v_sum_c1 = Q6_V_hi_W(v_sums);
HVX_Vector v_sum_sf_c0 = Q6_Vsf_equals_Vw(v_sum_c0);
HVX_Vector v_sum_sf_c1 = Q6_Vsf_equals_Vw(v_sum_c1);
HVX_Vector v_scale_w = vptr[4];
HVX_Vector v_scale_a_c0 = v_act0[8];
HVX_Vector v_scale_a_c1 = v_act1[8];
HVX_Vector v_scale_comb_c0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a_c0);
HVX_Vector v_scale_comb_c1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a_c1);
HVX_Vector v_sum_scaled_c0 = hvx_vec_mul_f32_f32(v_sum_sf_c0, v_scale_comb_c0);
HVX_Vector v_sum_scaled_c1 = hvx_vec_mul_f32_f32(v_sum_sf_c1, v_scale_comb_c1);
v_sum_float_c0 = hvx_vec_add_f32_f32(v_sum_float_c0, v_sum_scaled_c0);
v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1);
}
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
}
static void tiled_vec_dot_q4_1_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y_q = vy;
HVX_Vector v_sum_float = Q6_V_vzero();
uint32_t n_k_tiles = n / 32;
for (uint32_t kt = 0; kt < n_k_tiles; kt++) {
const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 640);
const HVX_Vector * restrict v_act = (const HVX_Vector *) (y_q + kt * 1280);
HVX_Vector v_sum = accum_4bit_32x1(vptr, v_act, Q6_V_vzero());
HVX_Vector v_sum_sf = Q6_Vsf_equals_Vw(v_sum);
HVX_Vector v_scale_offset = vptr[4];
HVX_VectorPair p_deal = Q6_W_vdeal_VVR(v_scale_offset, v_scale_offset, -2);
HVX_Vector v_scale = Q6_V_lo_W(p_deal);
HVX_Vector v_offset = Q6_V_hi_W(p_deal);
HVX_Vector v_scale_a = v_act[8];
HVX_Vector v_sum_a = v_act[9];
HVX_Vector v_scale_comb = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale, v_scale_a);
HVX_Vector v_offset_comb = hvx_vec_mul_f16_f16_to_f32_lower32(v_offset, v_sum_a);
HVX_Vector v_scaled_dot = hvx_vec_mul_f32_f32(v_sum_sf, v_scale_comb);
HVX_Vector v_sum_scaled = hvx_vec_add_f32_f32(v_scaled_dot, v_offset_comb);
v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled);
}
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
}
static void tiled_vec_dot_q4_1_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y0_q = vy0;
const uint8_t * restrict y1_q = vy1;
HVX_Vector v_sum_float_c0 = Q6_V_vzero();
HVX_Vector v_sum_float_c1 = Q6_V_vzero();
uint32_t n_k_tiles = n / 32;
uint32_t kt = 0;
for (; kt + 1 < n_k_tiles; kt += 2) {
const HVX_Vector * restrict vptr0 = (const HVX_Vector *) (tile_ptr + (kt + 0) * 640);
const HVX_Vector * restrict v_act0_0 = (const HVX_Vector *) (y0_q + (kt + 0) * 1280);
const HVX_Vector * restrict v_act1_0 = (const HVX_Vector *) (y1_q + (kt + 0) * 1280);
const HVX_Vector * restrict vptr1 = (const HVX_Vector *) (tile_ptr + (kt + 1) * 640);
const HVX_Vector * restrict v_act0_1 = (const HVX_Vector *) (y0_q + (kt + 1) * 1280);
const HVX_Vector * restrict v_act1_1 = (const HVX_Vector *) (y1_q + (kt + 1) * 1280);
HVX_VectorPair v_sums0 = accum_4bit_32x2(vptr0, v_act0_0, v_act1_0, Q6_V_vzero());
HVX_VectorPair v_sums1 = accum_4bit_32x2(vptr1, v_act0_1, v_act1_1, Q6_V_vzero());
HVX_Vector v_sum_c0_0 = Q6_V_lo_W(v_sums0);
HVX_Vector v_sum_c1_0 = Q6_V_hi_W(v_sums0);
HVX_Vector v_sum_c0_1 = Q6_V_lo_W(v_sums1);
HVX_Vector v_sum_c1_1 = Q6_V_hi_W(v_sums1);
HVX_Vector v_sum_sf_c0_0 = Q6_Vsf_equals_Vw(v_sum_c0_0);
HVX_Vector v_sum_sf_c1_0 = Q6_Vsf_equals_Vw(v_sum_c1_0);
HVX_Vector v_sum_sf_c0_1 = Q6_Vsf_equals_Vw(v_sum_c0_1);
HVX_Vector v_sum_sf_c1_1 = Q6_Vsf_equals_Vw(v_sum_c1_1);
HVX_Vector v_scale_offset0 = vptr0[4];
HVX_VectorPair p_deal0 = Q6_W_vdeal_VVR(v_scale_offset0, v_scale_offset0, -2);
HVX_Vector v_scale0 = Q6_V_lo_W(p_deal0);
HVX_Vector v_offset0 = Q6_V_hi_W(p_deal0);
HVX_Vector v_scale_offset1 = vptr1[4];
HVX_VectorPair p_deal1 = Q6_W_vdeal_VVR(v_scale_offset1, v_scale_offset1, -2);
HVX_Vector v_scale1 = Q6_V_lo_W(p_deal1);
HVX_Vector v_offset1 = Q6_V_hi_W(p_deal1);
HVX_Vector v_scale_a_c0_0 = v_act0_0[8];
HVX_Vector v_sum_a_c0_0 = v_act0_0[9];
HVX_Vector v_scale_a_c1_0 = v_act1_0[8];
HVX_Vector v_sum_a_c1_0 = v_act1_0[9];
HVX_Vector v_scale_a_c0_1 = v_act0_1[8];
HVX_Vector v_sum_a_c0_1 = v_act0_1[9];
HVX_Vector v_scale_a_c1_1 = v_act1_1[8];
HVX_Vector v_sum_a_c1_1 = v_act1_1[9];
HVX_Vector v_scale_comb_c0_0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale0, v_scale_a_c0_0);
HVX_Vector v_offset_comb_c0_0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_offset0, v_sum_a_c0_0);
HVX_Vector v_scale_comb_c1_0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale0, v_scale_a_c1_0);
HVX_Vector v_offset_comb_c1_0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_offset0, v_sum_a_c1_0);
HVX_Vector v_scale_comb_c0_1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale1, v_scale_a_c0_1);
HVX_Vector v_offset_comb_c0_1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_offset1, v_sum_a_c0_1);
HVX_Vector v_scale_comb_c1_1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale1, v_scale_a_c1_1);
HVX_Vector v_offset_comb_c1_1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_offset1, v_sum_a_c1_1);
HVX_Vector v_scaled_dot_c0_0 = hvx_vec_mul_f32_f32(v_sum_sf_c0_0, v_scale_comb_c0_0);
HVX_Vector v_sum_scaled_c0_0 = hvx_vec_add_f32_f32(v_scaled_dot_c0_0, v_offset_comb_c0_0);
HVX_Vector v_scaled_dot_c1_0 = hvx_vec_mul_f32_f32(v_sum_sf_c1_0, v_scale_comb_c1_0);
HVX_Vector v_sum_scaled_c1_0 = hvx_vec_add_f32_f32(v_scaled_dot_c1_0, v_offset_comb_c1_0);
HVX_Vector v_scaled_dot_c0_1 = hvx_vec_mul_f32_f32(v_sum_sf_c0_1, v_scale_comb_c0_1);
HVX_Vector v_sum_scaled_c0_1 = hvx_vec_add_f32_f32(v_scaled_dot_c0_1, v_offset_comb_c0_1);
HVX_Vector v_scaled_dot_c1_1 = hvx_vec_mul_f32_f32(v_sum_sf_c1_1, v_scale_comb_c1_1);
HVX_Vector v_sum_scaled_c1_1 = hvx_vec_add_f32_f32(v_scaled_dot_c1_1, v_offset_comb_c1_1);
v_sum_float_c0 = hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vec_add_f32_f32(v_sum_scaled_c0_0, v_sum_scaled_c0_1));
v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vec_add_f32_f32(v_sum_scaled_c1_0, v_sum_scaled_c1_1));
}
for (; kt < n_k_tiles; kt++) {
const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 640);
const HVX_Vector * restrict v_act0 = (const HVX_Vector *) (y0_q + kt * 1280);
const HVX_Vector * restrict v_act1 = (const HVX_Vector *) (y1_q + kt * 1280);
HVX_VectorPair v_sums = accum_4bit_32x2(vptr, v_act0, v_act1, Q6_V_vzero());
HVX_Vector v_sum_c0 = Q6_V_lo_W(v_sums);
HVX_Vector v_sum_c1 = Q6_V_hi_W(v_sums);
HVX_Vector v_sum_sf_c0 = Q6_Vsf_equals_Vw(v_sum_c0);
HVX_Vector v_sum_sf_c1 = Q6_Vsf_equals_Vw(v_sum_c1);
HVX_Vector v_scale_offset = vptr[4];
HVX_VectorPair p_deal = Q6_W_vdeal_VVR(v_scale_offset, v_scale_offset, -2);
HVX_Vector v_scale = Q6_V_lo_W(p_deal);
HVX_Vector v_offset = Q6_V_hi_W(p_deal);
HVX_Vector v_scale_a_c0 = v_act0[8];
HVX_Vector v_sum_a_c0 = v_act0[9];
HVX_Vector v_scale_a_c1 = v_act1[8];
HVX_Vector v_sum_a_c1 = v_act1[9];
HVX_Vector v_scale_comb_c0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale, v_scale_a_c0);
HVX_Vector v_offset_comb_c0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_offset, v_sum_a_c0);
HVX_Vector v_scale_comb_c1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale, v_scale_a_c1);
HVX_Vector v_offset_comb_c1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_offset, v_sum_a_c1);
HVX_Vector v_scaled_dot_c0 = hvx_vec_mul_f32_f32(v_sum_sf_c0, v_scale_comb_c0);
HVX_Vector v_sum_scaled_c0 = hvx_vec_add_f32_f32(v_scaled_dot_c0, v_offset_comb_c0);
HVX_Vector v_scaled_dot_c1 = hvx_vec_mul_f32_f32(v_sum_sf_c1, v_scale_comb_c1);
HVX_Vector v_sum_scaled_c1 = hvx_vec_add_f32_f32(v_scaled_dot_c1, v_offset_comb_c1);
v_sum_float_c0 = hvx_vec_add_f32_f32(v_sum_float_c0, v_sum_scaled_c0);
v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1);
}
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
}
static void tiled_vec_dot_q8_0_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y_q = vy;
HVX_Vector v_sum_float = Q6_V_vzero();
uint32_t n_k_tiles = n / 32;
for (uint32_t kt = 0; kt < n_k_tiles; kt++) {
const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 1152);
const HVX_Vector * restrict v_act = (const HVX_Vector *) (y_q + kt * 1152);
HVX_Vector v_sum = accum_q8_0_32x1(vptr, v_act);
HVX_Vector v_sum_sf = Q6_Vsf_equals_Vw(v_sum);
HVX_Vector v_scale_w = vptr[8];
HVX_Vector v_scale_a = v_act[8];
HVX_Vector v_scale_comb = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a);
HVX_Vector v_sum_scaled = hvx_vec_mul_f32_f32(v_sum_sf, v_scale_comb);
v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled);
}
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
}
static void tiled_vec_dot_q8_0_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y0_q = vy0;
const uint8_t * restrict y1_q = vy1;
HVX_Vector v_sum_float_c0 = Q6_V_vzero();
HVX_Vector v_sum_float_c1 = Q6_V_vzero();
uint32_t n_k_tiles = n / 32;
uint32_t kt = 0;
for (; kt + 1 < n_k_tiles; kt += 2) {
const HVX_Vector * restrict vptr0 = (const HVX_Vector *) (tile_ptr + (kt + 0) * 1152);
const HVX_Vector * restrict v_act0_0 = (const HVX_Vector *) (y0_q + (kt + 0) * 1152);
const HVX_Vector * restrict v_act1_0 = (const HVX_Vector *) (y1_q + (kt + 0) * 1152);
const HVX_Vector * restrict vptr1 = (const HVX_Vector *) (tile_ptr + (kt + 1) * 1152);
const HVX_Vector * restrict v_act0_1 = (const HVX_Vector *) (y0_q + (kt + 1) * 1152);
const HVX_Vector * restrict v_act1_1 = (const HVX_Vector *) (y1_q + (kt + 1) * 1152);
HVX_VectorPair v_sums0 = accum_q8_0_32x2(vptr0, v_act0_0, v_act1_0);
HVX_VectorPair v_sums1 = accum_q8_0_32x2(vptr1, v_act0_1, v_act1_1);
HVX_Vector v_sum_c0_0 = Q6_V_lo_W(v_sums0);
HVX_Vector v_sum_c1_0 = Q6_V_hi_W(v_sums0);
HVX_Vector v_sum_c0_1 = Q6_V_lo_W(v_sums1);
HVX_Vector v_sum_c1_1 = Q6_V_hi_W(v_sums1);
HVX_Vector v_sum_sf_c0_0 = Q6_Vsf_equals_Vw(v_sum_c0_0);
HVX_Vector v_sum_sf_c1_0 = Q6_Vsf_equals_Vw(v_sum_c1_0);
HVX_Vector v_sum_sf_c0_1 = Q6_Vsf_equals_Vw(v_sum_c0_1);
HVX_Vector v_sum_sf_c1_1 = Q6_Vsf_equals_Vw(v_sum_c1_1);
HVX_Vector v_scale_w0 = vptr0[8];
HVX_Vector v_scale_w1 = vptr1[8];
HVX_Vector v_scale_a_c0_0 = v_act0_0[8];
HVX_Vector v_scale_a_c1_0 = v_act1_0[8];
HVX_Vector v_scale_a_c0_1 = v_act0_1[8];
HVX_Vector v_scale_a_c1_1 = v_act1_1[8];
HVX_Vector v_scale_comb_c0_0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w0, v_scale_a_c0_0);
HVX_Vector v_scale_comb_c1_0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w0, v_scale_a_c1_0);
HVX_Vector v_scale_comb_c0_1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w1, v_scale_a_c0_1);
HVX_Vector v_scale_comb_c1_1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w1, v_scale_a_c1_1);
HVX_Vector v_sum_scaled_c0_0 = hvx_vec_mul_f32_f32(v_sum_sf_c0_0, v_scale_comb_c0_0);
HVX_Vector v_sum_scaled_c1_0 = hvx_vec_mul_f32_f32(v_sum_sf_c1_0, v_scale_comb_c1_0);
HVX_Vector v_sum_scaled_c0_1 = hvx_vec_mul_f32_f32(v_sum_sf_c0_1, v_scale_comb_c0_1);
HVX_Vector v_sum_scaled_c1_1 = hvx_vec_mul_f32_f32(v_sum_sf_c1_1, v_scale_comb_c1_1);
v_sum_float_c0 = hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vec_add_f32_f32(v_sum_scaled_c0_0, v_sum_scaled_c0_1));
v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vec_add_f32_f32(v_sum_scaled_c1_0, v_sum_scaled_c1_1));
}
for (; kt < n_k_tiles; kt++) {
const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 1152);
const HVX_Vector * restrict v_act0 = (const HVX_Vector *) (y0_q + kt * 1152);
const HVX_Vector * restrict v_act1 = (const HVX_Vector *) (y1_q + kt * 1152);
HVX_VectorPair v_sums = accum_q8_0_32x2(vptr, v_act0, v_act1);
HVX_Vector v_sum_c0 = Q6_V_lo_W(v_sums);
HVX_Vector v_sum_c1 = Q6_V_hi_W(v_sums);
HVX_Vector v_sum_sf_c0 = Q6_Vsf_equals_Vw(v_sum_c0);
HVX_Vector v_sum_sf_c1 = Q6_Vsf_equals_Vw(v_sum_c1);
HVX_Vector v_scale_w = vptr[8];
HVX_Vector v_scale_a_c0 = v_act0[8];
HVX_Vector v_scale_a_c1 = v_act1[8];
HVX_Vector v_scale_comb_c0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a_c0);
HVX_Vector v_scale_comb_c1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a_c1);
HVX_Vector v_sum_scaled_c0 = hvx_vec_mul_f32_f32(v_sum_sf_c0, v_scale_comb_c0);
HVX_Vector v_sum_scaled_c1 = hvx_vec_mul_f32_f32(v_sum_sf_c1, v_scale_comb_c1);
v_sum_float_c0 = hvx_vec_add_f32_f32(v_sum_float_c0, v_sum_scaled_c0);
v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1);
}
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
}
static void tiled_vec_dot_iq4nl_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y_q = vy;
HVX_Vector v_sum_float = Q6_V_vzero();
HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
HVX_Vector lut = *(const HVX_Vector *) kvalues_iq4nl_lut;
uint32_t n_k_tiles = n / 32;
for (uint32_t kt = 0; kt < n_k_tiles; kt++) {
const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 640);
const HVX_Vector * restrict v_act = (const HVX_Vector *) (y_q + kt * 1152);
HVX_Vector v_sum = accum_4bit_32x1_lut(vptr, v_act, mask_h4, lut);
HVX_Vector v_sum_sf = Q6_Vsf_equals_Vw(v_sum);
HVX_Vector v_scale_w = vptr[4];
HVX_Vector v_scale_a = v_act[8];
HVX_Vector v_scale_comb = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a);
HVX_Vector v_sum_scaled = hvx_vec_mul_f32_f32(v_sum_sf, v_scale_comb);
v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled);
}
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
}
static void tiled_vec_dot_iq4nl_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y0_q = vy0;
const uint8_t * restrict y1_q = vy1;
HVX_Vector v_sum_float_c0 = Q6_V_vzero();
HVX_Vector v_sum_float_c1 = Q6_V_vzero();
HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
HVX_Vector lut = *(const HVX_Vector *) kvalues_iq4nl_lut;
uint32_t n_k_tiles = n / 32;
uint32_t kt = 0;
for (; kt + 1 < n_k_tiles; kt += 2) {
const HVX_Vector * restrict vptr0 = (const HVX_Vector *) (tile_ptr + (kt + 0) * 640);
const HVX_Vector * restrict v_act0_0 = (const HVX_Vector *) (y0_q + (kt + 0) * 1152);
const HVX_Vector * restrict v_act1_0 = (const HVX_Vector *) (y1_q + (kt + 0) * 1152);
const HVX_Vector * restrict vptr1 = (const HVX_Vector *) (tile_ptr + (kt + 1) * 640);
const HVX_Vector * restrict v_act0_1 = (const HVX_Vector *) (y0_q + (kt + 1) * 1152);
const HVX_Vector * restrict v_act1_1 = (const HVX_Vector *) (y1_q + (kt + 1) * 1152);
HVX_VectorPair v_sums0 = accum_4bit_32x2_lut(vptr0, v_act0_0, v_act1_0, mask_h4, lut);
HVX_VectorPair v_sums1 = accum_4bit_32x2_lut(vptr1, v_act0_1, v_act1_1, mask_h4, lut);
HVX_Vector v_sum_c0_0 = Q6_V_lo_W(v_sums0);
HVX_Vector v_sum_c1_0 = Q6_V_hi_W(v_sums0);
HVX_Vector v_sum_c0_1 = Q6_V_lo_W(v_sums1);
HVX_Vector v_sum_c1_1 = Q6_V_hi_W(v_sums1);
HVX_Vector v_sum_sf_c0_0 = Q6_Vsf_equals_Vw(v_sum_c0_0);
HVX_Vector v_sum_sf_c1_0 = Q6_Vsf_equals_Vw(v_sum_c1_0);
HVX_Vector v_sum_sf_c0_1 = Q6_Vsf_equals_Vw(v_sum_c0_1);
HVX_Vector v_sum_sf_c1_1 = Q6_Vsf_equals_Vw(v_sum_c1_1);
HVX_Vector v_scale_w0 = vptr0[4];
HVX_Vector v_scale_w1 = vptr1[4];
HVX_Vector v_scale_a_c0_0 = v_act0_0[8];
HVX_Vector v_scale_a_c1_0 = v_act1_0[8];
HVX_Vector v_scale_a_c0_1 = v_act0_1[8];
HVX_Vector v_scale_a_c1_1 = v_act1_1[8];
HVX_Vector v_scale_comb_c0_0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w0, v_scale_a_c0_0);
HVX_Vector v_scale_comb_c1_0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w0, v_scale_a_c1_0);
HVX_Vector v_scale_comb_c0_1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w1, v_scale_a_c0_1);
HVX_Vector v_scale_comb_c1_1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w1, v_scale_a_c1_1);
HVX_Vector v_sum_scaled_c0_0 = hvx_vec_mul_f32_f32(v_sum_sf_c0_0, v_scale_comb_c0_0);
HVX_Vector v_sum_scaled_c1_0 = hvx_vec_mul_f32_f32(v_sum_sf_c1_0, v_scale_comb_c1_0);
HVX_Vector v_sum_scaled_c0_1 = hvx_vec_mul_f32_f32(v_sum_sf_c0_1, v_scale_comb_c0_1);
HVX_Vector v_sum_scaled_c1_1 = hvx_vec_mul_f32_f32(v_sum_sf_c1_1, v_scale_comb_c1_1);
v_sum_float_c0 = hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vec_add_f32_f32(v_sum_scaled_c0_0, v_sum_scaled_c0_1));
v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vec_add_f32_f32(v_sum_scaled_c1_0, v_sum_scaled_c1_1));
}
for (; kt < n_k_tiles; kt++) {
const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 640);
const HVX_Vector * restrict v_act0 = (const HVX_Vector *) (y0_q + kt * 1152);
const HVX_Vector * restrict v_act1 = (const HVX_Vector *) (y1_q + kt * 1152);
HVX_VectorPair v_sums = accum_4bit_32x2_lut(vptr, v_act0, v_act1, mask_h4, lut);
HVX_Vector v_sum_c0 = Q6_V_lo_W(v_sums);
HVX_Vector v_sum_c1 = Q6_V_hi_W(v_sums);
HVX_Vector v_sum_sf_c0 = Q6_Vsf_equals_Vw(v_sum_c0);
HVX_Vector v_sum_sf_c1 = Q6_Vsf_equals_Vw(v_sum_c1);
HVX_Vector v_scale_w = vptr[4];
HVX_Vector v_scale_a_c0 = v_act0[8];
HVX_Vector v_scale_a_c1 = v_act1[8];
HVX_Vector v_scale_comb_c0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a_c0);
HVX_Vector v_scale_comb_c1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a_c1);
HVX_Vector v_sum_scaled_c0 = hvx_vec_mul_f32_f32(v_sum_sf_c0, v_scale_comb_c0);
HVX_Vector v_sum_scaled_c1 = hvx_vec_mul_f32_f32(v_sum_sf_c1, v_scale_comb_c1);
v_sum_float_c0 = hvx_vec_add_f32_f32(v_sum_float_c0, v_sum_scaled_c0);
v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1);
}
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
}
static void tiled_vec_dot_mxfp4_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y_q = vy;
HVX_Vector v_sum_float = Q6_V_vzero();
HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
HVX_Vector lut = *(const HVX_Vector *) kvalues_mxfp4_lut;
HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0;
HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff);
uint32_t n_k_tiles = n / 32;
for (uint32_t kt = 0; kt < n_k_tiles; kt++) {
const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 640);
const HVX_Vector * restrict v_act = (const HVX_Vector *) (y_q + kt * 1152);
HVX_Vector v_sum = accum_4bit_32x1_lut(vptr, v_act, mask_h4, lut);
HVX_Vector v_sum_sf = Q6_Vsf_equals_Vw(v_sum);
HVX_Vector v_scale_w = hvx_vmem(tile_ptr + kt * 640 + 512);
HVX_Vector r0_d = Q6_V_vdelta_VV(v_scale_w, expand);
r0_d = Q6_V_vand_VV(r0_d, e8m0_mask);
HVX_Vector v_scale_w_f32 = Q6_Vw_vasl_VwR(r0_d, 23);
HVX_Vector v_scale_a_f16 = v_act[8];
HVX_VectorPair p_scale_a_f32 = hvx_vec_f16_to_f32_shuff(v_scale_a_f16);
HVX_Vector v_scale_a = Q6_V_lo_W(p_scale_a_f32);
HVX_Vector v_scale_comb = hvx_vec_mul_f32_f32(v_scale_w_f32, v_scale_a);
HVX_Vector v_sum_scaled = hvx_vec_mul_f32_f32(v_sum_sf, v_scale_comb);
v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled);
}
v_sum_float = hvx_vec_mul_f32_f32(v_sum_float, hvx_vec_splat_f32(0.5f));
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
}
static void tiled_vec_dot_mxfp4_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y0_q = vy0;
const uint8_t * restrict y1_q = vy1;
HVX_Vector v_sum_float_c0 = Q6_V_vzero();
HVX_Vector v_sum_float_c1 = Q6_V_vzero();
HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
HVX_Vector lut = *(const HVX_Vector *) kvalues_mxfp4_lut;
HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0;
HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff);
uint32_t n_k_tiles = n / 32;
uint32_t kt = 0;
for (; kt + 1 < n_k_tiles; kt += 2) {
const HVX_Vector * restrict vptr0 = (const HVX_Vector *) (tile_ptr + (kt + 0) * 640);
const HVX_Vector * restrict v_act0_0 = (const HVX_Vector *) (y0_q + (kt + 0) * 1152);
const HVX_Vector * restrict v_act1_0 = (const HVX_Vector *) (y1_q + (kt + 0) * 1152);
const HVX_Vector * restrict vptr1 = (const HVX_Vector *) (tile_ptr + (kt + 1) * 640);
const HVX_Vector * restrict v_act0_1 = (const HVX_Vector *) (y0_q + (kt + 1) * 1152);
const HVX_Vector * restrict v_act1_1 = (const HVX_Vector *) (y1_q + (kt + 1) * 1152);
HVX_VectorPair v_sums0 = accum_4bit_32x2_lut(vptr0, v_act0_0, v_act1_0, mask_h4, lut);
HVX_VectorPair v_sums1 = accum_4bit_32x2_lut(vptr1, v_act0_1, v_act1_1, mask_h4, lut);
HVX_Vector v_sum_c0_0 = Q6_V_lo_W(v_sums0);
HVX_Vector v_sum_c1_0 = Q6_V_hi_W(v_sums0);
HVX_Vector v_sum_c0_1 = Q6_V_lo_W(v_sums1);
HVX_Vector v_sum_c1_1 = Q6_V_hi_W(v_sums1);
HVX_Vector v_sum_sf_c0_0 = Q6_Vsf_equals_Vw(v_sum_c0_0);
HVX_Vector v_sum_sf_c1_0 = Q6_Vsf_equals_Vw(v_sum_c1_0);
HVX_Vector v_sum_sf_c0_1 = Q6_Vsf_equals_Vw(v_sum_c0_1);
HVX_Vector v_sum_sf_c1_1 = Q6_Vsf_equals_Vw(v_sum_c1_1);
HVX_Vector v_scale_w0 = hvx_vmem(tile_ptr + (kt + 0) * 640 + 512);
HVX_Vector r0_d0 = Q6_V_vdelta_VV(v_scale_w0, expand);
r0_d0 = Q6_V_vand_VV(r0_d0, e8m0_mask);
HVX_Vector v_scale_w_f32_0 = Q6_Vw_vasl_VwR(r0_d0, 23);
HVX_Vector v_scale_w1 = hvx_vmem(tile_ptr + (kt + 1) * 640 + 512);
HVX_Vector r0_d1 = Q6_V_vdelta_VV(v_scale_w1, expand);
r0_d1 = Q6_V_vand_VV(r0_d1, e8m0_mask);
HVX_Vector v_scale_w_f32_1 = Q6_Vw_vasl_VwR(r0_d1, 23);
HVX_Vector v_scale_a_c0_f16_0 = v_act0_0[8];
HVX_Vector v_scale_a_c1_f16_0 = v_act1_0[8];
HVX_Vector v_scale_a_c0_f16_1 = v_act0_1[8];
HVX_Vector v_scale_a_c1_f16_1 = v_act1_1[8];
HVX_VectorPair p_scale_a_c0_f32_0 = hvx_vec_f16_to_f32_shuff(v_scale_a_c0_f16_0);
HVX_VectorPair p_scale_a_c1_f32_0 = hvx_vec_f16_to_f32_shuff(v_scale_a_c1_f16_0);
HVX_VectorPair p_scale_a_c0_f32_1 = hvx_vec_f16_to_f32_shuff(v_scale_a_c0_f16_1);
HVX_VectorPair p_scale_a_c1_f32_1 = hvx_vec_f16_to_f32_shuff(v_scale_a_c1_f16_1);
HVX_Vector v_scale_a_c0_0 = Q6_V_lo_W(p_scale_a_c0_f32_0);
HVX_Vector v_scale_a_c1_0 = Q6_V_lo_W(p_scale_a_c1_f32_0);
HVX_Vector v_scale_a_c0_1 = Q6_V_lo_W(p_scale_a_c0_f32_1);
HVX_Vector v_scale_a_c1_1 = Q6_V_lo_W(p_scale_a_c1_f32_1);
HVX_Vector v_scale_comb_c0_0 = hvx_vec_mul_f32_f32(v_scale_w_f32_0, v_scale_a_c0_0);
HVX_Vector v_scale_comb_c1_0 = hvx_vec_mul_f32_f32(v_scale_w_f32_0, v_scale_a_c1_0);
HVX_Vector v_scale_comb_c0_1 = hvx_vec_mul_f32_f32(v_scale_w_f32_1, v_scale_a_c0_1);
HVX_Vector v_scale_comb_c1_1 = hvx_vec_mul_f32_f32(v_scale_w_f32_1, v_scale_a_c1_1);
HVX_Vector v_sum_scaled_c0_0 = hvx_vec_mul_f32_f32(v_sum_sf_c0_0, v_scale_comb_c0_0);
HVX_Vector v_sum_scaled_c1_0 = hvx_vec_mul_f32_f32(v_sum_sf_c1_0, v_scale_comb_c1_0);
HVX_Vector v_sum_scaled_c0_1 = hvx_vec_mul_f32_f32(v_sum_sf_c0_1, v_scale_comb_c0_1);
HVX_Vector v_sum_scaled_c1_1 = hvx_vec_mul_f32_f32(v_sum_sf_c1_1, v_scale_comb_c1_1);
v_sum_float_c0 = hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vec_add_f32_f32(v_sum_scaled_c0_0, v_sum_scaled_c0_1));
v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vec_add_f32_f32(v_sum_scaled_c1_0, v_sum_scaled_c1_1));
}
for (; kt < n_k_tiles; kt++) {
const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 640);
const HVX_Vector * restrict v_act0 = (const HVX_Vector *) (y0_q + kt * 1152);
const HVX_Vector * restrict v_act1 = (const HVX_Vector *) (y1_q + kt * 1152);
HVX_VectorPair v_sums = accum_4bit_32x2_lut(vptr, v_act0, v_act1, mask_h4, lut);
HVX_Vector v_sum_c0 = Q6_V_lo_W(v_sums);
HVX_Vector v_sum_c1 = Q6_V_hi_W(v_sums);
HVX_Vector v_sum_sf_c0 = Q6_Vsf_equals_Vw(v_sum_c0);
HVX_Vector v_sum_sf_c1 = Q6_Vsf_equals_Vw(v_sum_c1);
HVX_Vector v_scale_w = hvx_vmem(tile_ptr + kt * 640 + 512);
HVX_Vector r0_d = Q6_V_vdelta_VV(v_scale_w, expand);
r0_d = Q6_V_vand_VV(r0_d, e8m0_mask);
HVX_Vector v_scale_w_f32 = Q6_Vw_vasl_VwR(r0_d, 23);
HVX_Vector v_scale_a_c0_f16 = v_act0[8];
HVX_Vector v_scale_a_c1_f16 = v_act1[8];
HVX_VectorPair p_scale_a_c0_f32 = hvx_vec_f16_to_f32_shuff(v_scale_a_c0_f16);
HVX_VectorPair p_scale_a_c1_f32 = hvx_vec_f16_to_f32_shuff(v_scale_a_c1_f16);
HVX_Vector v_scale_a_c0 = Q6_V_lo_W(p_scale_a_c0_f32);
HVX_Vector v_scale_a_c1 = Q6_V_lo_W(p_scale_a_c1_f32);
HVX_Vector v_scale_comb_c0 = hvx_vec_mul_f32_f32(v_scale_w_f32, v_scale_a_c0);
HVX_Vector v_scale_comb_c1 = hvx_vec_mul_f32_f32(v_scale_w_f32, v_scale_a_c1);
HVX_Vector v_sum_scaled_c0 = hvx_vec_mul_f32_f32(v_sum_sf_c0, v_scale_comb_c0);
HVX_Vector v_sum_scaled_c1 = hvx_vec_mul_f32_f32(v_sum_sf_c1, v_scale_comb_c1);
v_sum_float_c0 = hvx_vec_add_f32_f32(v_sum_float_c0, v_sum_scaled_c0);
v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1);
}
v_sum_float_c0 = hvx_vec_mul_f32_f32(v_sum_float_c0, hvx_vec_splat_f32(0.5f));
v_sum_float_c1 = hvx_vec_mul_f32_f32(v_sum_float_c1, hvx_vec_splat_f32(0.5f));
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
}
static inline void quantize_f32_q8_0_tiled_kernel(
const uint8_t * restrict src_data,
uint8_t * restrict dst_data,
uint8_t * restrict tmp_data,
uint32_t ne0,
uint32_t nrows,
size_t src_row_size,
size_t dst_row_size
) {
const size_t src_row_size_padded = hex_round_up(src_row_size, QK_Q8_0_TILED * sizeof(float));
hvx_splat_f32_a(tmp_data, 0.0f, src_row_size_padded / sizeof(float));
for (uint32_t i = 0; i < nrows; ++i) {
hex_l2fetch(src_data, src_row_size, src_row_size, 2);
hvx_copy_f32_aa(tmp_data, src_data, ne0);
quantize_row_f32_q8_0_tiled((float *) tmp_data, dst_data, ne0);
dst_data += dst_row_size;
src_data += src_row_size;
}
}
static inline void quantize_f32_q8_1_tiled_kernel(
const uint8_t * restrict src_data,
uint8_t * restrict dst_data,
uint8_t * restrict tmp_data,
uint32_t ne0,
uint32_t nrows,
size_t src_row_size,
size_t dst_row_size
) {
const size_t src_row_size_padded = hex_round_up(src_row_size, QK_Q8_0_TILED * sizeof(float));
hvx_splat_f32_a(tmp_data, 0.0f, src_row_size_padded / sizeof(float));
for (uint32_t i = 0; i < nrows; ++i) {
hex_l2fetch(src_data, src_row_size, src_row_size, 2);
hvx_copy_f32_aa(tmp_data, src_data, ne0);
quantize_row_f32_q8_1_tiled((float *) tmp_data, dst_data, ne0);
dst_data += dst_row_size;
src_data += src_row_size;
}
}
static inline void quantize_f32_q8_0_tiled_block_kernel(
const float * restrict src,
uint8_t * restrict dst,
uint8_t * restrict tmp_data,
uint32_t ne0,
uint32_t ib_first,
uint32_t ib_last,
size_t src_row_size,
size_t dst_row_size,
uint32_t r,
uint32_t c
) {
const uint32_t qk = QK_Q8_0_TILED;
const uint32_t nb = (ne0 + qk - 1) / qk;
for (uint32_t ib = ib_first; ib < ib_last; ++ib) {
const uint8_t * restrict src_ptr = (const uint8_t *) src + r * src_row_size + c * qk * sizeof(float);
uint8_t * restrict dst_ptr = dst + r * dst_row_size + c * 4 * 1152;
hex_l2fetch(src_ptr, qk * sizeof(float), qk * sizeof(float), 1);
if (c == nb - 1) {
uint32_t active_elements = ne0 - c * qk;
hvx_splat_f32_a(tmp_data, 0.0f, qk);
hvx_copy_f32_aa(tmp_data, src_ptr, active_elements);
} else {
hvx_copy_f32_aa(tmp_data, src_ptr, qk);
}
quantize_block_f32_q8_0_tiled((float *) tmp_data, dst_ptr);
c++;
if (c == nb) {
c = 0;
r++;
}
}
}
static inline void quantize_f32_q8_1_tiled_block_kernel(
const float * restrict src,
uint8_t * restrict dst,
uint8_t * restrict tmp_data,
uint32_t ne0,
uint32_t ib_first,
uint32_t ib_last,
size_t src_row_size,
size_t dst_row_size,
uint32_t r,
uint32_t c
) {
const uint32_t qk = QK_Q8_0_TILED;
const uint32_t nb = (ne0 + qk - 1) / qk;
for (uint32_t ib = ib_first; ib < ib_last; ++ib) {
const uint8_t * restrict src_ptr = (const uint8_t *) src + r * src_row_size + c * qk * sizeof(float);
uint8_t * restrict dst_ptr = dst + r * dst_row_size + c * 4 * 1280;
hex_l2fetch(src_ptr, qk * sizeof(float), qk * sizeof(float), 1);
if (c == nb - 1) {
uint32_t active_elements = ne0 - c * qk;
hvx_splat_f32_a(tmp_data, 0.0f, qk);
hvx_copy_f32_aa(tmp_data, src_ptr, active_elements);
} else {
hvx_copy_f32_aa(tmp_data, src_ptr, qk);
}
quantize_block_f32_q8_1_tiled((float *) tmp_data, dst_ptr);
c++;
if (c == nb) {
c = 0;
r++;
}
}
}