diff --git a/docs/backend/snapdragon/CMakeUserPresets.json b/docs/backend/snapdragon/CMakeUserPresets.json index d37100764f..848d735f1c 100644 --- a/docs/backend/snapdragon/CMakeUserPresets.json +++ b/docs/backend/snapdragon/CMakeUserPresets.json @@ -24,7 +24,6 @@ "GGML_LLAMAFILE": "OFF", "GGML_OPENCL": "ON", "GGML_HEXAGON": "ON", - "GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE": "128", "LLAMA_OPENSSL": "OFF" } }, @@ -47,7 +46,6 @@ "GGML_LLAMAFILE": "OFF", "GGML_OPENCL": "ON", "GGML_HEXAGON": "ON", - "GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE": "128", "LLAMA_OPENSSL": "OFF" } }, @@ -73,7 +71,6 @@ "GGML_LLAMAFILE": "OFF", "GGML_OPENCL": "OFF", "GGML_HEXAGON": "ON", - "GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE": "128", "LLAMA_OPENSSL": "OFF" } }, diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 04069784f1..a0cd4e7158 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -266,7 +266,6 @@ set (GGML_OPENCL_TARGET_VERSION "300" CACHE STRING "ggml: OpenCL API version to target") option(GGML_HEXAGON "ggml: enable Hexagon backend" OFF) -set(GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE 128 CACHE STRING "ggml: quantize group size (32, 64, or 128)") # toolchain for vulkan-shaders-gen set (GGML_VULKAN_SHADERS_GEN_TOOLCHAIN "" CACHE FILEPATH "ggml: toolchain file for vulkan-shaders-gen") diff --git a/ggml/src/ggml-hexagon/CMakeLists.txt b/ggml/src/ggml-hexagon/CMakeLists.txt index b82bae0c10..c6e49a71d1 100644 --- a/ggml/src/ggml-hexagon/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/CMakeLists.txt @@ -25,7 +25,6 @@ include(ExternalProject) option(GGML_HEXAGON_HTP_DEBUG "ggml-hexagon: enable HTP debug output" OFF) option(GGML_HEXAGON_FA_EXP2_HF "ggml-hexagon: use FP16 exp2 polynomial in FA softmax instead of F32 exp round-trip" OFF) set(GGML_HEXAGON_HTP_CERT "$ENV{HEXAGON_HTP_CERT}" CACHE PATH "ggml-hexagon: enable HTP library signing using certificate") -set(GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE 128 CACHE STRING "ggml-hexagon: quantize group size (32, 64, or 128)") add_library(htp_iface OBJECT ${CMAKE_CURRENT_BINARY_DIR}/htp_iface_stub.c) @@ -72,15 +71,12 @@ function(build_htp_skel V) -DHEXAGON_SDK_ROOT=${HEXAGON_SDK_ROOT} -DHEXAGON_TOOLS_ROOT=${HEXAGON_TOOLS_ROOT} -DHEXAGON_HTP_DEBUG=${GGML_HEXAGON_HTP_DEBUG} - -DGGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE} -DDSP_VERSION=${V} -DPREBUILT_LIB_DIR="toolv19_${V}") list(APPEND HTP_SKELS ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-${V}.so) set(HTP_SKELS ${HTP_SKELS} PARENT_SCOPE) endfunction() -build_htp_skel(v68) -build_htp_skel(v69) build_htp_skel(v73) build_htp_skel(v75) build_htp_skel(v79) diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index e612ec392b..3d41c47b65 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -18,6 +18,7 @@ #include #include #include +#include #ifdef _WIN32 # include @@ -41,6 +42,7 @@ #include "ggml-quants.h" #include "htp-opnode.h" #include "htp-ops.h" +#include "htp/matmul-ops.h" #include "htp_iface.h" #include "htp-drv.h" @@ -51,7 +53,7 @@ using u32vec = std::vector; static int opt_arch = 0; // autodetect static size_t opt_ndev = 1; static size_t opt_nhvx = 0; // use all -static int opt_use_hmx = 1; // when set, enable HMX; when 0, use HVX only +static int opt_nhmx = 1; // when set, enable HMX; when 0, use HVX only static size_t opt_vmem = HTP_OP_MAX_VMEM_DEFAULT; // max available va space for buffer mappings static size_t opt_mbuf = 1ul * 1024 * 1024 * 1024; // max buffer size static int opt_etm = 0; @@ -59,6 +61,8 @@ static int opt_verbose = 0; static int opt_profile = 0; // profiling mode (0-disabled, 1-basic, 2-pmu) static int opt_hostbuf = 1; // hostbuf ON by default +static int opt_mm_select = 3; // 3 = HMX -> Tiled -> Flat -> CPU, 2 = Tiled -> Flat -> CPU, 1 = Flat -> CPU + // Default PMU events, if profiling with PMU (mode=2) is enabled // See https://docs.qualcomm.com/doc/80-N2040-60/topic/pmu-events.html // https://docs.qualcomm.com/doc/80-N2040-61/topic/hvx-pmu-events.html @@ -68,22 +72,15 @@ static u32vec opt_pmu_evt { 0x3, 0x111, 0x100, 0x105, 0x240, 0x256, 0x7D, 0x8C } static int opt_opstage = HTP_OPSTAGE_QUEUE | HTP_OPSTAGE_COMPUTE; static int opt_opbatch = 1024; // max number of ops in a batch static int opt_opqueue = 16; // max number of pending batches -static int opt_oppoll = 0; // polling for batch completions static int opt_optrace = 0; // trace buffer size per thread (0 means default) +static int opt_oppoll = 0; // polling for batch completions +static int opt_opfusion = 1; // enable/disable op fusion static std::regex* opt_opfilter = NULL; // regex of ops to not claim #define HEX_VERBOSE(...) \ if (opt_verbose) GGML_LOG_DEBUG(__VA_ARGS__) -static inline uint64_t hex_is_aligned(void * addr, uint32_t align) { - return ((size_t) addr & (align - 1)) == 0; -} - -static inline size_t hex_round_up(size_t n, size_t m) { - return m * ((n + m - 1) / m); -} - static const char * status_to_str(uint32_t status) { switch (status) { case HTP_STATUS_OK: @@ -107,15 +104,15 @@ static void ggml_hexagon_dump_op_exec(const std::string &sess_name, const htp_op if (!opt_verbose) return; htp_opformat fmt(node); - GGML_LOG_DEBUG("ggml-hex: %s execute-op %s: %s : %s : %s : %s : %s : flags 0x%x\n", sess_name.c_str(), - node.op_name().c_str(), fmt.names, fmt.dims, fmt.types, fmt.strides, fmt.buffs, req_flags); + GGML_LOG_DEBUG("ggml-hex: %s execute-op %s|%s|%s|%s|%s|%s|%s|flags 0x%x\n", sess_name.c_str(), + node.op_name().c_str(), fmt.names, fmt.dims, fmt.types, fmt.strides, fmt.buffs, fmt.kparams, req_flags); } static void ggml_hexagon_dump_op_supp(const std::string &sess_name, const struct ggml_tensor * op, bool supp) { if (!opt_verbose) return; htp_opformat fmt(htp_opformat(htp_opnode{const_cast(op), {}, HTP_OP_INVALID})); - GGML_LOG_DEBUG("ggml-hex: %s supports-op %s: %s : %s : %s : %s : %s : %s\n", sess_name.c_str(), + GGML_LOG_DEBUG("ggml-hex: %s supports-op %s|%s|%s|%s|%s|%s|%s\n", sess_name.c_str(), ggml_op_desc(op), fmt.names, fmt.dims, fmt.types, fmt.strides, fmt.buffs, supp ? "yes" : "no"); } @@ -144,16 +141,52 @@ static void ggml_hexagon_dump_op_prof(const std::string &sess_name, const htp_op char pmu_str[256] = ""; if (opt_profile == 2) { static_assert(HTP_PROF_PMU_NCNT == 8, "current implementation assumes 8 PMU counters"); - sprintf(pmu_str, " pmu [%u,%u,%u,%u,%u,%u,%u,%u]", + snprintf(pmu_str, sizeof(pmu_str), " pmu [%u,%u,%u,%u,%u,%u,%u,%u]", pmu[0], pmu[1], pmu[2], pmu[3], pmu[4], pmu[5], pmu[6], pmu[7]); } htp_opformat fmt(node); float mhz = op_usec > 0 ? (float) op_cycles / op_usec : 0.0f; - GGML_LOG_DEBUG("ggml-hex: %s profile-op %s: %s : %s : %s : %s : usec %u cycles %u start %u mhz %.1f%s\n", sess_name.c_str(), - node.op_name().c_str(), fmt.names, fmt.dims, fmt.types, fmt.strides, op_usec, op_cycles, pd.cycles_start, mhz, pmu_str); + GGML_LOG_DEBUG("ggml-hex: %s profile-op %s|%s|%s|%s|%s|%s|usec %u cycles %u start %u mhz %.1f%s\n", sess_name.c_str(), + node.op_name().c_str(), fmt.names, fmt.dims, fmt.types, fmt.strides, fmt.kparams, op_usec, op_cycles, pd.cycles_start, mhz, pmu_str); } +// ** + +static inline bool ggml_hexagon_is_repack_type(enum ggml_type type) { + return type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || + type == GGML_TYPE_Q8_0 || type == GGML_TYPE_IQ4_NL || + type == GGML_TYPE_MXFP4; +} + +static inline bool ggml_hexagon_is_hmx_weight_type(enum ggml_type type) { + return type == GGML_TYPE_F16 || type == GGML_TYPE_F32 || ggml_hexagon_is_repack_type(type); +} + +struct htp_mm_kernel_params; +struct ggml_hexagon_session; +static void ggml_hexagon_precompute_matmul_params( + const struct ggml_hexagon_session * sess, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + const struct ggml_tensor * dst, + struct htp_mm_kernel_params * kparams +); + +static void ggml_hexagon_precompute_fused_qkv_params( + const struct ggml_hexagon_session * sess, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct htp_mm_kernel_params * kparams +); + +static void ggml_hexagon_precompute_fused_ffn_params( + const struct ggml_hexagon_session * sess, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct htp_mm_kernel_params * kparams +); + // ** backend sessions struct ggml_hexagon_opbatch; @@ -180,6 +213,18 @@ struct ggml_hexagon_session { ggml_backend_buffer_type buffer_type = {}; ggml_backend_buffer_type repack_buffer_type = {}; + uint32_t n_threads = 0; + uint32_t n_hvx = 0; + uint32_t n_hmx = 0; + uint64_t vtcm_size = 0; + size_t max_vmem = 0; + size_t max_bufsize = 0; + + struct { + uint64_t uid = 0; + std::vector htp_nodes; + } cached_graph; + ggml_hexagon_session(int dev_id, ggml_backend_dev_t dev) noexcept(false); ~ggml_hexagon_session() noexcept(true); @@ -325,47 +370,7 @@ static enum ggml_status ggml_backend_hexagon_buffer_init_tensor(ggml_backend_buf return GGML_STATUS_SUCCESS; } -// ======== Q4x4x2 ==================== -struct x2_q4 { - int v[2]; -}; - -static x2_q4 unpack_q4(uint8_t v) { - x2_q4 x = { (int) (v & 0x0f) - 8, (int) (v >> 4) - 8 }; - return x; -} - -static void dump_block_q4_0(const block_q4_0 * b, int i) { - HEX_VERBOSE("ggml-hex: repack q4_0 %d: %d %d %d %d ... %d %d %d %d : %.6f\n", i, unpack_q4(b->qs[0]).v[0], - unpack_q4(b->qs[1]).v[0], unpack_q4(b->qs[2]).v[0], unpack_q4(b->qs[3]).v[0], unpack_q4(b->qs[12]).v[1], - unpack_q4(b->qs[13]).v[1], unpack_q4(b->qs[14]).v[1], unpack_q4(b->qs[15]).v[1], - GGML_FP16_TO_FP32(b->d)); -} - -static void dump_packed_block_q4x4x2(const uint8_t * v, unsigned int i, size_t k) { - static const int qk = QK_Q4_0x4x2; - const int dblk_size = 8 * 2; // 8x __fp16 - const int qblk_size = qk / 2; // int4 - const int qrow_size = k / 2; // int4 (not padded) - - const uint8_t * v_q = v + 0; // quants first - const uint8_t * v_d = v + qrow_size; // then scales - - const uint8_t * q = v_q + i * qblk_size; - const ggml_half * d = (const ggml_half *) (v_d + i * dblk_size); - - HEX_VERBOSE("ggml-hex: repack q4x4x2-%d: %d %d %d %d ... %d %d %d %d ... %d %d %d %d : %.6f %.6f %.6f %.6f\n", i, - unpack_q4(q[0]).v[0], unpack_q4(q[1]).v[0], unpack_q4(q[2]).v[0], unpack_q4(q[3]).v[0], - unpack_q4(q[60]).v[0], unpack_q4(q[61]).v[0], unpack_q4(q[62]).v[0], unpack_q4(q[63]).v[0], - unpack_q4(q[124]).v[0], unpack_q4(q[125]).v[0], unpack_q4(q[126]).v[0], unpack_q4(q[127]).v[0], - GGML_FP16_TO_FP32(d[0]), GGML_FP16_TO_FP32(d[1]), GGML_FP16_TO_FP32(d[2]), GGML_FP16_TO_FP32(d[3])); - - HEX_VERBOSE("ggml-hex: repack q4x4x2-%d: %d %d %d %d ... %d %d %d %d ... %d %d %d %d : %.6f %.6f %.6f %.6f\n", - i + 1, unpack_q4(q[0]).v[1], unpack_q4(q[1]).v[1], unpack_q4(q[2]).v[1], unpack_q4(q[3]).v[1], - unpack_q4(q[60]).v[1], unpack_q4(q[61]).v[1], unpack_q4(q[62]).v[1], unpack_q4(q[63]).v[1], - unpack_q4(q[124]).v[1], unpack_q4(q[125]).v[1], unpack_q4(q[126]).v[1], unpack_q4(q[127]).v[1], - GGML_FP16_TO_FP32(d[4]), GGML_FP16_TO_FP32(d[5]), GGML_FP16_TO_FP32(d[6]), GGML_FP16_TO_FP32(d[7])); -} +// ** Repack helpers for tiled quantized weights static void unpack_q4_0_quants(uint8_t * qs, const block_q4_0 * x, unsigned int bi) { static const int qk = QK4_0; @@ -388,300 +393,6 @@ static void pack_q4_0_quants(block_q4_0 * x, const uint8_t * qs, unsigned int bi } } -static void repack_row_q4x4x2(uint8_t * y, const block_q4_0 * x, int64_t k) { - static const int qk = QK_Q4_0x4x2; - const int nb = (k + qk - 1) / qk; // number of blocks (padded) - const int nloe = k % qk; // leftovers - - const int dblk_size = 8 * 2; // 8x __fp16 - const int qblk_size = qk / 2; // int4 - const int qrow_size = k / 2; // int4 (not padded to blocks) - - uint8_t * y_q = y + 0; // quants first - uint8_t * y_d = y + qrow_size; // then scales - - if (opt_verbose > 2) { - for (int i = 0; i < nb; i++) { - dump_block_q4_0(&x[i * 8 + 0], 0); - dump_block_q4_0(&x[i * 8 + 1], 1); - dump_block_q4_0(&x[i * 8 + 2], 2); - dump_block_q4_0(&x[i * 8 + 3], 3); - dump_block_q4_0(&x[i * 8 + 4], 4); - dump_block_q4_0(&x[i * 8 + 5], 5); - dump_block_q4_0(&x[i * 8 + 6], 6); - dump_block_q4_0(&x[i * 8 + 7], 7); - } - } - - // Repack the quants - for (int i = 0; i < nb; i++) { - uint8_t qs[QK_Q4_0x4x2]; // unpacked quants - unpack_q4_0_quants(qs, &x[i * 8 + 0], 0); - unpack_q4_0_quants(qs, &x[i * 8 + 1], 1); - unpack_q4_0_quants(qs, &x[i * 8 + 2], 2); - unpack_q4_0_quants(qs, &x[i * 8 + 3], 3); - unpack_q4_0_quants(qs, &x[i * 8 + 4], 4); - unpack_q4_0_quants(qs, &x[i * 8 + 5], 5); - unpack_q4_0_quants(qs, &x[i * 8 + 6], 6); - unpack_q4_0_quants(qs, &x[i * 8 + 7], 7); - - bool partial = (nloe && i == nb-1); - - uint8_t * q = y_q + (i * qblk_size); - for (int j = 0; j < qk / 2; j++) { - q[j] = partial ? (qs[j*2+1] << 4) | qs[j*2+0] : (qs[j+128] << 4) | qs[j+000]; - } - } - - // Repack the scales - // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2) - // the last block is truncated and overridden by the scales. - for (int i = 0; i < nb; i++) { - // Repack the scales - ggml_half * d = (ggml_half *) (y_d + i * dblk_size); - d[0] = x[i * 8 + 0].d; - d[1] = x[i * 8 + 1].d; - d[2] = x[i * 8 + 2].d; - d[3] = x[i * 8 + 3].d; - d[4] = x[i * 8 + 4].d; - d[5] = x[i * 8 + 5].d; - d[6] = x[i * 8 + 6].d; - d[7] = x[i * 8 + 7].d; - } - - if (opt_verbose > 2) { - for (int i = 0; i < nb; i++) { - dump_packed_block_q4x4x2(y, i, k); - } - } -} - -static void unpack_row_q4x4x2(block_q4_0 * x, const uint8_t * y, int64_t k) { - static const int qk = QK_Q4_0x4x2; - const int nb = (k + qk - 1) / qk; // number of blocks (padded) - const int nloe = k % qk; // leftovers - - const int dblk_size = 8 * 2; // 8x __fp16 - const int qblk_size = qk / 2; // int4 - const int qrow_size = k / 2; // int4 (not padded to blocks) - - const uint8_t * y_q = y + 0; // quants first - const uint8_t * y_d = y + qrow_size; // then scales - - if (opt_verbose > 2) { - for (int i = 0; i < nb; i++) { - dump_packed_block_q4x4x2(y, i, k); - } - } - - // Unpack the quants - for (int i = 0; i < nb; i++) { - uint8_t qs[QK_Q4_0x4x2]; // unpacked quants - - bool partial = (nloe && i == nb-1); - - const uint8_t * q = y_q + (i * qblk_size); - for (int j = 0; j < qk / 2; j++) { - if (partial) { - qs[j*2+0] = q[j] & 0xf; - qs[j*2+1] = q[j] >> 4; - } else { - qs[j+000] = q[j] & 0xf; - qs[j+128] = q[j] >> 4; - } - } - - pack_q4_0_quants(&x[i * 8 + 0], qs, 0); - pack_q4_0_quants(&x[i * 8 + 1], qs, 1); - pack_q4_0_quants(&x[i * 8 + 2], qs, 2); - pack_q4_0_quants(&x[i * 8 + 3], qs, 3); - pack_q4_0_quants(&x[i * 8 + 4], qs, 4); - pack_q4_0_quants(&x[i * 8 + 5], qs, 5); - pack_q4_0_quants(&x[i * 8 + 6], qs, 6); - pack_q4_0_quants(&x[i * 8 + 7], qs, 7); - } - - // Repack the scales - // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2) - // the last block is truncated and overridden by the scales. - for (int i = 0; i < nb; i++) { - // Unpack the scales - const ggml_half * d = (const ggml_half *) (y_d + i * dblk_size); - x[i * 8 + 0].d = d[0]; - x[i * 8 + 1].d = d[1]; - x[i * 8 + 2].d = d[2]; - x[i * 8 + 3].d = d[3]; - x[i * 8 + 4].d = d[4]; - x[i * 8 + 5].d = d[5]; - x[i * 8 + 6].d = d[6]; - x[i * 8 + 7].d = d[7]; - } - - if (opt_verbose > 2) { - for (int i = 0; i < nb; i++) { - dump_block_q4_0(&x[i * 8 + 0], 0); - dump_block_q4_0(&x[i * 8 + 1], 1); - dump_block_q4_0(&x[i * 8 + 2], 2); - dump_block_q4_0(&x[i * 8 + 3], 3); - dump_block_q4_0(&x[i * 8 + 4], 4); - dump_block_q4_0(&x[i * 8 + 5], 5); - dump_block_q4_0(&x[i * 8 + 6], 6); - dump_block_q4_0(&x[i * 8 + 7], 7); - } - } -} - -static void init_row_q4x4x2(block_q4_0 * x, int64_t k) { - static const int qk = QK_Q4_0x4x2; - const int nb = (k + qk - 1) / qk; // number of blocks (padded) - - // Init the quants such that they unpack into zeros - uint8_t qs[QK_Q4_0x4x2]; // unpacked quants - memset(qs, 8, sizeof(qs)); - - for (int i = 0; i < nb; i++) { - pack_q4_0_quants(&x[i * 8 + 0], qs, 0); - pack_q4_0_quants(&x[i * 8 + 1], qs, 1); - pack_q4_0_quants(&x[i * 8 + 2], qs, 2); - pack_q4_0_quants(&x[i * 8 + 3], qs, 3); - pack_q4_0_quants(&x[i * 8 + 4], qs, 4); - pack_q4_0_quants(&x[i * 8 + 5], qs, 5); - pack_q4_0_quants(&x[i * 8 + 6], qs, 6); - pack_q4_0_quants(&x[i * 8 + 7], qs, 7); - } - - // Init the scales - // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2) - // the last block is truncated and overridden by the scales. - for (int i = 0; i < nb; i++) { - // Unpack the scales - x[i * 8 + 0].d = 0; - x[i * 8 + 1].d = 0; - x[i * 8 + 2].d = 0; - x[i * 8 + 3].d = 0; - x[i * 8 + 4].d = 0; - x[i * 8 + 5].d = 0; - x[i * 8 + 6].d = 0; - x[i * 8 + 7].d = 0; - } -} - -// repack q4_0 data into q4x4x2 tensor -static void repack_q4_0_q4x4x2(ggml_tensor * t, const void * data, size_t size) { - int64_t nrows = ggml_nrows(t); - - size_t row_size = ggml_row_size(t->type, t->ne[0]); - size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_0x4x2)); // extra elements for the pad - size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size/2 quants + scales) - - // Ensure we don't try to read more data than is available in the source buffer 'data' - // or write more than the tensor can hold. - const size_t total_tensor_size = (size_t)nrows * row_size; - const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size; - - // Calculate how many full rows and how many remaining bytes we need to process. - const int64_t n_full_rows = n_bytes_to_copy / row_size; - const size_t n_rem_bytes = n_bytes_to_copy % row_size; - - void * buf_pd = ggml_aligned_malloc(row_size_pd); - GGML_ASSERT(buf_pd != NULL); - - void * buf_rp = ggml_aligned_malloc(row_size_rp); - GGML_ASSERT(buf_rp != NULL); - - HEX_VERBOSE("ggml-hex: repack-q4_0-q4x4x2 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, size, - t->ne[0], nrows, row_size); - - init_row_q4x4x2((block_q4_0 *) buf_pd, t->ne[0]); // init padded buffer to make sure the tail is all zeros - - // 1. Process all the full rows - for (int64_t i = 0; i < n_full_rows; i++) { - const uint8_t * src = (const uint8_t *) data + (i * row_size); - uint8_t * dst = (uint8_t *) t->data + (i * row_size); - - memcpy(buf_pd, src, row_size); - repack_row_q4x4x2((uint8_t *) buf_rp, (const block_q4_0 *) buf_pd, t->ne[0]); - memcpy(dst, buf_rp, row_size); - } - - // 2. Process the final, potentially partial, row - if (n_rem_bytes > 0) { - const int64_t i = n_full_rows; - const uint8_t * src = (const uint8_t *) data + (i * row_size); - uint8_t * dst = (uint8_t *) t->data + (i * row_size); - - // re-init the row because we are potentially copying a partial row - init_row_q4x4x2((block_q4_0 *) buf_pd, t->ne[0]); - - // Copy only the remaining bytes from the source. - memcpy(buf_pd, src, n_rem_bytes); - - // Repack the entire buffer - repack_row_q4x4x2((uint8_t *) buf_rp, (const block_q4_0 *) buf_pd, t->ne[0]); - - // Write only the corresponding remaining bytes to the destination tensor. - memcpy(dst, buf_rp, n_rem_bytes); - } - - ggml_aligned_free(buf_pd, row_size_pd); - ggml_aligned_free(buf_rp, row_size_rp); -} - -// repack q4x4x2 tensor into q4_0 data -static void repack_q4x4x2_q4_0(void * data, const ggml_tensor * t, size_t size) { - int64_t nrows = ggml_nrows(t); - - size_t row_size = ggml_row_size(t->type, t->ne[0]); - size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_0x4x2)); // extra elements for the pad - size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size/2 quants + scales) - - // Ensure we don't try to copy more data than the tensor actually contains. - const size_t total_tensor_size = (size_t)nrows * row_size; - const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size; - - // Calculate how many full rows and how many remaining bytes we need to process. - const int64_t n_full_rows = n_bytes_to_copy / row_size; - const size_t n_rem_bytes = n_bytes_to_copy % row_size; - - void * buf_pd = ggml_aligned_malloc(row_size_pd); - GGML_ASSERT(buf_pd != NULL); - - void * buf_rp = ggml_aligned_malloc(row_size_rp); - GGML_ASSERT(buf_rp != NULL); - - HEX_VERBOSE("ggml-hex: repack-q4x4x2-q4_0 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, size, - t->ne[0], nrows, row_size); - - memset(buf_pd, 0, row_size_pd); // clear-out padded buffer to make sure the tail is all zeros - - // 1. Process all the full rows - for (int64_t i = 0; i < n_full_rows; i++) { - const uint8_t * src = (const uint8_t *) t->data + (i * row_size); - uint8_t * dst = (uint8_t *) data + (i * row_size); - - memcpy(buf_pd, src, row_size); - unpack_row_q4x4x2((block_q4_0 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]); - memcpy(dst, buf_rp, row_size); - } - - // 2. Process the final, potentially partial, row - if (n_rem_bytes > 0) { - const int64_t i = n_full_rows; - const uint8_t * src = (const uint8_t *) t->data + (i * row_size); - uint8_t * dst = (uint8_t *) data + (i * row_size); - - // We still need to read and unpack the entire source row because quantization is block-based. - memcpy(buf_pd, src, row_size); - unpack_row_q4x4x2((block_q4_0 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]); - - // But we only copy the remaining number of bytes to the destination. - memcpy(dst, buf_rp, n_rem_bytes); - } - - ggml_aligned_free(buf_pd, row_size_pd); - ggml_aligned_free(buf_rp, row_size_rp); -} - static void unpack_q4_1_quants(uint8_t * qs, const block_q4_1 * x, unsigned int bi) { static const int qk = QK4_1; @@ -703,603 +414,19 @@ static void pack_q4_1_quants(block_q4_1 * x, const uint8_t * qs, unsigned int bi } } -static void repack_row_q4_1x4x2(uint8_t * y, const block_q4_1 * x, int64_t k) { - static const int qk = QK_Q4_0x4x2; - const int nb = (k + qk - 1) / qk; // number of blocks (padded) - const int nloe = k % qk; // leftovers - - const int dblk_size = 8 * 4; // 8x (d, m) __fp16 = 32 bytes - const int qblk_size = qk / 2; // int4 = 128 bytes - const int qrow_size = k / 2; // int4 (not padded to blocks) - - uint8_t * y_q = y + 0; // quants first - uint8_t * y_d = y + qrow_size; // then scales/offsets - - // Repack the quants - for (int i = 0; i < nb; i++) { - uint8_t qs[QK_Q4_0x4x2]; // unpacked quants - unpack_q4_1_quants(qs, &x[i * 8 + 0], 0); - unpack_q4_1_quants(qs, &x[i * 8 + 1], 1); - unpack_q4_1_quants(qs, &x[i * 8 + 2], 2); - unpack_q4_1_quants(qs, &x[i * 8 + 3], 3); - unpack_q4_1_quants(qs, &x[i * 8 + 4], 4); - unpack_q4_1_quants(qs, &x[i * 8 + 5], 5); - unpack_q4_1_quants(qs, &x[i * 8 + 6], 6); - unpack_q4_1_quants(qs, &x[i * 8 + 7], 7); - - bool partial = (nloe && i == nb-1); - - uint8_t * q = y_q + (i * qblk_size); - for (int j = 0; j < qk / 2; j++) { - q[j] = partial ? (qs[j*2+1] << 4) | qs[j*2+0] : (qs[j+128] << 4) | qs[j+000]; - } - } - - // Repack the scales and offsets - for (int i = 0; i < nb; i++) { - ggml_half * d_m = (ggml_half *) (y_d + i * dblk_size); - for (int j = 0; j < 8; j++) { - d_m[j * 2 + 0] = x[i * 8 + j].d; - d_m[j * 2 + 1] = x[i * 8 + j].m; - } - } -} - -static void unpack_row_q4_1x4x2(block_q4_1 * x, const uint8_t * y, int64_t k) { - static const int qk = QK_Q4_0x4x2; - const int nb = (k + qk - 1) / qk; // number of blocks (padded) - const int nloe = k % qk; // leftovers - - const int dblk_size = 8 * 4; // 8x (d, m) __fp16 = 32 bytes - const int qblk_size = qk / 2; // int4 = 128 bytes - const int qrow_size = k / 2; // int4 (not padded to blocks) - - const uint8_t * y_q = y + 0; // quants first - const uint8_t * y_d = y + qrow_size; // then scales/offsets - - // Unpack the quants - for (int i = 0; i < nb; i++) { - uint8_t qs[QK_Q4_0x4x2]; - bool partial = (nloe && i == nb-1); - - const uint8_t * q = y_q + (i * qblk_size); - for (int j = 0; j < qk / 2; j++) { - if (partial) { - qs[j*2+0] = q[j] & 0x0F; - qs[j*2+1] = q[j] >> 4; - } else { - qs[j+000] = q[j] & 0x0F; - qs[j+128] = q[j] >> 4; - } - } - - pack_q4_1_quants(&x[i * 8 + 0], qs, 0); - pack_q4_1_quants(&x[i * 8 + 1], qs, 1); - pack_q4_1_quants(&x[i * 8 + 2], qs, 2); - pack_q4_1_quants(&x[i * 8 + 3], qs, 3); - pack_q4_1_quants(&x[i * 8 + 4], qs, 4); - pack_q4_1_quants(&x[i * 8 + 5], qs, 5); - pack_q4_1_quants(&x[i * 8 + 6], qs, 6); - pack_q4_1_quants(&x[i * 8 + 7], qs, 7); - } - - // Unpack the scales and offsets - for (int i = 0; i < nb; i++) { - const ggml_half * d_m = (const ggml_half *) (y_d + i * dblk_size); - for (int j = 0; j < 8; j++) { - x[i * 8 + j].d = d_m[j * 2 + 0]; - x[i * 8 + j].m = d_m[j * 2 + 1]; - } - } -} - -static void init_row_q4_1x4x2(block_q4_1 * x, int64_t k) { - static const int qk = QK_Q4_0x4x2; - const int nb = (k + qk - 1) / qk; // number of blocks (padded) - - uint8_t qs[QK_Q4_0x4x2]; // unpacked quants - memset(qs, 0, sizeof(qs)); - - for (int i = 0; i < nb; i++) { - pack_q4_1_quants(&x[i * 8 + 0], qs, 0); - pack_q4_1_quants(&x[i * 8 + 1], qs, 1); - pack_q4_1_quants(&x[i * 8 + 2], qs, 2); - pack_q4_1_quants(&x[i * 8 + 3], qs, 3); - pack_q4_1_quants(&x[i * 8 + 4], qs, 4); - pack_q4_1_quants(&x[i * 8 + 5], qs, 5); - pack_q4_1_quants(&x[i * 8 + 6], qs, 6); - pack_q4_1_quants(&x[i * 8 + 7], qs, 7); - } - - for (int i = 0; i < nb; i++) { - for (int j = 0; j < 8; j++) { - x[i * 8 + j].d = 0; - x[i * 8 + j].m = 0; - } - } -} - -static void repack_q4_1_q4x4x2(ggml_tensor * t, const void * data, size_t size) { - int64_t nrows = ggml_nrows(t); - - size_t row_size = ggml_row_size(t->type, t->ne[0]); - size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_0x4x2)); - size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size/2 quants + scales) - - const size_t total_tensor_size = (size_t)nrows * row_size; - const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size; - - const int64_t n_full_rows = n_bytes_to_copy / row_size; - const size_t n_rem_bytes = n_bytes_to_copy % row_size; - - void * buf_pd = ggml_aligned_malloc(row_size_pd); - GGML_ASSERT(buf_pd != NULL); - - void * buf_rp = ggml_aligned_malloc(row_size_rp); - GGML_ASSERT(buf_rp != NULL); - - HEX_VERBOSE("ggml-hex: repack-q4_1-q4x4x2 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, size, - t->ne[0], nrows, row_size); - - init_row_q4_1x4x2((block_q4_1 *) buf_pd, t->ne[0]); - - for (int64_t i = 0; i < n_full_rows; i++) { - const uint8_t * src = (const uint8_t *) data + (i * row_size); - uint8_t * dst = (uint8_t *) t->data + (i * row_size); - - memcpy(buf_pd, src, row_size); - repack_row_q4_1x4x2((uint8_t *) buf_rp, (const block_q4_1 *) buf_pd, t->ne[0]); - memcpy(dst, buf_rp, row_size); - } - - if (n_rem_bytes > 0) { - const int64_t i = n_full_rows; - const uint8_t * src = (const uint8_t *) data + (i * row_size); - uint8_t * dst = (uint8_t *) t->data + (i * row_size); - - init_row_q4_1x4x2((block_q4_1 *) buf_pd, t->ne[0]); - memcpy(buf_pd, src, n_rem_bytes); - repack_row_q4_1x4x2((uint8_t *) buf_rp, (const block_q4_1 *) buf_pd, t->ne[0]); - memcpy(dst, buf_rp, n_rem_bytes); - } - - ggml_aligned_free(buf_pd, row_size_pd); - ggml_aligned_free(buf_rp, row_size_rp); -} - -static void repack_q4x4x2_q4_1(void * data, const ggml_tensor * t, size_t size) { - int64_t nrows = ggml_nrows(t); - - size_t row_size = ggml_row_size(t->type, t->ne[0]); - size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_0x4x2)); - size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size/2 quants + scales) - - const size_t total_tensor_size = (size_t)nrows * row_size; - const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size; - - const int64_t n_full_rows = n_bytes_to_copy / row_size; - const size_t n_rem_bytes = n_bytes_to_copy % row_size; - - void * buf_pd = ggml_aligned_malloc(row_size_pd); - GGML_ASSERT(buf_pd != NULL); - - void * buf_rp = ggml_aligned_malloc(row_size_rp); - GGML_ASSERT(buf_rp != NULL); - - HEX_VERBOSE("ggml-hex: repack-q4x4x2-q4_1 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, size, - t->ne[0], nrows, row_size); - - memset(buf_rp, 0, row_size_rp); // clear-out padded buffer to make sure the tail is all zeros - - for (int64_t i = 0; i < n_full_rows; i++) { - const uint8_t * src = (const uint8_t *) t->data + (i * row_size); - uint8_t * dst = (uint8_t *) data + (i * row_size); - - memcpy(buf_rp, src, row_size); - unpack_row_q4_1x4x2((block_q4_1 *) buf_pd, (const uint8_t *) buf_rp, t->ne[0]); - memcpy(dst, buf_pd, row_size); - } - - if (n_rem_bytes > 0) { - const int64_t i = n_full_rows; - const uint8_t * src = (const uint8_t *) t->data + (i * row_size); - uint8_t * dst = (uint8_t *) data + (i * row_size); - - // We still need to read and unpack the entire source row because quantization is block-based. - memcpy(buf_rp, src, row_size); - unpack_row_q4_1x4x2((block_q4_1 *) buf_pd, (const uint8_t *) buf_rp, t->ne[0]); - memcpy(dst, buf_pd, n_rem_bytes); - } - - ggml_aligned_free(buf_pd, row_size_pd); - ggml_aligned_free(buf_rp, row_size_rp); -} - -// ======== Q8x4x2 ==================== -static void dump_block_q8_0(const block_q8_0 * b, int i) { - HEX_VERBOSE("ggml-hex: repack q8_0 %d: %d %d %d %d ... %d %d %d %d : %.6f\n", i, b->qs[0], b->qs[1], b->qs[2], - b->qs[3], b->qs[28], b->qs[29], b->qs[30], b->qs[31], GGML_FP16_TO_FP32(b->d)); -} - -static void dump_packed_block_q8x4x2(const uint8_t * v, unsigned int i, size_t k) { - static const int qk = QK_Q8_0x4x2; - const int dblk_size = 8 * 2; // 8x __fp16 - const int qblk_size = qk; // int8 - const int qrow_size = k; // int8 (not padded) - - const uint8_t * v_q = v + 0; // quants first - const uint8_t * v_d = v + qrow_size; // then scales - - const uint8_t * q = v_q + i * qblk_size; - const ggml_half * d = (const ggml_half *) (v_d + i * dblk_size); - - HEX_VERBOSE("ggml-hex: repack q8x4x2-%d: %d %d %d %d ... %d %d %d %d ... %d %d %d %d : %.6f %.6f %.6f %.6f\n", i, - q[0], q[1], q[2], q[3], q[60], q[61], q[62], q[63], q[124], q[125], q[126], q[127], - GGML_FP16_TO_FP32(d[0]), GGML_FP16_TO_FP32(d[1]), GGML_FP16_TO_FP32(d[2]), GGML_FP16_TO_FP32(d[3])); - - HEX_VERBOSE("ggml-hex: repack q8x4x2-%d: %d %d %d %d ... %d %d %d %d ... %d %d %d %d : %.6f %.6f %.6f %.6f\n", - i + 1, q[128], q[129], q[130], q[131], q[192], q[193], q[194], q[195], q[252], q[253], q[254], q[255], - GGML_FP16_TO_FP32(d[4]), GGML_FP16_TO_FP32(d[5]), GGML_FP16_TO_FP32(d[6]), GGML_FP16_TO_FP32(d[7])); -} - -static void unpack_q8_0_quants(uint8_t * qs, const block_q8_0 * x, unsigned int bi) { - static const int qk = QK8_0; - - for (unsigned int i = 0; i < qk; ++i) { - qs[bi * qk + i] = x->qs[i]; - } -} - -static void pack_q8_0_quants(block_q8_0 * x, const uint8_t * qs, unsigned int bi) { - static const int qk = QK8_0; - - for (unsigned int i = 0; i < qk; ++i) { - x->qs[i] = qs[bi * qk + i]; - } -} - -static void repack_row_q8x4x2(uint8_t * y, const block_q8_0 * x, int64_t k) { - static const int qk = QK_Q8_0x4x2; - const int nb = (k + qk - 1) / qk; // number of blocks (padded) - - const int dblk_size = 8 * 2; // 8x __fp16 - const int qblk_size = qk; // int8 - const int qrow_size = k; // int8 (not padded to blocks) - - uint8_t * y_q = y + 0; // quants first - uint8_t * y_d = y + qrow_size; // then scales - - if (opt_verbose > 2) { - for (int i = 0; i < nb; i++) { - dump_block_q8_0(&x[i * 8 + 0], 0); - dump_block_q8_0(&x[i * 8 + 1], 1); - dump_block_q8_0(&x[i * 8 + 2], 2); - dump_block_q8_0(&x[i * 8 + 3], 3); - dump_block_q8_0(&x[i * 8 + 4], 4); - dump_block_q8_0(&x[i * 8 + 5], 5); - dump_block_q8_0(&x[i * 8 + 6], 6); - dump_block_q8_0(&x[i * 8 + 7], 7); - } - } - - // Repack the quants - for (int i = 0; i < nb; i++) { - uint8_t qs[QK_Q8_0x4x2]; // unpacked quants - - unpack_q8_0_quants(qs, &x[i * 8 + 0], 0); - unpack_q8_0_quants(qs, &x[i * 8 + 1], 1); - unpack_q8_0_quants(qs, &x[i * 8 + 2], 2); - unpack_q8_0_quants(qs, &x[i * 8 + 3], 3); - unpack_q8_0_quants(qs, &x[i * 8 + 4], 4); - unpack_q8_0_quants(qs, &x[i * 8 + 5], 5); - unpack_q8_0_quants(qs, &x[i * 8 + 6], 6); - unpack_q8_0_quants(qs, &x[i * 8 + 7], 7); - - uint8_t * q = y_q + (i * qblk_size); - for (int j = 0; j < qk; j++) { - q[j] = qs[j]; - } - } - - // Repack the scales - // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2) - // the last block is truncated and overridden by the scales. - for (int i = 0; i < nb; i++) { - // Repack the scales - ggml_half * d = (ggml_half *) (y_d + i * dblk_size); - d[0] = x[i * 8 + 0].d; - d[1] = x[i * 8 + 1].d; - d[2] = x[i * 8 + 2].d; - d[3] = x[i * 8 + 3].d; - d[4] = x[i * 8 + 4].d; - d[5] = x[i * 8 + 5].d; - d[6] = x[i * 8 + 6].d; - d[7] = x[i * 8 + 7].d; - } - - if (opt_verbose > 2) { - for (int i = 0; i < nb; i++) { - dump_packed_block_q8x4x2(y, i, k); - } - } -} - -static void unpack_row_q8x4x2(block_q8_0 * x, const uint8_t * y, int64_t k) { - static const int qk = QK_Q8_0x4x2; - const int nb = (k + qk - 1) / qk; // number of blocks (padded) - - const int dblk_size = 8 * 2; // 8x __fp16 - const int qblk_size = qk; // int8 - const int qrow_size = k; // int8 (not padded to blocks) - - const uint8_t * y_q = y + 0; // quants first - const uint8_t * y_d = y + qrow_size; // then scales - - if (opt_verbose > 2) { - for (int i = 0; i < nb; i++) { - dump_packed_block_q8x4x2(y, i, k); - } - } - - // Unpack the quants - for (int i = 0; i < nb; i++) { - uint8_t qs[QK_Q4_0x4x2]; // unpacked quants - - const uint8_t * q = y_q + (i * qblk_size); - for (int j = 0; j < qk; j++) { - qs[j] = q[j]; - } - - pack_q8_0_quants(&x[i * 8 + 0], qs, 0); - pack_q8_0_quants(&x[i * 8 + 1], qs, 1); - pack_q8_0_quants(&x[i * 8 + 2], qs, 2); - pack_q8_0_quants(&x[i * 8 + 3], qs, 3); - pack_q8_0_quants(&x[i * 8 + 4], qs, 4); - pack_q8_0_quants(&x[i * 8 + 5], qs, 5); - pack_q8_0_quants(&x[i * 8 + 6], qs, 6); - pack_q8_0_quants(&x[i * 8 + 7], qs, 7); - } - - // Repack the scales - // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2) - // the last block is truncated and overridden by the scales. - for (int i = 0; i < nb; i++) { - // Unpack the scales - const ggml_half * d = (const ggml_half *) (y_d + i * dblk_size); - x[i * 8 + 0].d = d[0]; - x[i * 8 + 1].d = d[1]; - x[i * 8 + 2].d = d[2]; - x[i * 8 + 3].d = d[3]; - x[i * 8 + 4].d = d[4]; - x[i * 8 + 5].d = d[5]; - x[i * 8 + 6].d = d[6]; - x[i * 8 + 7].d = d[7]; - } - - if (opt_verbose > 2) { - for (int i = 0; i < nb; i++) { - dump_block_q8_0(&x[i * 8 + 0], 0); - dump_block_q8_0(&x[i * 8 + 1], 1); - dump_block_q8_0(&x[i * 8 + 2], 2); - dump_block_q8_0(&x[i * 8 + 3], 3); - dump_block_q8_0(&x[i * 8 + 4], 4); - dump_block_q8_0(&x[i * 8 + 5], 5); - dump_block_q8_0(&x[i * 8 + 6], 6); - dump_block_q8_0(&x[i * 8 + 7], 7); - } - } -} - -static void init_row_q8x4x2(block_q8_0 * x, int64_t k) { - static const int qk = QK_Q8_0x4x2; - const int nb = (k + qk - 1) / qk; // number of blocks (padded) - - // Init the quants such that they unpack into zeros - uint8_t qs[QK_Q8_0x4x2]; // unpacked quants - memset(qs, 0, sizeof(qs)); - - for (int i = 0; i < nb; i++) { - pack_q8_0_quants(&x[i * 8 + 0], qs, 0); - pack_q8_0_quants(&x[i * 8 + 1], qs, 1); - pack_q8_0_quants(&x[i * 8 + 2], qs, 2); - pack_q8_0_quants(&x[i * 8 + 3], qs, 3); - pack_q8_0_quants(&x[i * 8 + 4], qs, 4); - pack_q8_0_quants(&x[i * 8 + 5], qs, 5); - pack_q8_0_quants(&x[i * 8 + 6], qs, 6); - pack_q8_0_quants(&x[i * 8 + 7], qs, 7); - } - - // Init the scales - // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q8_0x4x2) - // the last block is truncated and overridden by the scales. - for (int i = 0; i < nb; i++) { - // Unpack the scales - x[i * 8 + 0].d = 0; - x[i * 8 + 1].d = 0; - x[i * 8 + 2].d = 0; - x[i * 8 + 3].d = 0; - x[i * 8 + 4].d = 0; - x[i * 8 + 5].d = 0; - x[i * 8 + 6].d = 0; - x[i * 8 + 7].d = 0; - } -} - -// repack q8_0 data into q8x4x2 tensor -static void repack_q8_0_q8x4x2(ggml_tensor * t, const void * data, size_t size) { - int64_t nrows = ggml_nrows(t); - - size_t row_size = ggml_row_size(t->type, t->ne[0]); - size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q8_0x4x2)); // extra elements for the pad - size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size quants + scales) - - // Ensure we don't try to read more data than is available in the source buffer 'data' - // or write more than the tensor can hold. - const size_t total_tensor_size = (size_t)nrows * row_size; - const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size; - - // Calculate how many full rows and how many remaining bytes we need to process. - const int64_t n_full_rows = n_bytes_to_copy / row_size; - const size_t n_rem_bytes = n_bytes_to_copy % row_size; - - void * buf_pd = ggml_aligned_malloc(row_size_pd); - GGML_ASSERT(buf_pd != NULL); - - void * buf_rp = ggml_aligned_malloc(row_size_rp); - GGML_ASSERT(buf_rp != NULL); - - HEX_VERBOSE("ggml-hex: repack-q8_0-q8x4x2 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, size, - t->ne[0], nrows, row_size); - - init_row_q8x4x2((block_q8_0 *) buf_pd, t->ne[0]); // init padded buffer to make sure the tail is all zeros - - // 1. Process all the full rows - for (int64_t i = 0; i < n_full_rows; i++) { - const uint8_t * src = (const uint8_t *) data + (i * row_size); - uint8_t * dst = (uint8_t *) t->data + (i * row_size); - - memcpy(buf_pd, src, row_size); - repack_row_q8x4x2((uint8_t *) buf_rp, (const block_q8_0 *) buf_pd, t->ne[0]); - memcpy(dst, buf_rp, row_size); - } - - // 2. Process the final, potentially partial, row - if (n_rem_bytes > 0) { - const int64_t i = n_full_rows; - const uint8_t * src = (const uint8_t *) data + (i * row_size); - uint8_t * dst = (uint8_t *) t->data + (i * row_size); - - // re-init the row because we are potentially copying a partial row - init_row_q8x4x2((block_q8_0 *) buf_pd, t->ne[0]); - - // Copy only the remaining bytes from the source. - memcpy(buf_pd, src, n_rem_bytes); - - // Repack the entire buffer - repack_row_q8x4x2((uint8_t *) buf_rp, (const block_q8_0 *) buf_pd, t->ne[0]); - - // Write only the corresponding remaining bytes to the destination tensor. - memcpy(dst, buf_rp, n_rem_bytes); - } - - ggml_aligned_free(buf_pd, row_size_pd); - ggml_aligned_free(buf_rp, row_size_rp); -} - -// repack q8x4x2 tensor into q8_0 data -static void repack_q8x4x2_q8_0(void * data, const ggml_tensor * t, size_t size) { - int64_t nrows = ggml_nrows(t); - - size_t row_size = ggml_row_size(t->type, t->ne[0]); - size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q8_0x4x2)); // extra elements for the pad - size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size quants + scales) - - // Ensure we don't try to copy more data than the tensor actually contains. - const size_t total_tensor_size = (size_t)nrows * row_size; - const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size; - - // Calculate how many full rows and how many remaining bytes we need to process. - const int64_t n_full_rows = n_bytes_to_copy / row_size; - const size_t n_rem_bytes = n_bytes_to_copy % row_size; - - void * buf_pd = ggml_aligned_malloc(row_size_pd); - GGML_ASSERT(buf_pd != NULL); - - void * buf_rp = ggml_aligned_malloc(row_size_rp); - GGML_ASSERT(buf_rp != NULL); - - HEX_VERBOSE("ggml-hex: repack-q8x4x2-q8_0 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, size, - t->ne[0], nrows, row_size); - - memset(buf_pd, 0, row_size_pd); // clear-out padded buffer to make sure the tail is all zeros - - // 1. Process all the full rows - for (int64_t i = 0; i < n_full_rows; i++) { - const uint8_t * src = (const uint8_t *) t->data + (i * row_size); - uint8_t * dst = (uint8_t *) data + (i * row_size); - - memcpy(buf_pd, src, row_size); - unpack_row_q8x4x2((block_q8_0 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]); - memcpy(dst, buf_rp, row_size); - } - - // 2. Process the final, potentially partial, row - if (n_rem_bytes > 0) { - const int64_t i = n_full_rows; - const uint8_t * src = (const uint8_t *) t->data + (i * row_size); - uint8_t * dst = (uint8_t *) data + (i * row_size); - - // We still need to read and unpack the entire source row because quantization is block-based. - memcpy(buf_pd, src, row_size); - unpack_row_q8x4x2((block_q8_0 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]); - - // But we only copy the remaining number of bytes to the destination. - memcpy(dst, buf_rp, n_rem_bytes); - } - - ggml_aligned_free(buf_pd, row_size_pd); - ggml_aligned_free(buf_rp, row_size_rp); -} - -// ======== MXFP4x4x2 ==================== -struct x2_mxfp4 { - int v[2]; -}; - -static x2_mxfp4 unpack_mxfp4(uint8_t v) { - x2_mxfp4 x; - x.v[0] = kvalues_mxfp4[(v & 0x0f)]; - x.v[1] = kvalues_mxfp4[(v >> 4)]; - return x; -} - -static void dump_block_mxfp4(const block_mxfp4 * b, int i) { - HEX_VERBOSE("ggml-hex: repack mxfp4 %d: %d %d %d %d ... %d %d %d %d : %.6f\n", i, unpack_mxfp4(b->qs[0]).v[0], - unpack_mxfp4(b->qs[1]).v[0], unpack_mxfp4(b->qs[2]).v[0], unpack_mxfp4(b->qs[3]).v[0], - unpack_mxfp4(b->qs[12]).v[1], unpack_mxfp4(b->qs[13]).v[1], unpack_mxfp4(b->qs[14]).v[1], - unpack_mxfp4(b->qs[15]).v[1], GGML_E8M0_TO_FP32_HALF(b->e)); -} - -static void dump_packed_block_mxfp4x4x2(const uint8_t * v, unsigned int i, size_t k) { - static const int qk = QK_MXFP4x4x2; - const int eblk_size = 8 * 1; // 8x E8M0 - const int qblk_size = qk / 2; // int4 - const int qrow_size = k / 2; // int4 (not padded) - - const uint8_t * v_q = v + 0; // quants first - const uint8_t * v_e = v + qrow_size; // then scales - - const uint8_t * q = v_q + i * qblk_size; - const uint8_t * e = (const uint8_t *) (v_e + i * eblk_size); - - HEX_VERBOSE("ggml-hex: repack mxfp4x4x2-%d: %d %d %d %d ... %d %d %d %d ... %d %d %d %d : %.6f %.6f %.6f %.6f\n", i, - unpack_mxfp4(q[0]).v[0], unpack_mxfp4(q[1]).v[0], unpack_mxfp4(q[2]).v[0], unpack_mxfp4(q[3]).v[0], - unpack_mxfp4(q[60]).v[0], unpack_mxfp4(q[61]).v[0], unpack_mxfp4(q[62]).v[0], unpack_mxfp4(q[63]).v[0], - unpack_mxfp4(q[124]).v[0], unpack_mxfp4(q[125]).v[0], unpack_mxfp4(q[126]).v[0], - unpack_mxfp4(q[127]).v[0], GGML_E8M0_TO_FP32_HALF(e[0]), GGML_E8M0_TO_FP32_HALF(e[1]), - GGML_E8M0_TO_FP32_HALF(e[2]), GGML_E8M0_TO_FP32_HALF(e[3])); - - HEX_VERBOSE("ggml-hex: repack mxfp4x4x2-%d: %d %d %d %d ... %d %d %d %d ... %d %d %d %d : %.6f %.6f %.6f %.6f\n", - i + 1, unpack_mxfp4(q[0]).v[1], unpack_mxfp4(q[1]).v[1], unpack_mxfp4(q[2]).v[1], - unpack_mxfp4(q[3]).v[1], unpack_mxfp4(q[60]).v[1], unpack_mxfp4(q[61]).v[1], unpack_mxfp4(q[62]).v[1], - unpack_mxfp4(q[63]).v[1], unpack_mxfp4(q[124]).v[1], unpack_mxfp4(q[125]).v[1], - unpack_mxfp4(q[126]).v[1], unpack_mxfp4(q[127]).v[1], GGML_E8M0_TO_FP32_HALF(e[4]), - GGML_E8M0_TO_FP32_HALF(e[5]), GGML_E8M0_TO_FP32_HALF(e[6]), GGML_E8M0_TO_FP32_HALF(e[7])); -} - static void unpack_mxfp4_quants(uint8_t * qs, const block_mxfp4 * x, unsigned int bi) { static const int qk = QK_MXFP4; for (unsigned int i = 0; i < qk / 2; ++i) { - const uint8_t x0 = (x->qs[i] & 0x0F); - const uint8_t x1 = (x->qs[i] >> 4); + const int x0 = (x->qs[i] & 0x0F); + const int x1 = (x->qs[i] >> 4); qs[bi * qk + i + 0] = x0; qs[bi * qk + i + qk / 2] = x1; } } static void pack_mxfp4_quants(block_mxfp4 * x, const uint8_t * qs, unsigned int bi) { - static const int qk = QK4_0; + static const int qk = QK_MXFP4; for (unsigned int i = 0; i < qk / 2; ++i) { const uint8_t x0 = qs[bi * qk + i + 0]; @@ -1308,299 +435,419 @@ static void pack_mxfp4_quants(block_mxfp4 * x, const uint8_t * qs, unsigned int } } -static void repack_row_mxfp4x4x2(uint8_t * y, const block_mxfp4 * x, int64_t k) { - static const int qk = QK_MXFP4x4x2; - const int nb = (k + qk - 1) / qk; // number of blocks (padded) - const int nloe = k % qk; // leftovers +// repack q4_0 data into q4_0_tiled tensor +static void repack_q4_0_tiled(ggml_tensor * t, const void * data, size_t size) { + const block_q4_0 * src_matrix = (const block_q4_0 *) data; + int64_t ne0 = t->ne[0]; + int64_t ne1 = t->ne[1]; + int64_t ne2 = t->ne[2]; + int64_t ne3 = t->ne[3]; + int64_t ne0_padded = hex_round_up(ne0, 32); + int64_t ne1_padded = hex_round_up(ne1, 32); - const int eblk_size = 8 * 1; // 8x E8M0 - const int qblk_size = qk / 2; // int4 - const int qrow_size = k / 2; // int4 (not padded to blocks) + int n_col_tiles = ne1_padded / 32; + int n_k_tiles = ne0_padded / 32; + const size_t tile_size = HTP_MM_WEIGHT_TILE_SIZE_Q4_0; + const size_t matrix_size = n_col_tiles * n_k_tiles * tile_size; - uint8_t * y_q = y + 0; // quants first - uint8_t * y_e = y + qrow_size; // then scales + for (int i3 = 0; i3 < ne3; i3++) { + for (int i2 = 0; i2 < ne2; i2++) { + const block_q4_0 * src_expert = src_matrix + (i3 * ne2 + i2) * (ne1 * (ne0 / 32)); + uint8_t * matrix_dst = (uint8_t *) t->data + (i3 * ne2 + i2) * matrix_size; - if (opt_verbose > 2) { - for (int i = 0; i < nb; i++) { - dump_block_mxfp4(&x[i * 8 + 0], 0); - dump_block_mxfp4(&x[i * 8 + 1], 1); - dump_block_mxfp4(&x[i * 8 + 2], 2); - dump_block_mxfp4(&x[i * 8 + 3], 3); - dump_block_mxfp4(&x[i * 8 + 4], 4); - dump_block_mxfp4(&x[i * 8 + 5], 5); - dump_block_mxfp4(&x[i * 8 + 6], 6); - dump_block_mxfp4(&x[i * 8 + 7], 7); - } - } + for (int ct = 0; ct < n_col_tiles; ct++) { + for (int kt = 0; kt < n_k_tiles; kt++) { + uint8_t * tile_dst = matrix_dst + (ct * n_k_tiles + kt) * tile_size; - // Repack the quants - for (int i = 0; i < nb; i++) { - uint8_t qs[QK_MXFP4x4x2]; // unpacked quants + uint8_t tile_quants[32][32]; + for (int row = 0; row < 32; row++) { + int64_t r = ct * 32 + row; + if (r < ne1 && kt < ne0 / 32) { + unpack_q4_0_quants(tile_quants[row], &src_expert[r * (ne0 / 32) + kt], 0); + } else { + memset(tile_quants[row], 8, 32); + } + } - unpack_mxfp4_quants(qs, &x[i * 8 + 0], 0); - unpack_mxfp4_quants(qs, &x[i * 8 + 1], 1); - unpack_mxfp4_quants(qs, &x[i * 8 + 2], 2); - unpack_mxfp4_quants(qs, &x[i * 8 + 3], 3); - unpack_mxfp4_quants(qs, &x[i * 8 + 4], 4); - unpack_mxfp4_quants(qs, &x[i * 8 + 5], 5); - unpack_mxfp4_quants(qs, &x[i * 8 + 6], 6); - unpack_mxfp4_quants(qs, &x[i * 8 + 7], 7); + for (int cp = 0; cp < 16; cp++) { + for (int row = 0; row < 32; row++) { + tile_dst[cp * 32 + row] = (tile_quants[row][2 * cp + 1] << 4) | tile_quants[row][2 * cp]; + } + } - bool partial = (nloe && i == nb-1); - - uint8_t * q = y_q + (i * qblk_size); - for (int j = 0; j < qk / 2; j++) { - q[j] = partial ? (qs[j*2+1] << 4) | qs[j*2+0] : (qs[j+128] << 4) | qs[j+000]; - } - } - - // Repack the scales - // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_MXFP4x4x2) - // the last block is truncated and overridden by the scales. - for (int i = 0; i < nb; i++) { - // Repack the scales - uint8_t * e = (uint8_t *) (y_e + i * eblk_size); - e[0] = x[i * 8 + 0].e; - e[1] = x[i * 8 + 1].e; - e[2] = x[i * 8 + 2].e; - e[3] = x[i * 8 + 3].e; - e[4] = x[i * 8 + 4].e; - e[5] = x[i * 8 + 5].e; - e[6] = x[i * 8 + 6].e; - e[7] = x[i * 8 + 7].e; - } - - if (opt_verbose > 2) { - for (int i = 0; i < nb; i++) { - dump_packed_block_mxfp4x4x2(y, i, k); - } - } -} - -static void unpack_row_mxfp4x4x2(block_mxfp4 * x, const uint8_t * y, int64_t k) { - static const int qk = QK_MXFP4x4x2; - const int nb = (k + qk - 1) / qk; // number of blocks (padded) - const int nloe = k % qk; // leftovers - - const int eblk_size = 8 * 1; // 8x E8M0 - const int qblk_size = qk / 2; // int4 - const int qrow_size = k / 2; // int4 (not padded to blocks) - - const uint8_t * y_q = y + 0; // quants first - const uint8_t * y_e = y + qrow_size; // then scales - - if (opt_verbose > 2) { - for (int i = 0; i < nb; i++) { - dump_packed_block_mxfp4x4x2(y, i, k); - } - } - - // Unpack the quants - for (int i = 0; i < nb; i++) { - uint8_t qs[QK_MXFP4x4x2]; // unpacked quants - - bool partial = (nloe && i == nb-1); - - const uint8_t * q = y_q + (i * qblk_size); - for (int j = 0; j < qk / 2; j++) { - if (partial) { - qs[j*2+0] = q[j] & 0xf; - qs[j*2+1] = q[j] >> 4; - } else { - qs[j+000] = q[j] & 0xf; - qs[j+128] = q[j] >> 4; + ggml_half * scale_dst = (ggml_half *)(tile_dst + 512); + for (int row = 0; row < 32; row++) { + int64_t r = ct * 32 + row; + scale_dst[row] = (r < ne1 && kt < ne0 / 32) ? src_expert[r * (ne0 / 32) + kt].d : 0; + } + } } } - - pack_mxfp4_quants(&x[i * 8 + 0], qs, 0); - pack_mxfp4_quants(&x[i * 8 + 1], qs, 1); - pack_mxfp4_quants(&x[i * 8 + 2], qs, 2); - pack_mxfp4_quants(&x[i * 8 + 3], qs, 3); - pack_mxfp4_quants(&x[i * 8 + 4], qs, 4); - pack_mxfp4_quants(&x[i * 8 + 5], qs, 5); - pack_mxfp4_quants(&x[i * 8 + 6], qs, 6); - pack_mxfp4_quants(&x[i * 8 + 7], qs, 7); } +} - // Repack the scales - // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_MXFP4_0x4x2) - // the last block is truncated and overridden by the scales. - for (int i = 0; i < nb; i++) { - // Unpack the scales - const uint8_t * e = (const uint8_t *) (y_e + i * eblk_size); - x[i * 8 + 0].e = e[0]; - x[i * 8 + 1].e = e[1]; - x[i * 8 + 2].e = e[2]; - x[i * 8 + 3].e = e[3]; - x[i * 8 + 4].e = e[4]; - x[i * 8 + 5].e = e[5]; - x[i * 8 + 6].e = e[6]; - x[i * 8 + 7].e = e[7]; - } +// repack q4_0_tiled tensor into q4_0 data +static void repack_tiled_q4_0(void * data, const ggml_tensor * t, size_t size) { + block_q4_0 * dst_matrix = (block_q4_0 *) data; + int64_t ne0 = t->ne[0]; + int64_t ne1 = t->ne[1]; + int64_t ne2 = t->ne[2]; + int64_t ne3 = t->ne[3]; + int64_t ne0_padded = hex_round_up(ne0, 32); + int64_t ne1_padded = hex_round_up(ne1, 32); - if (opt_verbose > 2) { - for (int i = 0; i < nb; i++) { - dump_block_mxfp4(&x[i * 8 + 0], 0); - dump_block_mxfp4(&x[i * 8 + 1], 1); - dump_block_mxfp4(&x[i * 8 + 2], 2); - dump_block_mxfp4(&x[i * 8 + 3], 3); - dump_block_mxfp4(&x[i * 8 + 4], 4); - dump_block_mxfp4(&x[i * 8 + 5], 5); - dump_block_mxfp4(&x[i * 8 + 6], 6); - dump_block_mxfp4(&x[i * 8 + 7], 7); + int n_col_tiles = ne1_padded / 32; + int n_k_tiles = ne0_padded / 32; + const size_t tile_size = HTP_MM_WEIGHT_TILE_SIZE_Q4_0; + const size_t matrix_size = n_col_tiles * n_k_tiles * tile_size; + + for (int i3 = 0; i3 < ne3; i3++) { + for (int i2 = 0; i2 < ne2; i2++) { + block_q4_0 * dst_expert = dst_matrix + (i3 * ne2 + i2) * (ne1 * (ne0 / 32)); + const uint8_t * matrix_src = (const uint8_t *) t->data + (i3 * ne2 + i2) * matrix_size; + + for (int ct = 0; ct < n_col_tiles; ct++) { + for (int kt = 0; kt < n_k_tiles; kt++) { + const uint8_t * tile_src = matrix_src + (ct * n_k_tiles + kt) * tile_size; + + uint8_t tile_quants[32][32]; + for (int cp = 0; cp < 16; cp++) { + for (int row = 0; row < 32; row++) { + uint8_t val = tile_src[cp * 32 + row]; + tile_quants[row][2 * cp + 0] = val & 0x0F; + tile_quants[row][2 * cp + 1] = val >> 4; + } + } + + for (int row = 0; row < 32; row++) { + int64_t r = ct * 32 + row; + if (r < ne1 && kt < ne0 / 32) { + pack_q4_0_quants(&dst_expert[r * (ne0 / 32) + kt], tile_quants[row], 0); + } + } + + const ggml_half * scale_src = (const ggml_half *)(tile_src + 512); + for (int row = 0; row < 32; row++) { + int64_t r = ct * 32 + row; + if (r < ne1 && kt < ne0 / 32) { + dst_expert[r * (ne0 / 32) + kt].d = scale_src[row]; + } + } + } + } } } } -static void init_row_mxfp4x4x2(block_mxfp4 * x, int64_t k) { - static const int qk = QK_MXFP4x4x2; - const int nb = (k + qk - 1) / qk; // number of blocks (padded) +// repack q4_1 data into q4_1_tiled tensor +static void repack_q4_1_tiled(ggml_tensor * t, const void * data, size_t size) { + const block_q4_1 * src_matrix = (const block_q4_1 *) data; + int64_t ne0 = t->ne[0]; + int64_t ne1 = t->ne[1]; + int64_t ne2 = t->ne[2]; + int64_t ne3 = t->ne[3]; + int64_t ne0_padded = hex_round_up(ne0, 32); + int64_t ne1_padded = hex_round_up(ne1, 32); - // Init the quants such that they unpack into zeros - uint8_t qs[QK_MXFP4x4x2]; // unpacked quants - memset(qs, 0, sizeof(qs)); + int n_col_tiles = ne1_padded / 32; + int n_k_tiles = ne0_padded / 32; + const size_t tile_size = HTP_MM_WEIGHT_TILE_SIZE_Q4_1; + const size_t matrix_size = n_col_tiles * n_k_tiles * tile_size; - for (int i = 0; i < nb; i++) { - pack_mxfp4_quants(&x[i * 8 + 0], qs, 0); - pack_mxfp4_quants(&x[i * 8 + 1], qs, 1); - pack_mxfp4_quants(&x[i * 8 + 2], qs, 2); - pack_mxfp4_quants(&x[i * 8 + 3], qs, 3); - pack_mxfp4_quants(&x[i * 8 + 4], qs, 4); - pack_mxfp4_quants(&x[i * 8 + 5], qs, 5); - pack_mxfp4_quants(&x[i * 8 + 6], qs, 6); - pack_mxfp4_quants(&x[i * 8 + 7], qs, 7); - } + for (int i3 = 0; i3 < ne3; i3++) { + for (int i2 = 0; i2 < ne2; i2++) { + const block_q4_1 * src_expert = src_matrix + (i3 * ne2 + i2) * (ne1 * (ne0 / 32)); + uint8_t * matrix_dst = (uint8_t *) t->data + (i3 * ne2 + i2) * matrix_size; - // Init the scales - // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_MXFP4x4x2) - // the last block is truncated and overridden by the scales. - for (int i = 0; i < nb; i++) { - // Unpack the scales - x[i * 8 + 0].e = 0; - x[i * 8 + 1].e = 0; - x[i * 8 + 2].e = 0; - x[i * 8 + 3].e = 0; - x[i * 8 + 4].e = 0; - x[i * 8 + 5].e = 0; - x[i * 8 + 6].e = 0; - x[i * 8 + 7].e = 0; + for (int ct = 0; ct < n_col_tiles; ct++) { + for (int kt = 0; kt < n_k_tiles; kt++) { + uint8_t * tile_dst = matrix_dst + (ct * n_k_tiles + kt) * tile_size; + + uint8_t tile_quants[32][32]; + for (int row = 0; row < 32; row++) { + int64_t r = ct * 32 + row; + if (r < ne1 && kt < ne0 / 32) { + unpack_q4_1_quants(tile_quants[row], &src_expert[r * (ne0 / 32) + kt], 0); + } else { + memset(tile_quants[row], 0, 32); + } + } + + for (int cp = 0; cp < 16; cp++) { + for (int row = 0; row < 32; row++) { + tile_dst[cp * 32 + row] = (tile_quants[row][2 * cp + 1] << 4) | tile_quants[row][2 * cp]; + } + } + + ggml_half * scale_dst = (ggml_half *)(tile_dst + 512); + for (int row = 0; row < 32; row++) { + int64_t r = ct * 32 + row; + if (r < ne1 && kt < ne0 / 32) { + scale_dst[2 * row + 0] = src_expert[r * (ne0 / 32) + kt].d; + scale_dst[2 * row + 1] = src_expert[r * (ne0 / 32) + kt].m; + } else { + scale_dst[2 * row + 0] = 0; + scale_dst[2 * row + 1] = 0; + } + } + } + } + } } } -// repack mxfp4 data into mxfp4x4x2 tensor -static void repack_mxfp4_mxfp4x4x2(ggml_tensor * t, const void * data, size_t size) { - int64_t nrows = ggml_nrows(t); +// repack q4_1_tiled tensor into q4_1 data +static void repack_tiled_q4_1(void * data, const ggml_tensor * t, size_t size) { + block_q4_1 * dst_matrix = (block_q4_1 *) data; + int64_t ne0 = t->ne[0]; + int64_t ne1 = t->ne[1]; + int64_t ne2 = t->ne[2]; + int64_t ne3 = t->ne[3]; + int64_t ne0_padded = hex_round_up(ne0, 32); + int64_t ne1_padded = hex_round_up(ne1, 32); - size_t row_size = ggml_row_size(t->type, t->ne[0]); - size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_MXFP4x4x2)); // extra elements for the pad - size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size/2 quants + scales) + int n_col_tiles = ne1_padded / 32; + int n_k_tiles = ne0_padded / 32; + const size_t tile_size = HTP_MM_WEIGHT_TILE_SIZE_Q4_1; + const size_t matrix_size = n_col_tiles * n_k_tiles * tile_size; - // Ensure we don't try to read more data than is available in the source buffer 'data' - // or write more than the tensor can hold. - const size_t total_tensor_size = (size_t)nrows * row_size; - const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size; + for (int i3 = 0; i3 < ne3; i3++) { + for (int i2 = 0; i2 < ne2; i2++) { + block_q4_1 * dst_expert = dst_matrix + (i3 * ne2 + i2) * (ne1 * (ne0 / 32)); + const uint8_t * matrix_src = (const uint8_t *) t->data + (i3 * ne2 + i2) * matrix_size; - // Calculate how many full rows and how many remaining bytes we need to process. - const int64_t n_full_rows = n_bytes_to_copy / row_size; - const size_t n_rem_bytes = n_bytes_to_copy % row_size; + for (int ct = 0; ct < n_col_tiles; ct++) { + for (int kt = 0; kt < n_k_tiles; kt++) { + const uint8_t * tile_src = matrix_src + (ct * n_k_tiles + kt) * tile_size; - void * buf_pd = ggml_aligned_malloc(row_size_pd); - GGML_ASSERT(buf_pd != NULL); + uint8_t tile_quants[32][32]; + for (int cp = 0; cp < 16; cp++) { + for (int row = 0; row < 32; row++) { + uint8_t val = tile_src[cp * 32 + row]; + tile_quants[row][2 * cp + 0] = val & 0x0F; + tile_quants[row][2 * cp + 1] = val >> 4; + } + } - void * buf_rp = ggml_aligned_malloc(row_size_rp); - GGML_ASSERT(buf_rp != NULL); + for (int row = 0; row < 32; row++) { + int64_t r = ct * 32 + row; + if (r < ne1 && kt < ne0 / 32) { + pack_q4_1_quants(&dst_expert[r * (ne0 / 32) + kt], tile_quants[row], 0); + } + } - HEX_VERBOSE("ggml-hex: repack-mxfp4-mxfp4x4x2 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, - size, t->ne[0], nrows, row_size); - - init_row_mxfp4x4x2((block_mxfp4 *) buf_pd, t->ne[0]); // init padded buffer to make sure the tail is all zeros - - // 1. Process all the full rows - for (int64_t i = 0; i < n_full_rows; i++) { - const uint8_t * src = (const uint8_t *) data + (i * row_size); - uint8_t * dst = (uint8_t *) t->data + (i * row_size); - - memcpy(buf_pd, src, row_size); - repack_row_mxfp4x4x2((uint8_t *) buf_rp, (const block_mxfp4 *) buf_pd, t->ne[0]); - memcpy(dst, buf_rp, row_size); + const ggml_half * scale_src = (const ggml_half *)(tile_src + 512); + for (int row = 0; row < 32; row++) { + int64_t r = ct * 32 + row; + if (r < ne1 && kt < ne0 / 32) { + dst_expert[r * (ne0 / 32) + kt].d = scale_src[2 * row]; + dst_expert[r * (ne0 / 32) + kt].m = scale_src[2 * row + 1]; + } + } + } + } + } } - - // 2. Process the final, potentially partial, row - if (n_rem_bytes > 0) { - const int64_t i = n_full_rows; - const uint8_t * src = (const uint8_t *) data + (i * row_size); - uint8_t * dst = (uint8_t *) t->data + (i * row_size); - - // re-init the row because we are potentially copying a partial row - init_row_mxfp4x4x2((block_mxfp4 *) buf_pd, t->ne[0]); - - // Copy only the remaining bytes from the source. - memcpy(buf_pd, src, n_rem_bytes); - - // Repack the entire buffer (partial data + zero padding). - repack_row_mxfp4x4x2((uint8_t *) buf_rp, (const block_mxfp4 *) buf_pd, t->ne[0]); - - // Write only the corresponding remaining bytes to the destination tensor. - memcpy(dst, buf_rp, n_rem_bytes); - } - - ggml_aligned_free(buf_pd, row_size_pd); - ggml_aligned_free(buf_rp, row_size_rp); } -// repack mxfp4x4x2 tensor into mxfp4 data -static void repack_mxfp4x4x2_mxfp4(void * data, const ggml_tensor * t, size_t size) { - int64_t nrows = ggml_nrows(t); +// repack q8_0 data into q8_0_tiled tensor +static void repack_q8_0_tiled(ggml_tensor * t, const void * data, size_t size) { + const block_q8_0 * src_matrix = (const block_q8_0 *) data; + int64_t ne0 = t->ne[0]; + int64_t ne1 = t->ne[1]; + int64_t ne2 = t->ne[2]; + int64_t ne3 = t->ne[3]; + int64_t ne0_padded = hex_round_up(ne0, 32); + int64_t ne1_padded = hex_round_up(ne1, 32); - size_t row_size = ggml_row_size(t->type, t->ne[0]); - size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_MXFP4x4x2)); // extra elements for the pad - size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size/2 quants + scales) + int n_col_tiles = ne1_padded / 32; + int n_k_tiles = ne0_padded / 32; + const size_t tile_size = HTP_MM_WEIGHT_TILE_SIZE_Q8_0; + const size_t matrix_size = n_col_tiles * n_k_tiles * tile_size; - // Ensure we don't try to copy more data than the tensor actually contains. - const size_t total_tensor_size = (size_t)nrows * row_size; - const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size; + for (int i3 = 0; i3 < ne3; i3++) { + for (int i2 = 0; i2 < ne2; i2++) { + const block_q8_0 * src_expert = src_matrix + (i3 * ne2 + i2) * (ne1 * (ne0 / 32)); + uint8_t * matrix_dst = (uint8_t *) t->data + (i3 * ne2 + i2) * matrix_size; - // Calculate how many full rows and how many remaining bytes we need to process. - const int64_t n_full_rows = n_bytes_to_copy / row_size; - const size_t n_rem_bytes = n_bytes_to_copy % row_size; + for (int ct = 0; ct < n_col_tiles; ct++) { + for (int kt = 0; kt < n_k_tiles; kt++) { + uint8_t * tile_dst = matrix_dst + (ct * n_k_tiles + kt) * tile_size; - void * buf_pd = ggml_aligned_malloc(row_size_pd); - GGML_ASSERT(buf_pd != NULL); + for (int cp = 0; cp < 16; cp++) { + int col0 = cp * 2; + int col1 = col0 + 1; + for (int row = 0; row < 32; row++) { + int64_t r = ct * 32 + row; + const block_q8_0 * b = (r < ne1 && kt < ne0 / 32) ? &src_expert[r * (ne0 / 32) + kt] : NULL; + tile_dst[cp * 64 + 2 * row + 0] = b ? b->qs[col0] : 0; + tile_dst[cp * 64 + 2 * row + 1] = b ? b->qs[col1] : 0; + } + } - void * buf_rp = ggml_aligned_malloc(row_size_rp); - GGML_ASSERT(buf_rp != NULL); - - HEX_VERBOSE("ggml-hex: repack-mxfp4x4x2-mxfp4 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, - size, t->ne[0], nrows, row_size); - - memset(buf_pd, 0, row_size_pd); // clear-out padded buffer to make sure the tail is all zeros - - // 1. Process all the full rows - for (int64_t i = 0; i < n_full_rows; i++) { - const uint8_t * src = (const uint8_t *) t->data + (i * row_size); - uint8_t * dst = (uint8_t *) data + (i * row_size); - - memcpy(buf_pd, src, row_size); - unpack_row_mxfp4x4x2((block_mxfp4 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]); - memcpy(dst, buf_rp, row_size); + ggml_half * scale_dst = (ggml_half *)(tile_dst + 1024); + for (int row = 0; row < 32; row++) { + int64_t r = ct * 32 + row; + scale_dst[row] = (r < ne1 && kt < ne0 / 32) ? src_expert[r * (ne0 / 32) + kt].d : 0; + } + } + } + } } +} - // 2. Process the final, potentially partial, row - if (n_rem_bytes > 0) { - const int64_t i = n_full_rows; - const uint8_t * src = (const uint8_t *) t->data + (i * row_size); - uint8_t * dst = (uint8_t *) data + (i * row_size); +// repack q8_0_tiled tensor into q8_0 data +static void repack_tiled_q8_0(void * data, const ggml_tensor * t, size_t size) { + block_q8_0 * dst_matrix = (block_q8_0 *) data; + int64_t ne0 = t->ne[0]; + int64_t ne1 = t->ne[1]; + int64_t ne2 = t->ne[2]; + int64_t ne3 = t->ne[3]; + int64_t ne0_padded = hex_round_up(ne0, 32); + int64_t ne1_padded = hex_round_up(ne1, 32); - // We still need to read and unpack the entire source row because the format is block-based. - memcpy(buf_pd, src, row_size); - unpack_row_mxfp4x4x2((block_mxfp4 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]); + int n_col_tiles = ne1_padded / 32; + int n_k_tiles = ne0_padded / 32; + const size_t tile_size = HTP_MM_WEIGHT_TILE_SIZE_Q8_0; + const size_t matrix_size = n_col_tiles * n_k_tiles * tile_size; - // But we only copy the remaining number of bytes to the destination to respect the size limit. - memcpy(dst, buf_rp, n_rem_bytes); + for (int i3 = 0; i3 < ne3; i3++) { + for (int i2 = 0; i2 < ne2; i2++) { + block_q8_0 * dst_expert = dst_matrix + (i3 * ne2 + i2) * (ne1 * (ne0 / 32)); + const uint8_t * matrix_src = (const uint8_t *) t->data + (i3 * ne2 + i2) * matrix_size; + + for (int ct = 0; ct < n_col_tiles; ct++) { + for (int kt = 0; kt < n_k_tiles; kt++) { + const uint8_t * tile_src = matrix_src + (ct * n_k_tiles + kt) * tile_size; + + for (int cp = 0; cp < 16; cp++) { + int col0 = cp * 2; + int col1 = col0 + 1; + for (int row = 0; row < 32; row++) { + int64_t r = ct * 32 + row; + if (r < ne1 && kt < ne0 / 32) { + block_q8_0 & b = dst_expert[r * (ne0 / 32) + kt]; + b.qs[col0] = tile_src[cp * 64 + 2 * row + 0]; + b.qs[col1] = tile_src[cp * 64 + 2 * row + 1]; + } + } + } + + const ggml_half * scale_src = (const ggml_half *)(tile_src + 1024); + for (int row = 0; row < 32; row++) { + int64_t r = ct * 32 + row; + if (r < ne1 && kt < ne0 / 32) { + dst_expert[r * (ne0 / 32) + kt].d = scale_src[row]; + } + } + } + } + } } +} - ggml_aligned_free(buf_pd, row_size_pd); - ggml_aligned_free(buf_rp, row_size_rp); +// repack mxfp4 data into mxfp4_tiled tensor +static void repack_mxfp4_tiled(ggml_tensor * t, const void * data, size_t size) { + const block_mxfp4 * src_matrix = (const block_mxfp4 *) data; + int64_t ne0 = t->ne[0]; + int64_t ne1 = t->ne[1]; + int64_t ne2 = t->ne[2]; + int64_t ne3 = t->ne[3]; + int64_t ne0_padded = hex_round_up(ne0, 32); + int64_t ne1_padded = hex_round_up(ne1, 32); + + int n_col_tiles = ne1_padded / 32; + int n_k_tiles = ne0_padded / 32; + const size_t tile_size = HTP_MM_WEIGHT_TILE_SIZE_MXFP4; + const size_t matrix_size = n_col_tiles * n_k_tiles * tile_size; + + for (int i3 = 0; i3 < ne3; i3++) { + for (int i2 = 0; i2 < ne2; i2++) { + const block_mxfp4 * src_expert = src_matrix + (i3 * ne2 + i2) * (ne1 * (ne0 / 32)); + uint8_t * matrix_dst = (uint8_t *) t->data + (i3 * ne2 + i2) * matrix_size; + + for (int ct = 0; ct < n_col_tiles; ct++) { + for (int kt = 0; kt < n_k_tiles; kt++) { + uint8_t * tile_dst = matrix_dst + (ct * n_k_tiles + kt) * tile_size; + + uint8_t tile_quants[32][32]; + for (int row = 0; row < 32; row++) { + int64_t r = ct * 32 + row; + if (r < ne1 && kt < ne0 / 32) { + unpack_mxfp4_quants(tile_quants[row], &src_expert[r * (ne0 / 32) + kt], 0); + } else { + memset(tile_quants[row], 0, 32); + } + } + + for (int cp = 0; cp < 16; cp++) { + for (int row = 0; row < 32; row++) { + tile_dst[cp * 32 + row] = (tile_quants[row][2 * cp + 1] << 4) | tile_quants[row][2 * cp]; + } + } + + uint8_t * scale_dst = tile_dst + 512; + for (int row = 0; row < 32; row++) { + int64_t r = ct * 32 + row; + scale_dst[row] = (r < ne1 && kt < ne0 / 32) ? src_expert[r * (ne0 / 32) + kt].e : 0; + } + } + } + } + } +} + +// repack mxfp4_tiled tensor into mxfp4 data +static void repack_tiled_mxfp4(void * data, const ggml_tensor * t, size_t size) { + block_mxfp4 * dst_matrix = (block_mxfp4 *) data; + int64_t ne0 = t->ne[0]; + int64_t ne1 = t->ne[1]; + int64_t ne2 = t->ne[2]; + int64_t ne3 = t->ne[3]; + int64_t ne0_padded = hex_round_up(ne0, 32); + int64_t ne1_padded = hex_round_up(ne1, 32); + + int n_col_tiles = ne1_padded / 32; + int n_k_tiles = ne0_padded / 32; + const size_t tile_size = HTP_MM_WEIGHT_TILE_SIZE_MXFP4; + const size_t matrix_size = n_col_tiles * n_k_tiles * tile_size; + + for (int i3 = 0; i3 < ne3; i3++) { + for (int i2 = 0; i2 < ne2; i2++) { + block_mxfp4 * dst_expert = dst_matrix + (i3 * ne2 + i2) * (ne1 * (ne0 / 32)); + const uint8_t * matrix_src = (const uint8_t *) t->data + (i3 * ne2 + i2) * matrix_size; + + for (int ct = 0; ct < n_col_tiles; ct++) { + for (int kt = 0; kt < n_k_tiles; kt++) { + const uint8_t * tile_src = matrix_src + (ct * n_k_tiles + kt) * tile_size; + + uint8_t tile_quants[32][32]; + for (int cp = 0; cp < 16; cp++) { + for (int row = 0; row < 32; row++) { + uint8_t val = tile_src[cp * 32 + row]; + tile_quants[row][2 * cp + 0] = val & 0x0F; + tile_quants[row][2 * cp + 1] = val >> 4; + } + } + + for (int row = 0; row < 32; row++) { + int64_t r = ct * 32 + row; + if (r < ne1 && kt < ne0 / 32) { + pack_mxfp4_quants(&dst_expert[r * (ne0 / 32) + kt], tile_quants[row], 0); + } + } + + const uint8_t * scale_src = tile_src + 512; + for (int row = 0; row < 32; row++) { + int64_t r = ct * 32 + row; + if (r < ne1 && kt < ne0 / 32) { + dst_expert[r * (ne0 / 32) + kt].e = scale_src[row]; + } + } + } + } + } + } } static void ggml_backend_hexagon_buffer_set_tensor(ggml_backend_buffer_t buffer, @@ -1617,32 +864,32 @@ static void ggml_backend_hexagon_buffer_set_tensor(ggml_backend_buffer_t buffer, case GGML_TYPE_Q4_0: GGML_ASSERT(offset == 0); GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); - repack_q4_0_q4x4x2(tensor, data, size); + repack_q4_0_tiled(tensor, data, size); break; case GGML_TYPE_Q4_1: GGML_ASSERT(offset == 0); GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); - repack_q4_1_q4x4x2(tensor, data, size); + repack_q4_1_tiled(tensor, data, size); break; case GGML_TYPE_Q8_0: GGML_ASSERT(offset == 0); GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); - repack_q8_0_q8x4x2(tensor, data, size); + repack_q8_0_tiled(tensor, data, size); break; case GGML_TYPE_IQ4_NL: GGML_ASSERT(offset == 0); GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); // IQ4_NL has identical block layout to Q4_0 (ggml_half d + uint8_t qs[16]) - repack_q4_0_q4x4x2(tensor, data, size); + repack_q4_0_tiled(tensor, data, size); break; case GGML_TYPE_MXFP4: GGML_ASSERT(offset == 0); GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); - repack_mxfp4_mxfp4x4x2(tensor, data, size); + repack_mxfp4_tiled(tensor, data, size); break; default: @@ -1665,31 +912,31 @@ static void ggml_backend_hexagon_buffer_get_tensor(ggml_backend_buffer_t buffer, case GGML_TYPE_Q4_0: GGML_ASSERT(offset == 0); GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); - repack_q4x4x2_q4_0(data, tensor, size); + repack_tiled_q4_0(data, tensor, size); break; case GGML_TYPE_Q4_1: GGML_ASSERT(offset == 0); GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); - repack_q4x4x2_q4_1(data, tensor, size); + repack_tiled_q4_1(data, tensor, size); break; case GGML_TYPE_Q8_0: GGML_ASSERT(offset == 0); GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); - repack_q8x4x2_q8_0(data, tensor, size); + repack_tiled_q8_0(data, tensor, size); break; case GGML_TYPE_IQ4_NL: GGML_ASSERT(offset == 0); GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); - repack_q4x4x2_q4_0(data, tensor, size); + repack_tiled_q4_0(data, tensor, size); break; case GGML_TYPE_MXFP4: GGML_ASSERT(offset == 0); GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); - repack_mxfp4x4x2_mxfp4(data, tensor, size); + repack_tiled_mxfp4(data, tensor, size); break; default: @@ -1767,12 +1014,19 @@ static size_t ggml_backend_hexagon_buffer_type_get_alignment(ggml_backend_buffer } static size_t ggml_backend_hexagon_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * t) { + if (t->type == GGML_TYPE_Q4_0 || t->type == GGML_TYPE_Q4_1 || t->type == GGML_TYPE_Q8_0 || t->type == GGML_TYPE_IQ4_NL || t->type == GGML_TYPE_MXFP4) { + int64_t ne0 = hex_round_up(t->ne[0], 32); + int64_t ne1 = hex_round_up(t->ne[1], 32); + int64_t ne2 = t->ne[2]; + int64_t ne3 = t->ne[3]; + return ggml_row_size(t->type, ne0) * ne1 * ne2 * ne3; + } return ggml_nbytes(t); } static size_t ggml_backend_hexagon_buffer_type_get_max_size(ggml_backend_buffer_type_t buffer_type) { - return opt_mbuf; // typically 1GB per buffer - GGML_UNUSED(buffer_type); + auto * context = static_cast(buffer_type->context); + return context->sess->max_bufsize; } static bool ggml_backend_hexagon_buffer_type_is_host(ggml_backend_buffer_type_t buft) { @@ -1803,6 +1057,17 @@ static ggml_backend_buffer_type_i ggml_backend_hexagon_repack_buffer_type_interf /* .is_host = */ ggml_backend_hexagon_repack_buffer_type_is_host, }; +static bool ggml_backend_buffer_is_hexagon(const struct ggml_backend_buffer * b) { + return b->buft->iface.get_alignment == ggml_backend_hexagon_buffer_type_get_alignment; +} + +static inline bool ggml_backend_buffer_is_hexagon_repack(const struct ggml_backend_buffer * b) { + if (!opt_hostbuf) { + return ggml_backend_buffer_is_hexagon(b); + } + return b->buft->iface.alloc_buffer == ggml_backend_hexagon_repack_buffer_type_alloc_buffer; +} + struct ggml_hexagon_opbatch { ggml_hexagon_session* sess; @@ -1883,14 +1148,25 @@ struct ggml_hexagon_opbatch { b_vmem += b.size; - HEX_VERBOSE("ggml-hex: add-buffer #%u : fd %d base %p size %zu : vmem %zu\n", bi, b.fd, (void*) sbuf->base, (size_t) b.size, b_vmem); + HEX_VERBOSE("ggml-hex: %s add-buffer #%u : fd %d base %p size %zu : vmem %zu\n", sess->c_name(), bi, b.fd, (void*) sbuf->base, (size_t) b.size, b_vmem); return bi; } bool same_shape(const htp_tensor * h, const ggml_tensor * t) const { - return (h->ne[0] == t->ne[0]) && (h->ne[1] == t->ne[1]) && (h->ne[2] == t->ne[2]) && (h->ne[3] == t->ne[3]) && - (h->nb[0] == t->nb[0]) && (h->nb[1] == t->nb[1]) && (h->nb[2] == t->nb[2]) && (h->nb[3] == t->nb[3]); + int64_t ne0 = t->ne[0]; + int64_t ne1 = t->ne[1]; + const bool is_repack = ggml_backend_buffer_is_hexagon_repack(t->buffer) && ggml_hexagon_is_repack_type(t->type); + if (is_repack) { + ne0 = hex_round_up(ne0, 32); + ne1 = hex_round_up(ne1, 32); + } + int64_t nb1 = is_repack ? ggml_row_size(t->type, ne0) : t->nb[1]; + int64_t nb2 = is_repack ? nb1 * ne1 : t->nb[2]; + int64_t nb3 = is_repack ? nb2 * t->ne[2] : t->nb[3]; + + return (h->ne[0] == ne0) && (h->ne[1] == ne1) && (h->ne[2] == t->ne[2]) && (h->ne[3] == t->ne[3]) && + (h->nb[0] == t->nb[0]) && (h->nb[1] == nb1) && (h->nb[2] == nb2) && (h->nb[3] == nb3); } // add tensor and return its index @@ -1921,19 +1197,35 @@ struct ggml_hexagon_opbatch { htp_tensor &h = h_tens[ti]; h.bi = add_buffer(sbuf); h.data = t_offset; - h.size = t_size; h.type = t->type; - h.ne[0] = t->ne[0]; h.ne[1] = t->ne[1]; h.ne[2] = t->ne[2]; h.ne[3] = t->ne[3]; - h.nb[0] = t->nb[0]; h.nb[1] = t->nb[1]; h.nb[2] = t->nb[2]; h.nb[3] = t->nb[3]; + + const bool is_repack = ggml_backend_buffer_is_hexagon_repack(t->buffer) && ggml_hexagon_is_repack_type(t->type); + if (is_repack) { + h.ne[0] = hex_round_up(t->ne[0], 32); + h.ne[1] = hex_round_up(t->ne[1], 32); + h.ne[2] = t->ne[2]; + h.ne[3] = t->ne[3]; + + h.nb[0] = t->nb[0]; + h.nb[1] = ggml_row_size(t->type, h.ne[0]); + h.nb[2] = h.nb[1] * h.ne[1]; + h.nb[3] = h.nb[2] * h.ne[2]; + h.size = h.nb[3] * h.ne[3]; + t_size = h.size; + } else { + h.size = t_size; + h.ne[0] = t->ne[0]; h.ne[1] = t->ne[1]; h.ne[2] = t->ne[2]; h.ne[3] = t->ne[3]; + h.nb[0] = t->nb[0]; h.nb[1] = t->nb[1]; h.nb[2] = t->nb[2]; h.nb[3] = t->nb[3]; + } h.flags = 0; if (ggml_backend_buffer_get_usage(t->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE) { h.flags |= HTP_TENSOR_COMPUTE; } - HEX_VERBOSE("ggml-hex: add-tensor #%u %s : bi %d data %p offset %zu size %zu flags 0x%x : %zu:%zu:%zu:%zu\n", + HEX_VERBOSE("ggml-hex: %s add-tensor #%u %s : bi %d data %p offset %zu size %zu flags 0x%x : %zu:%zu:%zu:%zu\n", sess->c_name(), ti, t->name, h.bi, (void*) t->data, (size_t) t_offset, t_size, h.flags, - (size_t) t->ne[0], (size_t) t->ne[1], (size_t) t->ne[2], (size_t) t->ne[3]); + (size_t) h.ne[0], (size_t) h.ne[1], (size_t) h.ne[2], (size_t) h.ne[3]); return ti; } @@ -1962,7 +1254,9 @@ struct ggml_hexagon_opbatch { for (const auto * src : node.get_inputs()) { fit_tensor(src); } - fit_tensor(node.dst()); + for (const auto * output : node.get_outputs()) { + fit_tensor(output); + } if ((extra_bufs + n_bufs) > n_bufs_max) return false; if ((extra_tens + n_tens) > n_tens_max) return false; @@ -1981,7 +1275,8 @@ struct ggml_hexagon_opbatch { ops[n] = node; htp_op_desc &o = h_ops[n]; - memcpy(&o.params, &node.node->op_params, sizeof(node.node->op_params)); + memcpy(o.params, node.node->op_params, sizeof(node.node->op_params)); + memcpy(o.kernel_params, node.kernel_params, sizeof(o.kernel_params)); o.opcode = node.opcode; o.flags = 0; @@ -1989,13 +1284,17 @@ struct ggml_hexagon_opbatch { o.flags |= HTP_OPFLAGS_SKIP_COMPUTE; } - ggml_hexagon_dump_op_exec(sess->c_name(), node, o.flags); + ggml_hexagon_dump_op_exec(sess->c_name(), ops[n], o.flags); auto inputs = node.get_inputs(); for (unsigned int i=0; i < HTP_OP_MAX_INPUTS; i++) { - o.src[i] = (i < inputs.size() && inputs[i]) ? add_tensor(inputs[i]) : 0xffff; + o.src[i] = (i < inputs.size() && inputs[i]) ? add_tensor(inputs[i]) : 0xffff; + } + + auto outputs = node.get_outputs(); + for (unsigned int i=0; i < HTP_OP_MAX_OUTPUTS; i++) { + o.dst[i] = (i < outputs.size() && outputs[i]) ? add_tensor(outputs[i]) : 0xffff; } - o.dst = add_tensor(node.dst()); } }; @@ -2006,14 +1305,14 @@ struct ggml_hexagon_opqueue { using opvec = std::vector; - std::queue done; // completed batch ids - std::vector op_cache; // per batch op cache - std::vector start_usec; // per batch start time + std::queue done; // completed batch ids + std::vector op_cache; // per batch op cache + std::vector start_usec; // per batch start time ggml_hexagon_opqueue(ggml_hexagon_session *sess, size_t batch_size, size_t depth) { size_t n_bufs = HTP_OP_MAX_BUFS; size_t n_ops = batch_size; - size_t n_tensors = n_ops + n_ops * HTP_OP_MAX_INPUTS; + size_t n_tensors = n_ops * HTP_OP_MAX_OUTPUTS + n_ops * HTP_OP_MAX_INPUTS; size_t tr_size = 0; if (opt_profile == 3) { @@ -2200,7 +1499,7 @@ struct ggml_hexagon_opqueue { char evt_str[256] = ""; if (opt_profile == 3) { - sprintf(evt_str, " evt [%u,%u,%u,%u,%u,%u,%u,%u,%u,%u,%u]", + snprintf(evt_str, sizeof(evt_str), " evt [%u,%u,%u,%u,%u,%u,%u,%u,%u,%u,%u]", rsp.n_traces[0], rsp.n_traces[1], rsp.n_traces[2], rsp.n_traces[3], rsp.n_traces[4], rsp.n_traces[5], rsp.n_traces[6], rsp.n_traces[7], rsp.n_traces[8], rsp.n_traces[9], rsp.n_traces[10]); @@ -2224,6 +1523,7 @@ void ggml_hexagon_session::flush_pending(bool all) { // Read response packet from queue const uint32_t timeo = opt_oppoll ? 0 : DSPQUEUE_TIMEOUT; + int err = dspqueue_read(this->queue, &flags, 1, &n_dbufs, &dbuf, sizeof(rsp), &rsp_size, (uint8_t *) &rsp, timeo); if (err == AEE_EEXPIRED) { continue; @@ -2404,6 +1704,31 @@ void ggml_hexagon_session::allocate(int dev_id) noexcept(false) { this->valid_handle = true; + // Query HW info and resolve session options + this->max_bufsize = opt_mbuf; + { + unsigned int hw_n_threads = 0; + unsigned int hw_n_hvx = 0; + unsigned int hw_n_hmx = 0; + unsigned long long hw_vtcm_size = 0; + int hw_err = htp_iface_hwinfo(this->handle, &hw_n_threads, &hw_n_hvx, &hw_n_hmx, &hw_vtcm_size); + if (hw_err == 0) { + this->n_threads = opt_nhvx > 0 ? (uint32_t)opt_nhvx : (uint32_t)hw_n_threads; + this->n_hvx = opt_nhvx > 0 ? (uint32_t)opt_nhvx : (uint32_t)hw_n_hvx; + this->n_hmx = (opt_nhmx != 0) ? (uint32_t)hw_n_hmx : 0; + this->vtcm_size = (uint64_t)hw_vtcm_size; + GGML_LOG_INFO("ggml-hex: %s hwinfo: threads %u, hvx %u, hmx %u, vtcm %llu MB\n", + this->c_name(), this->n_threads, this->n_hvx, this->n_hmx, + (unsigned long long)(this->vtcm_size / (1024 * 1024))); + } else { + GGML_LOG_WARN("ggml-hex: %s failed to query hwinfo (0x%x), using defaults\n", this->c_name(), hw_err); + this->n_threads = opt_nhvx > 0 ? (uint32_t)opt_nhvx : 8; + this->n_hvx = opt_nhvx > 0 ? (uint32_t)opt_nhvx : 8; + this->n_hmx = (opt_nhmx != 0) ? 1 : 0; + this->vtcm_size = 8 * 1024 * 1024; + } + } + // Enable FastRPC QoS mode { struct remote_rpc_control_latency l; @@ -2468,11 +1793,12 @@ void ggml_hexagon_session::allocate(int dev_id) noexcept(false) { opt_vmem = ggml_hexagon_measure_max_vmem(this); GGML_LOG_INFO("ggml-hex: %s measured max vmem %zu\n", this->c_name(), opt_vmem); } + this->max_vmem = opt_vmem; - this->op_batch = new ggml_hexagon_opbatch(this, opt_opbatch, opt_vmem); + this->op_batch = new ggml_hexagon_opbatch(this, opt_opbatch, this->max_vmem); // Start dspqueue/opbatch processing - err = htp_iface_start(this->handle, dev_id, this->queue_id, opt_nhvx, opt_use_hmx, opt_vmem); + err = htp_iface_start(this->handle, dev_id, this->queue_id, opt_nhvx, opt_nhmx, this->max_vmem); if (err != 0) { GGML_LOG_ERROR("ggml-hex: %s failed to start session: 0x%08x\n", this->c_name(), (unsigned) err); throw std::runtime_error("ggml-hex: iface start failed (see log for details)"); @@ -2553,16 +1879,6 @@ ggml_hexagon_session::~ggml_hexagon_session() noexcept(true) { // ** backend interface -static bool ggml_backend_buffer_is_hexagon(const struct ggml_backend_buffer * b) { - return b->buft->iface.get_alignment == ggml_backend_hexagon_buffer_type_get_alignment; -} - -static inline bool ggml_backend_buffer_is_hexagon_repack(const struct ggml_backend_buffer * b) { - if (!opt_hostbuf) { - return ggml_backend_buffer_is_hexagon(b); - } - return b->buft->iface.alloc_buffer == ggml_backend_hexagon_repack_buffer_type_alloc_buffer; -} static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { const struct ggml_tensor * src0 = op->src[0]; @@ -2653,6 +1969,640 @@ static bool ggml_hexagon_supported_gated_delta_net(const struct ggml_hexagon_ses return true; } +static bool ggml_hexagon_matmul_is_hmx_eligible( + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + const struct ggml_tensor * dst, + int ne01_padded, + bool is_matmul_id, + bool is_batched +) { + const int ne00 = src0->ne[0]; + const int ne11 = src1->ne[1]; + const int ne12 = src1->ne[2]; + const int wtype = src0->type; + + // HMX weight tile requires N to be 32-aligned. + if (ne01_padded % 32 != 0) { + return false; + } + + // HMX supports F16, F32, and repack quantized types. + if (!ggml_hexagon_is_hmx_weight_type((ggml_type) wtype)) { + return false; + } + + // HMX paths require K aligned to 32. + if (ne00 % 32 != 0) { + return false; + } + + // Quantized HMX kernels only handle flat 2D matmul (or matmul_id wrapping flat 2D matmuls). + if (!is_matmul_id && is_batched && wtype != GGML_TYPE_F16) { + return false; + } + + // HMX assumes contiguous row-major layout. + if (src0->nb[0] > src0->nb[1] || src1->nb[0] > src1->nb[1]) { + return false; + } + + // M alignment: Use HMX when M > HTP_MM_HMX_MIN_NROWS + const int m = is_matmul_id ? ne12 : ne11; + if (m <= HTP_MM_HMX_MIN_NROWS) { + return false; + } + + return true; +} + +static bool ggml_hexagon_precompute_hmx_mm_params( + const struct ggml_hexagon_session * sess, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + const struct ggml_tensor * dst, + int wtype, + int ne00_padded, + int ne01_padded, + int ne02, + int ne11, + int ne12, + int ne11_padded, + bool is_matmul_id, + bool is_batched, + size_t vtcm_budget, + struct htp_mm_kernel_params * kparams +) { + const int aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype); + const bool pipeline = is_matmul_id ? false : htp_mm_hmx_pipeline(ne11); + const int n_threads = (int)sess->n_threads; + const int ne10 = src1->ne[0]; + + const bool is_batched_val = is_matmul_id ? false : is_batched; + const int group_size = (ne02 > 0 ? ne12 / ne02 : 1); + + size_t m_chunk = 0; + size_t n_chunk = 0; + size_t vtcm_size = 0; + bool use_grouped = false; + int act_threads_selected = 0; + + if (is_batched_val && wtype == GGML_TYPE_F16 && group_size > 1) { + // Try grouped path first + const bool use_dma_activation = (src1->nb[1]/sizeof(float) > (size_t)ne00_padded); + size_t best_mblocks = SIZE_MAX; + int best_act_threads = 0; + size_t best_m_chunk = 0; + size_t best_n_chunk = 0; + size_t best_vtcm_size = 0; + + int act_threads = n_threads; + while (act_threads >= 1) { + const size_t f32_scratch_size = use_dma_activation ? hex_align_up(act_threads * HTP_MM_DMA_ACT_MULTIPLIER * ne00_padded * sizeof(float), HTP_MM_HMX_TILE_SIZE) : 0; + size_t group_overhead = 256 + f32_scratch_size; + size_t group_size_per_n, group_size_per_m, group_size_per_mn; + htp_mm_hmx_get_batched_chunk_costs(ne00_padded, group_size, &group_size_per_n, &group_size_per_m, &group_size_per_mn); + + size_t m_chunk_candidate = 0; + size_t n_chunk_candidate = 0; + size_t vtcm_size_candidate = 0; + + if (htp_mm_hmx_compute_chunks(vtcm_budget, group_overhead, group_size_per_n, group_size_per_m, group_size_per_mn, hex_align_up(ne11, 32), ne01_padded, + (size_t) ne01_padded * HTP_MM_HMX_COST_W_DEQUANT, (size_t) ne11 * HTP_MM_HMX_COST_A_CONVERT, + &m_chunk_candidate, &n_chunk_candidate, &vtcm_size_candidate) == 0) { + size_t exact_size = htp_mm_hmx_get_batched_vtcm_size(wtype, ne00_padded, m_chunk_candidate, n_chunk_candidate, group_size, use_dma_activation, pipeline, act_threads); + if (exact_size <= vtcm_budget) { + size_t mblocks = ((size_t) ne11 + m_chunk_candidate - 1) / m_chunk_candidate; + if (mblocks < best_mblocks || (mblocks == best_mblocks && act_threads > best_act_threads)) { + best_mblocks = mblocks; + best_act_threads = act_threads; + best_m_chunk = m_chunk_candidate; + best_n_chunk = n_chunk_candidate; + best_vtcm_size = exact_size; + } + } + } + if (act_threads == 1) { + act_threads = 0; + } else { + act_threads /= 2; + } + } + + if (best_act_threads > 0) { + m_chunk = best_m_chunk; + n_chunk = best_n_chunk; + vtcm_size = best_vtcm_size; + act_threads_selected = best_act_threads; + use_grouped = true; + } + } + + if (!use_grouped) { + // Fallback to simple 2D path (group_size = 1) + size_t best_mblocks = SIZE_MAX; + int best_act_threads = 0; + size_t best_m_chunk = 0; + size_t best_n_chunk = 0; + size_t best_vtcm_size = 0; + + // For MUL_MAT_ID the kernel runs one 2D matmul per expert, with M equal to the number of rows routed to that expert. + // A single expert can receive up to all routed rows (dst->ne[1]*dst->ne[2] = n_expert_used*n_tokens), so size the chunk + // search for that upper bound rather than ne12 (token positions only). + // We recompute m_chunk per expert against the actual count in the NPU kernel. + const int m_id_rows = (int) ((size_t) dst->ne[1] * dst->ne[2]); + const int m_for_chunks = is_matmul_id ? hex_align_up(m_id_rows, 32) : ne11_padded; + const int m_for_cost = is_matmul_id ? m_id_rows : ne11; + + int act_threads = n_threads; + while (act_threads >= 1) { + const size_t act_f32_size = is_matmul_id ? 0 : hex_align_up(act_threads * HTP_MM_DMA_ACT_MULTIPLIER * ne00_padded * sizeof(float), HTP_MM_HMX_TILE_SIZE); + size_t simple_2d_overhead = 256 + act_f32_size; + size_t simple_2d_size_per_n, simple_2d_size_per_m, simple_2d_size_per_mn; + htp_mm_hmx_get_2d_chunk_costs(wtype, ne00_padded, pipeline, aligned_tile_size, &simple_2d_size_per_n, &simple_2d_size_per_m, &simple_2d_size_per_mn); + + size_t m_chunk_candidate = 0; + size_t n_chunk_candidate = 0; + size_t vtcm_size_candidate = 0; + + if (htp_mm_hmx_compute_chunks(vtcm_budget, simple_2d_overhead, simple_2d_size_per_n, simple_2d_size_per_m, simple_2d_size_per_mn, m_for_chunks, ne01_padded, + (size_t) ne01_padded * HTP_MM_HMX_COST_W_DEQUANT, (size_t) m_for_cost * HTP_MM_HMX_COST_A_CONVERT, + &m_chunk_candidate, &n_chunk_candidate, &vtcm_size_candidate) == 0) { + size_t exact_size = htp_mm_hmx_get_2d_vtcm_size(wtype, ne00_padded, m_chunk_candidate, n_chunk_candidate, pipeline, is_matmul_id ? 0 : act_threads, aligned_tile_size); + if (exact_size <= vtcm_budget) { + size_t mblocks = ((size_t) m_for_cost + m_chunk_candidate - 1) / m_chunk_candidate; + if (mblocks < best_mblocks || (mblocks == best_mblocks && act_threads > best_act_threads)) { + best_mblocks = mblocks; + best_act_threads = act_threads; + best_m_chunk = m_chunk_candidate; + best_n_chunk = n_chunk_candidate; + best_vtcm_size = exact_size; + } + } + } + if (act_threads == 1) { + act_threads = 0; + } else { + act_threads /= 2; + } + } + + if (best_act_threads > 0) { + m_chunk = best_m_chunk; + n_chunk = best_n_chunk; + vtcm_size = best_vtcm_size; + act_threads_selected = best_act_threads; + } else { + return false; + } + } + + kparams->n_hmx = 1; + kparams->pipeline = pipeline ? 1 : 0; + kparams->m_chunk = m_chunk; + kparams->n_chunk = n_chunk; + kparams->n_threads = n_threads; + kparams->n_act_threads = act_threads_selected; + kparams->tile_size = htp_mm_get_weight_tile_size(wtype); + kparams->aligned_tile_size = aligned_tile_size; + kparams->src1_row_size = (wtype == GGML_TYPE_Q4_1) ? htp_mm_q8_1_tiled_row_size(ne10) : htp_mm_q8_0_tiled_row_size(ne10); + kparams->vtcm_size = vtcm_size; + kparams->vtcm_src0_size = 0; + kparams->vtcm_src1_size = 0; + kparams->vtcm_dst_size = 0; + + if (is_batched && !is_matmul_id) { + kparams->kernel_type = HTP_MM_KERNEL_HMX_F16_BATCHED; + } else { + kparams->kernel_type = HTP_MM_KERNEL_HMX_2D; + } + return true; +} + +static void ggml_hexagon_precompute_hvx_mm_params( + const struct ggml_hexagon_session * sess, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + const struct ggml_tensor * dst, + int wtype, + int ne02, + int ne03, + int ne10, + int ne11, + int ne12, + int ne13, + bool is_matmul_id, + size_t vtcm_budget, + struct htp_mm_kernel_params * kparams +) { + kparams->n_hmx = 0; + + const bool is_quant = (wtype != GGML_TYPE_F16 && wtype != GGML_TYPE_F32); + const int src1_nrows = ne11 * ne12 * ne13; + + if (is_quant) { + // Quantized HVX + kparams->tile_size = htp_mm_get_weight_tile_size(wtype); + kparams->aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype); + + const bool k_align = (ne10 % 32 == 0); + + if (is_matmul_id) { + kparams->kernel_type = (src1_nrows < (int) sess->n_threads) ? HTP_MM_KERNEL_HVX_QUANT_BLOCK : HTP_MM_KERNEL_HVX_QUANT_ROW; + kparams->src1_row_size = (wtype == GGML_TYPE_Q4_1) ? htp_mm_q8_1_tiled_row_size(ne10) : htp_mm_q8_0_tiled_row_size(ne10); + + size_t vtcm_src0_size = 0, vtcm_src1_size = 0; + uint32_t max_prefetch = (src1_nrows > HTP_MM_HMX_MIN_NROWS) ? 2 : 16; + uint32_t best_n_prefetch = 2; + size_t total_size = 0; + for (uint32_t d = max_prefetch; d >= 2; d /= 2) { + total_size = htp_mm_hvx_id_get_vtcm_sizes( + wtype, ne10, src1_nrows, sess->n_threads, src0->nb[1], d, + &vtcm_src0_size, &vtcm_src1_size + ); + if (total_size <= vtcm_budget) { + best_n_prefetch = d; + break; + } + } + if (best_n_prefetch == 2 && total_size > vtcm_budget) { + total_size = htp_mm_hvx_id_get_vtcm_sizes( + wtype, ne10, src1_nrows, sess->n_threads, src0->nb[1], 2, + &vtcm_src0_size, &vtcm_src1_size + ); + } + kparams->n_prefetch = best_n_prefetch; + kparams->vtcm_size = total_size; + kparams->vtcm_src0_size = vtcm_src0_size; + kparams->vtcm_src1_size = vtcm_src1_size; + kparams->vtcm_dst_size = 0; + } else { + bool try_tiled = (k_align && opt_mm_select >= 2); + if (try_tiled) { + kparams->src1_row_size = (wtype == GGML_TYPE_Q4_1) ? htp_mm_q8_1_tiled_row_size(ne10) : htp_mm_q8_0_tiled_row_size(ne10); + if (src1_nrows < (int)sess->n_threads) { + kparams->kernel_type = HTP_MM_KERNEL_HVX_QUANT_BLOCK; + } else { + kparams->kernel_type = HTP_MM_KERNEL_HVX_QUANT_ROW; + } + + uint32_t max_prefetch = (src1_nrows > HTP_MM_HMX_MIN_NROWS) ? 2 : 16; + uint32_t best_n_prefetch = 2; + size_t vtcm_src0_size = 0, vtcm_src1_size = 0, vtcm_dst_size = 0; + size_t total_size = 0; + for (uint32_t d = max_prefetch; d >= 2; d /= 2) { + total_size = htp_mm_hvx_get_vtcm_sizes( + kparams->kernel_type, wtype, ne10, src1_nrows, sess->n_threads, + dst->nb[1], src0->nb[1], src1->nb[1], d, &vtcm_src0_size, &vtcm_src1_size, &vtcm_dst_size + ); + if (total_size <= vtcm_budget) { + best_n_prefetch = d; + break; + } + } + if (best_n_prefetch == 2 && total_size > vtcm_budget) { + total_size = htp_mm_hvx_get_vtcm_sizes( + kparams->kernel_type, wtype, ne10, src1_nrows, sess->n_threads, + dst->nb[1], src0->nb[1], src1->nb[1], 2, &vtcm_src0_size, &vtcm_src1_size, &vtcm_dst_size + ); + } + + kparams->n_prefetch = best_n_prefetch; + + if (total_size <= vtcm_budget) { + kparams->vtcm_size = total_size; + kparams->vtcm_src0_size = vtcm_src0_size; + kparams->vtcm_src1_size = vtcm_src1_size; + kparams->vtcm_dst_size = vtcm_dst_size; + goto done_quant; + } + HEX_VERBOSE("ggml-hex: %s HVX tiled path VTCM size needed (%zu) > budget (%zu), falling back to HVX flat\n", sess->name.c_str(), total_size, vtcm_budget); + } + + // Flat HVX fallback + { + kparams->src1_row_size = (wtype == GGML_TYPE_Q4_1) ? htp_mm_q8_1_flat_row_size(ne10) : htp_mm_q8_0_flat_row_size(ne10); + kparams->kernel_type = HTP_MM_KERNEL_HVX_QUANT_ROW_FLAT; + + size_t vtcm_src0_size = 0, vtcm_src1_size = 0, vtcm_dst_size = 0; + size_t total_size = htp_mm_hvx_get_vtcm_sizes( + kparams->kernel_type, wtype, ne10, src1_nrows, sess->n_threads, + dst->nb[1], src0->nb[1], src1->nb[1], 16, &vtcm_src0_size, &vtcm_src1_size, &vtcm_dst_size + ); + + kparams->n_prefetch = 16; + kparams->vtcm_size = total_size; + kparams->vtcm_src0_size = vtcm_src0_size; + kparams->vtcm_src1_size = vtcm_src1_size; + kparams->vtcm_dst_size = vtcm_dst_size; + } + } + + done_quant:; + } else if (wtype == GGML_TYPE_F16) { + // F16 HVX + const bool is_batched = (ne02 > 1) || (ne03 > 1); + const bool is_permuted = ggml_is_permuted(src0) || ggml_is_permuted(src1); + + size_t vtcm_src0_size = 0, vtcm_src1_size = 0, vtcm_dst_size = 0; + size_t vtcm_size = htp_mm_hvx_get_vtcm_sizes( + HTP_MM_KERNEL_HVX_F16_F16_VTCM, wtype, ne10, src1_nrows, sess->n_threads, + dst->nb[1], src0->nb[1], src1->nb[1], 16, &vtcm_src0_size, &vtcm_src1_size, &vtcm_dst_size + ); + + if (!is_batched && !is_permuted && vtcm_size <= vtcm_budget) { + kparams->kernel_type = HTP_MM_KERNEL_HVX_F16_F16_VTCM; + kparams->src1_row_size = hex_round_up(ne10 * 2, 128); + kparams->vtcm_size = vtcm_size; + kparams->vtcm_src0_size = vtcm_src0_size; + kparams->vtcm_src1_size = vtcm_src1_size; + kparams->vtcm_dst_size = vtcm_dst_size; + kparams->n_prefetch = 16; + } else { + if (src1->type == GGML_TYPE_F32) { + kparams->kernel_type = HTP_MM_KERNEL_HVX_F16_F32_DDR; + } else { + kparams->kernel_type = HTP_MM_KERNEL_HVX_F16_F16_DDR; + } + kparams->src1_row_size = src1->nb[1]; + size_t ddr_size = htp_mm_hvx_get_vtcm_sizes( + kparams->kernel_type, wtype, ne10, src1_nrows, sess->n_threads, + dst->nb[1], src0->nb[1], src1->nb[1], 16, &vtcm_src0_size, &vtcm_src1_size, &vtcm_dst_size + ); + kparams->vtcm_size = ddr_size; + kparams->vtcm_src0_size = vtcm_src0_size; + kparams->vtcm_src1_size = vtcm_src1_size; + kparams->vtcm_dst_size = vtcm_dst_size; + kparams->n_prefetch = 16; + } + } else { + // F32 HVX + const bool is_batched = (ne02 > 1) || (ne03 > 1); + const bool is_permuted = ggml_is_permuted(src0) || ggml_is_permuted(src1); + + size_t vtcm_src0_size = 0, vtcm_src1_size = 0, vtcm_dst_size = 0; + size_t vtcm_size = htp_mm_hvx_get_vtcm_sizes( + HTP_MM_KERNEL_HVX_F32_F32_VTCM, wtype, ne10, src1_nrows, sess->n_threads, + dst->nb[1], src0->nb[1], src1->nb[1], 16, &vtcm_src0_size, &vtcm_src1_size, &vtcm_dst_size + ); + + if (!is_batched && !is_permuted && vtcm_size <= vtcm_budget) { + kparams->kernel_type = HTP_MM_KERNEL_HVX_F32_F32_VTCM; + kparams->src1_row_size = hex_round_up(ne10 * 4, 128); + kparams->vtcm_size = vtcm_size; + kparams->vtcm_src0_size = vtcm_src0_size; + kparams->vtcm_src1_size = vtcm_src1_size; + kparams->vtcm_dst_size = vtcm_dst_size; + kparams->n_prefetch = 16; + } else { + kparams->kernel_type = HTP_MM_KERNEL_HVX_F32_F32_DDR; + kparams->src1_row_size = src1->nb[1]; + size_t ddr_size = htp_mm_hvx_get_vtcm_sizes( + kparams->kernel_type, wtype, ne10, src1_nrows, sess->n_threads, + dst->nb[1], src0->nb[1], src1->nb[1], 16, &vtcm_src0_size, &vtcm_src1_size, &vtcm_dst_size + ); + kparams->vtcm_size = ddr_size; + kparams->vtcm_src0_size = vtcm_src0_size; + kparams->vtcm_src1_size = vtcm_src1_size; + kparams->vtcm_dst_size = vtcm_dst_size; + kparams->n_prefetch = 16; + } + } +} + +static void ggml_hexagon_precompute_matmul_params( + const struct ggml_hexagon_session * sess, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + const struct ggml_tensor * dst, + struct htp_mm_kernel_params * kparams +) { + memset(kparams, 0, sizeof(*kparams)); + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + + const int ne10 = src1->ne[0]; + const int ne11 = src1->ne[1]; + const int ne12 = src1->ne[2]; + const int ne13 = src1->ne[3]; + + const int wtype = src0->type; + const bool is_repack = ggml_hexagon_is_repack_type((ggml_type) wtype); + const int ne00_padded = is_repack ? hex_round_up(ne00, 32) : ne00; + const int ne01_padded = is_repack ? hex_round_up(ne01, 32) : ne01; + const int ne11_padded = hex_round_up(ne11, 32); + + const bool is_matmul_id = (dst->op == GGML_OP_MUL_MAT_ID); + const bool is_batched = (ne02 * ne03 > 1 || ne12 * ne13 > 1); + + const size_t vtcm_budget = sess->vtcm_size; + + // Check HMX eligibility and try precomputing HMX parameters + bool hmx_enabled = (sess->n_hmx > 0) && (opt_mm_select >= 3); + if (hmx_enabled && ggml_hexagon_matmul_is_hmx_eligible(src0, src1, dst, ne01_padded, is_matmul_id, is_batched)) { + if (ggml_hexagon_precompute_hmx_mm_params(sess, src0, src1, dst, wtype, ne00_padded, ne01_padded, ne02, ne11, ne12, ne11_padded, is_matmul_id, is_batched, vtcm_budget, kparams)) { + goto finalize; + } + } + + // Fallback to HVX parameter computation + ggml_hexagon_precompute_hvx_mm_params(sess, src0, src1, dst, wtype, ne02, ne03, ne10, ne11, ne12, ne13, is_matmul_id, vtcm_budget, kparams); + +finalize: + kparams->div_ne12_ne1 = init_fastdiv_values(ne12 * ne11); + kparams->div_ne1 = init_fastdiv_values(ne11); + kparams->div_r2 = init_fastdiv_values(ne02 > 0 ? ne12 / ne02 : 1); + kparams->div_r3 = init_fastdiv_values(ne03 > 0 ? ne13 / ne03 : 1); + kparams->div_ne11 = init_fastdiv_values(ne11); +} + +static void ggml_hexagon_precompute_fused_qkv_params( + const struct ggml_hexagon_session * sess, + const struct ggml_tensor * src0, // Wk + const struct ggml_tensor * src1, // x + struct htp_mm_kernel_params * kparams +) { + memset(kparams, 0, sizeof(*kparams)); + + const int wtype = src0->type; + const bool is_repack = ggml_hexagon_is_repack_type((ggml_type) wtype); + + const int ne10 = src1->ne[0]; + const int src1_nrows = src1->ne[1] * src1->ne[2] * src1->ne[3]; + const size_t src1_row_size = (wtype == GGML_TYPE_Q4_1) ? htp_mm_q8_1_tiled_row_size(ne10) : htp_mm_q8_0_tiled_row_size(ne10); + const size_t src0_row_size = src0->nb[1]; + const size_t src0_row_size_padded = hex_round_up(src0_row_size, 128); + + size_t src0_sz_per_thread = 0; + size_t src2_sz_per_thread = 0; + size_t src3_sz_per_thread = 0; + uint32_t best_n_prefetch = 16; + + if (is_repack) { + uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype); + uint32_t n_k_tiles = hex_round_up(ne10, 32) / 32; + uint32_t tile_row_size = n_k_tiles * aligned_tile_size; + size_t src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0_TILED * sizeof(float)); + size_t src1_sz_per_thread = hex_round_up(src1_row_size * src1_nrows, 128); + size_t src1_sz = src1_sz_per_thread; + + const uint32_t max_prefetch = (src1_nrows > HTP_MM_HMX_MIN_NROWS) ? 2 : 16; + best_n_prefetch = 2; + for (uint32_t d = max_prefetch; d >= 2; d /= 2) { + size_t repacked_vtcm_size = hex_round_up(d * tile_row_size, 128); + if (repacked_vtcm_size < src1_row_size_padded) { + repacked_vtcm_size = src1_row_size_padded; + } + size_t src0_sz = repacked_vtcm_size * sess->n_threads; + size_t src2_sz = hex_round_up(d * tile_row_size, 128) * sess->n_threads; + size_t src3_sz = hex_round_up(d * tile_row_size, 128) * sess->n_threads; + size_t tiled_vtcm_size = src0_sz + src1_sz + src2_sz + src3_sz; + + if (tiled_vtcm_size <= sess->vtcm_size) { + best_n_prefetch = d; + src0_sz_per_thread = repacked_vtcm_size; + src2_sz_per_thread = hex_round_up(d * tile_row_size, 128); + src3_sz_per_thread = hex_round_up(d * tile_row_size, 128); + break; + } + } + if (best_n_prefetch == 2 && src0_sz_per_thread == 0) { + size_t repacked_vtcm_size = hex_round_up(2 * tile_row_size, 128); + if (repacked_vtcm_size < src1_row_size_padded) { + repacked_vtcm_size = src1_row_size_padded; + } + src0_sz_per_thread = repacked_vtcm_size; + src2_sz_per_thread = hex_round_up(2 * tile_row_size, 128); + src3_sz_per_thread = hex_round_up(2 * tile_row_size, 128); + } + } else { + best_n_prefetch = 16; + src0_sz_per_thread = hex_round_up(best_n_prefetch * src0_row_size_padded, 128); + src2_sz_per_thread = hex_round_up(best_n_prefetch * src0_row_size_padded, 128); + src3_sz_per_thread = hex_round_up(best_n_prefetch * src0_row_size_padded, 128); + } + + size_t src1_sz_per_thread = hex_round_up(src1_row_size * src1_nrows, 128); + + size_t src0_sz = src0_sz_per_thread * sess->n_threads; + size_t src1_sz = src1_sz_per_thread; + size_t src2_sz = src2_sz_per_thread * sess->n_threads; + size_t src3_sz = src3_sz_per_thread * sess->n_threads; + + size_t tiled_vtcm_size = src0_sz + src1_sz + src2_sz + src3_sz; + bool try_tiled = (opt_mm_select >= 2); + if (try_tiled && tiled_vtcm_size <= sess->vtcm_size) { + kparams->kernel_type = HTP_MM_KERNEL_HVX_QUANT_ROW; + kparams->vtcm_src0_size = src0_sz; + kparams->vtcm_src1_size = src1_sz; + kparams->vtcm_src2_size = src2_sz; + kparams->vtcm_src3_size = src3_sz; + kparams->vtcm_size = tiled_vtcm_size; + kparams->n_prefetch = best_n_prefetch; + } else { + kparams->kernel_type = HTP_MM_KERNEL_HVX_QUANT_ROW_FLAT; + size_t flat_src1_row_size = (wtype == GGML_TYPE_Q4_1) ? htp_mm_q8_1_flat_row_size(ne10) : htp_mm_q8_0_flat_row_size(ne10); + size_t flat_src1_sz = hex_round_up(flat_src1_row_size * src1_nrows, 128); + kparams->vtcm_src0_size = src0_sz; + kparams->vtcm_src1_size = flat_src1_sz; + kparams->vtcm_src2_size = src2_sz; + kparams->vtcm_src3_size = src3_sz; + kparams->vtcm_size = src0_sz + flat_src1_sz + src2_sz + src3_sz; + kparams->n_prefetch = best_n_prefetch; + } +} + +static void ggml_hexagon_precompute_fused_ffn_params( + const struct ggml_hexagon_session * sess, + const struct ggml_tensor * src0, // Wgate + const struct ggml_tensor * src1, // y + struct htp_mm_kernel_params * kparams +) { + memset(kparams, 0, sizeof(*kparams)); + + const int wtype = src0->type; + const bool is_repack = ggml_hexagon_is_repack_type((ggml_type) wtype); + + const int ne10 = src1->ne[0]; + const int src1_nrows = src1->ne[1] * src1->ne[2] * src1->ne[3]; + const size_t src1_row_size = (wtype == GGML_TYPE_Q4_1) ? htp_mm_q8_1_tiled_row_size(ne10) : htp_mm_q8_0_tiled_row_size(ne10); + const size_t src0_row_size = src0->nb[1]; + const size_t src0_row_size_padded = hex_round_up(src0_row_size, 128); + + size_t src0_sz_per_thread = 0; + size_t src2_sz_per_thread = 0; + uint32_t best_n_prefetch = 16; + + if (is_repack) { + uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype); + uint32_t n_k_tiles = hex_round_up(ne10, 32) / 32; + uint32_t tile_row_size = n_k_tiles * aligned_tile_size; + size_t src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0_TILED * sizeof(float)); + size_t src1_sz_per_thread = hex_round_up(src1_row_size * src1_nrows, 128); + size_t src1_sz = src1_sz_per_thread; + + const uint32_t max_prefetch = (src1_nrows > HTP_MM_HMX_MIN_NROWS) ? 2 : 16; + best_n_prefetch = 2; + for (uint32_t d = max_prefetch; d >= 2; d /= 2) { + size_t repacked_vtcm_size = hex_round_up(d * tile_row_size, 128); + if (repacked_vtcm_size < src1_row_size_padded) { + repacked_vtcm_size = src1_row_size_padded; + } + size_t src0_sz = repacked_vtcm_size * sess->n_threads; + size_t src2_sz = hex_round_up(d * tile_row_size, 128) * sess->n_threads; + size_t tiled_vtcm_size = src0_sz + src1_sz + src2_sz; + + if (tiled_vtcm_size <= sess->vtcm_size) { + best_n_prefetch = d; + src0_sz_per_thread = repacked_vtcm_size; + src2_sz_per_thread = hex_round_up(d * tile_row_size, 128); + break; + } + } + if (best_n_prefetch == 2 && src0_sz_per_thread == 0) { + size_t repacked_vtcm_size = hex_round_up(2 * tile_row_size, 128); + if (repacked_vtcm_size < src1_row_size_padded) { + repacked_vtcm_size = src1_row_size_padded; + } + src0_sz_per_thread = repacked_vtcm_size; + src2_sz_per_thread = hex_round_up(2 * tile_row_size, 128); + } + } else { + best_n_prefetch = 16; + src0_sz_per_thread = hex_round_up(best_n_prefetch * src0_row_size_padded, 128); + src2_sz_per_thread = hex_round_up(best_n_prefetch * src0_row_size_padded, 128); + } + + size_t src1_sz_per_thread = hex_round_up(src1_row_size * src1_nrows, 128); + + size_t src0_sz = src0_sz_per_thread * sess->n_threads; + size_t src1_sz = src1_sz_per_thread; + size_t src2_sz = src2_sz_per_thread * sess->n_threads; + + size_t tiled_vtcm_size = src0_sz + src1_sz + src2_sz; + bool try_tiled = (opt_mm_select >= 2); + if (try_tiled && tiled_vtcm_size <= sess->vtcm_size) { + kparams->kernel_type = HTP_MM_KERNEL_HVX_QUANT_ROW; + kparams->vtcm_src0_size = src0_sz; + kparams->vtcm_src1_size = src1_sz; + kparams->vtcm_src2_size = src2_sz; + kparams->vtcm_size = tiled_vtcm_size; + kparams->n_prefetch = best_n_prefetch; + } else { + kparams->kernel_type = HTP_MM_KERNEL_HVX_QUANT_ROW_FLAT; + size_t flat_src1_row_size = (wtype == GGML_TYPE_Q4_1) ? htp_mm_q8_1_flat_row_size(ne10) : htp_mm_q8_0_flat_row_size(ne10); + size_t flat_src1_sz = hex_round_up(flat_src1_row_size * src1_nrows, 128); + kparams->vtcm_src0_size = src0_sz; + kparams->vtcm_src1_size = flat_src1_sz; + kparams->vtcm_src2_size = src2_sz; + kparams->vtcm_size = src0_sz + flat_src1_sz + src2_sz; + kparams->n_prefetch = best_n_prefetch; + } +} + static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * sess, const struct ggml_tensor * dst) { const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; @@ -2675,12 +2625,13 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s return false; } - if (ggml_nrows(src0) > 16 * 1024) { - return false; // typically the lm-head which would be too large for VTCM + // hardcoded limit to refuse the lm-head for now + if (src0->ne[1] > 32768) { + return false; } - if (ggml_nrows(src1) > 1024 || src1->ne[2] != 1 || src1->ne[3] != 1) { - return false; // no huge batches or broadcasting (for now) + if (src1->ne[2] != 1 || src1->ne[3] != 1) { + return false; // no broadcasting (for now) } // src0 (weights) must be repacked @@ -2691,16 +2642,11 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s case GGML_TYPE_F16: if (src0->nb[1] < src0->nb[0]) { - GGML_LOG_DEBUG("ggml_hexagon_supported_mul_mat: permuted F16 src0 not supported\n"); return false; } if (src1->ne[2] < src0->ne[2] || src1->ne[3] < src0->ne[3]) { - GGML_LOG_DEBUG("ggml_hexagon_supported_mul_mat: src1 broadcasting not supported\n"); return false; } - if (ggml_nrows(src1) > 1024) { - return false; // no huge batches (for now) - } break; case GGML_TYPE_F32: @@ -2708,22 +2654,24 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s return false; } if (src0->nb[1] < src0->nb[0]) { - GGML_LOG_DEBUG("ggml_hexagon_supported_mul_mat: permuted F32 src0 not supported\n"); return false; } if (src1->ne[2] < src0->ne[2] || src1->ne[3] < src0->ne[3]) { - GGML_LOG_DEBUG("ggml_hexagon_supported_mul_mat: src1 broadcasting not supported\n"); return false; } - if (ggml_nrows(src1) > 1024) { - return false; // no huge batches (for now) - } break; default: return false; } + struct htp_mm_kernel_params kparams; + ggml_hexagon_precompute_matmul_params(sess, src0, src1, dst, &kparams); + if ((size_t)kparams.vtcm_size > sess->vtcm_size) { + HEX_VERBOSE("ggml-hex: %s supported MUL_MAT VTCM size needed (%d) > budget (%zu)\n", sess->c_name(), kparams.vtcm_size, sess->vtcm_size); + return false; + } + return true; } @@ -2757,6 +2705,13 @@ static bool ggml_hexagon_supported_mul_mat_id(const struct ggml_hexagon_session return false; } + struct htp_mm_kernel_params kparams; + ggml_hexagon_precompute_matmul_params(sess, src0, src1, dst, &kparams); + if ((size_t)kparams.vtcm_size > sess->vtcm_size) { + HEX_VERBOSE("ggml-hex: %s supported MUL_MAT_ID VTCM size needed (%d) > budget (%zu)\n", sess->c_name(), kparams.vtcm_size, sess->vtcm_size); + return false; + } + return true; } @@ -3288,47 +3243,172 @@ static inline bool op_is_compute(ggml_tensor *node) return !ggml_op_is_empty(node->op) && !ggml_is_empty(node) && (node->flags & GGML_TENSOR_FLAG_COMPUTE); } +static bool is_hmx_eligible(const ggml_tensor * t) { + if (opt_nhmx == 0) { return false; } + + const ggml_tensor * src0 = t->src[0]; + const ggml_tensor * src1 = t->src[1]; + + const int wtype = src0->type; + const bool is_repack = ggml_hexagon_is_repack_type((ggml_type) wtype); + const bool is_matmul_id = (t->op == GGML_OP_MUL_MAT_ID); + const bool is_batched = (src0->ne[2] * src0->ne[3] > 1 || src1->ne[2] * src1->ne[3] > 1); + + const int ne01_padded = is_repack ? hex_round_up(src0->ne[1], 32) : src0->ne[1]; + + return ggml_hexagon_matmul_is_hmx_eligible(src0, src1, t, ne01_padded, is_matmul_id, is_batched); +} + +static bool is_mergeable_mul_mat(const ggml_tensor * t) { + if (!t || t->op != GGML_OP_MUL_MAT) return false; + if (t->src[1]->type != GGML_TYPE_F32) return false; + return ggml_is_quantized(t->src[0]->type) && !is_hmx_eligible(t); +} + +static bool is_mergeable_mul_mat_pair(const ggml_tensor * n1, const ggml_tensor * n2) { + if (!is_mergeable_mul_mat(n1) || !is_mergeable_mul_mat(n2)) { + return false; + } + if (n1->src[1] != n2->src[1]) { + return false; + } + if (n1->src[0]->ne[0] != n2->src[0]->ne[0] || + n1->src[0]->ne[1] != n2->src[0]->ne[1]) { + return false; + } + if (n1->src[0]->type != n2->src[0]->type) { + return false; + } + return true; +} + +static bool is_qkv_mergeable(const ggml_tensor * n_q, const ggml_tensor * n_k, const ggml_tensor * n_v) { + if (!is_mergeable_mul_mat(n_q) || !is_mergeable_mul_mat(n_k) || !is_mergeable_mul_mat(n_v)) { + return false; + } + if (n_q->src[1] != n_k->src[1] || n_q->src[1] != n_v->src[1]) { + return false; + } + if (n_q->src[0]->type != n_k->src[0]->type || n_q->src[0]->type != n_v->src[0]->type) { + return false; + } + if (n_k->src[0]->ne[0] != n_v->src[0]->ne[0] || + n_k->src[0]->ne[1] != n_v->src[0]->ne[1]) { + return false; + } + if (n_q->src[0]->ne[0] != n_k->src[0]->ne[0]) { + return false; + } + return true; +} + +static bool try_fuse_node(const ggml_hexagon_session * sess, const ggml_cgraph * graph, int & i, std::vector & nodes) { + if (!opt_opfusion) { + return false; + } + + ggml_tensor * n = graph->nodes[i]; + ggml_tensor * next_node = (i + 1 < graph->n_nodes) ? graph->nodes[i + 1] : nullptr; + + if (n->op == GGML_OP_RMS_NORM && next_node) { + if (next_node->op == GGML_OP_MUL && op_is_compute(next_node) && ggml_can_fuse(graph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { + htp_opnode node(n, {}, HTP_OP_RMS_NORM_MUL); + node.add_fused(next_node); + nodes.push_back(std::move(node)); + i++; // skip the fused MUL node + return true; + } + } + + if (is_mergeable_mul_mat(n)) { + ggml_tensor * n1 = (i + 1 < graph->n_nodes) ? graph->nodes[i + 1] : nullptr; + ggml_tensor * n2 = (i + 2 < graph->n_nodes) ? graph->nodes[i + 2] : nullptr; + if (is_qkv_mergeable(n, n1, n2)) { + struct htp_mm_kernel_params kparams; + ggml_hexagon_precompute_fused_qkv_params(sess, n1->src[0], n1->src[1], &kparams); + if ((size_t)kparams.vtcm_size <= sess->vtcm_size) { + // Reorder to KVQ: K (n1), V (n2), Q (n) + htp_opnode node(n1, {}, HTP_OP_MUL_MAT_QKV); + node.add_fused(n2, true); + node.add_fused(n, true); + memcpy(node.kernel_params, &kparams, sizeof(kparams)); + nodes.push_back(std::move(node)); + i += 2; + return true; + } else { + HEX_VERBOSE("ggml-hex: skip QKV fusion because VTCM needed (%d) > budget (%zu)\n", + kparams.vtcm_size, sess->vtcm_size); + } + } + if (is_mergeable_mul_mat_pair(n, n1)) { + struct htp_mm_kernel_params kparams; + ggml_hexagon_precompute_fused_ffn_params(sess, n->src[0], n->src[1], &kparams); + if ((size_t)kparams.vtcm_size <= sess->vtcm_size) { + htp_opnode node(n, {}, HTP_OP_MUL_MAT_FFN); + node.add_fused(n1, true); + memcpy(node.kernel_params, &kparams, sizeof(kparams)); + nodes.push_back(std::move(node)); + i += 1; + return true; + } else { + HEX_VERBOSE("ggml-hex: skip FFN fusion because VTCM needed (%d) > budget (%zu)\n", + kparams.vtcm_size, sess->vtcm_size); + } + } + } + + return false; +} + static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, ggml_cgraph * graph) { auto sess = static_cast(backend->context); HEX_VERBOSE("ggml-hex: %s graph-compute n_nodes %d\n", sess->c_name(), graph->n_nodes); - std::vector nodes; - nodes.reserve(graph->n_nodes); + const std::vector * nodes_ptr = nullptr; + std::vector computed_nodes; - // Fusion - for (int i = 0; i < graph->n_nodes; ++i) { - ggml_tensor * n = graph->nodes[i]; - if (!op_is_compute(n)) { - continue; - } + // Check for cache hit + bool cache_hit = (graph->uid != 0 && sess->cached_graph.uid == graph->uid); + if (cache_hit) { + nodes_ptr = &sess->cached_graph.htp_nodes; + } else { + computed_nodes.reserve(graph->n_nodes); - ggml_tensor * next_node = (i + 1 < graph->n_nodes) ? graph->nodes[i + 1] : nullptr; - - htp_opnode node = { - /*.node =*/ n, - /*.fused =*/ {}, - /*.opcode =*/ HTP_OP_INVALID - }; - - if (n->op == GGML_OP_RMS_NORM && next_node) { - if (next_node->op == GGML_OP_MUL && op_is_compute(next_node) && ggml_can_fuse(graph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { - node.add_fused(next_node); - node.opcode = HTP_OP_RMS_NORM_MUL; - i++; // skip the fused MUL node + // Fuse and finalize + for (int i = 0; i < graph->n_nodes; ++i) { + ggml_tensor * n = graph->nodes[i]; + if (!op_is_compute(n)) { + continue; } - } - if (node.opcode == HTP_OP_INVALID) { + if (try_fuse_node(sess, graph, i, computed_nodes)) { + continue; + } + + htp_opnode node(n, {}, HTP_OP_INVALID); node.opcode = op_remap_to_htp(n); + if (node.opcode == HTP_OP_MUL_MAT || node.opcode == HTP_OP_MUL_MAT_ID) { + ggml_hexagon_precompute_matmul_params(sess, + node.node->src[0], node.node->src[1], node.node, + (struct htp_mm_kernel_params *)node.kernel_params + ); + } + computed_nodes.push_back(std::move(node)); } - nodes.push_back(std::move(node)); + if (graph->uid != 0) { + sess->cached_graph.uid = graph->uid; + sess->cached_graph.htp_nodes = std::move(computed_nodes); + nodes_ptr = &sess->cached_graph.htp_nodes; + } else { + nodes_ptr = &computed_nodes; + } } // Queue and execute if (opt_opstage & HTP_OPSTAGE_QUEUE) { - for (const auto & node : nodes) { + for (const auto & node : *nodes_ptr) { sess->enqueue_op(node); } } @@ -3991,16 +4071,19 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) { const char * str_opbatch = getenv("GGML_HEXAGON_OPBATCH"); const char * str_opqueue = getenv("GGML_HEXAGON_OPQUEUE"); const char * str_oppoll = getenv("GGML_HEXAGON_OPPOLL"); - const char * str_optrace = getenv("GGML_HEXAGON_OPTRACE"); + const char * str_opfusion = getenv("GGML_HEXAGON_OPFUSION"); const char * str_opfilter = getenv("GGML_HEXAGON_OPFILTER"); const char * str_profile = getenv("GGML_HEXAGON_PROFILE"); const char * str_etm = getenv("GGML_HEXAGON_ETM"); const char * str_nhvx = getenv("GGML_HEXAGON_NHVX"); const char * str_use_hmx = getenv("GGML_HEXAGON_USE_HMX"); + const char * str_nhmx = getenv("GGML_HEXAGON_NHMX"); + const char * str_mm_select = getenv("GGML_HEXAGON_MM_SELECT"); const char * str_ndev = getenv("GGML_HEXAGON_NDEV"); const char * str_arch = getenv("GGML_HEXAGON_ARCH"); const char * str_vmem = getenv("GGML_HEXAGON_VMEM"); const char * str_mbuf = getenv("GGML_HEXAGON_MBUF"); + const char * str_optrace = getenv("GGML_HEXAGON_OPTRACE"); // Init Arch first since it affects other defaults if (!str_arch) { @@ -4029,12 +4112,14 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) { opt_opstage = str_opstage ? strtoul(str_opstage, NULL, 0) : opt_opstage; opt_opbatch = str_opbatch ? strtoul(str_opbatch, NULL, 0) : opt_opbatch; opt_opqueue = str_opqueue ? strtoul(str_opqueue, NULL, 0) : opt_opqueue; - opt_oppoll = str_oppoll ? strtoul(str_oppoll, NULL, 0) : opt_oppoll; opt_optrace = str_optrace ? strtoul(str_optrace, NULL, 0) : (opt_opbatch * 128); + opt_oppoll = str_oppoll ? strtoul(str_oppoll, NULL, 0) : opt_oppoll; + opt_opfusion = str_opfusion ? atoi(str_opfusion) : opt_opfusion; opt_profile = str_profile ? atoi(str_profile) : 0; opt_etm = str_etm ? atoi(str_etm) : 0; opt_nhvx = str_nhvx ? strtoul(str_nhvx, NULL, 0) : opt_nhvx; - opt_use_hmx = str_use_hmx ? atoi(str_use_hmx) : opt_use_hmx; + opt_nhmx = str_nhmx ? atoi(str_nhmx) : (str_use_hmx ? atoi(str_use_hmx) : opt_nhmx); + opt_mm_select = str_mm_select ? atoi(str_mm_select) : opt_mm_select; opt_ndev = str_ndev ? strtoul(str_ndev, NULL, 0) : opt_ndev; opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : opt_hostbuf; opt_mbuf = str_mbuf ? strtoul(str_mbuf, NULL, 0) * MiB : opt_mbuf; diff --git a/ggml/src/ggml-hexagon/htp-opnode.h b/ggml/src/ggml-hexagon/htp-opnode.h index 52c727c620..6fe23b0d6a 100644 --- a/ggml/src/ggml-hexagon/htp-opnode.h +++ b/ggml/src/ggml-hexagon/htp-opnode.h @@ -5,10 +5,12 @@ #include "ggml-backend-impl.h" #include "ggml-common.h" +#include #include #include #include #include "htp-ops.h" +#include "htp/matmul-ops.h" struct htp_opnode { ggml_tensor * node = nullptr; @@ -17,6 +19,13 @@ struct htp_opnode { htp_op_code opcode = HTP_OP_INVALID; + std::vector extra_dsts; + + int32_t kernel_params[HTP_OP_MAX_KERN_PARAMS] = {0}; + + htp_opnode(ggml_tensor * node = nullptr, std::vector fused = {}, htp_op_code opcode = HTP_OP_INVALID, std::vector extra_dsts = {}) + : node(node), fused(std::move(fused)), opcode(opcode), extra_dsts(std::move(extra_dsts)) {} + ggml_op op() const { return node->op; } @@ -25,6 +34,26 @@ struct htp_opnode { return fused.empty() ? node : fused.back(); } + void add_fused(ggml_tensor * t, bool extra_dst = false) { + fused.push_back(t); + if (extra_dst) { + extra_dsts.push_back(t); + } + } + + std::vector get_outputs() const { + std::vector res; + if (extra_dsts.empty()) { + res.push_back(dst()); + } else { + res.push_back(node); + for (const auto * x : extra_dsts) { + res.push_back(x); + } + } + return res; + } + const ggml_tensor * src0() const { return node->src[0]; } @@ -37,10 +66,6 @@ struct htp_opnode { return ggml_op_is_empty(node->op); } - void add_fused(ggml_tensor * t) { - fused.push_back(t); - } - bool stackable() const { switch (this->op()) { case GGML_OP_MUL_MAT: @@ -131,87 +156,117 @@ struct htp_opformat { char types[16 * GGML_MAX_SRC]; char buffs[64 * GGML_MAX_SRC]; char names[64 * GGML_MAX_SRC]; + char kparams[128]; - int format_tensor_dims(char * str, const struct ggml_tensor * t) { + int format_tensor_dims(char * str, size_t max_size, const struct ggml_tensor * t) { if (!t) { - return sprintf(str, "NONE"); + return snprintf(str, max_size, "NONE"); } if (t->ne[2] == 1 && t->ne[3] == 1) { - return sprintf(str, "%d:%d", (int) t->ne[0], (int) t->ne[1]); + return snprintf(str, max_size, "%d:%d", (int) t->ne[0], (int) t->ne[1]); } else { - return sprintf(str, "%d:%d:%d:%d", (int) t->ne[0], (int) t->ne[1], (int) t->ne[2], (int) t->ne[3]); + return snprintf(str, max_size, "%d:%d:%d:%d", (int) t->ne[0], (int) t->ne[1], (int) t->ne[2], (int) t->ne[3]); } } - void format_op_dims(char * str, const htp_opnode & node) { + void format_op_dims(char * str, size_t max_size, const htp_opnode & node) { char * p = str; + char * p_end = str + max_size; auto inputs = node.get_inputs(); if (!inputs.empty()) { - p += format_tensor_dims(p, inputs[0]); + p += std::min((size_t)format_tensor_dims(p, p_end - p, inputs[0]), (size_t)(p_end - p)); for (size_t i = 1; i < inputs.size(); i++) { - p += sprintf(p, " x "); - p += format_tensor_dims(p, inputs[i]); + if (p < p_end) { + p += std::min((size_t)snprintf(p, p_end - p, " x "), (size_t)(p_end - p)); + } + if (p < p_end) { + p += std::min((size_t)format_tensor_dims(p, p_end - p, inputs[i]), (size_t)(p_end - p)); + } } - p += sprintf(p, " -> "); + if (p < p_end) { + p += std::min((size_t)snprintf(p, p_end - p, " -> "), (size_t)(p_end - p)); + } } char self[64]; - format_tensor_dims(self, node.dst()); - p += sprintf(p, "%s", self); + format_tensor_dims(self, sizeof(self), node.dst()); + if (p < p_end) { + p += std::min((size_t)snprintf(p, p_end - p, "%s", self), (size_t)(p_end - p)); + } } - int format_tensor_strides(char * str, const struct ggml_tensor * t) { + int format_tensor_strides(char * str, size_t max_size, const struct ggml_tensor * t) { if (!t) { - return sprintf(str, "NONE"); + return snprintf(str, max_size, "NONE"); } const char * c = ggml_is_contiguous(t) ? "" : "!"; if (t->ne[2] == 1 && t->ne[3] == 1) { - return sprintf(str, "%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], c); + return snprintf(str, max_size, "%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], c); } else { - return sprintf(str, "%zu:%zu:%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], (size_t) t->nb[2], (size_t) t->nb[3], c); + return snprintf(str, max_size, "%zu:%zu:%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], (size_t) t->nb[2], (size_t) t->nb[3], c); } } - void format_op_strides(char * str, const htp_opnode & node) { + void format_op_strides(char * str, size_t max_size, const htp_opnode & node) { char * p = str; + char * p_end = str + max_size; auto inputs = node.get_inputs(); if (!inputs.empty()) { - p += format_tensor_strides(p, inputs[0]); + p += std::min((size_t)format_tensor_strides(p, p_end - p, inputs[0]), (size_t)(p_end - p)); for (size_t i = 1; i < inputs.size(); i++) { - p += sprintf(p, " x "); - p += format_tensor_strides(p, inputs[i]); + if (p < p_end) { + p += std::min((size_t)snprintf(p, p_end - p, " x "), (size_t)(p_end - p)); + } + if (p < p_end) { + p += std::min((size_t)format_tensor_strides(p, p_end - p, inputs[i]), (size_t)(p_end - p)); + } } - p += sprintf(p, " -> "); + if (p < p_end) { + p += std::min((size_t)snprintf(p, p_end - p, " -> "), (size_t)(p_end - p)); + } } char self[64]; - format_tensor_strides(self, node.dst()); - p += sprintf(p, "%s", self); + format_tensor_strides(self, sizeof(self), node.dst()); + if (p < p_end) { + p += std::min((size_t)snprintf(p, p_end - p, "%s", self), (size_t)(p_end - p)); + } } - void format_op_types(char * str, const htp_opnode & node) { + void format_op_types(char * str, size_t max_size, const htp_opnode & node) { char * p = str; + char * p_end = str + max_size; auto inputs = node.get_inputs(); if (!inputs.empty()) { - p += sprintf(p, "%s", inputs[0] ? ggml_type_name(inputs[0]->type) : "NONE"); - - for (size_t i = 1; i < inputs.size(); i++) { - p += sprintf(p, " x "); - p += sprintf(p, "%s", inputs[i] ? ggml_type_name(inputs[i]->type) : "NONE"); + if (p < p_end) { + p += std::min((size_t)snprintf(p, p_end - p, "%s", inputs[0] ? ggml_type_name(inputs[0]->type) : "NONE"), (size_t)(p_end - p)); } - p += sprintf(p, " -> "); + for (size_t i = 1; i < inputs.size(); i++) { + if (p < p_end) { + p += std::min((size_t)snprintf(p, p_end - p, " x "), (size_t)(p_end - p)); + } + if (p < p_end) { + p += std::min((size_t)snprintf(p, p_end - p, "%s", inputs[i] ? ggml_type_name(inputs[i]->type) : "NONE"), (size_t)(p_end - p)); + } + } + + if (p < p_end) { + p += std::min((size_t)snprintf(p, p_end - p, " -> "), (size_t)(p_end - p)); + } } - p += sprintf(p, "%s", ggml_type_name(node.dst()->type)); + if (p < p_end) { + p += std::min((size_t)snprintf(p, p_end - p, "%s", ggml_type_name(node.dst()->type)), (size_t)(p_end - p)); + } } const char * tensor_buff_name(const struct ggml_tensor * t) { @@ -221,51 +276,102 @@ struct htp_opformat { return "NONE"; } - void format_op_buffs(char * str, const htp_opnode & node) { + void format_op_buffs(char * str, size_t max_size, const htp_opnode & node) { char * p = str; + char * p_end = str + max_size; auto inputs = node.get_inputs(); if (!inputs.empty()) { - p += sprintf(p, "%s", tensor_buff_name(inputs[0])); - - for (size_t i = 1; i < inputs.size(); i++) { - p += sprintf(p, " x "); - p += sprintf(p, "%s", tensor_buff_name(inputs[i])); + if (p < p_end) { + p += std::min((size_t)snprintf(p, p_end - p, "%s", tensor_buff_name(inputs[0])), (size_t)(p_end - p)); } - p += sprintf(p, " -> "); + for (size_t i = 1; i < inputs.size(); i++) { + if (p < p_end) { + p += std::min((size_t)snprintf(p, p_end - p, " x "), (size_t)(p_end - p)); + } + if (p < p_end) { + p += std::min((size_t)snprintf(p, p_end - p, "%s", tensor_buff_name(inputs[i])), (size_t)(p_end - p)); + } + } + + if (p < p_end) { + p += std::min((size_t)snprintf(p, p_end - p, " -> "), (size_t)(p_end - p)); + } } - p += sprintf(p, "%s", tensor_buff_name(node.dst())); + if (p < p_end) { + p += std::min((size_t)snprintf(p, p_end - p, "%s", tensor_buff_name(node.dst())), (size_t)(p_end - p)); + } } - void format_op_names(char * str, const htp_opnode & node) { + void format_op_names(char * str, size_t max_size, const htp_opnode & node) { char * p = str; + char * p_end = str + max_size; auto inputs = node.get_inputs(); if (!inputs.empty()) { - p += sprintf(p, "%s", inputs[0] ? inputs[0]->name : "NONE"); - - for (size_t i = 1; i < inputs.size(); i++) { - p += sprintf(p, " x "); - p += sprintf(p, "%s", inputs[i] ? inputs[i]->name : "NONE"); + if (p < p_end) { + p += std::min((size_t)snprintf(p, p_end - p, "%s", inputs[0] ? inputs[0]->name : "NONE"), (size_t)(p_end - p)); } - p += sprintf(p, " -> "); + for (size_t i = 1; i < inputs.size(); i++) { + if (p < p_end) { + p += std::min((size_t)snprintf(p, p_end - p, " x "), (size_t)(p_end - p)); + } + if (p < p_end) { + p += std::min((size_t)snprintf(p, p_end - p, "%s", inputs[i] ? inputs[i]->name : "NONE"), (size_t)(p_end - p)); + } + } + + if (p < p_end) { + p += std::min((size_t)snprintf(p, p_end - p, " -> "), (size_t)(p_end - p)); + } } - p += sprintf(p, "%s", node.dst()->name); + if (p < p_end) { + p += std::min((size_t)snprintf(p, p_end - p, "%s", node.dst()->name), (size_t)(p_end - p)); + } + } + void format_kernel_params(char * str, size_t max_size, const htp_opnode & node) { + if (node.opcode == HTP_OP_MUL_MAT || node.opcode == HTP_OP_MUL_MAT_ID || + node.opcode == HTP_OP_MUL_MAT_QKV || node.opcode == HTP_OP_MUL_MAT_FFN) { + const auto * kparams = (const struct htp_mm_kernel_params *) node.kernel_params; + const char * path = "unknown"; + int32_t type = kparams->kernel_type; + if (type == HTP_MM_KERNEL_HMX_2D || type == HTP_MM_KERNEL_HMX_F16_BATCHED) { + path = "hmx-tiled"; + } else if (type == HTP_MM_KERNEL_HVX_F16_F16_VTCM || type == HTP_MM_KERNEL_HVX_F32_F32_VTCM || + type == HTP_MM_KERNEL_HVX_QUANT_ROW || type == HTP_MM_KERNEL_HVX_QUANT_BLOCK) { + path = "hvx-tiled"; + } else if (type == HTP_MM_KERNEL_HVX_F16_F16_DDR || type == HTP_MM_KERNEL_HVX_F16_F32_DDR || + type == HTP_MM_KERNEL_HVX_F32_F32_DDR || type == HTP_MM_KERNEL_HVX_F32_F16_DDR || + type == HTP_MM_KERNEL_HVX_QUANT_ROW_FLAT) { + path = "hvx-flat"; + } + snprintf(str, max_size, "%s vtcm %d", path, (int) kparams->vtcm_size); + } else { + snprintf(str, max_size, "----"); + } } void format(const htp_opnode & node) { - format_op_dims(dims, node); - format_op_strides(strides, node); - format_op_types(types, node); - format_op_buffs(buffs, node); - format_op_names(names, node); + format_op_dims(dims, sizeof(dims), node); + format_op_strides(strides, sizeof(strides), node); + format_op_types(types, sizeof(types), node); + format_op_buffs(buffs, sizeof(buffs), node); + format_op_names(names, sizeof(names), node); + format_kernel_params(kparams, sizeof(kparams), node); } - htp_opformat() {} + htp_opformat() { + strides[0] = '\0'; + dims[0] = '\0'; + types[0] = '\0'; + buffs[0] = '\0'; + names[0] = '\0'; + kparams[0] = '\0'; + } htp_opformat(const htp_opnode & node) { format(node); } }; diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt index 31ba527623..c48a5b86e3 100644 --- a/ggml/src/ggml-hexagon/htp/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt @@ -19,43 +19,9 @@ add_library(${HTP_LIB} SHARED htp_iface_skel.c worker-pool.c hex-dma.c -) - -target_compile_definitions(${HTP_LIB} PRIVATE - $,HTP_DEBUG=1,NDEBUG=1> - $,FARF_HIGH=1,> - FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE}) - -if (GGML_HEXAGON_FA_EXP2_HF) - message(STATUS "ggml-htp: HMX_FA_USE_EXP2_HF=1 (use FP16 exp2 polynomial in FA softmax)") - target_compile_definitions(${HTP_LIB} PRIVATE HMX_FA_USE_EXP2_HF=1) -endif() - -# HMX acceleration: available on v73+ architectures -set(HTP_HMX_VERSIONS v73 v75 v79 v81) -list(FIND HTP_HMX_VERSIONS ${DSP_VERSION} _hmx_idx) - -if (_hmx_idx GREATER_EQUAL 0) - target_sources(${HTP_LIB} PRIVATE - hmx-flash-attn-ops.c - hmx-matmul-ops.c - hmx-queue.c - ) - - # -mhmx enables HMX instruction set (needed by files that include hmx-utils.h) - set_source_files_properties( - hmx-flash-attn-ops.c - hmx-matmul-ops.c - hmx-queue.c - PROPERTIES COMPILE_OPTIONS "-mhmx" - ) - - target_compile_definitions(${HTP_LIB} PRIVATE HTP_HAS_HMX=1) -endif() - -build_idl(htp_iface.idl ${HTP_LIB}) - -target_sources(${HTP_LIB} PRIVATE + hmx-queue.c + flash-attn-ops.c + hmx-flash-attn-ops.c matmul-ops.c binary-ops.c unary-ops.c @@ -63,7 +29,6 @@ target_sources(${HTP_LIB} PRIVATE softmax-ops.c act-ops.c rope-ops.c - flash-attn-ops.c set-rows-ops.c get-rows-ops.c cpy-ops.c @@ -79,6 +44,17 @@ target_sources(${HTP_LIB} PRIVATE pad-ops.c ) +target_compile_definitions(${HTP_LIB} PRIVATE + $,HTP_DEBUG=1,NDEBUG=1> + $,FARF_HIGH=1,>) + +if (GGML_HEXAGON_FA_EXP2_HF) + message(STATUS "ggml-htp: HMX_FA_USE_EXP2_HF=1 (use FP16 exp2 polynomial in FA softmax)") + target_compile_definitions(${HTP_LIB} PRIVATE HMX_FA_USE_EXP2_HF=1) +endif() + +build_idl(htp_iface.idl ${HTP_LIB}) + set_target_properties(${HTP_LIB} PROPERTIES EXPORT_COMPILE_COMMANDS ON) install(TARGETS ${HTP_LIB}) diff --git a/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake b/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake index ed5c198468..3eff2a3986 100644 --- a/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +++ b/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake @@ -3,7 +3,7 @@ if (HEXAGON_TOOLCHAIN_INCLUDED) endif() set(HEXAGON_TOOLCHAIN_INCLUDED true) -#Cross Compiling for Hexagon +# Cross Compiling for Hexagon set(HEXAGON TRUE) set(CMAKE_SYSTEM_NAME QURT) set(CMAKE_SYSTEM_PROCESSOR Hexagon) @@ -14,7 +14,6 @@ set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY) set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE ONLY) set(CUSTOM_RUNELF_PATH "") -#To fix backward compatibility with EAI addon. if (NOT HEXAGON_SDK_ROOT) set(HEXAGON_SDK_ROOT $ENV{HEXAGON_SDK_ROOT}) endif() @@ -31,7 +30,6 @@ endif() file(TO_CMAKE_PATH "${HEXAGON_TOOLS_ROOT}" HEXAGON_TOOLS_ROOT) file(TO_CMAKE_PATH "${HEXAGON_SDK_ROOT}" HEXAGON_SDK_ROOT) -#Get the Binary extension of the Hexagon Toolchain if(CMAKE_HOST_SYSTEM_NAME STREQUAL Windows) set(HEXAGON_TOOLCHAIN_SUFFIX .exe) endif() @@ -48,12 +46,12 @@ set(CMAKE_TRY_COMPILE_PLATFORM_VARIABLES HEXAGON_TOOLS_ROOT ) -#QURT Related includes and linker flags +# QURT Related includes and linker flags set(V_ARCH ${HEXAGON_ARCH}) set(_QURT_INSTALL_DIR "${HEXAGON_SDK_ROOT}/rtos/qurt/ADSP${V_ARCH}MP${V_ARCH_EXTN}") set(_QURT_INSTALL_DIR "${HEXAGON_SDK_ROOT}/rtos/qurt/compute${V_ARCH}${V_ARCH_EXTN}") -if( ${TREE} MATCHES PAKMAN ) +if (${TREE} MATCHES PAKMAN) set(_QURT_INSTALL_DIR "${QURT_IMAGE_DIR}/compute${V_ARCH}${V_ARCH_EXTN}") endif() message(DEBUG "_QURT_INSTALL_DIR:${_QURT_INSTALL_DIR}") @@ -83,11 +81,9 @@ set(QURT_START_LINK_LIBS ) STRING(REPLACE ";" " " QURT_START_LINK_LIBS "${QURT_START_LINK_LIBS}") -set(QURT_END_LINK_LIBS - ${TARGET_DIR}/fini.o - ) +set(QURT_END_LINK_LIBS ${TARGET_DIR}/fini.o) -#Non QURT related includes and linker flags +# Non QURT related includes and linker flags set(TARGET_DIR_NOOS "${HEXAGON_TOOLCHAIN}/Tools/target/hexagon/lib/${HEXAGON_ARCH}") @@ -99,8 +95,10 @@ if (NOT NO_WRAP_MEM_API) set(WRAP_MEMALIGN -Wl,--wrap=memalign) endif() +set(ARCH_FLAGS "-mcpu=${V_ARCH} -m${V_ARCH} -mhvx=${V_ARCH} -mhmx") + set(PIC_SHARED_LD_FLAGS - -mcpu=${V_ARCH} -m${V_ARCH} -mhvx=${V_ARCH} + ${ARCH_FLAGS} -G0 -fpic -Wl,-Bsymbolic @@ -120,13 +118,13 @@ STRING(REPLACE ";" " " PIC_SHARED_LD_FLAGS "${PIC_SHARED_LD_FLAGS}") set(HEXAGON_PIC_SHARED_LINK_OPTIONS "${PIC_SHARED_LD_FLAGS}") -#System include paths +# System include paths include_directories(SYSTEM ${HEXAGON_SDK_ROOT}/incs) include_directories(SYSTEM ${HEXAGON_SDK_ROOT}/incs/stddef) include_directories(SYSTEM ${HEXAGON_SDK_ROOT}/ipc/fastrpc/incs) -#LLVM toolchain setup -#Compiler paths, options and architecture +# LLVM toolchain setup +# Compiler paths, options and architecture set(CMAKE_C_COMPILER ${HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-clang${HEXAGON_TOOLCHAIN_SUFFIX}) set(CMAKE_CXX_COMPILER ${HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-clang++${HEXAGON_TOOLCHAIN_SUFFIX}) set(CMAKE_AR ${HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-ar${HEXAGON_TOOLCHAIN_SUFFIX}) @@ -137,8 +135,8 @@ set(CMAKE_PREFIX_PATH ${HEXAGON_TOOLCHAIN}/Tools/target/hexagon) set(CMAKE_SHARED_LIBRARY_SONAME_C_FLAG "-Wl,-soname,") set(CMAKE_SHARED_LIBRARY_SONAME_CXX_FLAG "-Wl,-soname,") -#Compiler Options -set(COMMON_FLAGS "-mcpu=hexagon${V_ARCH} -m${V_ARCH} -mhvx=${V_ARCH} -fvectorize -flto -Wall -Werror -fno-zero-initialized-in-bss -G0 -fdata-sections -fpic ${XQF_ARGS}") +# Compiler Options +set(COMMON_FLAGS "${ARCH_FLAGS} -fvectorize -flto -Wall -Werror -fno-zero-initialized-in-bss -G0 -fdata-sections -fpic ${XQF_ARGS}") set(CMAKE_CXX_FLAGS_DEBUG "${COMMON_FLAGS} -O0 -D_DEBUG -g") set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} -O2 -g") diff --git a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c index b7511cd644..65f7844ae3 100644 --- a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c @@ -18,7 +18,8 @@ #include "htp-ctx.h" #include "htp-ops.h" #include "htp-ops.h" -#include "hmx-ops.h" + +int hmx_flash_attn_ext(struct htp_ops_context * octx); // Must be multiple of 32 #define FLASH_ATTN_BLOCK_SIZE (32 * 2) @@ -633,7 +634,6 @@ int op_flash_attn_ext(struct htp_ops_context * octx) { return HTP_STATUS_NO_SUPPORT; } -#ifdef HTP_HAS_HMX // HMX path: head_dim multiple of 64, F16 KV, and no sinks if (k->type == HTP_TYPE_F16 && v->type == HTP_TYPE_F16 && k->ne[0] % 64 == 0 && v->ne[0] % 64 == 0 && octx->src[4] == NULL) { int ret = hmx_flash_attn_ext(octx); @@ -642,7 +642,6 @@ int op_flash_attn_ext(struct htp_ops_context * octx) { } // VTCM too small or other failure -> fall through to HVX path } -#endif struct htp_fa_context factx; factx.octx = octx; diff --git a/ggml/src/ggml-hexagon/htp/hex-common.h b/ggml/src/ggml-hexagon/htp/hex-common.h new file mode 100644 index 0000000000..4714486a04 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hex-common.h @@ -0,0 +1,80 @@ +#ifndef HEX_COMMON_H +#define HEX_COMMON_H + +#include +#include +#include + +#ifndef SIZE_MAX +#define SIZE_MAX ((size_t)-1) +#endif + +#ifndef MAX +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#endif + +#ifndef MIN +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#endif + +static inline uint32_t hex_ceil_pow2(uint32_t x) { + if (x <= 1) { return 1; } + int p = 2; + x--; + while (x >>= 1) { p <<= 1; } + return p; +} + +static inline size_t hmx_ceil_div(size_t num, size_t den) { + return (num + den - 1) / den; +} + +static inline int32_t hex_is_aligned(const void * addr, uint32_t align) { + return ((size_t) addr & (align - 1)) == 0; +} + +static inline size_t hex_align_up(size_t v, size_t align) { + return hmx_ceil_div(v, align) * align; +} + +static inline size_t hex_align_down(size_t v, size_t align) { + return (v / align) * align; +} + +static inline int32_t hex_is_one_chunk(void * addr, uint32_t n, uint32_t chunk_size) { + uint32_t left_off = (size_t) addr & (chunk_size - 1); + uint32_t right_off = left_off + n; + return right_off <= chunk_size; +} + +static inline uint32_t hex_round_up(uint32_t n, uint32_t m) { + return m * ((n + m - 1) / m); +} + +static inline size_t hex_smin(size_t a, size_t b) { + return a < b ? a : b; +} + +static inline size_t hex_smax(size_t a, size_t b) { + return a > b ? a : b; +} + +static inline void hex_swap_ptr(void ** p1, void ** p2) { + void * t = *p1; + *p1 = *p2; + *p2 = t; +} + +static inline bool hex_mul_overflow(size_t a, size_t b, size_t *out) { + if (a != 0 && b > SIZE_MAX / a) return true; + *out = a * b; + return false; +} + +static inline bool hex_add_overflow(size_t a, size_t b, size_t *out) { + if (a > SIZE_MAX - b) return true; + *out = a + b; + return false; +} + +#endif // HEX_COMMON_H diff --git a/ggml/src/ggml-hexagon/htp/hex-dma.h b/ggml/src/ggml-hexagon/htp/hex-dma.h index 93c21ebe5e..8031a5679c 100644 --- a/ggml/src/ggml-hexagon/htp/hex-dma.h +++ b/ggml/src/ggml-hexagon/htp/hex-dma.h @@ -5,6 +5,7 @@ #include #include #include +#include "hex-utils.h" #include "hex-profile.h" @@ -127,13 +128,8 @@ static inline dma_ptr dma_make_ptr(void *dst, const void *src) return p; } -#if __HVX_ARCH__ < 73 -static const uint32_t dma_src_l2_bypass_on = 1; -static const uint32_t dma_dst_l2_bypass_on = 0; -#else static const uint32_t dma_src_l2_bypass_on = 1; static const uint32_t dma_dst_l2_bypass_on = 1; -#endif static inline bool dma_queue_push_single_1d(dma_queue * q, dma_ptr dptr, size_t size) { if (((q->push_idx + 1) & q->idx_mask) == q->pop_idx) { diff --git a/ggml/src/ggml-hexagon/htp/hex-utils.h b/ggml/src/ggml-hexagon/htp/hex-utils.h index 8e6e3ea750..07930bef6e 100644 --- a/ggml/src/ggml-hexagon/htp/hex-utils.h +++ b/ggml/src/ggml-hexagon/htp/hex-utils.h @@ -11,14 +11,7 @@ #include "hex-fastdiv.h" #include "hex-dump.h" - -#ifndef MAX -#define MAX(a, b) ((a) > (b) ? (a) : (b)) -#endif - -#ifndef MIN -#define MIN(a, b) ((a) < (b) ? (a) : (b)) -#endif +#include "hex-common.h" static inline uint64_t hex_get_cycles() { uint64_t cycles = 0; @@ -32,54 +25,6 @@ static inline uint64_t hex_get_pktcnt() { return pktcnt; } -static inline uint32_t hex_ceil_pow2(uint32_t x) { - if (x <= 1) { return 1; } - int p = 2; - x--; - while (x >>= 1) { p <<= 1; } - return p; -} - -static inline size_t hmx_ceil_div(size_t num, size_t den) { - return (num + den - 1) / den; -} - -static inline int32_t hex_is_aligned(const void * addr, uint32_t align) { - return ((size_t) addr & (align - 1)) == 0; -} - -static inline size_t hex_align_up(size_t v, size_t align) { - return hmx_ceil_div(v, align) * align; -} - -static inline size_t hex_align_down(size_t v, size_t align) { - return (v / align) * align; -} - -static inline int32_t hex_is_one_chunk(void * addr, uint32_t n, uint32_t chunk_size) { - uint32_t left_off = (size_t) addr & (chunk_size - 1); - uint32_t right_off = left_off + n; - return right_off <= chunk_size; -} - -static inline uint32_t hex_round_up(uint32_t n, uint32_t m) { - return m * ((n + m - 1) / m); -} - -static inline size_t hex_smin(size_t a, size_t b) { - return a < b ? a : b; -} - -static inline size_t hex_smax(size_t a, size_t b) { - return a > b ? a : b; -} - -static inline void hex_swap_ptr(void ** p1, void ** p2) { - void * t = *p1; - *p1 = *p2; - *p2 = t; -} - static inline void hex_l2fetch(const void * p, uint32_t width, uint32_t stride, uint32_t height) { const uint64_t control = Q6_P_combine_RR(stride, Q6_R_combine_RlRl(width, height)); Q6_l2fetch_AP((void *) p, control); diff --git a/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c index 986dde148d..996fd59757 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c @@ -49,7 +49,7 @@ // g_br = hex_align_up(gqa_factor * Br, 32) replaces Br for all Q/O/S/P/D dimensions. // Layout: Q + O_ping + O_pong + K_dma*2 + V_dma*2 + K_tile + V_tile + S + P + D + vectors + scales // Mask is DMA'd into a VTCM buffer (Br rows per KV block) to avoid DDR reads in softmax. -static size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV, size_t Br, size_t Bc, size_t n_threads, bool use_pipeline) { +static size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV, size_t Br, size_t Bc, size_t n_threads, bool pipeline) { const size_t g_br = hex_align_up(gqa_factor * Br, HMX_FP16_TILE_N_ROWS); const size_t q_tile_size = hex_align_up(g_br * DK * sizeof(__fp16), 4096); // Q: [g_br, DK] const size_t o_tile_size = hex_align_up(g_br * DV * sizeof(__fp16), 4096); // O: [g_br, DV] x2 ping-pong @@ -70,7 +70,7 @@ static size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV, + k_dma_size * 2 // K DMA x2 + v_dma_size * 2 // V DMA x2 + k_tile_size * 1 // K tiles - + v_tile_size * (use_pipeline ? 2 : 1) // V tiles (double-buffered if pipelining) + + v_tile_size * (pipeline ? 2 : 1) // V tiles (double-buffered if pipelining) + s_tile_size * 2 // S + P + d_tile_size * 1 // D (diagonal matrix) + col_vec_size * 4 // m_vec, l_vec, s_rowmax, p_rowsum @@ -290,7 +290,7 @@ static const int16_t d_tile_scatter_offsets[64] __attribute__((aligned(128))) = struct hmx_fa_context { const struct htp_ops_context * octx; - bool use_pipeline; // true when n_kv_blocks >= FA_MIN_KV_BLOCKS && n_threads >= 2 + bool pipeline; // true when n_kv_blocks >= FA_MIN_KV_BLOCKS && n_threads >= 2 uint32_t n_threads; // Op parameters @@ -409,7 +409,7 @@ static void fa_v_interleave_thread(unsigned int n, unsigned int i, void * data) return; } - __fp16 * v_tiles_dest = factx->use_pipeline ? factx->vtcm_v_tiles[args->buf_idx] : factx->vtcm_v_tiles[0]; + __fp16 * v_tiles_dest = factx->pipeline ? factx->vtcm_v_tiles[args->buf_idx] : factx->vtcm_v_tiles[0]; struct htp_thread_trace * tr = factx->octx->ctx ? &factx->octx->ctx->trace[i] : NULL; htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, start); @@ -1312,13 +1312,13 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { const size_t g_br = hex_align_up(G * Br, HMX_FP16_TILE_N_ROWS); const uint32_t n_kv_blocks = (nek1 + Bc - 1) / Bc; - const bool use_pipeline = (n_kv_blocks >= FA_MIN_KV_BLOCKS && n_threads_init >= 2); + const bool pipeline = (n_kv_blocks >= FA_MIN_KV_BLOCKS && n_threads_init >= 2); // Bypass thread pool dispatch for small prompts/non-pipelined prefill by setting n_threads = 1 - const uint32_t n_threads = use_pipeline ? n_threads_init : 1; + const uint32_t n_threads = pipeline ? n_threads_init : 1; FARF(HIGH, "hmx-fa: neq1=%u nek1=%u DK=%u DV=%u G=%u Br=%zu Bc=%zu g_br=%zu n_kv_blocks=%u pipeline=%d vtcm=%zu", - neq1, nek1, DK, DV, G, Br, Bc, g_br, n_kv_blocks, use_pipeline, vtcm_budget); + neq1, nek1, DK, DV, G, Br, Bc, g_br, n_kv_blocks, pipeline, vtcm_budget); // ======== Build context ======== struct hmx_fa_context factx; @@ -1339,7 +1339,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { factx.n_kv_blocks = n_kv_blocks; factx.is_q_fp32 = (q->type == HTP_TYPE_F32); factx.is_dst_fp32 = (dst->type == HTP_TYPE_F32); - factx.use_pipeline = use_pipeline; + factx.pipeline = pipeline; factx.mask_broadcast = (mask != NULL && mask->ne[2] == 1); // Extract op parameters (mutable during softcap adjustment, then stored as const in factx) @@ -1405,7 +1405,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { factx.vtcm_v_fp16[1] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_dma_bytes); factx.vtcm_k_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, k_tile_bytes); factx.vtcm_v_tiles[0] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_tile_bytes); - if (use_pipeline) { + if (pipeline) { factx.vtcm_v_tiles[1] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_tile_bytes); } else { factx.vtcm_v_tiles[1] = NULL; @@ -1456,7 +1456,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { // ======== HMX lock strategy ======== // Pipeline: queue thread auto-acquires HMX lock on first push; released by suspend. // Fallback: main thread holds the lock (original behavior). - if (!factx.use_pipeline) { + if (!factx.pipeline) { HAP_compute_res_hmx_lock(ctx->vtcm_rctx); } @@ -1550,7 +1550,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { const size_t k_src_stride = size_k_row_padded / sizeof(__fp16); const size_t v_src_stride = size_v_row_padded / sizeof(__fp16); - if (factx.use_pipeline) { + if (factx.pipeline) { // ================================================================== // Pipeline path: HVX phases ‖ HMX queue worker // ================================================================== @@ -1780,7 +1780,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { fa_build_d_diag_inv_l(&factx, n_row_tiles, n_row_tiles_g_br); // HMX: O_final = diag(1/l) @ O_prev - if (factx.use_pipeline) { + if (factx.pipeline) { on_job.o_curr = o_tile_curr; on_job.o_prev = o_tile_prev; on_job.d_tiles = factx.vtcm_d_tiles; @@ -1826,7 +1826,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { } // end KV head loop } // end batch loop - if (factx.use_pipeline) { + if (factx.pipeline) { hmx_queue_suspend(ctx->hmx_queue); } else { HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); diff --git a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c deleted file mode 100644 index 5c37f24ff0..0000000000 --- a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +++ /dev/null @@ -1,2080 +0,0 @@ -#pragma clang diagnostic ignored "-Wgnu-zero-variadic-macro-arguments" -#pragma clang diagnostic ignored "-Wunused-function" -#pragma clang diagnostic ignored "-Wunused-variable" -#pragma clang diagnostic ignored "-Wunused-but-set-variable" - -#include -#include -#include -#include -#include - -#include -#include - -#define GGML_COMMON_DECL_C -#include "ggml-common.h" - -#include "hex-dma.h" -#include "hex-fastdiv.h" -#include "worker-pool.h" - -#include "hvx-utils.h" -#include "hvx-dump.h" -#include "htp-ctx.h" -#include "htp-ops.h" - -#include "hmx-ops.h" -#include "hmx-utils.h" -#include "hmx-queue.h" -#include "hex-profile.h" - -#include "vtcm-utils.h" - -static const __fp16 q4_0_to_fp16_lut[64] __attribute__((aligned(VLEN))) = { - -8, 0, -7, 0, -6, 0, -5, 0, -4, 0, -3, 0, -2, 0, -1, 0, 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0, -}; - -static const __fp16 q4_1_to_fp16_lut[64] __attribute__((aligned(VLEN))) = { - 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0, 8, 0, 9, 0, 10, 0, 11, 0, 12, 0, 13, 0, 14, 0, 15, 0, -}; - -// MXFP4 dequantization LUT: maps 4-bit index to fp16 mantissa value -// kvalues: 0, 0.5, 1, 1.5, 2, 3, 4, 6, 0, -0.5, -1, -1.5, -2, -3, -4, -6 -static const __fp16 mxfp4_to_fp16_lut[64] __attribute__((aligned(VLEN))) = { - 0, 0, 0.5, 0, 1, 0, 1.5, 0, 2, 0, 3, 0, 4, 0, 6, 0, 0, 0, -0.5, 0, -1, 0, -1.5, 0, -2, 0, -3, 0, -4, 0, -6, 0, -}; - -static const __fp16 iq4_nl_to_fp16_lut[64] __attribute__((aligned(VLEN))) = { - -127, 0, -104, 0, -83, 0, -65, 0, -49, 0, -35, 0, -22, 0, -10, 0, - 1, 0, 13, 0, 25, 0, 38, 0, 53, 0, 69, 0, 89, 0, 113, 0, -}; - -// Scales per x4x2 logical block: 8 × sizeof(__fp16) = 16 bytes -#define HMX_X4X2_SCALES_PER_BLK 8 -#define HMX_X4X2_DBLK_SIZE 16 // 8 * 2 bytes (fp16 scales for Q4_0/Q8_0/IQ4_NL) -#define HMX_X4X2_MXFP4_EBLK_SIZE 8 // 8 * 1 byte (E8M0 scales for MXFP4) - -// Compute the byte stride of one row in x4x2 format. -// Numerically equals ggml_row_size(type, k) when k is 256-aligned, because -// x4x2 packing has the same density as block_q4_0 / block_q8_0. -// Layout per row: [quants: nb*128 (Q4) or nb*256 (Q8)][scales: nb*16 bytes] -// Total per row = nb * (128+16) = 144*nb (Q4) or nb * (256+16) = 272*nb (Q8). -// Callers must ensure k is a multiple of 256 (enforced by proc_hmx_matmul_req). -static inline size_t get_x4x2_row_stride(int weight_type, int k) { - int nb = (k + QK_Q4_0x4x2 - 1) / QK_Q4_0x4x2; - switch (weight_type) { - case HTP_TYPE_Q4_0: - case HTP_TYPE_IQ4_NL: - return (size_t) nb * (QK_Q4_0x4x2 / 2 + HMX_X4X2_DBLK_SIZE); // 144 * nb - case HTP_TYPE_Q4_1: - return (size_t) nb * (QK_Q4_0x4x2 / 2 + 32); // 160 * nb - case HTP_TYPE_Q8_0: - return (size_t) nb * (QK_Q8_0x4x2 + HMX_X4X2_DBLK_SIZE); // 272 * nb - case HTP_TYPE_MXFP4: - return (size_t) nb * (QK_MXFP4x4x2 / 2 + HMX_X4X2_MXFP4_EBLK_SIZE); // 136 * nb - case HTP_TYPE_F16: - return (size_t) k * sizeof(__fp16); - case HTP_TYPE_F32: - return (size_t) k * sizeof(float); - default: - return 0; - } -} - -// --- Overflow-safe arithmetic for VTCM budget calculation --- - -static inline bool hmx_mul_overflow(size_t a, size_t b, size_t *out) { - if (a != 0 && b > SIZE_MAX / a) return true; - *out = a * b; - return false; -} - -static inline bool hmx_add_overflow(size_t a, size_t b, size_t *out) { - if (a > SIZE_MAX - b) return true; - *out = a + b; - return false; -} - -// Search for optimal (mc, nc) chunk sizes within VTCM budget. -// -// VTCM model: nc * per_n_cost + mc * per_m_cost + mc * nc * per_mn_cost + overhead -// -// Minimize ceil(m/mc) * m_block_cost + ceil(n/nc) * n_block_cost. -// All matmul paths repeat weight processing per M-block and activation loading -// per N-block, so discrete block counts drive total overhead. -// Tie-break: when cost is equal, prefer larger mc * nc. -// -// Caller-provided coefficients: -// m_block_cost: penalty per extra M-block (weight redundancy, scales with n). -// n_block_cost: penalty per extra N-block (activation redundancy, scales with m). -// -// Algorithm: nc sweeps from n_max down by 32, analytically solving for mc_max. -// Returns 0 on success, -1 if VTCM is insufficient. -static int hmx_compute_chunks(size_t vtcm_total, - size_t overhead, - size_t per_n_cost, - size_t per_m_cost, - size_t per_mn_cost, - int m, - int n, - size_t m_block_cost, - size_t n_block_cost, - size_t * m_chunk_out, - size_t * n_chunk_out, - size_t * total_out) { - if (m <= 0 || n <= 0) return -1; - if (vtcm_total <= overhead) return -1; - if (per_n_cost == 0 || per_m_cost == 0 || per_mn_cost == 0) return -1; - - const size_t usable = vtcm_total - overhead; - - size_t best_cost = SIZE_MAX; - size_t best_mn = 0; - size_t best_m = 0, best_n = 0; - - const size_t n_max = hex_align_down((size_t)n, HMX_FP16_TILE_N_COLS); - for (size_t nc = n_max; nc >= HMX_FP16_TILE_N_COLS; nc -= HMX_FP16_TILE_N_COLS) { - size_t n_fixed = 0, ncmn = 0, mc_denom = 0; - if (hmx_mul_overflow(nc, per_n_cost, &n_fixed)) continue; - if (n_fixed >= usable) goto next_nc; - - if (hmx_mul_overflow(nc, per_mn_cost, &ncmn)) goto next_nc; - if (hmx_add_overflow(per_m_cost, ncmn, &mc_denom) || mc_denom == 0) goto next_nc; - - { - size_t remain = usable - n_fixed; - size_t mc = remain / mc_denom; - mc = hex_align_down(mc, HMX_FP16_TILE_N_ROWS); - mc = hex_smin(mc, (size_t)m); - - if (mc == 0) { - goto next_nc; - } - - size_t mblocks = ((size_t) m + mc - 1) / mc; - size_t nblocks = ((size_t) n + nc - 1) / nc; - size_t cost = mblocks * m_block_cost + nblocks * n_block_cost; - size_t mn = mc * nc; - if (cost < best_cost || (cost == best_cost && mn > best_mn)) { - best_cost = cost; - best_mn = mn; - best_m = mc; - best_n = nc; - } - } - -next_nc: - if (nc == HMX_FP16_TILE_N_COLS) break; // avoid size_t underflow - } - - if (best_m == 0 || best_n == 0) return -1; - - // Compute exact total (with overflow checks) - size_t t0 = 0, t1 = 0, t2 = 0, mn = 0, total = 0; - if (hmx_mul_overflow(best_n, per_n_cost, &t0)) return -1; - if (hmx_mul_overflow(best_m, per_m_cost, &t1)) return -1; - if (hmx_mul_overflow(best_m, best_n, &mn)) return -1; - if (hmx_mul_overflow(mn, per_mn_cost, &t2)) return -1; - if (hmx_add_overflow(t0, t1, &total)) return -1; - if (hmx_add_overflow(total, t2, &total)) return -1; - if (hmx_add_overflow(total, overhead, &total)) return -1; - - *m_chunk_out = best_m; - *n_chunk_out = best_n; - *total_out = total; - return 0; -} - -// --- x4x2 format dequantizers --- - -// Dequantize one x4x2 Q4_0 group (32 elements from 32 packed bytes) -> 32 FP16 in first 64 bytes. -// In x4x2, sub-blocks 0..3 use lower nibbles, sub-blocks 4..7 use upper nibbles -// of the same 32 packed bytes. -static inline HVX_Vector dequantize_x4x2_q4_0_group_hvx(const uint8_t *packed_32, bool upper_nibbles, const __fp16 *scale, const HVX_Vector vlut_cvt) { - (void)vlut_cvt; - HVX_Vector vq = hvx_vmemu(packed_32); - const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); - const HVX_Vector i8 = Q6_Vb_vsplat_R(8); - HVX_Vector v_scales = hvx_vec_repl_f16(hvx_vmemu(scale)); - - HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles); - v_quants = Q6_V_vand_VV(v_quants, mask_h4); - - HVX_Vector v_int8 = Q6_Vb_vsub_VbVb(v_quants, i8); - HVX_Vector v0 = Q6_V_lo_W(Q6_Wh_vunpack_Vb(v_int8)); - HVX_Vector v_hf = Q6_Vhf_equals_Vh(v0); - - return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_scales)); -} - -// Batch-dequantize 4 contiguous x4x2 Q4_0 groups (4x32 = 128 packed bytes) using -// full HVX vector width. -// Output: vector_x2 each hold 32 FP16 values in the first 64 bytes. -static inline HVX_Vector_x2 dequantize_x4x2_q4_0_x4groups_hvx( - const uint8_t *packed_128, bool upper_nibbles, - const __fp16 *scales_4, const HVX_Vector vlut_cvt) { - (void)vlut_cvt; - HVX_Vector vq = hvx_vmemu(packed_128); - const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); - const HVX_Vector i8 = Q6_Vb_vsplat_R(8); - HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles); - v_quants = Q6_V_vand_VV(v_quants, mask_h4); - - HVX_Vector v_int8 = Q6_Vb_vsub_VbVb(v_quants, i8); - - HVX_VectorPair vp_int16 = Q6_Wh_vunpack_Vb(v_int8); - HVX_Vector v_lo = Q6_V_lo_W(vp_int16); - HVX_Vector v_hi = Q6_V_hi_W(vp_int16); - - v_lo = Q6_Vhf_equals_Vh(v_lo); - v_hi = Q6_Vhf_equals_Vh(v_hi); - - HVX_Vector vscale = hvx_vmemu(scales_4); - HVX_Vector v_sc01 = hvx_vec_repl_2x_f16(vscale); - HVX_Vector v_sc23 = hvx_vec_repl_2x_f16(Q6_V_vror_VR(vscale, 4)); - - v_lo = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_lo, v_sc01)); - v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23)); - - HVX_Vector_x2 r = { v_lo, v_hi }; - return r; -} - -static inline HVX_Vector dequantize_x4x2_q4_1_group_hvx(const uint8_t *packed_32, bool upper_nibbles, const __fp16 *scale_offset, const HVX_Vector vlut_cvt) { - (void)vlut_cvt; - HVX_Vector vq = hvx_vmemu(packed_32); - const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); - HVX_Vector v_dm = hvx_vmemu(scale_offset); - HVX_Vector v_scales = hvx_vec_repl_f16(v_dm); - HVX_Vector v_offsets = hvx_vec_repl_f16(Q6_V_vror_VR(v_dm, 2)); - - HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles); - v_quants = Q6_V_vand_VV(v_quants, mask_h4); - - HVX_Vector v0 = Q6_V_lo_W(Q6_Wh_vunpack_Vb(v_quants)); - HVX_Vector v_hf = Q6_Vhf_equals_Vh(v0); - - return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_scales), v_offsets)); -} - -static inline HVX_Vector_x2 dequantize_x4x2_q4_1_x4groups_hvx( - const uint8_t *packed_128, bool upper_nibbles, - const __fp16 *scales_offsets_4, const HVX_Vector vlut_cvt) { - (void)vlut_cvt; - HVX_Vector vq = hvx_vmemu(packed_128); - const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); - HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles); - v_quants = Q6_V_vand_VV(v_quants, mask_h4); - - HVX_VectorPair vp_int16 = Q6_Wh_vunpack_Vb(v_quants); - HVX_Vector v_lo = Q6_V_lo_W(vp_int16); - HVX_Vector v_hi = Q6_V_hi_W(vp_int16); - - v_lo = Q6_Vhf_equals_Vh(v_lo); - v_hi = Q6_Vhf_equals_Vh(v_hi); - - HVX_Vector vscale_offset = hvx_vmemu(scales_offsets_4); - HVX_VectorPair dm_deal = Q6_W_vdeal_VVR(vscale_offset, vscale_offset, -2); - HVX_Vector vd = Q6_V_lo_W(dm_deal); - HVX_Vector vm = Q6_V_hi_W(dm_deal); - - HVX_Vector v_sc01 = hvx_vec_repl_2x_f16(vd); - HVX_Vector v_sc23 = hvx_vec_repl_2x_f16(Q6_V_vror_VR(vd, 4)); - - HVX_Vector v_os01 = hvx_vec_repl_2x_f16(vm); - HVX_Vector v_os23 = hvx_vec_repl_2x_f16(Q6_V_vror_VR(vm, 4)); - - v_lo = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(v_lo, v_sc01), v_os01)); - v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23), v_os23)); - - HVX_Vector_x2 r = { v_lo, v_hi }; - return r; -} - -// LUT-based dequantizers for non-linear IQ4_NL format. -static inline HVX_Vector dequantize_x4x2_iq4_nl_group_hvx(const uint8_t *packed_32, bool upper_nibbles, const __fp16 *scale, const HVX_Vector vlut_cvt) { - HVX_Vector vq = hvx_vmemu(packed_32); - const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); - HVX_Vector v_scales = hvx_vec_repl_f16(hvx_vmemu(scale)); - HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles); - v_quants = Q6_V_vand_VV(v_quants, mask_h4); - v_quants = Q6_Vb_vshuff_Vb(v_quants); - HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0); - HVX_Vector v_hf = Q6_V_lo_W(vp); - - return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_scales)); -} - -static inline HVX_Vector_x2 dequantize_x4x2_iq4_nl_x4groups_hvx( - const uint8_t *packed_128, bool upper_nibbles, - const __fp16 *scales_4, const HVX_Vector vlut_cvt) { - HVX_Vector vq = hvx_vmemu(packed_128); - const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); - HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles); - v_quants = Q6_V_vand_VV(v_quants, mask_h4); - - v_quants = Q6_Vb_vshuff_Vb(v_quants); - - HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0); - HVX_Vector v_lo = Q6_V_lo_W(vp); - HVX_Vector v_hi = Q6_V_hi_W(vp); - - HVX_Vector vscale = hvx_vmemu(scales_4); - HVX_Vector v_sc01 = hvx_vec_repl_2x_f16(vscale); - HVX_Vector v_sc23 = hvx_vec_repl_2x_f16(Q6_V_vror_VR(vscale, 4)); - - v_lo = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_lo, v_sc01)); - v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23)); - - HVX_Vector_x2 r = { v_lo, v_hi }; - return r; -} - -// Dequantize one x4x2 Q8_0 group (32 int8 quants) -> 32 FP16 in first 64 bytes. -static inline HVX_Vector dequantize_x4x2_q8_0_group_hvx(const int8_t *quants_32, const __fp16 *scale) { - HVX_Vector vq = hvx_vmemu(quants_32); - HVX_Vector v_scales = hvx_vec_repl_f16(hvx_vmemu(scale)); - HVX_Vector v0 = Q6_V_lo_W(Q6_Wh_vunpack_Vb(vq)); - HVX_Vector v_hf = Q6_Vhf_equals_Vh(v0); - return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_scales)); -} - -// --- MXFP4 E8M0 scale conversion and dequantization --- -// -// HVX batch-convert 8 E8M0 bytes (one x4x2 block's scales) to __fp16[8] on stack. -// Scalar loads from the stack array execute on the scalar pipeline, in parallel -// with HVX vlut16/vmpy/vscatter — freeing HVX slots in the hot loop. -// Arithmetic: fp16_bits = clamp(e - 112, 0, 30) << 10 -// e=0..112 -> 0 (underflow), e=113..142 -> valid fp16, e>=143 -> clamped to 2^15. - -typedef struct { - __fp16 v[8] __attribute__((aligned(16))); -} mxfp4_scales_t; - -static inline mxfp4_scales_t mxfp4_convert_scales(const uint8_t * e8m0_8) { - mxfp4_scales_t s; - HVX_Vector v = hvx_vmemu(e8m0_8); - HVX_Vector vh = Q6_V_lo_W(Q6_Wuh_vunpack_Vub(v)); - vh = Q6_Vh_vsub_VhVh(vh, Q6_Vh_vsplat_R(112)); - vh = Q6_Vh_vmax_VhVh(vh, Q6_V_vzero()); - vh = Q6_Vh_vmin_VhVh(vh, Q6_Vh_vsplat_R(30)); - vh = Q6_Vh_vasl_VhR(vh, 10); - hvx_vec_store_u(s.v, 16, vh); - return s; -} - -static inline HVX_Vector mxfp4_extract_splat(mxfp4_scales_t scales, int idx) { - return hvx_vec_splat_f16(scales.v[idx]); -} - -// Dequantize one x4x2 MXFP4 group (32 elements from 32 packed bytes) -> 32 FP16. -static inline HVX_Vector dequantize_x4x2_mxfp4_group_hvx(const uint8_t * packed_32, - bool upper_nibbles, - int sub_blk, - const HVX_Vector vlut_cvt, - mxfp4_scales_t scales) { - HVX_Vector vq = hvx_vmemu(packed_32); - const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); - HVX_Vector v_quants = upper_nibbles ? Q6_Vub_vlsr_VubR(vq, 4) : vq; - v_quants = Q6_V_vand_VV(v_quants, mask_h4); - - HVX_Vector v_sc = mxfp4_extract_splat(scales, sub_blk); - - v_quants = Q6_Vb_vshuff_Vb(v_quants); - HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0); - HVX_Vector v_hf = Q6_V_lo_W(vp); - - return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_sc)); -} - -// Batch-dequantize 4 contiguous x4x2 MXFP4 groups (4x32 = 128 packed bytes). -static inline HVX_Vector_x4 dequantize_x4x2_mxfp4_x4groups_hvx(const uint8_t * packed_128, - bool upper_nibbles, - int sub_blk_base, - const HVX_Vector vlut_cvt, - mxfp4_scales_t scales) { - HVX_Vector vq = hvx_vmemu(packed_128); - const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); - HVX_Vector v_quants = upper_nibbles ? Q6_Vub_vlsr_VubR(vq, 4) : vq; - v_quants = Q6_V_vand_VV(v_quants, mask_h4); - - v_quants = Q6_Vb_vshuff_Vb(v_quants); - - HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0); - HVX_Vector v_lo = Q6_V_lo_W(vp); - HVX_Vector v_hi = Q6_V_hi_W(vp); - - HVX_VectorPred q64 = Q6_Q_vsetq_R(64); - HVX_Vector v_sc01 = Q6_V_vmux_QVV(q64, mxfp4_extract_splat(scales, sub_blk_base + 0), - mxfp4_extract_splat(scales, sub_blk_base + 1)); - HVX_Vector v_sc23 = Q6_V_vmux_QVV(q64, mxfp4_extract_splat(scales, sub_blk_base + 2), - mxfp4_extract_splat(scales, sub_blk_base + 3)); - - v_lo = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_lo, v_sc01)); - v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23)); - - HVX_Vector_x4 r = { v_lo, Q6_V_vror_VR(v_lo, 64), v_hi, Q6_V_vror_VR(v_hi, 64) }; - return r; -} - -typedef struct { - __fp16 *dst; - const uint8_t *src; - int n_cols; - int k_block; - size_t row_stride; - int weight_type; - int n_tot_tiles; - int n_tiles_per_task; - int n_tasks; - int n_k_tiles; - struct fastdiv_values n_k_tiles_div; - struct htp_thread_trace * traces; -} x4x2_dequantize_state_t; - -// Dequantize a tile range from x4x2 weight data (already in VTCM) to tile-major FP16. -// Input: vtcm_src has n_cols rows of x4x2 data, each row_stride bytes. -// Output: vtcm_dst in tile-major FP16 layout. - -#define DEFINE_DEQUANTIZE_Q4_TASK(suffix, lut_name, helper_prefix, dblk_size, scale_step) \ -static void dequantize_x4x2_weight_to_fp16_tiles_task_##suffix( \ - const x4x2_dequantize_state_t *state, \ - int start_tile, int end_tile) { \ - \ - const int n_k_tiles = state->n_k_tiles; \ - const int qrow_size = (unsigned)state->k_block / 2; \ - const struct fastdiv_values n_k_tiles_div = state->n_k_tiles_div; \ - const HVX_Vector vlut_cvt = hvx_vmem(lut_name); \ - \ - const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets); \ - const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); \ - const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); \ - \ - unsigned ct = fastdiv((unsigned)start_tile, &n_k_tiles_div); \ - unsigned kt = fastmodulo((unsigned)start_tile, n_k_tiles, &n_k_tiles_div); \ - \ - for (unsigned t = start_tile; t < (unsigned)end_tile; ) { \ - if (kt >= (unsigned)n_k_tiles) { kt = 0; ct++; } \ - \ - if ((kt % 4 == 0) && (t + 4 <= (unsigned)end_tile) && (fastdiv(t + 3, &n_k_tiles_div) == ct)) { \ - unsigned blk_idx = ((kt * 32) / QK_Q4_0x4x2); \ - unsigned sub_blk_base = ((kt * 32) % QK_Q4_0x4x2) / 32; \ - bool upper = (sub_blk_base >= 4); \ - unsigned packed_off = blk_idx * (QK_Q4_0x4x2 / 2); \ - unsigned scale_off = qrow_size + blk_idx * (dblk_size) + sub_blk_base * (scale_step); \ - \ - __fp16 *tile_bases[4]; \ - for (unsigned g = 0; g < 4; g++) { \ - tile_bases[g] = state->dst + (t + g) * HMX_FP16_TILE_N_ELMS; \ - } \ - \ - HVX_Vector v_off = v_scat_base; \ - unsigned row_offset = ct * HMX_FP16_TILE_N_COLS * state->row_stride; \ - \ - for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) { \ - const uint8_t *r0 = state->src + row_offset; row_offset += state->row_stride; \ - const uint8_t *r1 = state->src + row_offset; row_offset += state->row_stride; \ - \ - HVX_Vector_x2 dv0 = dequantize_x4x2_##helper_prefix##_x4groups_hvx( \ - r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt); \ - Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[0]); \ - Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[1]); \ - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); \ - \ - HVX_Vector_x2 dv1 = dequantize_x4x2_##helper_prefix##_x4groups_hvx( \ - r1 + packed_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt); \ - Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[0]); \ - Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[1]); \ - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); \ - } \ - \ - for (int g = 0; g < 4; g++) { (void) *(volatile HVX_Vector *)(tile_bases[g]); } \ - t += 4; kt += 4; \ - continue; \ - } \ - \ - __fp16 *tile_base = state->dst + t * HMX_FP16_TILE_N_ELMS; \ - { \ - unsigned blk_idx = (kt * 32) / QK_Q4_0x4x2; \ - unsigned sub_blk = ((kt * 32) % QK_Q4_0x4x2) / 32; \ - bool upper = (sub_blk >= 4); \ - unsigned byte_off = blk_idx * (QK_Q4_0x4x2 / 2) + (upper ? (sub_blk - 4) : sub_blk) * 32; \ - unsigned scale_off = qrow_size + blk_idx * (dblk_size) + sub_blk * (scale_step); \ - \ - HVX_Vector v_off = v_scat_base; \ - unsigned row_offset = ct * HMX_FP16_TILE_N_COLS * state->row_stride; \ - unsigned row1 = ct * HMX_FP16_TILE_N_COLS + 1; \ - \ - for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) { \ - const uint8_t *r0 = state->src + row_offset; row_offset += state->row_stride; \ - const uint8_t *r1 = state->src + row_offset; row_offset += state->row_stride; \ - \ - HVX_Vector v0 = dequantize_x4x2_##helper_prefix##_group_hvx( \ - r0 + byte_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt); \ - HVX_Vector v1 = (row1 < (unsigned)state->n_cols) \ - ? dequantize_x4x2_##helper_prefix##_group_hvx( \ - r1 + byte_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt) \ - : Q6_V_vzero(); \ - \ - Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0); \ - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); \ - Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1); \ - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); \ - } \ - (void) *(volatile HVX_Vector *)(tile_base); \ - } \ - ++t; ++kt; \ - } \ - \ - if (start_tile < end_tile) { \ - (void) *(volatile HVX_Vector *)(state->dst + (end_tile - 1) * HMX_FP16_TILE_N_ELMS); \ - } \ -} \ - \ -static void dequantize_x4x2_worker_loop_##suffix(unsigned int n, unsigned int i, void *data) { \ - x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data; \ - struct htp_thread_trace * tr = state->traces ? &state->traces[i] : NULL; \ - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i); \ - for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) { \ - int start = task_id * state->n_tiles_per_task; \ - int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles); \ - dequantize_x4x2_weight_to_fp16_tiles_task_##suffix(state, start, end); \ - } \ - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i); \ -} - -DEFINE_DEQUANTIZE_Q4_TASK(q4_0, q4_0_to_fp16_lut, q4_0, HMX_X4X2_DBLK_SIZE, (int)sizeof(__fp16)) -DEFINE_DEQUANTIZE_Q4_TASK(q4_1, q4_1_to_fp16_lut, q4_1, 32, 4) -DEFINE_DEQUANTIZE_Q4_TASK(iq4_nl, iq4_nl_to_fp16_lut, iq4_nl, HMX_X4X2_DBLK_SIZE, (int)sizeof(__fp16)) - -static void dequantize_x4x2_weight_to_fp16_tiles_task_mxfp4( - const x4x2_dequantize_state_t *state, - int start_tile, int end_tile) { - - const int n_k_tiles = state->n_k_tiles; - const int qrow_size = (unsigned)state->k_block / 2; - const struct fastdiv_values n_k_tiles_div = state->n_k_tiles_div; - const HVX_Vector vlut_cvt = hvx_vmem(mxfp4_to_fp16_lut); - - const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets); - const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); - const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); - - unsigned ct = fastdiv((unsigned)start_tile, &n_k_tiles_div); - unsigned kt = fastmodulo((unsigned)start_tile, n_k_tiles, &n_k_tiles_div); - - for (unsigned t = start_tile; t < (unsigned)end_tile; ) { - if (kt >= (unsigned)n_k_tiles) { kt = 0; ct++; } - - // Batch-4 fast path for MXFP4 - if ((kt % 4 == 0) && (t + 4 <= (unsigned)end_tile) && (fastdiv(t + 3, &n_k_tiles_div) == ct)) { - int blk_idx = (kt * 32) / QK_MXFP4x4x2; - int sub_blk_base = ((kt * 32) % QK_MXFP4x4x2) / 32; - bool upper = (sub_blk_base >= 4); - int packed_off = blk_idx * (QK_MXFP4x4x2 / 2); - int e8m0_blk_off = qrow_size + blk_idx * HMX_X4X2_MXFP4_EBLK_SIZE; - - __fp16 * tile_bases[4]; - for (int g = 0; g < 4; g++) { - tile_bases[g] = state->dst + (t + g) * HMX_FP16_TILE_N_ELMS; - } - - HVX_Vector v_off = v_scat_base; - for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) { - int row0 = ct * HMX_FP16_TILE_N_COLS + r; - int row1 = row0 + 1; - const uint8_t * r0 = state->src + row0 * state->row_stride; - const uint8_t * r1 = state->src + row1 * state->row_stride; - - mxfp4_scales_t r0_e8 = mxfp4_convert_scales(r0 + e8m0_blk_off); - - HVX_Vector_x4 dv0, dv1; - dv0 = dequantize_x4x2_mxfp4_x4groups_hvx(r0 + packed_off, upper, sub_blk_base, vlut_cvt, r0_e8); - if (row1 < state->n_cols) { - mxfp4_scales_t r1_e8 = mxfp4_convert_scales(r1 + e8m0_blk_off); - dv1 = dequantize_x4x2_mxfp4_x4groups_hvx(r1 + packed_off, upper, sub_blk_base, vlut_cvt, r1_e8); - } else { - dv1.v[0] = dv1.v[1] = dv1.v[2] = dv1.v[3] = Q6_V_vzero(); - } - - for (int g = 0; g < 4; g++) { - Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[g]); - } - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); - for (int g = 0; g < 4; g++) { - Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[g]); - } - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); - } - - for (int g = 0; g < 4; g++) { - (void) *(volatile HVX_Vector *) (tile_bases[g]); - } - - t += 4; kt += 4; - continue; - } - - // Single-tile fallback - __fp16 *tile_base = state->dst + t * HMX_FP16_TILE_N_ELMS; - { - int blk_idx = (kt * 32) / QK_MXFP4x4x2; - int sub_blk = ((kt * 32) % QK_MXFP4x4x2) / 32; - bool upper = (sub_blk >= 4); - int byte_off = blk_idx * (QK_MXFP4x4x2 / 2) + (upper ? (sub_blk - 4) : sub_blk) * 32; - int e8m0_blk_off = qrow_size + blk_idx * HMX_X4X2_MXFP4_EBLK_SIZE; - - HVX_Vector v_off = v_scat_base; - for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) { - int row0 = ct * HMX_FP16_TILE_N_COLS + r; - int row1 = row0 + 1; - - const uint8_t * r0 = state->src + row0 * state->row_stride; - const uint8_t * r1 = state->src + row1 * state->row_stride; - - mxfp4_scales_t r0_e8 = mxfp4_convert_scales(r0 + e8m0_blk_off); - - HVX_Vector v0 = dequantize_x4x2_mxfp4_group_hvx(r0 + byte_off, upper, sub_blk, vlut_cvt, r0_e8); - HVX_Vector v1; - if (row1 < state->n_cols) { - mxfp4_scales_t r1_e8 = mxfp4_convert_scales(r1 + e8m0_blk_off); - v1 = dequantize_x4x2_mxfp4_group_hvx(r1 + byte_off, upper, sub_blk, vlut_cvt, r1_e8); - } else { - v1 = Q6_V_vzero(); - } - - Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0); - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); - Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1); - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); - } - (void) *(volatile HVX_Vector *) (tile_base); - } - ++t; ++kt; - } - - if (start_tile < end_tile) { - (void) *(volatile HVX_Vector *)(state->dst + (end_tile - 1) * HMX_FP16_TILE_N_ELMS); - } -} - -static void dequantize_x4x2_worker_loop_mxfp4(unsigned int n, unsigned int i, void *data) { - x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data; - struct htp_thread_trace * tr = state->traces ? &state->traces[i] : NULL; - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i); - for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) { - int start = task_id * state->n_tiles_per_task; - int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles); - dequantize_x4x2_weight_to_fp16_tiles_task_mxfp4(state, start, end); - } - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i); -} - -static void dequantize_x4x2_weight_to_fp16_tiles_task_q8_0( - const x4x2_dequantize_state_t *state, - int start_tile, int end_tile) { - - const int n_k_tiles = state->n_k_tiles; - const int qrow_size = state->k_block; - const struct fastdiv_values n_k_tiles_div = state->n_k_tiles_div; - - const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets); - const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); - const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); - - unsigned ct = fastdiv((unsigned)start_tile, &n_k_tiles_div); - unsigned kt = fastmodulo((unsigned)start_tile, n_k_tiles, &n_k_tiles_div); - - for (unsigned t = start_tile; t < (unsigned)end_tile; ) { - if (kt >= (unsigned)n_k_tiles) { kt = 0; ct++; } - - __fp16 *tile_base = state->dst + t * HMX_FP16_TILE_N_ELMS; - { - int blk_idx = (kt * 32) / QK_Q8_0x4x2; - int sub_blk = ((kt * 32) % QK_Q8_0x4x2) / 32; - int byte_off = blk_idx * QK_Q8_0x4x2 + sub_blk * 32; - int scale_off = qrow_size + blk_idx * HMX_X4X2_DBLK_SIZE + sub_blk * (int)sizeof(__fp16); - - HVX_Vector v_off = v_scat_base; - for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) { - int row0 = ct * HMX_FP16_TILE_N_COLS + r; - int row1 = row0 + 1; - - const uint8_t *r0 = state->src + row0 * state->row_stride; - const uint8_t *r1 = state->src + row1 * state->row_stride; - - HVX_Vector v0 = dequantize_x4x2_q8_0_group_hvx((const int8_t *)(r0 + byte_off), (const __fp16 *)(r0 + scale_off)); - HVX_Vector v1 = (row1 < state->n_cols) ? dequantize_x4x2_q8_0_group_hvx((const int8_t *)(r1 + byte_off), (const __fp16 *)(r1 + scale_off)) : Q6_V_vzero(); - - Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0); - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); - Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1); - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); - } - (void) *(volatile HVX_Vector *)(tile_base); - } - ++t; ++kt; - } - - if (start_tile < end_tile) { - (void) *(volatile HVX_Vector *)(state->dst + (end_tile - 1) * HMX_FP16_TILE_N_ELMS); - } -} - -static void dequantize_x4x2_worker_loop_q8_0(unsigned int n, unsigned int i, void *data) { - x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data; - struct htp_thread_trace * tr = state->traces ? &state->traces[i] : NULL; - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i); - for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) { - int start = task_id * state->n_tiles_per_task; - int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles); - dequantize_x4x2_weight_to_fp16_tiles_task_q8_0(state, start, end); - } - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i); -} - -static void convert_f16_weight_to_fp16_tiles_task( - const x4x2_dequantize_state_t *state, - int start_tile, int end_tile) { - - const int n_k_tiles = state->n_k_tiles; - const struct fastdiv_values n_k_tiles_div = state->n_k_tiles_div; - - const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets); - const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); - const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); - - unsigned ct = fastdiv((unsigned)start_tile, &n_k_tiles_div); - unsigned kt = fastmodulo((unsigned)start_tile, n_k_tiles, &n_k_tiles_div); - - for (unsigned t = start_tile; t < (unsigned)end_tile; ) { - if (kt >= (unsigned)n_k_tiles) { kt = 0; ct++; } - - __fp16 *tile_base = state->dst + t * HMX_FP16_TILE_N_ELMS; - { - int byte_off = kt * 32 * sizeof(__fp16); - - HVX_Vector v_off = v_scat_base; - for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) { - int row0 = ct * HMX_FP16_TILE_N_COLS + r; - int row1 = row0 + 1; - - const uint8_t *r0 = state->src + row0 * state->row_stride; - const uint8_t *r1 = state->src + row1 * state->row_stride; - - HVX_Vector v0 = hvx_vmemu((const __fp16 *)(r0 + byte_off)); - HVX_Vector v1 = (row1 < state->n_cols) ? hvx_vmemu((const __fp16 *)(r1 + byte_off)) : Q6_V_vzero(); - - Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0); - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); - Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1); - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); - } - (void) *(volatile HVX_Vector *)(tile_base); - } - ++t; ++kt; - } - - if (start_tile < end_tile) { - (void) *(volatile HVX_Vector *)(state->dst + (end_tile - 1) * HMX_FP16_TILE_N_ELMS); - } -} - -static void convert_f16_worker_loop(unsigned int n, unsigned int i, void *data) { - x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data; - struct htp_thread_trace * tr = state->traces ? &state->traces[i] : NULL; - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i); - for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) { - int start = task_id * state->n_tiles_per_task; - int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles); - convert_f16_weight_to_fp16_tiles_task(state, start, end); - } - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i); -} - -static void quantize_f32_weight_to_fp16_tiles_task( - const x4x2_dequantize_state_t *state, - int start_tile, int end_tile) { - - const int n_k_tiles = state->n_k_tiles; - const struct fastdiv_values n_k_tiles_div = state->n_k_tiles_div; - - const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets); - const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); - const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); - - unsigned ct = fastdiv((unsigned)start_tile, &n_k_tiles_div); - unsigned kt = fastmodulo((unsigned)start_tile, n_k_tiles, &n_k_tiles_div); - - for (unsigned t = start_tile; t < (unsigned)end_tile; ) { - if (kt >= (unsigned)n_k_tiles) { kt = 0; ct++; } - - __fp16 *tile_base = state->dst + t * HMX_FP16_TILE_N_ELMS; - { - int byte_off = kt * 32 * sizeof(float); - - HVX_Vector v_off = v_scat_base; - for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) { - int row0 = ct * HMX_FP16_TILE_N_COLS + r; - int row1 = row0 + 1; - - const uint8_t *r0 = state->src + row0 * state->row_stride; - const uint8_t *r1 = state->src + row1 * state->row_stride; - - HVX_Vector v0_f32 = hvx_vmemu((const float *)(r0 + byte_off)); - HVX_Vector v1_f32 = (row1 < state->n_cols) ? hvx_vmemu((const float *)(r1 + byte_off)) : Q6_V_vzero(); - - HVX_Vector v_out = hvx_vec_f32_to_f16(v0_f32, v1_f32); - - Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v_out); - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); - - HVX_Vector v_out_hi = Q6_V_vror_VR(v_out, 64); - Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v_out_hi); - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); - } - (void) *(volatile HVX_Vector *)(tile_base); - } - ++t; ++kt; - } - - if (start_tile < end_tile) { - (void) *(volatile HVX_Vector *)(state->dst + (end_tile - 1) * HMX_FP16_TILE_N_ELMS); - } -} - -static void quantize_f32_worker_loop(unsigned int n, unsigned int i, void *data) { - x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data; - struct htp_thread_trace * tr = state->traces ? &state->traces[i] : NULL; - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i); - for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) { - int start = task_id * state->n_tiles_per_task; - int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles); - quantize_f32_weight_to_fp16_tiles_task(state, start, end); - } - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i); -} - - -static void dequantize_x4x2_weight_chunk_to_fp16_tiles( - struct htp_context *ctx, __fp16 *vtcm_dst, - const void *vtcm_src, int n_cols, int k_block, - size_t row_stride, int weight_type, - int n_k_tiles, struct fastdiv_values n_k_tiles_div, - worker_callback_t dequant_worker_fn, int n_threads) { - - assert(n_cols % HMX_FP16_TILE_N_COLS == 0); - assert(k_block % HMX_FP16_TILE_N_COLS == 0); - - size_t n_col_tiles = n_cols / HMX_FP16_TILE_N_COLS; - size_t n_tot_tiles = n_col_tiles * n_k_tiles; - - size_t n_tiles_per_task = (n_threads == 1) ? n_tot_tiles : hmx_ceil_div(n_tot_tiles, n_threads); - - x4x2_dequantize_state_t state; - state.n_tasks = (n_tot_tiles + n_tiles_per_task - 1) / n_tiles_per_task; - state.n_tot_tiles = n_tot_tiles; - state.n_tiles_per_task = n_tiles_per_task; - state.dst = vtcm_dst; - state.src = (const uint8_t *)vtcm_src; - state.n_cols = n_cols; - state.k_block = k_block; - state.row_stride = row_stride; - state.weight_type = weight_type; - state.n_k_tiles = n_k_tiles; - state.n_k_tiles_div = n_k_tiles_div; - state.traces = ctx ? ctx->trace : NULL; - - if (state.n_tasks == 1 || n_threads == 1) { - dequant_worker_fn(1, 0, &state); - } else { - worker_pool_run_func(ctx->worker_pool, dequant_worker_fn, &state, n_threads); - } -} - -// --- End x4x2 dequantizers --- - -#pragma clang diagnostic ignored "-Wbackend-plugin" // spurios warning for hmx intrinsics - -// requires external HMX lock -static void core_dot_chunk_fp16(__fp16 *restrict output, const __fp16 *restrict activation, const __fp16 *restrict weight, const __fp16 *restrict scales, - int n_row_tiles, int n_col_tiles, int n_dot_tiles) { - __builtin_assume(n_row_tiles > 0); - __builtin_assume(n_col_tiles > 0); - __builtin_assume(n_dot_tiles > 0); - - Q6_bias_mxmem2_A((void *)scales); - for (int r = 0; r < n_row_tiles; ++r) { - for (size_t c = 0; c < n_col_tiles; ++c) { - Q6_mxclracc_hf(); - - const __fp16 *row_tiles = activation + r * n_dot_tiles * HMX_FP16_TILE_N_ELMS; - const __fp16 *col_tiles = weight + c * n_dot_tiles * HMX_FP16_TILE_N_ELMS; - - for (int k = 0, k_block; k < n_dot_tiles; k += k_block) { - k_block = hex_smin(n_dot_tiles - k, 32); - const uint32_t range = 2048u * (uint32_t)k_block - 1; - Q6_activation_hf_mxmem_RR_deep((unsigned int)row_tiles, range); - Q6_weight_hf_mxmem_RR((unsigned int)col_tiles, range); - row_tiles += k_block * HMX_FP16_TILE_N_ELMS; - col_tiles += k_block * HMX_FP16_TILE_N_ELMS; - } - - __fp16 *out_tile = output + (r * n_col_tiles + c) * HMX_FP16_TILE_N_ELMS; - Q6_mxmem_AR_after_hf(out_tile, 0); - } - } -} - -// --- Async HMX matmul job (for pipeline overlap) --- - -typedef struct { - __fp16 * output; - const __fp16 * activation; - const __fp16 * weight; - const __fp16 * scales; - uint32_t n_row_tiles; - uint32_t n_col_tiles; - uint32_t n_dot_tiles; -} hmx_matmul_job_t; - -static void hmx_matmul_worker_fn(void * data) { - hmx_matmul_job_t * job = (hmx_matmul_job_t *) data; - FARF(HIGH, "hmx-mm-job: n_row_tiles %u n_col_tiles %u n_dot_tiles %u", job->n_row_tiles, job->n_col_tiles, job->n_dot_tiles); - core_dot_chunk_fp16(job->output, job->activation, job->weight, job->scales, job->n_row_tiles, job->n_col_tiles, job->n_dot_tiles); -} - -static inline void hmx_matmul_job_init(hmx_matmul_job_t * job, - __fp16 * output, - const __fp16 * activation, - const __fp16 * weight, - const __fp16 * scales, - int n_row_tiles, - int n_col_tiles, - int n_dot_tiles) { - job->output = output; - job->activation = activation; - job->weight = weight; - job->scales = scales; - job->n_row_tiles = n_row_tiles; - job->n_col_tiles = n_col_tiles; - job->n_dot_tiles = n_dot_tiles; -} - -// output : fp16 -> f32p - -static void transfer_output_chunk_fp16_to_fp32(float *restrict dst, const __fp16 *restrict vtcm_src, int n_rows, int n_cols, int n) { - assert(n_cols % HMX_FP16_TILE_N_COLS == 0); - const size_t tile_row_stride = (n_cols / HMX_FP16_TILE_N_COLS) * HMX_FP16_TILE_N_ELMS; - - const HVX_Vector one = hvx_vec_splat_f16(1.0); - - for (size_t r = 0; r < n_rows; r += 2) { - const size_t r0 = r / HMX_FP16_TILE_N_ROWS; - const size_t r1 = (r % HMX_FP16_TILE_N_ROWS) / 2; // index of the row pair within the tile - const __fp16 *row_base = vtcm_src + r0 * tile_row_stride; - float *output_row_base = dst + r * n; // global memory row base for row r (and r+1) - - #pragma unroll(4) - for (size_t c = 0; c < n_cols; c += HMX_FP16_TILE_N_COLS) { - const size_t c0 = c / HMX_FP16_TILE_N_COLS; - const __fp16 *tile = row_base + c0 * HMX_FP16_TILE_N_ELMS; - HVX_Vector v = ((const HVX_Vector *) tile)[r1]; - HVX_VectorPair vp = Q6_Wqf32_vmpy_VhfVhf(v, one); - - volatile HVX_Vector *pv_out0 = (volatile HVX_Vector *) (output_row_base + c + 0); - volatile HVX_Vector *pv_out1 = (volatile HVX_Vector *) (output_row_base + c + n); // next row in global memory - - *pv_out0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vp)); - if (r + 1 < n_rows) { - *pv_out1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vp)); - } - } - } -} - -typedef struct { - const __fp16 *vtcm_src; - float *dst; - int n_tasks; - int n_tot_chunks; - int n_chunks_per_task; - int n_cols; - int n; // DDR row stride (total output columns) - struct htp_thread_trace * traces; -} output_transfer_task_state_t; - -static void transfer_output_chunk_worker_fn(unsigned int n, unsigned int i, void *data) { - output_transfer_task_state_t *st = (output_transfer_task_state_t *) data; - struct htp_thread_trace * tr = st->traces ? &st->traces[i] : NULL; - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_O_PROC, i); - - for (unsigned int task_id = i; task_id < (unsigned int)st->n_tasks; task_id += n) { - int chunk_idx = task_id * st->n_chunks_per_task; - size_t chunk_size = hex_smin(st->n_tot_chunks - chunk_idx, st->n_chunks_per_task); - - float *dst = st->dst + chunk_idx * st->n; - const __fp16 *vtcm_src = st->vtcm_src + chunk_idx * st->n_cols; - transfer_output_chunk_fp16_to_fp32(dst, vtcm_src, chunk_size, st->n_cols, st->n); - } - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_O_PROC, i); -} - -static void transfer_output_chunk_threaded(struct htp_context *ctx, float *dst, const __fp16 *vtcm_src, - int n_rows, int n_cols, int n, int n_threads) { - assert(n_cols % HMX_FP16_TILE_N_COLS == 0); - - size_t n_tot_chunks = n_rows; - size_t n_chunks_per_task = (n_threads == 1) ? n_tot_chunks : HMX_FP16_TILE_N_ROWS; // must be multiple of HMX_FP16_TILE_N_ROWS (32) - - output_transfer_task_state_t state; - state.n_tasks = (n_tot_chunks + n_chunks_per_task - 1) / n_chunks_per_task; - state.n_tot_chunks = n_tot_chunks; - state.n_chunks_per_task = n_chunks_per_task; - state.dst = dst; - state.vtcm_src = vtcm_src; - state.n_cols = n_cols; - state.n = n; - state.traces = ctx ? ctx->trace : NULL; - - if (state.n_tasks == 1 || n_threads == 1) { - transfer_output_chunk_worker_fn(1, 0, &state); - } else { - worker_pool_run_func(ctx->worker_pool, transfer_output_chunk_worker_fn, &state, n_threads); - } -} - -// activations : fp32 -> fp16 - -static void transfer_activation_chunk_fp32_to_fp16(__fp16 *restrict vtcm_dst, const float *restrict src, int n_rows, int k_block, int k_stride) { - const int n_rows_padded = hex_align_up(n_rows, HMX_FP16_TILE_N_ROWS); - const int n_rows_tiled = (n_rows / HMX_FP16_TILE_N_ROWS) * HMX_FP16_TILE_N_ROWS; - - int r = 0; - - #pragma unroll(2) - for (r = 0; r < n_rows_tiled; r += 2) { - int r0 = r / HMX_FP16_TILE_N_ROWS; // tile row index - int r1 = r % HMX_FP16_TILE_N_ROWS; // intra-tile row idx - - const HVX_Vector *pv_in0 = (const HVX_Vector *) (src + (r + 0) * k_stride); - const HVX_Vector *pv_in1 = (const HVX_Vector *) (src + (r + 1) * k_stride); - for (int c = 0; c < k_block; c += 32) { - HVX_Vector v0 = *pv_in0++; - HVX_Vector v1 = *pv_in1++; - - HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); - - // compute output position - int c0 = c / HMX_FP16_TILE_N_COLS; // tile column index - int tile_idx = r0 * (k_block / HMX_FP16_TILE_N_COLS) + c0; - - HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HMX_FP16_TILE_N_ELMS); - tile[r1 / 2] = v_out; - } - } - - for (; r < n_rows_padded; r += 2) { - int r0 = r / HMX_FP16_TILE_N_ROWS; // tile row index - int r1 = r % HMX_FP16_TILE_N_ROWS; // intra-tile row idx - - const bool row0_valid = r < n_rows; - const bool row1_valid = (r + 1) < n_rows; - - const HVX_Vector *pv_in0 = row0_valid ? (const HVX_Vector *) (src + (r + 0) * k_stride) : NULL; - const HVX_Vector *pv_in1 = row1_valid ? (const HVX_Vector *) (src + (r + 1) * k_stride) : NULL; - for (int c = 0; c < k_block; c += 32) { - HVX_Vector v0 = row0_valid ? *pv_in0++ : Q6_V_vzero(); - HVX_Vector v1 = row1_valid ? *pv_in1++ : Q6_V_vzero(); - - HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); - - // compute output position - int c0 = c / HMX_FP16_TILE_N_COLS; // tile column index - int tile_idx = r0 * (k_block / HMX_FP16_TILE_N_COLS) + c0; - - HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HMX_FP16_TILE_N_ELMS); - tile[r1 / 2] = v_out; - } - } -} - -typedef struct { - __fp16 *dst; - const float *src; - int n_tasks; - int n_tot_chunks; - int n_chunks_per_task; - int k_block; - int k_stride; - struct htp_thread_trace * traces; -} activation_transfer_task_state_t; - -static void transfer_activation_chunk_worker_fn(unsigned int n, unsigned int i, void *data) { - activation_transfer_task_state_t *st = (activation_transfer_task_state_t *) data; - struct htp_thread_trace * tr = st->traces ? &st->traces[i] : NULL; - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_PREP, i); - - for (unsigned int task_id = i; task_id < (unsigned int)st->n_tasks; task_id += n) { - // one chunk: one row - int chunk_idx = task_id * st->n_chunks_per_task; - size_t chunk_size = hex_smin(st->n_tot_chunks - chunk_idx, st->n_chunks_per_task); - - __fp16 *dst = st->dst + chunk_idx * st->k_block; - const float *src = st->src + chunk_idx * st->k_stride; - transfer_activation_chunk_fp32_to_fp16(dst, src, chunk_size, st->k_block, st->k_stride); - } - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_PREP, i); -} - -static void transfer_activation_chunk_threaded(struct htp_context *ctx, __fp16 *dst, const float *src, int n_rows, int k_block, int k_stride, int n_threads) { - assert(k_block % HMX_FP16_TILE_N_COLS == 0 && k_stride % HMX_FP16_TILE_N_COLS == 0); - assert(VLEN == 32 * sizeof(float)); - - size_t n_tot_chunks = n_rows; - size_t n_chunks_per_task = (n_threads == 1) ? n_tot_chunks : 32; // must be multiple of 32 to ensure correct destination address - - activation_transfer_task_state_t state; - state.n_tasks = (n_tot_chunks + n_chunks_per_task - 1) / n_chunks_per_task; - state.n_tot_chunks = n_tot_chunks; - state.n_chunks_per_task = n_chunks_per_task; - state.dst = dst; - state.src = src; - state.k_block = k_block; - state.k_stride = k_stride; - state.traces = ctx ? ctx->trace : NULL; - - if (state.n_tasks == 1 || n_threads == 1) { - transfer_activation_chunk_worker_fn(1, 0, &state); - } else { - worker_pool_run_func(ctx->worker_pool, transfer_activation_chunk_worker_fn, &state, n_threads); - } -} - -// C += AB -static void core_mma_chunk_fp16(__fp16 *restrict c, const __fp16 *restrict a, const __fp16 *restrict b, - const __fp16 *restrict col_scales, const __fp16 *restrict eye_tile, - int n_row_tiles, int n_col_tiles, int n_dot_tiles, bool zero_init) { - __builtin_assume(n_row_tiles > 0); - __builtin_assume(n_col_tiles > 0); - __builtin_assume(n_dot_tiles > 0); - - Q6_bias_mxmem2_A((void *)col_scales); - - const size_t dot_tile_stride = n_dot_tiles * HMX_FP16_TILE_N_ELMS; - for (size_t i = 0; i < n_row_tiles; ++i) { - const __fp16 *row_base = a + i * dot_tile_stride; - __fp16 *res_base = c + i * n_col_tiles * HMX_FP16_TILE_N_ELMS; - for (size_t j = 0; j < n_col_tiles; ++j) { - Q6_mxclracc_hf(); - - const __fp16 *col_tiles = b + j * dot_tile_stride; - const __fp16 *row_tiles = row_base; - __fp16 *accum_tile = res_base + j * HMX_FP16_TILE_N_ELMS; - if (!zero_init) { - Q6_activation_hf_mxmem_RR((unsigned int)accum_tile, 2047); - Q6_weight_hf_mxmem_RR((unsigned int)eye_tile, 2047); - } - - for (int k = 0, k_block; k < n_dot_tiles; k += k_block) { - k_block = hex_smin(n_dot_tiles - k, 32); - const uint32_t range = 2048u * (uint32_t)k_block - 1; - Q6_activation_hf_mxmem_RR_deep((unsigned int)row_tiles, range); - Q6_weight_hf_mxmem_RR((unsigned int)col_tiles, range); - row_tiles += k_block * HMX_FP16_TILE_N_ELMS; - col_tiles += k_block * HMX_FP16_TILE_N_ELMS; - } - - Q6_mxmem_AR_after_hf(accum_tile, 0); - } - } -} - -int hmx_matmul_2d_f32(struct htp_context *ctx, float *restrict dst, const float *restrict activation, - const uint8_t *restrict permuted_weight, int m, int k, int n, - int act_stride, int weight_stride, int weight_type) { - if (k % 32 != 0 || n % 32 != 0) { return -1; } - - if (!hex_is_aligned(dst, VLEN) || !hex_is_aligned(activation, VLEN) || !hex_is_aligned(permuted_weight, VLEN)) { - return -1; - } - - size_t row_stride = get_x4x2_row_stride(weight_type, k); - if (row_stride == 0) { - return -1; - } - - worker_callback_t dequant_worker_fn = NULL; - switch (weight_type) { - case HTP_TYPE_Q4_0: dequant_worker_fn = dequantize_x4x2_worker_loop_q4_0; break; - case HTP_TYPE_IQ4_NL: dequant_worker_fn = dequantize_x4x2_worker_loop_iq4_nl; break; - case HTP_TYPE_Q4_1: dequant_worker_fn = dequantize_x4x2_worker_loop_q4_1; break; - case HTP_TYPE_MXFP4: dequant_worker_fn = dequantize_x4x2_worker_loop_mxfp4; break; - case HTP_TYPE_Q8_0: dequant_worker_fn = dequantize_x4x2_worker_loop_q8_0; break; - case HTP_TYPE_F16: dequant_worker_fn = convert_f16_worker_loop; break; - case HTP_TYPE_F32: dequant_worker_fn = quantize_f32_worker_loop; break; - default: - return -1; - } - - const int n_k_tiles = k / HMX_FP16_TILE_N_COLS; - const struct fastdiv_values n_k_tiles_div = init_fastdiv_values(n_k_tiles); - - // --- Dynamic Mode Configuration --- - const bool use_pipeline = (m > 32); - const int num_threads = (m <= 32) ? 1 : ctx->n_threads; - - // --- Dynamic VTCM layout --- - const size_t vec_dot_size = k * sizeof(__fp16); - const size_t vtcm_budget = ctx->vtcm_size; - size_t vtcm_used = 0; - - // Pipeline = 4-stage DMA→dequant→HMX→store with HMX worker overlap. - const size_t size_per_n = row_stride + (use_pipeline ? 2 * vec_dot_size : vec_dot_size); // Q + S0 + S1 (dequant bufs) - const size_t size_per_mn = (use_pipeline ? 2 : 1) * sizeof(__fp16); // O x 2 (output double buffer) - - size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0; - if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, size_per_n, /*per_m=*/vec_dot_size, size_per_mn, - hex_align_up(m, HMX_FP16_TILE_N_ROWS), n, - /*m_block_cost=*/(size_t) n * 3, - /*n_block_cost=*/(size_t) m * 2, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used)) { - FARF(HIGH, "hmx-mm-2d: VTCM too small : m %d k %d n %d budget %zu", m, k, n, vtcm_budget); - return -1; - } - - const size_t weight_area_size = hex_align_up(n_chunk_n_cols * row_stride, HMX_FP16_TILE_SIZE); - const size_t act_area_size = hex_align_up(m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE); - const size_t output_area_size = hex_align_up(m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE); - - size_t scratch0_size, scratch1_size, scratch2_size; - scratch0_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); // dequant buf 0 - scratch1_size = use_pipeline ? scratch0_size : 0; // dequant buf 1 - scratch2_size = use_pipeline ? output_area_size : 0; // output buf 1 - - uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; - __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); - __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, act_area_size); - __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size); - void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch0_size); - void *vtcm_scratch1 = scratch1_size ? vtcm_seq_alloc(&vtcm_ptr, scratch1_size) : NULL; - void *vtcm_scratch2 = scratch2_size ? vtcm_seq_alloc(&vtcm_ptr, scratch2_size) : NULL; - __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); - - vtcm_used = vtcm_ptr - (uint8_t *) ctx->vtcm_base; - if (vtcm_used > vtcm_budget) { - FARF(ERROR, "hmx-mm-2d: VTCM overflow: used %zu budget %zu", vtcm_used, vtcm_budget); - return -1; - } - - hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 - - FARF(HIGH, "hmx-mm-2d: standard : m %d k %d n %d wtype %d mc %zu nc %zu vtcm %zu/%zu", - m, k, n, weight_type, m_chunk_n_rows, n_chunk_n_cols, vtcm_used, vtcm_budget); - - - - int n_chunk_cnt = hmx_ceil_div(n, n_chunk_n_cols); - - if (use_pipeline) { - // --- Asynchronous Pipelined Loop (Current implementation) --- - hmx_matmul_job_t job_slots[2]; // persistent double-buffered job descriptors - - for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { - const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); - - void *vtcm_qweight = vtcm_weight; - void *vtcm_weight_bufs[2] = { vtcm_scratch0, vtcm_scratch1 }; - void *vtcm_output_bufs[2] = { vtcm_output, vtcm_scratch2 }; - - // prologue: A0 - const size_t n_cols_A0 = hex_smin(n - 0 * n_chunk_n_cols, n_chunk_n_cols); - { - const uint8_t *qweight_chunk_A0 = permuted_weight; - dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A0), row_stride, weight_stride, row_stride, n_cols_A0); - } - - { - const float *activation_chunk = activation + mr * act_stride; - transfer_activation_chunk_threaded(ctx, vtcm_activation, activation_chunk, n_rows, k, act_stride, num_threads); - } - - // prologue: B0, A1, submit C0 (async), B1 (overlaps C0) - { - // B0: wait for DMA, dequant weight chunk 0 - dma_queue_pop(ctx->dma[0]); - dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[0], vtcm_qweight, n_cols_A0, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn, num_threads); - - // A1: issue DMA for weight chunk 1 - const size_t n_cols_A1 = hex_smin(n - 1 * n_chunk_n_cols, n_chunk_n_cols); - if (1 < n_chunk_cnt) { - const uint8_t *qweight_chunk_A1 = permuted_weight + n_chunk_n_cols * weight_stride; - dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A1), row_stride, weight_stride, row_stride, n_cols_A1); - } - - // submit C0 (non-blocking — HMX worker executes in parallel) - hmx_matmul_job_init(&job_slots[0], (__fp16 *) vtcm_output_bufs[0], (__fp16 *) vtcm_activation, - (__fp16 *) vtcm_weight_bufs[0], vtcm_scales, - hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), - hmx_ceil_div(n_cols_A0, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS); - hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[0])); - - // B1: DMA pop + dequant (runs in parallel with C0 on HMX worker) - if (1 < n_chunk_cnt) { - dma_queue_pop(ctx->dma[0]); - dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[1], vtcm_qweight, n_cols_A1, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn, num_threads); - } - } - - // main loop: wait C_i → submit C_{i+1} → D_i + B_{i+2} (parallel with C_{i+1}) - for (int i = 0; i < n_chunk_cnt; ++i) { - const size_t nc = i * n_chunk_n_cols; - const size_t nc_p1 = nc + 1 * n_chunk_n_cols; - const size_t nc_p2 = nc + 2 * n_chunk_n_cols; - - const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); - const size_t n_cols_p1 = hex_smin(n - nc_p1, n_chunk_n_cols); - const size_t n_cols_p2 = hex_smin(n - nc_p2, n_chunk_n_cols); - - // issue A_{i+2}: DMA push (non-blocking) - if (i + 2 < n_chunk_cnt) { - const uint8_t *qweight_chunk_p2 = permuted_weight + nc_p2 * weight_stride; - dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_p2), row_stride, weight_stride, row_stride, n_cols_p2); - } - - // wait C_i: block until prologue/previous C completes - hmx_queue_pop(ctx->hmx_queue); - - // submit C_{i+1} (non-blocking, overlaps with D_i + B_{i+2} below) - if (i + 1 < n_chunk_cnt) { - hmx_matmul_job_init(&job_slots[(i + 1) % 2], (__fp16 *) vtcm_output_bufs[(i + 1) % 2], - (__fp16 *) vtcm_activation, (__fp16 *) vtcm_weight_bufs[(i + 1) % 2], - vtcm_scales, hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), - hmx_ceil_div(n_cols_p1, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS); - hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[(i + 1) % 2])); - } - - // D_i: store output (multi-thread HVX, parallel with C_{i+1}) - float *output_chunk = dst + (mr * n + nc); - transfer_output_chunk_threaded(ctx, output_chunk, vtcm_output_bufs[i % 2], n_rows, n_cols, n, num_threads); - - // B_{i+2}: DMA pop + dequant (multi-thread HVX, parallel with C_{i+1}) - if (i + 2 < n_chunk_cnt) { - dma_queue_pop(ctx->dma[0]); - dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[(i + 2) % 2], vtcm_qweight, n_cols_p2, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn, num_threads); - } - } - } - hmx_queue_suspend(ctx->hmx_queue); - } else { - // --- Synchronous Loop (Optimized for small/non-pipelined cases) --- - HAP_compute_res_hmx_lock(ctx->vtcm_rctx); - - for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { - const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); - const size_t n_row_tiles = hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS); - - // Load Activation - const float *activation_chunk = activation + mr * act_stride; - transfer_activation_chunk_threaded(ctx, vtcm_activation, activation_chunk, n_rows, k, act_stride, num_threads); - - for (size_t nc = 0; nc < n; nc += n_chunk_n_cols) { - const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); - const size_t n_col_tiles = hmx_ceil_div(n_cols, HMX_FP16_TILE_N_COLS); - - // A: DMA Load Weight - const uint8_t *qweight_chunk = permuted_weight + nc * weight_stride; - dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_weight, qweight_chunk), row_stride, weight_stride, row_stride, n_cols); - dma_queue_pop(ctx->dma[0]); - - // B: Dequantize / Convert Weight - dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_scratch0, vtcm_weight, n_cols, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn, num_threads); - - // C: HMX Compute (Synchronous) - { - struct htp_thread_trace * tr = ctx ? &ctx->trace[HTP_MAX_NTHREADS] : NULL; - htp_trace_event_start(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS); - core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_scratch0, vtcm_scales, n_row_tiles, n_col_tiles, k / HMX_FP16_TILE_N_ROWS); - htp_trace_event_stop(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS); - } - - // D: Output Store - float *output_chunk = dst + (mr * n + nc); - transfer_output_chunk_threaded(ctx, output_chunk, vtcm_output, n_rows, n_cols, n, num_threads); - } - } - HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); - } - - - - return 0; -} - -// - -static inline int hmx_matmul_batch_r2(const hmx_matmul_f16_f32_batched_params_t *params) { - return params->ne02 > 0 ? params->ne12 / params->ne02 : 1; -} - -static inline int hmx_matmul_batch_r3(const hmx_matmul_f16_f32_batched_params_t *params) { - return params->ne03 > 0 ? params->ne13 / params->ne03 : 1; -} - -static inline const __fp16 *hmx_matmul_weight_batch_ptr(const hmx_matmul_f16_f32_batched_params_t *params, - int dst_b2, int dst_b3) { - const int r2 = hmx_matmul_batch_r2(params); - const int r3 = hmx_matmul_batch_r3(params); - return (const __fp16 *) ((const uint8_t *) params->permuted_weight + - (size_t) (dst_b2 / r2) * params->src0_nb2 + - (size_t) (dst_b3 / r3) * params->src0_nb3); -} - -static inline const float *hmx_matmul_activation_batch_ptr(const hmx_matmul_f16_f32_batched_params_t *params, - int dst_b2, int dst_b3) { - return (const float *) ((const uint8_t *) params->activation + - (size_t) dst_b2 * params->src1_nb2 + - (size_t) dst_b3 * params->src1_nb3); -} - -static inline float *hmx_matmul_dst_batch_ptr(const hmx_matmul_f16_f32_batched_params_t *params, - int dst_b2, int dst_b3) { - return (float *) ((uint8_t *) params->dst + - (size_t) dst_b2 * params->dst_nb2 + - (size_t) dst_b3 * params->dst_nb3); -} - -static int hmx_matmul_f16_f32_batched_legacy(struct htp_context *ctx, - const hmx_matmul_f16_f32_batched_params_t *params) { - int ret = 0; - for (int b3 = 0; b3 < params->ne13 && ret == 0; ++b3) { - for (int b2 = 0; b2 < params->ne12 && ret == 0; ++b2) { - ret = hmx_matmul_f16_f32(ctx, hmx_matmul_dst_batch_ptr(params, b2, b3), - hmx_matmul_activation_batch_ptr(params, b2, b3), - hmx_matmul_weight_batch_ptr(params, b2, b3), - params->m, params->k, params->n, - params->act_stride, params->weight_stride); - } - } - return ret; -} - -int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32_batched_params_t *params) { - if (!ctx || !params || !params->dst || !params->activation || !params->permuted_weight) { return -1; } - if (!params->m || !params->k || !params->n) { return -1; } - if (params->act_stride < params->k || params->weight_stride < params->k || params->dst_stride < params->n) { return -1; } - if (params->ne02 <= 0 || params->ne03 <= 0 || params->ne12 <= 0 || params->ne13 <= 0) { return -1; } - if (params->ne12 % params->ne02 != 0 || params->ne13 % params->ne03 != 0) { return -1; } - if (params->k % 32 != 0 || params->n % 32 != 0) { return -1; } - - if (!hex_is_aligned(params->dst, VLEN) || - !hex_is_aligned(params->activation, VLEN) || - !hex_is_aligned(params->permuted_weight, VLEN)) { - return -1; - } - - const int group_size = hmx_matmul_batch_r2(params); - - if (group_size <= 1) { - FARF(HIGH, "%s: no dim2 GQA reuse (group=%d), using legacy batched loop", __func__, group_size); - return hmx_matmul_f16_f32_batched_legacy(ctx, params); - } - - // Grouped path: reuse interleaved weight across all q_heads sharing a - // kv_head. Each q_head gets its own activation buffer in VTCM (so - // activation is loaded once per m_chunk and reused across all n_chunks), - // and each q_head is computed individually to avoid tile-major packing - // issues. m_chunk_n_rows is always a multiple of 32 (from - // hmx_compute_chunks), so per-head tile arrays don't overlap. - const size_t vtcm_budget = ctx->vtcm_size; - const size_t vec_dot_size = params->k * sizeof(__fp16); - - // When the activation has a large stride (e.g. permuted Q tensor with - // act_stride >> k), HVX vector loads from strided DDR thrash L2 cache. - // Allocate an F32 scratch buffer in VTCM and use 2D DMA to gather - // strided rows into a contiguous block before the F32->F16 conversion. - const bool use_dma_activation = (params->act_stride > params->k); - const size_t f32_scratch_per_m = use_dma_activation ? (size_t) params->k * sizeof(float) : 0; - - size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0, vtcm_used = 0; - // FP16 weight: interleave and activation load have similar per-element cost. - if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, - /*per_n=*/3 * vec_dot_size, - /*per_m=*/group_size * vec_dot_size + f32_scratch_per_m, - /*per_mn=*/sizeof(__fp16), - hex_align_up(params->m, HMX_FP16_TILE_N_ROWS), params->n, - /*m_block_cost=*/(size_t) params->n, - /*n_block_cost=*/(size_t) params->m, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { - FARF(HIGH, "%s: grouped path does not fit VTCM, falling back to legacy batched loop", __func__); - return hmx_matmul_f16_f32_batched_legacy(ctx, params); - } - - const size_t act_head_stride = m_chunk_n_rows * (size_t) params->k; // fp16 elements between heads - const size_t weight_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); - const size_t activation_area_size = hex_align_up(group_size * m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE); - const size_t output_area_size = hex_align_up(m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE); - const size_t scratch_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); - const size_t f32_scratch_size = use_dma_activation - ? hex_align_up(m_chunk_n_rows * (size_t) params->k * sizeof(float), HMX_FP16_TILE_SIZE) : 0; - - uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; - __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); - __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, activation_area_size); - __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size); - void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size); - void *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size); - __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); - float *vtcm_f32_act = use_dma_activation ? (float *) vtcm_seq_alloc(&vtcm_ptr, f32_scratch_size) : NULL; - - if ((size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base) > vtcm_budget) { - FARF(HIGH, "%s: grouped layout overflowed VTCM, falling back to legacy batched loop", __func__); - return hmx_matmul_f16_f32_batched_legacy(ctx, params); - } - - hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 - - FARF(HIGH, "%s: grouped path m=%d k=%d n=%d group=%d streams=%d mc=%zu nc=%zu vtcm=%zu/%zu", - __func__, params->m, params->k, params->n, group_size, params->ne13, - m_chunk_n_rows, n_chunk_n_cols, - (size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base), vtcm_budget); - - - - const size_t fp16_row_bytes = (size_t) params->k * sizeof(__fp16); - const size_t weight_row_bytes = (size_t) params->weight_stride * sizeof(__fp16); - - HAP_compute_res_hmx_lock(ctx->vtcm_rctx); - - for (int b3 = 0; b3 < params->ne13; ++b3) { - for (int b2_base = 0; b2_base < params->ne12; b2_base += group_size) { - const __fp16 *weight_group = hmx_matmul_weight_batch_ptr(params, b2_base, b3); - - for (size_t mr = 0; mr < (size_t) params->m; mr += m_chunk_n_rows) { - const size_t n_rows = hex_smin((size_t) params->m - mr, m_chunk_n_rows); - const size_t n_row_tiles = hmx_ceil_div((int) n_rows, HMX_FP16_TILE_N_ROWS); - - // Pre-load activations for all heads in the group (once per m_chunk). - // When the source is strided (permuted Q), use 2D DMA to gather - // contiguous rows into a VTCM scratch buffer first, then HVX - // converts from the contiguous VTCM buffer. This avoids L2 cache - // thrashing from HVX loads at large strides. - for (int g = 0; g < group_size; ++g) { - const float *activation_chunk = hmx_matmul_activation_batch_ptr(params, b2_base + g, b3) + mr * params->act_stride; - __fp16 *vtcm_act_g = vtcm_activation + (size_t) g * act_head_stride; - if (use_dma_activation) { - const size_t row_bytes = (size_t) params->k * sizeof(float); - const size_t stride_bytes = (size_t) params->act_stride * sizeof(float); - dma_queue_push(ctx->dma[0], - dma_make_ptr(vtcm_f32_act, activation_chunk), - row_bytes, stride_bytes, row_bytes, n_rows); - dma_queue_pop(ctx->dma[0]); - transfer_activation_chunk_threaded(ctx, vtcm_act_g, - vtcm_f32_act, (int) n_rows, - params->k, params->k, ctx->n_threads); - } else { - transfer_activation_chunk_threaded(ctx, vtcm_act_g, - activation_chunk, (int) n_rows, - params->k, params->act_stride, ctx->n_threads); - } - } - - void *buf_curr = vtcm_scratch0; - void *buf_next = vtcm_scratch1; - - { - const size_t n_cols_first = hex_smin((size_t) params->n, n_chunk_n_cols); - dma_queue_push(ctx->dma[0], dma_make_ptr(buf_curr, weight_group), - fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_first); - } - - for (size_t nc = 0; nc < (size_t) params->n; nc += n_chunk_n_cols) { - const size_t n_cols = hex_smin((size_t) params->n - nc, n_chunk_n_cols); - const size_t n_col_tiles = hmx_ceil_div((int) n_cols, HMX_FP16_TILE_N_COLS); - - { - dma_queue_pop(ctx->dma[0]); - - const size_t nc_next = nc + n_chunk_n_cols; - if (nc_next < (size_t) params->n) { - const size_t n_cols_next = hex_smin((size_t) params->n - nc_next, n_chunk_n_cols); - const __fp16 *next_weight_chunk = weight_group + nc_next * params->weight_stride; - - dma_queue_push(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk), - fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_next); - } - - hmx_interleave_rows_to_tiles(vtcm_weight, (const __fp16 *) buf_curr, n_cols, params->k, params->k, - 0, n_cols); - hex_swap_ptr(&buf_curr, &buf_next); - } - - // Reuse the interleaved weight for every q_head in this GQA group - for (int g = 0; g < group_size; ++g) { - { - const __fp16 * vtcm_act_g = vtcm_activation + (size_t) g * act_head_stride; - struct htp_thread_trace * tr = ctx ? &ctx->trace[HTP_MAX_NTHREADS] : NULL; - htp_trace_event_start(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS); - core_dot_chunk_fp16(vtcm_output, vtcm_act_g, vtcm_weight, vtcm_scales, n_row_tiles, n_col_tiles, - params->k / 32); - htp_trace_event_stop(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS); - } - - { - float *output = hmx_matmul_dst_batch_ptr(params, b2_base + g, b3) + mr * params->dst_stride + nc; - transfer_output_chunk_threaded(ctx, output, vtcm_output, (int) n_rows, (int) n_cols, params->dst_stride, ctx->n_threads); - } - } - } - } - } - } - - HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); - - - - return 0; -} - -int hmx_matmul_f16_f32(struct htp_context *ctx, float *restrict dst, const float *restrict activation, - const __fp16 *restrict permuted_weight, int m, int k, int n, - int act_stride, int weight_stride) { - if (!dst || !activation || !permuted_weight || !m || !n || !k) { return -1; } - return hmx_matmul_2d_f32(ctx, dst, activation, (const uint8_t *)permuted_weight, m, k, n, - act_stride, weight_stride * (int)sizeof(__fp16), HTP_TYPE_F16); -} - -struct mmid_row_mapping { - uint32_t i1; - uint32_t i2; -}; - -typedef struct { - __fp16 *dst; - const float *src; - int n_tasks; - int n_tot_chunks; - int n_chunks_per_task; - int k_block; - const struct mmid_row_mapping *matrix_rows; - int cur_a; - int mapping_stride; - int ne11; - struct fastdiv_values ne11_div; - size_t nb11; - size_t nb12; - int start_row; - int cne1; - struct htp_thread_trace *traces; -} activation_transfer_gathered_task_state_t; - -typedef struct { - const __fp16 *vtcm_src; - float *dst; - int n_tasks; - int n_tot_chunks; - int n_chunks_per_task; - int n_cols; - const struct mmid_row_mapping *matrix_rows; - int cur_a; - int mapping_stride; - size_t dst_nb1; - size_t dst_nb2; - int start_row; - int cne1; - struct htp_thread_trace *traces; -} output_transfer_scattered_task_state_t; - -static void transfer_activation_chunk_fp32_to_fp16_gathered( - __fp16 *restrict vtcm_dst, - const float *restrict src, - int start_row, - int n_rows, - int k_block, - const struct mmid_row_mapping *matrix_rows, - int cur_a, - int mapping_stride, - int ne11, - const struct fastdiv_values * ne11_div, - size_t nb11, - size_t nb12, - int cne1) { - const int n_rows_padded = hex_align_up(n_rows, HMX_FP16_TILE_N_ROWS); - const int n_rows_tiled = (n_rows / HMX_FP16_TILE_N_ROWS) * HMX_FP16_TILE_N_ROWS; - - int r = 0; - - #pragma unroll(2) - for (r = 0; r < n_rows_tiled; r += 2) { - int r0 = r / HMX_FP16_TILE_N_ROWS; // tile row index - int r1 = r % HMX_FP16_TILE_N_ROWS; // intra-tile row idx - - int r_idx0 = start_row + r + 0; - int r_idx1 = start_row + r + 1; - - struct mmid_row_mapping mapping0 = matrix_rows[cur_a * mapping_stride + r_idx0]; - struct mmid_row_mapping mapping1 = matrix_rows[cur_a * mapping_stride + r_idx1]; - - int i11_0 = fastmodulo(mapping0.i1, ne11, ne11_div); - int i11_1 = fastmodulo(mapping1.i1, ne11, ne11_div); - - const float *row0_ptr = (const float *) ((const uint8_t *) src + i11_0 * nb11 + mapping0.i2 * nb12); - const float *row1_ptr = (const float *) ((const uint8_t *) src + i11_1 * nb11 + mapping1.i2 * nb12); - - const HVX_Vector *pv_in0 = (const HVX_Vector *) row0_ptr; - const HVX_Vector *pv_in1 = (const HVX_Vector *) row1_ptr; - - for (int c = 0; c < k_block; c += 32) { - HVX_Vector v0 = *pv_in0++; - HVX_Vector v1 = *pv_in1++; - - HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); - - int c0 = c / HMX_FP16_TILE_N_COLS; // tile column index - int tile_idx = r0 * (k_block / HMX_FP16_TILE_N_COLS) + c0; - - HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HMX_FP16_TILE_N_ELMS); - tile[r1 / 2] = v_out; - } - } - - for (; r < n_rows_padded; r += 2) { - int r0 = r / HMX_FP16_TILE_N_ROWS; // tile row index - int r1 = r % HMX_FP16_TILE_N_ROWS; // intra-tile row idx - - const bool row0_valid = (start_row + r + 0) < cne1; - const bool row1_valid = (start_row + r + 1) < cne1; - - const float *row0_ptr = NULL; - const float *row1_ptr = NULL; - - if (row0_valid) { - struct mmid_row_mapping mapping0 = matrix_rows[cur_a * mapping_stride + (start_row + r + 0)]; - int i11_0 = fastmodulo(mapping0.i1, ne11, ne11_div); - row0_ptr = (const float *) ((const uint8_t *) src + i11_0 * nb11 + mapping0.i2 * nb12); - } - if (row1_valid) { - struct mmid_row_mapping mapping1 = matrix_rows[cur_a * mapping_stride + (start_row + r + 1)]; - int i11_1 = fastmodulo(mapping1.i1, ne11, ne11_div); - row1_ptr = (const float *) ((const uint8_t *) src + i11_1 * nb11 + mapping1.i2 * nb12); - } - - const HVX_Vector *pv_in0 = (const HVX_Vector *) row0_ptr; - const HVX_Vector *pv_in1 = (const HVX_Vector *) row1_ptr; - - for (int c = 0; c < k_block; c += 32) { - HVX_Vector v0 = row0_valid ? *pv_in0++ : Q6_V_vzero(); - HVX_Vector v1 = row1_valid ? *pv_in1++ : Q6_V_vzero(); - - HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); - - int c0 = c / HMX_FP16_TILE_N_COLS; // tile column index - int tile_idx = r0 * (k_block / HMX_FP16_TILE_N_COLS) + c0; - - HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HMX_FP16_TILE_N_ELMS); - tile[r1 / 2] = v_out; - } - } -} - -static void transfer_activation_chunk_gathered_worker_fn(unsigned int n, unsigned int i, void *data) { - activation_transfer_gathered_task_state_t *st = data; - struct htp_thread_trace * tr = st->traces ? &st->traces[i] : NULL; - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_PREP, i); - - int chunk_idx = i; - int chunk_size = st->n_chunks_per_task; - int start_row = st->start_row + chunk_idx * chunk_size; - int n_rows = hex_smin(st->cne1 - start_row, chunk_size); - if (n_rows > 0) { - __fp16 *dst = st->dst + (size_t)(start_row - st->start_row) * st->k_block; - transfer_activation_chunk_fp32_to_fp16_gathered( - dst, st->src, start_row, n_rows, st->k_block, - st->matrix_rows, st->cur_a, st->mapping_stride, - st->ne11, &st->ne11_div, st->nb11, st->nb12, st->cne1); - } - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_PREP, i); -} - -static void transfer_activation_chunk_gathered_threaded( - struct htp_context *ctx, - __fp16 *dst, - const float *src, - int start_row, - int n_rows, - int k_block, - const struct mmid_row_mapping *matrix_rows, - int cur_a, - int mapping_stride, - int ne11, - size_t nb11, - size_t nb12, - int cne1, - int n_threads) { - if (n_rows <= 0) return; - int chunks_per_thread = hmx_ceil_div(n_rows, n_threads); - chunks_per_thread = hex_align_up(chunks_per_thread, HMX_FP16_TILE_N_ROWS); - - int actual_threads = hmx_ceil_div(n_rows, chunks_per_thread); - - activation_transfer_gathered_task_state_t state = { - .dst = dst, - .src = src, - .n_tasks = actual_threads, - .n_tot_chunks = n_rows, - .n_chunks_per_task = chunks_per_thread, - .k_block = k_block, - .matrix_rows = matrix_rows, - .cur_a = cur_a, - .mapping_stride = mapping_stride, - .ne11 = ne11, - .ne11_div = init_fastdiv_values(ne11), - .nb11 = nb11, - .nb12 = nb12, - .start_row = start_row, - .cne1 = cne1, - .traces = ctx ? ctx->trace : NULL, - }; - - if (actual_threads <= 1) { - transfer_activation_chunk_gathered_worker_fn(1, 0, &state); - } else { - worker_pool_run_func(ctx->worker_pool, transfer_activation_chunk_gathered_worker_fn, &state, actual_threads); - } -} - -static void transfer_output_chunk_fp16_to_fp32_scattered( - float *restrict dst, - const __fp16 *restrict vtcm_src, - int start_row, - int n_rows, - int n_cols, - const struct mmid_row_mapping *matrix_rows, - int cur_a, - int mapping_stride, - size_t dst_nb1, - size_t dst_nb2, - int cne1) { - assert(n_cols % HMX_FP16_TILE_N_COLS == 0); - const size_t tile_row_stride = (n_cols / HMX_FP16_TILE_N_COLS) * HMX_FP16_TILE_N_ELMS; - - const HVX_Vector one = hvx_vec_splat_f16(1.0); - - for (size_t r = 0; r < n_rows; r += 2) { - const size_t r0 = r / HMX_FP16_TILE_N_ROWS; - const size_t r1 = (r % HMX_FP16_TILE_N_ROWS) / 2; // index of the row pair within the tile - const __fp16 *row_base = vtcm_src + r0 * tile_row_stride; - - int r_idx0 = start_row + (int)r + 0; - int r_idx1 = start_row + (int)r + 1; - - if (r_idx0 >= cne1) break; - - struct mmid_row_mapping mapping0 = matrix_rows[cur_a * mapping_stride + r_idx0]; - float *output_row0 = (float *) ((uint8_t *) dst + mapping0.i1 * dst_nb1 + mapping0.i2 * dst_nb2); - - float *output_row1 = NULL; - if (r_idx1 < cne1) { - struct mmid_row_mapping mapping1 = matrix_rows[cur_a * mapping_stride + r_idx1]; - output_row1 = (float *) ((uint8_t *) dst + mapping1.i1 * dst_nb1 + mapping1.i2 * dst_nb2); - } - - #pragma unroll(4) - for (size_t c = 0; c < (size_t)n_cols; c += HMX_FP16_TILE_N_COLS) { - const size_t c0 = c / HMX_FP16_TILE_N_COLS; - const __fp16 *tile = row_base + c0 * HMX_FP16_TILE_N_ELMS; - HVX_Vector v = ((const HVX_Vector *) tile)[r1]; - HVX_VectorPair vp = Q6_Wqf32_vmpy_VhfVhf(v, one); - - volatile HVX_Vector *pv_out0 = (volatile HVX_Vector *) (output_row0 + c); - volatile HVX_Vector *pv_out1 = output_row1 ? (volatile HVX_Vector *) (output_row1 + c) : NULL; - - *pv_out0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vp)); - if (pv_out1) { - *pv_out1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vp)); - } - } - } -} - -static void transfer_output_chunk_scattered_worker_fn(unsigned int n, unsigned int i, void *data) { - output_transfer_scattered_task_state_t *st = data; - struct htp_thread_trace * tr = st->traces ? &st->traces[i] : NULL; - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_O_PROC, i); - - int chunk_idx = i; - int chunk_size = st->n_chunks_per_task; - int start_row = st->start_row + chunk_idx * chunk_size; - int n_rows = hex_smin(st->cne1 - start_row, chunk_size); - if (n_rows > 0) { - const __fp16 *src = st->vtcm_src + (size_t)(start_row - st->start_row) * st->n_cols; - transfer_output_chunk_fp16_to_fp32_scattered( - st->dst, src, start_row, n_rows, st->n_cols, - st->matrix_rows, st->cur_a, st->mapping_stride, - st->dst_nb1, st->dst_nb2, st->cne1); - } - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_O_PROC, i); -} - -static void transfer_output_chunk_scattered_threaded( - struct htp_context *ctx, - float *dst, - const __fp16 *vtcm_src, - int start_row, - int n_rows, - int n_cols, - const struct mmid_row_mapping *matrix_rows, - int cur_a, - int mapping_stride, - size_t dst_nb1, - size_t dst_nb2, - int cne1, - int n_threads) { - if (n_rows <= 0) return; - int chunks_per_thread = hmx_ceil_div(n_rows, n_threads); - chunks_per_thread = hex_align_up(chunks_per_thread, HMX_FP16_TILE_N_ROWS); - - int actual_threads = hmx_ceil_div(n_rows, chunks_per_thread); - - output_transfer_scattered_task_state_t state = { - .vtcm_src = vtcm_src, - .dst = dst, - .n_tasks = actual_threads, - .n_tot_chunks = n_rows, - .n_chunks_per_task = chunks_per_thread, - .n_cols = n_cols, - .matrix_rows = matrix_rows, - .cur_a = cur_a, - .mapping_stride = mapping_stride, - .dst_nb1 = dst_nb1, - .dst_nb2 = dst_nb2, - .start_row = start_row, - .cne1 = cne1, - .traces = ctx ? ctx->trace : NULL, - }; - - if (actual_threads <= 1) { - transfer_output_chunk_scattered_worker_fn(1, 0, &state); - } else { - worker_pool_run_func(ctx->worker_pool, transfer_output_chunk_scattered_worker_fn, &state, actual_threads); - } -} - -int hmx_matmul_id_2d_f32(struct htp_context *ctx, - float *restrict dst, - const float *activation, - const uint8_t *permuted_weight, - int m, int k, int n, - int ne11, - size_t act_nb1, size_t act_nb2, - size_t dst_nb1, size_t dst_nb2, - int weight_stride, - int weight_type, - const struct mmid_row_mapping *matrix_rows, - int cur_a, - int mapping_stride) { - const int cne1 = m; - const int m_padded = hex_align_up(m, 32); - - if (k % 32 != 0 || n % 32 != 0) { return -1; } - - if (!hex_is_aligned(dst, VLEN) || !hex_is_aligned(activation, VLEN) || !hex_is_aligned(permuted_weight, VLEN)) { - return -1; - } - - size_t row_stride = get_x4x2_row_stride(weight_type, k); - if (row_stride == 0) { - return -1; - } - - worker_callback_t dequant_worker_fn = NULL; - switch (weight_type) { - case HTP_TYPE_Q4_0: dequant_worker_fn = dequantize_x4x2_worker_loop_q4_0; break; - case HTP_TYPE_IQ4_NL: dequant_worker_fn = dequantize_x4x2_worker_loop_iq4_nl; break; - case HTP_TYPE_Q4_1: dequant_worker_fn = dequantize_x4x2_worker_loop_q4_1; break; - case HTP_TYPE_MXFP4: dequant_worker_fn = dequantize_x4x2_worker_loop_mxfp4; break; - case HTP_TYPE_Q8_0: dequant_worker_fn = dequantize_x4x2_worker_loop_q8_0; break; - case HTP_TYPE_F16: dequant_worker_fn = convert_f16_worker_loop; break; - case HTP_TYPE_F32: dequant_worker_fn = quantize_f32_worker_loop; break; - default: - return -1; - } - - const int n_k_tiles = k / HMX_FP16_TILE_N_COLS; - const struct fastdiv_values n_k_tiles_div = init_fastdiv_values(n_k_tiles); - - const int num_threads = ctx->n_threads; - - const size_t vec_dot_size = k * sizeof(__fp16); - const size_t vtcm_budget = ctx->vtcm_size; - size_t vtcm_used = 0; - - const size_t size_per_n = row_stride + vec_dot_size; - const size_t size_per_mn = sizeof(__fp16); - - size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0; - if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, size_per_n, /*per_m=*/vec_dot_size, size_per_mn, - m_padded, n, - /*m_block_cost=*/(size_t) n * 3, - /*n_block_cost=*/(size_t) m_padded * 2, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used)) { - FARF(HIGH, "hmx-mm-id-2d: VTCM too small : m %d k %d n %d budget %zu", m_padded, k, n, vtcm_budget); - return -1; - } - - const size_t weight_area_size = hex_align_up(n_chunk_n_cols * row_stride, HMX_FP16_TILE_SIZE); - const size_t act_area_size = hex_align_up(m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE); - const size_t output_area_size = hex_align_up(m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE); - - size_t scratch0_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); - - uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; - __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); - __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, act_area_size); - __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size); - void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch0_size); - __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); - - vtcm_used = vtcm_ptr - (uint8_t *) ctx->vtcm_base; - if (vtcm_used > vtcm_budget) { - FARF(ERROR, "hmx-mm-id-2d: VTCM overflow: used %zu budget %zu", vtcm_used, vtcm_budget); - return -1; - } - - hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); - - HAP_compute_res_hmx_lock(ctx->vtcm_rctx); - - for (size_t mr = 0; mr < (size_t) m_padded; mr += m_chunk_n_rows) { - const size_t n_rows = hex_smin(m_padded - mr, m_chunk_n_rows); - const size_t n_row_tiles = hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS); - - transfer_activation_chunk_gathered_threaded( - ctx, vtcm_activation, activation, (int) mr, (int) n_rows, k, - matrix_rows, cur_a, mapping_stride, ne11, act_nb1, act_nb2, cne1, num_threads); - - for (size_t nc = 0; nc < (size_t) n; nc += n_chunk_n_cols) { - const size_t n_cols = hex_smin((size_t) n - nc, n_chunk_n_cols); - const size_t n_col_tiles = hmx_ceil_div(n_cols, HMX_FP16_TILE_N_COLS); - - const uint8_t *qweight_chunk = permuted_weight + nc * weight_stride; - dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_weight, qweight_chunk), row_stride, weight_stride, row_stride, n_cols); - dma_queue_pop(ctx->dma[0]); - - dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_scratch0, vtcm_weight, n_cols, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn, num_threads); - - { - struct htp_thread_trace * tr = ctx ? &ctx->trace[HTP_MAX_NTHREADS] : NULL; - htp_trace_event_start(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS); - core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_scratch0, vtcm_scales, n_row_tiles, n_col_tiles, k / HMX_FP16_TILE_N_ROWS); - htp_trace_event_stop(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS); - } - - transfer_output_chunk_scattered_threaded( - ctx, dst, vtcm_output, (int) mr, (int) n_rows, (int) n_cols, - matrix_rows, cur_a, mapping_stride, dst_nb1, dst_nb2, cne1, num_threads); - } - } - - HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); - return 0; -} diff --git a/ggml/src/ggml-hexagon/htp/hmx-mm-kernels-tiled.h b/ggml/src/ggml-hexagon/htp/hmx-mm-kernels-tiled.h new file mode 100644 index 0000000000..b7fba22a87 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hmx-mm-kernels-tiled.h @@ -0,0 +1,1306 @@ +#include "hmx-utils.h" +#include "hmx-queue.h" + +// MXFP4 dequantization LUT: maps 4-bit index to fp16 mantissa value +// kvalues: 0, 0.5, 1, 1.5, 2, 3, 4, 6, 0, -0.5, -1, -1.5, -2, -3, -4, -6 +static const __fp16 mxfp4_to_fp16_lut[64] __attribute__((aligned(VLEN))) = { + 0, 0, 0.5, 0, 1, 0, 1.5, 0, 2, 0, 3, 0, 4, 0, 6, 0, 0, 0, -0.5, 0, -1, 0, -1.5, 0, -2, 0, -3, 0, -4, 0, -6, 0, +}; + +static const __fp16 iq4_nl_to_fp16_lut[64] __attribute__((aligned(VLEN))) = { + -127, 0, -104, 0, -83, 0, -65, 0, -49, 0, -35, 0, -22, 0, -10, 0, + 1, 0, 13, 0, 25, 0, 38, 0, 53, 0, 69, 0, 89, 0, 113, 0, +}; + +// --- tiled format dequantizers --- + +typedef struct { + struct htp_context * ctx; + struct htp_thread_trace * traces; + __fp16 * dst; + const uint8_t * src; + + struct fastdiv_values n_k_tiles_div; + uint32_t n_k_tiles; + uint32_t n_tot_tiles; + uint32_t n_tiles_per_task; + uint32_t tile_size; + uint32_t aligned_tile_size; + uint32_t n_tasks; + uint32_t n_cols; + uint32_t k_block; + size_t row_stride; + uint32_t weight_type; +} tiled_dequantize_state_t; + +// Dequantize a single tile from tiled weight data (already in VTCM) to tile-major FP16. +static void dequantize_tiled_weight_to_fp16_task_q4_0( + const tiled_dequantize_state_t *state, + uint32_t start_tile, uint32_t end_tile) { + + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + const HVX_Vector i8 = Q6_Vb_vsplat_R(8); + + for (uint32_t t = start_tile; t < end_tile; t++) { + const uint8_t * tile_src = state->src + t * state->aligned_tile_size; + __fp16 * dst_ptr = state->dst + t * HTP_MM_HMX_TILE_N_ELMS; + + HVX_Vector v_sc = hvx_vmem(tile_src + 512); + HVX_Vector v_scale_duplicated = Q6_V_lo_W(Q6_W_vshuff_VVR(v_sc, v_sc, -2)); + + // Load all 4 groups in parallel + HVX_Vector vq0 = hvx_vmem(tile_src + 0 * 128); + HVX_Vector vq1 = hvx_vmem(tile_src + 1 * 128); + HVX_Vector vq2 = hvx_vmem(tile_src + 2 * 128); + HVX_Vector vq3 = hvx_vmem(tile_src + 3 * 128); + + // Nibble extraction + HVX_Vector v_lo0 = Q6_V_vand_VV(vq0, mask_h4); + HVX_Vector v_hi0 = Q6_Vub_vlsr_VubR(vq0, 4); + HVX_Vector v_lo1 = Q6_V_vand_VV(vq1, mask_h4); + HVX_Vector v_hi1 = Q6_Vub_vlsr_VubR(vq1, 4); + HVX_Vector v_lo2 = Q6_V_vand_VV(vq2, mask_h4); + HVX_Vector v_hi2 = Q6_Vub_vlsr_VubR(vq2, 4); + HVX_Vector v_lo3 = Q6_V_vand_VV(vq3, mask_h4); + HVX_Vector v_hi3 = Q6_Vub_vlsr_VubR(vq3, 4); + + // Offsetting (-8) + v_lo0 = Q6_Vb_vsub_VbVb(v_lo0, i8); + v_hi0 = Q6_Vb_vsub_VbVb(v_hi0, i8); + v_lo1 = Q6_Vb_vsub_VbVb(v_lo1, i8); + v_hi1 = Q6_Vb_vsub_VbVb(v_hi1, i8); + v_lo2 = Q6_Vb_vsub_VbVb(v_lo2, i8); + v_hi2 = Q6_Vb_vsub_VbVb(v_hi2, i8); + v_lo3 = Q6_Vb_vsub_VbVb(v_lo3, i8); + v_hi3 = Q6_Vb_vsub_VbVb(v_hi3, i8); + + // Shuffling + HVX_VectorPair vp_shuf0 = Q6_W_vshuff_VVR(v_hi0, v_lo0, -1); + HVX_VectorPair vp_shuf1 = Q6_W_vshuff_VVR(v_hi1, v_lo1, -1); + HVX_VectorPair vp_shuf2 = Q6_W_vshuff_VVR(v_hi2, v_lo2, -1); + HVX_VectorPair vp_shuf3 = Q6_W_vshuff_VVR(v_hi3, v_lo3, -1); + + // Unpack to 16-bit + HVX_VectorPair vp_int16_lo0 = Q6_Wh_vunpack_Vb(Q6_V_lo_W(vp_shuf0)); + HVX_VectorPair vp_int16_hi0 = Q6_Wh_vunpack_Vb(Q6_V_hi_W(vp_shuf0)); + HVX_VectorPair vp_int16_lo1 = Q6_Wh_vunpack_Vb(Q6_V_lo_W(vp_shuf1)); + HVX_VectorPair vp_int16_hi1 = Q6_Wh_vunpack_Vb(Q6_V_hi_W(vp_shuf1)); + HVX_VectorPair vp_int16_lo2 = Q6_Wh_vunpack_Vb(Q6_V_lo_W(vp_shuf2)); + HVX_VectorPair vp_int16_hi2 = Q6_Wh_vunpack_Vb(Q6_V_hi_W(vp_shuf2)); + HVX_VectorPair vp_int16_lo3 = Q6_Wh_vunpack_Vb(Q6_V_lo_W(vp_shuf3)); + HVX_VectorPair vp_int16_hi3 = Q6_Wh_vunpack_Vb(Q6_V_hi_W(vp_shuf3)); + + // Convert and scale multiplication + HVX_Vector v_grp0_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_lo0)), v_scale_duplicated)); + HVX_Vector v_grp0_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_lo0)), v_scale_duplicated)); + HVX_Vector v_grp0_2 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_hi0)), v_scale_duplicated)); + HVX_Vector v_grp0_3 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_hi0)), v_scale_duplicated)); + + HVX_Vector v_grp1_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_lo1)), v_scale_duplicated)); + HVX_Vector v_grp1_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_lo1)), v_scale_duplicated)); + HVX_Vector v_grp1_2 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_hi1)), v_scale_duplicated)); + HVX_Vector v_grp1_3 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_hi1)), v_scale_duplicated)); + + HVX_Vector v_grp2_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_lo2)), v_scale_duplicated)); + HVX_Vector v_grp2_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_lo2)), v_scale_duplicated)); + HVX_Vector v_grp2_2 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_hi2)), v_scale_duplicated)); + HVX_Vector v_grp2_3 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_hi2)), v_scale_duplicated)); + + HVX_Vector v_grp3_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_lo3)), v_scale_duplicated)); + HVX_Vector v_grp3_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_lo3)), v_scale_duplicated)); + HVX_Vector v_grp3_2 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_hi3)), v_scale_duplicated)); + HVX_Vector v_grp3_3 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_hi3)), v_scale_duplicated)); + + hvx_vmem(dst_ptr + 0 * 64) = v_grp0_0; + hvx_vmem(dst_ptr + 1 * 64) = v_grp0_1; + hvx_vmem(dst_ptr + 2 * 64) = v_grp0_2; + hvx_vmem(dst_ptr + 3 * 64) = v_grp0_3; + + hvx_vmem(dst_ptr + 4 * 64) = v_grp1_0; + hvx_vmem(dst_ptr + 5 * 64) = v_grp1_1; + hvx_vmem(dst_ptr + 6 * 64) = v_grp1_2; + hvx_vmem(dst_ptr + 7 * 64) = v_grp1_3; + + hvx_vmem(dst_ptr + 8 * 64) = v_grp2_0; + hvx_vmem(dst_ptr + 9 * 64) = v_grp2_1; + hvx_vmem(dst_ptr + 10 * 64) = v_grp2_2; + hvx_vmem(dst_ptr + 11 * 64) = v_grp2_3; + + hvx_vmem(dst_ptr + 12 * 64) = v_grp3_0; + hvx_vmem(dst_ptr + 13 * 64) = v_grp3_1; + hvx_vmem(dst_ptr + 14 * 64) = v_grp3_2; + hvx_vmem(dst_ptr + 15 * 64) = v_grp3_3; + } +} + +static void dequantize_tiled_weight_to_fp16_task_q4_1( + const tiled_dequantize_state_t *state, + uint32_t start_tile, uint32_t end_tile) { + + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + + for (uint32_t t = start_tile; t < end_tile; t++) { + const uint8_t * tile_src = state->src + t * state->aligned_tile_size; + __fp16 * dst_ptr = state->dst + t * HTP_MM_HMX_TILE_N_ELMS; + + HVX_Vector vscale_offset = hvx_vmem(tile_src + 512); + HVX_VectorPair dm_deal = Q6_W_vdeal_VVR(vscale_offset, vscale_offset, -2); + HVX_Vector vd = Q6_V_lo_W(dm_deal); + HVX_Vector vm = Q6_V_hi_W(dm_deal); + + HVX_Vector v_scale_duplicated = Q6_V_lo_W(Q6_W_vshuff_VVR(vd, vd, -2)); + HVX_Vector v_offset_duplicated = Q6_V_lo_W(Q6_W_vshuff_VVR(vm, vm, -2)); + + // Load all 4 groups in parallel + HVX_Vector vq0 = hvx_vmem(tile_src + 0 * 128); + HVX_Vector vq1 = hvx_vmem(tile_src + 1 * 128); + HVX_Vector vq2 = hvx_vmem(tile_src + 2 * 128); + HVX_Vector vq3 = hvx_vmem(tile_src + 3 * 128); + + // Nibble extraction + HVX_Vector v_lo0 = Q6_V_vand_VV(vq0, mask_h4); + HVX_Vector v_hi0 = Q6_Vub_vlsr_VubR(vq0, 4); + HVX_Vector v_lo1 = Q6_V_vand_VV(vq1, mask_h4); + HVX_Vector v_hi1 = Q6_Vub_vlsr_VubR(vq1, 4); + HVX_Vector v_lo2 = Q6_V_vand_VV(vq2, mask_h4); + HVX_Vector v_hi2 = Q6_Vub_vlsr_VubR(vq2, 4); + HVX_Vector v_lo3 = Q6_V_vand_VV(vq3, mask_h4); + HVX_Vector v_hi3 = Q6_Vub_vlsr_VubR(vq3, 4); + + // Shuffling + HVX_VectorPair vp_shuf0 = Q6_W_vshuff_VVR(v_hi0, v_lo0, -1); + HVX_VectorPair vp_shuf1 = Q6_W_vshuff_VVR(v_hi1, v_lo1, -1); + HVX_VectorPair vp_shuf2 = Q6_W_vshuff_VVR(v_hi2, v_lo2, -1); + HVX_VectorPair vp_shuf3 = Q6_W_vshuff_VVR(v_hi3, v_lo3, -1); + + // Unpack to 16-bit + HVX_VectorPair vp_int16_lo0 = Q6_Wh_vunpack_Vb(Q6_V_lo_W(vp_shuf0)); + HVX_VectorPair vp_int16_hi0 = Q6_Wh_vunpack_Vb(Q6_V_hi_W(vp_shuf0)); + HVX_VectorPair vp_int16_lo1 = Q6_Wh_vunpack_Vb(Q6_V_lo_W(vp_shuf1)); + HVX_VectorPair vp_int16_hi1 = Q6_Wh_vunpack_Vb(Q6_V_hi_W(vp_shuf1)); + HVX_VectorPair vp_int16_lo2 = Q6_Wh_vunpack_Vb(Q6_V_lo_W(vp_shuf2)); + HVX_VectorPair vp_int16_hi2 = Q6_Wh_vunpack_Vb(Q6_V_hi_W(vp_shuf2)); + HVX_VectorPair vp_int16_lo3 = Q6_Wh_vunpack_Vb(Q6_V_lo_W(vp_shuf3)); + HVX_VectorPair vp_int16_hi3 = Q6_Wh_vunpack_Vb(Q6_V_hi_W(vp_shuf3)); + + // Convert, multiply, add offset + HVX_Vector v_grp0_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_lo0)), v_scale_duplicated), v_offset_duplicated)); + HVX_Vector v_grp0_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_lo0)), v_scale_duplicated), v_offset_duplicated)); + HVX_Vector v_grp0_2 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_hi0)), v_scale_duplicated), v_offset_duplicated)); + HVX_Vector v_grp0_3 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_hi0)), v_scale_duplicated), v_offset_duplicated)); + + HVX_Vector v_grp1_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_lo1)), v_scale_duplicated), v_offset_duplicated)); + HVX_Vector v_grp1_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_lo1)), v_scale_duplicated), v_offset_duplicated)); + HVX_Vector v_grp1_2 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_hi1)), v_scale_duplicated), v_offset_duplicated)); + HVX_Vector v_grp1_3 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_hi1)), v_scale_duplicated), v_offset_duplicated)); + + HVX_Vector v_grp2_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_lo2)), v_scale_duplicated), v_offset_duplicated)); + HVX_Vector v_grp2_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_lo2)), v_scale_duplicated), v_offset_duplicated)); + HVX_Vector v_grp2_2 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_hi2)), v_scale_duplicated), v_offset_duplicated)); + HVX_Vector v_grp2_3 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_hi2)), v_scale_duplicated), v_offset_duplicated)); + + HVX_Vector v_grp3_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_lo3)), v_scale_duplicated), v_offset_duplicated)); + HVX_Vector v_grp3_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_lo3)), v_scale_duplicated), v_offset_duplicated)); + HVX_Vector v_grp3_2 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_hi3)), v_scale_duplicated), v_offset_duplicated)); + HVX_Vector v_grp3_3 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_hi3)), v_scale_duplicated), v_offset_duplicated)); + + // Parallel Stores + hvx_vmem(dst_ptr + 0 * 64) = v_grp0_0; + hvx_vmem(dst_ptr + 1 * 64) = v_grp0_1; + hvx_vmem(dst_ptr + 2 * 64) = v_grp0_2; + hvx_vmem(dst_ptr + 3 * 64) = v_grp0_3; + + hvx_vmem(dst_ptr + 4 * 64) = v_grp1_0; + hvx_vmem(dst_ptr + 5 * 64) = v_grp1_1; + hvx_vmem(dst_ptr + 6 * 64) = v_grp1_2; + hvx_vmem(dst_ptr + 7 * 64) = v_grp1_3; + + hvx_vmem(dst_ptr + 8 * 64) = v_grp2_0; + hvx_vmem(dst_ptr + 9 * 64) = v_grp2_1; + hvx_vmem(dst_ptr + 10 * 64) = v_grp2_2; + hvx_vmem(dst_ptr + 11 * 64) = v_grp2_3; + + hvx_vmem(dst_ptr + 12 * 64) = v_grp3_0; + hvx_vmem(dst_ptr + 13 * 64) = v_grp3_1; + hvx_vmem(dst_ptr + 14 * 64) = v_grp3_2; + hvx_vmem(dst_ptr + 15 * 64) = v_grp3_3; + } +} + +static void dequantize_tiled_weight_to_fp16_task_iq4_nl( + const tiled_dequantize_state_t *state, + uint32_t start_tile, uint32_t end_tile) { + + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + const HVX_Vector vlut_cvt = hvx_vmem(iq4_nl_to_fp16_lut); + + for (uint32_t t = start_tile; t < end_tile; t++) { + const uint8_t * tile_src = state->src + t * state->aligned_tile_size; + __fp16 * dst_ptr = state->dst + t * HTP_MM_HMX_TILE_N_ELMS; + + HVX_Vector v_sc = hvx_vmem(tile_src + 512); + HVX_Vector v_scale_duplicated = Q6_V_lo_W(Q6_W_vshuff_VVR(v_sc, v_sc, -2)); + + // Load all 4 groups in parallel + HVX_Vector vq0 = hvx_vmem(tile_src + 0 * 128); + HVX_Vector vq1 = hvx_vmem(tile_src + 1 * 128); + HVX_Vector vq2 = hvx_vmem(tile_src + 2 * 128); + HVX_Vector vq3 = hvx_vmem(tile_src + 3 * 128); + + // Nibble extraction + HVX_Vector v_lo0 = Q6_V_vand_VV(vq0, mask_h4); + HVX_Vector v_hi0 = Q6_Vub_vlsr_VubR(vq0, 4); + HVX_Vector v_lo1 = Q6_V_vand_VV(vq1, mask_h4); + HVX_Vector v_hi1 = Q6_Vub_vlsr_VubR(vq1, 4); + HVX_Vector v_lo2 = Q6_V_vand_VV(vq2, mask_h4); + HVX_Vector v_hi2 = Q6_Vub_vlsr_VubR(vq2, 4); + HVX_Vector v_lo3 = Q6_V_vand_VV(vq3, mask_h4); + HVX_Vector v_hi3 = Q6_Vub_vlsr_VubR(vq3, 4); + + // Shuffling + HVX_VectorPair vp_shuf0 = Q6_W_vshuff_VVR(v_hi0, v_lo0, -1); + HVX_VectorPair vp_shuf1 = Q6_W_vshuff_VVR(v_hi1, v_lo1, -1); + HVX_VectorPair vp_shuf2 = Q6_W_vshuff_VVR(v_hi2, v_lo2, -1); + HVX_VectorPair vp_shuf3 = Q6_W_vshuff_VVR(v_hi3, v_lo3, -1); + + // Shuffle for LUT lookup + HVX_Vector v_q_lo0 = Q6_Vb_vshuff_Vb(Q6_V_lo_W(vp_shuf0)); + HVX_Vector v_q_hi0 = Q6_Vb_vshuff_Vb(Q6_V_hi_W(vp_shuf0)); + HVX_Vector v_q_lo1 = Q6_Vb_vshuff_Vb(Q6_V_lo_W(vp_shuf1)); + HVX_Vector v_q_hi1 = Q6_Vb_vshuff_Vb(Q6_V_hi_W(vp_shuf1)); + HVX_Vector v_q_lo2 = Q6_Vb_vshuff_Vb(Q6_V_lo_W(vp_shuf2)); + HVX_Vector v_q_hi2 = Q6_Vb_vshuff_Vb(Q6_V_hi_W(vp_shuf2)); + HVX_Vector v_q_lo3 = Q6_Vb_vshuff_Vb(Q6_V_lo_W(vp_shuf3)); + HVX_Vector v_q_hi3 = Q6_Vb_vshuff_Vb(Q6_V_hi_W(vp_shuf3)); + + // LUT lookup + HVX_VectorPair vp_lo0 = Q6_Wh_vlut16_VbVhR(v_q_lo0, vlut_cvt, 0); + HVX_VectorPair vp_hi0 = Q6_Wh_vlut16_VbVhR(v_q_hi0, vlut_cvt, 0); + HVX_VectorPair vp_lo1 = Q6_Wh_vlut16_VbVhR(v_q_lo1, vlut_cvt, 0); + HVX_VectorPair vp_hi1 = Q6_Wh_vlut16_VbVhR(v_q_hi1, vlut_cvt, 0); + HVX_VectorPair vp_lo2 = Q6_Wh_vlut16_VbVhR(v_q_lo2, vlut_cvt, 0); + HVX_VectorPair vp_hi2 = Q6_Wh_vlut16_VbVhR(v_q_hi2, vlut_cvt, 0); + HVX_VectorPair vp_lo3 = Q6_Wh_vlut16_VbVhR(v_q_lo3, vlut_cvt, 0); + HVX_VectorPair vp_hi3 = Q6_Wh_vlut16_VbVhR(v_q_hi3, vlut_cvt, 0); + + // Convert and scale multiplication + HVX_Vector v_grp0_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_lo_W(vp_lo0), v_scale_duplicated)); + HVX_Vector v_grp0_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_hi_W(vp_lo0), v_scale_duplicated)); + HVX_Vector v_grp0_2 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_lo_W(vp_hi0), v_scale_duplicated)); + HVX_Vector v_grp0_3 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_hi_W(vp_hi0), v_scale_duplicated)); + + HVX_Vector v_grp1_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_lo_W(vp_lo1), v_scale_duplicated)); + HVX_Vector v_grp1_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_hi_W(vp_lo1), v_scale_duplicated)); + HVX_Vector v_grp1_2 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_lo_W(vp_hi1), v_scale_duplicated)); + HVX_Vector v_grp1_3 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_hi_W(vp_hi1), v_scale_duplicated)); + + HVX_Vector v_grp2_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_lo_W(vp_lo2), v_scale_duplicated)); + HVX_Vector v_grp2_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_hi_W(vp_lo2), v_scale_duplicated)); + HVX_Vector v_grp2_2 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_lo_W(vp_hi2), v_scale_duplicated)); + HVX_Vector v_grp2_3 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_hi_W(vp_hi2), v_scale_duplicated)); + + HVX_Vector v_grp3_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_lo_W(vp_lo3), v_scale_duplicated)); + HVX_Vector v_grp3_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_hi_W(vp_lo3), v_scale_duplicated)); + HVX_Vector v_grp3_2 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_lo_W(vp_hi3), v_scale_duplicated)); + HVX_Vector v_grp3_3 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_hi_W(vp_hi3), v_scale_duplicated)); + + hvx_vmem(dst_ptr + 0 * 64) = v_grp0_0; + hvx_vmem(dst_ptr + 1 * 64) = v_grp0_1; + hvx_vmem(dst_ptr + 2 * 64) = v_grp0_2; + hvx_vmem(dst_ptr + 3 * 64) = v_grp0_3; + + hvx_vmem(dst_ptr + 4 * 64) = v_grp1_0; + hvx_vmem(dst_ptr + 5 * 64) = v_grp1_1; + hvx_vmem(dst_ptr + 6 * 64) = v_grp1_2; + hvx_vmem(dst_ptr + 7 * 64) = v_grp1_3; + + hvx_vmem(dst_ptr + 8 * 64) = v_grp2_0; + hvx_vmem(dst_ptr + 9 * 64) = v_grp2_1; + hvx_vmem(dst_ptr + 10 * 64) = v_grp2_2; + hvx_vmem(dst_ptr + 11 * 64) = v_grp2_3; + + hvx_vmem(dst_ptr + 12 * 64) = v_grp3_0; + hvx_vmem(dst_ptr + 13 * 64) = v_grp3_1; + hvx_vmem(dst_ptr + 14 * 64) = v_grp3_2; + hvx_vmem(dst_ptr + 15 * 64) = v_grp3_3; + } +} + +static void dequantize_tiled_weight_to_fp16_task_mxfp4( + const tiled_dequantize_state_t *state, + uint32_t start_tile, uint32_t end_tile) { + + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + const HVX_Vector vlut_cvt = hvx_vmem(mxfp4_to_fp16_lut); + + for (uint32_t t = start_tile; t < end_tile; t++) { + const uint8_t * tile_src = state->src + t * state->aligned_tile_size; + __fp16 * dst_ptr = state->dst + t * HTP_MM_HMX_TILE_N_ELMS; + + HVX_Vector v = hvx_vmem(tile_src + 512); + HVX_Vector vh = Q6_V_lo_W(Q6_Wuh_vunpack_Vub(v)); + vh = Q6_Vh_vsub_VhVh(vh, Q6_Vh_vsplat_R(112)); + vh = Q6_Vh_vmax_VhVh(vh, Q6_V_vzero()); + vh = Q6_Vh_vmin_VhVh(vh, Q6_Vh_vsplat_R(30)); + vh = Q6_Vh_vasl_VhR(vh, 10); + + HVX_Vector v_scale_duplicated = Q6_V_lo_W(Q6_W_vshuff_VVR(vh, vh, -2)); + + // Load all 4 groups in parallel + HVX_Vector vq0 = hvx_vmem(tile_src + 0 * 128); + HVX_Vector vq1 = hvx_vmem(tile_src + 1 * 128); + HVX_Vector vq2 = hvx_vmem(tile_src + 2 * 128); + HVX_Vector vq3 = hvx_vmem(tile_src + 3 * 128); + + // Nibble extraction + HVX_Vector v_lo0 = Q6_V_vand_VV(vq0, mask_h4); + HVX_Vector v_hi0 = Q6_Vub_vlsr_VubR(vq0, 4); + HVX_Vector v_lo1 = Q6_V_vand_VV(vq1, mask_h4); + HVX_Vector v_hi1 = Q6_Vub_vlsr_VubR(vq1, 4); + HVX_Vector v_lo2 = Q6_V_vand_VV(vq2, mask_h4); + HVX_Vector v_hi2 = Q6_Vub_vlsr_VubR(vq2, 4); + HVX_Vector v_lo3 = Q6_V_vand_VV(vq3, mask_h4); + HVX_Vector v_hi3 = Q6_Vub_vlsr_VubR(vq3, 4); + + // Shuffling + HVX_VectorPair vp_shuf0 = Q6_W_vshuff_VVR(v_hi0, v_lo0, -1); + HVX_VectorPair vp_shuf1 = Q6_W_vshuff_VVR(v_hi1, v_lo1, -1); + HVX_VectorPair vp_shuf2 = Q6_W_vshuff_VVR(v_hi2, v_lo2, -1); + HVX_VectorPair vp_shuf3 = Q6_W_vshuff_VVR(v_hi3, v_lo3, -1); + + // Shuffle for LUT lookup + HVX_Vector v_q_lo0 = Q6_Vb_vshuff_Vb(Q6_V_lo_W(vp_shuf0)); + HVX_Vector v_q_hi0 = Q6_Vb_vshuff_Vb(Q6_V_hi_W(vp_shuf0)); + HVX_Vector v_q_lo1 = Q6_Vb_vshuff_Vb(Q6_V_lo_W(vp_shuf1)); + HVX_Vector v_q_hi1 = Q6_Vb_vshuff_Vb(Q6_V_hi_W(vp_shuf1)); + HVX_Vector v_q_lo2 = Q6_Vb_vshuff_Vb(Q6_V_lo_W(vp_shuf2)); + HVX_Vector v_q_hi2 = Q6_Vb_vshuff_Vb(Q6_V_hi_W(vp_shuf2)); + HVX_Vector v_q_lo3 = Q6_Vb_vshuff_Vb(Q6_V_lo_W(vp_shuf3)); + HVX_Vector v_q_hi3 = Q6_Vb_vshuff_Vb(Q6_V_hi_W(vp_shuf3)); + + // LUT lookup + HVX_VectorPair vp_lo0 = Q6_Wh_vlut16_VbVhR(v_q_lo0, vlut_cvt, 0); + HVX_VectorPair vp_hi0 = Q6_Wh_vlut16_VbVhR(v_q_hi0, vlut_cvt, 0); + HVX_VectorPair vp_lo1 = Q6_Wh_vlut16_VbVhR(v_q_lo1, vlut_cvt, 0); + HVX_VectorPair vp_hi1 = Q6_Wh_vlut16_VbVhR(v_q_hi1, vlut_cvt, 0); + HVX_VectorPair vp_lo2 = Q6_Wh_vlut16_VbVhR(v_q_lo2, vlut_cvt, 0); + HVX_VectorPair vp_hi2 = Q6_Wh_vlut16_VbVhR(v_q_hi2, vlut_cvt, 0); + HVX_VectorPair vp_lo3 = Q6_Wh_vlut16_VbVhR(v_q_lo3, vlut_cvt, 0); + HVX_VectorPair vp_hi3 = Q6_Wh_vlut16_VbVhR(v_q_hi3, vlut_cvt, 0); + + // Convert and scale multiplication + HVX_Vector v_grp0_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_lo_W(vp_lo0), v_scale_duplicated)); + HVX_Vector v_grp0_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_hi_W(vp_lo0), v_scale_duplicated)); + HVX_Vector v_grp0_2 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_lo_W(vp_hi0), v_scale_duplicated)); + HVX_Vector v_grp0_3 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_hi_W(vp_hi0), v_scale_duplicated)); + + HVX_Vector v_grp1_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_lo_W(vp_lo1), v_scale_duplicated)); + HVX_Vector v_grp1_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_hi_W(vp_lo1), v_scale_duplicated)); + HVX_Vector v_grp1_2 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_lo_W(vp_hi1), v_scale_duplicated)); + HVX_Vector v_grp1_3 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_hi_W(vp_hi1), v_scale_duplicated)); + + HVX_Vector v_grp2_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_lo_W(vp_lo2), v_scale_duplicated)); + HVX_Vector v_grp2_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_hi_W(vp_lo2), v_scale_duplicated)); + HVX_Vector v_grp2_2 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_lo_W(vp_hi2), v_scale_duplicated)); + HVX_Vector v_grp2_3 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_hi_W(vp_hi2), v_scale_duplicated)); + + HVX_Vector v_grp3_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_lo_W(vp_lo3), v_scale_duplicated)); + HVX_Vector v_grp3_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_hi_W(vp_lo3), v_scale_duplicated)); + HVX_Vector v_grp3_2 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_lo_W(vp_hi3), v_scale_duplicated)); + HVX_Vector v_grp3_3 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_hi_W(vp_hi3), v_scale_duplicated)); + + hvx_vmem(dst_ptr + 0 * 64) = v_grp0_0; + hvx_vmem(dst_ptr + 1 * 64) = v_grp0_1; + hvx_vmem(dst_ptr + 2 * 64) = v_grp0_2; + hvx_vmem(dst_ptr + 3 * 64) = v_grp0_3; + + hvx_vmem(dst_ptr + 4 * 64) = v_grp1_0; + hvx_vmem(dst_ptr + 5 * 64) = v_grp1_1; + hvx_vmem(dst_ptr + 6 * 64) = v_grp1_2; + hvx_vmem(dst_ptr + 7 * 64) = v_grp1_3; + + hvx_vmem(dst_ptr + 8 * 64) = v_grp2_0; + hvx_vmem(dst_ptr + 9 * 64) = v_grp2_1; + hvx_vmem(dst_ptr + 10 * 64) = v_grp2_2; + hvx_vmem(dst_ptr + 11 * 64) = v_grp2_3; + + hvx_vmem(dst_ptr + 12 * 64) = v_grp3_0; + hvx_vmem(dst_ptr + 13 * 64) = v_grp3_1; + hvx_vmem(dst_ptr + 14 * 64) = v_grp3_2; + hvx_vmem(dst_ptr + 15 * 64) = v_grp3_3; + } +} + +static void dequantize_tiled_weight_to_fp16_task_q8_0( + const tiled_dequantize_state_t *state, + uint32_t start_tile, uint32_t end_tile) { + + for (uint32_t t = start_tile; t < end_tile; t++) { + const uint8_t * tile_src = state->src + t * state->aligned_tile_size; + __fp16 * dst_ptr = state->dst + t * HTP_MM_HMX_TILE_N_ELMS; + + HVX_Vector v_sc = hvx_vmem(tile_src + 1024); + HVX_Vector v_scale_duplicated = Q6_V_lo_W(Q6_W_vshuff_VVR(v_sc, v_sc, -2)); + + // Load groups 0-3 in parallel + HVX_Vector vq0 = hvx_vmem(tile_src + 0 * 128); + HVX_Vector vq1 = hvx_vmem(tile_src + 1 * 128); + HVX_Vector vq2 = hvx_vmem(tile_src + 2 * 128); + HVX_Vector vq3 = hvx_vmem(tile_src + 3 * 128); + + HVX_VectorPair vp_int16_0 = Q6_Wh_vunpack_Vb(vq0); + HVX_VectorPair vp_int16_1 = Q6_Wh_vunpack_Vb(vq1); + HVX_VectorPair vp_int16_2 = Q6_Wh_vunpack_Vb(vq2); + HVX_VectorPair vp_int16_3 = Q6_Wh_vunpack_Vb(vq3); + + // Load groups 4-7 in parallel + HVX_Vector vq4 = hvx_vmem(tile_src + 4 * 128); + HVX_Vector vq5 = hvx_vmem(tile_src + 5 * 128); + HVX_Vector vq6 = hvx_vmem(tile_src + 6 * 128); + HVX_Vector vq7 = hvx_vmem(tile_src + 7 * 128); + + HVX_VectorPair vp_int16_4 = Q6_Wh_vunpack_Vb(vq4); + HVX_VectorPair vp_int16_5 = Q6_Wh_vunpack_Vb(vq5); + HVX_VectorPair vp_int16_6 = Q6_Wh_vunpack_Vb(vq6); + HVX_VectorPair vp_int16_7 = Q6_Wh_vunpack_Vb(vq7); + + // Convert and scale multiply for groups 0-3 + HVX_Vector v_grp0_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_0)), v_scale_duplicated)); + HVX_Vector v_grp0_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_0)), v_scale_duplicated)); + HVX_Vector v_grp1_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_1)), v_scale_duplicated)); + HVX_Vector v_grp1_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_1)), v_scale_duplicated)); + HVX_Vector v_grp2_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_2)), v_scale_duplicated)); + HVX_Vector v_grp2_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_2)), v_scale_duplicated)); + HVX_Vector v_grp3_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_3)), v_scale_duplicated)); + HVX_Vector v_grp3_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_3)), v_scale_duplicated)); + + // Store groups 0-3 + hvx_vmem(dst_ptr + 0 * 64) = v_grp0_0; + hvx_vmem(dst_ptr + 1 * 64) = v_grp0_1; + hvx_vmem(dst_ptr + 2 * 64) = v_grp1_0; + hvx_vmem(dst_ptr + 3 * 64) = v_grp1_1; + hvx_vmem(dst_ptr + 4 * 64) = v_grp2_0; + hvx_vmem(dst_ptr + 5 * 64) = v_grp2_1; + hvx_vmem(dst_ptr + 6 * 64) = v_grp3_0; + hvx_vmem(dst_ptr + 7 * 64) = v_grp3_1; + + // Convert and scale multiply for groups 4-7 + HVX_Vector v_grp4_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_4)), v_scale_duplicated)); + HVX_Vector v_grp4_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_4)), v_scale_duplicated)); + HVX_Vector v_grp5_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_5)), v_scale_duplicated)); + HVX_Vector v_grp5_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_5)), v_scale_duplicated)); + HVX_Vector v_grp6_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_6)), v_scale_duplicated)); + HVX_Vector v_grp6_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_6)), v_scale_duplicated)); + HVX_Vector v_grp7_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_7)), v_scale_duplicated)); + HVX_Vector v_grp7_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_7)), v_scale_duplicated)); + + // Store groups 4-7 + hvx_vmem(dst_ptr + 8 * 64) = v_grp4_0; + hvx_vmem(dst_ptr + 9 * 64) = v_grp4_1; + hvx_vmem(dst_ptr + 10 * 64) = v_grp5_0; + hvx_vmem(dst_ptr + 11 * 64) = v_grp5_1; + hvx_vmem(dst_ptr + 12 * 64) = v_grp6_0; + hvx_vmem(dst_ptr + 13 * 64) = v_grp6_1; + hvx_vmem(dst_ptr + 14 * 64) = v_grp7_0; + hvx_vmem(dst_ptr + 15 * 64) = v_grp7_1; + } +} + +static void convert_f16_weight_to_fp16_tiles_task( + const tiled_dequantize_state_t *state, + uint32_t start_tile, uint32_t end_tile) { + + const uint32_t n_k_tiles = state->n_k_tiles; + const struct fastdiv_values n_k_tiles_div = state->n_k_tiles_div; + + const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets); + const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); + const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); + + unsigned ct = fastdiv((unsigned)start_tile, &n_k_tiles_div); + unsigned kt = fastmodulo((unsigned)start_tile, n_k_tiles, &n_k_tiles_div); + + for (unsigned t = start_tile; t < (unsigned)end_tile; ) { + if (kt >= (unsigned)n_k_tiles) { kt = 0; ct++; } + + __fp16 *tile_base = state->dst + t * HTP_MM_HMX_TILE_N_ELMS; + { + uint32_t byte_off = kt * 32 * sizeof(__fp16); + + HVX_Vector v_off = v_scat_base; + for (uint32_t r = 0; r < HTP_MM_HMX_TILE_N_ROWS; r += 2) { + uint32_t row0 = ct * HTP_MM_HMX_TILE_N_COLS + r; + uint32_t row1 = row0 + 1; + + const uint8_t *r0 = state->src + row0 * state->row_stride; + const uint8_t *r1 = state->src + row1 * state->row_stride; + + HVX_Vector v0 = hvx_vmemu((const __fp16 *)(r0 + byte_off)); + HVX_Vector v1 = (row1 < state->n_cols) ? hvx_vmemu((const __fp16 *)(r1 + byte_off)) : Q6_V_vzero(); + + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HTP_MM_HMX_TILE_SIZE - 1, v_off, v0); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HTP_MM_HMX_TILE_SIZE - 1, v_off, v1); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + } + (void) *(volatile HVX_Vector *)(tile_base); + } + ++t; ++kt; + } + + if (start_tile < end_tile) { + (void) *(volatile HVX_Vector *)(state->dst + (end_tile - 1) * HTP_MM_HMX_TILE_N_ELMS); + } +} + +static void quantize_f32_weight_to_fp16_tiles_task( + const tiled_dequantize_state_t *state, + uint32_t start_tile, uint32_t end_tile) { + + const uint32_t n_k_tiles = state->n_k_tiles; + const struct fastdiv_values n_k_tiles_div = state->n_k_tiles_div; + + const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets); + const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); + const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); + + unsigned ct = fastdiv((unsigned)start_tile, &n_k_tiles_div); + unsigned kt = fastmodulo((unsigned)start_tile, n_k_tiles, &n_k_tiles_div); + + for (unsigned t = start_tile; t < (unsigned)end_tile; ) { + if (kt >= (unsigned)n_k_tiles) { kt = 0; ct++; } + + __fp16 *tile_base = state->dst + t * HTP_MM_HMX_TILE_N_ELMS; + { + uint32_t byte_off = kt * 32 * sizeof(float); + + HVX_Vector v_off = v_scat_base; + for (uint32_t r = 0; r < HTP_MM_HMX_TILE_N_ROWS; r += 2) { + uint32_t row0 = ct * HTP_MM_HMX_TILE_N_COLS + r; + uint32_t row1 = row0 + 1; + + const uint8_t *r0 = state->src + row0 * state->row_stride; + const uint8_t *r1 = state->src + row1 * state->row_stride; + + HVX_Vector v0_f32 = hvx_vmem((const float *)(r0 + byte_off)); + HVX_Vector v1_f32 = (row1 < state->n_cols) ? hvx_vmem((const float *)(r1 + byte_off)) : Q6_V_vzero(); + + HVX_Vector v_out = hvx_vec_f32_to_f16(v0_f32, v1_f32); + + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HTP_MM_HMX_TILE_SIZE - 1, v_off, v_out); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + + HVX_Vector v_out_hi = Q6_V_vror_VR(v_out, 64); + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HTP_MM_HMX_TILE_SIZE - 1, v_off, v_out_hi); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + } + (void) *(volatile HVX_Vector *)(tile_base); + } + ++t; ++kt; + } + + if (start_tile < end_tile) { + (void) *(volatile HVX_Vector *)(state->dst + (end_tile - 1) * HTP_MM_HMX_TILE_N_ELMS); + } +} + +// --- End tiled dequantizers --- + +// requires external HMX lock +static void core_dot_chunk_fp16(__fp16 *restrict output, const __fp16 *restrict activation, const __fp16 *restrict weight, const __fp16 *restrict scales, + uint32_t n_row_tiles, uint32_t n_col_tiles, uint32_t n_dot_tiles) { + __builtin_assume(n_row_tiles > 0); + __builtin_assume(n_col_tiles > 0); + __builtin_assume(n_dot_tiles > 0); + + Q6_bias_mxmem2_A((void *)scales); + for (uint32_t r = 0; r < n_row_tiles; ++r) { + for (size_t c = 0; c < n_col_tiles; ++c) { + Q6_mxclracc_hf(); + + const __fp16 *row_tiles = activation + r * n_dot_tiles * HTP_MM_HMX_TILE_N_ELMS; + const __fp16 *col_tiles = weight + c * n_dot_tiles * HTP_MM_HMX_TILE_N_ELMS; + + for (uint32_t k = 0, k_block; k < n_dot_tiles; k += k_block) { + k_block = hex_smin(n_dot_tiles - k, 32); + const uint32_t range = 2048u * (uint32_t)k_block - 1; + Q6_activation_hf_mxmem_RR_deep((unsigned int)row_tiles, range); + Q6_weight_hf_mxmem_RR((unsigned int)col_tiles, range); + row_tiles += k_block * HTP_MM_HMX_TILE_N_ELMS; + col_tiles += k_block * HTP_MM_HMX_TILE_N_ELMS; + } + + __fp16 *out_tile = output + (r * n_col_tiles + c) * HTP_MM_HMX_TILE_N_ELMS; + Q6_mxmem_AR_after_hf(out_tile, 0); + } + } +} + +// C += AB +static void core_mma_chunk_fp16(__fp16 *restrict c, const __fp16 *restrict a, const __fp16 *restrict b, + const __fp16 *restrict col_scales, const __fp16 *restrict eye_tile, + uint32_t n_row_tiles, uint32_t n_col_tiles, uint32_t n_dot_tiles, bool zero_init) { + __builtin_assume(n_row_tiles > 0); + __builtin_assume(n_col_tiles > 0); + __builtin_assume(n_dot_tiles > 0); + + Q6_bias_mxmem2_A((void *)col_scales); + + const size_t dot_tile_stride = n_dot_tiles * HTP_MM_HMX_TILE_N_ELMS; + for (size_t i = 0; i < n_row_tiles; ++i) { + const __fp16 *row_base = a + i * dot_tile_stride; + __fp16 *res_base = c + i * n_col_tiles * HTP_MM_HMX_TILE_N_ELMS; + for (size_t j = 0; j < n_col_tiles; ++j) { + Q6_mxclracc_hf(); + + const __fp16 *col_tiles = b + j * dot_tile_stride; + const __fp16 *row_tiles = row_base; + __fp16 *accum_tile = res_base + j * HTP_MM_HMX_TILE_N_ELMS; + if (!zero_init) { + Q6_activation_hf_mxmem_RR((unsigned int)accum_tile, 2047); + Q6_weight_hf_mxmem_RR((unsigned int)eye_tile, 2047); + } + + for (uint32_t k = 0, k_block; k < n_dot_tiles; k += k_block) { + k_block = hex_smin(n_dot_tiles - k, 32); + const uint32_t range = 2048u * k_block - 1; + Q6_activation_hf_mxmem_RR_deep((unsigned int)row_tiles, range); + Q6_weight_hf_mxmem_RR((unsigned int)col_tiles, range); + row_tiles += k_block * HTP_MM_HMX_TILE_N_ELMS; + col_tiles += k_block * HTP_MM_HMX_TILE_N_ELMS; + } + + Q6_mxmem_AR_after_hf(accum_tile, 0); + } + } +} + +// --- Async HMX matmul job (for pipeline overlap) --- + +typedef struct { + __fp16 * output; + const __fp16 * activation; + const __fp16 * weight; + const __fp16 * scales; + uint32_t n_row_tiles; + uint32_t n_col_tiles; + uint32_t n_dot_tiles; +} hmx_matmul_job_t; + +static void hmx_matmul_worker_fn(void * data) { + hmx_matmul_job_t * job = (hmx_matmul_job_t *) data; + FARF(HIGH, "hmx-mm-job: n_row_tiles %u n_col_tiles %u n_dot_tiles %u", job->n_row_tiles, job->n_col_tiles, job->n_dot_tiles); + core_dot_chunk_fp16(job->output, job->activation, job->weight, job->scales, job->n_row_tiles, job->n_col_tiles, job->n_dot_tiles); +} + +static inline void hmx_matmul_job_init(hmx_matmul_job_t * job, + __fp16 * output, + const __fp16 * activation, + const __fp16 * weight, + const __fp16 * scales, + uint32_t n_row_tiles, + uint32_t n_col_tiles, + uint32_t n_dot_tiles) { + job->output = output; + job->activation = activation; + job->weight = weight; + job->scales = scales; + job->n_row_tiles = n_row_tiles; + job->n_col_tiles = n_col_tiles; + job->n_dot_tiles = n_dot_tiles; +} + +// output : fp16 -> f32p + +static void transfer_output_chunk_fp16_to_fp32(float *restrict dst, const __fp16 *restrict vtcm_src, uint32_t start_row, uint32_t n_rows, uint32_t n_cols, uint32_t dst_stride, uint32_t dst_cols) { + assert(n_cols % HTP_MM_HMX_TILE_N_COLS == 0); + const size_t tile_row_stride = (n_cols / HTP_MM_HMX_TILE_N_COLS) * HTP_MM_HMX_TILE_N_ELMS; + + const HVX_Vector one = hvx_vec_splat_f16(1.0); + + const size_t limit_c = hex_smin(n_cols, dst_cols); + const size_t limit_c_aligned = (limit_c & ~31); + + for (size_t r = 0; r < n_rows; r += 2) { + const size_t r_idx0 = start_row + r + 0; + const size_t r0 = r_idx0 / HTP_MM_HMX_TILE_N_ROWS; + const size_t r1 = (r_idx0 % HTP_MM_HMX_TILE_N_ROWS) / 2; // index of the row pair within the tile + const __fp16 *row_base = vtcm_src + r0 * tile_row_stride; + float *output_row_base = dst + r * dst_stride; // global memory row base for row r (and r+1) + + #pragma unroll(4) + for (size_t c = 0; c < limit_c_aligned; c += HTP_MM_HMX_TILE_N_COLS) { + const size_t c0 = c / HTP_MM_HMX_TILE_N_COLS; + const __fp16 *tile = row_base + c0 * HTP_MM_HMX_TILE_N_ELMS; + HVX_Vector v = ((const HVX_Vector *) tile)[r1]; + HVX_VectorPair vp = Q6_Wqf32_vmpy_VhfVhf(v, one); + + HVX_Vector *pv_out0 = (HVX_Vector *) (output_row_base + c + 0); + HVX_Vector *pv_out1 = (HVX_Vector *) (output_row_base + c + dst_stride); + + *pv_out0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vp)); + if (r + 1 < n_rows) { + *pv_out1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vp)); + } + } + + if (limit_c_aligned < limit_c) { + size_t c = limit_c_aligned; + size_t valid_c = limit_c - c; + const size_t c0 = c / HTP_MM_HMX_TILE_N_COLS; + const __fp16 *tile = row_base + c0 * HTP_MM_HMX_TILE_N_ELMS; + HVX_Vector v = ((const HVX_Vector *) tile)[r1]; + HVX_VectorPair vp = Q6_Wqf32_vmpy_VhfVhf(v, one); + + hvx_vec_store_u(output_row_base + c, valid_c * sizeof(float), Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vp))); + if (r + 1 < n_rows) { + hvx_vec_store_u(output_row_base + c + dst_stride, valid_c * sizeof(float), Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vp))); + } + } + } +} + +typedef struct { + const __fp16 *vtcm_src; + float *dst; + uint32_t n_tasks; + uint32_t n_tot_chunks; + uint32_t n_chunks_per_task; + uint32_t n_cols; + uint32_t dst_stride; // DDR row stride + uint32_t dst_cols; // Actual output columns + struct htp_thread_trace * traces; +} output_transfer_task_state_t; + +// activations : fp32 -> fp16 + +static void transfer_activation_chunk_fp32_to_fp16(__fp16 *restrict vtcm_dst, const float *restrict src, uint32_t n_rows, uint32_t k_block, uint32_t k_stride, uint32_t k_valid) { + const uint32_t n_rows_padded = hex_align_up(n_rows, HTP_MM_HMX_TILE_N_ROWS); + const uint32_t n_rows_tiled = (n_rows / HTP_MM_HMX_TILE_N_ROWS) * HTP_MM_HMX_TILE_N_ROWS; + + uint32_t r = 0; + + #pragma unroll(2) + for (r = 0; r < n_rows_tiled; r += 2) { + uint32_t r0 = r / HTP_MM_HMX_TILE_N_ROWS; // tile row index + uint32_t r1 = r % HTP_MM_HMX_TILE_N_ROWS; // intra-tile row idx + + const float *ptr_in0 = src + (r + 0) * k_stride; + const float *ptr_in1 = src + (r + 1) * k_stride; + + uint32_t c = 0; + for (; c + 32 <= k_valid; c += 32) { + HVX_Vector v0 = *(const HVX_Vector *)(ptr_in0 + c); + HVX_Vector v1 = *(const HVX_Vector *)(ptr_in1 + c); + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); + + uint32_t c0 = c / HTP_MM_HMX_TILE_N_COLS; // tile column index + uint32_t tile_idx = r0 * (k_block / HTP_MM_HMX_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HTP_MM_HMX_TILE_N_ELMS); + tile[r1 / 2] = v_out; + } + if (c < k_block) { + HVX_Vector v0 = *(const HVX_Vector *)(ptr_in0 + c); + HVX_Vector v1 = *(const HVX_Vector *)(ptr_in1 + c); + + uint32_t rem = k_valid - c; + HVX_VectorPred mask = Q6_Q_vsetq2_R(rem > 0 ? rem * sizeof(float) : 0); + v0 = Q6_V_vmux_QVV(mask, v0, Q6_V_vzero()); + v1 = Q6_V_vmux_QVV(mask, v1, Q6_V_vzero()); + + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); + + uint32_t c0 = c / HTP_MM_HMX_TILE_N_COLS; // tile column index + uint32_t tile_idx = r0 * (k_block / HTP_MM_HMX_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HTP_MM_HMX_TILE_N_ELMS); + tile[r1 / 2] = v_out; + } + } + + for (; r < n_rows_padded; r += 2) { + uint32_t r0 = r / HTP_MM_HMX_TILE_N_ROWS; // tile row index + uint32_t r1 = r % HTP_MM_HMX_TILE_N_ROWS; // intra-tile row idx + + const bool row0_valid = r < n_rows; + const bool row1_valid = (r + 1) < n_rows; + + const float *ptr_in0 = row0_valid ? (src + (r + 0) * k_stride) : NULL; + const float *ptr_in1 = row1_valid ? (src + (r + 1) * k_stride) : NULL; + + uint32_t c = 0; + for (; c + 32 <= k_valid; c += 32) { + HVX_Vector v0 = Q6_V_vzero(); + HVX_Vector v1 = Q6_V_vzero(); + if (row0_valid) v0 = *(const HVX_Vector *)(ptr_in0 + c); + if (row1_valid) v1 = *(const HVX_Vector *)(ptr_in1 + c); + + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); + + uint32_t c0 = c / HTP_MM_HMX_TILE_N_COLS; // tile column index + uint32_t tile_idx = r0 * (k_block / HTP_MM_HMX_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HTP_MM_HMX_TILE_N_ELMS); + tile[r1 / 2] = v_out; + } + if (c < k_block) { + HVX_Vector v0 = Q6_V_vzero(); + HVX_Vector v1 = Q6_V_vzero(); + if (row0_valid) v0 = *(const HVX_Vector *)(ptr_in0 + c); + if (row1_valid) v1 = *(const HVX_Vector *)(ptr_in1 + c); + + uint32_t rem = k_valid - c; + HVX_VectorPred mask = Q6_Q_vsetq2_R(rem > 0 ? rem * sizeof(float) : 0); + v0 = Q6_V_vmux_QVV(mask, v0, Q6_V_vzero()); + v1 = Q6_V_vmux_QVV(mask, v1, Q6_V_vzero()); + + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); + + uint32_t c0 = c / HTP_MM_HMX_TILE_N_COLS; // tile column index + uint32_t tile_idx = r0 * (k_block / HTP_MM_HMX_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HTP_MM_HMX_TILE_N_ELMS); + tile[r1 / 2] = v_out; + } + } +} + +typedef struct { + __fp16 *dst; + const float *src; + uint32_t n_tasks; + uint32_t n_tot_chunks; + uint32_t n_chunks_per_task; + uint32_t k_block; + uint32_t k_stride; + uint32_t k_valid; + struct htp_thread_trace * traces; + struct htp_context * ctx; + float * vtcm_f32_act; +} activation_transfer_task_state_t; + +static void transfer_activation_chunk_fp32_to_fp16_dma_pipelined( + dma_queue *dma_q, + __fp16 *restrict vtcm_dst, + const float *restrict src, + uint32_t n_rows, + uint32_t k_block, + uint32_t k_stride, + uint32_t k_valid, + float *thread_f32_act) { + + const uint32_t R = HTP_MM_DMA_ACT_ROWS_PER_STEP; + const uint32_t n_rows_padded = hex_align_up(n_rows, HTP_MM_HMX_TILE_N_ROWS); + + const uint32_t n_steps = n_rows_padded / R; + + // pre-fetch step 0 + if (n_steps > 0 && n_rows > 0) { + uint32_t nrows_to_fetch = hex_smin(n_rows, R); + dma_queue_push(dma_q, dma_make_ptr(thread_f32_act, src), + k_block * sizeof(float), k_stride * sizeof(float), k_valid * sizeof(float), nrows_to_fetch); + } + + for (uint32_t s = 0; s < n_steps; ++s) { + uint32_t r = R * s; + float *curr_buf = thread_f32_act + (s % 2) * R * k_block; + + if (r < n_rows) { + dma_queue_pop(dma_q); + } + + uint32_t next_s = s + 1; + uint32_t next_r = R * next_s; + if (next_r < n_rows) { + uint32_t nrows_to_fetch = hex_smin(n_rows - next_r, R); + const float *next_src = src + next_r * k_stride; + float *next_buf = thread_f32_act + (next_s % 2) * R * k_block; + dma_queue_push(dma_q, dma_make_ptr(next_buf, next_src), + k_block * sizeof(float), k_stride * sizeof(float), k_valid * sizeof(float), nrows_to_fetch); + } + + #pragma unroll + for (uint32_t i = 0; i < HTP_MM_DMA_ACT_ROWS_PER_STEP; i += 2) { + uint32_t curr_r = r + i; + const bool row0_valid = (curr_r < n_rows); + const bool row1_valid = (curr_r + 1) < n_rows; + + const float *ptr_in0 = curr_buf + i * k_block; + const float *ptr_in1 = curr_buf + (i + 1) * k_block; + + uint32_t c = 0; + for (; c + 32 <= k_valid; c += 32) { + HVX_Vector v0 = Q6_V_vzero(); + HVX_Vector v1 = Q6_V_vzero(); + if (row0_valid) v0 = *(const HVX_Vector *)(ptr_in0 + c); + if (row1_valid) v1 = *(const HVX_Vector *)(ptr_in1 + c); + + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); + + uint32_t r0 = curr_r / HTP_MM_HMX_TILE_N_ROWS; // tile row index + uint32_t r1 = curr_r % HTP_MM_HMX_TILE_N_ROWS; // intra-tile row idx + uint32_t c0 = c / HTP_MM_HMX_TILE_N_COLS; // tile column index + uint32_t tile_idx = r0 * (k_block / HTP_MM_HMX_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HTP_MM_HMX_TILE_N_ELMS); + tile[r1 / 2] = v_out; + } + if (c < k_block) { + HVX_Vector v0 = Q6_V_vzero(); + HVX_Vector v1 = Q6_V_vzero(); + if (row0_valid) v0 = *(const HVX_Vector *)(ptr_in0 + c); + if (row1_valid) v1 = *(const HVX_Vector *)(ptr_in1 + c); + + uint32_t rem = k_valid - c; + HVX_VectorPred mask = Q6_Q_vsetq2_R(rem > 0 ? rem * sizeof(float) : 0); + v0 = Q6_V_vmux_QVV(mask, v0, Q6_V_vzero()); + v1 = Q6_V_vmux_QVV(mask, v1, Q6_V_vzero()); + + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); + + uint32_t r0 = curr_r / HTP_MM_HMX_TILE_N_ROWS; // tile row index + uint32_t r1 = curr_r % HTP_MM_HMX_TILE_N_ROWS; // intra-tile row idx + uint32_t c0 = c / HTP_MM_HMX_TILE_N_COLS; // tile column index + uint32_t tile_idx = r0 * (k_block / HTP_MM_HMX_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HTP_MM_HMX_TILE_N_ELMS); + tile[r1 / 2] = v_out; + } + } + } +} + +typedef struct { + const struct mmid_row_mapping *matrix_rows; + __fp16 *dst; + const float *src; + uint32_t n_tasks; + uint32_t n_tot_chunks; + uint32_t n_chunks_per_task; + uint32_t k_block; + uint32_t cur_a; + uint32_t mapping_stride; + uint32_t ne11; + struct fastdiv_values ne11_div; + size_t nb11; + size_t nb12; + uint32_t start_row; + uint32_t cne1; + uint32_t k_valid; + struct htp_thread_trace *traces; +} activation_transfer_gathered_task_state_t; + +typedef struct { + const struct mmid_row_mapping *matrix_rows; + const __fp16 *vtcm_src; + float *dst; + uint32_t n_tasks; + uint32_t n_tot_chunks; + uint32_t n_chunks_per_task; + uint32_t n_cols; + uint32_t cur_a; + uint32_t mapping_stride; + size_t dst_nb1; + size_t dst_nb2; + uint32_t start_row; + uint32_t cne1; + struct htp_thread_trace *traces; +} output_transfer_scattered_task_state_t; + +static void transfer_activation_chunk_fp32_to_fp16_gathered( + __fp16 *restrict vtcm_dst, + const float *restrict src, + uint32_t start_row, + uint32_t n_rows, + uint32_t k_block, + const struct mmid_row_mapping *matrix_rows, + uint32_t cur_a, + uint32_t mapping_stride, + uint32_t ne11, + const struct fastdiv_values * ne11_div, + size_t nb11, + size_t nb12, + uint32_t cne1, + uint32_t k_valid) { + const uint32_t n_rows_padded = hex_align_up(n_rows, HTP_MM_HMX_TILE_N_ROWS); + const uint32_t n_rows_tiled = (n_rows / HTP_MM_HMX_TILE_N_ROWS) * HTP_MM_HMX_TILE_N_ROWS; + + uint32_t r = 0; + + #pragma unroll(2) + for (r = 0; r < n_rows_tiled; r += 2) { + uint32_t r_idx0 = start_row + r + 0; + uint32_t r_idx1 = start_row + r + 1; + uint32_t r0 = r_idx0 / HTP_MM_HMX_TILE_N_ROWS; // tile row index + uint32_t r1 = r_idx0 % HTP_MM_HMX_TILE_N_ROWS; // intra-tile row idx + + struct mmid_row_mapping mapping0 = matrix_rows[cur_a * mapping_stride + r_idx0]; + struct mmid_row_mapping mapping1 = matrix_rows[cur_a * mapping_stride + r_idx1]; + + uint32_t i11_0 = fastmodulo(mapping0.i1, ne11, ne11_div); + uint32_t i11_1 = fastmodulo(mapping1.i1, ne11, ne11_div); + + const float *row0_ptr = (const float *) ((const uint8_t *) src + i11_0 * nb11 + mapping0.i2 * nb12); + const float *row1_ptr = (const float *) ((const uint8_t *) src + i11_1 * nb11 + mapping1.i2 * nb12); + + uint32_t c = 0; + for (; c + 32 <= k_valid; c += 32) { + HVX_Vector v0 = *(const HVX_Vector *)(row0_ptr + c); + HVX_Vector v1 = *(const HVX_Vector *)(row1_ptr + c); + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); + + uint32_t c0 = c / HTP_MM_HMX_TILE_N_COLS; + uint32_t tile_idx = r0 * (k_block / HTP_MM_HMX_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HTP_MM_HMX_TILE_N_ELMS); + tile[r1 / 2] = v_out; + } + if (c < k_block) { + HVX_Vector v0 = *(const HVX_Vector *)(row0_ptr + c); + HVX_Vector v1 = *(const HVX_Vector *)(row1_ptr + c); + + uint32_t rem = k_valid - c; + HVX_VectorPred mask = Q6_Q_vsetq2_R(rem > 0 ? rem * sizeof(float) : 0); + v0 = Q6_V_vmux_QVV(mask, v0, Q6_V_vzero()); + v1 = Q6_V_vmux_QVV(mask, v1, Q6_V_vzero()); + + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); + + uint32_t c0 = c / HTP_MM_HMX_TILE_N_COLS; + uint32_t tile_idx = r0 * (k_block / HTP_MM_HMX_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HTP_MM_HMX_TILE_N_ELMS); + tile[r1 / 2] = v_out; + } + } + + for (; r < n_rows_padded; r += 2) { + uint32_t r_idx0 = start_row + r; + uint32_t r0 = r_idx0 / HTP_MM_HMX_TILE_N_ROWS; // tile row index + uint32_t r1 = r_idx0 % HTP_MM_HMX_TILE_N_ROWS; // intra-tile row idx + + const bool row0_valid = (start_row + r + 0) < cne1; + const bool row1_valid = (start_row + r + 1) < cne1; + + const float *row0_ptr = NULL; + const float *row1_ptr = NULL; + + if (row0_valid) { + struct mmid_row_mapping mapping0 = matrix_rows[cur_a * mapping_stride + (start_row + r + 0)]; + uint32_t i11_0 = fastmodulo(mapping0.i1, ne11, ne11_div); + row0_ptr = (const float *) ((const uint8_t *) src + i11_0 * nb11 + mapping0.i2 * nb12); + } + if (row1_valid) { + struct mmid_row_mapping mapping1 = matrix_rows[cur_a * mapping_stride + (start_row + r + 1)]; + uint32_t i11_1 = fastmodulo(mapping1.i1, ne11, ne11_div); + row1_ptr = (const float *) ((const uint8_t *) src + i11_1 * nb11 + mapping1.i2 * nb12); + } + + uint32_t c = 0; + for (; c + 32 <= k_valid; c += 32) { + HVX_Vector v0 = Q6_V_vzero(); + HVX_Vector v1 = Q6_V_vzero(); + if (row0_valid) v0 = *(const HVX_Vector *)(row0_ptr + c); + if (row1_valid) v1 = *(const HVX_Vector *)(row1_ptr + c); + + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); + + uint32_t c0 = c / HTP_MM_HMX_TILE_N_COLS; + uint32_t tile_idx = r0 * (k_block / HTP_MM_HMX_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HTP_MM_HMX_TILE_N_ELMS); + tile[r1 / 2] = v_out; + } + if (c < k_block) { + HVX_Vector v0 = Q6_V_vzero(); + HVX_Vector v1 = Q6_V_vzero(); + if (row0_valid) v0 = *(const HVX_Vector *)(row0_ptr + c); + if (row1_valid) v1 = *(const HVX_Vector *)(row1_ptr + c); + + uint32_t rem = k_valid - c; + HVX_VectorPred mask = Q6_Q_vsetq2_R(rem > 0 ? rem * sizeof(float) : 0); + v0 = Q6_V_vmux_QVV(mask, v0, Q6_V_vzero()); + v1 = Q6_V_vmux_QVV(mask, v1, Q6_V_vzero()); + + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); + + uint32_t c0 = c / HTP_MM_HMX_TILE_N_COLS; + uint32_t tile_idx = r0 * (k_block / HTP_MM_HMX_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HTP_MM_HMX_TILE_N_ELMS); + tile[r1 / 2] = v_out; + } + } +} + +static void transfer_activation_chunk_fp32_to_fp16_gathered_flat( + __fp16 *restrict vtcm_dst, + const float *restrict src, + uint32_t start_row, + uint32_t n_rows, + uint32_t k_block, + const struct mmid_row_mapping *matrix_rows, + uint32_t cur_a, + uint32_t mapping_stride, + size_t nb12, + uint32_t cne1, + uint32_t k_valid) { + const uint32_t n_rows_padded = hex_align_up(n_rows, HTP_MM_HMX_TILE_N_ROWS); + const uint32_t n_rows_tiled = (n_rows / HTP_MM_HMX_TILE_N_ROWS) * HTP_MM_HMX_TILE_N_ROWS; + + uint32_t r = 0; + + #pragma unroll(2) + for (r = 0; r < n_rows_tiled; r += 2) { + uint32_t r_idx0 = start_row + r + 0; + uint32_t r_idx1 = start_row + r + 1; + uint32_t r0 = r_idx0 / HTP_MM_HMX_TILE_N_ROWS; // tile row index + uint32_t r1 = r_idx0 % HTP_MM_HMX_TILE_N_ROWS; // intra-tile row idx + + struct mmid_row_mapping mapping0 = matrix_rows[cur_a * mapping_stride + r_idx0]; + struct mmid_row_mapping mapping1 = matrix_rows[cur_a * mapping_stride + r_idx1]; + + const float *row0_ptr = (const float *) ((const uint8_t *) src + mapping0.i2 * nb12); + const float *row1_ptr = (const float *) ((const uint8_t *) src + mapping1.i2 * nb12); + + uint32_t c = 0; + for (; c + 32 <= k_valid; c += 32) { + HVX_Vector v0 = *(const HVX_Vector *)(row0_ptr + c); + HVX_Vector v1 = *(const HVX_Vector *)(row1_ptr + c); + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); + + uint32_t c0 = c / HTP_MM_HMX_TILE_N_COLS; + uint32_t tile_idx = r0 * (k_block / HTP_MM_HMX_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HTP_MM_HMX_TILE_N_ELMS); + tile[r1 / 2] = v_out; + } + if (c < k_block) { + HVX_Vector v0 = *(const HVX_Vector *)(row0_ptr + c); + HVX_Vector v1 = *(const HVX_Vector *)(row1_ptr + c); + + uint32_t rem = k_valid - c; + HVX_VectorPred mask = Q6_Q_vsetq2_R(rem > 0 ? rem * sizeof(float) : 0); + v0 = Q6_V_vmux_QVV(mask, v0, Q6_V_vzero()); + v1 = Q6_V_vmux_QVV(mask, v1, Q6_V_vzero()); + + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); + + uint32_t c0 = c / HTP_MM_HMX_TILE_N_COLS; + uint32_t tile_idx = r0 * (k_block / HTP_MM_HMX_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HTP_MM_HMX_TILE_N_ELMS); + tile[r1 / 2] = v_out; + } + } + + for (; r < n_rows_padded; r += 2) { + uint32_t r_idx0 = start_row + r; + uint32_t r0 = r_idx0 / HTP_MM_HMX_TILE_N_ROWS; // tile row index + uint32_t r1 = r_idx0 % HTP_MM_HMX_TILE_N_ROWS; // intra-tile row idx + + const bool row0_valid = (start_row + r + 0) < cne1; + const bool row1_valid = (start_row + r + 1) < cne1; + + const float *row0_ptr = NULL; + const float *row1_ptr = NULL; + + if (row0_valid) { + struct mmid_row_mapping mapping0 = matrix_rows[cur_a * mapping_stride + (start_row + r + 0)]; + row0_ptr = (const float *) ((const uint8_t *) src + mapping0.i2 * nb12); + } + if (row1_valid) { + struct mmid_row_mapping mapping1 = matrix_rows[cur_a * mapping_stride + (start_row + r + 1)]; + row1_ptr = (const float *) ((const uint8_t *) src + mapping1.i2 * nb12); + } + + uint32_t c = 0; + for (; c + 32 <= k_valid; c += 32) { + HVX_Vector v0 = Q6_V_vzero(); + HVX_Vector v1 = Q6_V_vzero(); + if (row0_valid) v0 = *(const HVX_Vector *)(row0_ptr + c); + if (row1_valid) v1 = *(const HVX_Vector *)(row1_ptr + c); + + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); + + uint32_t c0 = c / HTP_MM_HMX_TILE_N_COLS; + uint32_t tile_idx = r0 * (k_block / HTP_MM_HMX_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HTP_MM_HMX_TILE_N_ELMS); + tile[r1 / 2] = v_out; + } + if (c < k_block) { + HVX_Vector v0 = Q6_V_vzero(); + HVX_Vector v1 = Q6_V_vzero(); + if (row0_valid) v0 = *(const HVX_Vector *)(row0_ptr + c); + if (row1_valid) v1 = *(const HVX_Vector *)(row1_ptr + c); + + uint32_t rem = k_valid - c; + HVX_VectorPred mask = Q6_Q_vsetq2_R(rem > 0 ? rem * sizeof(float) : 0); + v0 = Q6_V_vmux_QVV(mask, v0, Q6_V_vzero()); + v1 = Q6_V_vmux_QVV(mask, v1, Q6_V_vzero()); + + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); + + uint32_t c0 = c / HTP_MM_HMX_TILE_N_COLS; + uint32_t tile_idx = r0 * (k_block / HTP_MM_HMX_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HTP_MM_HMX_TILE_N_ELMS); + tile[r1 / 2] = v_out; + } + } +} + +static void transfer_output_chunk_fp16_to_fp32_scattered( + float *restrict dst, + const __fp16 *restrict vtcm_src, + uint32_t start_row, + uint32_t n_rows, + uint32_t n_cols, + const struct mmid_row_mapping *matrix_rows, + uint32_t cur_a, + uint32_t mapping_stride, + size_t dst_nb1, + size_t dst_nb2, + uint32_t cne1) { + assert(n_cols % HTP_MM_HMX_TILE_N_COLS == 0); + const size_t tile_row_stride = (n_cols / HTP_MM_HMX_TILE_N_COLS) * HTP_MM_HMX_TILE_N_ELMS; + + const HVX_Vector one = hvx_vec_splat_f16(1.0); + + for (size_t r = 0; r < n_rows; r += 2) { + uint32_t r_idx0 = start_row + r + 0; + uint32_t r_idx1 = start_row + r + 1; + const size_t r0 = r_idx0 / HTP_MM_HMX_TILE_N_ROWS; + const size_t r1 = (r_idx0 % HTP_MM_HMX_TILE_N_ROWS) / 2; // index of the row pair within the tile + const __fp16 *row_base = vtcm_src + r0 * tile_row_stride; + + if (r_idx0 >= cne1) break; + + struct mmid_row_mapping mapping0 = matrix_rows[cur_a * mapping_stride + r_idx0]; + float *output_row0 = (float *) ((uint8_t *) dst + mapping0.i1 * dst_nb1 + mapping0.i2 * dst_nb2); + + float *output_row1 = NULL; + if (r_idx1 < cne1) { + struct mmid_row_mapping mapping1 = matrix_rows[cur_a * mapping_stride + r_idx1]; + output_row1 = (float *) ((uint8_t *) dst + mapping1.i1 * dst_nb1 + mapping1.i2 * dst_nb2); + } + + #pragma unroll(4) + for (size_t c = 0; c < (size_t)n_cols; c += HTP_MM_HMX_TILE_N_COLS) { + const size_t c0 = c / HTP_MM_HMX_TILE_N_COLS; + const __fp16 *tile = row_base + c0 * HTP_MM_HMX_TILE_N_ELMS; + HVX_Vector v = ((const HVX_Vector *) tile)[r1]; + HVX_VectorPair vp = Q6_Wqf32_vmpy_VhfVhf(v, one); + + HVX_Vector *pv_out0 = (HVX_Vector *) (output_row0 + c); + HVX_Vector *pv_out1 = output_row1 ? (HVX_Vector *) (output_row1 + c) : NULL; + + *pv_out0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vp)); + if (pv_out1) { + *pv_out1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vp)); + } + } + } +} diff --git a/ggml/src/ggml-hexagon/htp/hmx-ops.c b/ggml/src/ggml-hexagon/htp/hmx-ops.c deleted file mode 100644 index 114d8c1481..0000000000 --- a/ggml/src/ggml-hexagon/htp/hmx-ops.c +++ /dev/null @@ -1,6 +0,0 @@ -// HMX operations compiled as a single translation unit. -// This allows interprocedural optimizations within HMX ops without requiring global HTP LTO. - -#include "hmx-queue.c" -#include "hmx-matmul-ops.c" -#include "hmx-flash-attn-ops.c" diff --git a/ggml/src/ggml-hexagon/htp/hmx-ops.h b/ggml/src/ggml-hexagon/htp/hmx-ops.h deleted file mode 100644 index a67842f3ff..0000000000 --- a/ggml/src/ggml-hexagon/htp/hmx-ops.h +++ /dev/null @@ -1,88 +0,0 @@ -// HMX operation entry-point declarations. -// Ported from htp-ops-lib/include/dsp/ops.h (renamed, benchmark kernels removed). (https://github.com/haozixu/htp-ops-lib) - -#ifndef HMX_OPS_H -#define HMX_OPS_H - -#include -#include - -#include "htp-ops.h" - -#ifdef __cplusplus -extern "C" { -#endif - -typedef struct { - float *dst; - const float *activation; - const __fp16 *permuted_weight; - int m; - int k; - int n; - int act_stride; - int weight_stride; - int dst_stride; - int ne02; - int ne03; - int ne12; - int ne13; - size_t src0_nb2; - size_t src0_nb3; - size_t src1_nb2; - size_t src1_nb3; - size_t dst_nb2; - size_t dst_nb3; -} hmx_matmul_f16_f32_batched_params_t; - -// HMX matrix multiplication — tile-permuted FP16 weights, FP32 activation/output -// act_stride: activation row stride in elements (= k for contiguous, or -// nb[1]/sizeof(float) for permuted tensors like attention Q). -// weight_stride: weight row stride in elements (= k for compact weights, or -// nb[1]/sizeof(__fp16) for permuted KV-cache views used by QK). -int hmx_matmul_f16_f32(struct htp_context *ctx, - float *restrict dst, - const float *activation, - const __fp16 *permuted_weight, - int m, int k, int n, - int act_stride, - int weight_stride); - -// Batched F16 wrapper over hmx_mat_mul_f16_f32. -// Batch semantics match ggml_mul_mat(): src0 broadcasts to src1 in dims 2/3. -int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32_batched_params_t *params); - -// HMX matrix multiplication — all supported weight types (F16/F32/Q4_0/Q4_1/Q8_0/IQ4_NL/MXFP4) -int hmx_matmul_2d_f32(struct htp_context *ctx, - float *restrict dst, - const float *activation, - const uint8_t *permuted_weight, - int m, int k, int n, - int act_stride, - int weight_stride, - int weight_type); - -struct mmid_row_mapping; - -int hmx_matmul_id_2d_f32(struct htp_context *ctx, - float *restrict dst, - const float *activation, - const uint8_t *permuted_weight, - int m, int k, int n, - int ne11, - size_t act_nb1, size_t act_nb2, - size_t dst_nb1, size_t dst_nb2, - int weight_stride, - int weight_type, - const struct mmid_row_mapping *matrix_rows, - int cur_a, - int mapping_stride); - -// HMX flash attention -int hmx_flash_attn_ext(struct htp_ops_context * octx); - -#ifdef __cplusplus -} -#endif - -#endif // HMX_OPS_H diff --git a/ggml/src/ggml-hexagon/htp/htp-ctx.h b/ggml/src/ggml-hexagon/htp/htp-ctx.h index cbb5d08786..6ad77d3daa 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ctx.h +++ b/ggml/src/ggml-hexagon/htp/htp-ctx.h @@ -13,7 +13,9 @@ #include #include +#ifndef HTP_MAX_NTHREADS #define HTP_MAX_NTHREADS 10 +#endif #define HTP_MAX_MMAPS 16 // Memory mapping @@ -42,9 +44,13 @@ struct htp_ops_context { enum htp_op_code op; // FIXME: rename to opcode int32_t op_params[HTP_OP_MAX_PARAMS]; + int32_t kernel_params[HTP_OP_MAX_KERN_PARAMS]; const struct htp_tensor * src[HTP_OP_MAX_INPUTS]; - const struct htp_tensor * dst; + union { + const struct htp_tensor * dst; + const struct htp_tensor * dsts[HTP_OP_MAX_OUTPUTS]; + }; // TODO convert these to an array struct htp_spad src0_spad; @@ -87,13 +93,13 @@ struct htp_context { struct htp_ops_context octx; -#ifdef HTP_HAS_HMX struct hmx_queue * hmx_queue; // Async HMX queue for pipeline overlap -#endif }; int op_matmul(struct htp_ops_context * octx); int op_matmul_id(struct htp_ops_context * octx); +int op_matmul_qkv(struct htp_ops_context * octx); +int op_matmul_ffn(struct htp_ops_context * octx); int op_binary(struct htp_ops_context * octx); int op_unary(struct htp_ops_context * octx); int op_sum_rows(struct htp_ops_context * octx); diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index 0f4b74a93a..d040901357 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -28,18 +28,19 @@ enum htp_data_type { HTP_TYPE_MXFP4 = 39, // types used internally for repack, dyn.quant, etc - HTP_TYPE_Q4_0x4x2 = 200, - HTP_TYPE_Q4_1x4x2, - HTP_TYPE_Q8_0x4x2, - HTP_TYPE_MXFP4x4x2, + HTP_TYPE_Q4_0_TILED = 200, + HTP_TYPE_Q4_1_TILED, + HTP_TYPE_Q8_0_TILED, + HTP_TYPE_MXFP4_TILED, HTP_TYPE_INVALID }; // Constats for internal types -#define QK_Q4_0x4x2 256 // 4x Q4_0 blocks packed with next 4x Q4_0 blocks (size in bytes 128) -#define QK_Q8_0x4x2 256 // 4x Q8_0 blocks concat with next 4x Q8_0 blocks -#define QK_MXFP4x4x2 256 // 4x MXFP4 blocks concat with next 4x MXFP4 blocks +#define QK_Q4_0_TILED 256 // 32x32 Q4_0 tiled layout +#define QK_Q8_0_TILED 128 // 32x32 Q8_0 tiled layout +#define QK_MXFP4_TILED 256 // 32x32 MXFP4 tiled layout + // Mask to enable various stages of the Ops. @@ -57,6 +58,8 @@ enum htp_op_code { HTP_OP_DIV = 3, HTP_OP_MUL_MAT, HTP_OP_MUL_MAT_ID, + HTP_OP_MUL_MAT_QKV, + HTP_OP_MUL_MAT_FFN, HTP_OP_RMS_NORM, HTP_OP_RMS_NORM_MUL, HTP_OP_UNARY_SILU, @@ -99,7 +102,9 @@ enum htp_op_code { #define HTP_OP_MAX_DIMS 4 // aka GGML_MAX_DIMS #define HTP_OP_MAX_INPUTS 6 // aka GGML_MAX_SRCS +#define HTP_OP_MAX_OUTPUTS 4 #define HTP_OP_MAX_PARAMS 16 // aka GGML_MAX_OP_PARAMS +#define HTP_OP_MAX_KERN_PARAMS 32 #define HTP_OP_MAX_BUFS 16 #define HTP_OP_MAX_REQS 256 @@ -142,8 +147,10 @@ struct htp_op_desc { uint32_t opcode; // GGML/HTP Op uint32_t flags; // Op flags int32_t params[HTP_OP_MAX_PARAMS]; // Params for the op, e.g. epsilon of RMS norm + int32_t kernel_params[HTP_OP_MAX_KERN_PARAMS]; // generic blob for host-precomputed parameters uint16_t src[HTP_OP_MAX_INPUTS]; // Input tensors indices - uint16_t dst; // Output tensor index + uint16_t dst[HTP_OP_MAX_OUTPUTS]; // Output tensor indices + uint16_t pad[2]; // padding to align to 64 bits }; #ifndef HTP_MAX_NTHREADS diff --git a/ggml/src/ggml-hexagon/htp/htp_iface.idl b/ggml/src/ggml-hexagon/htp/htp_iface.idl index d696a5fba0..47693d8b8b 100644 --- a/ggml/src/ggml-hexagon/htp/htp_iface.idl +++ b/ggml/src/ggml-hexagon/htp/htp_iface.idl @@ -11,12 +11,13 @@ struct htp_iface_pmu_conf { }; interface htp_iface : remote_handle64 { - AEEResult start(in uint32 sess_id, in uint64 dsp_queue_id, in uint32 n_hvx, in uint32 use_hmx, in uint64 max_vmem); + AEEResult start(in uint32 sess_id, in uint64 dsp_queue_id, in uint32 n_hvx, in uint32 n_hmx, in uint64 max_vmem); AEEResult stop(); AEEResult mmap(in uint32 fd, in uint32 size); AEEResult munmap(in uint32 fd); AEEResult profiler(in uint32 mode, in htp_iface_pmu_conf pmu); AEEResult etm(in uint32 enable); + AEEResult hwinfo(rout uint32 n_threads, rout uint32 n_hvx, rout uint32 n_hmx, rout uint64 vtcm_size); }; #endif /* HTP_IDL */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-base.h b/ggml/src/ggml-hexagon/htp/hvx-base.h index f6cb02951d..493b26c6e7 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-base.h +++ b/ggml/src/ggml-hexagon/htp/hvx-base.h @@ -170,25 +170,7 @@ static inline HVX_VectorPair hvx_vec_f16_to_f32(HVX_Vector v) { } #endif -/* Q6_Vsf_equals_Vw is only available on v73+.*/ -#if __HVX_ARCH__ < 73 -static inline HVX_Vector hvx_vec_i32_to_qf32(HVX_Vector const in) -{ - HVX_Vector const vzero = Q6_V_vzero(); - HVX_VectorPred is_zero = Q6_Q_vcmp_eq_VwVw(in, vzero); - HVX_Vector lshift = Q6_Vw_vnormamt_Vw(in); - HVX_Vector normalized = Q6_Vw_vasl_VwVw(in, lshift); - HVX_Vector vexp = Q6_Vw_vsub_VwVw(Q6_V_vsplat_R(0x7f + 30), lshift); - HVX_Vector mant = Q6_V_vand_VV(Q6_V_vsplat_R(0xFFFFFF00), normalized); - HVX_Vector ret = Q6_V_vmux_QVV(is_zero, vzero, Q6_Vw_vadd_VwVw(mant, vexp)); - return ret; -} -static inline HVX_Vector Q6_Vsf_equals_Vw(HVX_Vector const in) -{ - return Q6_Vsf_equals_Vqf32(hvx_vec_i32_to_qf32(in)); -} -#endif static inline HVX_Vector hvx_vec_i16_from_hf_rnd_sat(HVX_Vector vin) { // This looks complicated. @@ -305,4 +287,17 @@ static inline HVX_Vector hvx_vec_mul_f32_f32(HVX_Vector a, HVX_Vector b) { #endif // __HVX_ARCH__ < 79 +static inline HVX_Vector hvx_vec_load_act_tile(const uint8_t * y_q, uint32_t kt, HVX_Vector * v_act_all) { + if (kt % 4 == 0) { + *v_act_all = hvx_vmem(y_q + kt * 32); + return *v_act_all; + } else if (kt % 4 == 1) { + return Q6_V_vror_VR(*v_act_all, 32); + } else if (kt % 4 == 2) { + return Q6_V_vror_VR(*v_act_all, 64); + } else { + return Q6_V_vror_VR(*v_act_all, 96); + } +} + #endif /* HVX_BASE_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-mm-kernels-flat.h b/ggml/src/ggml-hexagon/htp/hvx-mm-kernels-flat.h new file mode 100644 index 0000000000..52351b1039 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-mm-kernels-flat.h @@ -0,0 +1,1024 @@ +// Dynamic quantizers that produce flat (non-tiled) activations + +static inline void quantize_block_f32_q8_0_flat( + float * restrict x, + uint8_t * restrict y_quants, + __fp16 * restrict y_scales, + uint32_t block_idx +) { + HVX_Vector * vx = (HVX_Vector *) x; + HVX_Vector zero = Q6_V_vzero(); + + HVX_Vector vmax0_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[0])); + HVX_Vector vmax1_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[1])); + HVX_Vector vmax2_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[2])); + HVX_Vector vmax3_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[3])); + + HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); + HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); + HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); + HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); + + HVX_Vector vmax0_qf = Q6_Vqf32_vsub_VsfVsf(vmax0_sf, zero); + HVX_Vector vmax1_qf = Q6_Vqf32_vsub_VsfVsf(vmax1_sf, zero); + HVX_Vector vmax2_qf = Q6_Vqf32_vsub_VsfVsf(vmax2_sf, zero); + HVX_Vector vmax3_qf = Q6_Vqf32_vsub_VsfVsf(vmax3_sf, zero); + + HVX_Vector vmax01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax1_qf, vmax0_qf))); + HVX_Vector vmax23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax3_qf, vmax2_qf))); + + HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf))); + HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf))); + + HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 + HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 + HVX_Vector vd01_hf = Q6_Vhf_equals_Vqf16(vd01_qf16); + HVX_Vector vd23_hf = Q6_Vhf_equals_Vqf16(vd23_qf16); + + HVX_Vector vd01_inv_hf = hvx_vec_inverse_f16(vd01_hf); + HVX_Vector vd23_inv_hf = hvx_vec_inverse_f16(vd23_hf); + vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd01_inv_hf)); + vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd23_inv_hf)); + + HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf); + HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf); + HVX_Vector vx_i8 = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16); + + * (HVX_Vector *) (y_quants + block_idx * 128) = vx_i8; + + HVX_VectorPair vp1 = Q6_W_vshuff_VVR(vd23_hf, vd01_hf, -2); + HVX_VectorPair vp2 = Q6_W_vshuff_VVR(Q6_V_hi_W(vp1), Q6_V_lo_W(vp1), -2); + HVX_Vector v_scales = Q6_V_lo_W(vp2); + hvx_vec_store_u(y_scales + block_idx * 4, 8, v_scales); +} + +static inline void quantize_block_f32_q8_1_flat( + float * restrict x, + uint8_t * restrict y_quants, + __fp16 * restrict y_scales, + uint32_t block_idx +) { + HVX_Vector * vx = (HVX_Vector *) x; + HVX_Vector zero = Q6_V_vzero(); + + HVX_Vector vmax0_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[0])); + HVX_Vector vmax1_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[1])); + HVX_Vector vmax2_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[2])); + HVX_Vector vmax3_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[3])); + + HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); + HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); + HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); + HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); + + HVX_Vector vmax0_qf = Q6_Vqf32_vsub_VsfVsf(vmax0_sf, zero); + HVX_Vector vmax1_qf = Q6_Vqf32_vsub_VsfVsf(vmax1_sf, zero); + HVX_Vector vmax2_qf = Q6_Vqf32_vsub_VsfVsf(vmax2_sf, zero); + HVX_Vector vmax3_qf = Q6_Vqf32_vsub_VsfVsf(vmax3_sf, zero); + + HVX_Vector vmax01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax1_qf, vmax0_qf))); + HVX_Vector vmax23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax3_qf, vmax2_qf))); + + HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf))); + HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf))); + + HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 + HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 + HVX_Vector vd01_hf = Q6_Vhf_equals_Vqf16(vd01_qf16); + HVX_Vector vd23_hf = Q6_Vhf_equals_Vqf16(vd23_qf16); + + HVX_Vector vd01_inv_hf = hvx_vec_inverse_f16(vd01_hf); + HVX_Vector vd23_inv_hf = hvx_vec_inverse_f16(vd23_hf); + vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd01_inv_hf)); + vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd23_inv_hf)); + + HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf); + HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf); + HVX_Vector vx_i8 = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16); + + const HVX_Vector ones = Q6_Vb_vsplat_R(1); + HVX_Vector v_sums = Q6_Vw_vrmpy_VbVb(vx_i8, ones); + v_sums = Q6_Vw_vadd_VwVw(v_sums, Q6_V_vror_VR(v_sums, 4)); + v_sums = Q6_Vw_vadd_VwVw(v_sums, Q6_V_vror_VR(v_sums, 8)); + v_sums = Q6_Vw_vadd_VwVw(v_sums, Q6_V_vror_VR(v_sums, 16)); + + * (HVX_Vector *) (y_quants + block_idx * 128) = vx_i8; + + HVX_VectorPair vp1 = Q6_W_vshuff_VVR(vd23_hf, vd01_hf, -2); + HVX_VectorPair vp2 = Q6_W_vshuff_VVR(Q6_V_hi_W(vp1), Q6_V_lo_W(vp1), -2); + HVX_Vector v_scales = Q6_V_lo_W(vp2); + + HVX_VectorPair v_deal1 = Q6_W_vdeal_VVR(v_sums, v_sums, -4); + HVX_Vector v_even1 = Q6_V_lo_W(v_deal1); + HVX_VectorPair v_deal2 = Q6_W_vdeal_VVR(v_even1, v_even1, -4); + HVX_Vector v_even2 = Q6_V_lo_W(v_deal2); + HVX_VectorPair v_deal3 = Q6_W_vdeal_VVR(v_even2, v_even2, -4); + HVX_Vector v_sums_shuffled = Q6_V_lo_W(v_deal3); + + HVX_Vector v_sums_sf = Q6_Vsf_equals_Vw(v_sums_shuffled); + HVX_Vector v_sums_hf = hvx_vec_f32_to_f16(v_sums_sf, Q6_V_vzero()); + + HVX_Vector v_prod = hvx_vec_mul_f16_f16(v_scales, v_sums_hf); + + HVX_VectorPair vp_scales = Q6_W_vshuff_VVR(v_prod, v_scales, -2); + HVX_Vector v_final = Q6_V_lo_W(vp_scales); + + hvx_vec_store_u(y_scales + block_idx * 8, 16, v_final); +} + +static inline void quantize_row_f32_q8_0_flat(float * restrict x, uint8_t * restrict y, uint32_t k) { + assert(k % 32 == 0); + const uint32_t quants_size = hex_round_up(k, 128); + uint8_t * restrict y_quants = y; + __fp16 * restrict y_scales = (__fp16 *) (y + quants_size); + + const uint32_t nb = (k + 127) / 128; + for (uint32_t i = 0; i < nb; i++) { + quantize_block_f32_q8_0_flat(x + i * 128, y_quants, y_scales, i); + } +} + +static inline void quantize_row_f32_q8_1_flat(float * restrict x, uint8_t * restrict y, uint32_t k) { + assert(k % 32 == 0); + const uint32_t quants_size = hex_round_up(k, 128); + uint8_t * restrict y_quants = y; + __fp16 * restrict y_scales = (__fp16 *) (y + quants_size); + + const uint32_t nb = (k + 127) / 128; + for (uint32_t i = 0; i < nb; i++) { + quantize_block_f32_q8_1_flat(x + i * 128, y_quants, y_scales, i); + } +} + +static inline void quantize_f32_q8_0_flat_kernel( + const uint8_t * restrict src_data, + uint8_t * restrict dst_data, + uint8_t * restrict tmp_data, + uint32_t ne0, + uint32_t nrows, + size_t src_row_size, + size_t dst_row_size +) { + const size_t src_row_size_padded = hex_round_up(src_row_size, QK_Q8_0_TILED * sizeof(float)); + hvx_splat_f32_a(tmp_data, 0.0f, src_row_size_padded / sizeof(float)); + + for (uint32_t i = 0; i < nrows; ++i) { + hex_l2fetch(src_data, src_row_size, src_row_size, 2); + hvx_copy_f32_aa(tmp_data, src_data, ne0); + + quantize_row_f32_q8_0_flat((float *) tmp_data, dst_data, ne0); + dst_data += dst_row_size; + src_data += src_row_size; + } +} + +static inline void quantize_f32_q8_1_flat_kernel( + const uint8_t * restrict src_data, + uint8_t * restrict dst_data, + uint8_t * restrict tmp_data, + uint32_t ne0, + uint32_t nrows, + size_t src_row_size, + size_t dst_row_size +) { + const size_t src_row_size_padded = hex_round_up(src_row_size, QK_Q8_0_TILED * sizeof(float)); + hvx_splat_f32_a(tmp_data, 0.0f, src_row_size_padded / sizeof(float)); + + for (uint32_t i = 0; i < nrows; ++i) { + hex_l2fetch(src_data, src_row_size, src_row_size, 2); + hvx_copy_f32_aa(tmp_data, src_data, ne0); + + quantize_row_f32_q8_1_flat((float *) tmp_data, dst_data, ne0); + dst_data += dst_row_size; + src_data += src_row_size; + } +} + +static inline void quantize_f32_f32_flat_kernel( + const uint8_t * restrict src_data, + uint8_t * restrict dst_data, + uint8_t * restrict tmp_data, + uint32_t ne0, + uint32_t nrows, + size_t src_stride, + size_t dst_stride +) { + (void) tmp_data; + const size_t src_row_size = ne0 * sizeof(float); + for (uint32_t i = 0; i < nrows; ++i) { + hex_l2fetch(src_data, src_row_size, src_stride, 2); + hvx_copy_f32_au(dst_data, src_data, ne0); + + dst_data += dst_stride; + src_data += src_stride; + } +} + +static inline void quantize_f32_f16_flat_kernel( + const uint8_t * restrict src_data, + uint8_t * restrict dst_data, + uint8_t * restrict tmp_data, + uint32_t ne0, + uint32_t nrows, + size_t src_stride, + size_t dst_stride +) { + (void) tmp_data; + const size_t src_row_size = ne0 * sizeof(float); + for (uint32_t i = 0; i < nrows; ++i) { + hex_l2fetch(src_data, src_row_size, src_stride, 2); + hvx_copy_f16_f32_au(dst_data, src_data, ne0); + + dst_data += dst_stride; + src_data += src_stride; + } +} + +static inline void quantize_f16_f16_flat_kernel( + const uint8_t * restrict src_data, + uint8_t * restrict dst_data, + uint8_t * restrict tmp_data, + uint32_t ne0, + uint32_t nrows, + size_t src_stride, + size_t dst_stride +) { + (void) tmp_data; + const size_t src_row_size = ne0 * sizeof(float); + for (uint32_t i = 0; i < nrows; ++i) { + hex_l2fetch(src_data, src_row_size, src_stride, 2); + hvx_copy_f16_au(dst_data, src_data, ne0); + + dst_data += dst_stride; + src_data += src_stride; + } +} + +// Dot kernels that consume flat (non-tiled) activations + +static void flat_vec_dot_q4_0_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) { + const uint8_t * restrict tile_ptr = vx; + const uint8_t * restrict y_q = vy; + + HVX_Vector v_sum_float = Q6_V_vzero(); + HVX_Vector i8 = Q6_Vb_vsplat_R(8); + + static const uint8_t __attribute__((aligned(128))) repl[128] = { + 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + }; + HVX_Vector v_repl_ctrl = * (const HVX_Vector *) repl; + + const uint32_t quants_size = hex_round_up(n, 128); + const __fp16 * restrict y_scales = (const __fp16 *) (y_q + quants_size); + + uint32_t n_k_tiles = n / 32; + for (uint32_t kt = 0; kt < n_k_tiles; kt++) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 640); + + uint32_t block_idx = kt / 4; + uint32_t sub_idx = kt % 4; + + HVX_Vector vx_i8 = * (const HVX_Vector *) (y_q + block_idx * 128); + HVX_Vector v_act_raw = Q6_V_vror_VR(vx_i8, sub_idx * 32); + + HVX_Vector v_act_rep[8]; + v_act_rep[0] = Q6_V_vdelta_VV(v_act_raw, v_repl_ctrl); + v_act_rep[1] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 4), v_repl_ctrl); + v_act_rep[2] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 8), v_repl_ctrl); + v_act_rep[3] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 12), v_repl_ctrl); + v_act_rep[4] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 16), v_repl_ctrl); + v_act_rep[5] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 20), v_repl_ctrl); + v_act_rep[6] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 24), v_repl_ctrl); + v_act_rep[7] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 28), v_repl_ctrl); + + HVX_Vector v_sum = accum_4bit_32x1(vptr, v_act_rep, i8); + HVX_Vector v_sum_sf = Q6_Vsf_equals_Vw(v_sum); + + HVX_Vector v_scale_w = vptr[4]; + + __fp16 scale_a_val = y_scales[kt]; + HVX_Vector v_scale_a = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&scale_a_val)); + + HVX_Vector v_scale_comb = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a); + HVX_Vector v_sum_scaled = hvx_vec_mul_f32_f32(v_sum_sf, v_scale_comb); + + v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled); + } + + hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float); +} + +static void flat_vec_dot_q4_0_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) { + const uint8_t * restrict tile_ptr = vx; + const uint8_t * restrict y0_q = vy0; + const uint8_t * restrict y1_q = vy1; + + HVX_Vector v_sum_float_c0 = Q6_V_vzero(); + HVX_Vector v_sum_float_c1 = Q6_V_vzero(); + HVX_Vector i8 = Q6_Vb_vsplat_R(8); + + static const uint8_t __attribute__((aligned(128))) repl[128] = { + 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + }; + HVX_Vector v_repl_ctrl = * (const HVX_Vector *) repl; + + const uint32_t quants_size = hex_round_up(n, 128); + const __fp16 * restrict y0_scales = (const __fp16 *) (y0_q + quants_size); + const __fp16 * restrict y1_scales = (const __fp16 *) (y1_q + quants_size); + + uint32_t n_k_tiles = n / 32; + for (uint32_t kt = 0; kt < n_k_tiles; kt++) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 640); + + uint32_t block_idx = kt / 4; + uint32_t sub_idx = kt % 4; + + HVX_Vector vx0_i8 = * (const HVX_Vector *) (y0_q + block_idx * 128); + HVX_Vector vx1_i8 = * (const HVX_Vector *) (y1_q + block_idx * 128); + + HVX_Vector v_act0_raw = Q6_V_vror_VR(vx0_i8, sub_idx * 32); + HVX_Vector v_act1_raw = Q6_V_vror_VR(vx1_i8, sub_idx * 32); + + HVX_Vector v_act0_rep[8]; + v_act0_rep[0] = Q6_V_vdelta_VV(v_act0_raw, v_repl_ctrl); + v_act0_rep[1] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 4), v_repl_ctrl); + v_act0_rep[2] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 8), v_repl_ctrl); + v_act0_rep[3] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 12), v_repl_ctrl); + v_act0_rep[4] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 16), v_repl_ctrl); + v_act0_rep[5] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 20), v_repl_ctrl); + v_act0_rep[6] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 24), v_repl_ctrl); + v_act0_rep[7] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 28), v_repl_ctrl); + + HVX_Vector v_act1_rep[8]; + v_act1_rep[0] = Q6_V_vdelta_VV(v_act1_raw, v_repl_ctrl); + v_act1_rep[1] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 4), v_repl_ctrl); + v_act1_rep[2] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 8), v_repl_ctrl); + v_act1_rep[3] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 12), v_repl_ctrl); + v_act1_rep[4] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 16), v_repl_ctrl); + v_act1_rep[5] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 20), v_repl_ctrl); + v_act1_rep[6] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 24), v_repl_ctrl); + v_act1_rep[7] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 28), v_repl_ctrl); + + HVX_VectorPair v_sums = accum_4bit_32x2(vptr, v_act0_rep, v_act1_rep, i8); + HVX_Vector v_sum_c0 = Q6_V_lo_W(v_sums); + HVX_Vector v_sum_c1 = Q6_V_hi_W(v_sums); + + HVX_Vector v_sum_sf_c0 = Q6_Vsf_equals_Vw(v_sum_c0); + HVX_Vector v_sum_sf_c1 = Q6_Vsf_equals_Vw(v_sum_c1); + + HVX_Vector v_scale_w = vptr[4]; + + __fp16 scale_a0_val = y0_scales[kt]; + __fp16 scale_a1_val = y1_scales[kt]; + HVX_Vector v_scale_a0 = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&scale_a0_val)); + HVX_Vector v_scale_a1 = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&scale_a1_val)); + + HVX_Vector v_scale_comb_c0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a0); + HVX_Vector v_scale_comb_c1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a1); + + HVX_Vector v_sum_scaled_c0 = hvx_vec_mul_f32_f32(v_sum_sf_c0, v_scale_comb_c0); + HVX_Vector v_sum_scaled_c1 = hvx_vec_mul_f32_f32(v_sum_sf_c1, v_scale_comb_c1); + + v_sum_float_c0 = hvx_vec_add_f32_f32(v_sum_float_c0, v_sum_scaled_c0); + v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1); + } + + hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0); + hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1); +} + +static void flat_vec_dot_q4_1_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) { + const uint8_t * restrict tile_ptr = vx; + const uint8_t * restrict y_q = vy; + + HVX_Vector v_sum_float = Q6_V_vzero(); + + static const uint8_t __attribute__((aligned(128))) repl[128] = { + 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + }; + HVX_Vector v_repl_ctrl = * (const HVX_Vector *) repl; + + const uint32_t quants_size = hex_round_up(n, 128); + const __fp16 * restrict y_scales = (const __fp16 *) (y_q + quants_size); + + uint32_t n_k_tiles = n / 32; + for (uint32_t kt = 0; kt < n_k_tiles; kt++) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 640); + + uint32_t block_idx = kt / 4; + uint32_t sub_idx = kt % 4; + + HVX_Vector vx_i8 = * (const HVX_Vector *) (y_q + block_idx * 128); + HVX_Vector v_act_raw = Q6_V_vror_VR(vx_i8, sub_idx * 32); + + HVX_Vector v_act_rep[8]; + v_act_rep[0] = Q6_V_vdelta_VV(v_act_raw, v_repl_ctrl); + v_act_rep[1] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 4), v_repl_ctrl); + v_act_rep[2] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 8), v_repl_ctrl); + v_act_rep[3] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 12), v_repl_ctrl); + v_act_rep[4] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 16), v_repl_ctrl); + v_act_rep[5] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 20), v_repl_ctrl); + v_act_rep[6] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 24), v_repl_ctrl); + v_act_rep[7] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 28), v_repl_ctrl); + + HVX_Vector v_sum = accum_4bit_32x1(vptr, v_act_rep, Q6_V_vzero()); + HVX_Vector v_sum_sf = Q6_Vsf_equals_Vw(v_sum); + + HVX_Vector v_scale_offset = vptr[4]; + HVX_VectorPair p_deal = Q6_W_vdeal_VVR(v_scale_offset, v_scale_offset, -2); + HVX_Vector v_scale = Q6_V_lo_W(p_deal); + HVX_Vector v_offset = Q6_V_hi_W(p_deal); + + __fp16 scale_a_val = y_scales[kt * 2 + 0]; + __fp16 sum_a_val = y_scales[kt * 2 + 1]; + HVX_Vector v_scale_a = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&scale_a_val)); + HVX_Vector v_sum_a = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&sum_a_val)); + + HVX_Vector v_scale_comb = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale, v_scale_a); + HVX_Vector v_offset_comb = hvx_vec_mul_f16_f16_to_f32_lower32(v_offset, v_sum_a); + + HVX_Vector v_scaled_dot = hvx_vec_mul_f32_f32(v_sum_sf, v_scale_comb); + HVX_Vector v_sum_scaled = hvx_vec_add_f32_f32(v_scaled_dot, v_offset_comb); + + v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled); + } + + hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float); +} + +static void flat_vec_dot_q4_1_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) { + const uint8_t * restrict tile_ptr = vx; + const uint8_t * restrict y0_q = vy0; + const uint8_t * restrict y1_q = vy1; + + HVX_Vector v_sum_float_c0 = Q6_V_vzero(); + HVX_Vector v_sum_float_c1 = Q6_V_vzero(); + + static const uint8_t __attribute__((aligned(128))) repl[128] = { + 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + }; + HVX_Vector v_repl_ctrl = * (const HVX_Vector *) repl; + + const uint32_t quants_size = hex_round_up(n, 128); + const __fp16 * restrict y0_scales = (const __fp16 *) (y0_q + quants_size); + const __fp16 * restrict y1_scales = (const __fp16 *) (y1_q + quants_size); + + uint32_t n_k_tiles = n / 32; + for (uint32_t kt = 0; kt < n_k_tiles; kt++) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 640); + + uint32_t block_idx = kt / 4; + uint32_t sub_idx = kt % 4; + + HVX_Vector vx0_i8 = * (const HVX_Vector *) (y0_q + block_idx * 128); + HVX_Vector vx1_i8 = * (const HVX_Vector *) (y1_q + block_idx * 128); + + HVX_Vector v_act0_raw = Q6_V_vror_VR(vx0_i8, sub_idx * 32); + HVX_Vector v_act1_raw = Q6_V_vror_VR(vx1_i8, sub_idx * 32); + + HVX_Vector v_act0_rep[8]; + v_act0_rep[0] = Q6_V_vdelta_VV(v_act0_raw, v_repl_ctrl); + v_act0_rep[1] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 4), v_repl_ctrl); + v_act0_rep[2] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 8), v_repl_ctrl); + v_act0_rep[3] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 12), v_repl_ctrl); + v_act0_rep[4] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 16), v_repl_ctrl); + v_act0_rep[5] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 20), v_repl_ctrl); + v_act0_rep[6] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 24), v_repl_ctrl); + v_act0_rep[7] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 28), v_repl_ctrl); + + HVX_Vector v_act1_rep[8]; + v_act1_rep[0] = Q6_V_vdelta_VV(v_act1_raw, v_repl_ctrl); + v_act1_rep[1] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 4), v_repl_ctrl); + v_act1_rep[2] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 8), v_repl_ctrl); + v_act1_rep[3] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 12), v_repl_ctrl); + v_act1_rep[4] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 16), v_repl_ctrl); + v_act1_rep[5] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 20), v_repl_ctrl); + v_act1_rep[6] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 24), v_repl_ctrl); + v_act1_rep[7] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 28), v_repl_ctrl); + + HVX_VectorPair v_sums = accum_4bit_32x2(vptr, v_act0_rep, v_act1_rep, Q6_V_vzero()); + HVX_Vector v_sum_c0 = Q6_V_lo_W(v_sums); + HVX_Vector v_sum_c1 = Q6_V_hi_W(v_sums); + + HVX_Vector v_sum_sf_c0 = Q6_Vsf_equals_Vw(v_sum_c0); + HVX_Vector v_sum_sf_c1 = Q6_Vsf_equals_Vw(v_sum_c1); + + HVX_Vector v_scale_offset = vptr[4]; + HVX_VectorPair p_deal = Q6_W_vdeal_VVR(v_scale_offset, v_scale_offset, -2); + HVX_Vector v_scale = Q6_V_lo_W(p_deal); + HVX_Vector v_offset = Q6_V_hi_W(p_deal); + + __fp16 scale_a0_val = y0_scales[kt * 2 + 0]; + __fp16 sum_a0_val = y0_scales[kt * 2 + 1]; + __fp16 scale_a1_val = y1_scales[kt * 2 + 0]; + __fp16 sum_a1_val = y1_scales[kt * 2 + 1]; + + HVX_Vector v_scale_a0 = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&scale_a0_val)); + HVX_Vector v_sum_a0 = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&sum_a0_val)); + HVX_Vector v_scale_a1 = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&scale_a1_val)); + HVX_Vector v_sum_a1 = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&sum_a1_val)); + + HVX_Vector v_scale_comb_c0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale, v_scale_a0); + HVX_Vector v_offset_comb_c0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_offset, v_sum_a0); + HVX_Vector v_scale_comb_c1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale, v_scale_a1); + HVX_Vector v_offset_comb_c1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_offset, v_sum_a1); + + HVX_Vector v_scaled_dot_c0 = hvx_vec_mul_f32_f32(v_sum_sf_c0, v_scale_comb_c0); + HVX_Vector v_sum_scaled_c0 = hvx_vec_add_f32_f32(v_scaled_dot_c0, v_offset_comb_c0); + + HVX_Vector v_scaled_dot_c1 = hvx_vec_mul_f32_f32(v_sum_sf_c1, v_scale_comb_c1); + HVX_Vector v_sum_scaled_c1 = hvx_vec_add_f32_f32(v_scaled_dot_c1, v_offset_comb_c1); + + v_sum_float_c0 = hvx_vec_add_f32_f32(v_sum_float_c0, v_sum_scaled_c0); + v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1); + } + + hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0); + hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1); +} + +static void flat_vec_dot_q8_0_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) { + const uint8_t * restrict tile_ptr = vx; + const uint8_t * restrict y_q = vy; + + HVX_Vector v_sum_float = Q6_V_vzero(); + + static const uint8_t __attribute__((aligned(128))) repl[128] = { + 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + }; + HVX_Vector v_repl_ctrl = * (const HVX_Vector *) repl; + + const uint32_t quants_size = hex_round_up(n, 128); + const __fp16 * restrict y_scales = (const __fp16 *) (y_q + quants_size); + + uint32_t n_k_tiles = n / 32; + for (uint32_t kt = 0; kt < n_k_tiles; kt++) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 1152); + + uint32_t block_idx = kt / 4; + uint32_t sub_idx = kt % 4; + + HVX_Vector vx_i8 = * (const HVX_Vector *) (y_q + block_idx * 128); + HVX_Vector v_act_raw = Q6_V_vror_VR(vx_i8, sub_idx * 32); + + HVX_Vector v_act_rep[8]; + v_act_rep[0] = Q6_V_vdelta_VV(v_act_raw, v_repl_ctrl); + v_act_rep[1] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 4), v_repl_ctrl); + v_act_rep[2] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 8), v_repl_ctrl); + v_act_rep[3] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 12), v_repl_ctrl); + v_act_rep[4] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 16), v_repl_ctrl); + v_act_rep[5] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 20), v_repl_ctrl); + v_act_rep[6] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 24), v_repl_ctrl); + v_act_rep[7] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 28), v_repl_ctrl); + + HVX_Vector v_sum = accum_q8_0_32x1(vptr, v_act_rep); + HVX_Vector v_sum_sf = Q6_Vsf_equals_Vw(v_sum); + + HVX_Vector v_scale_w = vptr[8]; + + __fp16 scale_a_val = y_scales[kt]; + HVX_Vector v_scale_a = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&scale_a_val)); + + HVX_Vector v_scale_comb = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a); + HVX_Vector v_sum_scaled = hvx_vec_mul_f32_f32(v_sum_sf, v_scale_comb); + + v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled); + } + + hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float); +} + +static void flat_vec_dot_q8_0_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) { + const uint8_t * restrict tile_ptr = vx; + const uint8_t * restrict y0_q = vy0; + const uint8_t * restrict y1_q = vy1; + + HVX_Vector v_sum_float_c0 = Q6_V_vzero(); + HVX_Vector v_sum_float_c1 = Q6_V_vzero(); + + static const uint8_t __attribute__((aligned(128))) repl[128] = { + 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + }; + HVX_Vector v_repl_ctrl = * (const HVX_Vector *) repl; + + const uint32_t quants_size = hex_round_up(n, 128); + const __fp16 * restrict y0_scales = (const __fp16 *) (y0_q + quants_size); + const __fp16 * restrict y1_scales = (const __fp16 *) (y1_q + quants_size); + + uint32_t n_k_tiles = n / 32; + for (uint32_t kt = 0; kt < n_k_tiles; kt++) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 1152); + + uint32_t block_idx = kt / 4; + uint32_t sub_idx = kt % 4; + + HVX_Vector vx0_i8 = * (const HVX_Vector *) (y0_q + block_idx * 128); + HVX_Vector vx1_i8 = * (const HVX_Vector *) (y1_q + block_idx * 128); + + HVX_Vector v_act0_raw = Q6_V_vror_VR(vx0_i8, sub_idx * 32); + HVX_Vector v_act1_raw = Q6_V_vror_VR(vx1_i8, sub_idx * 32); + + HVX_Vector v_act0_rep[8]; + v_act0_rep[0] = Q6_V_vdelta_VV(v_act0_raw, v_repl_ctrl); + v_act0_rep[1] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 4), v_repl_ctrl); + v_act0_rep[2] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 8), v_repl_ctrl); + v_act0_rep[3] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 12), v_repl_ctrl); + v_act0_rep[4] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 16), v_repl_ctrl); + v_act0_rep[5] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 20), v_repl_ctrl); + v_act0_rep[6] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 24), v_repl_ctrl); + v_act0_rep[7] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 28), v_repl_ctrl); + + HVX_Vector v_act1_rep[8]; + v_act1_rep[0] = Q6_V_vdelta_VV(v_act1_raw, v_repl_ctrl); + v_act1_rep[1] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 4), v_repl_ctrl); + v_act1_rep[2] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 8), v_repl_ctrl); + v_act1_rep[3] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 12), v_repl_ctrl); + v_act1_rep[4] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 16), v_repl_ctrl); + v_act1_rep[5] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 20), v_repl_ctrl); + v_act1_rep[6] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 24), v_repl_ctrl); + v_act1_rep[7] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 28), v_repl_ctrl); + + HVX_VectorPair v_sums = accum_q8_0_32x2(vptr, v_act0_rep, v_act1_rep); + HVX_Vector v_sum_c0 = Q6_V_lo_W(v_sums); + HVX_Vector v_sum_c1 = Q6_V_hi_W(v_sums); + + HVX_Vector v_sum_sf_c0 = Q6_Vsf_equals_Vw(v_sum_c0); + HVX_Vector v_sum_sf_c1 = Q6_Vsf_equals_Vw(v_sum_c1); + + HVX_Vector v_scale_w = vptr[8]; + + __fp16 scale_a0_val = y0_scales[kt]; + __fp16 scale_a1_val = y1_scales[kt]; + HVX_Vector v_scale_a0 = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&scale_a0_val)); + HVX_Vector v_scale_a1 = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&scale_a1_val)); + + HVX_Vector v_scale_comb_c0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a0); + HVX_Vector v_scale_comb_c1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a1); + + HVX_Vector v_sum_scaled_c0 = hvx_vec_mul_f32_f32(v_sum_sf_c0, v_scale_comb_c0); + HVX_Vector v_sum_scaled_c1 = hvx_vec_mul_f32_f32(v_sum_sf_c1, v_scale_comb_c1); + + v_sum_float_c0 = hvx_vec_add_f32_f32(v_sum_float_c0, v_sum_scaled_c0); + v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1); + } + + hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0); + hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1); +} + +static void flat_vec_dot_iq4nl_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) { + const uint8_t * restrict tile_ptr = vx; + const uint8_t * restrict y_q = vy; + + HVX_Vector v_sum_float = Q6_V_vzero(); + HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + HVX_Vector lut = *(const HVX_Vector *) kvalues_iq4nl_lut; + + static const uint8_t __attribute__((aligned(128))) repl[128] = { + 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + }; + HVX_Vector v_repl_ctrl = * (const HVX_Vector *) repl; + + const uint32_t quants_size = hex_round_up(n, 128); + const __fp16 * restrict y_scales = (const __fp16 *) (y_q + quants_size); + + uint32_t n_k_tiles = n / 32; + for (uint32_t kt = 0; kt < n_k_tiles; kt++) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 640); + + uint32_t block_idx = kt / 4; + uint32_t sub_idx = kt % 4; + + HVX_Vector vx = * (const HVX_Vector *) (y_q + block_idx * 128); + HVX_Vector v_act_raw = Q6_V_vror_VR(vx, sub_idx * 32); + + HVX_Vector v_act_rep[8]; + v_act_rep[0] = Q6_V_vdelta_VV(v_act_raw, v_repl_ctrl); + v_act_rep[1] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 4), v_repl_ctrl); + v_act_rep[2] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 8), v_repl_ctrl); + v_act_rep[3] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 12), v_repl_ctrl); + v_act_rep[4] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 16), v_repl_ctrl); + v_act_rep[5] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 20), v_repl_ctrl); + v_act_rep[6] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 24), v_repl_ctrl); + v_act_rep[7] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 28), v_repl_ctrl); + + HVX_Vector v_sum = accum_4bit_32x1_lut(vptr, v_act_rep, mask_h4, lut); + HVX_Vector v_sum_sf = Q6_Vsf_equals_Vw(v_sum); + + HVX_Vector v_scale_w = vptr[4]; + + __fp16 scale_a_val = y_scales[kt]; + HVX_Vector v_scale_a = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&scale_a_val)); + + HVX_Vector v_scale_comb = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a); + HVX_Vector v_sum_scaled = hvx_vec_mul_f32_f32(v_sum_sf, v_scale_comb); + + v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled); + } + + hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float); +} + +static void flat_vec_dot_iq4nl_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) { + const uint8_t * restrict tile_ptr = vx; + const uint8_t * restrict y0_q = vy0; + const uint8_t * restrict y1_q = vy1; + + HVX_Vector v_sum_float_c0 = Q6_V_vzero(); + HVX_Vector v_sum_float_c1 = Q6_V_vzero(); + HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + HVX_Vector lut = *(const HVX_Vector *) kvalues_iq4nl_lut; + + static const uint8_t __attribute__((aligned(128))) repl[128] = { + 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + }; + HVX_Vector v_repl_ctrl = * (const HVX_Vector *) repl; + + const uint32_t quants_size = hex_round_up(n, 128); + const __fp16 * restrict y0_scales = (const __fp16 *) (y0_q + quants_size); + const __fp16 * restrict y1_scales = (const __fp16 *) (y1_q + quants_size); + + uint32_t n_k_tiles = n / 32; + for (uint32_t kt = 0; kt < n_k_tiles; kt++) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 640); + + uint32_t block_idx = kt / 4; + uint32_t sub_idx = kt % 4; + + HVX_Vector vx0 = * (const HVX_Vector *) (y0_q + block_idx * 128); + HVX_Vector vx1 = * (const HVX_Vector *) (y1_q + block_idx * 128); + + HVX_Vector v_act0_raw = Q6_V_vror_VR(vx0, sub_idx * 32); + HVX_Vector v_act1_raw = Q6_V_vror_VR(vx1, sub_idx * 32); + + HVX_Vector v_act0_rep[8]; + v_act0_rep[0] = Q6_V_vdelta_VV(v_act0_raw, v_repl_ctrl); + v_act0_rep[1] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 4), v_repl_ctrl); + v_act0_rep[2] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 8), v_repl_ctrl); + v_act0_rep[3] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 12), v_repl_ctrl); + v_act0_rep[4] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 16), v_repl_ctrl); + v_act0_rep[5] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 20), v_repl_ctrl); + v_act0_rep[6] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 24), v_repl_ctrl); + v_act0_rep[7] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 28), v_repl_ctrl); + + HVX_Vector v_act1_rep[8]; + v_act1_rep[0] = Q6_V_vdelta_VV(v_act1_raw, v_repl_ctrl); + v_act1_rep[1] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 4), v_repl_ctrl); + v_act1_rep[2] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 8), v_repl_ctrl); + v_act1_rep[3] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 12), v_repl_ctrl); + v_act1_rep[4] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 16), v_repl_ctrl); + v_act1_rep[5] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 20), v_repl_ctrl); + v_act1_rep[6] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 24), v_repl_ctrl); + v_act1_rep[7] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 28), v_repl_ctrl); + + HVX_VectorPair v_sums = accum_4bit_32x2_lut(vptr, v_act0_rep, v_act1_rep, mask_h4, lut); + HVX_Vector v_sum_c0 = Q6_V_lo_W(v_sums); + HVX_Vector v_sum_c1 = Q6_V_hi_W(v_sums); + + HVX_Vector v_sum_sf_c0 = Q6_Vsf_equals_Vw(v_sum_c0); + HVX_Vector v_sum_sf_c1 = Q6_Vsf_equals_Vw(v_sum_c1); + + HVX_Vector v_scale_w = vptr[4]; + + __fp16 scale_a0_val = y0_scales[kt]; + __fp16 scale_a1_val = y1_scales[kt]; + HVX_Vector v_scale_a0 = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&scale_a0_val)); + HVX_Vector v_scale_a1 = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&scale_a1_val)); + + HVX_Vector v_scale_comb_c0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a0); + HVX_Vector v_scale_comb_c1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a1); + + HVX_Vector v_sum_scaled_c0 = hvx_vec_mul_f32_f32(v_sum_sf_c0, v_scale_comb_c0); + HVX_Vector v_sum_scaled_c1 = hvx_vec_mul_f32_f32(v_sum_sf_c1, v_scale_comb_c1); + + v_sum_float_c0 = hvx_vec_add_f32_f32(v_sum_float_c0, v_sum_scaled_c0); + v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1); + } + + hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0); + hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1); +} + +static void flat_vec_dot_mxfp4_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) { + const uint8_t * restrict tile_ptr = vx; + const uint8_t * restrict y_q = vy; + + HVX_Vector v_sum_float = Q6_V_vzero(); + HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + HVX_Vector lut = *(const HVX_Vector *) kvalues_mxfp4_lut; + HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; + HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); + + static const uint8_t __attribute__((aligned(128))) repl[128] = { + 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + }; + HVX_Vector v_repl_ctrl = * (const HVX_Vector *) repl; + + const uint32_t quants_size = hex_round_up(n, 128); + const __fp16 * restrict y_scales = (const __fp16 *) (y_q + quants_size); + + uint32_t n_k_tiles = n / 32; + for (uint32_t kt = 0; kt < n_k_tiles; kt++) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 640); + + uint32_t block_idx = kt / 4; + uint32_t sub_idx = kt % 4; + + HVX_Vector vx = * (const HVX_Vector *) (y_q + block_idx * 128); + HVX_Vector v_act_raw = Q6_V_vror_VR(vx, sub_idx * 32); + + HVX_Vector v_act_rep[8]; + v_act_rep[0] = Q6_V_vdelta_VV(v_act_raw, v_repl_ctrl); + v_act_rep[1] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 4), v_repl_ctrl); + v_act_rep[2] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 8), v_repl_ctrl); + v_act_rep[3] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 12), v_repl_ctrl); + v_act_rep[4] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 16), v_repl_ctrl); + v_act_rep[5] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 20), v_repl_ctrl); + v_act_rep[6] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 24), v_repl_ctrl); + v_act_rep[7] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 28), v_repl_ctrl); + + HVX_Vector v_sum = accum_4bit_32x1_lut(vptr, v_act_rep, mask_h4, lut); + HVX_Vector v_sum_sf = Q6_Vsf_equals_Vw(v_sum); + + HVX_Vector v_scale_w = hvx_vmem(tile_ptr + kt * 640 + 512); + HVX_Vector r0_d = Q6_V_vdelta_VV(v_scale_w, expand); + r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); + HVX_Vector v_scale_w_f32 = Q6_Vw_vasl_VwR(r0_d, 23); + + __fp16 scale_a_val = y_scales[kt]; + HVX_Vector v_scale_a_f16 = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&scale_a_val)); + HVX_VectorPair p_scale_a_f32 = hvx_vec_f16_to_f32(v_scale_a_f16); + HVX_Vector v_scale_a = Q6_V_lo_W(p_scale_a_f32); + + HVX_Vector v_scale_comb = hvx_vec_mul_f32_f32(v_scale_w_f32, v_scale_a); + HVX_Vector v_sum_scaled = hvx_vec_mul_f32_f32(v_sum_sf, v_scale_comb); + + v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled); + } + + v_sum_float = hvx_vec_mul_f32_f32(v_sum_float, hvx_vec_splat_f32(0.5f)); + + hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float); +} + +static void flat_vec_dot_mxfp4_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) { + const uint8_t * restrict tile_ptr = vx; + const uint8_t * restrict y0_q = vy0; + const uint8_t * restrict y1_q = vy1; + + HVX_Vector v_sum_float_c0 = Q6_V_vzero(); + HVX_Vector v_sum_float_c1 = Q6_V_vzero(); + HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + HVX_Vector lut = *(const HVX_Vector *) kvalues_mxfp4_lut; + HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; + HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); + + static const uint8_t __attribute__((aligned(128))) repl[128] = { + 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + }; + HVX_Vector v_repl_ctrl = * (const HVX_Vector *) repl; + + const uint32_t quants_size = hex_round_up(n, 128); + const __fp16 * restrict y0_scales = (const __fp16 *) (y0_q + quants_size); + const __fp16 * restrict y1_scales = (const __fp16 *) (y1_q + quants_size); + + uint32_t n_k_tiles = n / 32; + for (uint32_t kt = 0; kt < n_k_tiles; kt++) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 640); + + uint32_t block_idx = kt / 4; + uint32_t sub_idx = kt % 4; + + HVX_Vector vx0 = * (const HVX_Vector *) (y0_q + block_idx * 128); + HVX_Vector vx1 = * (const HVX_Vector *) (y1_q + block_idx * 128); + + HVX_Vector v_act0_raw = Q6_V_vror_VR(vx0, sub_idx * 32); + HVX_Vector v_act1_raw = Q6_V_vror_VR(vx1, sub_idx * 32); + + HVX_Vector v_act0_rep[8]; + v_act0_rep[0] = Q6_V_vdelta_VV(v_act0_raw, v_repl_ctrl); + v_act0_rep[1] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 4), v_repl_ctrl); + v_act0_rep[2] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 8), v_repl_ctrl); + v_act0_rep[3] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 12), v_repl_ctrl); + v_act0_rep[4] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 16), v_repl_ctrl); + v_act0_rep[5] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 20), v_repl_ctrl); + v_act0_rep[6] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 24), v_repl_ctrl); + v_act0_rep[7] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 28), v_repl_ctrl); + + HVX_Vector v_act1_rep[8]; + v_act1_rep[0] = Q6_V_vdelta_VV(v_act1_raw, v_repl_ctrl); + v_act1_rep[1] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 4), v_repl_ctrl); + v_act1_rep[2] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 8), v_repl_ctrl); + v_act1_rep[3] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 12), v_repl_ctrl); + v_act1_rep[4] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 16), v_repl_ctrl); + v_act1_rep[5] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 20), v_repl_ctrl); + v_act1_rep[6] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 24), v_repl_ctrl); + v_act1_rep[7] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 28), v_repl_ctrl); + + HVX_VectorPair v_sums = accum_4bit_32x2_lut(vptr, v_act0_rep, v_act1_rep, mask_h4, lut); + HVX_Vector v_sum_c0 = Q6_V_lo_W(v_sums); + HVX_Vector v_sum_c1 = Q6_V_hi_W(v_sums); + + HVX_Vector v_sum_sf_c0 = Q6_Vsf_equals_Vw(v_sum_c0); + HVX_Vector v_sum_sf_c1 = Q6_Vsf_equals_Vw(v_sum_c1); + + HVX_Vector v_scale_w = hvx_vmem(tile_ptr + kt * 640 + 512); + HVX_Vector r0_d = Q6_V_vdelta_VV(v_scale_w, expand); + r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); + HVX_Vector v_scale_w_f32 = Q6_Vw_vasl_VwR(r0_d, 23); + + __fp16 scale_a0_val = y0_scales[kt]; + __fp16 scale_a1_val = y1_scales[kt]; + HVX_Vector v_scale_a0_f16 = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&scale_a0_val)); + HVX_Vector v_scale_a1_f16 = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&scale_a1_val)); + HVX_VectorPair p_scale_a0_f32 = hvx_vec_f16_to_f32(v_scale_a0_f16); + HVX_VectorPair p_scale_a1_f32 = hvx_vec_f16_to_f32(v_scale_a1_f16); + HVX_Vector v_scale_a0 = Q6_V_lo_W(p_scale_a0_f32); + HVX_Vector v_scale_a1 = Q6_V_lo_W(p_scale_a1_f32); + + HVX_Vector v_scale_comb_c0 = hvx_vec_mul_f32_f32(v_scale_w_f32, v_scale_a0); + HVX_Vector v_scale_comb_c1 = hvx_vec_mul_f32_f32(v_scale_w_f32, v_scale_a1); + + HVX_Vector v_sum_scaled_c0 = hvx_vec_mul_f32_f32(v_sum_sf_c0, v_scale_comb_c0); + HVX_Vector v_sum_scaled_c1 = hvx_vec_mul_f32_f32(v_sum_sf_c1, v_scale_comb_c1); + + v_sum_float_c0 = hvx_vec_add_f32_f32(v_sum_float_c0, v_sum_scaled_c0); + v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1); + } + + v_sum_float_c0 = hvx_vec_mul_f32_f32(v_sum_float_c0, hvx_vec_splat_f32(0.5f)); + v_sum_float_c1 = hvx_vec_mul_f32_f32(v_sum_float_c1, hvx_vec_splat_f32(0.5f)); + + hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0); + hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1); +} diff --git a/ggml/src/ggml-hexagon/htp/hvx-mm-kernels-tiled.h b/ggml/src/ggml-hexagon/htp/hvx-mm-kernels-tiled.h new file mode 100644 index 0000000000..bcb0b8f9e4 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-mm-kernels-tiled.h @@ -0,0 +1,1140 @@ +// Dynamic quantizers that produce tiled activations + +static inline void quantize_block_f32_q8_1_tiled(float * restrict x, uint8_t * restrict y_block) { + assert((unsigned long) x % 128 == 0); + assert((unsigned long) y_block % 128 == 0); + + HVX_Vector * vx = (HVX_Vector *) x; + HVX_Vector zero = Q6_V_vzero(); + + HVX_Vector vmax0_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[0])); + HVX_Vector vmax1_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[1])); + HVX_Vector vmax2_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[2])); + HVX_Vector vmax3_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[3])); + + HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); + HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); + HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); + HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); + + HVX_Vector vmax0_qf = Q6_Vqf32_vsub_VsfVsf(vmax0_sf, zero); + HVX_Vector vmax1_qf = Q6_Vqf32_vsub_VsfVsf(vmax1_sf, zero); + HVX_Vector vmax2_qf = Q6_Vqf32_vsub_VsfVsf(vmax2_sf, zero); + HVX_Vector vmax3_qf = Q6_Vqf32_vsub_VsfVsf(vmax3_sf, zero); + + HVX_Vector vmax01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax1_qf, vmax0_qf))); + HVX_Vector vmax23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax3_qf, vmax2_qf))); + + HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf))); + HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf))); + + HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 + HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 + HVX_Vector vd01_hf = Q6_Vhf_equals_Vqf16(vd01_qf16); + HVX_Vector vd23_hf = Q6_Vhf_equals_Vqf16(vd23_qf16); + + HVX_Vector vd01_inv_hf = hvx_vec_inverse_f16(vd01_hf); + HVX_Vector vd23_inv_hf = hvx_vec_inverse_f16(vd23_hf); + vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd01_inv_hf)); + vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd23_inv_hf)); + + HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf); + HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf); + HVX_Vector vx_i8 = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16); + + const HVX_Vector ones = Q6_Vb_vsplat_R(1); + HVX_Vector v_sums = Q6_Vw_vrmpy_VbVb(vx_i8, ones); + v_sums = Q6_Vw_vadd_VwVw(v_sums, Q6_V_vror_VR(v_sums, 4)); + v_sums = Q6_Vw_vadd_VwVw(v_sums, Q6_V_vror_VR(v_sums, 8)); + v_sums = Q6_Vw_vadd_VwVw(v_sums, Q6_V_vror_VR(v_sums, 16)); + + float vmax0[32] __attribute__((aligned(128))); + float vmax1[32] __attribute__((aligned(128))); + float vmax2[32] __attribute__((aligned(128))); + float vmax3[32] __attribute__((aligned(128))); + int32_t sums[32] __attribute__((aligned(128))); + + hvx_vec_store_u(vmax0, 128, vmax0_sf); + hvx_vec_store_u(vmax1, 128, vmax1_sf); + hvx_vec_store_u(vmax2, 128, vmax2_sf); + hvx_vec_store_u(vmax3, 128, vmax3_sf); + hvx_vec_store_u(sums, 128, v_sums); + + float d0 = vmax0[0] / 127.0f; + float d1 = vmax1[0] / 127.0f; + float d2 = vmax2[0] / 127.0f; + float d3 = vmax3[0] / 127.0f; + + static const uint8_t __attribute__((aligned(128))) repl[128] = { + 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + }; + HVX_Vector v_repl_ctrl = * (const HVX_Vector *) repl; + + for (int b = 0; b < 4; b++) { + HVX_Vector v_act = Q6_V_vror_VR(vx_i8, b * 32); + + HVX_Vector r0 = Q6_V_vdelta_VV(v_act, v_repl_ctrl); + HVX_Vector r1 = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act, 4), v_repl_ctrl); + HVX_Vector r2 = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act, 8), v_repl_ctrl); + HVX_Vector r3 = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act, 12), v_repl_ctrl); + HVX_Vector r4 = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act, 16), v_repl_ctrl); + HVX_Vector r5 = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act, 20), v_repl_ctrl); + HVX_Vector r6 = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act, 24), v_repl_ctrl); + HVX_Vector r7 = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act, 28), v_repl_ctrl); + + __fp16 scale_h, offset_h; + if (b == 0) { + scale_h = (__fp16) d0; + offset_h = (__fp16) (sums[0] * d0); + } else if (b == 1) { + scale_h = (__fp16) d1; + offset_h = (__fp16) (sums[8] * d1); + } else if (b == 2) { + scale_h = (__fp16) d2; + offset_h = (__fp16) (sums[16] * d2); + } else { + scale_h = (__fp16) d3; + offset_h = (__fp16) (sums[24] * d3); + } + + HVX_Vector r_scale = Q6_Vh_vsplat_R(*(int16_t *)&scale_h); + HVX_Vector r_offset = Q6_Vh_vsplat_R(*(int16_t *)&offset_h); + + HVX_Vector * restrict dst = (HVX_Vector *) (y_block + b * 1280); + dst[0] = r0; + dst[1] = r1; + dst[2] = r2; + dst[3] = r3; + dst[4] = r4; + dst[5] = r5; + dst[6] = r6; + dst[7] = r7; + dst[8] = r_scale; + dst[9] = r_offset; + } +} + +static inline void quantize_block_f32_q8_0_tiled(float * restrict x, uint8_t * restrict y_block) { + assert((unsigned long) x % 128 == 0); + assert((unsigned long) y_block % 128 == 0); + + HVX_Vector * vx = (HVX_Vector *) x; + HVX_Vector zero = Q6_V_vzero(); + + HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); + HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); + HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); + HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); + + HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf))); + HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf))); + + HVX_Vector vmax_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx01_hf)); + vmax_hf = hvx_vec_reduce_max2_f16(hvx_vec_abs_f16(vx23_hf), vmax_hf); + + HVX_Vector vd_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax_hf, Q6_Vh_vsplat_R(0x2008)); + HVX_Vector vd_hf = Q6_Vhf_equals_Vqf16(vd_qf16); + + HVX_Vector vd_inv_hf = hvx_vec_inverse_f16(vd_hf); + vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd_inv_hf)); + vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd_inv_hf)); + + HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf); + HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf); + HVX_Vector vx_i8 = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16); + + HVX_Vector r_scale = hvx_vec_repl_f16(vd_hf); + + static const uint8_t __attribute__((aligned(128))) repl[128] = { + 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + }; + HVX_Vector v_repl_ctrl = * (const HVX_Vector *) repl; + + for (int b = 0; b < 4; b++) { + HVX_Vector v_act = Q6_V_vror_VR(vx_i8, b * 32); + + HVX_Vector r0 = Q6_V_vdelta_VV(v_act, v_repl_ctrl); + HVX_Vector r1 = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act, 4), v_repl_ctrl); + HVX_Vector r2 = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act, 8), v_repl_ctrl); + HVX_Vector r3 = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act, 12), v_repl_ctrl); + HVX_Vector r4 = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act, 16), v_repl_ctrl); + HVX_Vector r5 = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act, 20), v_repl_ctrl); + HVX_Vector r6 = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act, 24), v_repl_ctrl); + HVX_Vector r7 = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act, 28), v_repl_ctrl); + + HVX_Vector * restrict dst = (HVX_Vector *) (y_block + b * 1152); + dst[0] = r0; + dst[1] = r1; + dst[2] = r2; + dst[3] = r3; + dst[4] = r4; + dst[5] = r5; + dst[6] = r6; + dst[7] = r7; + dst[8] = r_scale; + } +} + +static void quantize_row_f32_q8_0_tiled(float * restrict x, uint8_t * restrict y, uint32_t k) { + assert(k % 32 == 0); + const uint32_t qk = QK_Q8_0_TILED; + const uint32_t nb = (k + qk - 1) / qk; + + for (uint32_t i = 0; i < nb; i++) { + uint8_t * restrict y_block = y + i * 4 * 1152; + quantize_block_f32_q8_0_tiled(x + i * qk, y_block); + } +} + +static void quantize_row_f32_q8_1_tiled(float * restrict x, uint8_t * restrict y, uint32_t k) { + assert(k % 32 == 0); + const uint32_t qk = QK_Q8_0_TILED; + const uint32_t nb = (k + qk - 1) / qk; + + for (uint32_t i = 0; i < nb; i++) { + uint8_t * restrict y_block = y + i * 4 * 1280; + quantize_block_f32_q8_1_tiled(x + i * qk, y_block); + } +} + +// Dot kernels & helpers that consume tiled activations + +static inline HVX_Vector hvx_vec_mul_f16_f16_to_f32_lower32(HVX_Vector v1, HVX_Vector v2) { +#if __HVX_ARCH__ >= 79 + HVX_VectorPair p = Q6_Wsf_vmpy_VhfVhf(v1, v2); + return Q6_V_lo_W(Q6_W_vshuff_VVR(Q6_V_hi_W(p), Q6_V_lo_W(p), -4)); +#else + HVX_VectorPair p = Q6_Wqf32_vmpy_VhfVhf(v1, v2); + HVX_Vector hi = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(p)); + HVX_Vector lo = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(p)); + return Q6_V_lo_W(Q6_W_vshuff_VVR(hi, lo, -4)); +#endif +} + +static inline HVX_Vector unpack_and_interleave_4bit(HVX_Vector v_a, HVX_Vector v_b, HVX_Vector mask_h4) { + HVX_Vector v_W0 = Q6_V_vand_VV(v_a, mask_h4); + HVX_Vector v_W1 = Q6_Vub_vlsr_VubR(v_a, 4); + HVX_Vector v_W2 = Q6_V_vand_VV(v_b, mask_h4); + HVX_Vector v_W3 = Q6_Vub_vlsr_VubR(v_b, 4); + + HVX_VectorPair v01_pair = Q6_W_vshuff_VVR(v_W1, v_W0, -1); + HVX_VectorPair v23_pair = Q6_W_vshuff_VVR(v_W3, v_W2, -1); + HVX_VectorPair v0123_pair = Q6_W_vshuff_VVR(Q6_V_lo_W(v23_pair), Q6_V_lo_W(v01_pair), -2); + return Q6_V_lo_W(v0123_pair); +} + +static inline HVX_VectorPair unpack_and_interleave_4bit_x2(HVX_Vector v_src, HVX_Vector mask_h4) { + HVX_Vector v_lo = Q6_V_vand_VV(v_src, mask_h4); + HVX_Vector v_hi = Q6_Vub_vlsr_VubR(v_src, 4); + HVX_VectorPair v01_pair = Q6_W_vshuff_VVR(v_hi, v_lo, -1); + HVX_Vector v01_lo = Q6_V_lo_W(v01_pair); + HVX_Vector v01_hi = Q6_V_hi_W(v01_pair); + + HVX_Vector v23_lo = Q6_V_valign_VVR(v01_hi, v01_lo, 64); + HVX_Vector v_W0 = Q6_V_lo_W(Q6_W_vshuff_VVR(v23_lo, v01_lo, -2)); + + HVX_Vector v67_lo = Q6_V_valign_VVR(v01_lo, v01_hi, 64); + HVX_Vector v_W1 = Q6_V_lo_W(Q6_W_vshuff_VVR(v67_lo, v01_hi, -2)); + + return Q6_W_vcombine_VV(v_W1, v_W0); +} + +static inline HVX_Vector accum_4bit_32x1( + const HVX_Vector * restrict vptr, + const HVX_Vector * restrict v_act, + HVX_Vector i8 +) { + HVX_Vector v_sum0 = Q6_V_vzero(); + HVX_Vector v_sum1 = Q6_V_vzero(); + HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + + #pragma unroll + for (int i = 0; i < 4; i++) { + HVX_VectorPair v_W_pair = unpack_and_interleave_4bit_x2(vptr[i], mask_h4); + HVX_Vector v_W0 = Q6_Vb_vsub_VbVb(Q6_V_lo_W(v_W_pair), i8); + HVX_Vector v_W1 = Q6_Vb_vsub_VbVb(Q6_V_hi_W(v_W_pair), i8); + v_sum0 = Q6_Vw_vrmpyacc_VwVbVb(v_sum0, v_W0, v_act[i * 2 + 0]); + v_sum1 = Q6_Vw_vrmpyacc_VwVbVb(v_sum1, v_W1, v_act[i * 2 + 1]); + } + + return Q6_Vw_vadd_VwVw(v_sum0, v_sum1); +} + +static inline HVX_Vector accum_4bit_32x1_lut( + const HVX_Vector * restrict vptr, + const HVX_Vector * restrict v_act, + HVX_Vector mask_h4, + HVX_Vector lut +) { + HVX_Vector v_sum0 = Q6_V_vzero(); + HVX_Vector v_sum1 = Q6_V_vzero(); + + #pragma unroll + for (int i = 0; i < 4; i++) { + HVX_VectorPair v_W_pair = unpack_and_interleave_4bit_x2(vptr[i], mask_h4); + HVX_Vector v_W0 = Q6_Vb_vlut32_VbVbI(Q6_V_lo_W(v_W_pair), lut, 0); + HVX_Vector v_W1 = Q6_Vb_vlut32_VbVbI(Q6_V_hi_W(v_W_pair), lut, 0); + v_sum0 = Q6_Vw_vrmpyacc_VwVbVb(v_sum0, v_W0, v_act[i * 2 + 0]); + v_sum1 = Q6_Vw_vrmpyacc_VwVbVb(v_sum1, v_W1, v_act[i * 2 + 1]); + } + + return Q6_Vw_vadd_VwVw(v_sum0, v_sum1); +} + +static inline HVX_VectorPair accum_4bit_32x2( + const HVX_Vector * restrict vptr, + const HVX_Vector * restrict v_act0, + const HVX_Vector * restrict v_act1, + HVX_Vector i8 +) { + HVX_Vector v_sum0 = Q6_V_vzero(); + HVX_Vector v_sum1 = Q6_V_vzero(); + HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + + #pragma unroll + for (int i = 0; i < 4; i++) { + HVX_VectorPair v_W_pair = unpack_and_interleave_4bit_x2(vptr[i], mask_h4); + HVX_Vector v_W0 = Q6_Vb_vsub_VbVb(Q6_V_lo_W(v_W_pair), i8); + HVX_Vector v_W1 = Q6_Vb_vsub_VbVb(Q6_V_hi_W(v_W_pair), i8); + + v_sum0 = Q6_Vw_vrmpyacc_VwVbVb(v_sum0, v_W0, v_act0[i * 2 + 0]); + v_sum0 = Q6_Vw_vrmpyacc_VwVbVb(v_sum0, v_W1, v_act0[i * 2 + 1]); + + v_sum1 = Q6_Vw_vrmpyacc_VwVbVb(v_sum1, v_W0, v_act1[i * 2 + 0]); + v_sum1 = Q6_Vw_vrmpyacc_VwVbVb(v_sum1, v_W1, v_act1[i * 2 + 1]); + } + + return Q6_W_vcombine_VV(v_sum1, v_sum0); +} + +static inline HVX_VectorPair accum_4bit_32x2_lut( + const HVX_Vector * restrict vptr, + const HVX_Vector * restrict v_act0, + const HVX_Vector * restrict v_act1, + HVX_Vector mask_h4, + HVX_Vector lut +) { + HVX_Vector v_sum0 = Q6_V_vzero(); + HVX_Vector v_sum1 = Q6_V_vzero(); + + #pragma unroll + for (int i = 0; i < 4; i++) { + HVX_VectorPair v_W_pair = unpack_and_interleave_4bit_x2(vptr[i], mask_h4); + HVX_Vector v_W0 = Q6_Vb_vlut32_VbVbI(Q6_V_lo_W(v_W_pair), lut, 0); + HVX_Vector v_W1 = Q6_Vb_vlut32_VbVbI(Q6_V_hi_W(v_W_pair), lut, 0); + + v_sum0 = Q6_Vw_vrmpyacc_VwVbVb(v_sum0, v_W0, v_act0[i * 2 + 0]); + v_sum0 = Q6_Vw_vrmpyacc_VwVbVb(v_sum0, v_W1, v_act0[i * 2 + 1]); + + v_sum1 = Q6_Vw_vrmpyacc_VwVbVb(v_sum1, v_W0, v_act1[i * 2 + 0]); + v_sum1 = Q6_Vw_vrmpyacc_VwVbVb(v_sum1, v_W1, v_act1[i * 2 + 1]); + } + + return Q6_W_vcombine_VV(v_sum1, v_sum0); +} + +static inline HVX_Vector accum_q8_0_32x1( + const HVX_Vector * restrict vptr, + const HVX_Vector * restrict v_act +) { + HVX_Vector v_sum = Q6_V_vzero(); + #pragma unroll + for (int g = 0; g < 8; g++) { + HVX_Vector v_rot = Q6_V_vror_VR(vptr[g], 64); + HVX_Vector v_W = Q6_V_lo_W(Q6_W_vshuff_VVR(v_rot, vptr[g], -2)); + v_sum = Q6_Vw_vrmpyacc_VwVbVb(v_sum, v_W, v_act[g]); + } + return v_sum; +} + +static inline HVX_VectorPair accum_q8_0_32x2( + const HVX_Vector * restrict vptr, + const HVX_Vector * restrict v_act0, + const HVX_Vector * restrict v_act1 +) { + HVX_Vector v_sum0 = Q6_V_vzero(); + HVX_Vector v_sum1 = Q6_V_vzero(); + #pragma unroll + for (int g = 0; g < 8; g++) { + HVX_Vector v_rot = Q6_V_vror_VR(vptr[g], 64); + HVX_Vector v_W = Q6_V_lo_W(Q6_W_vshuff_VVR(v_rot, vptr[g], -2)); + v_sum0 = Q6_Vw_vrmpyacc_VwVbVb(v_sum0, v_W, v_act0[g]); + v_sum1 = Q6_Vw_vrmpyacc_VwVbVb(v_sum1, v_W, v_act1[g]); + } + return Q6_W_vcombine_VV(v_sum1, v_sum0); +} + +static void tiled_vec_dot_q4_0_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) { + const uint8_t * restrict tile_ptr = vx; + const uint8_t * restrict y_q = vy; + + HVX_Vector v_sum_float = Q6_V_vzero(); + HVX_Vector i8 = Q6_Vb_vsplat_R(8); + + uint32_t n_k_tiles = n / 32; + for (uint32_t kt = 0; kt < n_k_tiles; kt++) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 640); + const HVX_Vector * restrict v_act = (const HVX_Vector *) (y_q + kt * 1152); + + HVX_Vector v_sum = accum_4bit_32x1(vptr, v_act, i8); + HVX_Vector v_sum_sf = Q6_Vsf_equals_Vw(v_sum); + + HVX_Vector v_scale_w = vptr[4]; + HVX_Vector v_scale_a = v_act[8]; + HVX_Vector v_scale_comb = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a); + HVX_Vector v_sum_scaled = hvx_vec_mul_f32_f32(v_sum_sf, v_scale_comb); + + v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled); + } + + hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float); +} + +static void tiled_vec_dot_q4_0_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) { + const uint8_t * restrict tile_ptr = vx; + const uint8_t * restrict y0_q = vy0; + const uint8_t * restrict y1_q = vy1; + + HVX_Vector v_sum_float_c0 = Q6_V_vzero(); + HVX_Vector v_sum_float_c1 = Q6_V_vzero(); + HVX_Vector i8 = Q6_Vb_vsplat_R(8); + + uint32_t n_k_tiles = n / 32; + uint32_t kt = 0; + for (; kt + 1 < n_k_tiles; kt += 2) { + const HVX_Vector * restrict vptr0 = (const HVX_Vector *) (tile_ptr + (kt + 0) * 640); + const HVX_Vector * restrict v_act0_0 = (const HVX_Vector *) (y0_q + (kt + 0) * 1152); + const HVX_Vector * restrict v_act1_0 = (const HVX_Vector *) (y1_q + (kt + 0) * 1152); + + const HVX_Vector * restrict vptr1 = (const HVX_Vector *) (tile_ptr + (kt + 1) * 640); + const HVX_Vector * restrict v_act0_1 = (const HVX_Vector *) (y0_q + (kt + 1) * 1152); + const HVX_Vector * restrict v_act1_1 = (const HVX_Vector *) (y1_q + (kt + 1) * 1152); + + HVX_VectorPair v_sums0 = accum_4bit_32x2(vptr0, v_act0_0, v_act1_0, i8); + HVX_VectorPair v_sums1 = accum_4bit_32x2(vptr1, v_act0_1, v_act1_1, i8); + + HVX_Vector v_sum_c0_0 = Q6_V_lo_W(v_sums0); + HVX_Vector v_sum_c1_0 = Q6_V_hi_W(v_sums0); + HVX_Vector v_sum_c0_1 = Q6_V_lo_W(v_sums1); + HVX_Vector v_sum_c1_1 = Q6_V_hi_W(v_sums1); + + HVX_Vector v_sum_sf_c0_0 = Q6_Vsf_equals_Vw(v_sum_c0_0); + HVX_Vector v_sum_sf_c1_0 = Q6_Vsf_equals_Vw(v_sum_c1_0); + HVX_Vector v_sum_sf_c0_1 = Q6_Vsf_equals_Vw(v_sum_c0_1); + HVX_Vector v_sum_sf_c1_1 = Q6_Vsf_equals_Vw(v_sum_c1_1); + + HVX_Vector v_scale_w0 = vptr0[4]; + HVX_Vector v_scale_w1 = vptr1[4]; + HVX_Vector v_scale_a_c0_0 = v_act0_0[8]; + HVX_Vector v_scale_a_c1_0 = v_act1_0[8]; + HVX_Vector v_scale_a_c0_1 = v_act0_1[8]; + HVX_Vector v_scale_a_c1_1 = v_act1_1[8]; + + HVX_Vector v_scale_comb_c0_0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w0, v_scale_a_c0_0); + HVX_Vector v_scale_comb_c1_0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w0, v_scale_a_c1_0); + HVX_Vector v_scale_comb_c0_1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w1, v_scale_a_c0_1); + HVX_Vector v_scale_comb_c1_1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w1, v_scale_a_c1_1); + + HVX_Vector v_sum_scaled_c0_0 = hvx_vec_mul_f32_f32(v_sum_sf_c0_0, v_scale_comb_c0_0); + HVX_Vector v_sum_scaled_c1_0 = hvx_vec_mul_f32_f32(v_sum_sf_c1_0, v_scale_comb_c1_0); + HVX_Vector v_sum_scaled_c0_1 = hvx_vec_mul_f32_f32(v_sum_sf_c0_1, v_scale_comb_c0_1); + HVX_Vector v_sum_scaled_c1_1 = hvx_vec_mul_f32_f32(v_sum_sf_c1_1, v_scale_comb_c1_1); + + v_sum_float_c0 = hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vec_add_f32_f32(v_sum_scaled_c0_0, v_sum_scaled_c0_1)); + v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vec_add_f32_f32(v_sum_scaled_c1_0, v_sum_scaled_c1_1)); + } + + for (; kt < n_k_tiles; kt++) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 640); + const HVX_Vector * restrict v_act0 = (const HVX_Vector *) (y0_q + kt * 1152); + const HVX_Vector * restrict v_act1 = (const HVX_Vector *) (y1_q + kt * 1152); + + HVX_VectorPair v_sums = accum_4bit_32x2(vptr, v_act0, v_act1, i8); + HVX_Vector v_sum_c0 = Q6_V_lo_W(v_sums); + HVX_Vector v_sum_c1 = Q6_V_hi_W(v_sums); + + HVX_Vector v_sum_sf_c0 = Q6_Vsf_equals_Vw(v_sum_c0); + HVX_Vector v_sum_sf_c1 = Q6_Vsf_equals_Vw(v_sum_c1); + + HVX_Vector v_scale_w = vptr[4]; + HVX_Vector v_scale_a_c0 = v_act0[8]; + HVX_Vector v_scale_a_c1 = v_act1[8]; + + HVX_Vector v_scale_comb_c0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a_c0); + HVX_Vector v_scale_comb_c1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a_c1); + + HVX_Vector v_sum_scaled_c0 = hvx_vec_mul_f32_f32(v_sum_sf_c0, v_scale_comb_c0); + HVX_Vector v_sum_scaled_c1 = hvx_vec_mul_f32_f32(v_sum_sf_c1, v_scale_comb_c1); + + v_sum_float_c0 = hvx_vec_add_f32_f32(v_sum_float_c0, v_sum_scaled_c0); + v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1); + } + + hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0); + hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1); +} + +static void tiled_vec_dot_q4_1_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) { + const uint8_t * restrict tile_ptr = vx; + const uint8_t * restrict y_q = vy; + + HVX_Vector v_sum_float = Q6_V_vzero(); + + uint32_t n_k_tiles = n / 32; + for (uint32_t kt = 0; kt < n_k_tiles; kt++) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 640); + const HVX_Vector * restrict v_act = (const HVX_Vector *) (y_q + kt * 1280); + + HVX_Vector v_sum = accum_4bit_32x1(vptr, v_act, Q6_V_vzero()); + HVX_Vector v_sum_sf = Q6_Vsf_equals_Vw(v_sum); + + HVX_Vector v_scale_offset = vptr[4]; + HVX_VectorPair p_deal = Q6_W_vdeal_VVR(v_scale_offset, v_scale_offset, -2); + HVX_Vector v_scale = Q6_V_lo_W(p_deal); + HVX_Vector v_offset = Q6_V_hi_W(p_deal); + + HVX_Vector v_scale_a = v_act[8]; + HVX_Vector v_sum_a = v_act[9]; + + HVX_Vector v_scale_comb = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale, v_scale_a); + HVX_Vector v_offset_comb = hvx_vec_mul_f16_f16_to_f32_lower32(v_offset, v_sum_a); + + HVX_Vector v_scaled_dot = hvx_vec_mul_f32_f32(v_sum_sf, v_scale_comb); + HVX_Vector v_sum_scaled = hvx_vec_add_f32_f32(v_scaled_dot, v_offset_comb); + + v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled); + } + + hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float); +} + +static void tiled_vec_dot_q4_1_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) { + const uint8_t * restrict tile_ptr = vx; + const uint8_t * restrict y0_q = vy0; + const uint8_t * restrict y1_q = vy1; + + HVX_Vector v_sum_float_c0 = Q6_V_vzero(); + HVX_Vector v_sum_float_c1 = Q6_V_vzero(); + + uint32_t n_k_tiles = n / 32; + uint32_t kt = 0; + for (; kt + 1 < n_k_tiles; kt += 2) { + const HVX_Vector * restrict vptr0 = (const HVX_Vector *) (tile_ptr + (kt + 0) * 640); + const HVX_Vector * restrict v_act0_0 = (const HVX_Vector *) (y0_q + (kt + 0) * 1280); + const HVX_Vector * restrict v_act1_0 = (const HVX_Vector *) (y1_q + (kt + 0) * 1280); + + const HVX_Vector * restrict vptr1 = (const HVX_Vector *) (tile_ptr + (kt + 1) * 640); + const HVX_Vector * restrict v_act0_1 = (const HVX_Vector *) (y0_q + (kt + 1) * 1280); + const HVX_Vector * restrict v_act1_1 = (const HVX_Vector *) (y1_q + (kt + 1) * 1280); + + HVX_VectorPair v_sums0 = accum_4bit_32x2(vptr0, v_act0_0, v_act1_0, Q6_V_vzero()); + HVX_VectorPair v_sums1 = accum_4bit_32x2(vptr1, v_act0_1, v_act1_1, Q6_V_vzero()); + + HVX_Vector v_sum_c0_0 = Q6_V_lo_W(v_sums0); + HVX_Vector v_sum_c1_0 = Q6_V_hi_W(v_sums0); + HVX_Vector v_sum_c0_1 = Q6_V_lo_W(v_sums1); + HVX_Vector v_sum_c1_1 = Q6_V_hi_W(v_sums1); + + HVX_Vector v_sum_sf_c0_0 = Q6_Vsf_equals_Vw(v_sum_c0_0); + HVX_Vector v_sum_sf_c1_0 = Q6_Vsf_equals_Vw(v_sum_c1_0); + HVX_Vector v_sum_sf_c0_1 = Q6_Vsf_equals_Vw(v_sum_c0_1); + HVX_Vector v_sum_sf_c1_1 = Q6_Vsf_equals_Vw(v_sum_c1_1); + + HVX_Vector v_scale_offset0 = vptr0[4]; + HVX_VectorPair p_deal0 = Q6_W_vdeal_VVR(v_scale_offset0, v_scale_offset0, -2); + HVX_Vector v_scale0 = Q6_V_lo_W(p_deal0); + HVX_Vector v_offset0 = Q6_V_hi_W(p_deal0); + + HVX_Vector v_scale_offset1 = vptr1[4]; + HVX_VectorPair p_deal1 = Q6_W_vdeal_VVR(v_scale_offset1, v_scale_offset1, -2); + HVX_Vector v_scale1 = Q6_V_lo_W(p_deal1); + HVX_Vector v_offset1 = Q6_V_hi_W(p_deal1); + + HVX_Vector v_scale_a_c0_0 = v_act0_0[8]; + HVX_Vector v_sum_a_c0_0 = v_act0_0[9]; + HVX_Vector v_scale_a_c1_0 = v_act1_0[8]; + HVX_Vector v_sum_a_c1_0 = v_act1_0[9]; + + HVX_Vector v_scale_a_c0_1 = v_act0_1[8]; + HVX_Vector v_sum_a_c0_1 = v_act0_1[9]; + HVX_Vector v_scale_a_c1_1 = v_act1_1[8]; + HVX_Vector v_sum_a_c1_1 = v_act1_1[9]; + + HVX_Vector v_scale_comb_c0_0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale0, v_scale_a_c0_0); + HVX_Vector v_offset_comb_c0_0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_offset0, v_sum_a_c0_0); + HVX_Vector v_scale_comb_c1_0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale0, v_scale_a_c1_0); + HVX_Vector v_offset_comb_c1_0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_offset0, v_sum_a_c1_0); + + HVX_Vector v_scale_comb_c0_1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale1, v_scale_a_c0_1); + HVX_Vector v_offset_comb_c0_1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_offset1, v_sum_a_c0_1); + HVX_Vector v_scale_comb_c1_1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale1, v_scale_a_c1_1); + HVX_Vector v_offset_comb_c1_1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_offset1, v_sum_a_c1_1); + + HVX_Vector v_scaled_dot_c0_0 = hvx_vec_mul_f32_f32(v_sum_sf_c0_0, v_scale_comb_c0_0); + HVX_Vector v_sum_scaled_c0_0 = hvx_vec_add_f32_f32(v_scaled_dot_c0_0, v_offset_comb_c0_0); + + HVX_Vector v_scaled_dot_c1_0 = hvx_vec_mul_f32_f32(v_sum_sf_c1_0, v_scale_comb_c1_0); + HVX_Vector v_sum_scaled_c1_0 = hvx_vec_add_f32_f32(v_scaled_dot_c1_0, v_offset_comb_c1_0); + + HVX_Vector v_scaled_dot_c0_1 = hvx_vec_mul_f32_f32(v_sum_sf_c0_1, v_scale_comb_c0_1); + HVX_Vector v_sum_scaled_c0_1 = hvx_vec_add_f32_f32(v_scaled_dot_c0_1, v_offset_comb_c0_1); + + HVX_Vector v_scaled_dot_c1_1 = hvx_vec_mul_f32_f32(v_sum_sf_c1_1, v_scale_comb_c1_1); + HVX_Vector v_sum_scaled_c1_1 = hvx_vec_add_f32_f32(v_scaled_dot_c1_1, v_offset_comb_c1_1); + + v_sum_float_c0 = hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vec_add_f32_f32(v_sum_scaled_c0_0, v_sum_scaled_c0_1)); + v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vec_add_f32_f32(v_sum_scaled_c1_0, v_sum_scaled_c1_1)); + } + + for (; kt < n_k_tiles; kt++) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 640); + const HVX_Vector * restrict v_act0 = (const HVX_Vector *) (y0_q + kt * 1280); + const HVX_Vector * restrict v_act1 = (const HVX_Vector *) (y1_q + kt * 1280); + + HVX_VectorPair v_sums = accum_4bit_32x2(vptr, v_act0, v_act1, Q6_V_vzero()); + HVX_Vector v_sum_c0 = Q6_V_lo_W(v_sums); + HVX_Vector v_sum_c1 = Q6_V_hi_W(v_sums); + + HVX_Vector v_sum_sf_c0 = Q6_Vsf_equals_Vw(v_sum_c0); + HVX_Vector v_sum_sf_c1 = Q6_Vsf_equals_Vw(v_sum_c1); + + HVX_Vector v_scale_offset = vptr[4]; + HVX_VectorPair p_deal = Q6_W_vdeal_VVR(v_scale_offset, v_scale_offset, -2); + HVX_Vector v_scale = Q6_V_lo_W(p_deal); + HVX_Vector v_offset = Q6_V_hi_W(p_deal); + + HVX_Vector v_scale_a_c0 = v_act0[8]; + HVX_Vector v_sum_a_c0 = v_act0[9]; + HVX_Vector v_scale_a_c1 = v_act1[8]; + HVX_Vector v_sum_a_c1 = v_act1[9]; + + HVX_Vector v_scale_comb_c0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale, v_scale_a_c0); + HVX_Vector v_offset_comb_c0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_offset, v_sum_a_c0); + HVX_Vector v_scale_comb_c1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale, v_scale_a_c1); + HVX_Vector v_offset_comb_c1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_offset, v_sum_a_c1); + + HVX_Vector v_scaled_dot_c0 = hvx_vec_mul_f32_f32(v_sum_sf_c0, v_scale_comb_c0); + HVX_Vector v_sum_scaled_c0 = hvx_vec_add_f32_f32(v_scaled_dot_c0, v_offset_comb_c0); + + HVX_Vector v_scaled_dot_c1 = hvx_vec_mul_f32_f32(v_sum_sf_c1, v_scale_comb_c1); + HVX_Vector v_sum_scaled_c1 = hvx_vec_add_f32_f32(v_scaled_dot_c1, v_offset_comb_c1); + + v_sum_float_c0 = hvx_vec_add_f32_f32(v_sum_float_c0, v_sum_scaled_c0); + v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1); + } + + hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0); + hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1); +} + +static void tiled_vec_dot_q8_0_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) { + const uint8_t * restrict tile_ptr = vx; + const uint8_t * restrict y_q = vy; + + HVX_Vector v_sum_float = Q6_V_vzero(); + + uint32_t n_k_tiles = n / 32; + for (uint32_t kt = 0; kt < n_k_tiles; kt++) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 1152); + const HVX_Vector * restrict v_act = (const HVX_Vector *) (y_q + kt * 1152); + + HVX_Vector v_sum = accum_q8_0_32x1(vptr, v_act); + HVX_Vector v_sum_sf = Q6_Vsf_equals_Vw(v_sum); + + HVX_Vector v_scale_w = vptr[8]; + HVX_Vector v_scale_a = v_act[8]; + HVX_Vector v_scale_comb = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a); + HVX_Vector v_sum_scaled = hvx_vec_mul_f32_f32(v_sum_sf, v_scale_comb); + + v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled); + } + + hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float); +} + +static void tiled_vec_dot_q8_0_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) { + const uint8_t * restrict tile_ptr = vx; + const uint8_t * restrict y0_q = vy0; + const uint8_t * restrict y1_q = vy1; + + HVX_Vector v_sum_float_c0 = Q6_V_vzero(); + HVX_Vector v_sum_float_c1 = Q6_V_vzero(); + + uint32_t n_k_tiles = n / 32; + uint32_t kt = 0; + for (; kt + 1 < n_k_tiles; kt += 2) { + const HVX_Vector * restrict vptr0 = (const HVX_Vector *) (tile_ptr + (kt + 0) * 1152); + const HVX_Vector * restrict v_act0_0 = (const HVX_Vector *) (y0_q + (kt + 0) * 1152); + const HVX_Vector * restrict v_act1_0 = (const HVX_Vector *) (y1_q + (kt + 0) * 1152); + + const HVX_Vector * restrict vptr1 = (const HVX_Vector *) (tile_ptr + (kt + 1) * 1152); + const HVX_Vector * restrict v_act0_1 = (const HVX_Vector *) (y0_q + (kt + 1) * 1152); + const HVX_Vector * restrict v_act1_1 = (const HVX_Vector *) (y1_q + (kt + 1) * 1152); + + HVX_VectorPair v_sums0 = accum_q8_0_32x2(vptr0, v_act0_0, v_act1_0); + HVX_VectorPair v_sums1 = accum_q8_0_32x2(vptr1, v_act0_1, v_act1_1); + + HVX_Vector v_sum_c0_0 = Q6_V_lo_W(v_sums0); + HVX_Vector v_sum_c1_0 = Q6_V_hi_W(v_sums0); + HVX_Vector v_sum_c0_1 = Q6_V_lo_W(v_sums1); + HVX_Vector v_sum_c1_1 = Q6_V_hi_W(v_sums1); + + HVX_Vector v_sum_sf_c0_0 = Q6_Vsf_equals_Vw(v_sum_c0_0); + HVX_Vector v_sum_sf_c1_0 = Q6_Vsf_equals_Vw(v_sum_c1_0); + HVX_Vector v_sum_sf_c0_1 = Q6_Vsf_equals_Vw(v_sum_c0_1); + HVX_Vector v_sum_sf_c1_1 = Q6_Vsf_equals_Vw(v_sum_c1_1); + + HVX_Vector v_scale_w0 = vptr0[8]; + HVX_Vector v_scale_w1 = vptr1[8]; + HVX_Vector v_scale_a_c0_0 = v_act0_0[8]; + HVX_Vector v_scale_a_c1_0 = v_act1_0[8]; + HVX_Vector v_scale_a_c0_1 = v_act0_1[8]; + HVX_Vector v_scale_a_c1_1 = v_act1_1[8]; + + HVX_Vector v_scale_comb_c0_0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w0, v_scale_a_c0_0); + HVX_Vector v_scale_comb_c1_0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w0, v_scale_a_c1_0); + HVX_Vector v_scale_comb_c0_1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w1, v_scale_a_c0_1); + HVX_Vector v_scale_comb_c1_1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w1, v_scale_a_c1_1); + + HVX_Vector v_sum_scaled_c0_0 = hvx_vec_mul_f32_f32(v_sum_sf_c0_0, v_scale_comb_c0_0); + HVX_Vector v_sum_scaled_c1_0 = hvx_vec_mul_f32_f32(v_sum_sf_c1_0, v_scale_comb_c1_0); + HVX_Vector v_sum_scaled_c0_1 = hvx_vec_mul_f32_f32(v_sum_sf_c0_1, v_scale_comb_c0_1); + HVX_Vector v_sum_scaled_c1_1 = hvx_vec_mul_f32_f32(v_sum_sf_c1_1, v_scale_comb_c1_1); + + v_sum_float_c0 = hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vec_add_f32_f32(v_sum_scaled_c0_0, v_sum_scaled_c0_1)); + v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vec_add_f32_f32(v_sum_scaled_c1_0, v_sum_scaled_c1_1)); + } + + for (; kt < n_k_tiles; kt++) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 1152); + const HVX_Vector * restrict v_act0 = (const HVX_Vector *) (y0_q + kt * 1152); + const HVX_Vector * restrict v_act1 = (const HVX_Vector *) (y1_q + kt * 1152); + + HVX_VectorPair v_sums = accum_q8_0_32x2(vptr, v_act0, v_act1); + HVX_Vector v_sum_c0 = Q6_V_lo_W(v_sums); + HVX_Vector v_sum_c1 = Q6_V_hi_W(v_sums); + + HVX_Vector v_sum_sf_c0 = Q6_Vsf_equals_Vw(v_sum_c0); + HVX_Vector v_sum_sf_c1 = Q6_Vsf_equals_Vw(v_sum_c1); + + HVX_Vector v_scale_w = vptr[8]; + HVX_Vector v_scale_a_c0 = v_act0[8]; + HVX_Vector v_scale_a_c1 = v_act1[8]; + + HVX_Vector v_scale_comb_c0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a_c0); + HVX_Vector v_scale_comb_c1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a_c1); + + HVX_Vector v_sum_scaled_c0 = hvx_vec_mul_f32_f32(v_sum_sf_c0, v_scale_comb_c0); + HVX_Vector v_sum_scaled_c1 = hvx_vec_mul_f32_f32(v_sum_sf_c1, v_scale_comb_c1); + + v_sum_float_c0 = hvx_vec_add_f32_f32(v_sum_float_c0, v_sum_scaled_c0); + v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1); + } + + hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0); + hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1); +} + +static void tiled_vec_dot_iq4nl_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) { + const uint8_t * restrict tile_ptr = vx; + const uint8_t * restrict y_q = vy; + + HVX_Vector v_sum_float = Q6_V_vzero(); + HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + HVX_Vector lut = *(const HVX_Vector *) kvalues_iq4nl_lut; + + uint32_t n_k_tiles = n / 32; + for (uint32_t kt = 0; kt < n_k_tiles; kt++) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 640); + const HVX_Vector * restrict v_act = (const HVX_Vector *) (y_q + kt * 1152); + + HVX_Vector v_sum = accum_4bit_32x1_lut(vptr, v_act, mask_h4, lut); + HVX_Vector v_sum_sf = Q6_Vsf_equals_Vw(v_sum); + + HVX_Vector v_scale_w = vptr[4]; + HVX_Vector v_scale_a = v_act[8]; + HVX_Vector v_scale_comb = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a); + HVX_Vector v_sum_scaled = hvx_vec_mul_f32_f32(v_sum_sf, v_scale_comb); + + v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled); + } + + hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float); +} + +static void tiled_vec_dot_iq4nl_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) { + const uint8_t * restrict tile_ptr = vx; + const uint8_t * restrict y0_q = vy0; + const uint8_t * restrict y1_q = vy1; + + HVX_Vector v_sum_float_c0 = Q6_V_vzero(); + HVX_Vector v_sum_float_c1 = Q6_V_vzero(); + HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + HVX_Vector lut = *(const HVX_Vector *) kvalues_iq4nl_lut; + + uint32_t n_k_tiles = n / 32; + uint32_t kt = 0; + for (; kt + 1 < n_k_tiles; kt += 2) { + const HVX_Vector * restrict vptr0 = (const HVX_Vector *) (tile_ptr + (kt + 0) * 640); + const HVX_Vector * restrict v_act0_0 = (const HVX_Vector *) (y0_q + (kt + 0) * 1152); + const HVX_Vector * restrict v_act1_0 = (const HVX_Vector *) (y1_q + (kt + 0) * 1152); + + const HVX_Vector * restrict vptr1 = (const HVX_Vector *) (tile_ptr + (kt + 1) * 640); + const HVX_Vector * restrict v_act0_1 = (const HVX_Vector *) (y0_q + (kt + 1) * 1152); + const HVX_Vector * restrict v_act1_1 = (const HVX_Vector *) (y1_q + (kt + 1) * 1152); + + HVX_VectorPair v_sums0 = accum_4bit_32x2_lut(vptr0, v_act0_0, v_act1_0, mask_h4, lut); + HVX_VectorPair v_sums1 = accum_4bit_32x2_lut(vptr1, v_act0_1, v_act1_1, mask_h4, lut); + + HVX_Vector v_sum_c0_0 = Q6_V_lo_W(v_sums0); + HVX_Vector v_sum_c1_0 = Q6_V_hi_W(v_sums0); + HVX_Vector v_sum_c0_1 = Q6_V_lo_W(v_sums1); + HVX_Vector v_sum_c1_1 = Q6_V_hi_W(v_sums1); + + HVX_Vector v_sum_sf_c0_0 = Q6_Vsf_equals_Vw(v_sum_c0_0); + HVX_Vector v_sum_sf_c1_0 = Q6_Vsf_equals_Vw(v_sum_c1_0); + HVX_Vector v_sum_sf_c0_1 = Q6_Vsf_equals_Vw(v_sum_c0_1); + HVX_Vector v_sum_sf_c1_1 = Q6_Vsf_equals_Vw(v_sum_c1_1); + + HVX_Vector v_scale_w0 = vptr0[4]; + HVX_Vector v_scale_w1 = vptr1[4]; + HVX_Vector v_scale_a_c0_0 = v_act0_0[8]; + HVX_Vector v_scale_a_c1_0 = v_act1_0[8]; + HVX_Vector v_scale_a_c0_1 = v_act0_1[8]; + HVX_Vector v_scale_a_c1_1 = v_act1_1[8]; + + HVX_Vector v_scale_comb_c0_0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w0, v_scale_a_c0_0); + HVX_Vector v_scale_comb_c1_0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w0, v_scale_a_c1_0); + HVX_Vector v_scale_comb_c0_1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w1, v_scale_a_c0_1); + HVX_Vector v_scale_comb_c1_1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w1, v_scale_a_c1_1); + + HVX_Vector v_sum_scaled_c0_0 = hvx_vec_mul_f32_f32(v_sum_sf_c0_0, v_scale_comb_c0_0); + HVX_Vector v_sum_scaled_c1_0 = hvx_vec_mul_f32_f32(v_sum_sf_c1_0, v_scale_comb_c1_0); + HVX_Vector v_sum_scaled_c0_1 = hvx_vec_mul_f32_f32(v_sum_sf_c0_1, v_scale_comb_c0_1); + HVX_Vector v_sum_scaled_c1_1 = hvx_vec_mul_f32_f32(v_sum_sf_c1_1, v_scale_comb_c1_1); + + v_sum_float_c0 = hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vec_add_f32_f32(v_sum_scaled_c0_0, v_sum_scaled_c0_1)); + v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vec_add_f32_f32(v_sum_scaled_c1_0, v_sum_scaled_c1_1)); + } + + for (; kt < n_k_tiles; kt++) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 640); + const HVX_Vector * restrict v_act0 = (const HVX_Vector *) (y0_q + kt * 1152); + const HVX_Vector * restrict v_act1 = (const HVX_Vector *) (y1_q + kt * 1152); + + HVX_VectorPair v_sums = accum_4bit_32x2_lut(vptr, v_act0, v_act1, mask_h4, lut); + HVX_Vector v_sum_c0 = Q6_V_lo_W(v_sums); + HVX_Vector v_sum_c1 = Q6_V_hi_W(v_sums); + + HVX_Vector v_sum_sf_c0 = Q6_Vsf_equals_Vw(v_sum_c0); + HVX_Vector v_sum_sf_c1 = Q6_Vsf_equals_Vw(v_sum_c1); + + HVX_Vector v_scale_w = vptr[4]; + HVX_Vector v_scale_a_c0 = v_act0[8]; + HVX_Vector v_scale_a_c1 = v_act1[8]; + + HVX_Vector v_scale_comb_c0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a_c0); + HVX_Vector v_scale_comb_c1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a_c1); + + HVX_Vector v_sum_scaled_c0 = hvx_vec_mul_f32_f32(v_sum_sf_c0, v_scale_comb_c0); + HVX_Vector v_sum_scaled_c1 = hvx_vec_mul_f32_f32(v_sum_sf_c1, v_scale_comb_c1); + + v_sum_float_c0 = hvx_vec_add_f32_f32(v_sum_float_c0, v_sum_scaled_c0); + v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1); + } + + hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0); + hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1); +} + +static void tiled_vec_dot_mxfp4_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) { + const uint8_t * restrict tile_ptr = vx; + const uint8_t * restrict y_q = vy; + + HVX_Vector v_sum_float = Q6_V_vzero(); + HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + HVX_Vector lut = *(const HVX_Vector *) kvalues_mxfp4_lut; + HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; + HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); + + uint32_t n_k_tiles = n / 32; + for (uint32_t kt = 0; kt < n_k_tiles; kt++) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 640); + const HVX_Vector * restrict v_act = (const HVX_Vector *) (y_q + kt * 1152); + + HVX_Vector v_sum = accum_4bit_32x1_lut(vptr, v_act, mask_h4, lut); + HVX_Vector v_sum_sf = Q6_Vsf_equals_Vw(v_sum); + + HVX_Vector v_scale_w = hvx_vmem(tile_ptr + kt * 640 + 512); + HVX_Vector r0_d = Q6_V_vdelta_VV(v_scale_w, expand); + r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); + HVX_Vector v_scale_w_f32 = Q6_Vw_vasl_VwR(r0_d, 23); + + HVX_Vector v_scale_a_f16 = v_act[8]; + HVX_VectorPair p_scale_a_f32 = hvx_vec_f16_to_f32_shuff(v_scale_a_f16); + HVX_Vector v_scale_a = Q6_V_lo_W(p_scale_a_f32); + + HVX_Vector v_scale_comb = hvx_vec_mul_f32_f32(v_scale_w_f32, v_scale_a); + HVX_Vector v_sum_scaled = hvx_vec_mul_f32_f32(v_sum_sf, v_scale_comb); + + v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled); + } + + v_sum_float = hvx_vec_mul_f32_f32(v_sum_float, hvx_vec_splat_f32(0.5f)); + + hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float); +} + +static void tiled_vec_dot_mxfp4_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) { + const uint8_t * restrict tile_ptr = vx; + const uint8_t * restrict y0_q = vy0; + const uint8_t * restrict y1_q = vy1; + + HVX_Vector v_sum_float_c0 = Q6_V_vzero(); + HVX_Vector v_sum_float_c1 = Q6_V_vzero(); + HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + HVX_Vector lut = *(const HVX_Vector *) kvalues_mxfp4_lut; + HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; + HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); + + uint32_t n_k_tiles = n / 32; + uint32_t kt = 0; + for (; kt + 1 < n_k_tiles; kt += 2) { + const HVX_Vector * restrict vptr0 = (const HVX_Vector *) (tile_ptr + (kt + 0) * 640); + const HVX_Vector * restrict v_act0_0 = (const HVX_Vector *) (y0_q + (kt + 0) * 1152); + const HVX_Vector * restrict v_act1_0 = (const HVX_Vector *) (y1_q + (kt + 0) * 1152); + + const HVX_Vector * restrict vptr1 = (const HVX_Vector *) (tile_ptr + (kt + 1) * 640); + const HVX_Vector * restrict v_act0_1 = (const HVX_Vector *) (y0_q + (kt + 1) * 1152); + const HVX_Vector * restrict v_act1_1 = (const HVX_Vector *) (y1_q + (kt + 1) * 1152); + + HVX_VectorPair v_sums0 = accum_4bit_32x2_lut(vptr0, v_act0_0, v_act1_0, mask_h4, lut); + HVX_VectorPair v_sums1 = accum_4bit_32x2_lut(vptr1, v_act0_1, v_act1_1, mask_h4, lut); + + HVX_Vector v_sum_c0_0 = Q6_V_lo_W(v_sums0); + HVX_Vector v_sum_c1_0 = Q6_V_hi_W(v_sums0); + HVX_Vector v_sum_c0_1 = Q6_V_lo_W(v_sums1); + HVX_Vector v_sum_c1_1 = Q6_V_hi_W(v_sums1); + + HVX_Vector v_sum_sf_c0_0 = Q6_Vsf_equals_Vw(v_sum_c0_0); + HVX_Vector v_sum_sf_c1_0 = Q6_Vsf_equals_Vw(v_sum_c1_0); + HVX_Vector v_sum_sf_c0_1 = Q6_Vsf_equals_Vw(v_sum_c0_1); + HVX_Vector v_sum_sf_c1_1 = Q6_Vsf_equals_Vw(v_sum_c1_1); + + HVX_Vector v_scale_w0 = hvx_vmem(tile_ptr + (kt + 0) * 640 + 512); + HVX_Vector r0_d0 = Q6_V_vdelta_VV(v_scale_w0, expand); + r0_d0 = Q6_V_vand_VV(r0_d0, e8m0_mask); + HVX_Vector v_scale_w_f32_0 = Q6_Vw_vasl_VwR(r0_d0, 23); + + HVX_Vector v_scale_w1 = hvx_vmem(tile_ptr + (kt + 1) * 640 + 512); + HVX_Vector r0_d1 = Q6_V_vdelta_VV(v_scale_w1, expand); + r0_d1 = Q6_V_vand_VV(r0_d1, e8m0_mask); + HVX_Vector v_scale_w_f32_1 = Q6_Vw_vasl_VwR(r0_d1, 23); + + HVX_Vector v_scale_a_c0_f16_0 = v_act0_0[8]; + HVX_Vector v_scale_a_c1_f16_0 = v_act1_0[8]; + HVX_Vector v_scale_a_c0_f16_1 = v_act0_1[8]; + HVX_Vector v_scale_a_c1_f16_1 = v_act1_1[8]; + + HVX_VectorPair p_scale_a_c0_f32_0 = hvx_vec_f16_to_f32_shuff(v_scale_a_c0_f16_0); + HVX_VectorPair p_scale_a_c1_f32_0 = hvx_vec_f16_to_f32_shuff(v_scale_a_c1_f16_0); + HVX_VectorPair p_scale_a_c0_f32_1 = hvx_vec_f16_to_f32_shuff(v_scale_a_c0_f16_1); + HVX_VectorPair p_scale_a_c1_f32_1 = hvx_vec_f16_to_f32_shuff(v_scale_a_c1_f16_1); + + HVX_Vector v_scale_a_c0_0 = Q6_V_lo_W(p_scale_a_c0_f32_0); + HVX_Vector v_scale_a_c1_0 = Q6_V_lo_W(p_scale_a_c1_f32_0); + HVX_Vector v_scale_a_c0_1 = Q6_V_lo_W(p_scale_a_c0_f32_1); + HVX_Vector v_scale_a_c1_1 = Q6_V_lo_W(p_scale_a_c1_f32_1); + + HVX_Vector v_scale_comb_c0_0 = hvx_vec_mul_f32_f32(v_scale_w_f32_0, v_scale_a_c0_0); + HVX_Vector v_scale_comb_c1_0 = hvx_vec_mul_f32_f32(v_scale_w_f32_0, v_scale_a_c1_0); + HVX_Vector v_scale_comb_c0_1 = hvx_vec_mul_f32_f32(v_scale_w_f32_1, v_scale_a_c0_1); + HVX_Vector v_scale_comb_c1_1 = hvx_vec_mul_f32_f32(v_scale_w_f32_1, v_scale_a_c1_1); + + HVX_Vector v_sum_scaled_c0_0 = hvx_vec_mul_f32_f32(v_sum_sf_c0_0, v_scale_comb_c0_0); + HVX_Vector v_sum_scaled_c1_0 = hvx_vec_mul_f32_f32(v_sum_sf_c1_0, v_scale_comb_c1_0); + HVX_Vector v_sum_scaled_c0_1 = hvx_vec_mul_f32_f32(v_sum_sf_c0_1, v_scale_comb_c0_1); + HVX_Vector v_sum_scaled_c1_1 = hvx_vec_mul_f32_f32(v_sum_sf_c1_1, v_scale_comb_c1_1); + + v_sum_float_c0 = hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vec_add_f32_f32(v_sum_scaled_c0_0, v_sum_scaled_c0_1)); + v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vec_add_f32_f32(v_sum_scaled_c1_0, v_sum_scaled_c1_1)); + } + + for (; kt < n_k_tiles; kt++) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 640); + const HVX_Vector * restrict v_act0 = (const HVX_Vector *) (y0_q + kt * 1152); + const HVX_Vector * restrict v_act1 = (const HVX_Vector *) (y1_q + kt * 1152); + + HVX_VectorPair v_sums = accum_4bit_32x2_lut(vptr, v_act0, v_act1, mask_h4, lut); + HVX_Vector v_sum_c0 = Q6_V_lo_W(v_sums); + HVX_Vector v_sum_c1 = Q6_V_hi_W(v_sums); + + HVX_Vector v_sum_sf_c0 = Q6_Vsf_equals_Vw(v_sum_c0); + HVX_Vector v_sum_sf_c1 = Q6_Vsf_equals_Vw(v_sum_c1); + + HVX_Vector v_scale_w = hvx_vmem(tile_ptr + kt * 640 + 512); + HVX_Vector r0_d = Q6_V_vdelta_VV(v_scale_w, expand); + r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); + HVX_Vector v_scale_w_f32 = Q6_Vw_vasl_VwR(r0_d, 23); + + HVX_Vector v_scale_a_c0_f16 = v_act0[8]; + HVX_Vector v_scale_a_c1_f16 = v_act1[8]; + + HVX_VectorPair p_scale_a_c0_f32 = hvx_vec_f16_to_f32_shuff(v_scale_a_c0_f16); + HVX_VectorPair p_scale_a_c1_f32 = hvx_vec_f16_to_f32_shuff(v_scale_a_c1_f16); + + HVX_Vector v_scale_a_c0 = Q6_V_lo_W(p_scale_a_c0_f32); + HVX_Vector v_scale_a_c1 = Q6_V_lo_W(p_scale_a_c1_f32); + + HVX_Vector v_scale_comb_c0 = hvx_vec_mul_f32_f32(v_scale_w_f32, v_scale_a_c0); + HVX_Vector v_scale_comb_c1 = hvx_vec_mul_f32_f32(v_scale_w_f32, v_scale_a_c1); + + HVX_Vector v_sum_scaled_c0 = hvx_vec_mul_f32_f32(v_sum_sf_c0, v_scale_comb_c0); + HVX_Vector v_sum_scaled_c1 = hvx_vec_mul_f32_f32(v_sum_sf_c1, v_scale_comb_c1); + + v_sum_float_c0 = hvx_vec_add_f32_f32(v_sum_float_c0, v_sum_scaled_c0); + v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1); + } + + v_sum_float_c0 = hvx_vec_mul_f32_f32(v_sum_float_c0, hvx_vec_splat_f32(0.5f)); + v_sum_float_c1 = hvx_vec_mul_f32_f32(v_sum_float_c1, hvx_vec_splat_f32(0.5f)); + + hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0); + hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1); +} + +static inline void quantize_f32_q8_0_tiled_kernel( + const uint8_t * restrict src_data, + uint8_t * restrict dst_data, + uint8_t * restrict tmp_data, + uint32_t ne0, + uint32_t nrows, + size_t src_row_size, + size_t dst_row_size +) { + const size_t src_row_size_padded = hex_round_up(src_row_size, QK_Q8_0_TILED * sizeof(float)); + hvx_splat_f32_a(tmp_data, 0.0f, src_row_size_padded / sizeof(float)); + + for (uint32_t i = 0; i < nrows; ++i) { + hex_l2fetch(src_data, src_row_size, src_row_size, 2); + hvx_copy_f32_aa(tmp_data, src_data, ne0); + + quantize_row_f32_q8_0_tiled((float *) tmp_data, dst_data, ne0); + dst_data += dst_row_size; + src_data += src_row_size; + } +} + +static inline void quantize_f32_q8_1_tiled_kernel( + const uint8_t * restrict src_data, + uint8_t * restrict dst_data, + uint8_t * restrict tmp_data, + uint32_t ne0, + uint32_t nrows, + size_t src_row_size, + size_t dst_row_size +) { + const size_t src_row_size_padded = hex_round_up(src_row_size, QK_Q8_0_TILED * sizeof(float)); + hvx_splat_f32_a(tmp_data, 0.0f, src_row_size_padded / sizeof(float)); + + for (uint32_t i = 0; i < nrows; ++i) { + hex_l2fetch(src_data, src_row_size, src_row_size, 2); + hvx_copy_f32_aa(tmp_data, src_data, ne0); + + quantize_row_f32_q8_1_tiled((float *) tmp_data, dst_data, ne0); + dst_data += dst_row_size; + src_data += src_row_size; + } +} + +static inline void quantize_f32_q8_0_tiled_block_kernel( + const float * restrict src, + uint8_t * restrict dst, + uint8_t * restrict tmp_data, + uint32_t ne0, + uint32_t ib_first, + uint32_t ib_last, + size_t src_row_size, + size_t dst_row_size, + uint32_t r, + uint32_t c +) { + const uint32_t qk = QK_Q8_0_TILED; + const uint32_t nb = (ne0 + qk - 1) / qk; + + for (uint32_t ib = ib_first; ib < ib_last; ++ib) { + const uint8_t * restrict src_ptr = (const uint8_t *) src + r * src_row_size + c * qk * sizeof(float); + uint8_t * restrict dst_ptr = dst + r * dst_row_size + c * 4 * 1152; + + hex_l2fetch(src_ptr, qk * sizeof(float), qk * sizeof(float), 1); + + if (c == nb - 1) { + uint32_t active_elements = ne0 - c * qk; + hvx_splat_f32_a(tmp_data, 0.0f, qk); + hvx_copy_f32_aa(tmp_data, src_ptr, active_elements); + } else { + hvx_copy_f32_aa(tmp_data, src_ptr, qk); + } + + quantize_block_f32_q8_0_tiled((float *) tmp_data, dst_ptr); + + c++; + if (c == nb) { + c = 0; + r++; + } + } +} + +static inline void quantize_f32_q8_1_tiled_block_kernel( + const float * restrict src, + uint8_t * restrict dst, + uint8_t * restrict tmp_data, + uint32_t ne0, + uint32_t ib_first, + uint32_t ib_last, + size_t src_row_size, + size_t dst_row_size, + uint32_t r, + uint32_t c +) { + const uint32_t qk = QK_Q8_0_TILED; + const uint32_t nb = (ne0 + qk - 1) / qk; + + for (uint32_t ib = ib_first; ib < ib_last; ++ib) { + const uint8_t * restrict src_ptr = (const uint8_t *) src + r * src_row_size + c * qk * sizeof(float); + uint8_t * restrict dst_ptr = dst + r * dst_row_size + c * 4 * 1280; + + hex_l2fetch(src_ptr, qk * sizeof(float), qk * sizeof(float), 1); + + if (c == nb - 1) { + uint32_t active_elements = ne0 - c * qk; + hvx_splat_f32_a(tmp_data, 0.0f, qk); + hvx_copy_f32_aa(tmp_data, src_ptr, active_elements); + } else { + hvx_copy_f32_aa(tmp_data, src_ptr, qk); + } + + quantize_block_f32_q8_1_tiled((float *) tmp_data, dst_ptr); + + c++; + if (c == nb) { + c = 0; + r++; + } + } +} diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index 53ab33c07b..d76512ea4a 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -361,7 +361,7 @@ static void vtcm_free(struct htp_context * ctx) { static void htp_packet_callback(dspqueue_t queue, int error, void * context); static void htp_error_callback(dspqueue_t queue, int error, void * context); -AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_queue_id, uint32 n_hvx, uint32 use_hmx, uint64_t max_vmem) { +AEEResult htp_iface_start(remote_handle64 handle, uint32_t sess_id, uint64_t dsp_queue_id, uint32_t n_hvx, uint32_t n_hmx, uint64_t max_vmem) { struct htp_context * ctx = (struct htp_context *) handle; if (!ctx) { @@ -395,10 +395,9 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que return AEE_ENOMEMORY; } -#ifdef HTP_HAS_HMX - ctx->hmx_enabled = use_hmx; + ctx->hmx_enabled = n_hmx; ctx->hmx_queue = NULL; - if (use_hmx) { + if (n_hmx) { ctx->hmx_queue = hmx_queue_create(16, ctx->vtcm_rctx); if (ctx->hmx_queue) { ctx->hmx_queue->trace = &ctx->trace[HTP_MAX_NTHREADS]; @@ -407,8 +406,7 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que ctx->hmx_enabled = false; } } - FARF(HIGH, "HMX %s (use_hmx=%d)", ctx->hmx_enabled ? "enabled" : "disabled", use_hmx); -#endif + FARF(HIGH, "HMX %s (n_hmx=%d)", ctx->hmx_enabled ? "enabled" : "disabled", n_hmx); qurt_sysenv_max_hthreads_t hw_threads; qurt_sysenv_get_max_hw_threads(&hw_threads); @@ -481,13 +479,11 @@ AEEResult htp_iface_stop(remote_handle64 handle) { dma_queue_delete(ctx->dma[i]); } -#ifdef HTP_HAS_HMX if (ctx->hmx_queue) { hmx_queue_delete(ctx->hmx_queue); ctx->hmx_queue = NULL; } ctx->hmx_enabled = false; -#endif vtcm_free(ctx); @@ -500,6 +496,36 @@ AEEResult htp_iface_stop(remote_handle64 handle) { return AEE_SUCCESS; } +AEEResult htp_iface_hwinfo(remote_handle64 handle, uint32_t * n_threads, uint32_t * n_hvx, uint32_t * n_hmx, uint64_t * vtcm_size) { + (void)handle; + if (!n_threads || !n_hvx || !n_hmx || !vtcm_size) { + return AEE_EBADPARM; + } + + qurt_sysenv_max_hthreads_t hw_threads; + qurt_sysenv_get_max_hw_threads(&hw_threads); + uint32_t hw_nhvx = (qurt_hvx_get_units() >> 8) & 0xFF; + + uint32_t n_hvx_val = hw_nhvx; + if (n_hvx_val > hw_threads.max_hthreads) { + n_hvx_val = hw_threads.max_hthreads; + } + if (n_hvx_val > HTP_MAX_NTHREADS) { + n_hvx_val = HTP_MAX_NTHREADS; + } + + // for now we force n_threads == n_hvx + *n_threads = n_hvx_val; + *n_hvx = n_hvx_val; + *n_hmx = 1; + + uint32_t vtcm_sz = 8 * 1024 * 1024; // 8MB default fallback + HAP_compute_res_query_VTCM(0, (unsigned int *)&vtcm_sz, NULL, NULL, NULL); + *vtcm_size = vtcm_sz; + + return AEE_SUCCESS; +} + static void htp_error_callback(dspqueue_t queue, int error, void * context) { // No errors expected on the DSP. FARF(ERROR, "Error callback: 0x%08x", (unsigned) error); @@ -554,6 +580,12 @@ static int execute_op(struct htp_ops_context * octx) { case HTP_OP_MUL_MAT_ID: return op_matmul_id(octx); + case HTP_OP_MUL_MAT_QKV: + return op_matmul_qkv(octx); + + case HTP_OP_MUL_MAT_FFN: + return op_matmul_ffn(octx); + case HTP_OP_MUL: case HTP_OP_ADD: case HTP_OP_SUB: @@ -762,8 +794,9 @@ static void prep_tensors(struct htp_context *ctx, struct htp_buf_desc *bufs, str } } -static void proc_op_req(struct htp_ops_context * octx, struct htp_tensor *tens, uint32_t idx, struct htp_op_desc * op) { +static int proc_op_req(struct htp_ops_context * octx, struct htp_tensor *tens, uint32_t idx, struct htp_op_desc * op) { memcpy(octx->op_params, op->params, sizeof(octx->op_params)); + memcpy(octx->kernel_params, op->kernel_params, sizeof(octx->kernel_params)); octx->flags = op->flags; octx->op = op->opcode; @@ -785,22 +818,41 @@ static void proc_op_req(struct htp_ops_context * octx, struct htp_tensor *tens, src->ne[0], src->ne[1], src->ne[3], src->ne[3]); } - // Prep output tensor - struct htp_tensor *dst = tens + op->dst; + // Prep output tensors + for (uint32_t i = 0; i < HTP_OP_MAX_OUTPUTS; i++) { + uint16_t dst_idx = op->dst[i]; + if (dst_idx == 0xffff) { + octx->dsts[i] = NULL; + continue; + } + struct htp_tensor *dst = tens + dst_idx; + octx->dsts[i] = dst; - octx->dst = dst; + FARF(HIGH, "prep-dst[%u] #%u: data %p size %u : %u:%u:%u:%u", i, dst_idx, (void*) dst->data, dst->size, + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3]); + } - FARF(HIGH, "prep-dst #%u: data %p size %u : %u:%u:%u:%u", op->dst, (void*) dst->data, dst->size, - dst->ne[0], dst->ne[1], dst->ne[3], dst->ne[3]); + int status = execute_op(octx); - (void) execute_op(octx); + octx->src0_spad.src = NULL; + octx->src1_spad.src = NULL; + octx->src2_spad.src = NULL; + octx->src3_spad.src = NULL; + octx->dst_spad.src = NULL; // flush buffers on output - hex_l2flush((void *) dst->data, dst->size); - dst->flags |= HTP_TENSOR_FLUSHED; + for (uint32_t i = 0; i < HTP_OP_MAX_OUTPUTS; i++) { + if (octx->dsts[i]) { + struct htp_tensor *dst = (struct htp_tensor *)octx->dsts[i]; + hex_l2flush((void *) dst->data, dst->size); + dst->flags |= HTP_TENSOR_FLUSHED; - FARF(HIGH, "post-dst #%u: data %p size %u : %u:%u:%u:%u", op->dst, (void*) dst->data, dst->size, - dst->ne[0], dst->ne[1], dst->ne[3], dst->ne[3]); + FARF(HIGH, "post-dst[%u] #%u: data %p size %u : %u:%u:%u:%u", i, op->dst[i], (void*) dst->data, dst->size, + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3]); + } + } + + return status; } #define DSPQUEUE_POLL_TIMEOUT_USEC 100 @@ -892,20 +944,26 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) { } } + int op_status = HTP_STATUS_OK; + uint32_t op_wakeup = n_ops / 2; // half-way throgh the batch + for (uint32_t i=0; i < n_ops; i++) { struct profile_data prof; - if (i == (n_ops-1)) { - // wake up the host before starting the last op + if (i == op_wakeup) { dspqueue_write_early_wakeup_noblock(queue, 0, 0); } profile_start(ctx->profiler, &prof); - proc_op_req(octx, tens, i, &ops[i]); + op_status = proc_op_req(octx, tens, i, &ops[i]); profile_stop(ctx->profiler, &prof); + if (op_status != HTP_STATUS_OK) { + break; + } + if (ctx->profiler) { pds[i].opcode = ops[i].opcode; pds[i].usecs = prof.usecs; @@ -919,7 +977,7 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) { struct htp_opbatch_rsp rsp; rsp.id = req.id; - rsp.status = HTP_STATUS_OK; + rsp.status = op_status; rsp.n_bufs = n_bufs; rsp.n_tensors = n_tens; rsp.n_ops = n_ops; diff --git a/ggml/src/ggml-hexagon/htp/matmul-ops.c b/ggml/src/ggml-hexagon/htp/matmul-ops.c index 8e016c1be5..81a0ffbebb 100644 --- a/ggml/src/ggml-hexagon/htp/matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/matmul-ops.c @@ -5,6 +5,7 @@ #include #include +#include #include #include @@ -17,33 +18,50 @@ #include "ggml-common.h" #include "htp-ctx.h" #include "htp-ops.h" -#include "htp-ops.h" -#include "hmx-ops.h" +#include "matmul-ops.h" +#include "vtcm-utils.h" -#define MM_SPAD_SRC0_NROWS 16 -#define MM_SPAD_SRC1_NROWS 16 -#define MM_SPAD_DST_NROWS 2 +typedef struct { + float *dst; + const float *activation; + const __fp16 *weight; + int m; + int k; + int n; + int act_stride; + int weight_stride; + int dst_stride; + int ne02; + int ne03; + int ne12; + int ne13; + size_t src0_nb2; + size_t src0_nb3; + size_t src1_nb2; + size_t src1_nb3; + size_t dst_nb2; + size_t dst_nb3; +} hmx_mm_f16_f32_batched_params_t; -struct htp_matmul_context { +struct htp_mm_context { const char * type; struct htp_ops_context * octx; - void (*vec_dot_1x1)(const int n, float * restrict s0, + void (*vec_dot_1x1)(const uint32_t n, float * restrict s0, const void * restrict vx0, const void * restrict vy0); - void (*vec_dot_2x1)(const int n, float * restrict s0, + void (*vec_dot_2x1)(const uint32_t n, float * restrict s0, const void * restrict vx0, const void * restrict vx1, const void * restrict vy0); - void (*vec_dot_2x2)(const int n, float * restrict s0, float * restrict s1, + void (*vec_dot_2x2)(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx0, const void * restrict vx1, const void * restrict vy0, const void * restrict vy1); - void (*vec_dot_4x1)(const int n, float * restrict s0, - const void * restrict vx0, const void * restrict vx1, - const void * restrict vx2, const void * restrict vx3, - const void * restrict vy0); + void (*vec_dot_32x1)(const uint32_t n, float * restrict s, + const void * restrict vx, + const void * restrict vy, uint32_t valid_rows); // Precomputed values uint32_t src0_nrows_per_thread; @@ -53,11 +71,37 @@ struct htp_matmul_context { struct fastdiv_values mm_div_ne1; struct fastdiv_values mm_div_r2; struct fastdiv_values mm_div_r3; + struct fastdiv_values mm_div_ne11; + + // Precomputed block-parallel quantization values + uint32_t quant_ib_first[MAX_NUM_WORKERS]; + uint32_t quant_ib_last[MAX_NUM_WORKERS]; + uint32_t quant_r[MAX_NUM_WORKERS]; + uint32_t quant_c[MAX_NUM_WORKERS]; // Fields for scattered mapping & HMX support in MUL_MAT_ID const uint32_t * matrix_row_counts; const struct mmid_row_mapping * matrix_rows; - bool hmx_eligible; + + // Dynamic VTCM pointers allocated sequentially + uint8_t * vtcm_src0; + uint8_t * vtcm_src1; + uint8_t * vtcm_src2; + uint8_t * vtcm_src3; + uint8_t * vtcm_dst; + + // Cached strides + uint32_t vtcm_src0_stride; + uint32_t vtcm_src1_stride; + uint32_t vtcm_src2_stride; + uint32_t vtcm_src3_stride; + + // Cached thread offsets/sizes + uint32_t vtcm_src0_size_per_thread; + uint32_t vtcm_src1_size_per_thread; + uint32_t vtcm_src2_size_per_thread; + uint32_t vtcm_src3_size_per_thread; + uint32_t vtcm_dst_size_per_thread; }; // vdelta control to expand first 32 e8m0 values into 32 uint32 elements @@ -89,2835 +133,6 @@ static const uint8_t __attribute__((aligned(VLEN))) kvalues_mxfp4_lut[] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, }; -static inline HVX_Vector_x8 hvx_vec_load_iq4nlx4x8_full(const uint8_t * restrict ptr) { - const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; - - HVX_Vector v0_1 = vptr[0]; // first 256 elements (128 bytes) - HVX_Vector v2_3 = vptr[1]; // ... - HVX_Vector v4_5 = vptr[2]; // ... - HVX_Vector v6_7 = vptr[3]; // ... - - const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); - const HVX_Vector lut = *(const HVX_Vector *) kvalues_iq4nl_lut; - - HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F - HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4 - HVX_Vector v2 = Q6_V_vand_VV(v2_3, mask_h4); // & 0x0F - HVX_Vector v3 = Q6_Vub_vlsr_VubR(v2_3, 4); // >> 4 - HVX_Vector v4 = Q6_V_vand_VV(v4_5, mask_h4); // & 0x0F - HVX_Vector v5 = Q6_Vub_vlsr_VubR(v4_5, 4); // >> 4 - HVX_Vector v6 = Q6_V_vand_VV(v6_7, mask_h4); // & 0x0F - HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4); // >> 4 - - v0 = Q6_Vb_vlut32_VbVbI(v0, lut, 0); - v1 = Q6_Vb_vlut32_VbVbI(v1, lut, 0); - v2 = Q6_Vb_vlut32_VbVbI(v2, lut, 0); - v3 = Q6_Vb_vlut32_VbVbI(v3, lut, 0); - v4 = Q6_Vb_vlut32_VbVbI(v4, lut, 0); - v5 = Q6_Vb_vlut32_VbVbI(v5, lut, 0); - v6 = Q6_Vb_vlut32_VbVbI(v6, lut, 0); - v7 = Q6_Vb_vlut32_VbVbI(v7, lut, 0); - - HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 }; - return r; -} - -static inline HVX_Vector_x8 hvx_vec_load_iq4nlx4x8_partial(const uint8_t * restrict ptr, uint32_t n) { - const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; - - const uint32_t qk = QK_Q4_0x4x2; // 256 - const uint32_t nb = n / qk; - const uint32_t nloe = n % qk; - - const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); - const HVX_Vector lut = *(const HVX_Vector *) kvalues_iq4nl_lut; - - HVX_Vector_x8 r; - uint32_t i = 0; - - #pragma unroll(2) - for (i = 0; i < nb; i++) { - HVX_Vector v = vptr[i]; // 256 elements (128 bytes) - HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : first 128 elements - HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : second 128 elements - r.v[i * 2 + 0] = Q6_Vb_vlut32_VbVbI(v0, lut, 0); - r.v[i * 2 + 1] = Q6_Vb_vlut32_VbVbI(v1, lut, 0); - } - - if (nloe) { - HVX_Vector v = vptr[i]; // 256 elements (128 bytes) - HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : even 128 elements - HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : odd 128 elements - HVX_VectorPair v0_1_p = Q6_W_vshuff_VVR(v1, v0, -1); // zip even:odd:... - r.v[i * 2 + 0] = Q6_Vb_vlut32_VbVbI(Q6_V_lo_W(v0_1_p), lut, 0); - r.v[i * 2 + 1] = Q6_Vb_vlut32_VbVbI(Q6_V_hi_W(v0_1_p), lut, 0); - } - - return r; -} - -// q4x4x2 and q8x4x2 are the flat q4/8_0 formats where all quants are stored first followed by all scales - -static inline size_t q8x4x2_row_size(uint32_t ne) { - // ensures perfect alignment of quants and full row - const uint32_t qk = QK_Q8_0x4x2; - const uint32_t nb = (ne + qk - 1) / qk; - return hex_round_up(ne + nb * 8 * sizeof(__fp16), 128); -} - -static inline size_t q8_1x4x2_row_size(uint32_t ne) { - // ensures perfect alignment of quants and full row - const uint32_t qk = QK_Q8_0x4x2; - const uint32_t nb = (ne + qk - 1) / qk; - return hex_round_up(ne + nb * 8 * 2 * sizeof(__fp16), 128); -} - -static inline HVX_Vector_x8 hvx_vec_load_q4x4x8_full(const uint8_t * restrict ptr) { - const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; - - HVX_Vector v0_1 = vptr[0]; // first 256 elements (128 bytes) - HVX_Vector v2_3 = vptr[1]; // ... - HVX_Vector v4_5 = vptr[2]; // ... - HVX_Vector v6_7 = vptr[3]; // ... - - const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); - const HVX_Vector i8 = Q6_Vb_vsplat_R(8); - - HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F : first 128 elements - HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4 : second 128 elements - HVX_Vector v2 = Q6_V_vand_VV(v2_3, mask_h4); // & 0x0F ... - HVX_Vector v3 = Q6_Vub_vlsr_VubR(v2_3, 4); // >> 4 - HVX_Vector v4 = Q6_V_vand_VV(v4_5, mask_h4); // & 0x0F - HVX_Vector v5 = Q6_Vub_vlsr_VubR(v4_5, 4); // >> 4 - HVX_Vector v6 = Q6_V_vand_VV(v6_7, mask_h4); // & 0x0F - HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4); // >> 4 - - // Convert uint4 to int4 (i.e. x - 8) - v0 = Q6_Vb_vsub_VbVb(v0, i8); - v1 = Q6_Vb_vsub_VbVb(v1, i8); - v2 = Q6_Vb_vsub_VbVb(v2, i8); - v3 = Q6_Vb_vsub_VbVb(v3, i8); - v4 = Q6_Vb_vsub_VbVb(v4, i8); - v5 = Q6_Vb_vsub_VbVb(v5, i8); - v6 = Q6_Vb_vsub_VbVb(v6, i8); - v7 = Q6_Vb_vsub_VbVb(v7, i8); - - HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 }; - return r; -} - -static HVX_Vector_x8 hvx_vec_load_q4x4x8_partial(const uint8_t * restrict ptr, uint32_t n) { - const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; - - const uint32_t qk = QK_Q4_0x4x2; // 256 - const uint32_t nb = n / qk; - const uint32_t nloe = n % qk; - - const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); - const HVX_Vector i8 = Q6_Vb_vsplat_R(8); - - HVX_Vector_x8 r; - uint32_t i = 0; - - #pragma unroll(2) - for (i=0; i < nb; i++) { - HVX_Vector v = vptr[i]; // 256 elements (128 bytes) - HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : first 128 elements - HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : second 128 elements - r.v[i*2+0] = Q6_Vb_vsub_VbVb(v0, i8); - r.v[i*2+1] = Q6_Vb_vsub_VbVb(v1, i8); - } - - if (nloe) { - HVX_Vector v = vptr[i]; // 256 elements (128 bytes) - HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : even 128 elements - HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : odd 128 elements - HVX_VectorPair v0_1_p = Q6_W_vshuff_VVR(v1, v0, -1); // zip even:odd:... - r.v[i*2+0] = Q6_Vb_vsub_VbVb(Q6_V_lo_W(v0_1_p), i8); - r.v[i*2+1] = Q6_Vb_vsub_VbVb(Q6_V_hi_W(v0_1_p), i8); - } - - return r; -} - -static inline HVX_Vector_x8 hvx_vec_load_q4_1x4x8_full(const uint8_t * restrict ptr) { - const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; - - HVX_Vector v0_1 = vptr[0]; // first 256 elements (128 bytes) - HVX_Vector v2_3 = vptr[1]; // ... - HVX_Vector v4_5 = vptr[2]; // ... - HVX_Vector v6_7 = vptr[3]; // ... - - const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); - - HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F : first 128 elements - HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4 : second 128 elements - HVX_Vector v2 = Q6_V_vand_VV(v2_3, mask_h4); // & 0x0F ... - HVX_Vector v3 = Q6_Vub_vlsr_VubR(v2_3, 4); // >> 4 - HVX_Vector v4 = Q6_V_vand_VV(v4_5, mask_h4); // & 0x0F - HVX_Vector v5 = Q6_Vub_vlsr_VubR(v4_5, 4); // >> 4 - HVX_Vector v6 = Q6_V_vand_VV(v6_7, mask_h4); // & 0x0F - HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4); // >> 4 - - HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 }; - return r; -} - -static HVX_Vector_x8 hvx_vec_load_q4_1x4x8_partial(const uint8_t * restrict ptr, uint32_t n) { - const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; - - const uint32_t qk = QK_Q4_0x4x2; // 256 - const uint32_t nb = n / qk; - const uint32_t nloe = n % qk; - - const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); - - HVX_Vector_x8 r; - uint32_t i = 0; - - #pragma unroll(2) - for (i=0; i < nb; i++) { - HVX_Vector v = vptr[i]; // 256 elements (128 bytes) - HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : first 128 elements - HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : second 128 elements - r.v[i*2+0] = v0; - r.v[i*2+1] = v1; - } - - if (nloe) { - HVX_Vector v = vptr[i]; // 256 elements (128 bytes) - HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : even 128 elements - HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : odd 128 elements - HVX_VectorPair v0_1_p = Q6_W_vshuff_VVR(v1, v0, -1); // zip even:odd:... - r.v[i*2+0] = Q6_V_lo_W(v0_1_p); - r.v[i*2+1] = Q6_V_hi_W(v0_1_p); - } - - return r; -} - -static inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8_full(const uint8_t * restrict ptr) { - const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; - - HVX_Vector v0_1 = vptr[0]; // first 256 elements (128 bytes) - HVX_Vector v2_3 = vptr[1]; // ... - HVX_Vector v4_5 = vptr[2]; // ... - HVX_Vector v6_7 = vptr[3]; // ... - - const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); - const HVX_Vector lut = *(const HVX_Vector *) kvalues_mxfp4_lut; - - HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F - HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4 - HVX_Vector v2 = Q6_V_vand_VV(v2_3, mask_h4); // & 0x0F - HVX_Vector v3 = Q6_Vub_vlsr_VubR(v2_3, 4); // >> 4 - HVX_Vector v4 = Q6_V_vand_VV(v4_5, mask_h4); // & 0x0F - HVX_Vector v5 = Q6_Vub_vlsr_VubR(v4_5, 4); // >> 4 - HVX_Vector v6 = Q6_V_vand_VV(v6_7, mask_h4); // & 0x0F - HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4); // >> 4 - - v0 = Q6_Vb_vlut32_VbVbI(v0, lut, 0); - v1 = Q6_Vb_vlut32_VbVbI(v1, lut, 0); - v2 = Q6_Vb_vlut32_VbVbI(v2, lut, 0); - v3 = Q6_Vb_vlut32_VbVbI(v3, lut, 0); - v4 = Q6_Vb_vlut32_VbVbI(v4, lut, 0); - v5 = Q6_Vb_vlut32_VbVbI(v5, lut, 0); - v6 = Q6_Vb_vlut32_VbVbI(v6, lut, 0); - v7 = Q6_Vb_vlut32_VbVbI(v7, lut, 0); - - HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 }; - return r; -} - -static inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8_partial(const uint8_t * restrict ptr, uint32_t n) { - const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; - - const uint32_t qk = QK_Q4_0x4x2; // 256 - const uint32_t nb = n / qk; - const uint32_t nloe = n % qk; - - const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); - const HVX_Vector lut = *(const HVX_Vector *) kvalues_mxfp4_lut; - - HVX_Vector_x8 r; - uint32_t i = 0; - - #pragma unroll(2) - for (i=0; i < nb; i++) { - HVX_Vector v = vptr[i]; // 256 elements (128 bytes) - HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : first 128 elements - HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : second 128 elements - r.v[i*2+0] = Q6_Vb_vlut32_VbVbI(v0, lut, 0); - r.v[i*2+1] = Q6_Vb_vlut32_VbVbI(v1, lut, 0); - } - - if (nloe) { - HVX_Vector v = vptr[i]; // 256 elements (128 bytes) - HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : even 128 elements - HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : odd 128 elements - HVX_VectorPair v0_1_p = Q6_W_vshuff_VVR(v1, v0, -1); // zip even:odd:... - r.v[i*2+0] = Q6_Vb_vlut32_VbVbI(Q6_V_lo_W(v0_1_p), lut, 0); - r.v[i*2+1] = Q6_Vb_vlut32_VbVbI(Q6_V_hi_W(v0_1_p), lut, 0); - } - - return r; -} - -static inline HVX_Vector_x8 hvx_vec_load_q8x4x8_full(const uint8_t * restrict ptr) { - const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; - - HVX_Vector v0 = vptr[0]; // first 128 vals - HVX_Vector v1 = vptr[1]; // ... - HVX_Vector v2 = vptr[2]; // ... - HVX_Vector v3 = vptr[3]; // ... - HVX_Vector v4 = vptr[4]; // ... - HVX_Vector v5 = vptr[5]; // ... - HVX_Vector v6 = vptr[6]; // ... - HVX_Vector v7 = vptr[7]; // ... - - HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 }; - return r; -} - -static inline HVX_Vector_x8 hvx_vec_load_q8x4x8_partial(const uint8_t * restrict ptr, uint32_t nloe) { - return hvx_vec_load_q8x4x8_full(ptr); -} - -// Reduce multiply 1024 x 1024 int8 elements (32x q4/8 blocks in 8x HVX vectors). -// Accumulate each block into a single int32 value. -// Return a single HVX vector with 32x int32 accumulators. -// This version is parameterized to support less than 1024 elements. -// if() checks are optimized out at compile time -- make sure to pass N as a constexpr. - -static inline HVX_Vector hvx_vec_rmpy_x8_n(HVX_Vector_x8 x, HVX_Vector_x8 y, unsigned int n) { - HVX_Vector r0 = Q6_V_vzero(); - HVX_Vector r1 = Q6_V_vzero(); - HVX_Vector r2 = Q6_V_vzero(); - HVX_Vector r3 = Q6_V_vzero(); - HVX_Vector r4 = Q6_V_vzero(); - HVX_Vector r5 = Q6_V_vzero(); - HVX_Vector r6 = Q6_V_vzero(); - HVX_Vector r7 = Q6_V_vzero(); - - HVX_VectorPair p3; - HVX_VectorPair p2; - HVX_VectorPair p1; - HVX_VectorPair p0; - - if (n >= 128) { r0 = Q6_Vw_vrmpy_VbVb(x.v[0], y.v[0]); } - if (n >= 256) { r1 = Q6_Vw_vrmpy_VbVb(x.v[1], y.v[1]); } - if (n >= 384) { r2 = Q6_Vw_vrmpy_VbVb(x.v[2], y.v[2]); } - if (n >= 512) { r3 = Q6_Vw_vrmpy_VbVb(x.v[3], y.v[3]); } - if (n >= 640) { r4 = Q6_Vw_vrmpy_VbVb(x.v[4], y.v[4]); } - if (n >= 768) { r5 = Q6_Vw_vrmpy_VbVb(x.v[5], y.v[5]); } - if (n >= 896) { r6 = Q6_Vw_vrmpy_VbVb(x.v[6], y.v[6]); } - if (n >= 1024) { r7 = Q6_Vw_vrmpy_VbVb(x.v[7], y.v[7]); } - - if (n >= 128) { p0 = Q6_W_vdeal_VVR(r1, r0, -4); } - if (n >= 384) { p1 = Q6_W_vdeal_VVR(r3, r2, -4); } - if (n >= 640) { p2 = Q6_W_vdeal_VVR(r5, r4, -4); } - if (n >= 896) { p3 = Q6_W_vdeal_VVR(r7, r6, -4); } - - if (n >= 128) { r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0)); } - if (n >= 384) { r1 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p1), Q6_V_hi_W(p1)); } - if (n >= 640) { r2 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p2), Q6_V_hi_W(p2)); } - if (n >= 896) { r3 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p3), Q6_V_hi_W(p3)); } - - if (n >= 128) { p0 = Q6_W_vdeal_VVR(r1, r0, -4); } - if (n >= 640) { p1 = Q6_W_vdeal_VVR(r3, r2, -4); } - - if (n >= 128) { r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0)); } - if (n >= 640) { r1 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p1), Q6_V_hi_W(p1)); } - - if (n >= 128) { p0 = Q6_W_vdeal_VVR(r1, r0, -4); } - if (n >= 128) { r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0)); } - - return r0; -} - -static inline HVX_Vector hvx_vec_rmpy_x8_full(HVX_Vector_x8 x, HVX_Vector_x8 y) { - HVX_Vector r0 = Q6_Vw_vrmpy_VbVb(x.v[0], y.v[0]); - HVX_Vector r1 = Q6_Vw_vrmpy_VbVb(x.v[1], y.v[1]); - HVX_Vector r2 = Q6_Vw_vrmpy_VbVb(x.v[2], y.v[2]); - HVX_Vector r3 = Q6_Vw_vrmpy_VbVb(x.v[3], y.v[3]); - HVX_Vector r4 = Q6_Vw_vrmpy_VbVb(x.v[4], y.v[4]); - HVX_Vector r5 = Q6_Vw_vrmpy_VbVb(x.v[5], y.v[5]); - HVX_Vector r6 = Q6_Vw_vrmpy_VbVb(x.v[6], y.v[6]); - HVX_Vector r7 = Q6_Vw_vrmpy_VbVb(x.v[7], y.v[7]); - - HVX_VectorPair p0 = Q6_W_vdeal_VVR(r1, r0, -4); - HVX_VectorPair p1 = Q6_W_vdeal_VVR(r3, r2, -4); - HVX_VectorPair p2 = Q6_W_vdeal_VVR(r5, r4, -4); - HVX_VectorPair p3 = Q6_W_vdeal_VVR(r7, r6, -4); - - r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0)); - r1 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p1), Q6_V_hi_W(p1)); - r2 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p2), Q6_V_hi_W(p2)); - r3 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p3), Q6_V_hi_W(p3)); - - p0 = Q6_W_vdeal_VVR(r1, r0, -4); - p1 = Q6_W_vdeal_VVR(r3, r2, -4); - - r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0)); - r1 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p1), Q6_V_hi_W(p1)); - - p0 = Q6_W_vdeal_VVR(r1, r0, -4); - r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0)); - - return r0; -} - -static inline HVX_Vector hvx_vec_rmpy_x8_partial(HVX_Vector_x8 x, HVX_Vector_x8 y, unsigned int n) { - if (n >= 512) - return hvx_vec_rmpy_x8_full(x, y); - - return hvx_vec_rmpy_x8_partial(x, y, 512); -} - -static void vec_dot_q4_1x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) { - assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx0 % 128 == 0); - assert((unsigned long) vy0 % 128 == 0); - - const uint32_t qk = QK_Q4_0x4x2 * 4; - - const uint32_t x_dblk_size = 8 * 4 * 2 * 2; // 32x (d, m) __fp16 = 128 bytes - const uint32_t x_qblk_size = qk / 2; // int4 - const uint32_t x_qrow_size = n / 2; // int4 (not padded) - - const uint32_t y_dblk_size = 8 * 4 * 4; // 32x (d, s) __fp16 = 128 bytes - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) - - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales/offsets - - const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales/sums - - // Row sum (sf) - HVX_Vector r0_sum = Q6_V_vzero(); - - const uint32_t nb = n / qk; // num full blocks - const uint32_t nloe = n % qk; // num leftover elemements - - uint32_t i = 0; - for (; i < nb; i++) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_full(r0_x_q + i * x_qblk_size); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); - - HVX_Vector ds = *(const HVX_UVector *) (y_d + i * y_dblk_size); - HVX_VectorPair ds_deal = Q6_W_vdeal_VVR(ds, ds, -2); - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds_deal)); - HVX_Vector vy_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds_deal)); - - HVX_Vector dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); - HVX_VectorPair dm_deal = Q6_W_vdeal_VVR(dm, dm, -2); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(dm_deal)); - HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(dm_deal)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s))); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_ms); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa_total, r0_sum)); - } - - // Process leftovers - if (nloe) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_partial(r0_x_q + i * x_qblk_size, nloe); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); - - HVX_Vector ds = *(const HVX_UVector *) (y_d + i * y_dblk_size); - HVX_VectorPair ds_deal = Q6_W_vdeal_VVR(ds, ds, -2); - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds_deal)); - HVX_Vector vy_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds_deal)); - - HVX_Vector dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); - HVX_VectorPair dm_deal = Q6_W_vdeal_VVR(dm, dm, -2); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(dm_deal)); - HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(dm_deal)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s))); - - // Zero out unused elements - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); - r0_dd = Q6_V_vand_QV(bmask, r0_dd); - r0_ms = Q6_V_vand_QV(bmask, r0_ms); - r0_ia = Q6_V_vand_QV(bmask, r0_ia); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_ms); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa_total, r0_sum)); - } - - r0_sum = hvx_vec_reduce_sum_f32(r0_sum); - hvx_vec_store_u(s0, 4, r0_sum); -} - -static void vec_dot_q4_1x4x2_q8x4x2_2x1(const int n, float * restrict s0, - const void * restrict vx0, const void * restrict vx1, - const void * restrict vy0) { - assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx0 % 128 == 0); - assert((unsigned long) vx1 % 128 == 0); - assert((unsigned long) vy0 % 128 == 0); - - const uint32_t qk = QK_Q4_0x4x2 * 4; - - const uint32_t x_dblk_size = 8 * 4 * 2 * 2; // 32x (d, m) __fp16 = 128 bytes - const uint32_t x_qblk_size = qk / 2; // int4 - const uint32_t x_qrow_size = n / 2; // int4 (not padded) - - const uint32_t y_dblk_size = 8 * 4 * 4; // 32x (d, s) __fp16 = 128 bytes - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) - - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales - const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first - const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales - - const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales/sums - - // Row sum (sf) - HVX_Vector r0_sum = Q6_V_vzero(); - HVX_Vector r1_sum = Q6_V_vzero(); - - const uint32_t nb = n / qk; // num full blocks - const uint32_t nloe = n % qk; // num leftover elemements - - uint32_t i = 0; - for (; i < nb; i++) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_full(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_q4_1x4x8_full(r1_x_q + i * x_qblk_size); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); - - HVX_Vector ds = *(const HVX_UVector *) (y_d + i * y_dblk_size); - HVX_VectorPair ds_deal = Q6_W_vdeal_VVR(ds, ds, -2); - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds_deal)); - HVX_Vector vy_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds_deal)); - - HVX_Vector r0_dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); - HVX_VectorPair r0_dm_deal = Q6_W_vdeal_VVR(r0_dm, r0_dm, -2); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r0_dm_deal)); - HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r0_dm_deal)); - - HVX_Vector r1_dm = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); - HVX_VectorPair r1_dm_deal = Q6_W_vdeal_VVR(r1_dm, r1_dm, -2); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r1_dm_deal)); - HVX_Vector r1_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r1_dm_deal)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s))); - - HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); - HVX_Vector r1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy_s))); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_ms); - - HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - HVX_Vector r1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_ms); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa_total, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa_total, r1_sum)); - } - - // Process leftovers - if (nloe) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_partial(r0_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r1_q = hvx_vec_load_q4_1x4x8_partial(r1_x_q + i * x_qblk_size, nloe); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); - - HVX_Vector ds = *(const HVX_UVector *) (y_d + i * y_dblk_size); - HVX_VectorPair ds_deal = Q6_W_vdeal_VVR(ds, ds, -2); - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds_deal)); - HVX_Vector vy_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds_deal)); - - HVX_Vector r0_dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); - HVX_VectorPair r0_dm_deal = Q6_W_vdeal_VVR(r0_dm, r0_dm, -2); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r0_dm_deal)); - HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r0_dm_deal)); - - HVX_Vector r1_dm = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); - HVX_VectorPair r1_dm_deal = Q6_W_vdeal_VVR(r1_dm, r1_dm, -2); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r1_dm_deal)); - HVX_Vector r1_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r1_dm_deal)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s))); - - HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); - HVX_Vector r1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy_s))); - - // Zero out unused elements - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); - r0_dd = Q6_V_vand_QV(bmask, r0_dd); - r0_ms = Q6_V_vand_QV(bmask, r0_ms); - r1_dd = Q6_V_vand_QV(bmask, r1_dd); - r1_ms = Q6_V_vand_QV(bmask, r1_ms); - r0_ia = Q6_V_vand_QV(bmask, r0_ia); - r1_ia = Q6_V_vand_QV(bmask, r1_ia); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_ms); - - HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - HVX_Vector r1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_ms); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa_total, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa_total, r1_sum)); - } - - HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); - hvx_vec_store_u(s0, 8, rsum); -} - -static void vec_dot_q4_1x4x2_q8x4x2_4x1(const int n, float * restrict s0, - const void * restrict vx0, const void * restrict vx1, - const void * restrict vx2, const void * restrict vx3, - const void * restrict vy0) { - assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx0 % 128 == 0); - assert((unsigned long) vx1 % 128 == 0); - assert((unsigned long) vx2 % 128 == 0); - assert((unsigned long) vx3 % 128 == 0); - assert((unsigned long) vy0 % 128 == 0); - - const uint32_t qk = QK_Q4_0x4x2 * 4; - - const uint32_t x_dblk_size = 8 * 4 * 2 * 2; // 32x (d, m) __fp16 = 128 bytes - const uint32_t x_qblk_size = qk / 2; // int4 - const uint32_t x_qrow_size = n / 2; // int4 (not padded) - - const uint32_t y_dblk_size = 8 * 4 * 4; // 32x (d, s) __fp16 = 128 bytes - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) - - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales - const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first - const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales - const uint8_t * restrict r2_x_q = ((const uint8_t *) vx2) + 0; // quants first - const uint8_t * restrict r2_x_d = ((const uint8_t *) vx2) + x_qrow_size; // then scales - const uint8_t * restrict r3_x_q = ((const uint8_t *) vx3) + 0; // quants first - const uint8_t * restrict r3_x_d = ((const uint8_t *) vx3) + x_qrow_size; // then scales - - const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales/sums - - // Row sum (sf) - HVX_Vector r0_sum = Q6_V_vzero(); - HVX_Vector r1_sum = Q6_V_vzero(); - HVX_Vector r2_sum = Q6_V_vzero(); - HVX_Vector r3_sum = Q6_V_vzero(); - - const uint32_t nb = n / qk; // num full blocks - const uint32_t nloe = n % qk; // num leftover elements - - uint32_t i = 0; - for (; i < nb; i++) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_full(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_q4_1x4x8_full(r1_x_q + i * x_qblk_size); - HVX_Vector_x8 r2_q = hvx_vec_load_q4_1x4x8_full(r2_x_q + i * x_qblk_size); - HVX_Vector_x8 r3_q = hvx_vec_load_q4_1x4x8_full(r3_x_q + i * x_qblk_size); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); - HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r2_q, vy_q)); - HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r3_q, vy_q)); - - HVX_Vector ds = *(const HVX_UVector *) (y_d + i * y_dblk_size); - HVX_VectorPair ds_deal = Q6_W_vdeal_VVR(ds, ds, -2); - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds_deal)); - HVX_Vector vy_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds_deal)); - - HVX_Vector r0_dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); - HVX_VectorPair r0_dm_deal = Q6_W_vdeal_VVR(r0_dm, r0_dm, -2); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r0_dm_deal)); - HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r0_dm_deal)); - - HVX_Vector r1_dm = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); - HVX_VectorPair r1_dm_deal = Q6_W_vdeal_VVR(r1_dm, r1_dm, -2); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r1_dm_deal)); - HVX_Vector r1_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r1_dm_deal)); - - HVX_Vector r2_dm = *(const HVX_UVector *) (r2_x_d + i * x_dblk_size); - HVX_VectorPair r2_dm_deal = Q6_W_vdeal_VVR(r2_dm, r2_dm, -2); - HVX_Vector r2_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r2_dm_deal)); - HVX_Vector r2_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r2_dm_deal)); - - HVX_Vector r3_dm = *(const HVX_UVector *) (r3_x_d + i * x_dblk_size); - HVX_VectorPair r3_dm_deal = Q6_W_vdeal_VVR(r3_dm, r3_dm, -2); - HVX_Vector r3_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r3_dm_deal)); - HVX_Vector r3_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r3_dm_deal)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s))); - - HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); - HVX_Vector r1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy_s))); - - HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d))); - HVX_Vector r2_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_m, vy_s))); - - HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d))); - HVX_Vector r3_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_m, vy_s))); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_ms); - - HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - HVX_Vector r1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_ms); - - HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); - HVX_Vector r2_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_ms); - - HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); - HVX_Vector r3_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_ms); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa_total, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa_total, r1_sum)); - r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa_total, r2_sum)); - r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa_total, r3_sum)); - } - - if (nloe) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_partial(r0_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r1_q = hvx_vec_load_q4_1x4x8_partial(r1_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r2_q = hvx_vec_load_q4_1x4x8_partial(r2_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r3_q = hvx_vec_load_q4_1x4x8_partial(r3_x_q + i * x_qblk_size, nloe); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); - HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r2_q, vy_q, nloe)); - HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r3_q, vy_q, nloe)); - - HVX_Vector ds = *(const HVX_UVector *) (y_d + i * y_dblk_size); - HVX_VectorPair ds_deal = Q6_W_vdeal_VVR(ds, ds, -2); - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds_deal)); - HVX_Vector vy_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds_deal)); - - HVX_Vector r0_dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); - HVX_VectorPair r0_dm_deal = Q6_W_vdeal_VVR(r0_dm, r0_dm, -2); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r0_dm_deal)); - HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r0_dm_deal)); - - HVX_Vector r1_dm = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); - HVX_VectorPair r1_dm_deal = Q6_W_vdeal_VVR(r1_dm, r1_dm, -2); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r1_dm_deal)); - HVX_Vector r1_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r1_dm_deal)); - - HVX_Vector r2_dm = *(const HVX_UVector *) (r2_x_d + i * x_dblk_size); - HVX_VectorPair r2_dm_deal = Q6_W_vdeal_VVR(r2_dm, r2_dm, -2); - HVX_Vector r2_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r2_dm_deal)); - HVX_Vector r2_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r2_dm_deal)); - - HVX_Vector r3_dm = *(const HVX_UVector *) (r3_x_d + i * x_dblk_size); - HVX_VectorPair r3_dm_deal = Q6_W_vdeal_VVR(r3_dm, r3_dm, -2); - HVX_Vector r3_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r3_dm_deal)); - HVX_Vector r3_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r3_dm_deal)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s))); - - HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); - HVX_Vector r1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy_s))); - - HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d))); - HVX_Vector r2_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_m, vy_s))); - - HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d))); - HVX_Vector r3_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_m, vy_s))); - - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); - r0_dd = Q6_V_vand_QV(bmask, r0_dd); - r0_ms = Q6_V_vand_QV(bmask, r0_ms); - r1_dd = Q6_V_vand_QV(bmask, r1_dd); - r1_ms = Q6_V_vand_QV(bmask, r1_ms); - r2_dd = Q6_V_vand_QV(bmask, r2_dd); - r2_ms = Q6_V_vand_QV(bmask, r2_ms); - r3_dd = Q6_V_vand_QV(bmask, r3_dd); - r3_ms = Q6_V_vand_QV(bmask, r3_ms); - r0_ia = Q6_V_vand_QV(bmask, r0_ia); - r1_ia = Q6_V_vand_QV(bmask, r1_ia); - r2_ia = Q6_V_vand_QV(bmask, r2_ia); - r3_ia = Q6_V_vand_QV(bmask, r3_ia); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_ms); - - HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - HVX_Vector r1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_ms); - - HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); - HVX_Vector r2_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_ms); - - HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); - HVX_Vector r3_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_ms); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa_total, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa_total, r1_sum)); - r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa_total, r2_sum)); - r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa_total, r3_sum)); - } - - HVX_Vector_x4 rsum_in = { .v = { r0_sum, r1_sum, r2_sum, r3_sum } }; - HVX_Vector rsum = hvx_vec_reduce_sum_f32x4(rsum_in); - hvx_vec_store_u(s0, 16, rsum); -} - - -static void vec_dot_q4_1x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1, - const void * restrict vx0, const void * restrict vx1, - const void * restrict vy0, const void * restrict vy1) { - assert(n % 32 == 0); - assert((unsigned long) vx0 % 128 == 0); - assert((unsigned long) vx1 % 128 == 0); - assert((unsigned long) vy0 % 128 == 0); - assert((unsigned long) vy1 % 128 == 0); - - const uint32_t qk = QK_Q4_0x4x2 * 4; - - const uint32_t x_dblk_size = 8 * 4 * 2 * 2; // 32x (d, m) __fp16 = 128 bytes - const uint32_t x_qblk_size = qk / 2; // int4 - const uint32_t x_qrow_size = n / 2; // int4 (not padded) - - const uint32_t y_dblk_size = 8 * 4 * 4; // 32x (d, s) __fp16 = 128 bytes - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) - - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales - const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first - const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales - - const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; // quants first - const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales/sums - const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; // quants first - const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales/sums - - // Row sums (sf) - 4 accumulators for 2×2 tile - HVX_Vector r0_c0_sum = Q6_V_vzero(); - HVX_Vector r0_c1_sum = Q6_V_vzero(); - HVX_Vector r1_c0_sum = Q6_V_vzero(); - HVX_Vector r1_c1_sum = Q6_V_vzero(); - - const uint32_t nb = n / qk; // num full blocks - const uint32_t nloe = n % qk; // num leftover elements - - uint32_t i = 0; - for (; i < nb; i++) { - // Load src1 columns - HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size); - HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size); - - // Load src0 rows - HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_full(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_q4_1x4x8_full(r1_x_q + i * x_qblk_size); - - // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1 - HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q)); - HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q)); - HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q)); - HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q)); - - // Load scales - HVX_Vector ds0 = *(const HVX_UVector *) (y0_d + i * y_dblk_size); - HVX_VectorPair ds0_deal = Q6_W_vdeal_VVR(ds0, ds0, -2); - HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds0_deal)); - HVX_Vector vy0_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds0_deal)); - - HVX_Vector ds1 = *(const HVX_UVector *) (y1_d + i * y_dblk_size); - HVX_VectorPair ds1_deal = Q6_W_vdeal_VVR(ds1, ds1, -2); - HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds1_deal)); - HVX_Vector vy1_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds1_deal)); - - HVX_Vector r0_dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); - HVX_VectorPair r0_dm_deal = Q6_W_vdeal_VVR(r0_dm, r0_dm, -2); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r0_dm_deal)); - HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r0_dm_deal)); - - HVX_Vector r1_dm = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); - HVX_VectorPair r1_dm_deal = Q6_W_vdeal_VVR(r1_dm, r1_dm, -2); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r1_dm_deal)); - HVX_Vector r1_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r1_dm_deal)); - - // Compute combined scales - HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); - HVX_Vector r0_c0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy0_s))); - - HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); - HVX_Vector r0_c1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy1_s))); - - HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); - HVX_Vector r1_c0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy0_s))); - - HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); - HVX_Vector r1_c1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy1_s))); - - // Apply scales and accumulate - HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); - HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); - HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); - HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); - - HVX_Vector r0_c0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_ms); - HVX_Vector r0_c1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_ms); - HVX_Vector r1_c0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_ms); - HVX_Vector r1_c1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_ms); - - r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa_total, r0_c0_sum)); - r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa_total, r0_c1_sum)); - r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa_total, r1_c0_sum)); - r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa_total, r1_c1_sum)); - } - - // Process leftovers - if (nloe) { - HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial(y0_q + i * y_qblk_size, nloe); - HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial(y1_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_partial(r0_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r1_q = hvx_vec_load_q4_1x4x8_partial(r1_x_q + i * x_qblk_size, nloe); - - HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe)); - HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe)); - HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy0_q, nloe)); - HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy1_q, nloe)); - - HVX_Vector ds0 = *(const HVX_UVector *) (y0_d + i * y_dblk_size); - HVX_VectorPair ds0_deal = Q6_W_vdeal_VVR(ds0, ds0, -2); - HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds0_deal)); - HVX_Vector vy0_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds0_deal)); - - HVX_Vector ds1 = *(const HVX_UVector *) (y1_d + i * y_dblk_size); - HVX_VectorPair ds1_deal = Q6_W_vdeal_VVR(ds1, ds1, -2); - HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds1_deal)); - HVX_Vector vy1_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds1_deal)); - - HVX_Vector r0_dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); - HVX_VectorPair r0_dm_deal = Q6_W_vdeal_VVR(r0_dm, r0_dm, -2); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r0_dm_deal)); - HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r0_dm_deal)); - - HVX_Vector r1_dm = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); - HVX_VectorPair r1_dm_deal = Q6_W_vdeal_VVR(r1_dm, r1_dm, -2); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r1_dm_deal)); - HVX_Vector r1_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r1_dm_deal)); - - HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); - HVX_Vector r0_c0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy0_s))); - - HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); - HVX_Vector r0_c1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy1_s))); - - HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); - HVX_Vector r1_c0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy0_s))); - - HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); - HVX_Vector r1_c1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy1_s))); - - // Zero out unused elements - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); - r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd); - r0_c0_ms = Q6_V_vand_QV(bmask, r0_c0_ms); - r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd); - r0_c1_ms = Q6_V_vand_QV(bmask, r0_c1_ms); - r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd); - r1_c0_ms = Q6_V_vand_QV(bmask, r1_c0_ms); - r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd); - r1_c1_ms = Q6_V_vand_QV(bmask, r1_c1_ms); - - r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia); - r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia); - r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia); - r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia); - - HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); - HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); - HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); - HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); - - HVX_Vector r0_c0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_ms); - HVX_Vector r0_c1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_ms); - HVX_Vector r1_c0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_ms); - HVX_Vector r1_c1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_ms); - - r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa_total, r0_c0_sum)); - r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa_total, r0_c1_sum)); - r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa_total, r1_c0_sum)); - r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa_total, r1_c1_sum)); - } - - // Reduce and store results - HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum); - HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum); - - hvx_vec_store_u(s0, 8, r0_r1_c0_sum); // row0,col0 row1,col0 - hvx_vec_store_u(s1, 8, r0_r1_c1_sum); // row0,col1 row1,col1 -} - -static void vec_dot_q4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) { - assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx0 % 128 == 0); - assert((unsigned long) vy0 % 128 == 0); - - const uint32_t qk = QK_Q4_0x4x2 * 4; - - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t x_qblk_size = qk / 2; // int4 - const uint32_t x_qrow_size = n / 2; // int4 (not padded) - - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) - - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales - - const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales - - // Row sum (sf) - HVX_Vector r0_sum = Q6_V_vzero(); - - // Multiply and accumulate into int32. - // Compute combined scale (fp32). - // Apply scale to acc and accumulate into the row sum (qf32). - - const uint32_t nb = n / qk; // num full blocks - const uint32_t nloe = n % qk; // num leftover elemements - - uint32_t i = 0; - for (; i < nb; i++) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); - - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - } - - // Process leftovers - if (nloe) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); - - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - - // Zero out unused elements - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); - r0_dd = Q6_V_vand_QV(bmask, r0_dd); - r0_ia = Q6_V_vand_QV(bmask, r0_ia); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - } - - r0_sum = hvx_vec_reduce_sum_f32(r0_sum); - - hvx_vec_store_u(s0, 4, r0_sum); -} - -static void vec_dot_q4x4x2_q8x4x2_2x1(const int n, float * restrict s0, - const void * restrict vx0, const void * restrict vx1, - const void * restrict vy0) { - assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx0 % 128 == 0); - assert((unsigned long) vx1 % 128 == 0); - assert((unsigned long) vy0 % 128 == 0); - - const uint32_t qk = QK_Q4_0x4x2 * 4; - - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t x_qblk_size = qk / 2; // int4 - const uint32_t x_qrow_size = n / 2; // int4 (not padded) - - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) - - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales - const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first - const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales - - const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales - - // Row sum (sf) - HVX_Vector r0_sum = Q6_V_vzero(); - HVX_Vector r1_sum = Q6_V_vzero(); - - // Multiply and accumulate into int32. - // Compute combined scale (fp32). - // Apply scale to acc and accumulate into the row sum (qf32). - - const uint32_t nb = n / qk; // num full blocks - const uint32_t nloe = n % qk; // num leftover elemements - - uint32_t i = 0; - for (; i < nb; i++) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_full(r1_x_q + i * x_qblk_size); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); - - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); - } - - // Process leftovers - if (nloe) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_partial(r1_x_q + i * x_qblk_size, nloe); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); - - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); - - // Zero out unused elements - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); - r0_dd = Q6_V_vand_QV(bmask, r0_dd); - r1_dd = Q6_V_vand_QV(bmask, r1_dd); - r0_ia = Q6_V_vand_QV(bmask, r0_ia); - r1_ia = Q6_V_vand_QV(bmask, r1_ia); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); - } - - HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); - hvx_vec_store_u(s0, 8, rsum); -} - -static void vec_dot_q4x4x2_q8x4x2_4x1(const int n, float * restrict s0, - const void * restrict vx0, const void * restrict vx1, - const void * restrict vx2, const void * restrict vx3, - const void * restrict vy0) { - assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx0 % 128 == 0); - assert((unsigned long) vx1 % 128 == 0); - assert((unsigned long) vx2 % 128 == 0); - assert((unsigned long) vx3 % 128 == 0); - assert((unsigned long) vy0 % 128 == 0); - - const uint32_t qk = QK_Q4_0x4x2 * 4; - - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t x_qblk_size = qk / 2; // int4 - const uint32_t x_qrow_size = n / 2; // int4 (not padded) - - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) - - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; - const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; - const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; - const uint8_t * restrict r2_x_q = ((const uint8_t *) vx2) + 0; - const uint8_t * restrict r2_x_d = ((const uint8_t *) vx2) + x_qrow_size; - const uint8_t * restrict r3_x_q = ((const uint8_t *) vx3) + 0; - const uint8_t * restrict r3_x_d = ((const uint8_t *) vx3) + x_qrow_size; - - const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); - const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); - - // Row sum (sf) - HVX_Vector r0_sum = Q6_V_vzero(); - HVX_Vector r1_sum = Q6_V_vzero(); - HVX_Vector r2_sum = Q6_V_vzero(); - HVX_Vector r3_sum = Q6_V_vzero(); - - const uint32_t nb = n / qk; // num full blocks - const uint32_t nloe = n % qk; // num leftover elements - - uint32_t i = 0; - for (; i < nb; i++) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_full(r1_x_q + i * x_qblk_size); - HVX_Vector_x8 r2_q = hvx_vec_load_q4x4x8_full(r2_x_q + i * x_qblk_size); - HVX_Vector_x8 r3_q = hvx_vec_load_q4x4x8_full(r3_x_q + i * x_qblk_size); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); - HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r2_q, vy_q)); - HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r3_q, vy_q)); - - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); - HVX_Vector r2_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r2_x_d + i * x_dblk_size)); - HVX_Vector r3_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r3_x_d + i * x_dblk_size)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); - HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d))); - HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d))); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); - HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); - r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum)); - r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum)); - } - - if (nloe) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_partial(r1_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r2_q = hvx_vec_load_q4x4x8_partial(r2_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r3_q = hvx_vec_load_q4x4x8_partial(r3_x_q + i * x_qblk_size, nloe); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); - HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r2_q, vy_q, nloe)); - HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r3_q, vy_q, nloe)); - - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); - HVX_Vector r2_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r2_x_d + i * x_dblk_size)); - HVX_Vector r3_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r3_x_d + i * x_dblk_size)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); - HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d))); - HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d))); - - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); - r0_dd = Q6_V_vand_QV(bmask, r0_dd); - r1_dd = Q6_V_vand_QV(bmask, r1_dd); - r2_dd = Q6_V_vand_QV(bmask, r2_dd); - r3_dd = Q6_V_vand_QV(bmask, r3_dd); - r0_ia = Q6_V_vand_QV(bmask, r0_ia); - r1_ia = Q6_V_vand_QV(bmask, r1_ia); - r2_ia = Q6_V_vand_QV(bmask, r2_ia); - r3_ia = Q6_V_vand_QV(bmask, r3_ia); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); - HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); - r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum)); - r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum)); - } - - HVX_Vector_x4 rsum_in = { .v = { r0_sum, r1_sum, r2_sum, r3_sum } }; - HVX_Vector rsum = hvx_vec_reduce_sum_f32x4(rsum_in); - hvx_vec_store_u(s0, 16, rsum); -} - - -static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1, - const void * restrict vx0, const void * restrict vx1, - const void * restrict vy0, const void * restrict vy1) { - assert(n % 32 == 0); - assert((unsigned long) vx0 % 128 == 0); - assert((unsigned long) vx1 % 128 == 0); - assert((unsigned long) vy0 % 128 == 0); - assert((unsigned long) vy1 % 128 == 0); - - const uint32_t qk = QK_Q4_0x4x2 * 4; - - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t x_qblk_size = qk / 2; // int4 - const uint32_t x_qrow_size = n / 2; // int4 (not padded) - - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) - - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales - const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first - const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales - - const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; // quants first - const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales - const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; // quants first - const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales - - // Row sums (sf) - 4 accumulators for 2×2 tile - HVX_Vector r0_c0_sum = Q6_V_vzero(); - HVX_Vector r0_c1_sum = Q6_V_vzero(); - HVX_Vector r1_c0_sum = Q6_V_vzero(); - HVX_Vector r1_c1_sum = Q6_V_vzero(); - - const uint32_t nb = n / qk; // num full blocks - const uint32_t nloe = n % qk; // num leftover elements - - uint32_t i = 0; - for (; i < nb; i++) { - // Load src1 columns (reused across both src0 rows) - HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size); - HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size); - - // Load src0 rows (reused across both src1 columns) - HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_full(r1_x_q + i * x_qblk_size); - - // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1 - HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q)); - HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q)); - HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q)); - HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q)); - - // Load scales - HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); - HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); - - // Compute combined scales - HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); - HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); - HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); - HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); - - // Apply scales and accumulate - HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); - HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); - HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); - HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); - - r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); - r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); - r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); - r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); - } - - // Process leftovers - if (nloe) { - HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial(y0_q + i * y_qblk_size, nloe); - HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial(y1_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_partial(r1_x_q + i * x_qblk_size, nloe); - - HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe)); - HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe)); - HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy0_q, nloe)); - HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy1_q, nloe)); - - HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); - HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); - - HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); - HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); - HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); - HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); - - // Zero out unused scales - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); - r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd); - r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd); - r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd); - r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd); - r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia); - r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia); - r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia); - r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia); - - HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); - HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); - HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); - HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); - - r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); - r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); - r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); - r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); - } - - // Reduce and store results - HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum); - HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum); - - hvx_vec_store_u(s0, 8, r0_r1_c0_sum); // row0,col0 row1,col0 - hvx_vec_store_u(s1, 8, r0_r1_c1_sum); // row0,col1 row1,col1 -} - -static void vec_dot_q8x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) { - assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx0 % 128 == 0); - assert((unsigned long) vy0 % 128 == 0); - - const uint32_t qk = QK_Q4_0x4x2 * 4; - - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t x_qblk_size = qk; // int8 - const uint32_t x_qrow_size = n; // int8 (not padded) - - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) - - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales - - const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales - - // Row sum (sf) - HVX_Vector r0_sum = Q6_V_vzero(); - - // Multiply and accumulate into int32. - // Compute combined scale (fp32). - // Apply scale to acc and accumulate into the row sum (qf32). - - const uint32_t nb = n / qk; // num full blocks - int32_t nloe = n % qk; // num leftover elemements (must be signed) - - uint32_t i = 0; - for (; i < nb; i++) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_full(r0_x_q + i * x_qblk_size); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); - - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - } - - // Process leftovers - if (nloe) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_partial(r0_x_q + i * x_qblk_size, nloe); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); - - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - - // Zero out unused elements - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); - r0_dd = Q6_V_vand_QV(bmask, r0_dd); - r0_ia = Q6_V_vand_QV(bmask, r0_ia); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - } - - r0_sum = hvx_vec_reduce_sum_f32(r0_sum); - - hvx_vec_store_u(s0, 4, r0_sum); -} - -static void vec_dot_q8x4x2_q8x4x2_2x1(const int n, float * restrict s0, - const void * restrict vx0, const void * restrict vx1, - const void * restrict vy0) { - assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx0 % 128 == 0); - assert((unsigned long) vx1 % 128 == 0); - assert((unsigned long) vy0 % 128 == 0); - - const uint32_t qk = QK_Q4_0x4x2 * 4; - - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t x_qblk_size = qk; // int8 - const uint32_t x_qrow_size = n; // int8 (not padded) - - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) - - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales - const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first - const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales - - const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales - - // Row sum (qf32) - HVX_Vector r0_sum = Q6_V_vzero(); - HVX_Vector r1_sum = Q6_V_vzero(); - - // Multiply and accumulate into int32. - // Compute combined scale (fp32). - // Apply scale to acc and accumulate into the row sum (qf32). - - const uint32_t nb = n / qk; // num full blocks - int32_t nloe = n % qk; // num leftover elemements (must be signed) - - uint32_t i = 0; - for (; i < nb; i++) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_full(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_full(r1_x_q + i * x_qblk_size); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); - - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); - } - - // Process leftovers - if (nloe) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_partial(r0_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_partial(r1_x_q + i * x_qblk_size, nloe); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); - - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); - - // Zero out unused elements - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); - r0_dd = Q6_V_vand_QV(bmask, r0_dd); - r1_dd = Q6_V_vand_QV(bmask, r1_dd); - r0_ia = Q6_V_vand_QV(bmask, r0_ia); - r1_ia = Q6_V_vand_QV(bmask, r1_ia); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); - } - - HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); - hvx_vec_store_u(s0, 8, rsum); -} - -static void vec_dot_q8x4x2_q8x4x2_4x1(const int n, float * restrict s0, - const void * restrict vx0, const void * restrict vx1, - const void * restrict vx2, const void * restrict vx3, - const void * restrict vy0) { - assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx0 % 128 == 0); - assert((unsigned long) vx1 % 128 == 0); - assert((unsigned long) vx2 % 128 == 0); - assert((unsigned long) vx3 % 128 == 0); - assert((unsigned long) vy0 % 128 == 0); - - const uint32_t qk = QK_Q4_0x4x2 * 4; - - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t x_qblk_size = qk; // int8 - const uint32_t x_qrow_size = n; // int8 (not padded) - - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) - - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales - const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first - const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales - const uint8_t * restrict r2_x_q = ((const uint8_t *) vx2) + 0; // quants first - const uint8_t * restrict r2_x_d = ((const uint8_t *) vx2) + x_qrow_size; // then scales - const uint8_t * restrict r3_x_q = ((const uint8_t *) vx3) + 0; // quants first - const uint8_t * restrict r3_x_d = ((const uint8_t *) vx3) + x_qrow_size; // then scales - - const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales - - // Row sum (qf32) - HVX_Vector r0_sum = Q6_V_vzero(); - HVX_Vector r1_sum = Q6_V_vzero(); - HVX_Vector r2_sum = Q6_V_vzero(); - HVX_Vector r3_sum = Q6_V_vzero(); - - const uint32_t nb = n / qk; // num full blocks - int32_t nloe = n % qk; // num leftover elemements (must be signed) - - uint32_t i = 0; - for (; i < nb; i++) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_full(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_full(r1_x_q + i * x_qblk_size); - HVX_Vector_x8 r2_q = hvx_vec_load_q8x4x8_full(r2_x_q + i * x_qblk_size); - HVX_Vector_x8 r3_q = hvx_vec_load_q8x4x8_full(r3_x_q + i * x_qblk_size); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); - HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r2_q, vy_q)); - HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r3_q, vy_q)); - - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); - HVX_Vector r2_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r2_x_d + i * x_dblk_size)); - HVX_Vector r3_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r3_x_d + i * x_dblk_size)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); - HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d))); - HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d))); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); - HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); - r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum)); - r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum)); - } - - if (nloe) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_partial(r0_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_partial(r1_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r2_q = hvx_vec_load_q8x4x8_partial(r2_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r3_q = hvx_vec_load_q8x4x8_partial(r3_x_q + i * x_qblk_size, nloe); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); - HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r2_q, vy_q, nloe)); - HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r3_q, vy_q, nloe)); - - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); - HVX_Vector r2_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r2_x_d + i * x_dblk_size)); - HVX_Vector r3_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r3_x_d + i * x_dblk_size)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); - HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d))); - HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d))); - - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); - r0_dd = Q6_V_vand_QV(bmask, r0_dd); - r1_dd = Q6_V_vand_QV(bmask, r1_dd); - r2_dd = Q6_V_vand_QV(bmask, r2_dd); - r3_dd = Q6_V_vand_QV(bmask, r3_dd); - r0_ia = Q6_V_vand_QV(bmask, r0_ia); - r1_ia = Q6_V_vand_QV(bmask, r1_ia); - r2_ia = Q6_V_vand_QV(bmask, r2_ia); - r3_ia = Q6_V_vand_QV(bmask, r3_ia); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); - HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); - r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum)); - r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum)); - } - - HVX_Vector_x4 rsum_in = { .v = { r0_sum, r1_sum, r2_sum, r3_sum } }; - HVX_Vector rsum = hvx_vec_reduce_sum_f32x4(rsum_in); - hvx_vec_store_u(s0, 16, rsum); -} - - -static void vec_dot_q8x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1, - const void * restrict vx0, const void * restrict vx1, - const void * restrict vy0, const void * restrict vy1) { - assert(n % 32 == 0); - assert((unsigned long) vx0 % 128 == 0); - assert((unsigned long) vx1 % 128 == 0); - assert((unsigned long) vy0 % 128 == 0); - assert((unsigned long) vy1 % 128 == 0); - - const uint32_t qk = QK_Q8_0x4x2 * 4; - - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t x_qblk_size = qk; // int8 - const uint32_t x_qrow_size = n; // int8 (not padded) - - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) - - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales - const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first - const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales - - const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; // quants first - const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales - const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; // quants first - const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales - - // Row sums (sf) - 4 accumulators for 2×2 tile - HVX_Vector r0_c0_sum = Q6_V_vzero(); - HVX_Vector r0_c1_sum = Q6_V_vzero(); - HVX_Vector r1_c0_sum = Q6_V_vzero(); - HVX_Vector r1_c1_sum = Q6_V_vzero(); - - const uint32_t nb = n / qk; // num full blocks - const uint32_t nloe = n % qk; // num leftover elements - - uint32_t i = 0; - for (; i < nb; i++) { - // Load src1 columns (reused across both src0 rows) - HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size); - HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size); - - // Load src0 rows (reused across both src1 columns) - HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_full(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_full(r1_x_q + i * x_qblk_size); - - // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1 - HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q)); - HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q)); - HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q)); - HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q)); - - // Load scales - HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); - HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); - - // Compute combined scales - HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); - HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); - HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); - HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); - - // Apply scales and accumulate - HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); - HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); - HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); - HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); - - r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); - r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); - r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); - r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); - } - - // Process leftovers - if (nloe) { - HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial(y0_q + i * y_qblk_size, nloe); - HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial(y1_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_partial(r0_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_partial(r1_x_q + i * x_qblk_size, nloe); - - HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe)); - HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe)); - HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy0_q, nloe)); - HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy1_q, nloe)); - - HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); - HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); - - HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); - HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); - HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); - HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); - - // Zero out unused elements - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); - r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd); - r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd); - r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd); - r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd); - r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia); - r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia); - r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia); - r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia); - - HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); - HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); - HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); - HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); - - r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); - r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); - r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); - r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); - } - - // Reduce and store results - HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum); - HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum); - - hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum); // row0,col0 row1,col0 - hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1 -} - -// ======== IQ4_NL x Q8_0 vec_dot kernels ======== -// Same structure as Q4_0 vec_dot but uses IQ4_NL LUT-based load (4-bit index -> int8 kvalue). -// Scale format is identical to Q4_0 (fp16 scales). - -static void vec_dot_iq4nlx4x2_q8x4x2_1x1(const int n, - float * restrict s0, - const void * restrict vx0, - const void * restrict vy0) { - assert(n % 32 == 0); - assert((unsigned long) vx0 % 128 == 0); - assert((unsigned long) vy0 % 128 == 0); - - const uint32_t qk = QK_Q4_0x4x2 * 4; - - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t x_qblk_size = qk / 2; // int4 - const uint32_t x_qrow_size = n / 2; // int4 (not padded) - - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) - - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales - - const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales - - HVX_Vector r0_sum = Q6_V_vzero(); - - const uint32_t nb = n / qk; - const uint32_t nloe = n % qk; - - uint32_t i = 0; - for (; i < nb; i++) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_full(r0_x_q + i * x_qblk_size); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); - - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - } - - if (nloe) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_partial(r0_x_q + i * x_qblk_size, nloe); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); - - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); - r0_dd = Q6_V_vand_QV(bmask, r0_dd); - r0_ia = Q6_V_vand_QV(bmask, r0_ia); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - } - - r0_sum = hvx_vec_reduce_sum_f32(r0_sum); - - hvx_vec_store_u(s0, 4, r0_sum); -} - -static void vec_dot_iq4nlx4x2_q8x4x2_2x1(const int n, - float * restrict s0, - const void * restrict vx0, - const void * restrict vx1, - const void * restrict vy0) { - assert(n % 32 == 0); - assert((unsigned long) vx0 % 128 == 0); - assert((unsigned long) vx1 % 128 == 0); - assert((unsigned long) vy0 % 128 == 0); - - const uint32_t qk = QK_Q4_0x4x2 * 4; - - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t x_qblk_size = qk / 2; // int4 - const uint32_t x_qrow_size = n / 2; // int4 (not padded) - - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) - - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales - const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first - const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales - - const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales - - HVX_Vector r0_sum = Q6_V_vzero(); - HVX_Vector r1_sum = Q6_V_vzero(); - - const uint32_t nb = n / qk; - const uint32_t nloe = n % qk; - - uint32_t i = 0; - for (; i < nb; i++) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_full(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_full(r1_x_q + i * x_qblk_size); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); - - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); - } - - if (nloe) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_partial(r0_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_partial(r1_x_q + i * x_qblk_size, nloe); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); - - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); - - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); - r0_dd = Q6_V_vand_QV(bmask, r0_dd); - r1_dd = Q6_V_vand_QV(bmask, r1_dd); - r0_ia = Q6_V_vand_QV(bmask, r0_ia); - r1_ia = Q6_V_vand_QV(bmask, r1_ia); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); - } - - HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); - hvx_vec_store_u(s0, 8, rsum); -} - -static void vec_dot_iq4nlx4x2_q8x4x2_4x1(const int n, - float * restrict s0, - const void * restrict vx0, - const void * restrict vx1, - const void * restrict vx2, - const void * restrict vx3, - const void * restrict vy0) { - assert(n % 32 == 0); - assert((unsigned long) vx0 % 128 == 0); - assert((unsigned long) vx1 % 128 == 0); - assert((unsigned long) vx2 % 128 == 0); - assert((unsigned long) vx3 % 128 == 0); - assert((unsigned long) vy0 % 128 == 0); - - const uint32_t qk = QK_Q4_0x4x2 * 4; - - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t x_qblk_size = qk / 2; // int4 - const uint32_t x_qrow_size = n / 2; // int4 (not padded) - - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) - - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales - const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first - const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales - const uint8_t * restrict r2_x_q = ((const uint8_t *) vx2) + 0; // quants first - const uint8_t * restrict r2_x_d = ((const uint8_t *) vx2) + x_qrow_size; // then scales - const uint8_t * restrict r3_x_q = ((const uint8_t *) vx3) + 0; // quants first - const uint8_t * restrict r3_x_d = ((const uint8_t *) vx3) + x_qrow_size; // then scales - - const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales - - HVX_Vector r0_sum = Q6_V_vzero(); - HVX_Vector r1_sum = Q6_V_vzero(); - HVX_Vector r2_sum = Q6_V_vzero(); - HVX_Vector r3_sum = Q6_V_vzero(); - - const uint32_t nb = n / qk; - const uint32_t nloe = n % qk; - - uint32_t i = 0; - for (; i < nb; i++) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_full(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_full(r1_x_q + i * x_qblk_size); - HVX_Vector_x8 r2_q = hvx_vec_load_iq4nlx4x8_full(r2_x_q + i * x_qblk_size); - HVX_Vector_x8 r3_q = hvx_vec_load_iq4nlx4x8_full(r3_x_q + i * x_qblk_size); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); - HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r2_q, vy_q)); - HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r3_q, vy_q)); - - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); - HVX_Vector r2_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r2_x_d + i * x_dblk_size)); - HVX_Vector r3_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r3_x_d + i * x_dblk_size)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); - HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d))); - HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d))); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); - HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); - r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum)); - r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum)); - } - - if (nloe) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_partial(r0_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_partial(r1_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r2_q = hvx_vec_load_iq4nlx4x8_partial(r2_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r3_q = hvx_vec_load_iq4nlx4x8_partial(r3_x_q + i * x_qblk_size, nloe); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); - HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r2_q, vy_q, nloe)); - HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r3_q, vy_q, nloe)); - - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); - HVX_Vector r2_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r2_x_d + i * x_dblk_size)); - HVX_Vector r3_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r3_x_d + i * x_dblk_size)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); - HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d))); - HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d))); - - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); - r0_dd = Q6_V_vand_QV(bmask, r0_dd); - r1_dd = Q6_V_vand_QV(bmask, r1_dd); - r2_dd = Q6_V_vand_QV(bmask, r2_dd); - r3_dd = Q6_V_vand_QV(bmask, r3_dd); - r0_ia = Q6_V_vand_QV(bmask, r0_ia); - r1_ia = Q6_V_vand_QV(bmask, r1_ia); - r2_ia = Q6_V_vand_QV(bmask, r2_ia); - r3_ia = Q6_V_vand_QV(bmask, r3_ia); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); - HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); - r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum)); - r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum)); - } - - HVX_Vector_x4 rsum_in = { .v = { r0_sum, r1_sum, r2_sum, r3_sum } }; - HVX_Vector rsum = hvx_vec_reduce_sum_f32x4(rsum_in); - hvx_vec_store_u(s0, 16, rsum); -} - - -static void vec_dot_iq4nlx4x2_q8x4x2_2x2(const int n, - float * restrict s0, - float * restrict s1, - const void * restrict vx0, - const void * restrict vx1, - const void * restrict vy0, - const void * restrict vy1) { - assert(n % 32 == 0); - assert((unsigned long) vx0 % 128 == 0); - assert((unsigned long) vx1 % 128 == 0); - assert((unsigned long) vy0 % 128 == 0); - assert((unsigned long) vy1 % 128 == 0); - - const uint32_t qk = QK_Q4_0x4x2 * 4; - - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t x_qblk_size = qk / 2; // int4 - const uint32_t x_qrow_size = n / 2; // int4 (not padded) - - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) - - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; - const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; - const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; - - const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; - const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; - const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; - const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; - - HVX_Vector r0_c0_sum = Q6_V_vzero(); - HVX_Vector r0_c1_sum = Q6_V_vzero(); - HVX_Vector r1_c0_sum = Q6_V_vzero(); - HVX_Vector r1_c1_sum = Q6_V_vzero(); - - const uint32_t nb = n / qk; - const uint32_t nloe = n % qk; - - uint32_t i = 0; - for (; i < nb; i++) { - HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size); - HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_full(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_full(r1_x_q + i * x_qblk_size); - - HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q)); - HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q)); - HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q)); - HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q)); - - HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); - HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); - - HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); - HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); - HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); - HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); - - HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); - HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); - HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); - HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); - - r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); - r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); - r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); - r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); - } - - if (nloe) { - HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial(y0_q + i * y_qblk_size, nloe); - HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial(y1_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_partial(r0_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_partial(r1_x_q + i * x_qblk_size, nloe); - - HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe)); - HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe)); - HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy0_q, nloe)); - HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy1_q, nloe)); - - HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); - HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); - - HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); - HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); - HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); - HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); - - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); - r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd); - r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd); - r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd); - r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd); - r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia); - r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia); - r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia); - r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia); - - HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); - HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); - HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); - HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); - - r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); - r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); - r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); - r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); - } - - HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum); - HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum); - - hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum); - hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); -} - -static void vec_dot_mxfp4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) { - assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx0 % 128 == 0); - assert((unsigned long) vy0 % 128 == 0); - - const uint32_t qk = QK_MXFP4x4x2 * 4; - - const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0 - const uint32_t x_qblk_size = qk / 2; // fp4 - const uint32_t x_qrow_size = n / 2; // fp4 (not padded) - - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) - - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales - - const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales - - // Row sum (sf) - HVX_Vector r0_sum = Q6_V_vzero(); - - // Multiply and accumulate into int32. - // Compute combined scale (fp32). - // Apply scale to acc and accumulate into the row sum (qf32). - - const uint32_t nb = n / qk; // num full blocks - int32_t nloe = n % qk; // num leftover elemements (must be signed) - - uint32_t i = 0; - for (; i < nb; i++) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full( y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_full(r0_x_q + i * x_qblk_size); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); - - HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size); - HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); - - // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving - HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 - vy_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half)); - vy_d = Q6_Vsf_equals_Vqf32(vy_d); - - // Convert rX_d scales from e8m0 to fp32 - // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ... - // Left shift with zero fill to create FP32 - // FIXME: might need to handle zero as a special case (see ggml-cpu code) - HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; - HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); - r0_d = Q6_V_vdelta_VV(r0_d, expand); - r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); - r0_d = Q6_Vw_vasl_VwR(r0_d, 23); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d)); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - } - - // Process leftovers - if (nloe) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial( y_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); - - HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size); - HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); - - // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving - HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 - vy_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half)); - vy_d = Q6_Vsf_equals_Vqf32(vy_d); - - // Convert rX_d scales from e8m0 to fp32 - // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ... - // Left shift with zero fill to create FP32 - // FIXME: might need to handle zero as a special case (see ggml-cpu code) - HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; - HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); - r0_d = Q6_V_vdelta_VV(r0_d, expand); - r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); - r0_d = Q6_Vw_vasl_VwR(r0_d, 23); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d)); - - // Zero-out unused scales - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); - r0_dd = Q6_V_vand_QV(bmask, r0_dd); - r0_ia = Q6_V_vand_QV(bmask, r0_ia); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - } - - r0_sum = hvx_vec_reduce_sum_f32(r0_sum); - - hvx_vec_store_u(s0, 4, r0_sum); -} - -static void vec_dot_mxfp4x4x2_q8x4x2_2x1(const int n, float * restrict s0, - const void * restrict vx0, const void * restrict vx1, - const void * restrict vy0) { - assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx0 % 128 == 0); - assert((unsigned long) vx1 % 128 == 0); - assert((unsigned long) vy0 % 128 == 0); - - const uint32_t qk = QK_MXFP4x4x2 * 4; - - const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0 - const uint32_t x_qblk_size = qk / 2; // fp4 - const uint32_t x_qrow_size = n / 2; // fp4 (not padded) - - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) - - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales - const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first - const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales - - const uint8_t * restrict y_q = ((const uint8_t *) vy0) + 0; // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales - - // Row sum (sf) - HVX_Vector r0_sum = Q6_V_vzero(); - HVX_Vector r1_sum = Q6_V_vzero(); - - // Multiply and accumulate into int32. - // Compute combined scale (fp32). - // Apply scale to acc and accumulate into the row sum (f32). - - const uint32_t nb = n / qk; // num full blocks - int32_t nloe = n % qk; // num leftover elemements (must be signed) - - uint32_t i = 0; - for (; i < nb; i++) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full( y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_full(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_full(r1_x_q + i * x_qblk_size); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); - - HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size); - HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); - HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); - - // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving - HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 - vy_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half)); - vy_d = Q6_Vsf_equals_Vqf32(vy_d); - - // Convert rX_d scales from e8m0 to fp32 - // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ... - // Left shift with zero fill to create FP32 - // FIXME: might need to handle zero as a special case (see ggml-cpu code) - HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; - HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); - r0_d = Q6_V_vdelta_VV(r0_d, expand); - r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); - r0_d = Q6_Vw_vasl_VwR(r0_d, 23); - r1_d = Q6_V_vdelta_VV(r1_d, expand); - r1_d = Q6_V_vand_VV(r1_d, e8m0_mask); - r1_d = Q6_Vw_vasl_VwR(r1_d, 23); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d)); - HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d)); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); - } - - // Process leftovers - if (nloe) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial( y_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_partial(r1_x_q + i * x_qblk_size, nloe); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); - - HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size); - HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); - HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); - - // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving - HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 - vy_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half)); - vy_d = Q6_Vsf_equals_Vqf32(vy_d); - - // Convert rX_d scales from e8m0 to fp32 - // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ... - // Left shift with zero fill to create FP32 - // FIXME: might need to handle zero as a special case (see ggml-cpu code) - HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; - HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); - r0_d = Q6_V_vdelta_VV(r0_d, expand); - r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); - r0_d = Q6_Vw_vasl_VwR(r0_d, 23); - r1_d = Q6_V_vdelta_VV(r1_d, expand); - r1_d = Q6_V_vand_VV(r1_d, e8m0_mask); - r1_d = Q6_Vw_vasl_VwR(r1_d, 23); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d)); - HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d)); - - // Zero-out unused values - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); - r0_dd = Q6_V_vand_QV(bmask, r0_dd); - r1_dd = Q6_V_vand_QV(bmask, r1_dd); - r0_ia = Q6_V_vand_QV(bmask, r0_ia); - r1_ia = Q6_V_vand_QV(bmask, r1_ia); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); - } - - HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); - hvx_vec_store_u(s0, 8, rsum); -} - -static void vec_dot_mxfp4x4x2_q8x4x2_4x1(const int n, float * restrict s0, - const void * restrict vx0, const void * restrict vx1, - const void * restrict vx2, const void * restrict vx3, - const void * restrict vy0) { - assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx0 % 128 == 0); - assert((unsigned long) vx1 % 128 == 0); - assert((unsigned long) vx2 % 128 == 0); - assert((unsigned long) vx3 % 128 == 0); - assert((unsigned long) vy0 % 128 == 0); - - const uint32_t qk = QK_MXFP4x4x2 * 4; - - const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0 - const uint32_t x_qblk_size = qk / 2; // fp4 - const uint32_t x_qrow_size = n / 2; // fp4 (not padded) - - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) - - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales - const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first - const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales - const uint8_t * restrict r2_x_q = ((const uint8_t *) vx2) + 0; // quants first - const uint8_t * restrict r2_x_d = ((const uint8_t *) vx2) + x_qrow_size; // then scales - const uint8_t * restrict r3_x_q = ((const uint8_t *) vx3) + 0; // quants first - const uint8_t * restrict r3_x_d = ((const uint8_t *) vx3) + x_qrow_size; // then scales - - const uint8_t * restrict y_q = ((const uint8_t *) vy0) + 0; // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales - - // Row sum (sf) - HVX_Vector r0_sum = Q6_V_vzero(); - HVX_Vector r1_sum = Q6_V_vzero(); - HVX_Vector r2_sum = Q6_V_vzero(); - HVX_Vector r3_sum = Q6_V_vzero(); - - const uint32_t nb = n / qk; // num full blocks - int32_t nloe = n % qk; // num leftover elemements (must be signed) - - uint32_t i = 0; - for (; i < nb; i++) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full( y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_full(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_full(r1_x_q + i * x_qblk_size); - HVX_Vector_x8 r2_q = hvx_vec_load_mxfp4x4x8_full(r2_x_q + i * x_qblk_size); - HVX_Vector_x8 r3_q = hvx_vec_load_mxfp4x4x8_full(r3_x_q + i * x_qblk_size); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); - HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r2_q, vy_q)); - HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r3_q, vy_q)); - - HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size); - HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); - HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); - HVX_Vector r2_d = *(const HVX_UVector *) (r2_x_d + i * x_dblk_size); - HVX_Vector r3_d = *(const HVX_UVector *) (r3_x_d + i * x_dblk_size); - - // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving - HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 - vy_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half)); - vy_d = Q6_Vsf_equals_Vqf32(vy_d); - - // Convert rX_d scales from e8m0 to fp32 - HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; - HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); - r0_d = Q6_V_vdelta_VV(r0_d, expand); - r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); - r0_d = Q6_Vw_vasl_VwR(r0_d, 23); - r1_d = Q6_V_vdelta_VV(r1_d, expand); - r1_d = Q6_V_vand_VV(r1_d, e8m0_mask); - r1_d = Q6_Vw_vasl_VwR(r1_d, 23); - r2_d = Q6_V_vdelta_VV(r2_d, expand); - r2_d = Q6_V_vand_VV(r2_d, e8m0_mask); - r2_d = Q6_Vw_vasl_VwR(r2_d, 23); - r3_d = Q6_V_vdelta_VV(r3_d, expand); - r3_d = Q6_V_vand_VV(r3_d, e8m0_mask); - r3_d = Q6_Vw_vasl_VwR(r3_d, 23); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d)); - HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d)); - HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r2_d, vy_d)); - HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r3_d, vy_d)); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); - HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); - r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum)); - r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum)); - } - - if (nloe) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial( y_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_partial(r1_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r2_q = hvx_vec_load_mxfp4x4x8_partial(r2_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r3_q = hvx_vec_load_mxfp4x4x8_partial(r3_x_q + i * x_qblk_size, nloe); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); - HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r2_q, vy_q)); - HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r3_q, vy_q)); - - HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size); - HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); - HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); - HVX_Vector r2_d = *(const HVX_UVector *) (r2_x_d + i * x_dblk_size); - HVX_Vector r3_d = *(const HVX_UVector *) (r3_x_d + i * x_dblk_size); - - // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving - HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 - vy_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half)); - vy_d = Q6_Vsf_equals_Vqf32(vy_d); - - // Convert rX_d scales from e8m0 to fp32 - HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; - HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); - r0_d = Q6_V_vdelta_VV(r0_d, expand); - r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); - r0_d = Q6_Vw_vasl_VwR(r0_d, 23); - r1_d = Q6_V_vdelta_VV(r1_d, expand); - r1_d = Q6_V_vand_VV(r1_d, e8m0_mask); - r1_d = Q6_Vw_vasl_VwR(r1_d, 23); - r2_d = Q6_V_vdelta_VV(r2_d, expand); - r2_d = Q6_V_vand_VV(r2_d, e8m0_mask); - r2_d = Q6_Vw_vasl_VwR(r2_d, 23); - r3_d = Q6_V_vdelta_VV(r3_d, expand); - r3_d = Q6_V_vand_VV(r3_d, e8m0_mask); - r3_d = Q6_Vw_vasl_VwR(r3_d, 23); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d)); - HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d)); - HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r2_d, vy_d)); - HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r3_d, vy_d)); - - // Zero-out unused values - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); - r0_dd = Q6_V_vand_QV(bmask, r0_dd); - r1_dd = Q6_V_vand_QV(bmask, r1_dd); - r2_dd = Q6_V_vand_QV(bmask, r2_dd); - r3_dd = Q6_V_vand_QV(bmask, r3_dd); - r0_ia = Q6_V_vand_QV(bmask, r0_ia); - r1_ia = Q6_V_vand_QV(bmask, r1_ia); - r2_ia = Q6_V_vand_QV(bmask, r2_ia); - r3_ia = Q6_V_vand_QV(bmask, r3_ia); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); - HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); - r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum)); - r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum)); - } - - HVX_Vector_x4 rsum_in = { .v = { r0_sum, r1_sum, r2_sum, r3_sum } }; - HVX_Vector rsum = hvx_vec_reduce_sum_f32x4(rsum_in); - hvx_vec_store_u(s0, 16, rsum); -} - - -static void vec_dot_mxfp4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1, - const void * restrict vx0, const void * restrict vx1, - const void * restrict vy0, const void * restrict vy1) { - assert(n % 32 == 0); - assert((unsigned long) vx0 % 128 == 0); - assert((unsigned long) vx1 % 128 == 0); - assert((unsigned long) vy0 % 128 == 0); - assert((unsigned long) vy1 % 128 == 0); - - const uint32_t qk = QK_MXFP4x4x2 * 4; - - const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0 - const uint32_t x_qblk_size = qk / 2; // fp4 - const uint32_t x_qrow_size = n / 2; // fp4 (not padded) - - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) - - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales - const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first - const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales - - const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; // quants first - const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales - const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; // quants first - const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales - - // Row sums (sf) - 4 accumulators for 2×2 tile - HVX_Vector r0_c0_sum = Q6_V_vzero(); - HVX_Vector r0_c1_sum = Q6_V_vzero(); - HVX_Vector r1_c0_sum = Q6_V_vzero(); - HVX_Vector r1_c1_sum = Q6_V_vzero(); - - const uint32_t nb = n / qk; // num full blocks - const uint32_t nloe = n % qk; // num leftover elements - - uint32_t i = 0; - for (; i < nb; i++) { - // Load src1 columns (reused across both src0 rows) - HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size); - HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size); - - // Load src0 rows (reused across both src1 columns) - HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_full(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_full(r1_x_q + i * x_qblk_size); - - // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1 - HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q)); - HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q)); - HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q)); - HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q)); - - // Load scales - HVX_Vector vy0_d = *(const HVX_UVector *) (y0_d + i * y_dblk_size); - HVX_Vector vy1_d = *(const HVX_UVector *) (y1_d + i * y_dblk_size); - HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); - HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); - - // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving - HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 - vy0_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy0_d), half)); - vy0_d = Q6_Vsf_equals_Vqf32(vy0_d); - vy1_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy1_d), half)); - vy1_d = Q6_Vsf_equals_Vqf32(vy1_d); - - // Convert rX_d scales from e8m0 to fp32 - // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ... - // Left shift with zero fill to create FP32 - // FIXME: might need to handle zero as a special case (see ggml-cpu code) - HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; - HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); - r0_d = Q6_V_vdelta_VV(r0_d, expand); - r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); - r0_d = Q6_Vw_vasl_VwR(r0_d, 23); - r1_d = Q6_V_vdelta_VV(r1_d, expand); - r1_d = Q6_V_vand_VV(r1_d, e8m0_mask); - r1_d = Q6_Vw_vasl_VwR(r1_d, 23); - - // Compute combined scales - HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy0_d)); - HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy1_d)); - HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy0_d)); - HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy1_d)); - - // Apply scales and accumulate - HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); - HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); - HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); - HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); - - r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); - r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); - r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); - r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); - } - - // Process leftovers - if (nloe) { - HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial( y0_q + i * y_qblk_size, nloe); - HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial( y1_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_partial(r1_x_q + i * x_qblk_size, nloe); - - HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe)); - HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe)); - HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy0_q, nloe)); - HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy1_q, nloe)); - - HVX_Vector vy0_d = *(const HVX_UVector *) (y0_d + i * y_dblk_size); - HVX_Vector vy1_d = *(const HVX_UVector *) (y1_d + i * y_dblk_size); - HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); - HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); - - // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving - HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 - vy0_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy0_d), half)); - vy0_d = Q6_Vsf_equals_Vqf32(vy0_d); - vy1_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy1_d), half)); - vy1_d = Q6_Vsf_equals_Vqf32(vy1_d); - - // Convert rX_d scales from e8m0 to fp32 - // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ... - // Left shift with zero fill to create FP32 - // FIXME: might need to handle zero as a special case (see ggml-cpu code) - HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; - HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); - r0_d = Q6_V_vdelta_VV(r0_d, expand); - r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); - r0_d = Q6_Vw_vasl_VwR(r0_d, 23); - r1_d = Q6_V_vdelta_VV(r1_d, expand); - r1_d = Q6_V_vand_VV(r1_d, e8m0_mask); - r1_d = Q6_Vw_vasl_VwR(r1_d, 23); - - HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy0_d)); - HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy1_d)); - HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy0_d)); - HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy1_d)); - - // Zero out unused scales - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); - r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd); - r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd); - r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd); - r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd); - r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia); - r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia); - r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia); - r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia); - - HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); - HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); - HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); - HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); - - r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); - r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); - r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); - r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); - } - - // Reduce and store results - HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum); - HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum); - - hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum); // row0,col0 row1,col0 - hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1 -} - #if __HVX_ARCH__ < 79 #define HVX_OP_ADD_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b)) #define HVX_OP_MUL_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b)) @@ -2926,7 +141,7 @@ static void vec_dot_mxfp4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float #define HVX_OP_MUL_F32(a, b) Q6_Vsf_vmpy_VsfVsf(a, b) #endif -static void vec_dot_f32_f32_aa_1x1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { +static void vec_dot_f32_f32_aa_1x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy) { const HVX_Vector * restrict x = (const HVX_Vector *) vx; const HVX_Vector * restrict y = (const HVX_Vector *) vy; @@ -2954,7 +169,7 @@ static void vec_dot_f32_f32_aa_1x1(const int n, float * restrict s, const void * *s = hvx_vec_get_f32(hvx_vec_reduce_sum_f32(rsum)); } -static void vec_dot_f32_f32_aa_2x1(const int n, float * restrict s0, +static void vec_dot_f32_f32_aa_2x1(const uint32_t n, float * restrict s0, const void * restrict vx0, const void * restrict vx1, const void * restrict vy0) { const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0; @@ -2996,7 +211,7 @@ static void vec_dot_f32_f32_aa_2x1(const int n, float * restrict s0, s0[1] = va.fp32[1]; } -static void vec_dot_f32_f32_aa_2x2(const int n, float * restrict s0, float * restrict s1, +static void vec_dot_f32_f32_aa_2x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx0, const void * restrict vx1, const void * restrict vy0, const void * restrict vy1) { const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0; @@ -3054,7 +269,7 @@ static void vec_dot_f32_f32_aa_2x2(const int n, float * restrict s0, float * res s1[1] = va1.fp32[1]; } -static void vec_dot_f32_f32_uu_1x1(const int n, float * restrict s, const void * restrict x, const void * restrict y) { +static void vec_dot_f32_f32_uu_1x1(const uint32_t n, float * restrict s, const void * restrict x, const void * restrict y) { const HVX_UVector * restrict vx = (const HVX_UVector * restrict) x; const HVX_UVector * restrict vy = (const HVX_UVector * restrict) y; @@ -3088,7 +303,7 @@ static void vec_dot_f32_f32_uu_1x1(const int n, float * restrict s, const void * hvx_vec_store_u(&s[0], 4, rsum); } -static void vec_dot_f16_f16_aa_1x1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { +static void vec_dot_f16_f16_aa_1x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy) { const HVX_Vector * restrict x = (const HVX_Vector *) vx; const HVX_Vector * restrict y = (const HVX_Vector *) vy; @@ -3115,7 +330,7 @@ static void vec_dot_f16_f16_aa_1x1(const int n, float * restrict s, const void * hvx_vec_store_u(s, 4, hvx_vec_reduce_sum_f32(rsum)); } -static void vec_dot_f16_f16_aa_2x1(const int n, float * restrict s0, +static void vec_dot_f16_f16_aa_2x1(const uint32_t n, float * restrict s0, const void * restrict vx0, const void * restrict vx1, const void * restrict vy0) { const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0; @@ -3152,7 +367,7 @@ static void vec_dot_f16_f16_aa_2x1(const int n, float * restrict s0, hvx_vec_store_u(s0, 8, rsum); } -static void vec_dot_f16_f16_aa_2x2(const int n, float * restrict s0, float * restrict s1, +static void vec_dot_f16_f16_aa_2x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx0, const void * restrict vx1, const void * restrict vy0, const void * restrict vy1) { const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0; @@ -3212,7 +427,7 @@ static void vec_dot_f16_f16_aa_2x2(const int n, float * restrict s0, float * res hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1 } -static void vec_dot_f16_f16_uu_1x1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { +static void vec_dot_f16_f16_uu_1x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy) { const HVX_UVector * restrict x = (const HVX_UVector *) vx; const HVX_UVector * restrict y = (const HVX_UVector *) vy; @@ -3242,7 +457,7 @@ static void vec_dot_f16_f16_uu_1x1(const int n, float * restrict s, const void * hvx_vec_store_u(&s[0], 4, rsum); } -static void vec_dot_f16_f32_uu_1x1(const int n, float * restrict s, const void * restrict x, const void * restrict y) { +static void vec_dot_f16_f32_uu_1x1(const uint32_t n, float * restrict s, const void * restrict x, const void * restrict y) { const HVX_UVector * restrict vx = (const HVX_UVector * restrict) x; const HVX_UVector * restrict vy = (const HVX_UVector * restrict) y; @@ -3295,65 +510,58 @@ static void vec_dot_f16_f32_uu_1x1(const int n, float * restrict s, const void * hvx_vec_store_u(&s[0], 4, rsum); } -#define htp_matmul_tensors_preamble \ - const struct htp_tensor * restrict src0 = octx->src[0]; \ - const struct htp_tensor * restrict src1 = octx->src[1]; \ - const struct htp_tensor * restrict src2 = octx->src[2]; \ - const struct htp_tensor * restrict dst = octx->dst; \ - struct htp_spad * restrict src0_spad = &octx->src0_spad; \ - struct htp_spad * restrict src1_spad = &octx->src1_spad; \ - struct htp_spad * restrict dst_spad = &octx->dst_spad; \ - \ - const uint32_t ne00 = src0->ne[0]; \ - const uint32_t ne01 = src0->ne[1]; \ - const uint32_t ne02 = src0->ne[2]; \ - const uint32_t ne03 = src0->ne[3]; \ - \ - const uint32_t ne10 = src1->ne[0]; \ - const uint32_t ne11 = src1->ne[1]; \ - const uint32_t ne12 = src1->ne[2]; \ - const uint32_t ne13 = src1->ne[3]; \ - \ - const uint32_t ne20 = src2->ne[0]; \ - const uint32_t ne21 = src2->ne[1]; \ - const uint32_t ne22 = src2->ne[2]; \ - const uint32_t ne23 = src2->ne[3]; \ - \ - const uint32_t ne0 = dst->ne[0]; \ - const uint32_t ne1 = dst->ne[1]; \ - const uint32_t ne2 = dst->ne[2]; \ - const uint32_t ne3 = dst->ne[3]; \ - \ - const uint32_t nb00 = src0->nb[0]; \ - const uint32_t nb01 = src0->nb[1]; \ - const uint32_t nb02 = src0->nb[2]; \ - const uint32_t nb03 = src0->nb[3]; \ - \ - const uint32_t nb10 = src1->nb[0]; \ - const uint32_t nb11 = src1->nb[1]; \ - const uint32_t nb12 = src1->nb[2]; \ - const uint32_t nb13 = src1->nb[3]; \ - \ - const uint32_t nb0 = dst->nb[0]; \ - const uint32_t nb1 = dst->nb[1]; \ - const uint32_t nb2 = dst->nb[2]; \ +#define htp_matmul_tensors_preamble \ + const struct htp_tensor * restrict src0 = octx->src[0]; \ + const struct htp_tensor * restrict src1 = octx->src[1]; \ + const struct htp_tensor * restrict src2 = octx->src[2]; \ + const struct htp_tensor * restrict dst = octx->dst; \ + \ + const uint32_t ne00 = src0->ne[0]; \ + const uint32_t ne01 = src0->ne[1]; \ + const uint32_t ne02 = src0->ne[2]; \ + const uint32_t ne03 = src0->ne[3]; \ + \ + const uint32_t ne10 = src1->ne[0]; \ + const uint32_t ne11 = src1->ne[1]; \ + const uint32_t ne12 = src1->ne[2]; \ + const uint32_t ne13 = src1->ne[3]; \ + \ + const uint32_t ne20 = src2->ne[0]; \ + const uint32_t ne21 = src2->ne[1]; \ + const uint32_t ne22 = src2->ne[2]; \ + const uint32_t ne23 = src2->ne[3]; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + const uint32_t ne2 = dst->ne[2]; \ + const uint32_t ne3 = dst->ne[3]; \ + \ + const uint32_t nb00 = src0->nb[0]; \ + const uint32_t nb01 = src0->nb[1]; \ + const uint32_t nb02 = src0->nb[2]; \ + const uint32_t nb03 = src0->nb[3]; \ + \ + const uint32_t nb10 = src1->nb[0]; \ + const uint32_t nb11 = src1->nb[1]; \ + const uint32_t nb12 = src1->nb[2]; \ + const uint32_t nb13 = src1->nb[3]; \ + \ + const uint32_t nb0 = dst->nb[0]; \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ const uint32_t nb3 = dst->nb[3]; -#define htp_matmul_preamble \ - struct htp_matmul_context * mmctx = data; \ - struct htp_ops_context * octx = mmctx->octx; \ - htp_matmul_tensors_preamble; \ - dma_queue *dma_queue = octx->ctx->dma[ith]; \ - uint32_t src0_nrows_per_thread = mmctx->src0_nrows_per_thread; +#define htp_matmul_preamble \ + struct htp_mm_context * mmctx = data; \ + struct htp_ops_context * octx = mmctx->octx; \ + dma_queue *dma_queue = octx->ctx->dma[ith]; \ + uint32_t src0_nrows_per_thread = mmctx->src0_nrows_per_thread; \ + htp_matmul_tensors_preamble; // *** matmul with support for 4d tensors and full broadcasting -static void matmul_4d(unsigned int nth, unsigned int ith, void * data) { +static void hvx_mm_4d(unsigned int nth, unsigned int ith, void * data) { htp_matmul_preamble; - struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL; - - uint64_t t1, t2; - t1 = HAP_perf_get_qtimer_count(); assert(ne12 % ne02 == 0); assert(ne13 % ne03 == 0); @@ -3388,7 +596,9 @@ static void matmul_4d(unsigned int nth, unsigned int ith, void * data) { return; } - // block-tiling attempt + struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL; + htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0_start); + const uint32_t blck_0 = 64; const uint32_t blck_1 = 64; @@ -3412,28 +622,606 @@ static void matmul_4d(unsigned int nth, unsigned int ith, void * data) { float * dst_col = (float *) ((uint8_t * restrict) dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3)); const uint32_t ir0_block_end = MIN(iir0 + blck_0, ir0_end); - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, iir0); for (uint32_t ir0 = iir0; ir0 < ir0_block_end; ir0++) { const uint8_t * restrict src0_row = src0_base + ir0 * nb01; mmctx->vec_dot_1x1(ne00, &dst_col[ir0], src0_row, src1_col); } - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, iir0); } } } - t2 = HAP_perf_get_qtimer_count(); - - FARF(HIGH, "matmul-4d %d/%d: %ux%ux%ux%u (%u:%u %u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, - src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ir0_start, ir0_end, ir1_start, ir1_end, src1->ne[0], - src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], - (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); + htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0_start); } -// src1 tensor is already in VTCM spad -static void matmul_2d(unsigned int nth, unsigned int ith, void * data) { - htp_matmul_preamble; +#include "hmx-mm-kernels-tiled.h" +#include "hvx-mm-kernels-tiled.h" +#include "hvx-mm-kernels-flat.h" + +// Specialized repacked matmul macros +#define MATMUL_2D_REPACKED_IMPL(SUFFIX, TILE_SIZE, DOT_2X2, DOT_2X1) \ +static void hvx_mm_2d_repacked_##SUFFIX(unsigned int nth, unsigned int ith, void * data) { \ + htp_matmul_preamble; \ + \ + const uint32_t src0_nrows = ne01 * ne02 * ne03; \ + const uint32_t src1_nrows = ne11 * ne12 * ne13; \ + \ + const uint32_t src0_start_row = src0_nrows_per_thread * ith; \ + const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); \ + \ + if (src0_start_row >= src0_end_row) { \ + return; \ + } \ + \ + struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL; \ + \ + const struct htp_mm_kernel_params * kparams = (const struct htp_mm_kernel_params *) octx->kernel_params; \ + const uint32_t n_prefetch = kparams->n_prefetch; \ + assert(n_prefetch >= 2 && n_prefetch <= HTP_MM_MAX_PREFETCH && (n_prefetch & (n_prefetch - 1)) == 0); \ + \ + const size_t dst_row_size = nb1; \ + const size_t src1_row_size = nb11; \ + const size_t src1_stride = mmctx->vtcm_src1_stride; \ + \ + uint8_t * restrict vtcm_dst_ptr = mmctx->vtcm_dst + mmctx->vtcm_dst_size_per_thread * ith; \ + uint8_t * restrict vtcm_src0_ptr = mmctx->vtcm_src0 + mmctx->vtcm_src0_size_per_thread * ith; \ + uint8_t * restrict src1_data = mmctx->vtcm_src1; \ + \ + const uint8_t * restrict src0_row = (const uint8_t *) src0->data; \ + \ + const uint32_t tile_size = TILE_SIZE; \ + const uint32_t aligned_tile_size = hex_align_up(tile_size, 128); \ + \ + uint32_t n_k_tiles_w = ne00 / 32; \ + uint32_t n_k_tiles_a = ne10 / 32; \ + uint32_t tile_row_stride = n_k_tiles_w * tile_size; \ + uint32_t tile_row_transfer_size_aligned = n_k_tiles_a * aligned_tile_size; \ + \ + uint32_t ct_start = src0_start_row / 32; \ + uint32_t ct_end = (src0_end_row + 31) / 32; \ + \ + uint32_t push_ct = ct_start; \ + for (uint32_t d = 0; d < n_prefetch && push_ct < ct_end; d++, push_ct++) { \ + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src0_ptr + d * tile_row_transfer_size_aligned, \ + src0_row + push_ct * tile_row_stride), aligned_tile_size, tile_size, tile_size, n_k_tiles_a); \ + } \ + \ + for (uint32_t ct = ct_start; ct < ct_end; ct++) { \ + const uint8_t * w_tile = dma_queue_pop(dma_queue).dst; \ + \ + int valid_rows = (int)ne0 - (int)(ct * 32); \ + valid_rows = MIN(32, MAX(0, valid_rows)); \ + \ + htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ct); \ + uint32_t ir1 = 0; \ + for (; ir1 + 1 < src1_nrows; ir1 += 2) { \ + const uint8_t * restrict src1_col0 = (const uint8_t *) (src1_data + (ir1+0) * src1_stride); \ + const uint8_t * restrict src1_col1 = (const uint8_t *) (src1_data + (ir1+1) * src1_stride); \ + float * restrict dst_row0 = (float *) (dst->data + ((ir1+0) * dst_row_size)); \ + float * restrict dst_row1 = (float *) (dst->data + ((ir1+1) * dst_row_size)); \ + \ + float * dst_ptr0 = &dst_row0[ct * 32]; \ + float * dst_ptr1 = &dst_row1[ct * 32]; \ + \ + DOT_2X2(ne10, dst_ptr0, dst_ptr1, w_tile, src1_col0, src1_col1, valid_rows); \ + } \ + \ + for (; ir1 < src1_nrows; ++ir1) { \ + const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride); \ + float * restrict dst_row = (float *) (dst->data + (ir1 * dst_row_size)); \ + float * dst_ptr = &dst_row[ct * 32]; \ + \ + DOT_2X1(ne10, dst_ptr, w_tile, src1_col, valid_rows); \ + } \ + htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ct); \ + \ + if (push_ct < ct_end) { \ + dma_queue_push(dma_queue, dma_make_ptr((uint8_t *)w_tile, src0_row + push_ct * tile_row_stride), \ + aligned_tile_size, tile_size, tile_size, n_k_tiles_a); \ + push_ct++; \ + } \ + } \ +} + +#define MATVEC_2D_REPACKED_IMPL(SUFFIX, TILE_SIZE, DOT_2X1) \ +static void hvx_mv_2d_repacked_##SUFFIX(unsigned int nth, unsigned int ith, void * data) { \ + htp_matmul_preamble; \ + \ + const uint32_t src0_nrows = ne01; \ + \ + const uint32_t src0_start_row = src0_nrows_per_thread * ith; \ + const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); \ + \ + if (src0_start_row >= src0_end_row) { \ + return; \ + } \ + \ + struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL; \ + \ + const struct htp_mm_kernel_params * kparams = (const struct htp_mm_kernel_params *) octx->kernel_params; \ + const uint32_t n_prefetch = kparams->n_prefetch; \ + assert(n_prefetch >= 2 && n_prefetch <= HTP_MM_MAX_PREFETCH && (n_prefetch & (n_prefetch - 1)) == 0); \ + \ + const size_t dst_row_size = nb1; \ + const size_t src1_row_size = nb11; \ + const size_t src1_stride = mmctx->vtcm_src1_stride; \ + \ + uint8_t * vtcm_dst_ptr = mmctx->vtcm_dst + mmctx->vtcm_dst_size_per_thread * ith; \ + uint8_t * vtcm_src0_ptr = mmctx->vtcm_src0 + mmctx->vtcm_src0_size_per_thread * ith; \ + uint8_t * src1_data = mmctx->vtcm_src1; \ + \ + float * tmp = (float *) vtcm_dst_ptr; \ + \ + const uint8_t * restrict src0_row = (const uint8_t *) src0->data; \ + const uint8_t * restrict src1_col = (const uint8_t *) src1_data; \ + float * restrict dst_col = (float *) dst->data; \ + \ + const uint32_t tile_size = TILE_SIZE; \ + const uint32_t aligned_tile_size = hex_align_up(tile_size, 128); \ + \ + uint32_t n_k_tiles_w = ne00 / 32; \ + uint32_t n_k_tiles_a = ne10 / 32; \ + uint32_t tile_row_stride = n_k_tiles_w * tile_size; \ + uint32_t tile_row_transfer_size_aligned = n_k_tiles_a * aligned_tile_size; \ + \ + uint32_t ct_start = src0_start_row / 32; \ + uint32_t ct_end = (src0_end_row + 31) / 32; \ + \ + uint32_t push_ct = ct_start; \ + for (uint32_t d = 0; d < n_prefetch && push_ct < ct_end; d++, push_ct++) { \ + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src0_ptr + d * tile_row_transfer_size_aligned, \ + src0_row + push_ct * tile_row_stride), aligned_tile_size, tile_size, tile_size, n_k_tiles_a); \ + } \ + \ + for (uint32_t ct = ct_start; ct < ct_end; ct++) { \ + const uint8_t * w_tile = dma_queue_pop(dma_queue).dst; \ + \ + float * dst_ptr = &tmp[ct * 32 - src0_start_row]; \ + int valid_rows = (int)ne0 - (int)(ct * 32); \ + valid_rows = MIN(32, MAX(0, valid_rows)); \ + \ + htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ct); \ + DOT_2X1(ne10, dst_ptr, w_tile, src1_col, valid_rows); \ + htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ct); \ + \ + if (push_ct < ct_end) { \ + dma_queue_push(dma_queue, dma_make_ptr((uint8_t *)w_tile, src0_row + push_ct * tile_row_stride), \ + aligned_tile_size, tile_size, tile_size, n_k_tiles_a); \ + push_ct++; \ + } \ + } \ + \ + int copy_cnt = (int)MIN(src0_end_row, ne0) - (int)src0_start_row; \ + if (copy_cnt > 0) { \ + hvx_copy_f32_ua((uint8_t *) &dst_col[src0_start_row], (uint8_t *) tmp, copy_cnt); \ + } \ +} + +#define MATMUL_QKV_2D_REPACKED_IMPL(SUFFIX, TILE_SIZE, DOT_2X2, DOT_2X1) \ +static void hvx_mm_qkv_2d_repacked_##SUFFIX(unsigned int nth, unsigned int ith, void * data) { \ + struct htp_mm_context * mmctx = data; \ + struct htp_ops_context * octx = mmctx->octx; \ + \ + const struct htp_tensor * restrict src0 = octx->src[0]; /* Wk */ \ + const struct htp_tensor * restrict src1 = octx->src[1]; /* x */ \ + const struct htp_tensor * restrict src2 = octx->src[2]; /* Wv */ \ + const struct htp_tensor * restrict src3 = octx->src[3]; /* Wq */ \ + const struct htp_tensor * restrict dst_k = octx->dsts[0]; \ + const struct htp_tensor * restrict dst_v = octx->dsts[1]; \ + const struct htp_tensor * restrict dst_q = octx->dsts[2]; \ + \ + const uint32_t ne00 = src0->ne[0]; \ + const uint32_t ne10 = src1->ne[0]; \ + const uint32_t src1_nrows = src1->ne[1] * src1->ne[2] * src1->ne[3]; \ + \ + const size_t dst_k_row_size = dst_k->nb[1]; /* K and V share output width */ \ + const size_t dst_q_row_size = dst_q->nb[1]; /* Q may be wider (GQA) */ \ + const size_t src1_stride = mmctx->vtcm_src1_stride; \ + \ + uint8_t * restrict vtcm_src0_ptr = mmctx->vtcm_src0 + mmctx->vtcm_src0_size_per_thread * ith; \ + uint8_t * restrict vtcm_src2_ptr = mmctx->vtcm_src2 + mmctx->vtcm_src2_size_per_thread * ith; \ + uint8_t * restrict vtcm_src3_ptr = mmctx->vtcm_src3 + mmctx->vtcm_src3_size_per_thread * ith; \ + uint8_t * restrict src1_data = mmctx->vtcm_src1; \ + \ + struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL; \ + \ + const struct htp_mm_kernel_params * kparams = (const struct htp_mm_kernel_params *) octx->kernel_params; \ + const uint32_t n_prefetch = kparams->n_prefetch; \ + assert(n_prefetch >= 2 && n_prefetch <= HTP_MM_MAX_PREFETCH && (n_prefetch & (n_prefetch - 1)) == 0); \ + \ + const uint8_t * restrict src0_row = (const uint8_t *) src0->data; \ + const uint8_t * restrict src2_row = (const uint8_t *) src2->data; \ + const uint8_t * restrict src3_row = (const uint8_t *) src3->data; \ + \ + const uint32_t tile_size = TILE_SIZE; \ + const uint32_t aligned_tile_size = hex_align_up(tile_size, 128); \ + \ + uint32_t n_k_tiles_w = ne00 / 32; \ + uint32_t n_k_tiles_a = ne10 / 32; \ + uint32_t tile_row_stride = n_k_tiles_w * tile_size; \ + uint32_t tile_row_transfer_size_aligned = n_k_tiles_a * aligned_tile_size; \ + \ + dma_queue * dma_queue = octx->ctx->dma[ith]; \ + \ + /* 1. Process K and V together */ \ + const uint32_t src0_nrows_kv = src0->ne[1] * src0->ne[2] * src0->ne[3]; /* src0 is Wk */ \ + uint32_t src0_nrows_per_thread_kv = (src0_nrows_kv + nth - 1) / nth; \ + src0_nrows_per_thread_kv = hex_round_up(src0_nrows_per_thread_kv, 32); \ + \ + const uint32_t start_row_kv = src0_nrows_per_thread_kv * ith; \ + const uint32_t end_row_kv = MIN(start_row_kv + src0_nrows_per_thread_kv, src0_nrows_kv); \ + \ + if (start_row_kv < end_row_kv) { \ + uint32_t ct_start_kv = start_row_kv / 32; \ + uint32_t ct_end_kv = (end_row_kv + 31) / 32; \ + \ + uint32_t push_ct = ct_start_kv; \ + for (uint32_t d = 0; d < n_prefetch && push_ct < ct_end_kv; d++, push_ct++) { \ + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src0_ptr + d * tile_row_transfer_size_aligned, \ + src0_row + push_ct * tile_row_stride), aligned_tile_size, tile_size, tile_size, n_k_tiles_a); \ + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src2_ptr + d * tile_row_transfer_size_aligned, \ + src2_row + push_ct * tile_row_stride), aligned_tile_size, tile_size, tile_size, n_k_tiles_a); \ + } \ + \ + for (uint32_t ct = ct_start_kv; ct < ct_end_kv; ct++) { \ + const uint8_t * w_tile_k = dma_queue_pop(dma_queue).dst; \ + const uint8_t * w_tile_v = dma_queue_pop(dma_queue).dst; \ + \ + int valid_rows = (int)src0->ne[1] - (int)(ct * 32); \ + valid_rows = MIN(32, MAX(0, valid_rows)); \ + \ + htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ith); \ + uint32_t ir1 = 0; \ + for (; ir1 + 1 < src1_nrows; ir1 += 2) { \ + const uint8_t * restrict src1_col0 = (const uint8_t *) (src1_data + (ir1+0) * src1_stride); \ + const uint8_t * restrict src1_col1 = (const uint8_t *) (src1_data + (ir1+1) * src1_stride); \ + \ + float * restrict dst_row0_k = (float *) (dst_k->data + ((ir1+0) * dst_k_row_size)); \ + float * restrict dst_row1_k = (float *) (dst_k->data + ((ir1+1) * dst_k_row_size)); \ + float * dst_ptr0_k = &dst_row0_k[ct * 32]; \ + float * dst_ptr1_k = &dst_row1_k[ct * 32]; \ + \ + float * restrict dst_row0_v = (float *) (dst_v->data + ((ir1+0) * dst_k_row_size)); \ + float * restrict dst_row1_v = (float *) (dst_v->data + ((ir1+1) * dst_k_row_size)); \ + float * dst_ptr0_v = &dst_row0_v[ct * 32]; \ + float * dst_ptr1_v = &dst_row1_v[ct * 32]; \ + \ + DOT_2X2(ne10, dst_ptr0_k, dst_ptr1_k, w_tile_k, src1_col0, src1_col1, valid_rows); \ + DOT_2X2(ne10, dst_ptr0_v, dst_ptr1_v, w_tile_v, src1_col0, src1_col1, valid_rows); \ + } \ + \ + for (; ir1 < src1_nrows; ++ir1) { \ + const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride); \ + \ + float * restrict dst_row_k = (float *) (dst_k->data + (ir1 * dst_k_row_size)); \ + float * dst_ptr_k = &dst_row_k[ct * 32]; \ + \ + float * restrict dst_row_v = (float *) (dst_v->data + (ir1 * dst_k_row_size)); \ + float * dst_ptr_v = &dst_row_v[ct * 32]; \ + \ + DOT_2X1(ne10, dst_ptr_k, w_tile_k, src1_col, valid_rows); \ + DOT_2X1(ne10, dst_ptr_v, w_tile_v, src1_col, valid_rows); \ + } \ + htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ith); \ + \ + if (push_ct < ct_end_kv) { \ + dma_queue_push(dma_queue, dma_make_ptr((uint8_t *)w_tile_k, src0_row + push_ct * tile_row_stride), \ + aligned_tile_size, tile_size, tile_size, n_k_tiles_a); \ + dma_queue_push(dma_queue, dma_make_ptr((uint8_t *)w_tile_v, src2_row + push_ct * tile_row_stride), \ + aligned_tile_size, tile_size, tile_size, n_k_tiles_a); \ + push_ct++; \ + } \ + } \ + } \ + \ + /* 2. Process Q separately */ \ + const uint32_t src0_nrows_q = src3->ne[1] * src3->ne[2] * src3->ne[3]; /* src3 is Wq */ \ + uint32_t src0_nrows_per_thread_q = (src0_nrows_q + nth - 1) / nth; \ + src0_nrows_per_thread_q = hex_round_up(src0_nrows_per_thread_q, 32); \ + \ + const uint32_t start_row_q = src0_nrows_per_thread_q * ith; \ + const uint32_t end_row_q = MIN(start_row_q + src0_nrows_per_thread_q, src0_nrows_q); \ + \ + if (start_row_q < end_row_q) { \ + uint32_t ct_start_q = start_row_q / 32; \ + uint32_t ct_end_q = (end_row_q + 31) / 32; \ + \ + uint32_t push_ct = ct_start_q; \ + for (uint32_t d = 0; d < n_prefetch && push_ct < ct_end_q; d++, push_ct++) { \ + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src3_ptr + d * tile_row_transfer_size_aligned, \ + src3_row + push_ct * tile_row_stride), aligned_tile_size, tile_size, tile_size, n_k_tiles_a); \ + } \ + \ + for (uint32_t ct = ct_start_q; ct < ct_end_q; ct++) { \ + const uint8_t * w_tile_q = dma_queue_pop(dma_queue).dst; \ + \ + int valid_rows = (int)src3->ne[1] - (int)(ct * 32); \ + valid_rows = MIN(32, MAX(0, valid_rows)); \ + \ + htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ct); \ + uint32_t ir1 = 0; \ + for (; ir1 + 1 < src1_nrows; ir1 += 2) { \ + const uint8_t * restrict src1_col0 = (const uint8_t *) (src1_data + (ir1+0) * src1_stride); \ + const uint8_t * restrict src1_col1 = (const uint8_t *) (src1_data + (ir1+1) * src1_stride); \ + \ + float * restrict dst_row0_q = (float *) (dst_q->data + ((ir1+0) * dst_q_row_size)); \ + float * restrict dst_row1_q = (float *) (dst_q->data + ((ir1+1) * dst_q_row_size)); \ + float * dst_ptr0_q = &dst_row0_q[ct * 32]; \ + float * dst_ptr1_q = &dst_row1_q[ct * 32]; \ + \ + DOT_2X2(ne10, dst_ptr0_q, dst_ptr1_q, w_tile_q, src1_col0, src1_col1, valid_rows); \ + } \ + \ + for (; ir1 < src1_nrows; ++ir1) { \ + const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride); \ + \ + float * restrict dst_row_q = (float *) (dst_q->data + (ir1 * dst_q_row_size)); \ + float * dst_ptr_q = &dst_row_q[ct * 32]; \ + \ + DOT_2X1(ne10, dst_ptr_q, w_tile_q, src1_col, valid_rows); \ + } \ + htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ct); \ + \ + if (push_ct < ct_end_q) { \ + dma_queue_push(dma_queue, dma_make_ptr((uint8_t *)w_tile_q, src3_row + push_ct * tile_row_stride), \ + aligned_tile_size, tile_size, tile_size, n_k_tiles_a); \ + push_ct++; \ + } \ + } \ + } \ +} + +#define MATMUL_FFN_2D_REPACKED_IMPL(SUFFIX, TILE_SIZE, DOT_2X2, DOT_2X1) \ +static void hvx_mm_ffn_2d_repacked_##SUFFIX(unsigned int nth, unsigned int ith, void * data) { \ + struct htp_mm_context * mmctx = data; \ + struct htp_ops_context * octx = mmctx->octx; \ + \ + const struct htp_tensor * restrict src0 = octx->src[0]; /* Wgate */ \ + const struct htp_tensor * restrict src1 = octx->src[1]; /* y */ \ + const struct htp_tensor * restrict src2 = octx->src[2]; /* Wup */ \ + const struct htp_tensor * restrict dst_gate = octx->dsts[0]; \ + const struct htp_tensor * restrict dst_up = octx->dsts[1]; \ + \ + const uint32_t ne00 = src0->ne[0]; \ + const uint32_t ne01 = src0->ne[1]; \ + const uint32_t ne10 = src1->ne[0]; \ + const uint32_t src1_nrows = src1->ne[1] * src1->ne[2] * src1->ne[3]; \ + \ + const size_t dst_row_size = dst_gate->nb[1]; \ + const size_t src1_stride = mmctx->vtcm_src1_stride; \ + \ + uint8_t * restrict vtcm_src0_ptr = mmctx->vtcm_src0 + mmctx->vtcm_src0_size_per_thread * ith; \ + uint8_t * restrict vtcm_src2_ptr = mmctx->vtcm_src2 + mmctx->vtcm_src2_size_per_thread * ith; \ + uint8_t * restrict src1_data = mmctx->vtcm_src1; \ + \ + struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL; \ + \ + const uint8_t * restrict src0_row = (const uint8_t *) src0->data; \ + const uint8_t * restrict src2_row = (const uint8_t *) src2->data; \ + \ + const uint32_t tile_size = TILE_SIZE; \ + const uint32_t aligned_tile_size = hex_align_up(tile_size, 128); \ + \ + const struct htp_mm_kernel_params * kparams = (const struct htp_mm_kernel_params *) octx->kernel_params; \ + const uint32_t n_prefetch = kparams->n_prefetch; \ + assert(n_prefetch >= 2 && n_prefetch <= HTP_MM_MAX_PREFETCH && (n_prefetch & (n_prefetch - 1)) == 0); \ + \ + uint32_t n_k_tiles_w = ne00 / 32; \ + uint32_t n_k_tiles_a = ne10 / 32; \ + uint32_t tile_row_stride = n_k_tiles_w * tile_size; \ + uint32_t tile_row_transfer_size_aligned = n_k_tiles_a * aligned_tile_size; \ + dma_queue * dma_queue = octx->ctx->dma[ith]; \ + \ + const uint32_t src0_nrows = ne01 * src0->ne[2] * src0->ne[3]; \ + const uint32_t src0_start_row = mmctx->src0_nrows_per_thread * ith; \ + const uint32_t src0_end_row = MIN(src0_start_row + mmctx->src0_nrows_per_thread, src0_nrows); \ + \ + uint32_t ct_start = src0_start_row / 32; \ + uint32_t ct_end = (src0_end_row + 31) / 32; \ + \ + uint32_t push_ct = ct_start; \ + for (uint32_t d = 0; d < n_prefetch && push_ct < ct_end; d++, push_ct++) { \ + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src0_ptr + d * tile_row_transfer_size_aligned, \ + src0_row + push_ct * tile_row_stride), aligned_tile_size, tile_size, tile_size, n_k_tiles_a); \ + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src2_ptr + d * tile_row_transfer_size_aligned, \ + src2_row + push_ct * tile_row_stride), aligned_tile_size, tile_size, tile_size, n_k_tiles_a); \ + } \ + \ + for (uint32_t ct = ct_start; ct < ct_end; ct++) { \ + const uint8_t * w_tile_gate = dma_queue_pop(dma_queue).dst; \ + const uint8_t * w_tile_up = dma_queue_pop(dma_queue).dst; \ + \ + int valid_rows = (int)ne01 - (int)(ct * 32); \ + valid_rows = MIN(32, MAX(0, valid_rows)); \ + \ + htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ct); \ + uint32_t ir1 = 0; \ + for (; ir1 + 1 < src1_nrows; ir1 += 2) { \ + const uint8_t * restrict src1_col0 = (const uint8_t *) (src1_data + (ir1+0) * src1_stride); \ + const uint8_t * restrict src1_col1 = (const uint8_t *) (src1_data + (ir1+1) * src1_stride); \ + \ + float * restrict dst_row0_gate = (float *) (dst_gate->data + ((ir1+0) * dst_row_size)); \ + float * restrict dst_row1_gate = (float *) (dst_gate->data + ((ir1+1) * dst_row_size)); \ + float * dst_ptr0_gate = &dst_row0_gate[ct * 32]; \ + float * dst_ptr1_gate = &dst_row1_gate[ct * 32]; \ + \ + float * restrict dst_row0_up = (float *) (dst_up->data + ((ir1+0) * dst_row_size)); \ + float * restrict dst_row1_up = (float *) (dst_up->data + ((ir1+1) * dst_row_size)); \ + float * dst_ptr0_up = &dst_row0_up[ct * 32]; \ + float * dst_ptr1_up = &dst_row1_up[ct * 32]; \ + \ + DOT_2X2(ne10, dst_ptr0_gate, dst_ptr1_gate, w_tile_gate, src1_col0, src1_col1, valid_rows); \ + DOT_2X2(ne10, dst_ptr0_up, dst_ptr1_up, w_tile_up, src1_col0, src1_col1, valid_rows); \ + } \ + \ + for (; ir1 < src1_nrows; ++ir1) { \ + const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride); \ + \ + float * restrict dst_row_gate = (float *) (dst_gate->data + (ir1 * dst_row_size)); \ + float * dst_ptr_gate = &dst_row_gate[ct * 32]; \ + \ + float * restrict dst_row_up = (float *) (dst_up->data + (ir1 * dst_row_size)); \ + float * dst_ptr_up = &dst_row_up[ct * 32]; \ + \ + DOT_2X1(ne10, dst_ptr_gate, w_tile_gate, src1_col, valid_rows); \ + DOT_2X1(ne10, dst_ptr_up, w_tile_up, src1_col, valid_rows); \ + } \ + htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ct); \ + \ + if (push_ct < ct_end) { \ + dma_queue_push(dma_queue, dma_make_ptr((uint8_t *)w_tile_gate, src0_row + push_ct * tile_row_stride), \ + aligned_tile_size, tile_size, tile_size, n_k_tiles_a); \ + dma_queue_push(dma_queue, dma_make_ptr((uint8_t *)w_tile_up, src2_row + push_ct * tile_row_stride), \ + aligned_tile_size, tile_size, tile_size, n_k_tiles_a); \ + push_ct++; \ + } \ + } \ +} + +MATMUL_2D_REPACKED_IMPL(q4_0, 576, tiled_vec_dot_q4_0_32x2, tiled_vec_dot_q4_0_32x1) +MATMUL_2D_REPACKED_IMPL(q4_1, 640, tiled_vec_dot_q4_1_32x2, tiled_vec_dot_q4_1_32x1) +MATMUL_2D_REPACKED_IMPL(q8_0, 1088, tiled_vec_dot_q8_0_32x2, tiled_vec_dot_q8_0_32x1) +MATMUL_2D_REPACKED_IMPL(iq4nl, 576, tiled_vec_dot_iq4nl_32x2, tiled_vec_dot_iq4nl_32x1) +MATMUL_2D_REPACKED_IMPL(mxfp4, 544, tiled_vec_dot_mxfp4_32x2, tiled_vec_dot_mxfp4_32x1) + +MATMUL_2D_REPACKED_IMPL(q4_0_flat, 576, flat_vec_dot_q4_0_32x2, flat_vec_dot_q4_0_32x1) +MATMUL_2D_REPACKED_IMPL(q4_1_flat, 640, flat_vec_dot_q4_1_32x2, flat_vec_dot_q4_1_32x1) +MATMUL_2D_REPACKED_IMPL(q8_0_flat, 1088, flat_vec_dot_q8_0_32x2, flat_vec_dot_q8_0_32x1) +MATMUL_2D_REPACKED_IMPL(iq4nl_flat, 576, flat_vec_dot_iq4nl_32x2, flat_vec_dot_iq4nl_32x1) +MATMUL_2D_REPACKED_IMPL(mxfp4_flat, 544, flat_vec_dot_mxfp4_32x2, flat_vec_dot_mxfp4_32x1) + +#define QUANTIZE_IMPL(name, log_name, kernel_fn, dst_row_size_expr) \ +static void name(unsigned int nth, unsigned int ith, void * data) { \ + struct htp_mm_context * mmctx = data; \ + struct htp_ops_context * octx = mmctx->octx; \ + const struct htp_tensor * src = octx->src[1]; \ + const uint32_t ne0 = src->ne[0]; \ + const uint32_t ne1 = src->ne[1]; \ + const uint32_t ne2 = src->ne[2]; \ + const uint32_t ne3 = src->ne[3]; \ + const uint32_t nrows = ne1 * ne2 * ne3; \ + const uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread; \ + \ + const uint32_t ir_first = nrows_per_thread * ith; \ + if (ir_first >= nrows) { \ + return; \ + } \ + \ + struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL; \ + htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first); \ + \ + uint8_t * restrict dst = mmctx->vtcm_src1; \ + const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); \ + const size_t src_row_size = src->nb[1]; \ + const size_t dst_row_size = (dst_row_size_expr); \ + const uint8_t * restrict src_data = (const uint8_t *) src->data + (src_row_size * ir_first); \ + uint8_t * restrict dst_data = (uint8_t *) dst + (dst_row_size * ir_first); \ + uint8_t * restrict tmp_data = (uint8_t *) mmctx->vtcm_src0 + (mmctx->vtcm_src0_size_per_thread * ith); \ + kernel_fn(src_data, dst_data, tmp_data, ne0, ir_last - ir_first, src_row_size, dst_row_size); \ + \ + htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first); \ +} + +QUANTIZE_IMPL(quantize_f32_q8_0_tiled, "quantize-f32-q8_0_tiled", quantize_f32_q8_0_tiled_kernel, htp_mm_q8_0_tiled_row_size(ne0)) +QUANTIZE_IMPL(quantize_f32_q8_1_tiled, "quantize-f32-q8_1_tiled", quantize_f32_q8_1_tiled_kernel, htp_mm_q8_1_tiled_row_size(ne0)) +QUANTIZE_IMPL(quantize_f32_q8_0_flat, "quantize-f32-q8_0_flat", quantize_f32_q8_0_flat_kernel, htp_mm_q8_0_flat_row_size(ne0)) +QUANTIZE_IMPL(quantize_f32_q8_1_flat, "quantize-f32-q8_1_flat", quantize_f32_q8_1_flat_kernel, htp_mm_q8_1_flat_row_size(ne0)) +QUANTIZE_IMPL(quantize_f32_f32_flat, "quantize-f32-f32", quantize_f32_f32_flat_kernel, mmctx->vtcm_src1_stride) +QUANTIZE_IMPL(quantize_f32_f16_flat, "quantize-f32-f16", quantize_f32_f16_flat_kernel, mmctx->vtcm_src1_stride) +QUANTIZE_IMPL(quantize_f16_f16_flat, "quantize-f16-f16", quantize_f16_f16_flat_kernel, mmctx->vtcm_src1_stride) + +static void quantize_f32_q8_0_tiled_block(unsigned int nth, unsigned int ith, void * data) { + struct htp_mm_context * mmctx = data; + struct htp_ops_context * octx = mmctx->octx; struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL; + htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_QUANT, mmctx->quant_ib_first[ith]); + + const struct htp_tensor * src = octx->src[1]; + + quantize_f32_q8_0_tiled_block_kernel( + (const float *) src->data, + mmctx->vtcm_src1, + (uint8_t *) mmctx->vtcm_src0 + (mmctx->vtcm_src0_size_per_thread * ith), + src->ne[0], + mmctx->quant_ib_first[ith], + mmctx->quant_ib_last[ith], + src->nb[1], + htp_mm_q8_0_tiled_row_size(src->ne[0]), + mmctx->quant_r[ith], + mmctx->quant_c[ith] + ); + + htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_QUANT, mmctx->quant_ib_first[ith]); +} + +static void quantize_f32_q8_1_tiled_block(unsigned int nth, unsigned int ith, void * data) { + struct htp_mm_context * mmctx = data; + struct htp_ops_context * octx = mmctx->octx; + struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL; + htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_QUANT, mmctx->quant_ib_first[ith]); + + const struct htp_tensor * src = octx->src[1]; + + quantize_f32_q8_1_tiled_block_kernel( + (const float *) src->data, + mmctx->vtcm_src1, + (uint8_t *) mmctx->vtcm_src0 + (mmctx->vtcm_src0_size_per_thread * ith), + src->ne[0], + mmctx->quant_ib_first[ith], + mmctx->quant_ib_last[ith], + src->nb[1], + htp_mm_q8_1_tiled_row_size(src->ne[0]), + mmctx->quant_r[ith], + mmctx->quant_c[ith] + ); + + htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_QUANT, mmctx->quant_ib_first[ith]); +} + +MATVEC_2D_REPACKED_IMPL(q4_0, 576, tiled_vec_dot_q4_0_32x1) +MATVEC_2D_REPACKED_IMPL(q4_1, 640, tiled_vec_dot_q4_1_32x1) +MATVEC_2D_REPACKED_IMPL(q8_0, 1088, tiled_vec_dot_q8_0_32x1) +MATVEC_2D_REPACKED_IMPL(iq4nl, 576, tiled_vec_dot_iq4nl_32x1) +MATVEC_2D_REPACKED_IMPL(mxfp4, 544, tiled_vec_dot_mxfp4_32x1) + +MATVEC_2D_REPACKED_IMPL(q4_0_flat, 576, flat_vec_dot_q4_0_32x1) +MATVEC_2D_REPACKED_IMPL(q4_1_flat, 640, flat_vec_dot_q4_1_32x1) +MATVEC_2D_REPACKED_IMPL(q8_0_flat, 1088, flat_vec_dot_q8_0_32x1) +MATVEC_2D_REPACKED_IMPL(iq4nl_flat, 576, flat_vec_dot_iq4nl_32x1) +MATVEC_2D_REPACKED_IMPL(mxfp4_flat, 544, flat_vec_dot_mxfp4_32x1) + + +MATMUL_QKV_2D_REPACKED_IMPL(q4_0, 576, tiled_vec_dot_q4_0_32x2, tiled_vec_dot_q4_0_32x1) +MATMUL_QKV_2D_REPACKED_IMPL(q4_1, 640, tiled_vec_dot_q4_1_32x2, tiled_vec_dot_q4_1_32x1) +MATMUL_QKV_2D_REPACKED_IMPL(q8_0, 1088, tiled_vec_dot_q8_0_32x2, tiled_vec_dot_q8_0_32x1) +MATMUL_QKV_2D_REPACKED_IMPL(iq4nl, 576, tiled_vec_dot_iq4nl_32x2, tiled_vec_dot_iq4nl_32x1) +MATMUL_QKV_2D_REPACKED_IMPL(mxfp4, 544, tiled_vec_dot_mxfp4_32x2, tiled_vec_dot_mxfp4_32x1) + +MATMUL_QKV_2D_REPACKED_IMPL(q4_0_flat, 576, flat_vec_dot_q4_0_32x2, flat_vec_dot_q4_0_32x1) +MATMUL_QKV_2D_REPACKED_IMPL(q4_1_flat, 640, flat_vec_dot_q4_1_32x2, flat_vec_dot_q4_1_32x1) +MATMUL_QKV_2D_REPACKED_IMPL(q8_0_flat, 1088, flat_vec_dot_q8_0_32x2, flat_vec_dot_q8_0_32x1) +MATMUL_QKV_2D_REPACKED_IMPL(iq4nl_flat, 576, flat_vec_dot_iq4nl_32x2, flat_vec_dot_iq4nl_32x1) +MATMUL_QKV_2D_REPACKED_IMPL(mxfp4_flat, 544, flat_vec_dot_mxfp4_32x2, flat_vec_dot_mxfp4_32x1) + + +MATMUL_FFN_2D_REPACKED_IMPL(q4_0, 576, tiled_vec_dot_q4_0_32x2, tiled_vec_dot_q4_0_32x1) +MATMUL_FFN_2D_REPACKED_IMPL(q4_1, 640, tiled_vec_dot_q4_1_32x2, tiled_vec_dot_q4_1_32x1) +MATMUL_FFN_2D_REPACKED_IMPL(q8_0, 1088, tiled_vec_dot_q8_0_32x2, tiled_vec_dot_q8_0_32x1) +MATMUL_FFN_2D_REPACKED_IMPL(iq4nl, 576, tiled_vec_dot_iq4nl_32x2, tiled_vec_dot_iq4nl_32x1) +MATMUL_FFN_2D_REPACKED_IMPL(mxfp4, 544, tiled_vec_dot_mxfp4_32x2, tiled_vec_dot_mxfp4_32x1) + +MATMUL_FFN_2D_REPACKED_IMPL(q4_0_flat, 576, flat_vec_dot_q4_0_32x2, flat_vec_dot_q4_0_32x1) +MATMUL_FFN_2D_REPACKED_IMPL(q4_1_flat, 640, flat_vec_dot_q4_1_32x2, flat_vec_dot_q4_1_32x1) +MATMUL_FFN_2D_REPACKED_IMPL(q8_0_flat, 1088, flat_vec_dot_q8_0_32x2, flat_vec_dot_q8_0_32x1) +MATMUL_FFN_2D_REPACKED_IMPL(iq4nl_flat, 576, flat_vec_dot_iq4nl_32x2, flat_vec_dot_iq4nl_32x1) +MATMUL_FFN_2D_REPACKED_IMPL(mxfp4_flat, 544, flat_vec_dot_mxfp4_32x2, flat_vec_dot_mxfp4_32x1) + +static void hvx_mm_2d(unsigned int nth, unsigned int ith, void * data) { + htp_matmul_preamble; + + const struct htp_mm_kernel_params * kparams = (const struct htp_mm_kernel_params *) octx->kernel_params; + const uint32_t n_prefetch = kparams->n_prefetch; + assert(n_prefetch >= 2 && n_prefetch <= HTP_MM_MAX_PREFETCH && (n_prefetch & (n_prefetch - 1)) == 0); + const uint32_t prefetch_mask = n_prefetch - 1; const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows const uint32_t src1_nrows = ne11 * ne12 * ne13; // src1 rows @@ -3447,34 +1235,31 @@ static void matmul_2d(unsigned int nth, unsigned int ith, void * data) { return; } + struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL; + const size_t dst_row_size = nb1; const size_t src0_row_size = nb01; const size_t src1_row_size = nb11; - const size_t src0_stride = src0_spad->stride; - const size_t src1_stride = src1_spad->stride; + const size_t src0_stride = mmctx->vtcm_src0_stride; + const size_t src1_stride = mmctx->vtcm_src1_stride; - // Per-thread VTCM scratchpads for all tensors - // Note that the entire src1 tensor is already in VTCM - // For other tensors we allocate N rows per thread, padded to HVX vector size - uint8_t * restrict spad_dst = dst_spad->data + dst_spad->size_per_thread * ith; - uint8_t * restrict spad_src0 = src0_spad->data + src0_spad->size_per_thread * ith; - uint8_t * restrict src1_data = src1_spad->data; - - volatile uint64_t t1, t2; - t1 = HAP_perf_get_qtimer_count(); + // Per-thread VTCMs for all tensors + uint8_t * restrict vtcm_dst_ptr = mmctx->vtcm_dst + mmctx->vtcm_dst_size_per_thread * ith; + uint8_t * restrict vtcm_src0_ptr = mmctx->vtcm_src0 + mmctx->vtcm_src0_size_per_thread * ith; + uint8_t * restrict src1_data = mmctx->vtcm_src1; const uint8_t * restrict src0_row = (const uint8_t *) src0->data; - // Prefill spad with src0 rows + // Prefill vtcm with src0 rows #pragma unroll(4) for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { const int is0 = (ir0 - src0_start_row); - if (is0 >= MM_SPAD_SRC0_NROWS) { + if (is0 >= (int)n_prefetch) { break; } - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), - src0_stride, src0_row_size, 2); + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src0_ptr + is0 * src0_stride, src0_row + ir0 * src0_row_size), + src0_stride, src0_row_size, src0_row_size, 2); } // Process src0 rows @@ -3482,7 +1267,6 @@ static void matmul_2d(unsigned int nth, unsigned int ith, void * data) { const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0); - // Process src1 columns in pairs (2×2 tiling) uint32_t ir1 = 0; for (; ir1 + 1 < src1_nrows; ir1 += 2) { @@ -3499,24 +1283,23 @@ static void matmul_2d(unsigned int nth, unsigned int ith, void * data) { float * restrict dst_row = (float *) (dst->data + (ir1 * dst_row_size)); mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_stride, src1_col); } - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0); - // Prefetch next (n + spad_nrows) row - const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS); - const int is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS; + // Prefetch next (n + vtcm_nrows) row + const int pr0 = (ir0 + n_prefetch); + const int is0 = (pr0 - src0_start_row) & prefetch_mask; if (pr0 < src0_end_row_x2) { - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + pr0 * src0_row_size), - src0_stride, src0_row_size, 2); + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src0_ptr + is0 * src0_stride, src0_row + pr0 * src0_row_size), + src0_stride, src0_row_size, src0_row_size, 2); } } // Process the last row (if any) if (src0_end_row != src0_end_row_x2) { uint32_t ir0 = src0_end_row_x2; - const int is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS; - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), - src0_stride, src0_row_size, 1); + const int is0 = (ir0 - src0_start_row) & prefetch_mask; + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src0_ptr + is0 * src0_stride, src0_row + ir0 * src0_row_size), + src0_stride, src0_row_size, src0_row_size, 1); const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0); @@ -3528,19 +1311,10 @@ static void matmul_2d(unsigned int nth, unsigned int ith, void * data) { } htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0); } - - t2 = HAP_perf_get_qtimer_count(); - - FARF(HIGH, "matmul-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", mmctx->type, ith, nth, - src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1], - src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], - (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } -// q8x4x2 src1 tensor is already in VTCM spad -static void matvec_2d(unsigned int nth, unsigned int ith, void * data) { +static void hvx_mv_2d(unsigned int nth, unsigned int ith, void * data) { htp_matmul_preamble; - struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL; const uint32_t src0_nrows = ne01; @@ -3552,164 +1326,101 @@ static void matvec_2d(unsigned int nth, unsigned int ith, void * data) { return; } + struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL; + const size_t dst_row_size = nb1; const size_t src0_row_size = nb01; const size_t src1_row_size = nb11; - const size_t src0_stride = src0_spad->stride; - const size_t src1_stride = src1_spad->stride; + const size_t src0_stride = mmctx->vtcm_src0_stride; + const size_t src1_stride = mmctx->vtcm_src1_stride; - // Per-thread VTCM scratchpads for all tensors - // Note that the entire src1 tensor is already in VTCM - // For other tensors we allocate N rows per thread, padded to HVX vector size - uint8_t * spad_dst = dst_spad->data + dst_spad->size_per_thread * ith; - uint8_t * spad_src0 = src0_spad->data + src0_spad->size_per_thread * ith; - uint8_t * src1_data = src1_spad->data; + // Per-thread VTCMs for all tensors + uint8_t * vtcm_dst_ptr = mmctx->vtcm_dst + mmctx->vtcm_dst_size_per_thread * ith; + uint8_t * vtcm_src0_ptr = mmctx->vtcm_src0 + mmctx->vtcm_src0_size_per_thread * ith; + uint8_t * src1_data = mmctx->vtcm_src1; - uint64_t t1, t2; - t1 = HAP_perf_get_qtimer_count(); - - float * tmp = (float *) spad_dst; + float * tmp = (float *) vtcm_dst_ptr; const uint8_t * restrict src0_row = (const uint8_t *) src0->data; const uint8_t * restrict src1_col = (const uint8_t *) src1_data; float * restrict dst_col = (float *) dst->data; - if (mmctx->vec_dot_4x1 != NULL) { - const uint32_t src0_end_row_x4 = src0_start_row + ((src0_end_row - src0_start_row) & ~3U); + const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U); - // Prefill spad with 4x src0 rows - #pragma unroll(4) - for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x4; ir0 += 4) { - const uint32_t is0 = (ir0 - src0_start_row); - if (is0 >= MM_SPAD_SRC0_NROWS) { - break; - } - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), - src0_stride, src0_row_size, 4); + const struct htp_mm_kernel_params * kparams = (const struct htp_mm_kernel_params *) octx->kernel_params; + const uint32_t n_prefetch = kparams->n_prefetch; + assert(n_prefetch >= 2 && n_prefetch <= HTP_MM_MAX_PREFETCH && (n_prefetch & (n_prefetch - 1)) == 0); + const uint32_t prefetch_mask = n_prefetch - 1; + + // Prefill vtcm with 2x src0 rows + #pragma unroll(2) + for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { + const uint32_t is0 = (ir0 - src0_start_row); + if (is0 >= n_prefetch) { + break; } + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src0_ptr + is0 * src0_stride, src0_row + ir0 * src0_row_size), + src0_stride, src0_row_size, src0_row_size, 2); + } - // Process src0 rows - for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x4; ir0 += 4) { - const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0); - mmctx->vec_dot_4x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, ss0 + 2 * src0_stride, ss0 + 3 * src0_stride, src1_col); - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0); + // Process src0 rows + for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { + const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; + htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0); + mmctx->vec_dot_2x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, src1_col); + htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0); - // Prefetch next (n + spad_nrows) row - const uint32_t pr0 = (ir0 + MM_SPAD_SRC0_NROWS); - const uint32_t is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS; - if (pr0 < src0_end_row_x4) { - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + pr0 * src0_row_size), - src0_stride, src0_row_size, 4); - } - } - - // Process leftovers - uint32_t ir0 = src0_end_row_x4; - if (ir0 + 2 <= src0_end_row) { - const uint32_t is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS; - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), - src0_stride, src0_row_size, 2); - const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0); - mmctx->vec_dot_2x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, src1_col); - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0); - ir0 += 2; - } - if (ir0 < src0_end_row) { - const uint32_t is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS; - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), - src0_stride, src0_row_size, 1); - const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0); - mmctx->vec_dot_1x1(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col); - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0); - ir0 += 1; - } - } else { - const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U); - - // Prefill spad with 2x src0 rows - #pragma unroll(2) - for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { - const uint32_t is0 = (ir0 - src0_start_row); - if (is0 >= MM_SPAD_SRC0_NROWS) { - break; - } - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), - src0_stride, src0_row_size, 2); - } - - // Process src0 rows - for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { - const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0); - mmctx->vec_dot_2x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, src1_col); - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0); - - // Prefetch next (n + spad_nrows) row - const uint32_t pr0 = (ir0 + MM_SPAD_SRC0_NROWS); - const uint32_t is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS; - if (pr0 < src0_end_row_x2) { - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + pr0 * src0_row_size), - src0_stride, src0_row_size, 2); - } - } - - // Process the last row (if any) - if (src0_end_row != src0_end_row_x2) { - const uint32_t ir0 = src0_end_row_x2; - const uint32_t is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS; - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), - src0_stride, src0_row_size, 1); - const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0); - mmctx->vec_dot_1x1(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col); - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0); + // Prefetch next (n + vtcm_nrows) row + const uint32_t pr0 = (ir0 + n_prefetch); + const uint32_t is0 = (pr0 - src0_start_row) & prefetch_mask; + if (pr0 < src0_end_row_x2) { + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src0_ptr + is0 * src0_stride, src0_row + pr0 * src0_row_size), + src0_stride, src0_row_size, src0_row_size, 2); } } + // Process the last row (if any) + if (src0_end_row != src0_end_row_x2) { + const uint32_t ir0 = src0_end_row_x2; + const uint32_t is0 = (ir0 - src0_start_row) & prefetch_mask; + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src0_ptr + is0 * src0_stride, src0_row + ir0 * src0_row_size), + src0_stride, src0_row_size, src0_row_size, 1); + const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; + htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0); + mmctx->vec_dot_1x1(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col); + htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0); + } + hvx_copy_f32_ua((uint8_t *) &dst_col[src0_start_row], (uint8_t *) tmp, src0_end_row - src0_start_row); - - t2 = HAP_perf_get_qtimer_count(); - - FARF(HIGH, "matvec-%s %u/%u: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", mmctx->type, ith, nth, - src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1], - src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], - (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } #define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id) * ids->ne[0] * ids->ne[1] + (i1)] -struct mmid_row_mapping { - uint32_t i1; - uint32_t i2; -}; - -// src1 tensor is already in VTCM spad -static void matmul_id(unsigned int nth, unsigned int ith, void * data) { +static void hvx_mm_id(unsigned int nth, unsigned int ith, void * data) { htp_matmul_preamble; - struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL; const struct htp_tensor * restrict ids = octx->src[2]; - struct htp_spad * restrict src2_spad = &octx->src2_spad; uint64_t t1, t2; t1 = HAP_perf_get_qtimer_count(); - const uint32_t src0_nrows = ne01; // src0 rows per expert - const uint32_t src1_nrows = ne11; - + const uint32_t src0_nrows = ne01; // src0 rows per expert + const uint32_t src1_nrows = ne11; const uint32_t src0_start_row = src0_nrows_per_thread * ith; const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); - const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U); // no work for this thread if (src0_start_row >= src0_end_row) { return; } + struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL; + + const struct htp_mm_kernel_params * kparams = (const struct htp_mm_kernel_params *) octx->kernel_params; + const uint32_t n_prefetch = kparams->n_prefetch; + assert(n_prefetch >= 2 && n_prefetch <= HTP_MM_MAX_PREFETCH && (n_prefetch & (n_prefetch - 1)) == 0); + const uint32_t n_ids = ids->ne[0]; // n_expert_used const uint32_t n_as = ne02; // n_expert @@ -3717,807 +1428,195 @@ static void matmul_id(unsigned int nth, unsigned int ith, void * data) { const struct mmid_row_mapping * matrix_rows = mmctx->matrix_rows; const size_t dst_row_size = nb1; - const size_t src0_row_size = nb01; - const size_t src1_row_size = q8x4x2_row_size(ne10); + const size_t src1_row_size = htp_mm_q8_0_tiled_row_size(ne10); - const size_t src0_row_size_padded = hex_round_up(src0_row_size, 128); + const size_t src1_stride = mmctx->vtcm_src1_stride; - // Per-thread VTCM scratchpads for all tensors - // Note that the entire src1 tensor is already in VTCM - // For other tensors we allocate N rows per thread, padded to HVX vector size - uint8_t * restrict spad_dst = dst_spad->data + dst_spad->size_per_thread * ith; - uint8_t * restrict spad_src0 = src0_spad->data + src0_spad->size_per_thread * ith; - uint8_t * restrict src1_data = src1_spad->data; + // Per-thread VTCMs for all tensors + uint8_t * restrict vtcm_src0_ptr = mmctx->vtcm_src0 + mmctx->vtcm_src0_size_per_thread * ith; + uint8_t * restrict src1_data = mmctx->vtcm_src1; for (uint32_t cur_a = 0; cur_a < n_as; ++cur_a) { const int32_t cne1 = matrix_row_counts[cur_a]; - if (cne1 == 0) { continue; } - if (mmctx->hmx_eligible) { - continue; + const uint8_t * src0_row = (const uint8_t *) src0->data + cur_a * nb02; + + const uint32_t tile_size = htp_mm_get_weight_tile_size(src0->type); + const uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(src0->type); + const uint32_t n_k_tiles_w = ne00 / 32; + const uint32_t n_k_tiles_a = ne10 / 32; + const uint32_t tile_row_stride = n_k_tiles_w * tile_size; + const uint32_t tile_row_transfer_size_aligned = n_k_tiles_a * aligned_tile_size; + + const uint32_t ct_start = src0_start_row / 32; + const uint32_t ct_end = (src0_end_row + 31) / 32; + + uint32_t push_ct = ct_start; + for (uint32_t d = 0; d < n_prefetch && push_ct < ct_end; d++, push_ct++) { + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src0_ptr + d * tile_row_transfer_size_aligned, src0_row + push_ct * tile_row_stride), + aligned_tile_size, tile_size, tile_size, n_k_tiles_a); } - const uint8_t * src0_row = (const uint8_t *) src0->data + (0 + cur_a * nb02 + 0); + for (uint32_t ct = ct_start; ct < ct_end; ct++) { + const uint8_t * w_tile = dma_queue_pop(dma_queue).dst; - // Prefill spad with src0 rows - #pragma unroll(4) - for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { - const int is0 = (ir0 - src0_start_row); - if (is0 >= MM_SPAD_SRC0_NROWS) { - break; - } - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size), - src0_row_size_padded, src0_row_size, 2); - } + int valid_rows = (int)ne01 - (int)(ct * 32); + valid_rows = MIN(32, MAX(0, valid_rows)); - // Process src0 rows - for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { - const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; - - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0); + htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ct); for (uint32_t cid = 0; cid < cne1; ++cid) { struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, cid); const int rm1 = row_mapping.i1; // expert idx const int rm2 = row_mapping.i2; // token idx - const uint32_t ir1 = src1_nrows == 1 ? 0 : rm1; // src1 row idx - const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_row_size); - float * dst_row = (float *) (dst->data + (rm1 * nb1 + rm2 * nb2 + 0)); + const uint32_t ir1 = fastmodulo(rm1, ne11, &mmctx->mm_div_ne11); // src1 row idx + const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_stride); + float * restrict dst_row = (float *) (dst->data + (rm1 * nb1 + rm2 * nb2 + 0)); - mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_row_size_padded, src1_col); + mmctx->vec_dot_32x1(ne10, &dst_row[ct * 32], w_tile, src1_col, valid_rows); } - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0); + htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ct); - // Prefetch next (n + spad_nrows) row - const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS); - const int is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS; - if (pr0 < src0_end_row_x2) { - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size), - src0_row_size_padded, src0_row_size, 2); + if (push_ct < ct_end) { + dma_queue_push(dma_queue, dma_make_ptr((uint8_t *)w_tile, src0_row + push_ct * tile_row_stride), + aligned_tile_size, tile_size, tile_size, n_k_tiles_a); + push_ct++; } } - - // Process the last row (if any) - if (src0_end_row != src0_end_row_x2) { - uint32_t ir0 = src0_end_row_x2; - const uint32_t is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS; - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size), - src0_row_size_padded, src0_row_size, 1); - const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; - - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0); - for (uint32_t cid = 0; cid < cne1; ++cid) { - struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, cid); - const int rm1 = row_mapping.i1; // expert idx - const int rm2 = row_mapping.i2; // token idx - - const uint32_t ir1 = src1_nrows == 1 ? 0 : rm1; // src1 row idx - const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_row_size); - float * dst_row = (float *) (dst->data + (rm1 * nb1 + rm2 * nb2 + 0)); - - mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col); - } - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0); - } } - - t2 = HAP_perf_get_qtimer_count(); - - FARF(HIGH, "matmul-id-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", mmctx->type, - ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], - src1->ne[1], src1->ne[2], src1->ne[3], ids->ne[0], ids->ne[1], ids->ne[2], ids->ne[3], dst->ne[0], dst->ne[1], - dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } -// src1 tensor is already in VTCM spad -static void matvec_id(unsigned int nth, unsigned int ith, void * data) { +static void hvx_mv_id(unsigned int nth, unsigned int ith, void * data) { htp_matmul_preamble; - struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL; const struct htp_tensor * restrict ids = octx->src[2]; - struct htp_spad * restrict src2_spad = &octx->src2_spad; - - uint64_t t1, t2; - t1 = HAP_perf_get_qtimer_count(); - - const uint32_t src0_nrows = ne01; // src0 rows per expert + const uint32_t src0_nrows = ne01; // src0 rows per expert const uint32_t src0_start_row = src0_nrows_per_thread * ith; const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); - const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U); // no work for this thread if (src0_start_row >= src0_end_row) { return; } + struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL; + + const struct htp_mm_kernel_params * kparams = (const struct htp_mm_kernel_params *) octx->kernel_params; + const uint32_t n_prefetch = kparams->n_prefetch; + assert(n_prefetch >= 2 && n_prefetch <= HTP_MM_MAX_PREFETCH && (n_prefetch & (n_prefetch - 1)) == 0); + assert(ne13 % ne03 == 0); const size_t dst_row_size = nb1; - const size_t src0_row_size = nb01; - const size_t src1_row_size = q8x4x2_row_size(ne10); - - const size_t src0_row_size_padded = hex_round_up(src0_row_size, 128); + const size_t src1_row_size = htp_mm_q8_0_tiled_row_size(ne10); const uint32_t n_aids = src2->ne[0]; // num activated experts const uint32_t n_ids = ne02; // num experts - // Per-thread VTCM scratchpads for all tensors - // Note that the entire src1 tensor is already in VTCM - // For other tensors we allocate N rows per thread, padded to HVX vector size - uint8_t * restrict spad_dst = dst_spad->data + dst_spad->size_per_thread * ith; - uint8_t * restrict spad_src0 = src0_spad->data + src0_spad->size_per_thread * ith; - uint8_t * restrict src1_data = src1_spad->data; + // Per-thread VTCMs for all tensors + uint8_t * restrict vtcm_src0_ptr = mmctx->vtcm_src0 + mmctx->vtcm_src0_size_per_thread * ith; + uint8_t * restrict src1_data = mmctx->vtcm_src1; for (uint32_t ie1 = 0; ie1 < n_aids; ++ie1) { // for each expert - const uint32_t eid = *(const int32_t *) ((const uint8_t *) src2->data + ie1 * src2->nb[0]); - assert(eid < n_ids); + const int32_t eid = *(const int32_t *) ((const uint8_t *) src2->data + ie1 * src2->nb[0]); + if (eid < 0) { + continue; + } + assert(eid < (int32_t) n_ids); const uint8_t * restrict src0_row = (const uint8_t *) src0->data + eid * nb02; const uint8_t * restrict src1_col = (const uint8_t *) src1_data; float * restrict dst_row = (float *) (dst->data + ie1 * nb1); - // Prefill spad with src0 rows - #pragma unroll(4) - for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { - const int is0 = (ir0 - src0_start_row); - if (is0 >= MM_SPAD_SRC0_NROWS) { - break; - } - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size), - src0_row_size_padded, src0_row_size, 2); + const uint32_t tile_size = htp_mm_get_weight_tile_size(src0->type); + const uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(src0->type); + const uint32_t n_k_tiles_w = ne00 / 32; + const uint32_t n_k_tiles_a = ne10 / 32; + const uint32_t tile_row_stride = n_k_tiles_w * tile_size; + const uint32_t tile_row_transfer_size_aligned = n_k_tiles_a * aligned_tile_size; + + const uint32_t ct_start = src0_start_row / 32; + const uint32_t ct_end = (src0_end_row + 31) / 32; + + uint32_t push_ct = ct_start; + for (uint32_t d = 0; d < n_prefetch && push_ct < ct_end; d++, push_ct++) { + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src0_ptr + d * tile_row_transfer_size_aligned, src0_row + push_ct * tile_row_stride), + aligned_tile_size, tile_size, tile_size, n_k_tiles_a); } - // Process src0 rows - for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { - const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0); - mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_row_size_padded, src1_col); - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0); + for (uint32_t ct = ct_start; ct < ct_end; ct++) { + const uint8_t * w_tile = dma_queue_pop(dma_queue).dst; - // Prefetch next (n + spad_nrows) row - const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS); - const int is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS; - if (pr0 < src0_end_row_x2) { - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size), - src0_row_size_padded, src0_row_size, 2); + int valid_rows = (int)ne01 - (int)(ct * 32); + valid_rows = MIN(32, MAX(0, valid_rows)); + + htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ct); + mmctx->vec_dot_32x1(ne10, &dst_row[ct * 32], w_tile, src1_col, valid_rows); + htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ct); + + if (push_ct < ct_end) { + dma_queue_push(dma_queue, dma_make_ptr((uint8_t *)w_tile, src0_row + push_ct * tile_row_stride), + aligned_tile_size, tile_size, tile_size, n_k_tiles_a); + push_ct++; } } - - // Process the last row (if any) - if (src0_end_row != src0_end_row_x2) { - uint32_t ir0 = src0_end_row_x2; - const uint32_t is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS; - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size), - src0_row_size_padded, src0_row_size, 1); - const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0); - mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col); - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0); - } } - - t2 = HAP_perf_get_qtimer_count(); - - FARF(HIGH, "matvec-id-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", mmctx->type, - ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], - src1->ne[1], src1->ne[2], src1->ne[3], src2->ne[0], src2->ne[1], src2->ne[2], src2->ne[3], dst->ne[0], - dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } -// *** dynamic quant - -static inline void quantize_block_f32_q8_1x1(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) { - assert((unsigned long) x % 128 == 0); - assert((unsigned long) y_q % 128 == 0); - - HVX_Vector * vx = (HVX_Vector *) x; - HVX_Vector zero = Q6_V_vzero(); - - // Use reduce max fp32 to find max(abs(e)) first - HVX_Vector vmax0_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[0])); - HVX_Vector vmax1_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[1])); - HVX_Vector vmax2_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[2])); - HVX_Vector vmax3_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[3])); - - // Load and convert into QF32 - HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); // 32 elements - HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); // 32 elements - HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); // 32 elements - HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); // 32 elements - - // Convert to QF32 - HVX_Vector vmax0_qf = Q6_Vqf32_vsub_VsfVsf(vmax0_sf, zero); - HVX_Vector vmax1_qf = Q6_Vqf32_vsub_VsfVsf(vmax1_sf, zero); - HVX_Vector vmax2_qf = Q6_Vqf32_vsub_VsfVsf(vmax2_sf, zero); - HVX_Vector vmax3_qf = Q6_Vqf32_vsub_VsfVsf(vmax3_sf, zero); - - // Combine and convert to fp16 - HVX_Vector vmax01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax1_qf, vmax0_qf))); - HVX_Vector vmax23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax3_qf, vmax2_qf))); - - // Convert into fp16 - HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf))); - HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf))); - - HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 - HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 - HVX_Vector vd01_hf = Q6_Vhf_equals_Vqf16(vd01_qf16); - HVX_Vector vd23_hf = Q6_Vhf_equals_Vqf16(vd23_qf16); - - // Divide input by the scale - HVX_Vector vd01_inv_hf = hvx_vec_inverse_f16(vd01_hf); - HVX_Vector vd23_inv_hf = hvx_vec_inverse_f16(vd23_hf); - vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd01_inv_hf)); - vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd23_inv_hf)); - - // Convert to int8 - HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf); - HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf); - HVX_Vector vx_i8 = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16); - - *(HVX_Vector *) y_q = vx_i8; - - // --- Sum calculation --- - const HVX_Vector ones = Q6_Vb_vsplat_R(1); - HVX_Vector v_sums = Q6_Vw_vrmpy_VbVb(vx_i8, ones); // sum every 4 consecutive elements - // Sum 8 elements: - v_sums = Q6_Vw_vadd_VwVw(v_sums, Q6_V_vror_VR(v_sums, 4)); - v_sums = Q6_Vw_vadd_VwVw(v_sums, Q6_V_vror_VR(v_sums, 8)); - v_sums = Q6_Vw_vadd_VwVw(v_sums, Q6_V_vror_VR(v_sums, 16)); - - // Copy to stack to extract sums and vmaxes - float vmax0[32] __attribute__((aligned(128))); - float vmax1[32] __attribute__((aligned(128))); - float vmax2[32] __attribute__((aligned(128))); - float vmax3[32] __attribute__((aligned(128))); - int32_t sums[32] __attribute__((aligned(128))); - - hvx_vec_store_u(vmax0, 128, vmax0_sf); - hvx_vec_store_u(vmax1, 128, vmax1_sf); - hvx_vec_store_u(vmax2, 128, vmax2_sf); - hvx_vec_store_u(vmax3, 128, vmax3_sf); - hvx_vec_store_u(sums, 128, v_sums); - - float d0 = vmax0[0] / 127.0f; - float d1 = vmax1[0] / 127.0f; - float d2 = vmax2[0] / 127.0f; - float d3 = vmax3[0] / 127.0f; - - __fp16 * y_d_half = (__fp16 *) y_d; - y_d_half[0] = d0; - y_d_half[1] = (float) sums[0] * d0; - y_d_half[2] = d1; - y_d_half[3] = (float) sums[8] * d1; - y_d_half[4] = d2; - y_d_half[5] = (float) sums[16] * d2; - y_d_half[6] = d3; - y_d_half[7] = (float) sums[24] * d3; -} - -static inline void quantize_block_f32_q8x1(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) { - assert((unsigned long) x % 128 == 0); - assert((unsigned long) y_q % 128 == 0); - - HVX_Vector * vx = (HVX_Vector *) x; - HVX_Vector zero = Q6_V_vzero(); - - // Use reduce max fp32 to find max(abs(e)) first - HVX_Vector vmax0_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[0])); - HVX_Vector vmax1_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[1])); - HVX_Vector vmax2_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[2])); - HVX_Vector vmax3_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[3])); - // Load and convert into QF32 - HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); // 32 elements - HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); // 32 elements - HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); // 32 elements - HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); // 32 elements - - // Convert to QF32 - HVX_Vector vmax0_qf = Q6_Vqf32_vsub_VsfVsf(vmax0_sf, zero); // replicated over all lanes - HVX_Vector vmax1_qf = Q6_Vqf32_vsub_VsfVsf(vmax1_sf, zero); // replicated over all lanes - HVX_Vector vmax2_qf = Q6_Vqf32_vsub_VsfVsf(vmax2_sf, zero); // replicated over all lanes - HVX_Vector vmax3_qf = Q6_Vqf32_vsub_VsfVsf(vmax3_sf, zero); // replicated over all lanes - - // Combine and convert to fp16 - HVX_Vector vmax01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax1_qf, vmax0_qf))); - HVX_Vector vmax23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax3_qf, vmax2_qf))); - - // Convert into fp16 - HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf))); - HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf))); - - HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 - HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 - HVX_Vector vd01_hf = Q6_Vhf_equals_Vqf16(vd01_qf16); - HVX_Vector vd23_hf = Q6_Vhf_equals_Vqf16(vd23_qf16); - - hvx_vec_store_u(y_d + 0, 2, vd01_hf); - HVX_Vector rotated_vd_hf = Q6_V_vror_VR(vd01_hf, 64); - hvx_vec_store_u(y_d + 2, 2, rotated_vd_hf); - - hvx_vec_store_u(y_d + 4, 2, vd23_hf); - rotated_vd_hf = Q6_V_vror_VR(vd23_hf, 64); - hvx_vec_store_u(y_d + 6, 2, rotated_vd_hf); - - // Divide input by the scale - HVX_Vector vd01_inv_hf = hvx_vec_inverse_f16(vd01_hf); - HVX_Vector vd23_inv_hf = hvx_vec_inverse_f16(vd23_hf); - vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd01_inv_hf)); - vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd23_inv_hf)); - - // Convert to int8 - HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf); - HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf); - HVX_Vector vx_i8 = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16); - - *(HVX_Vector *) y_q = vx_i8; -} - -static inline void quantize_block_f32_q8x2(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) { - assert((unsigned long) x % 128 == 0); - assert((unsigned long) y_q % 128 == 0); - - HVX_Vector * vx = (HVX_Vector *) x; - - // Load and convert into QF32 - HVX_Vector zero = Q6_V_vzero(); - HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); // 32 elements - HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); // 32 elements - HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); // 32 elements - HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); // 32 elements - - // Convert into fp16 - HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf))); - HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf))); - - // Compute max and scale - HVX_Vector vmax01_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx01_hf)); // replicated over all lanes - HVX_Vector vmax23_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx23_hf)); // replicated over all lanes - - HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 - HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 - HVX_Vector vd01_hf = Q6_Vhf_equals_Vqf16(vd01_qf16); - HVX_Vector vd23_hf = Q6_Vhf_equals_Vqf16(vd23_qf16); - - hvx_vec_store_u(y_d + 0, 4, vd01_hf); - hvx_vec_store_u(y_d + 4, 4, vd23_hf); - - // Divide input by the scale - HVX_Vector vd01_inv_hf = hvx_vec_inverse_f16(vd01_hf); - HVX_Vector vd23_inv_hf = hvx_vec_inverse_f16(vd23_hf); - vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd01_inv_hf)); - vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd23_inv_hf)); - - // Convert to int8 - HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf); - HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf); - HVX_Vector vx_i8 = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16); - - *(HVX_Vector *) y_q = vx_i8; -} - -static inline void quantize_block_f32_q8x4(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) { - assert((unsigned long) x % 128 == 0); - assert((unsigned long) y_q % 128 == 0); - - HVX_Vector * vx = (HVX_Vector *) x; - - // Load and convert into QF32 - HVX_Vector zero = Q6_V_vzero(); - HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); // 32 elements - HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); // 32 elements - HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); // 32 elements - HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); // 32 elements - - // Convert into fp16 - HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf))); - HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf))); - - // Compute max and scale - HVX_Vector vmax_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx01_hf)); - vmax_hf = hvx_vec_reduce_max2_f16(hvx_vec_abs_f16(vx23_hf), vmax_hf); // replicated over all lanes - - HVX_Vector vd_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 - HVX_Vector vd_hf = Q6_Vhf_equals_Vqf16(vd_qf16); - - *(HVX_UVector *) y_d = vd_hf; - - // Divide input by the scale - HVX_Vector vd_inv_hf = hvx_vec_inverse_f16(vd_hf); - vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd_inv_hf)); - vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd_inv_hf)); - - // Convert to int8 - HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf); - HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf); - HVX_Vector vx_i8 = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16); - - *(HVX_Vector *) y_q = vx_i8; -} - -// Overrides input x -static void quantize_row_f32_q8x4x2(float * restrict x, uint8_t * restrict y, uint32_t k) { - assert(k % 32 == 0); - const uint32_t qk = QK_Q8_0x4x2; - const uint32_t nb = (k + qk - 1) / qk; - - const uint32_t qrow_size = k; // int8 - - const uint32_t dblk_size = 8 * 2; // 8x __fp16 - const uint32_t qblk_size = QK_Q8_0x4x2; // int8 - - uint8_t * restrict y_q = (y + 0); // quants first - uint8_t * restrict y_d = (y + qrow_size); // then scales - - // Temp scales override input since we're working off of the aligned temp buffer in VTCM - uint8_t * restrict t_d = (uint8_t *) x; - - for (uint32_t i = 0; i < nb; i++) { -#if FP32_QUANTIZE_GROUP_SIZE == 32 - quantize_block_f32_q8x1(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2); - quantize_block_f32_q8x1(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2); -#elif FP32_QUANTIZE_GROUP_SIZE == 64 - quantize_block_f32_q8x2(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2); - quantize_block_f32_q8x2(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2); -#elif FP32_QUANTIZE_GROUP_SIZE == 128 - quantize_block_f32_q8x4(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2); - quantize_block_f32_q8x4(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2); -#else -#error "FP32_QUANTIZE_GROUP_SIZE must be 32, 64, or 128" -#endif - } - - // now copy the scales into final location - hvx_copy_f16_ua(y_d, t_d, nb * 8); -} - -static void quantize_f32_q8x4x2(unsigned int nth, unsigned int ith, void * data) { - struct htp_matmul_context * mmctx = data; - struct htp_ops_context * octx = mmctx->octx; - struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL; - - const struct htp_tensor * src = octx->src[1]; - uint8_t * restrict dst = octx->src1_spad.data; - struct htp_spad * spad = &octx->src0_spad; - uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread; - - uint64_t t1 = HAP_perf_get_qtimer_count(); - - const uint32_t ne0 = src->ne[0]; - const uint32_t ne1 = src->ne[1]; - const uint32_t ne2 = src->ne[2]; - const uint32_t ne3 = src->ne[3]; - - const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows - - const uint32_t ir_first = nrows_per_thread * ith; // first row - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first); - const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row - - const size_t src_row_size = src->nb[1]; - const size_t dst_row_size = q8x4x2_row_size(ne0); - - uint8_t * restrict src_data = (uint8_t *) src->data + (src_row_size * ir_first); - uint8_t * restrict dst_data = (uint8_t *) dst + (dst_row_size * ir_first); - uint8_t * restrict tmp_data = (uint8_t *) spad->data + (spad->size_per_thread * ith); - - const size_t src_row_size_padded = hex_round_up(src_row_size, QK_Q8_0x4x2 * sizeof(float)); - memset(tmp_data, 0, src_row_size_padded); // zero-out temp row data for padding - - for (uint32_t i = ir_first; i < ir_last; ++i) { - hex_l2fetch(src_data, src_row_size, src_row_size, 2); - hvx_copy_f32_aa(tmp_data, src_data, ne0); - - // FARF(HIGH, "quantize-q8x4-row: %u\n", i); - quantize_row_f32_q8x4x2((float *) tmp_data, dst_data, ne0); - dst_data += dst_row_size; - src_data += src_row_size; - } - - uint64_t t2 = HAP_perf_get_qtimer_count(); - - FARF(HIGH, "quantize-f32-q8x4: %u/%u : n-rows %u (%u:%u) row-size %u -> %u usec %u\n", ith, nth, nrows, ir_first, - ir_last, src_row_size, dst_row_size, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first); -} - -static void quantize_row_f32_q8_1x4x2(float * restrict x, uint8_t * restrict y, uint32_t k) { - assert(k % 32 == 0); - const uint32_t qk = QK_Q8_0x4x2; - const uint32_t nb = (k + qk - 1) / qk; - - const uint32_t qrow_size = k; // int8 - - const uint32_t dblk_size = 8 * 4; // 8x (d, s) __fp16 = 32 bytes - const uint32_t qblk_size = QK_Q8_0x4x2; // int8 - - uint8_t * restrict y_q = (y + 0); // quants first - uint8_t * restrict y_d = (y + qrow_size); // then scales/sums - - // Temp scales override input since we're working off of the aligned temp buffer in VTCM - uint8_t * restrict t_d = (uint8_t *) x; - - for (uint32_t i = 0; i < nb; i++) { - quantize_block_f32_q8_1x1(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2); - quantize_block_f32_q8_1x1(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2); - } - - // now copy the scales/sums into final location - hvx_copy_f16_ua(y_d, t_d, nb * 16); -} - -static void quantize_f32_q8_1x4x2(unsigned int nth, unsigned int ith, void * data) { - struct htp_matmul_context * mmctx = data; - struct htp_ops_context * octx = mmctx->octx; - struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL; - - const struct htp_tensor * src = octx->src[1]; - uint8_t * restrict dst = octx->src1_spad.data; - struct htp_spad * spad = &octx->src0_spad; - uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread; - - uint64_t t1 = HAP_perf_get_qtimer_count(); - - const uint32_t ne0 = src->ne[0]; - const uint32_t ne1 = src->ne[1]; - const uint32_t ne2 = src->ne[2]; - const uint32_t ne3 = src->ne[3]; - - const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows - - const uint32_t ir_first = nrows_per_thread * ith; // first row - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first); - const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row - - const size_t src_row_size = src->nb[1]; - const size_t dst_row_size = q8_1x4x2_row_size(ne0); - - uint8_t * restrict src_data = (uint8_t *) src->data + (src_row_size * ir_first); - uint8_t * restrict dst_data = (uint8_t *) dst + (dst_row_size * ir_first); - uint8_t * restrict tmp_data = (uint8_t *) spad->data + (spad->size_per_thread * ith); - - const size_t src_row_size_padded = hex_round_up(src_row_size, QK_Q8_0x4x2 * sizeof(float)); - memset(tmp_data, 0, src_row_size_padded); // zero-out temp row data for padding - - for (uint32_t i = ir_first; i < ir_last; ++i) { - hex_l2fetch(src_data, src_row_size, src_row_size, 2); - hvx_copy_f32_aa(tmp_data, src_data, ne0); - - quantize_row_f32_q8_1x4x2((float *) tmp_data, dst_data, ne0); - dst_data += dst_row_size; - src_data += src_row_size; - } - - uint64_t t2 = HAP_perf_get_qtimer_count(); - - FARF(HIGH, "quantize-f32-q8_1x4: %u/%u : n-rows %u (%u:%u) row-size %u -> %u usec %u\n", ith, nth, nrows, ir_first, - ir_last, src_row_size, dst_row_size, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first); -} - -static void quantize_f32_f32(unsigned int nth, unsigned int ith, void * data) { - struct htp_matmul_context * mmctx = data; - struct htp_ops_context * octx = mmctx->octx; - struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL; - - const struct htp_tensor * src = octx->src[1]; - uint8_t * restrict dst = octx->src1_spad.data; - uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread; - uint32_t dst_stride = octx->src1_spad.stride; - - uint64_t t1 = HAP_perf_get_qtimer_count(); - - const uint32_t ne0 = src->ne[0]; - const uint32_t ne1 = src->ne[1]; - const uint32_t ne2 = src->ne[2]; - const uint32_t ne3 = src->ne[3]; - - const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows - - const uint32_t ir_first = nrows_per_thread * ith; // first row - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first); - const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row - - const size_t src_row_size = ne0 * sizeof(float); - const size_t src_stride = src->nb[1]; - - uint8_t * restrict src_data = (uint8_t *) src->data + (src_stride * ir_first); - uint8_t * restrict dst_data = (uint8_t *) dst + (dst_stride * ir_first); - - for (uint32_t i = ir_first; i < ir_last; ++i) { - hex_l2fetch(src_data, src_row_size, src_stride, 2); - hvx_copy_f32_au(dst_data, src_data, ne0); - - dst_data += dst_stride; - src_data += src_stride; - } - - uint64_t t2 = HAP_perf_get_qtimer_count(); - - FARF(HIGH, "quantize-f32-f32: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first, - ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first); -} - -static void quantize_f32_f16(unsigned int nth, unsigned int ith, void * data) { - struct htp_matmul_context * mmctx = data; - struct htp_ops_context * octx = mmctx->octx; - struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL; - - const struct htp_tensor * src = octx->src[1]; - uint8_t * restrict dst = octx->src1_spad.data; - uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread; - uint32_t dst_stride = octx->src1_spad.stride; - - uint64_t t1 = HAP_perf_get_qtimer_count(); - - const uint32_t ne0 = src->ne[0]; - const uint32_t ne1 = src->ne[1]; - const uint32_t ne2 = src->ne[2]; - const uint32_t ne3 = src->ne[3]; - - const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows - - const uint32_t ir_first = nrows_per_thread * ith; // first row - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first); - const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row - - const size_t src_row_size = ne0 * sizeof(float); - const size_t src_stride = src->nb[1]; - - uint8_t * restrict src_data = (uint8_t *) src->data + (src_stride * ir_first); - uint8_t * restrict dst_data = (uint8_t *) dst + (dst_stride * ir_first); - - for (uint32_t i = ir_first; i < ir_last; ++i) { - hex_l2fetch(src_data, src_row_size, src_stride, 2); - hvx_copy_f16_f32_au(dst_data, src_data, ne0); - - dst_data += dst_stride; - src_data += src_stride; - } - - uint64_t t2 = HAP_perf_get_qtimer_count(); - - FARF(HIGH, "quantize-f32-f16: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first, - ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first); -} - -// TODO just a plain copy that should be done via the DMA during the Op setup -static void quantize_f16_f16(unsigned int nth, unsigned int ith, void * data) { - struct htp_matmul_context * mmctx = data; - struct htp_ops_context * octx = mmctx->octx; - struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL; - - const struct htp_tensor * src = octx->src[1]; - uint8_t * restrict dst = octx->src1_spad.data; - uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread; - uint32_t dst_stride = octx->src1_spad.stride; - - uint64_t t1 = HAP_perf_get_qtimer_count(); - - const uint32_t ne0 = src->ne[0]; - const uint32_t ne1 = src->ne[1]; - const uint32_t ne2 = src->ne[2]; - const uint32_t ne3 = src->ne[3]; - - const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows - - const uint32_t ir_first = nrows_per_thread * ith; // first row - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first); - const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row - - const size_t src_row_size = ne0 * sizeof(float); - const size_t src_stride = src->nb[1]; - - uint8_t * restrict src_data = (uint8_t *) src->data + (src_stride * ir_first); - uint8_t * restrict dst_data = (uint8_t *) dst + (dst_stride * ir_first); - - for (uint32_t i = ir_first; i < ir_last; ++i) { - hex_l2fetch(src_data, src_row_size, src_stride, 2); - hvx_copy_f16_au(dst_data, src_data, ne0); - - dst_data += dst_stride; - src_data += src_stride; - } - - uint64_t t2 = HAP_perf_get_qtimer_count(); - - FARF(HIGH, "quantize-f16-f16: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first, - ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first); -} - - -static inline bool htp_is_permuted(const struct htp_tensor * t) { - return t->nb[0] > t->nb[1] || t->nb[1] > t->nb[2] || t->nb[2] > t->nb[3]; -} - -static int htp_mminit_vec_dot(struct htp_matmul_context * mmctx, enum htp_data_type type) { +static int hvx_mm_init_vec_dot(struct htp_mm_context * mmctx, enum htp_data_type type) { switch (type) { case HTP_TYPE_Q4_0: - mmctx->type = "q4x4x2-f32"; - mmctx->vec_dot_1x1 = vec_dot_q4x4x2_q8x4x2_1x1; - mmctx->vec_dot_2x1 = vec_dot_q4x4x2_q8x4x2_2x1; - mmctx->vec_dot_2x2 = vec_dot_q4x4x2_q8x4x2_2x2; - mmctx->vec_dot_4x1 = vec_dot_q4x4x2_q8x4x2_4x1; + mmctx->type = "q4_0_tiled-f32"; + mmctx->vec_dot_32x1 = tiled_vec_dot_q4_0_32x1; return 0; case HTP_TYPE_Q4_1: - mmctx->type = "q4_1x4x2-f32"; - mmctx->vec_dot_1x1 = vec_dot_q4_1x4x2_q8x4x2_1x1; - mmctx->vec_dot_2x1 = vec_dot_q4_1x4x2_q8x4x2_2x1; - mmctx->vec_dot_2x2 = vec_dot_q4_1x4x2_q8x4x2_2x2; - mmctx->vec_dot_4x1 = vec_dot_q4_1x4x2_q8x4x2_4x1; + mmctx->type = "q4_1_tiled-f32"; + mmctx->vec_dot_32x1 = tiled_vec_dot_q4_1_32x1; return 0; case HTP_TYPE_Q8_0: - mmctx->type = "q8x4x2-f32"; - mmctx->vec_dot_1x1 = vec_dot_q8x4x2_q8x4x2_1x1; - mmctx->vec_dot_2x1 = vec_dot_q8x4x2_q8x4x2_2x1; - mmctx->vec_dot_2x2 = vec_dot_q8x4x2_q8x4x2_2x2; - mmctx->vec_dot_4x1 = vec_dot_q8x4x2_q8x4x2_4x1; + mmctx->type = "q8_0_tiled-f32"; + mmctx->vec_dot_32x1 = tiled_vec_dot_q8_0_32x1; return 0; case HTP_TYPE_IQ4_NL: - mmctx->type = "iq4nlx4x2-f32"; - mmctx->vec_dot_1x1 = vec_dot_iq4nlx4x2_q8x4x2_1x1; - mmctx->vec_dot_2x1 = vec_dot_iq4nlx4x2_q8x4x2_2x1; - mmctx->vec_dot_2x2 = vec_dot_iq4nlx4x2_q8x4x2_2x2; - mmctx->vec_dot_4x1 = vec_dot_iq4nlx4x2_q8x4x2_4x1; + mmctx->type = "iq4nl_tiled-f32"; + mmctx->vec_dot_32x1 = tiled_vec_dot_iq4nl_32x1; return 0; case HTP_TYPE_MXFP4: - mmctx->type = "mxfp4x4x2-f32"; - mmctx->vec_dot_1x1 = vec_dot_mxfp4x4x2_q8x4x2_1x1; - mmctx->vec_dot_2x1 = vec_dot_mxfp4x4x2_q8x4x2_2x1; - mmctx->vec_dot_2x2 = vec_dot_mxfp4x4x2_q8x4x2_2x2; - mmctx->vec_dot_4x1 = vec_dot_mxfp4x4x2_q8x4x2_4x1; + mmctx->type = "mxfp4_tiled-f32"; + mmctx->vec_dot_32x1 = tiled_vec_dot_mxfp4_32x1; return 0; default: return -1; } } -static void htp_mminit_spad(struct htp_ops_context * octx, - size_t dst_row_size, - size_t src0_row_size_padded, - size_t src1_row_size, - uint32_t src1_nrows, - size_t src2_spad_size_per_thread) { - octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); - octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256); - - if (src2_spad_size_per_thread > 0) { - octx->src2_spad.size_per_thread = src2_spad_size_per_thread; - octx->src2_spad.size = octx->src2_spad.size_per_thread; - } - - // src0 spad is also used in dynamic quantizer to store padded src1 rows - size_t src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float)); - if (octx->src0_spad.size_per_thread < src1_row_size_padded) { - octx->src0_spad.size_per_thread = src1_row_size_padded; - } - - octx->src1_spad.size = octx->src1_spad.size_per_thread; - octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; - octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; -} - -static int op_matmul_hvx(struct htp_ops_context * octx) { +static int hvx_mm_matmul(struct htp_ops_context * octx) { htp_matmul_tensors_preamble; - struct htp_matmul_context mmctx_struct = {0}; - struct htp_matmul_context * mmctx = &mmctx_struct; + struct htp_mm_context mmctx_struct = {0}; + struct htp_mm_context * mmctx = &mmctx_struct; mmctx->octx = octx; + const struct htp_mm_kernel_params * kparams = (const struct htp_mm_kernel_params *) octx->kernel_params; + const uint32_t src0_nrows = ne01 * ne02 * ne03; const uint32_t src1_nrows = ne11 * ne12 * ne13; + bool is_repacked = (src0->type == HTP_TYPE_Q4_0 || src0->type == HTP_TYPE_Q4_1 || + src0->type == HTP_TYPE_Q8_0 || src0->type == HTP_TYPE_IQ4_NL || + src0->type == HTP_TYPE_MXFP4); + // Compute src0_nrows_per_thread mmctx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads; - mmctx->src0_nrows_per_thread += (mmctx->src0_nrows_per_thread & 1); // round up to even + if (is_repacked) { + mmctx->src0_nrows_per_thread = hex_round_up(mmctx->src0_nrows_per_thread, 32); + } else { + mmctx->src0_nrows_per_thread += (mmctx->src0_nrows_per_thread & 1); // round up to even + } const size_t src0_row_size = nb01; const size_t dst_row_size = nb1; @@ -4527,178 +1626,213 @@ static int op_matmul_hvx(struct htp_ops_context * octx) { size_t src1_row_size_padded; worker_callback_t quant_job_func; - worker_callback_t matmul_job_func = src1_nrows > 1 ? matmul_2d : matvec_2d; + worker_callback_t matmul_job_func; + uint32_t n_quant_jobs = 1; + if (src1_nrows > 1) { + if (is_repacked) { + switch (src0->type) { + case HTP_TYPE_Q4_0: matmul_job_func = hvx_mm_2d_repacked_q4_0; break; + case HTP_TYPE_Q4_1: matmul_job_func = hvx_mm_2d_repacked_q4_1; break; + case HTP_TYPE_Q8_0: matmul_job_func = hvx_mm_2d_repacked_q8_0; break; + case HTP_TYPE_IQ4_NL: matmul_job_func = hvx_mm_2d_repacked_iq4nl; break; + case HTP_TYPE_MXFP4: matmul_job_func = hvx_mm_2d_repacked_mxfp4; break; + default: return HTP_STATUS_NO_SUPPORT; + } + } else { + matmul_job_func = hvx_mm_2d; + } + } else { + if (is_repacked) { + switch (src0->type) { + case HTP_TYPE_Q4_0: matmul_job_func = hvx_mv_2d_repacked_q4_0; break; + case HTP_TYPE_Q4_1: matmul_job_func = hvx_mv_2d_repacked_q4_1; break; + case HTP_TYPE_Q8_0: matmul_job_func = hvx_mv_2d_repacked_q8_0; break; + case HTP_TYPE_IQ4_NL: matmul_job_func = hvx_mv_2d_repacked_iq4nl; break; + case HTP_TYPE_MXFP4: matmul_job_func = hvx_mv_2d_repacked_mxfp4; break; + default: return HTP_STATUS_NO_SUPPORT; + } + } else { + matmul_job_func = hvx_mv_2d; + } + } bool need_quant = true; - if (src0->type == HTP_TYPE_F16) { - // Try optimized f16-f16 path first (src1 in VTCM) - const size_t f16_src1_row_size = hex_round_up(ne10 * 2, 128); - const size_t f16_src1_spad_size = hex_round_up(f16_src1_row_size * src1_nrows, 256); - const size_t f16_src0_spad_size = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256) * octx->n_threads; - const size_t f16_dst_spad_size = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256) * octx->n_threads; + switch (kparams->kernel_type) { + case HTP_MM_KERNEL_HVX_F16_F16_VTCM: + quant_job_func = (src1->type == HTP_TYPE_F32) ? quantize_f32_f16_flat : quantize_f16_f16_flat; + mmctx->type = "f16-f16"; + mmctx->vec_dot_1x1 = vec_dot_f16_f16_aa_1x1; + mmctx->vec_dot_2x1 = vec_dot_f16_f16_aa_2x1; + mmctx->vec_dot_2x2 = vec_dot_f16_f16_aa_2x2; + src1_row_size = hex_round_up(ne10 * 2, 128); + break; - const size_t f16_total_size = f16_src1_spad_size + f16_src0_spad_size + f16_dst_spad_size; + case HTP_MM_KERNEL_HVX_F16_F32_DDR: + mmctx->type = "f16-f32"; + mmctx->vec_dot_1x1 = vec_dot_f16_f32_uu_1x1; + matmul_job_func = hvx_mm_4d; + mmctx->mm_div_ne12_ne1 = kparams->div_ne12_ne1; + mmctx->mm_div_ne1 = kparams->div_ne1; + mmctx->mm_div_r2 = kparams->div_r2; + mmctx->mm_div_r3 = kparams->div_r3; + need_quant = false; + quant_job_func = NULL; + src1_row_size = nb11; + break; - // Default matmul implementation does not support multi-batch src0 (N-vs-N broadcasting). - // It only supports 1-vs-N broadcasting (src0 is 2D) or standard 2D matmul. - const bool is_batched = (ne02 > 1) || (ne03 > 1); - const bool is_permuted = htp_is_permuted(octx->src[0]) || htp_is_permuted(octx->src[1]); + case HTP_MM_KERNEL_HVX_F16_F16_DDR: + mmctx->type = "f16-f16"; + mmctx->vec_dot_1x1 = vec_dot_f16_f16_uu_1x1; + matmul_job_func = hvx_mm_4d; + mmctx->mm_div_ne12_ne1 = kparams->div_ne12_ne1; + mmctx->mm_div_ne1 = kparams->div_ne1; + mmctx->mm_div_r2 = kparams->div_r2; + mmctx->mm_div_r3 = kparams->div_r3; + src1_row_size = nb11; + need_quant = false; + quant_job_func = NULL; + break; - if (!is_batched && !is_permuted && f16_total_size <= octx->ctx->vtcm_size) { - // Optimized path - quant_job_func = (src1->type == HTP_TYPE_F32) ? quantize_f32_f16 : quantize_f16_f16; - mmctx->type = "f16-f16"; - mmctx->vec_dot_1x1 = vec_dot_f16_f16_aa_1x1; - mmctx->vec_dot_2x1 = vec_dot_f16_f16_aa_2x1; - mmctx->vec_dot_2x2 = vec_dot_f16_f16_aa_2x2; + case HTP_MM_KERNEL_HVX_F32_F32_VTCM: + quant_job_func = quantize_f32_f32_flat; + mmctx->type = "f32-f32"; + mmctx->vec_dot_1x1 = vec_dot_f32_f32_aa_1x1; + mmctx->vec_dot_2x1 = vec_dot_f32_f32_aa_2x1; + mmctx->vec_dot_2x2 = vec_dot_f32_f32_aa_2x2; + src1_row_size = hex_round_up(ne10 * 4, 128); + break; - src1_row_size = f16_src1_row_size; // row size post quantization + case HTP_MM_KERNEL_HVX_F32_F32_DDR: + quant_job_func = NULL; + mmctx->type = "f32-f32"; + mmctx->vec_dot_1x1 = vec_dot_f32_f32_uu_1x1; + mmctx->mm_div_ne12_ne1 = kparams->div_ne12_ne1; + mmctx->mm_div_ne1 = kparams->div_ne1; + mmctx->mm_div_r2 = kparams->div_r2; + mmctx->mm_div_r3 = kparams->div_r3; + src1_row_size = nb11; + need_quant = false; + matmul_job_func = hvx_mm_4d; + break; - octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); - octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256); + case HTP_MM_KERNEL_HVX_QUANT_ROW_FLAT: { + n_quant_jobs = MIN(src1_nrows, octx->n_threads); + quant_job_func = (src0->type == HTP_TYPE_Q4_1) ? quantize_f32_q8_1_flat : quantize_f32_q8_0_flat; + src1_row_size = (src0->type == HTP_TYPE_Q4_1) ? htp_mm_q8_1_flat_row_size(ne10) : htp_mm_q8_0_flat_row_size(ne10); - octx->src1_spad.size = octx->src1_spad.size_per_thread; - octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; - octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; - } else { - // Fallback to f16/f32 (DDR) if src1 doesn't fit in VTCM or broadcasting is required - quant_job_func = NULL; - if (src1->type == HTP_TYPE_F32) { - mmctx->type = "f16-f32"; - mmctx->vec_dot_1x1 = vec_dot_f16_f32_uu_1x1; - matmul_job_func = matmul_4d; + if (src1_nrows > 1) { + switch (src0->type) { + case HTP_TYPE_Q4_0: matmul_job_func = hvx_mm_2d_repacked_q4_0_flat; break; + case HTP_TYPE_Q4_1: matmul_job_func = hvx_mm_2d_repacked_q4_1_flat; break; + case HTP_TYPE_Q8_0: matmul_job_func = hvx_mm_2d_repacked_q8_0_flat; break; + case HTP_TYPE_IQ4_NL: matmul_job_func = hvx_mm_2d_repacked_iq4nl_flat; break; + case HTP_TYPE_MXFP4: matmul_job_func = hvx_mm_2d_repacked_mxfp4_flat; break; + default: return HTP_STATUS_NO_SUPPORT; + } } else { - mmctx->type = "f16-f16"; - mmctx->vec_dot_1x1 = vec_dot_f16_f16_uu_1x1; - matmul_job_func = matmul_4d; + switch (src0->type) { + case HTP_TYPE_Q4_0: matmul_job_func = hvx_mv_2d_repacked_q4_0_flat; break; + case HTP_TYPE_Q4_1: matmul_job_func = hvx_mv_2d_repacked_q4_1_flat; break; + case HTP_TYPE_Q8_0: matmul_job_func = hvx_mv_2d_repacked_q8_0_flat; break; + case HTP_TYPE_IQ4_NL: matmul_job_func = hvx_mv_2d_repacked_iq4nl_flat; break; + case HTP_TYPE_MXFP4: matmul_job_func = hvx_mv_2d_repacked_mxfp4_flat; break; + default: return HTP_STATUS_NO_SUPPORT; + } + } + break; + } + + case HTP_MM_KERNEL_HVX_QUANT_BLOCK: + case HTP_MM_KERNEL_HVX_QUANT_ROW: + default: + if (hvx_mm_init_vec_dot(mmctx, src0->type) != 0) { + return HTP_STATUS_NO_SUPPORT; } - src1_row_size = nb11; // original row size in DDR + const uint32_t qk = QK_Q8_0_TILED; + const uint32_t nb = (ne10 + qk - 1) / qk; + const uint32_t total_nb = src1_nrows * nb; - octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size, 256); - octx->src1_spad.size_per_thread = hex_round_up(MM_SPAD_SRC1_NROWS * src1_row_size, 256); - - octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; - octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads; - octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; - - // Init fastdiv for matmul_4d (supports broadcasting) - mmctx->mm_div_ne12_ne1 = init_fastdiv_values(src1->ne[2] * dst->ne[1]); - mmctx->mm_div_ne1 = init_fastdiv_values(dst->ne[1]); - mmctx->mm_div_r2 = init_fastdiv_values(src1->ne[2] / src0->ne[2]); - mmctx->mm_div_r3 = init_fastdiv_values(src1->ne[3] / src0->ne[3]); - - need_quant = false; - } - } else if (src0->type == HTP_TYPE_F32) { - // Try optimized f32-f32 path first (src1 in VTCM) - const size_t f32_src1_row_size = hex_round_up(ne10 * 4, 128); - const size_t f32_src1_spad_size = hex_round_up(f32_src1_row_size * src1_nrows, 256); - const size_t f32_src0_spad_size = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256) * octx->n_threads; - const size_t f32_dst_spad_size = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256) * octx->n_threads; - - const size_t f32_total_size = f32_src1_spad_size + f32_src0_spad_size + f32_dst_spad_size; - - const bool is_batched = (ne02 > 1) || (ne03 > 1); - const bool is_permuted = htp_is_permuted(octx->src[0]) || htp_is_permuted(octx->src[1]); - - if (!is_batched && !is_permuted && f32_total_size <= octx->ctx->vtcm_size) { - // Optimized path - quant_job_func = quantize_f32_f32; - mmctx->type = "f32-f32"; - mmctx->vec_dot_1x1 = vec_dot_f32_f32_aa_1x1; - mmctx->vec_dot_2x1 = vec_dot_f32_f32_aa_2x1; - mmctx->vec_dot_2x2 = vec_dot_f32_f32_aa_2x2; - - src1_row_size = f32_src1_row_size; - - octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); - octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256); - - octx->src1_spad.size = octx->src1_spad.size_per_thread; - octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; - octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; - } else { - // Fallback to DDR / broadcasting - quant_job_func = NULL; - mmctx->type = "f32-f32"; - mmctx->vec_dot_1x1 = vec_dot_f32_f32_uu_1x1; - matmul_job_func = matmul_4d; - - src1_row_size = nb11; - - octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size, 256); - octx->src1_spad.size_per_thread = hex_round_up(MM_SPAD_SRC1_NROWS * src1_row_size, 256); - - octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; - octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads; - octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; - - // Init fastdiv for matmul_4d (supports broadcasting) - mmctx->mm_div_ne12_ne1 = init_fastdiv_values(src1->ne[2] * dst->ne[1]); - mmctx->mm_div_ne1 = init_fastdiv_values(dst->ne[1]); - mmctx->mm_div_r2 = init_fastdiv_values(src1->ne[2] / src0->ne[2]); - mmctx->mm_div_r3 = init_fastdiv_values(src1->ne[3] / src0->ne[3]); - - need_quant = false; - } - } else { - if (htp_mminit_vec_dot(mmctx, src0->type) != 0) { - return HTP_STATUS_NO_SUPPORT; - } - - if (src0->type == HTP_TYPE_Q4_1) { - quant_job_func = quantize_f32_q8_1x4x2; - src1_row_size = q8_1x4x2_row_size(ne10); - } else { - quant_job_func = quantize_f32_q8x4x2; - src1_row_size = q8x4x2_row_size(ne10); - } - htp_mminit_spad(octx, dst_row_size, src0_row_size_padded, src1_row_size, src1_nrows, 0); + if (src1_nrows < octx->n_threads) { + n_quant_jobs = MIN(total_nb, octx->n_threads); + quant_job_func = (src0->type == HTP_TYPE_Q4_1) ? quantize_f32_q8_1_tiled_block : quantize_f32_q8_0_tiled_block; + for (uint32_t ith = 0; ith < n_quant_jobs; ++ith) { + uint32_t ib_first = (total_nb * ith) / n_quant_jobs; + uint32_t ib_last = (total_nb * (ith + 1)) / n_quant_jobs; + mmctx->quant_ib_first[ith] = ib_first; + mmctx->quant_ib_last[ith] = ib_last; + mmctx->quant_r[ith] = ib_first / nb; + mmctx->quant_c[ith] = ib_first % nb; + } + } else { + n_quant_jobs = MIN(src1_nrows, octx->n_threads); + quant_job_func = (src0->type == HTP_TYPE_Q4_1) ? quantize_f32_q8_1_tiled : quantize_f32_q8_0_tiled; + } + src1_row_size = (src0->type == HTP_TYPE_Q4_1) ? htp_mm_q8_1_tiled_row_size(ne10) : htp_mm_q8_0_tiled_row_size(ne10); + break; } - // VTCM scratchpads for all tensors - size_t spad_size = octx->src1_spad.size + octx->src0_spad.size + octx->dst_spad.size; + size_t src0_sz = 0, src1_sz = 0, dst_sz = 0; + if (kparams->vtcm_src0_size > 0 || kparams->vtcm_src1_size > 0 || kparams->vtcm_dst_size > 0) { + src0_sz = kparams->vtcm_src0_size; + src1_sz = kparams->vtcm_src1_size; + dst_sz = kparams->vtcm_dst_size; + } else { + const uint32_t n_prefetch = kparams->n_prefetch; + assert(n_prefetch >= 2 && n_prefetch <= HTP_MM_MAX_PREFETCH && (n_prefetch & (n_prefetch - 1)) == 0); + htp_mm_hvx_get_vtcm_sizes( + kparams->kernel_type, src0->type, ne10, src1_nrows, octx->n_threads, + dst_row_size, src0_row_size, src1_row_size, n_prefetch, + &src0_sz, &src1_sz, &dst_sz + ); + } - FARF(HIGH, "matmul-%s : src0-spad-size %u src1-spad-size %u dst-spad-size %u (%zu)\n", mmctx->type, - octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size, spad_size); + if (kparams->kernel_type == HTP_MM_KERNEL_HVX_F16_F16_VTCM || + kparams->kernel_type == HTP_MM_KERNEL_HVX_F32_F32_VTCM || + kparams->kernel_type == HTP_MM_KERNEL_HVX_QUANT_ROW || + kparams->kernel_type == HTP_MM_KERNEL_HVX_QUANT_BLOCK) { + mmctx->vtcm_src1_size_per_thread = src1_sz; + } else { + mmctx->vtcm_src1_size_per_thread = src1_sz / octx->n_threads; + } + + mmctx->vtcm_src0_size_per_thread = src0_sz / octx->n_threads; + mmctx->vtcm_dst_size_per_thread = dst_sz / octx->n_threads; + + size_t vtcm_size = kparams->vtcm_size > 0 ? (size_t)kparams->vtcm_size : (src1_sz + src0_sz + dst_sz); + + FARF(HIGH, "matmul-%s : src0-vtcm-size %zu src1-vtcm-size %zu dst-vtcm-size %zu (%zu)\n", mmctx->type, + src0_sz, src1_sz, dst_sz, vtcm_size); FARF(HIGH, "matmul-%s : %ux%ux%ux%u * %ux%ux%ux%u-> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n", mmctx->type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], src0->data, src1->data, dst->data); - // Make sure the reserved vtcm size is sufficient - if (octx->ctx->vtcm_size < spad_size) { + if (octx->ctx->vtcm_size < vtcm_size) { FARF(ERROR, "matmul-%s : current VTCM reservation %zu is too small, needed %zu\n", mmctx->type, - octx->ctx->vtcm_size, spad_size); + octx->ctx->vtcm_size, vtcm_size); return HTP_STATUS_VTCM_TOO_SMALL; } - // Place src1 spad first. We use it for dyn.quant and may reuse between ops - octx->src1_spad.data = octx->ctx->vtcm_base; - octx->src0_spad.data = octx->src1_spad.data + octx->src1_spad.size; - octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; + uint8_t * vtcm_ptr = (uint8_t *) octx->ctx->vtcm_base; + mmctx->vtcm_src1 = vtcm_seq_alloc(&vtcm_ptr, src1_sz); + mmctx->vtcm_src0 = vtcm_seq_alloc(&vtcm_ptr, src0_sz); + mmctx->vtcm_dst = vtcm_seq_alloc(&vtcm_ptr, dst_sz); - octx->src1_spad.src = (src1 == octx->src1_spad.src) ? src1 : NULL; + octx->src1_spad.src = NULL; octx->src0_spad.src = NULL; octx->dst_spad.src = NULL; - octx->src0_spad.stride = src0_row_size_padded; - octx->src1_spad.stride = src1_row_size; + mmctx->vtcm_src0_stride = src0_row_size_padded; + mmctx->vtcm_src1_stride = src1_row_size; if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) return HTP_STATUS_OK; - if (need_quant && !octx->src1_spad.src) { - const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads); + if (need_quant) { mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs; worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs); - octx->src1_spad.src = src1; } const uint32_t n_matmul_jobs = octx->n_threads; @@ -4707,72 +1841,1209 @@ static int op_matmul_hvx(struct htp_ops_context * octx) { return HTP_STATUS_OK; } -int op_matmul(struct htp_ops_context * octx) { +static void hvx_mm_qkv_2d(unsigned int nth, unsigned int ith, void * data) { + struct htp_mm_context * mmctx = data; + struct htp_ops_context * octx = mmctx->octx; + + const struct htp_tensor * restrict src0 = octx->src[0]; // Wk + const struct htp_tensor * restrict src1 = octx->src[1]; // x + const struct htp_tensor * restrict src2 = octx->src[2]; // Wv + const struct htp_tensor * restrict src3 = octx->src[3]; // Wq + const struct htp_tensor * restrict dst_k = octx->dsts[0]; + const struct htp_tensor * restrict dst_v = octx->dsts[1]; + const struct htp_tensor * restrict dst_q = octx->dsts[2]; + + const uint32_t ne00 = src0->ne[0]; + const uint32_t ne01 = src0->ne[1]; + const uint32_t ne02 = src0->ne[2]; + const uint32_t ne03 = src0->ne[3]; + + const uint32_t ne11 = src1->ne[1]; + const uint32_t ne12 = src1->ne[2]; + const uint32_t ne13 = src1->ne[3]; + + const uint32_t src0_nrows = ne01 * ne02 * ne03; + const uint32_t src1_nrows = ne11 * ne12 * ne13; + + const uint32_t src0_nrows_per_thread = mmctx->src0_nrows_per_thread; + const uint32_t src0_start_row = src0_nrows_per_thread * ith; + const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); + const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U); + + if (src0_start_row >= src0_end_row) { + return; + } + + const size_t dst_k_row_size = dst_k->nb[1]; // K and V share output width + const size_t dst_q_row_size = dst_q->nb[1]; // Q may be wider (GQA) + const size_t src0_row_size = src0->nb[1]; + const size_t src2_row_size = src2->nb[1]; + const size_t src3_row_size = src3->nb[1]; + + const size_t src0_stride = mmctx->vtcm_src0_stride; + const size_t src2_stride = mmctx->vtcm_src2_stride; + const size_t src3_stride = mmctx->vtcm_src3_stride; + const size_t src1_stride = mmctx->vtcm_src1_stride; + + uint8_t * restrict vtcm_src0_ptr = mmctx->vtcm_src0 + mmctx->vtcm_src0_size_per_thread * ith; + uint8_t * restrict vtcm_src2_ptr = mmctx->vtcm_src2 + mmctx->vtcm_src2_size_per_thread * ith; + uint8_t * restrict vtcm_src3_ptr = mmctx->vtcm_src3 + mmctx->vtcm_src3_size_per_thread * ith; + uint8_t * restrict src1_data = mmctx->vtcm_src1; + + dma_queue * dma_queue = octx->ctx->dma[ith]; + + const struct htp_mm_kernel_params * kparams = (const struct htp_mm_kernel_params *) octx->kernel_params; + const uint32_t n_prefetch = kparams->n_prefetch; + assert(n_prefetch >= 2 && n_prefetch <= HTP_MM_MAX_PREFETCH && (n_prefetch & (n_prefetch - 1)) == 0); + const uint32_t prefetch_mask = n_prefetch - 1; + + const uint8_t * restrict src0_row = (const uint8_t *) src0->data; + const uint8_t * restrict src2_row = (const uint8_t *) src2->data; + const uint8_t * restrict src3_row = (const uint8_t *) src3->data; + + // Prefill spad with src0, src2, src3 rows + for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { + const int is0 = (ir0 - src0_start_row); + if (is0 >= (int)n_prefetch) { + break; + } + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src0_ptr + is0 * src0_stride, src0_row + ir0 * src0_row_size), + src0_stride, src0_row_size, src0_row_size, 2); + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src2_ptr + is0 * src2_stride, src2_row + ir0 * src2_row_size), + src2_stride, src2_row_size, src2_row_size, 2); + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src3_ptr + is0 * src3_stride, src3_row + ir0 * src3_row_size), + src3_stride, src3_row_size, src3_row_size, 2); + } + + // Process rows + for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { + const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; + const uint8_t * ss2 = dma_queue_pop(dma_queue).dst; + const uint8_t * ss3 = dma_queue_pop(dma_queue).dst; + + // Process src1 columns in pairs (2×2 tiling) + uint32_t ir1 = 0; + for (; ir1 + 1 < src1_nrows; ir1 += 2) { + const uint8_t * restrict src1_col0 = (const uint8_t *) (src1_data + (ir1+0) * src1_stride); + const uint8_t * restrict src1_col1 = (const uint8_t *) (src1_data + (ir1+1) * src1_stride); + + float * restrict dst_row0_k = (float *) (dst_k->data + ((ir1+0) * dst_k_row_size)); + float * restrict dst_row1_k = (float *) (dst_k->data + ((ir1+1) * dst_k_row_size)); + mmctx->vec_dot_2x2(ne00, &dst_row0_k[ir0], &dst_row1_k[ir0], ss0, ss0 + src0_stride, src1_col0, src1_col1); + + float * restrict dst_row0_v = (float *) (dst_v->data + ((ir1+0) * dst_k_row_size)); + float * restrict dst_row1_v = (float *) (dst_v->data + ((ir1+1) * dst_k_row_size)); + mmctx->vec_dot_2x2(ne00, &dst_row0_v[ir0], &dst_row1_v[ir0], ss2, ss2 + src2_stride, src1_col0, src1_col1); + + float * restrict dst_row0_q = (float *) (dst_q->data + ((ir1+0) * dst_q_row_size)); + float * restrict dst_row1_q = (float *) (dst_q->data + ((ir1+1) * dst_q_row_size)); + mmctx->vec_dot_2x2(ne00, &dst_row0_q[ir0], &dst_row1_q[ir0], ss3, ss3 + src3_stride, src1_col0, src1_col1); + } + + // Handle remaining src1 rows (fallback to 2×1) + for (; ir1 < src1_nrows; ++ir1) { + const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride); + + float * restrict dst_row_k = (float *) (dst_k->data + (ir1 * dst_k_row_size)); + mmctx->vec_dot_2x1(ne00, &dst_row_k[ir0], ss0, ss0 + src0_stride, src1_col); + + float * restrict dst_row_v = (float *) (dst_v->data + (ir1 * dst_k_row_size)); + mmctx->vec_dot_2x1(ne00, &dst_row_v[ir0], ss2, ss2 + src2_stride, src1_col); + + float * restrict dst_row_q = (float *) (dst_q->data + (ir1 * dst_q_row_size)); + mmctx->vec_dot_2x1(ne00, &dst_row_q[ir0], ss3, ss3 + src3_stride, src1_col); + } + + // Prefetch next (n + vtcm_nrows) rows + const int pr0 = (ir0 + n_prefetch); + const int is0 = (pr0 - src0_start_row) & prefetch_mask; + if (pr0 < src0_end_row_x2) { + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src0_ptr + is0 * src0_stride, src0_row + pr0 * src0_row_size), + src0_stride, src0_row_size, src0_row_size, 2); + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src2_ptr + is0 * src2_stride, src2_row + pr0 * src2_row_size), + src2_stride, src2_row_size, src2_row_size, 2); + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src3_ptr + is0 * src3_stride, src3_row + pr0 * src3_row_size), + src3_stride, src3_row_size, src3_row_size, 2); + } + } + + // Process last row (if any) + if (src0_end_row != src0_end_row_x2) { + uint32_t ir0 = src0_end_row_x2; + const int is0 = (ir0 - src0_start_row) & prefetch_mask; + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src0_ptr + is0 * src0_stride, src0_row + ir0 * src0_row_size), + src0_stride, src0_row_size, src0_row_size, 1); + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src2_ptr + is0 * src2_stride, src2_row + ir0 * src2_row_size), + src2_stride, src2_row_size, src2_row_size, 1); + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src3_ptr + is0 * src3_stride, src3_row + ir0 * src3_row_size), + src3_stride, src3_row_size, src3_row_size, 1); + + const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; + const uint8_t * ss2 = dma_queue_pop(dma_queue).dst; + const uint8_t * ss3 = dma_queue_pop(dma_queue).dst; + + for (uint32_t ir1 = 0; ir1 < src1_nrows; ++ir1) { + const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride); + + float * restrict dst_row_k = (float *) (dst_k->data + (ir1 * dst_k_row_size)); + mmctx->vec_dot_1x1(ne00, &dst_row_k[ir0], ss0, src1_col); + + float * restrict dst_row_v = (float *) (dst_v->data + (ir1 * dst_k_row_size)); + mmctx->vec_dot_1x1(ne00, &dst_row_v[ir0], ss2, src1_col); + + float * restrict dst_row_q = (float *) (dst_q->data + (ir1 * dst_q_row_size)); + mmctx->vec_dot_1x1(ne00, &dst_row_q[ir0], ss3, src1_col); + } + } +} + +static void hvx_mm_ffn_2d(unsigned int nth, unsigned int ith, void * data) { + struct htp_mm_context * mmctx = data; + struct htp_ops_context * octx = mmctx->octx; + + const struct htp_tensor * restrict src0 = octx->src[0]; // Wgate + const struct htp_tensor * restrict src1 = octx->src[1]; // y + const struct htp_tensor * restrict src2 = octx->src[2]; // Wup + const struct htp_tensor * restrict dst_gate = octx->dsts[0]; + const struct htp_tensor * restrict dst_up = octx->dsts[1]; + + const uint32_t ne00 = src0->ne[0]; + const uint32_t ne01 = src0->ne[1]; + const uint32_t ne02 = src0->ne[2]; + const uint32_t ne03 = src0->ne[3]; + + const uint32_t ne11 = src1->ne[1]; + const uint32_t ne12 = src1->ne[2]; + const uint32_t ne13 = src1->ne[3]; + + const uint32_t src0_nrows = ne01 * ne02 * ne03; + const uint32_t src1_nrows = ne11 * ne12 * ne13; + + const uint32_t src0_nrows_per_thread = mmctx->src0_nrows_per_thread; + const uint32_t src0_start_row = src0_nrows_per_thread * ith; + const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); + const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U); + + if (src0_start_row >= src0_end_row) { + return; + } + + const size_t dst_row_size = dst_gate->nb[1]; + const size_t src0_row_size = src0->nb[1]; + const size_t src2_row_size = src2->nb[1]; + + const size_t src0_stride = mmctx->vtcm_src0_stride; + const size_t src2_stride = mmctx->vtcm_src2_stride; + const size_t src1_stride = mmctx->vtcm_src1_stride; + + uint8_t * restrict vtcm_src0_ptr = mmctx->vtcm_src0 + mmctx->vtcm_src0_size_per_thread * ith; + uint8_t * restrict vtcm_src2_ptr = mmctx->vtcm_src2 + mmctx->vtcm_src2_size_per_thread * ith; + uint8_t * restrict src1_data = mmctx->vtcm_src1; + + dma_queue * dma_queue = octx->ctx->dma[ith]; + + const struct htp_mm_kernel_params * kparams = (const struct htp_mm_kernel_params *) octx->kernel_params; + const uint32_t n_prefetch = kparams->n_prefetch; + assert(n_prefetch >= 2 && n_prefetch <= HTP_MM_MAX_PREFETCH && (n_prefetch & (n_prefetch - 1)) == 0); + const uint32_t prefetch_mask = n_prefetch - 1; + + const uint8_t * restrict src0_row = (const uint8_t *) src0->data; + const uint8_t * restrict src2_row = (const uint8_t *) src2->data; + + // Prefill spad with src0, src2 rows + for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { + const int is0 = (ir0 - src0_start_row); + if (is0 >= (int)n_prefetch) { + break; + } + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src0_ptr + is0 * src0_stride, src0_row + ir0 * src0_row_size), + src0_stride, src0_row_size, src0_row_size, 2); + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src2_ptr + is0 * src2_stride, src2_row + ir0 * src2_row_size), + src2_stride, src2_row_size, src2_row_size, 2); + } + + // Process rows + for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { + const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; + const uint8_t * ss2 = dma_queue_pop(dma_queue).dst; + + // Process src1 columns in pairs (2×2 tiling) + uint32_t ir1 = 0; + for (; ir1 + 1 < src1_nrows; ir1 += 2) { + const uint8_t * restrict src1_col0 = (const uint8_t *) (src1_data + (ir1+0) * src1_stride); + const uint8_t * restrict src1_col1 = (const uint8_t *) (src1_data + (ir1+1) * src1_stride); + + float * restrict dst_row0_gate = (float *) (dst_gate->data + ((ir1+0) * dst_row_size)); + float * restrict dst_row1_gate = (float *) (dst_gate->data + ((ir1+1) * dst_row_size)); + mmctx->vec_dot_2x2(ne00, &dst_row0_gate[ir0], &dst_row1_gate[ir0], ss0, ss0 + src0_stride, src1_col0, src1_col1); + + float * restrict dst_row0_up = (float *) (dst_up->data + ((ir1+0) * dst_row_size)); + float * restrict dst_row1_up = (float *) (dst_up->data + ((ir1+1) * dst_row_size)); + mmctx->vec_dot_2x2(ne00, &dst_row0_up[ir0], &dst_row1_up[ir0], ss2, ss2 + src2_stride, src1_col0, src1_col1); + } + + // Handle remaining src1 rows (fallback to 2×1) + for (; ir1 < src1_nrows; ++ir1) { + const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride); + + float * restrict dst_row_gate = (float *) (dst_gate->data + (ir1 * dst_row_size)); + mmctx->vec_dot_2x1(ne00, &dst_row_gate[ir0], ss0, ss0 + src0_stride, src1_col); + + float * restrict dst_row_up = (float *) (dst_up->data + (ir1 * dst_row_size)); + mmctx->vec_dot_2x1(ne00, &dst_row_up[ir0], ss2, ss2 + src2_stride, src1_col); + } + + // Prefetch next rows + const int pr0 = (ir0 + n_prefetch); + const int is0 = (pr0 - src0_start_row) & prefetch_mask; + if (pr0 < src0_end_row_x2) { + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src0_ptr + is0 * src0_stride, src0_row + pr0 * src0_row_size), + src0_stride, src0_row_size, src0_row_size, 2); + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src2_ptr + is0 * src2_stride, src2_row + pr0 * src2_row_size), + src2_stride, src2_row_size, src2_row_size, 2); + } + } + + // Process last row (if any) + if (src0_end_row != src0_end_row_x2) { + uint32_t ir0 = src0_end_row_x2; + const int is0 = (ir0 - src0_start_row) & prefetch_mask; + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src0_ptr + is0 * src0_stride, src0_row + ir0 * src0_row_size), + src0_stride, src0_row_size, src0_row_size, 1); + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src2_ptr + is0 * src2_stride, src2_row + ir0 * src2_row_size), + src2_stride, src2_row_size, src2_row_size, 1); + + const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; + const uint8_t * ss2 = dma_queue_pop(dma_queue).dst; + + for (uint32_t ir1 = 0; ir1 < src1_nrows; ++ir1) { + const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride); + + float * restrict dst_row_gate = (float *) (dst_gate->data + (ir1 * dst_row_size)); + mmctx->vec_dot_1x1(ne00, &dst_row_gate[ir0], ss0, src1_col); + + float * restrict dst_row_up = (float *) (dst_up->data + (ir1 * dst_row_size)); + mmctx->vec_dot_1x1(ne00, &dst_row_up[ir0], ss2, src1_col); + } + } +} + +#define DEQUANTIZE_WORKER_LOOP_IMPL(SUFFIX) \ +static void dequantize_tiled_worker_loop_##SUFFIX(unsigned int n, unsigned int i, void *data) { \ + tiled_dequantize_state_t *state = (tiled_dequantize_state_t *)data; \ + struct htp_thread_trace * tr = state->traces ? &state->traces[i] : NULL; \ + htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i); \ + for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) { \ + int start = task_id * state->n_tiles_per_task; \ + int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles); \ + dequantize_tiled_weight_to_fp16_task_##SUFFIX(state, start, end); \ + } \ + htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i); \ +} + +DEQUANTIZE_WORKER_LOOP_IMPL(q4_0) +DEQUANTIZE_WORKER_LOOP_IMPL(q4_1) +DEQUANTIZE_WORKER_LOOP_IMPL(iq4_nl) +DEQUANTIZE_WORKER_LOOP_IMPL(mxfp4) +DEQUANTIZE_WORKER_LOOP_IMPL(q8_0) + +static void convert_f16_worker_loop(unsigned int n, unsigned int i, void *data) { + tiled_dequantize_state_t *state = (tiled_dequantize_state_t *)data; + struct htp_thread_trace * tr = state->traces ? &state->traces[i] : NULL; + htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i); + for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) { + int start = task_id * state->n_tiles_per_task; + int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles); + convert_f16_weight_to_fp16_tiles_task(state, start, end); + } + htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i); +} + +static void quantize_f32_worker_loop(unsigned int n, unsigned int i, void *data) { + tiled_dequantize_state_t *state = (tiled_dequantize_state_t *)data; + + struct htp_thread_trace * tr = state->traces ? &state->traces[i] : NULL; + htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_QUANT, i); + + for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) { + int start = task_id * state->n_tiles_per_task; + int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles); + quantize_f32_weight_to_fp16_tiles_task(state, start, end); + } + + htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_QUANT, i); +} + +static void transfer_output_chunk_worker_fn(unsigned int n, unsigned int i, void *data) { + output_transfer_task_state_t *st = (output_transfer_task_state_t *) data; + + struct htp_thread_trace * tr = st->traces ? &st->traces[i] : NULL; + + int start_chunk_idx = i * st->n_chunks_per_task; + htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_O_PROC, start_chunk_idx); + + for (unsigned int task_id = i; task_id < (unsigned int)st->n_tasks; task_id += n) { + int chunk_idx = task_id * st->n_chunks_per_task; + size_t chunk_size = hex_smin(st->n_tot_chunks - chunk_idx, st->n_chunks_per_task); + + float *dst = st->dst + chunk_idx * st->dst_stride; + transfer_output_chunk_fp16_to_fp32(dst, st->vtcm_src, chunk_idx, chunk_size, st->n_cols, st->dst_stride, st->dst_cols); + } + + htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_O_PROC, start_chunk_idx); +} + +static void transfer_activation_chunk_worker_fn(unsigned int n, unsigned int i, void *data) { + activation_transfer_task_state_t *st = (activation_transfer_task_state_t *) data; + + struct htp_thread_trace * tr = st->traces ? &st->traces[i] : NULL; + + int start_chunk_idx = i * st->n_chunks_per_task; + htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_PREP, start_chunk_idx); + + for (unsigned int task_id = i; task_id < (unsigned int)st->n_tasks; task_id += n) { + int chunk_idx = task_id * st->n_chunks_per_task; + size_t chunk_size = hex_smin(st->n_tot_chunks - chunk_idx, st->n_chunks_per_task); + + __fp16 *dst = st->dst + chunk_idx * st->k_block; + const float *src = st->src + chunk_idx * st->k_stride; + + if (st->vtcm_f32_act) { + float *thread_f32_act = st->vtcm_f32_act + i * HTP_MM_DMA_ACT_MULTIPLIER * st->k_block; + transfer_activation_chunk_fp32_to_fp16_dma_pipelined( + st->ctx->dma[i], dst, src, chunk_size, st->k_block, st->k_stride, st->k_valid, thread_f32_act + ); + } else { + transfer_activation_chunk_fp32_to_fp16(dst, src, chunk_size, st->k_block, st->k_stride, st->k_valid); + } + } + + htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_PREP, start_chunk_idx); +} + +static void transfer_activation_chunk_gathered_worker_fn(unsigned int n, unsigned int i, void *data) { + activation_transfer_gathered_task_state_t *st = data; + struct htp_thread_trace * tr = st->traces ? &st->traces[i] : NULL; + int chunk_idx = i; + int chunk_size = st->n_chunks_per_task; + int start_row = st->start_row + chunk_idx * chunk_size; + int n_rows = hex_smin(st->cne1 - start_row, chunk_size); + if (n_rows > 0) { + htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_PREP, chunk_idx); + transfer_activation_chunk_fp32_to_fp16_gathered( + st->dst, st->src, start_row, n_rows, st->k_block, + st->matrix_rows, st->cur_a, st->mapping_stride, + st->ne11, &st->ne11_div, st->nb11, st->nb12, st->cne1, st->k_valid); + htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_PREP, chunk_idx); + } +} + +static void transfer_activation_chunk_gathered_worker_flat_fn(unsigned int n, unsigned int i, void *data) { + activation_transfer_gathered_task_state_t *st = data; + struct htp_thread_trace * tr = st->traces ? &st->traces[i] : NULL; + int chunk_idx = i; + int chunk_size = st->n_chunks_per_task; + int start_row = st->start_row + chunk_idx * chunk_size; + int n_rows = hex_smin(st->cne1 - start_row, chunk_size); + if (n_rows > 0) { + htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_PREP, chunk_idx); + transfer_activation_chunk_fp32_to_fp16_gathered_flat( + st->dst, st->src, start_row, n_rows, st->k_block, + st->matrix_rows, st->cur_a, st->mapping_stride, + st->nb12, st->cne1, st->k_valid); + htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_PREP, chunk_idx); + } +} + +static void transfer_output_chunk_scattered_worker_fn(unsigned int n, unsigned int i, void *data) { + output_transfer_scattered_task_state_t *st = data; + struct htp_thread_trace * tr = st->traces ? &st->traces[i] : NULL; + int chunk_idx = i; + int chunk_size = st->n_chunks_per_task; + int start_row = st->start_row + chunk_idx * chunk_size; + int n_rows = hex_smin(st->cne1 - start_row, chunk_size); + if (n_rows > 0) { + htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_O_PROC, chunk_idx); + transfer_output_chunk_fp16_to_fp32_scattered( + st->dst, st->vtcm_src, start_row, n_rows, st->n_cols, + st->matrix_rows, st->cur_a, st->mapping_stride, + st->dst_nb1, st->dst_nb2, st->cne1); + htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_O_PROC, chunk_idx); + } +} + +// --- HMX Dispatchers & Entry Points --- + +static void dequantize_tiled_weight_chunk_to_fp16_tiles( + struct htp_context *ctx, __fp16 *vtcm_dst, + const void *weight_src_ddr, + int n_cols, int k_block, + size_t row_stride, int weight_type, + int n_k_tiles, struct fastdiv_values n_k_tiles_div, + worker_callback_t dequant_worker_fn, int n_threads) { + + assert(n_cols % HTP_MM_HMX_TILE_N_COLS == 0); + assert(k_block % HTP_MM_HMX_TILE_N_COLS == 0); + + size_t n_col_tiles = n_cols / HTP_MM_HMX_TILE_N_COLS; + size_t n_tot_tiles = n_col_tiles * n_k_tiles; + + size_t n_tiles_per_task = (n_threads == 1) ? n_tot_tiles : hmx_ceil_div(n_tot_tiles, n_threads); + + tiled_dequantize_state_t state; + state.n_tasks = (n_tot_tiles + n_tiles_per_task - 1) / n_tiles_per_task; + state.n_tot_tiles = n_tot_tiles; + state.n_tiles_per_task = n_tiles_per_task; + state.dst = vtcm_dst; + state.src = (const uint8_t *)weight_src_ddr; + state.n_cols = n_cols; + state.k_block = k_block; + state.row_stride = row_stride; + state.weight_type = weight_type; + state.n_k_tiles = n_k_tiles; + state.n_k_tiles_div = n_k_tiles_div; + state.traces = ctx->trace; + state.ctx = ctx; + + state.tile_size = htp_mm_get_weight_tile_size(weight_type); + state.aligned_tile_size = htp_mm_get_weight_aligned_tile_size(weight_type); + + if (state.n_tasks == 1 || n_threads == 1) { + dequant_worker_fn(1, 0, &state); + } else { + int n_tasks = hex_smin((int) state.n_tasks, n_threads); + worker_pool_run_func(ctx->worker_pool, dequant_worker_fn, &state, n_tasks); + } +} + +static void transfer_output_chunk_threaded(struct htp_context *ctx, float *dst, const __fp16 *vtcm_src, + int n_rows, int n_cols, int dst_stride, int dst_cols, int n_threads) { + assert(n_cols % HTP_MM_HMX_TILE_N_COLS == 0); + + if (n_rows <= 0) return; + + size_t n_tot_chunks = n_rows; + size_t n_chunks_per_task = (n_threads == 1) ? n_tot_chunks : hmx_ceil_div(n_rows, n_threads); + n_chunks_per_task = hex_align_up(n_chunks_per_task, 2); + + int actual_threads = hmx_ceil_div(n_rows, n_chunks_per_task); + + output_transfer_task_state_t state; + state.n_tasks = actual_threads; + state.n_tot_chunks = n_tot_chunks; + state.n_chunks_per_task = n_chunks_per_task; + state.dst = dst; + state.vtcm_src = vtcm_src; + state.n_cols = n_cols; + state.dst_stride = dst_stride; + state.dst_cols = dst_cols; + state.traces = ctx->trace; + + if (actual_threads <= 1) { + transfer_output_chunk_worker_fn(1, 0, &state); + } else { + worker_pool_run_func(ctx->worker_pool, transfer_output_chunk_worker_fn, &state, actual_threads); + } +} + +static void transfer_activation_chunk_threaded( + struct htp_context *ctx, + __fp16 *dst, + const float *src, + int n_rows, + int k_block, + int k_stride, + int n_threads, + int k_valid, + float *vtcm_f32_act) { + assert(k_block % HTP_MM_HMX_TILE_N_COLS == 0 && k_stride % HTP_MM_HMX_TILE_N_COLS == 0); + + size_t n_tot_chunks = n_rows; + size_t n_chunks_per_task = (n_threads == 1) ? n_tot_chunks : 32; // must be multiple of 32 to ensure correct destination address + + activation_transfer_task_state_t state; + state.n_tasks = (n_tot_chunks + n_chunks_per_task - 1) / n_chunks_per_task; + state.n_tot_chunks = n_tot_chunks; + state.n_chunks_per_task = n_chunks_per_task; + state.dst = dst; + state.src = src; + state.k_block = k_block; + state.k_stride = k_stride; + state.k_valid = k_valid; + state.traces = ctx->trace; + state.ctx = ctx; + state.vtcm_f32_act = vtcm_f32_act; + + if (state.n_tasks == 1 || n_threads == 1) { + transfer_activation_chunk_worker_fn(1, 0, &state); + } else { + int n_tasks = hex_smin((int) state.n_tasks, n_threads); + worker_pool_run_func(ctx->worker_pool, transfer_activation_chunk_worker_fn, &state, n_tasks); + } +} + +static int hmx_mm_2d_f32(struct htp_context *ctx, + float *restrict dst, + const float *activation, + const uint8_t *weight, + int m, int k, int n, + int act_stride, + int weight_stride, + int weight_type, + int k_valid, + int dst_stride, + int dst_cols, + int m_chunk, + int n_chunk, + int pipeline, + int n_threads, + int act_threads, + int tile_size, + int aligned_tile_size, + int vtcm_size) { + if (k % 32 != 0 || n % 32 != 0) { return -1; } + if (!hex_is_aligned(dst, VLEN) || !hex_is_aligned(activation, VLEN)) { return -1; } + + size_t row_stride = htp_mm_get_tiled_row_stride(weight_type, k); + if (row_stride == 0) { + return -1; + } + + worker_callback_t dequant_worker_fn = NULL; + switch (weight_type) { + case HTP_TYPE_Q4_0: dequant_worker_fn = dequantize_tiled_worker_loop_q4_0; break; + case HTP_TYPE_IQ4_NL: dequant_worker_fn = dequantize_tiled_worker_loop_iq4_nl; break; + case HTP_TYPE_Q4_1: dequant_worker_fn = dequantize_tiled_worker_loop_q4_1; break; + case HTP_TYPE_MXFP4: dequant_worker_fn = dequantize_tiled_worker_loop_mxfp4; break; + case HTP_TYPE_Q8_0: dequant_worker_fn = dequantize_tiled_worker_loop_q8_0; break; + case HTP_TYPE_F16: dequant_worker_fn = convert_f16_worker_loop; break; + case HTP_TYPE_F32: dequant_worker_fn = quantize_f32_worker_loop; break; + default: + return -1; + } + + const int n_k_tiles = k / HTP_MM_HMX_TILE_N_COLS; + const struct fastdiv_values n_k_tiles_div = init_fastdiv_values(n_k_tiles); + + const bool is_quant = (weight_type != HTP_TYPE_F16 && weight_type != HTP_TYPE_F32); + const size_t vec_dot_size = k * sizeof(__fp16); + const size_t vtcm_budget = ctx->vtcm_size; + + size_t m_chunk_n_rows = m_chunk; + size_t n_chunk_n_cols = n_chunk; + size_t vtcm_used = vtcm_size; + + const size_t qweight_row_stride = is_quant ? (size_t)(n_k_tiles * aligned_tile_size) / 32 : 0; + + const size_t act_f32_size = hex_align_up((size_t)act_threads * HTP_MM_DMA_ACT_MULTIPLIER * k * sizeof(float), HTP_MM_HMX_TILE_SIZE); + + const size_t weight_area_size = is_quant + ? hex_align_up((n_chunk_n_cols / 32) * n_k_tiles * aligned_tile_size, HTP_MM_HMX_TILE_SIZE) + : hex_align_up(n_chunk_n_cols * row_stride, HTP_MM_HMX_TILE_SIZE); + const size_t act_area_size = hex_align_up(m_chunk_n_rows * vec_dot_size, HTP_MM_HMX_TILE_SIZE); + const size_t output_area_size = hex_align_up(m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HTP_MM_HMX_TILE_SIZE); + + size_t scratch0_size, scratch1_size, scratch2_size; + scratch0_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HTP_MM_HMX_TILE_SIZE); // dequant buf 0 + scratch1_size = pipeline ? scratch0_size : 0; // dequant buf 1 + scratch2_size = pipeline ? output_area_size : 0; // output buf 1 + + uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; + __fp16 *vtcm_weight_raw[2] = { NULL, NULL }; + if (weight_area_size) { + if (pipeline) { + vtcm_weight_raw[0] = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); + vtcm_weight_raw[1] = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); + } else { + vtcm_weight_raw[0] = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); + } + } + __fp16 *vtcm_f16_act = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, act_area_size); + float *vtcm_f32_act = (float *) vtcm_seq_alloc(&vtcm_ptr, act_f32_size); + __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size); + void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch0_size); + void *vtcm_scratch1 = scratch1_size ? vtcm_seq_alloc(&vtcm_ptr, scratch1_size) : NULL; + void *vtcm_scratch2 = scratch2_size ? vtcm_seq_alloc(&vtcm_ptr, scratch2_size) : NULL; + __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); + + vtcm_used = vtcm_ptr - (uint8_t *) ctx->vtcm_base; + if (vtcm_used > vtcm_budget) { + FARF(ERROR, "hmx-mm-2d-precomputed: VTCM overflow: used %zu budget %zu, m %d k %d n %d mc %zu nc %zu", + vtcm_used, vtcm_budget, m, k, n, m_chunk_n_rows, n_chunk_n_cols); + return -1; + } + + hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 + + FARF(HIGH, "hmx-mm-2d-precomputed: standard : m %d k %d n %d wtype %d mc %zu nc %zu vtcm %zu/%zu", + m, k, n, weight_type, m_chunk_n_rows, n_chunk_n_cols, vtcm_used, vtcm_budget); + + int n_chunk_cnt = hmx_ceil_div(n, n_chunk_n_cols); + + if (pipeline) { + // --- Asynchronous Pipelined Loop --- + hmx_matmul_job_t job_slots[2]; // persistent double-buffered job descriptors + + for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { + const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); + + void *vtcm_weight_bufs[2] = { vtcm_scratch0, vtcm_scratch1 }; + void *vtcm_output_bufs[2] = { vtcm_output, vtcm_scratch2 }; + + transfer_activation_chunk_threaded(ctx, vtcm_f16_act, activation + mr * act_stride, n_rows, k, act_stride, act_threads, k_valid, vtcm_f32_act); + + // Prologue: push A0 and optionally A1 (if n_chunk_cnt > 1) + const size_t n_cols_A0 = hex_smin(n - 0 * n_chunk_n_cols, n_chunk_n_cols); + if (is_quant) { + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_weight_raw[0], weight), aligned_tile_size, tile_size, tile_size, (n_cols_A0 / 32) * n_k_tiles); + } else { + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_weight_raw[0], weight), row_stride, weight_stride, row_stride, n_cols_A0); + } + + if (1 < n_chunk_cnt) { + const size_t n_cols_A1 = hex_smin(n - 1 * n_chunk_n_cols, n_chunk_n_cols); + if (is_quant) { + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_weight_raw[1], weight + n_chunk_n_cols * weight_stride), aligned_tile_size, tile_size, tile_size, (n_cols_A1 / 32) * n_k_tiles); + } else { + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_weight_raw[1], weight + n_chunk_n_cols * weight_stride), row_stride, weight_stride, row_stride, n_cols_A1); + } + } + + // pop A0 -> dequantize A0 -> submit C0 + dma_queue_pop(ctx->dma[0]); + dequantize_tiled_weight_chunk_to_fp16_tiles( + ctx, vtcm_weight_bufs[0], vtcm_weight_raw[0], + n_cols_A0, k, row_stride, weight_type, + n_k_tiles, n_k_tiles_div, dequant_worker_fn, n_threads); + + hmx_matmul_job_init(&job_slots[0], (__fp16 *) vtcm_output_bufs[0], (__fp16 *) vtcm_f16_act, + (__fp16 *) vtcm_weight_bufs[0], vtcm_scales, + hmx_ceil_div(n_rows, HTP_MM_HMX_TILE_N_ROWS), + hmx_ceil_div(n_cols_A0, HTP_MM_HMX_TILE_N_COLS), k / HTP_MM_HMX_TILE_N_ROWS); + hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[0])); + + // Main loop: pop/dequantize A_{i+1} -> push A_{i+2} -> submit C_{i+1} -> wait C_i and store D_i + for (int i = 0; i < n_chunk_cnt; ++i) { + const size_t nc = i * n_chunk_n_cols; + const size_t nc_p1 = nc + 1 * n_chunk_n_cols; + const size_t nc_p2 = nc + 2 * n_chunk_n_cols; + + const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); + const size_t n_cols_p1 = hex_smin(n - nc_p1, n_chunk_n_cols); + const size_t n_cols_p2 = hex_smin(n - nc_p2, n_chunk_n_cols); + + // 1. pop A_{i+1} and dequantize it (if i+1 < n_chunk_cnt) + if (i + 1 < n_chunk_cnt) { + dma_queue_pop(ctx->dma[0]); + dequantize_tiled_weight_chunk_to_fp16_tiles( + ctx, vtcm_weight_bufs[(i + 1) % 2], vtcm_weight_raw[(i + 1) % 2], + n_cols_p1, k, row_stride, weight_type, + n_k_tiles, n_k_tiles_div, dequant_worker_fn, n_threads); + } + + // 2. push A_{i+2} (if i+2 < n_chunk_cnt) + if (i + 2 < n_chunk_cnt) { + if (is_quant) { + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_weight_raw[(i + 2) % 2], weight + nc_p2 * weight_stride), aligned_tile_size, tile_size, tile_size, (n_cols_p2 / 32) * n_k_tiles); + } else { + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_weight_raw[(i + 2) % 2], weight + nc_p2 * weight_stride), row_stride, weight_stride, row_stride, n_cols_p2); + } + } + + // 3. submit C_{i+1} (if i+1 < n_chunk_cnt) + if (i + 1 < n_chunk_cnt) { + hmx_matmul_job_init(&job_slots[(i + 1) % 2], (__fp16 *) vtcm_output_bufs[(i + 1) % 2], + (__fp16 *) vtcm_f16_act, (__fp16 *) vtcm_weight_bufs[(i + 1) % 2], + vtcm_scales, hmx_ceil_div(n_rows, HTP_MM_HMX_TILE_N_ROWS), + hmx_ceil_div(n_cols_p1, HTP_MM_HMX_TILE_N_COLS), k / HTP_MM_HMX_TILE_N_ROWS); + hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[(i + 1) % 2])); + } + + // 4. wait C_i and store D_i (multi-thread HVX, parallel with C_{i+1}) + hmx_queue_pop(ctx->hmx_queue); + float *output_chunk = dst + (mr * dst_stride + nc); + int chunk_dst_cols = dst_cols - (int)nc; + if (chunk_dst_cols > 0) { + transfer_output_chunk_threaded(ctx, output_chunk, vtcm_output_bufs[i % 2], n_rows, n_cols, dst_stride, chunk_dst_cols, n_threads); + } + } + } + hmx_queue_suspend(ctx->hmx_queue); + } else { + // --- Synchronous Un-pipelined loop (m <= 32 or fallback) --- + HAP_compute_res_hmx_lock(ctx->vtcm_rctx); + for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { + const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); + + transfer_activation_chunk_threaded(ctx, vtcm_f16_act, activation + mr * act_stride, n_rows, k, act_stride, act_threads, k_valid, vtcm_f32_act); + + for (size_t nc = 0; nc < n; nc += n_chunk_n_cols) { + const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); + const size_t n_row_tiles = hmx_ceil_div(n_rows, HTP_MM_HMX_TILE_N_ROWS); + const size_t n_col_tiles = hmx_ceil_div(n_cols, HTP_MM_HMX_TILE_N_COLS); + + // A: Weight DMA (Synchronous) + if (is_quant) { + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_weight_raw[0], weight + nc * weight_stride), aligned_tile_size, tile_size, tile_size, (n_cols / 32) * n_k_tiles); + } else { + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_weight_raw[0], weight + nc * weight_stride), row_stride, weight_stride, row_stride, n_cols); + } + dma_queue_pop(ctx->dma[0]); + + // B: Weight Dequantize (Threaded) + dequantize_tiled_weight_chunk_to_fp16_tiles( + ctx, vtcm_scratch0, vtcm_weight_raw[0], + n_cols, k, row_stride, weight_type, + n_k_tiles, n_k_tiles_div, dequant_worker_fn, n_threads); + + // C: HMX Compute (Synchronous) + core_dot_chunk_fp16(vtcm_output, vtcm_f16_act, vtcm_scratch0, vtcm_scales, n_row_tiles, n_col_tiles, k / HTP_MM_HMX_TILE_N_ROWS); + + // D: Output Store + float *output_chunk = dst + (mr * dst_stride + nc); + int chunk_dst_cols = dst_cols - (int)nc; + if (chunk_dst_cols > 0) { + transfer_output_chunk_threaded(ctx, output_chunk, vtcm_output, n_rows, n_cols, dst_stride, chunk_dst_cols, n_threads); + } + } + } + HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); + } + + return 0; +} + +static inline int hmx_mm_batch_r2(const hmx_mm_f16_f32_batched_params_t *params) { + return params->ne02 > 0 ? params->ne12 / params->ne02 : 1; +} + +static inline int hmx_mm_batch_r3(const hmx_mm_f16_f32_batched_params_t *params) { + return params->ne03 > 0 ? params->ne13 / params->ne03 : 1; +} + +static inline const __fp16 *hmx_mm_weight_batch_ptr(const hmx_mm_f16_f32_batched_params_t *params, + int dst_b2, int dst_b3) { + const int r2 = hmx_mm_batch_r2(params); + const int r3 = hmx_mm_batch_r3(params); + return (const __fp16 *) ((const uint8_t *) params->weight + + (size_t) (dst_b2 / r2) * params->src0_nb2 + + (size_t) (dst_b3 / r3) * params->src0_nb3); +} + +static inline const float *hmx_mm_activation_batch_ptr(const hmx_mm_f16_f32_batched_params_t *params, + int dst_b2, int dst_b3) { + return (const float *) ((const uint8_t *) params->activation + + (size_t) dst_b2 * params->src1_nb2 + + (size_t) dst_b3 * params->src1_nb3); +} + +static inline float *hmx_mm_dst_batch_ptr(const hmx_mm_f16_f32_batched_params_t *params, + int dst_b2, int dst_b3) { + return (float *) ((uint8_t *) params->dst + + (size_t) dst_b2 * params->dst_nb2 + + (size_t) dst_b3 * params->dst_nb3); +} + +static int hmx_mm_f16_f32_batched_simple(struct htp_context *ctx, + const hmx_mm_f16_f32_batched_params_t *params, + int m_chunk, int n_chunk, int pipeline, int n_threads, int act_threads, int vtcm_size) { + int ret = 0; + for (int b3 = 0; b3 < params->ne13 && ret == 0; ++b3) { + for (int b2 = 0; b2 < params->ne12 && ret == 0; ++b2) { + ret = hmx_mm_2d_f32(ctx, hmx_mm_dst_batch_ptr(params, b2, b3), + hmx_mm_activation_batch_ptr(params, b2, b3), + (const uint8_t *)hmx_mm_weight_batch_ptr(params, b2, b3), + params->m, params->k, params->n, + params->act_stride, params->weight_stride * (int)sizeof(__fp16), + HTP_TYPE_F16, params->k, params->n, params->n, + m_chunk, n_chunk, pipeline, n_threads, act_threads, + 0, 0, vtcm_size); + } + } + return ret; +} + +static int hmx_mm_f16_f32_batched(struct htp_context *ctx, const hmx_mm_f16_f32_batched_params_t *params, + int m_chunk, int n_chunk, int pipeline, int n_threads, int act_threads, int vtcm_size) { + if (params->act_stride < params->k || params->weight_stride < params->k || params->dst_stride < params->n) { return -1; } + if (params->ne02 <= 0 || params->ne03 <= 0 || params->ne12 <= 0 || params->ne13 <= 0) { return -1; } + if (params->ne12 % params->ne02 != 0 || params->ne13 % params->ne03 != 0) { return -1; } + if (params->k % 32 != 0 || params->n % 32 != 0) { return -1; } + if (!hex_is_aligned(params->dst, VLEN) || !hex_is_aligned(params->activation, VLEN)) { return -1; } + + const int group_size = hmx_mm_batch_r2(params); + const size_t vtcm_budget = ctx->vtcm_size; + + // Check if the precomputed parameters are grouped or simple. + // If simple, or if group_size <= 1, we use simple fallback loop. + // Grouped path is only valid if group_size > 1 and it fits within VTCM budget. + bool run_grouped = (group_size > 1 && (size_t)vtcm_size <= vtcm_budget); + if (!run_grouped) { + return hmx_mm_f16_f32_batched_simple(ctx, params, m_chunk, n_chunk, pipeline, n_threads, act_threads, vtcm_size); + } + + const size_t vec_dot_size = params->k * sizeof(__fp16); + + const bool use_dma_activation = (params->act_stride > params->k); + const size_t f32_scratch_size = use_dma_activation + ? hex_align_up((size_t)act_threads * HTP_MM_DMA_ACT_MULTIPLIER * (size_t) params->k * sizeof(float), HTP_MM_HMX_TILE_SIZE) : 0; + + size_t m_chunk_n_rows = m_chunk; + size_t n_chunk_n_cols = n_chunk; + size_t vtcm_used = vtcm_size; + + const size_t act_head_stride = m_chunk_n_rows * (size_t) params->k; // fp16 elements between heads + const size_t weight_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HTP_MM_HMX_TILE_SIZE); + const size_t activation_area_size = hex_align_up(group_size * m_chunk_n_rows * vec_dot_size, HTP_MM_HMX_TILE_SIZE); + const size_t output_area_size = hex_align_up(m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HTP_MM_HMX_TILE_SIZE); + const size_t scratch_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HTP_MM_HMX_TILE_SIZE); + + uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; + __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); + __fp16 *vtcm_f16_act = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, activation_area_size); + __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size); + void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size); + void *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size); + __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); + float *vtcm_f32_act = use_dma_activation ? (float *) vtcm_seq_alloc(&vtcm_ptr, f32_scratch_size) : NULL; + + if ((size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base) > vtcm_budget) { + FARF(HIGH, "%s: grouped layout overflowed VTCM, falling back to simple batched loop", __func__); + return hmx_mm_f16_f32_batched_simple(ctx, params, m_chunk, n_chunk, pipeline, n_threads, act_threads, vtcm_size); + } + + hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 + + FARF(HIGH, "%s: grouped path m=%d k=%d n=%d group=%d streams=%d mc=%zu nc=%zu vtcm=%zu/%zu", + __func__, params->m, params->k, params->n, group_size, params->ne13, + m_chunk_n_rows, n_chunk_n_cols, + (size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base), vtcm_budget); + + const size_t fp16_row_bytes = (size_t) params->k * sizeof(__fp16); + const size_t weight_row_bytes = (size_t) params->weight_stride * sizeof(__fp16); + + HAP_compute_res_hmx_lock(ctx->vtcm_rctx); + + for (int b3 = 0; b3 < params->ne13; ++b3) { + for (int b2_base = 0; b2_base < params->ne12; b2_base += group_size) { + const __fp16 *weight_group = hmx_mm_weight_batch_ptr(params, b2_base, b3); + + for (size_t mr = 0; mr < (size_t) params->m; mr += m_chunk_n_rows) { + const size_t n_rows = hex_smin((size_t) params->m - mr, m_chunk_n_rows); + const size_t n_row_tiles = hmx_ceil_div((int) n_rows, HTP_MM_HMX_TILE_N_ROWS); + + // Pre-load activations for all heads in the group (once per m_chunk). + // When the source is strided (permuted Q), use 2D DMA to gather + // contiguous rows into a VTCM scratch buffer first, then HVX + // converts from the contiguous VTCM buffer. This avoids L2 cache + // thrashing from HVX loads at large strides. + for (int g = 0; g < group_size; ++g) { + const float *activation_chunk = hmx_mm_activation_batch_ptr(params, b2_base + g, b3) + mr * params->act_stride; + __fp16 *vtcm_act_g = vtcm_f16_act + (size_t) g * act_head_stride; + if (use_dma_activation) { + transfer_activation_chunk_threaded(ctx, vtcm_act_g, + activation_chunk, (int) n_rows, + params->k, params->act_stride, act_threads, params->k, vtcm_f32_act); + } else { + transfer_activation_chunk_threaded(ctx, vtcm_act_g, + activation_chunk, (int) n_rows, + params->k, params->act_stride, act_threads, params->k, NULL); + } + } + + void *buf_curr = vtcm_scratch0; + void *buf_next = vtcm_scratch1; + + { + const size_t n_cols_first = hex_smin((size_t) params->n, n_chunk_n_cols); + dma_queue_push(ctx->dma[0], dma_make_ptr(buf_curr, weight_group), + fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_first); + } + + for (size_t nc = 0; nc < (size_t) params->n; nc += n_chunk_n_cols) { + const size_t n_cols = hex_smin((size_t) params->n - nc, n_chunk_n_cols); + const size_t n_col_tiles = hmx_ceil_div((int) n_cols, HTP_MM_HMX_TILE_N_COLS); + + { + dma_queue_pop(ctx->dma[0]); + + const size_t nc_next = nc + n_chunk_n_cols; + if (nc_next < (size_t) params->n) { + const size_t n_cols_next = hex_smin((size_t) params->n - nc_next, n_chunk_n_cols); + const __fp16 *next_weight_chunk = weight_group + nc_next * params->weight_stride; + + dma_queue_push(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk), + fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_next); + } + + hmx_interleave_rows_to_tiles(vtcm_weight, (const __fp16 *) buf_curr, n_cols, params->k, params->k, 0, n_cols); + hex_swap_ptr(&buf_curr, &buf_next); + } + + // Reuse the interleaved weight for every q_head in this GQA group + for (int g = 0; g < group_size; ++g) { + struct htp_thread_trace * tr = &ctx->trace[HTP_MAX_NTHREADS]; + htp_trace_event_start(tr, HTP_TRACE_EVT_HMX_COMP, g); + { + const __fp16 * vtcm_act_g = vtcm_f16_act + (size_t) g * act_head_stride; + core_dot_chunk_fp16(vtcm_output, vtcm_act_g, vtcm_weight, vtcm_scales, n_row_tiles, n_col_tiles, + params->k / 32); + } + htp_trace_event_stop(tr, HTP_TRACE_EVT_HMX_COMP, g); + + { + float *output = hmx_mm_dst_batch_ptr(params, b2_base + g, b3) + mr * params->dst_stride + nc; + int chunk_dst_cols = params->n - (int)nc; + if (chunk_dst_cols > 0) { + transfer_output_chunk_threaded(ctx, output, vtcm_output, (int) n_rows, (int) n_cols, params->dst_stride, chunk_dst_cols, ctx->n_threads); + } + } + } + } + } + } + } + + HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); + + return 0; +} + +static void transfer_activation_chunk_gathered_threaded( + struct htp_context *ctx, + __fp16 *dst, + const float *src, + int start_row, + int n_rows, + int k_block, + const struct mmid_row_mapping *matrix_rows, + int cur_a, + int mapping_stride, + int ne11, + size_t nb11, + size_t nb12, + int cne1, + int n_threads, + int k_valid) { + if (n_rows <= 0) return; + int chunks_per_thread = hmx_ceil_div(n_rows, n_threads); + chunks_per_thread = hex_align_up(chunks_per_thread, 2); + + int actual_threads = hmx_ceil_div(n_rows, chunks_per_thread); + + activation_transfer_gathered_task_state_t state = { + .dst = dst, + .src = src, + .n_tasks = actual_threads, + .n_tot_chunks = n_rows, + .n_chunks_per_task = chunks_per_thread, + .k_block = k_block, + .matrix_rows = matrix_rows, + .cur_a = cur_a, + .mapping_stride = mapping_stride, + .ne11 = ne11, + .ne11_div = ne11 > 1 ? init_fastdiv_values(ne11) : (struct fastdiv_values){0, 0}, + .nb11 = nb11, + .nb12 = nb12, + .start_row = start_row, + .cne1 = cne1, + .k_valid = k_valid, + .traces = ctx->trace, + }; + + worker_callback_t worker_fn = ne11 == 1 ? transfer_activation_chunk_gathered_worker_flat_fn : + transfer_activation_chunk_gathered_worker_fn; + + if (actual_threads <= 1) { + worker_fn(1, 0, &state); + } else { + worker_pool_run_func(ctx->worker_pool, worker_fn, &state, actual_threads); + } +} + +static void transfer_output_chunk_scattered_threaded( + struct htp_context *ctx, + float *dst, + const __fp16 *vtcm_src, + int start_row, + int n_rows, + int n_cols, + const struct mmid_row_mapping *matrix_rows, + int cur_a, + int mapping_stride, + size_t dst_nb1, + size_t dst_nb2, + int cne1, + int n_threads) { + if (n_rows <= 0) return; + int chunks_per_thread = hmx_ceil_div(n_rows, n_threads); + chunks_per_thread = hex_align_up(chunks_per_thread, 2); + + int actual_threads = hmx_ceil_div(n_rows, chunks_per_thread); + + output_transfer_scattered_task_state_t state = { + .vtcm_src = vtcm_src, + .dst = dst, + .n_tasks = actual_threads, + .n_tot_chunks = n_rows, + .n_chunks_per_task = chunks_per_thread, + .n_cols = n_cols, + .matrix_rows = matrix_rows, + .cur_a = cur_a, + .mapping_stride = mapping_stride, + .dst_nb1 = dst_nb1, + .dst_nb2 = dst_nb2, + .start_row = start_row, + .cne1 = cne1, + .traces = ctx->trace, + }; + + if (actual_threads <= 1) { + transfer_output_chunk_scattered_worker_fn(1, 0, &state); + } else { + worker_pool_run_func(ctx->worker_pool, transfer_output_chunk_scattered_worker_fn, &state, actual_threads); + } +} + +static int hmx_mm_id_2d_f32(struct htp_context *ctx, + float *restrict dst, + const float *activation, + const uint8_t *weight, + int m, int k, int n, + int k_valid, + int ne11, + size_t act_nb1, size_t act_nb2, + size_t dst_nb1, size_t dst_nb2, + int weight_stride, + int weight_type, + const struct mmid_row_mapping *matrix_rows, + int cur_a, + int mapping_stride) { + const int cne1 = m; + const int m_padded = hex_align_up(m, 32); + + if (k % 32 != 0 || n % 32 != 0) { return -1; } + if (!hex_is_aligned(dst, VLEN) || !hex_is_aligned(activation, VLEN)) { return -1; } + + size_t row_stride = htp_mm_get_tiled_row_stride(weight_type, k); + if (row_stride == 0) { + return -1; + } + + worker_callback_t dequant_worker_fn = NULL; + switch (weight_type) { + case HTP_TYPE_Q4_0: dequant_worker_fn = dequantize_tiled_worker_loop_q4_0; break; + case HTP_TYPE_IQ4_NL: dequant_worker_fn = dequantize_tiled_worker_loop_iq4_nl; break; + case HTP_TYPE_Q4_1: dequant_worker_fn = dequantize_tiled_worker_loop_q4_1; break; + case HTP_TYPE_MXFP4: dequant_worker_fn = dequantize_tiled_worker_loop_mxfp4; break; + case HTP_TYPE_Q8_0: dequant_worker_fn = dequantize_tiled_worker_loop_q8_0; break; + case HTP_TYPE_F16: dequant_worker_fn = convert_f16_worker_loop; break; + case HTP_TYPE_F32: dequant_worker_fn = quantize_f32_worker_loop; break; + default: + return -1; + } + + const int n_k_tiles = k / HTP_MM_HMX_TILE_N_COLS; + const struct fastdiv_values n_k_tiles_div = init_fastdiv_values(n_k_tiles); + + const int n_threads = ctx->n_threads; + const bool is_quant = (weight_type != HTP_TYPE_F16 && weight_type != HTP_TYPE_F32); + + const size_t vec_dot_size = k * sizeof(__fp16); + const size_t vtcm_budget = ctx->vtcm_size; + size_t vtcm_used = 0; + + int tile_size = htp_mm_get_weight_tile_size(weight_type); + int aligned_tile_size = htp_mm_get_weight_aligned_tile_size(weight_type); + + const size_t qweight_row_stride = is_quant ? (size_t)(n_k_tiles * aligned_tile_size) / 32 : 0; + const size_t weight_row_stride = is_quant ? qweight_row_stride : row_stride; + + size_t size_per_n = 0, size_per_m = 0, size_per_mn = 0; + htp_mm_hmx_get_2d_chunk_costs(weight_type, k, /*pipeline=*/false, aligned_tile_size, + &size_per_n, &size_per_m, &size_per_mn); + + size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0; + if (htp_mm_hmx_compute_chunks(vtcm_budget, /*overhead=*/256, size_per_n, size_per_m, size_per_mn, + m_padded, n, + /*m_block_cost=*/(size_t) n * HTP_MM_HMX_COST_W_DEQUANT, + /*n_block_cost=*/(size_t) m_padded * HTP_MM_HMX_COST_A_CONVERT, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used)) { + FARF(ERROR, "hmx-mm-id-2d: VTCM too small : m %d k %d n %d budget %zu", m_padded, k, n, vtcm_budget); + return -1; + } + + const size_t weight_area_size = hex_align_up(n_chunk_n_cols * weight_row_stride, HTP_MM_HMX_TILE_SIZE); + const size_t act_area_size = hex_align_up(m_chunk_n_rows * vec_dot_size, HTP_MM_HMX_TILE_SIZE); + const size_t output_area_size = hex_align_up(m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HTP_MM_HMX_TILE_SIZE); + + size_t scratch0_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HTP_MM_HMX_TILE_SIZE); + + uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; + __fp16 *vtcm_weight = weight_area_size ? (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size) : NULL; + __fp16 *vtcm_f16_act = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, act_area_size); + __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size); + void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch0_size); + __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); + + vtcm_used = vtcm_ptr - (uint8_t *) ctx->vtcm_base; + if (vtcm_used > vtcm_budget) { + FARF(ERROR, "hmx-mm-id-2d: VTCM overflow: used %zu budget %zu", vtcm_used, vtcm_budget); + return -1; + } + + hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); + + HAP_compute_res_hmx_lock(ctx->vtcm_rctx); + + for (size_t mr = 0; mr < (size_t) m_padded; mr += m_chunk_n_rows) { + const size_t n_rows = hex_smin(m_padded - mr, m_chunk_n_rows); + const size_t n_row_tiles = hmx_ceil_div(n_rows, HTP_MM_HMX_TILE_N_ROWS); + + transfer_activation_chunk_gathered_threaded( + ctx, vtcm_f16_act, activation, (int) mr, (int) n_rows, k, + matrix_rows, cur_a, mapping_stride, ne11, act_nb1, act_nb2, cne1, n_threads, k_valid); + + for (size_t nc = 0; nc < (size_t) n; nc += n_chunk_n_cols) { + const size_t n_cols = hex_smin((size_t) n - nc, n_chunk_n_cols); + const size_t n_col_tiles = hmx_ceil_div(n_cols, HTP_MM_HMX_TILE_N_COLS); + + if (is_quant) { + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_weight, weight + nc * weight_stride), aligned_tile_size, tile_size, tile_size, (n_cols / 32) * n_k_tiles); + } else { + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_weight, weight + nc * weight_stride), row_stride, weight_stride, row_stride, n_cols); + } + dma_queue_pop(ctx->dma[0]); + + dequantize_tiled_weight_chunk_to_fp16_tiles( + ctx, vtcm_scratch0, vtcm_weight, + n_cols, k, row_stride, weight_type, + n_k_tiles, n_k_tiles_div, dequant_worker_fn, n_threads + ); + + struct htp_thread_trace * tr = &ctx->trace[HTP_MAX_NTHREADS]; + htp_trace_event_start(tr, HTP_TRACE_EVT_HMX_COMP, nc); + core_dot_chunk_fp16(vtcm_output, vtcm_f16_act, vtcm_scratch0, vtcm_scales, n_row_tiles, n_col_tiles, k / HTP_MM_HMX_TILE_N_ROWS); + htp_trace_event_stop(tr, HTP_TRACE_EVT_HMX_COMP, nc); + + transfer_output_chunk_scattered_threaded( + ctx, dst + nc, vtcm_output, (int) mr, (int) n_rows, (int) n_cols, + matrix_rows, cur_a, mapping_stride, dst_nb1, dst_nb2, cne1, n_threads); + } + } + + HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); + return 0; +} + + +// --- Dispatchers and Public Entry Points --- + +static int hmx_mm_op_matmul(struct htp_ops_context * octx, const struct htp_mm_kernel_params * kparams) { htp_matmul_tensors_preamble; -#ifndef HTP_HAS_HMX - return op_matmul_hvx(octx); -#else - if (!octx->ctx->hmx_enabled) { - return op_matmul_hvx(octx); - } - - // HMX weight tile requires N to be 32-aligned. - if (src0->ne[1] % 32 != 0) { - return op_matmul_hvx(octx); - } - - // HMX supports F16, F32, Q4_0, Q8_0, IQ4_NL, MXFP4 weights. - // Other types fall back to HVX. - uint32_t wtype = src0->type; - if (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_F32 && wtype != HTP_TYPE_Q4_0 && wtype != HTP_TYPE_Q4_1 && wtype != HTP_TYPE_Q8_0 && wtype != HTP_TYPE_IQ4_NL && wtype != HTP_TYPE_MXFP4) { - return op_matmul_hvx(octx); - } - - // Quantised HMX path requires K aligned to 256 (x4x2 super-block). - // F16 and F32 HMX paths require K aligned to 32 (tile width). - if (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_F32 && src0->ne[0] % 256 != 0) { - return op_matmul_hvx(octx); - } - - if ((wtype == HTP_TYPE_F16 || wtype == HTP_TYPE_F32) && src0->ne[0] % 32 != 0) { - return op_matmul_hvx(octx); - } - - const bool is_batched = (src0->ne[2] * src0->ne[3] > 1 || src1->ne[2] * src1->ne[3] > 1); - - // Quantised HMX kernels only handle flat 2D matmul (host already rejects - // batched quantised, but guard here too). F16 batched matmul is handled - // by the dedicated wrapper in hmx-matmul-ops.c. - if (is_batched && src0->type != HTP_TYPE_F16) { - return op_matmul_hvx(octx); - } - - // HMX assumes contiguous row-major layout. Fall back for permuted - // tensors where strides are non-monotonic (e.g. transposed KV cache). - if (src0->nb[0] > src0->nb[1] || src1->nb[0] > src1->nb[1]) { - return op_matmul_hvx(octx); - } - - // M alignment: Use HMX when M >= 32, the last partial tile (m_total % 32 rows) - // is handled by HMX itself; when M < 32 fall back to HVX. - const int m_total = (int) src1->ne[1]; - const int m_hmx = m_total & ~31; // 0 when M < 32 - if (m_hmx == 0) { - return op_matmul_hvx(octx); - } - - // Always re-quantize src1 since HMX kernel overwrites vtcm/spad, - // so any previously cached quantized data is invalid. - octx->src1_spad.src = NULL; - - int k = (int) src0->ne[0]; // inner dimension - int n = (int) src0->ne[1]; // weight columns - - int ret = -1; - - // Row strides in elements. For compact tensors these equal k; for - // permuted attention views they can be larger, so pass the real stride. + int k = (int) src0->ne[0]; + int n = (int) src0->ne[1]; + const int m_total = (int) src1->ne[1]; const int act_stride = (int)(src1->nb[1] / sizeof(float)); const int wgt_stride = (int)(src0->nb[1] / sizeof(__fp16)); @@ -4780,54 +3051,204 @@ int op_matmul(struct htp_ops_context * octx) { return HTP_STATUS_OK; } - if (is_batched) { - if (src0->type == HTP_TYPE_F16) { - hmx_matmul_f16_f32_batched_params_t batch_params = { - .dst = (float *) dst->data, - .activation = (float *) src1->data, - .permuted_weight = (const __fp16 *) src0->data, - .m = m_total, - .k = k, - .n = n, - .act_stride = act_stride, - .weight_stride = wgt_stride, - .dst_stride = (int) (dst->nb[1] / sizeof(float)), - .ne02 = ne02, - .ne03 = ne03, - .ne12 = ne12, - .ne13 = ne13, - .src0_nb2 = src0->nb[2], - .src0_nb3 = src0->nb[3], - .src1_nb2 = src1->nb[2], - .src1_nb3 = src1->nb[3], - .dst_nb2 = dst->nb[2], - .dst_nb3 = dst->nb[3], - }; - ret = hmx_matmul_f16_f32_batched(octx->ctx, &batch_params); - } else { - return op_matmul_hvx(octx); - } + int ret = -1; + const int n_threads = MIN(kparams->n_threads, (int) octx->n_threads); + if (kparams->kernel_type == HTP_MM_KERNEL_HMX_F16_BATCHED) { + hmx_mm_f16_f32_batched_params_t batch_params = { + .dst = (float *) dst->data, + .activation = (float *) src1->data, + .weight = (const __fp16 *) src0->data, + .m = m_total, + .k = k, + .n = n, + .act_stride = act_stride, + .weight_stride = wgt_stride, + .dst_stride = (int) (dst->nb[1] / sizeof(float)), + .ne02 = ne02, + .ne03 = ne03, + .ne12 = ne12, + .ne13 = ne13, + .src0_nb2 = src0->nb[2], + .src0_nb3 = src0->nb[3], + .src1_nb2 = src1->nb[2], + .src1_nb3 = src1->nb[3], + .dst_nb2 = dst->nb[2], + .dst_nb3 = dst->nb[3], + }; + ret = hmx_mm_f16_f32_batched(octx->ctx, &batch_params, + kparams->m_chunk, kparams->n_chunk, + kparams->pipeline, n_threads, + kparams->n_act_threads, + kparams->vtcm_size); } else { - ret = hmx_matmul_2d_f32(octx->ctx, (float*) dst->data, (float*) src1->data, (const uint8_t *) src0->data, - m_total, k, n, act_stride, (int) src0->nb[1], (int) src0->type); + ret = hmx_mm_2d_f32( + octx->ctx, (float*) dst->data, (float*) src1->data, (const uint8_t *) src0->data, + m_total, k, n, act_stride, (int) src0->nb[1], (int) src0->type, (int) src1->ne[0], + (int)(dst->nb[1] / sizeof(float)), (int)dst->ne[0], + kparams->m_chunk, kparams->n_chunk, kparams->pipeline, n_threads, + kparams->n_act_threads, + kparams->tile_size, kparams->aligned_tile_size, kparams->vtcm_size + ); } if (ret != 0) { - FARF(HIGH, "HMX matmul failed (ret=%d), falling back to HVX", ret); - return op_matmul(octx); + FARF(ERROR, "HMX matmul failed (ret=%d)\n", ret); + return HTP_STATUS_INTERNAL_ERR; + } + return HTP_STATUS_OK; +} + +int op_matmul(struct htp_ops_context * octx) { + const struct htp_mm_kernel_params * kparams = (const struct htp_mm_kernel_params *) octx->kernel_params; + + if (kparams->n_hmx) { + return hmx_mm_op_matmul(octx, kparams); } - return 0; -#endif // HTP_HAS_HMX + return hvx_mm_matmul(octx); +} + +static int hmx_mm_op_matmul_id( + struct htp_ops_context * octx, + struct htp_mm_context * mmctx, + const uint32_t * matrix_row_counts, + const struct mmid_row_mapping * matrix_rows, + void * mapping_buf, + bool must_free_mapping +) { + htp_matmul_tensors_preamble; + const struct htp_mm_kernel_params * kparams = (const struct htp_mm_kernel_params *) octx->kernel_params; + const int n_ids = octx->src[2]->ne[0]; + const int n_as = ne02; + + for (uint32_t cur_a = 0; cur_a < n_as; ++cur_a) { + const int32_t cne1 = matrix_row_counts[cur_a]; + if (cne1 == 0) continue; + + int ret = hmx_mm_id_2d_f32(octx->ctx, (float*) dst->data, (float*) src1->data, + (const uint8_t *) src0->data + cur_a * nb02, + cne1, ne00, ne01, + ne10, + ne11, + nb11, nb12, + nb1, nb2, + (int) src0->nb[1], (int) src0->type, + matrix_rows, cur_a, n_ids * octx->src[2]->ne[1]); + if (ret != 0) { + FARF(ERROR, "HMX matmul failed for expert %u, error %d\n", cur_a, ret); + if (must_free_mapping) free(mapping_buf); + return HTP_STATUS_NO_SUPPORT; + } + } + + if (must_free_mapping) free(mapping_buf); + return HTP_STATUS_OK; +} + +static int hvx_mm_op_matmul_id( + struct htp_ops_context * octx, + struct htp_mm_context * mmctx, + size_t src0_row_size_padded, + uint32_t src1_nrows, + worker_callback_t matmul_id_job_func, + void * mapping_buf, + bool must_free_mapping +) { + htp_matmul_tensors_preamble; + const struct htp_mm_kernel_params * kparams = (const struct htp_mm_kernel_params *) octx->kernel_params; + const struct htp_tensor * restrict ids = octx->src[2]; + const size_t src0_row_size = nb01; + + const uint32_t qk = QK_Q8_0_TILED; + const uint32_t nb = (ne10 + qk - 1) / qk; + const uint32_t total_nb = src1_nrows * nb; + + worker_callback_t quant_job_func; + uint32_t n_quant_jobs = 1; + if (src1_nrows < octx->n_threads) { + n_quant_jobs = MIN(total_nb, octx->n_threads); + quant_job_func = (src0->type == HTP_TYPE_Q4_1) ? quantize_f32_q8_1_tiled_block : quantize_f32_q8_0_tiled_block; + for (uint32_t ith = 0; ith < n_quant_jobs; ++ith) { + uint32_t ib_first = (total_nb * ith) / n_quant_jobs; + uint32_t ib_last = (total_nb * (ith + 1)) / n_quant_jobs; + mmctx->quant_ib_first[ith] = ib_first; + mmctx->quant_ib_last[ith] = ib_last; + mmctx->quant_r[ith] = ib_first / nb; + mmctx->quant_c[ith] = ib_first % nb; + } + } else { + n_quant_jobs = MIN(src1_nrows, octx->n_threads); + quant_job_func = (src0->type == HTP_TYPE_Q4_1) ? quantize_f32_q8_1_tiled : quantize_f32_q8_0_tiled; + } + size_t src1_row_size = (src0->type == HTP_TYPE_Q4_1) ? htp_mm_q8_1_tiled_row_size(ne10) : htp_mm_q8_0_tiled_row_size(ne10); + + // Scratchpad sizes are computed on the host (htp_mm_hvx_id_get_vtcm_sizes) and passed in. + // The ID layout is routing-independent, so the host has exact visibility -- consume it here + // rather than recomputing, to keep host budgeting and device allocation in lockstep. + size_t src0_sz = kparams->vtcm_src0_size; + size_t src1_sz = kparams->vtcm_src1_size; + size_t src2_sz = 0; // mapping lives in DDR + size_t dst_sz = 0; // ID kernels scatter straight to DDR + size_t vtcm_size = kparams->vtcm_size; + + size_t src0_sz_per_thread = src0_sz / octx->n_threads; + size_t src1_sz_per_thread = src1_sz; + size_t src2_sz_per_thread = 0; + size_t dst_sz_per_thread = 0; + + FARF(HIGH, "matmul-id-%s : src0-spad-size %zu src1-spad-size %zu src2-spad-size %zu dst-spad-size %zu (%zu)\n", mmctx->type, + src0_sz, src1_sz, src2_sz, dst_sz, vtcm_size); + + FARF(HIGH, "matmul-id-%s : %ux%ux%ux%u * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n", mmctx->type, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], + ids->ne[0], ids->ne[1], ids->ne[2], ids->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], src0->data, + src1->data, dst->data); + + // Make sure the reserved vtcm size is sufficient + if (octx->ctx->vtcm_size < vtcm_size) { + FARF(ERROR, "matmul-id-%s : current VTCM reservation %zu is too small, needed %zu\n", mmctx->type, octx->ctx->vtcm_size, vtcm_size); + if (must_free_mapping) free(mapping_buf); + return HTP_STATUS_VTCM_TOO_SMALL; + } + + uint8_t * vtcm_ptr = (uint8_t *) octx->ctx->vtcm_base; + mmctx->vtcm_src1 = vtcm_seq_alloc(&vtcm_ptr, src1_sz); + mmctx->vtcm_src0 = vtcm_seq_alloc(&vtcm_ptr, src0_sz); + mmctx->vtcm_src2 = vtcm_seq_alloc(&vtcm_ptr, src2_sz); + mmctx->vtcm_dst = vtcm_seq_alloc(&vtcm_ptr, dst_sz); + + octx->src1_spad.src = NULL; + octx->src0_spad.src = NULL; + octx->src2_spad.src = NULL; + octx->dst_spad.src = NULL; + + mmctx->vtcm_src0_stride = src0_row_size_padded; + mmctx->vtcm_src1_stride = src1_row_size; + + mmctx->vtcm_src0_size_per_thread = src0_sz_per_thread; + mmctx->vtcm_src1_size_per_thread = src1_sz_per_thread; + mmctx->vtcm_src2_size_per_thread = src2_sz_per_thread; + mmctx->vtcm_dst_size_per_thread = dst_sz_per_thread; + + mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs; + worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs); + + const uint32_t n_matmul_jobs = octx->n_threads; + worker_pool_run_func(octx->ctx->worker_pool, matmul_id_job_func, mmctx, n_matmul_jobs); + + if (must_free_mapping) free(mapping_buf); + return HTP_STATUS_OK; } int op_matmul_id(struct htp_ops_context * octx) { htp_matmul_tensors_preamble; - struct htp_matmul_context mmctx_struct = {0}; - struct htp_matmul_context * mmctx = &mmctx_struct; + struct htp_mm_context mmctx_struct = {0}; + struct htp_mm_context * mmctx = &mmctx_struct; mmctx->octx = octx; + const struct htp_mm_kernel_params * kparams = (const struct htp_mm_kernel_params *) octx->kernel_params; + const struct htp_tensor * restrict ids = octx->src[2]; const size_t src0_row_size = nb01; @@ -4839,14 +3260,11 @@ int op_matmul_id(struct htp_ops_context * octx) { const uint32_t src1_nrows = ne11 * ne12 * ne13; worker_callback_t quant_job_func; - worker_callback_t matmul_id_job_func = src1_nrows > 1 ? matmul_id : matvec_id; + worker_callback_t matmul_id_job_func = src1_nrows > 1 ? hvx_mm_id : hvx_mv_id; // Compute src0_nrows_per_thread mmctx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads; - mmctx->src0_nrows_per_thread += (mmctx->src0_nrows_per_thread & 1); // round up to even - - size_t src1_row_size; - size_t src1_row_size_padded; + mmctx->src0_nrows_per_thread = hex_round_up(mmctx->src0_nrows_per_thread, 32); // row groups const int n_ids = ids->ne[0]; // n_expert_used @@ -4875,54 +3293,13 @@ int op_matmul_id(struct htp_ops_context * octx) { mmctx->matrix_row_counts = matrix_row_counts; mmctx->matrix_rows = matrix_rows; + mmctx->mm_div_ne11 = kparams->div_ne11; - if (htp_mminit_vec_dot(mmctx, src0->type) != 0) { + if (hvx_mm_init_vec_dot(mmctx, src0->type) != 0) { if (must_free_mapping) free(mapping_buf); return HTP_STATUS_NO_SUPPORT; } - if (src0->type == HTP_TYPE_Q4_1) { - quant_job_func = quantize_f32_q8_1x4x2; - src1_row_size = q8_1x4x2_row_size(ne10); - } else { - quant_job_func = quantize_f32_q8x4x2; - src1_row_size = q8x4x2_row_size(ne10); - } - - const size_t src2_spad_size_per_thread = 0; // We moved the mapping to DDR! - htp_mminit_spad(octx, dst_row_size, src0_row_size_padded, src1_row_size, src1_nrows, src2_spad_size_per_thread); - - size_t spad_size = octx->src2_spad.size + octx->src1_spad.size + octx->src0_spad.size + octx->dst_spad.size; - - FARF(HIGH, "matmul-id-%s : src0-spad-size %u src1-spad-size %u src2-spad-size %u dst-spad-size %u (%zu)\n", mmctx->type, - octx->src0_spad.size, octx->src1_spad.size, octx->src2_spad.size, octx->dst_spad.size, spad_size); - - FARF(HIGH, "matmul-id-%s : %ux%ux%ux%u * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n", mmctx->type, - src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], - ids->ne[0], ids->ne[1], ids->ne[2], ids->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], src0->data, - src1->data, dst->data); - - // Make sure the reserved vtcm size is sufficient - if (octx->ctx->vtcm_size < spad_size) { - FARF(ERROR, "matmul-id-%s : current VTCM reservation %zu is too small, needed %zu\n", mmctx->type, octx->ctx->vtcm_size, spad_size); - if (must_free_mapping) free(mapping_buf); - return HTP_STATUS_VTCM_TOO_SMALL; - } - - // Place src1 spad first. We use it for dyn.quant and may reuse in subseq ops. - octx->src1_spad.data = octx->ctx->vtcm_base; - octx->src0_spad.data = octx->src1_spad.data + octx->src1_spad.size; - octx->src2_spad.data = octx->src0_spad.data + octx->src0_spad.size; - octx->dst_spad.data = octx->src2_spad.data + octx->src2_spad.size; - - octx->src1_spad.src = (src1 == octx->src1_spad.src) ? src1 : NULL; - octx->src0_spad.src = NULL; - octx->src2_spad.src = NULL; - octx->dst_spad.src = NULL; - - octx->src0_spad.stride = src0_row_size_padded; - octx->src1_spad.stride = src1_row_size; - if (src1_nrows > 1) { // initialize matrix_row_counts and map memset(matrix_row_counts, 0, n_as * sizeof(uint32_t)); @@ -4930,9 +3307,12 @@ int op_matmul_id(struct htp_ops_context * octx) { // group rows by src0 matrix for (uint32_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) { // token idx for (uint32_t id = 0; id < n_ids; ++id) { // expert idx - const uint32_t i02 = *(const uint32_t *) ((const uint8_t *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]); + const int32_t i02 = *(const int32_t *) ((const uint8_t *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]); - assert(i02 >= 0 && i02 < n_as); + if (i02 < 0) { + continue; + } + assert(i02 < n_as); matrix_rows[i02 * n_ids * ids->ne[1] + matrix_row_counts[i02]] = (struct mmid_row_mapping) { id, iid1 }; matrix_row_counts[i02] += 1; @@ -4945,60 +3325,292 @@ int op_matmul_id(struct htp_ops_context * octx) { return HTP_STATUS_OK; } - bool hmx_eligible = false; -#ifdef HTP_HAS_HMX - if (octx->ctx->hmx_enabled && src1_nrows > 1) { - uint32_t wtype = src0->type; - if (ne01 % 32 == 0 && - (wtype == HTP_TYPE_F16 || wtype == HTP_TYPE_F32 || wtype == HTP_TYPE_Q4_0 || wtype == HTP_TYPE_Q4_1 || wtype == HTP_TYPE_Q8_0 || wtype == HTP_TYPE_IQ4_NL || wtype == HTP_TYPE_MXFP4)) { - if ((wtype == HTP_TYPE_F16 || wtype == HTP_TYPE_F32) && ne00 % 32 == 0) { - hmx_eligible = true; - } else if (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_F32 && ne00 % 256 == 0) { - hmx_eligible = true; - } - } + if (kparams->n_hmx) { + return hmx_mm_op_matmul_id(octx, mmctx, matrix_row_counts, matrix_rows, mapping_buf, must_free_mapping); } -#endif - mmctx->hmx_eligible = hmx_eligible; + return hvx_mm_op_matmul_id(octx, mmctx, src0_row_size_padded, src1_nrows, matmul_id_job_func, mapping_buf, must_free_mapping); +} - if (hmx_eligible) { - for (uint32_t cur_a = 0; cur_a < n_as; ++cur_a) { - const int32_t cne1 = matrix_row_counts[cur_a]; - if (cne1 == 0) continue; +int op_matmul_qkv(struct htp_ops_context * octx) { + const struct htp_tensor * restrict src0 = octx->src[0]; // Wk + const struct htp_tensor * restrict src1 = octx->src[1]; // x + const struct htp_tensor * restrict src2 = octx->src[2]; // Wv + const struct htp_tensor * restrict src3 = octx->src[3]; // Wq + const struct htp_tensor * restrict dst_k = octx->dsts[0]; + const struct htp_tensor * restrict dst_v = octx->dsts[1]; + const struct htp_tensor * restrict dst_q = octx->dsts[2]; - int ret = hmx_matmul_id_2d_f32(octx->ctx, (float*) dst->data, (float*) src1->data, - (const uint8_t *) src0->data + cur_a * nb02, - cne1, ne00, ne01, - ne11, - nb11, nb12, - nb1, nb2, - (int) src0->nb[1], (int) src0->type, - matrix_rows, cur_a, n_ids * ids->ne[1]); - if (ret != 0) { - FARF(ERROR, "HMX matmul failed for expert %u, error %d\n", cur_a, ret); - if (must_free_mapping) free(mapping_buf); - return HTP_STATUS_NO_SUPPORT; - } + bool is_repacked = (src0->type == HTP_TYPE_Q4_0 || src0->type == HTP_TYPE_Q4_1 || + src0->type == HTP_TYPE_Q8_0 || src0->type == HTP_TYPE_IQ4_NL || + src0->type == HTP_TYPE_MXFP4); + + struct htp_mm_context mmctx_struct = {0}; + struct htp_mm_context * mmctx = &mmctx_struct; + mmctx->octx = octx; + + const struct htp_mm_kernel_params * kparams = (const struct htp_mm_kernel_params *) octx->kernel_params; + + const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3]; + const uint32_t src1_nrows = src1->ne[1] * src1->ne[2] * src1->ne[3]; + + // Compute src0_nrows_per_thread + mmctx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads; + if (is_repacked) { + mmctx->src0_nrows_per_thread = hex_round_up(mmctx->src0_nrows_per_thread, 32); + } else { + mmctx->src0_nrows_per_thread += (mmctx->src0_nrows_per_thread & 1); // round up to even + } + + const size_t src0_row_size = src0->nb[1]; + const size_t src0_row_size_padded = hex_round_up(src0_row_size, 128); + + if (hvx_mm_init_vec_dot(mmctx, src0->type) != 0) { + return HTP_STATUS_NO_SUPPORT; + } + + const uint32_t qk = QK_Q8_0_TILED; + const uint32_t nb = (src1->ne[0] + qk - 1) / qk; + const uint32_t total_nb = src1_nrows * nb; + + worker_callback_t quant_job_func; + uint32_t n_quant_jobs = 1; + if (kparams->kernel_type == HTP_MM_KERNEL_HVX_QUANT_ROW_FLAT) { + n_quant_jobs = MIN(src1_nrows, octx->n_threads); + quant_job_func = (src0->type == HTP_TYPE_Q4_1) ? quantize_f32_q8_1_flat : quantize_f32_q8_0_flat; + } else if (src1_nrows < octx->n_threads) { + n_quant_jobs = MIN(total_nb, octx->n_threads); + quant_job_func = (src0->type == HTP_TYPE_Q4_1) ? quantize_f32_q8_1_tiled_block : quantize_f32_q8_0_tiled_block; + for (uint32_t ith = 0; ith < n_quant_jobs; ++ith) { + uint32_t ib_first = (total_nb * ith) / n_quant_jobs; + uint32_t ib_last = (total_nb * (ith + 1)) / n_quant_jobs; + mmctx->quant_ib_first[ith] = ib_first; + mmctx->quant_ib_last[ith] = ib_last; + mmctx->quant_r[ith] = ib_first / nb; + mmctx->quant_c[ith] = ib_first % nb; } + } else { + n_quant_jobs = MIN(src1_nrows, octx->n_threads); + quant_job_func = (src0->type == HTP_TYPE_Q4_1) ? quantize_f32_q8_1_tiled : quantize_f32_q8_0_tiled; + } - // HMX has overwritten VTCM, so force dynamic quantization cache to clear - octx->src1_spad.src = NULL; + size_t src1_row_size; + if (kparams->kernel_type == HTP_MM_KERNEL_HVX_QUANT_ROW_FLAT) { + src1_row_size = (src0->type == HTP_TYPE_Q4_1) ? htp_mm_q8_1_flat_row_size(src1->ne[0]) : htp_mm_q8_0_flat_row_size(src1->ne[0]); + } else { + src1_row_size = (src0->type == HTP_TYPE_Q4_1) ? htp_mm_q8_1_tiled_row_size(src1->ne[0]) : htp_mm_q8_0_tiled_row_size(src1->ne[0]); + } - if (must_free_mapping) free(mapping_buf); + // Set up scratchpads using precomputed sizes from the host + size_t src0_sz = kparams->vtcm_src0_size; + size_t src1_sz = kparams->vtcm_src1_size; + size_t src2_sz = kparams->vtcm_src2_size; + size_t src3_sz = kparams->vtcm_src3_size; + size_t vtcm_size = kparams->vtcm_size; + + size_t src0_sz_per_thread = src0_sz / octx->n_threads; + size_t src1_sz_per_thread = src1_sz; + size_t src2_sz_per_thread = src2_sz / octx->n_threads; + size_t src3_sz_per_thread = src3_sz / octx->n_threads; + + if (octx->ctx->vtcm_size < vtcm_size) { + FARF(ERROR, "matmul-qkv: current VTCM reservation %zu is too small, needed %zu\n", + octx->ctx->vtcm_size, vtcm_size); + return HTP_STATUS_VTCM_TOO_SMALL; + } + + uint8_t * vtcm_ptr = (uint8_t *) octx->ctx->vtcm_base; + mmctx->vtcm_src1 = vtcm_seq_alloc(&vtcm_ptr, src1_sz); + mmctx->vtcm_src0 = vtcm_seq_alloc(&vtcm_ptr, src0_sz); + mmctx->vtcm_src2 = vtcm_seq_alloc(&vtcm_ptr, src2_sz); + mmctx->vtcm_src3 = vtcm_seq_alloc(&vtcm_ptr, src3_sz); + + octx->src1_spad.src = NULL; + octx->src0_spad.src = NULL; + octx->src2_spad.src = NULL; + octx->src3_spad.src = NULL; + + mmctx->vtcm_src0_stride = is_repacked ? 0 : src0_row_size_padded; + mmctx->vtcm_src2_stride = is_repacked ? 0 : src0_row_size_padded; + mmctx->vtcm_src3_stride = is_repacked ? 0 : src0_row_size_padded; + mmctx->vtcm_src1_stride = src1_row_size; + + mmctx->vtcm_src0_size_per_thread = src0_sz_per_thread; + mmctx->vtcm_src1_size_per_thread = src1_sz_per_thread; + mmctx->vtcm_src2_size_per_thread = src2_sz_per_thread; + mmctx->vtcm_src3_size_per_thread = src3_sz_per_thread; + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) return HTP_STATUS_OK; - } - if (octx->src1_spad.src != src1) { - const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads); - mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs; - worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs); - octx->src1_spad.src = src1; - } + // Run quantization once + mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs; + worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs); + // Run fused matmul const uint32_t n_matmul_jobs = octx->n_threads; - worker_pool_run_func(octx->ctx->worker_pool, matmul_id_job_func, mmctx, n_matmul_jobs); + worker_callback_t matmul_job_func; + if (is_repacked) { + if (kparams->kernel_type == HTP_MM_KERNEL_HVX_QUANT_ROW_FLAT) { + switch (src0->type) { + case HTP_TYPE_Q4_0: matmul_job_func = hvx_mm_qkv_2d_repacked_q4_0_flat; break; + case HTP_TYPE_Q4_1: matmul_job_func = hvx_mm_qkv_2d_repacked_q4_1_flat; break; + case HTP_TYPE_Q8_0: matmul_job_func = hvx_mm_qkv_2d_repacked_q8_0_flat; break; + case HTP_TYPE_IQ4_NL: matmul_job_func = hvx_mm_qkv_2d_repacked_iq4nl_flat; break; + case HTP_TYPE_MXFP4: matmul_job_func = hvx_mm_qkv_2d_repacked_mxfp4_flat; break; + default: return HTP_STATUS_NO_SUPPORT; + } + } else { + switch (src0->type) { + case HTP_TYPE_Q4_0: matmul_job_func = hvx_mm_qkv_2d_repacked_q4_0; break; + case HTP_TYPE_Q4_1: matmul_job_func = hvx_mm_qkv_2d_repacked_q4_1; break; + case HTP_TYPE_Q8_0: matmul_job_func = hvx_mm_qkv_2d_repacked_q8_0; break; + case HTP_TYPE_IQ4_NL: matmul_job_func = hvx_mm_qkv_2d_repacked_iq4nl; break; + case HTP_TYPE_MXFP4: matmul_job_func = hvx_mm_qkv_2d_repacked_mxfp4; break; + default: return HTP_STATUS_NO_SUPPORT; + } + } + } else { + matmul_job_func = hvx_mm_qkv_2d; + } + worker_pool_run_func(octx->ctx->worker_pool, matmul_job_func, mmctx, n_matmul_jobs); + + return HTP_STATUS_OK; +} + +int op_matmul_ffn(struct htp_ops_context * octx) { + const struct htp_tensor * restrict src0 = octx->src[0]; // Wgate + const struct htp_tensor * restrict src1 = octx->src[1]; // y + const struct htp_tensor * restrict src2 = octx->src[2]; // Wup + const struct htp_tensor * restrict dst_gate = octx->dsts[0]; + const struct htp_tensor * restrict dst_up = octx->dsts[1]; + + bool is_repacked = (src0->type == HTP_TYPE_Q4_0 || src0->type == HTP_TYPE_Q4_1 || + src0->type == HTP_TYPE_Q8_0 || src0->type == HTP_TYPE_IQ4_NL || + src0->type == HTP_TYPE_MXFP4); + + struct htp_mm_context mmctx_struct = {0}; + struct htp_mm_context * mmctx = &mmctx_struct; + mmctx->octx = octx; + + const struct htp_mm_kernel_params * kparams = (const struct htp_mm_kernel_params *) octx->kernel_params; + + const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3]; + const uint32_t src1_nrows = src1->ne[1] * src1->ne[2] * src1->ne[3]; + + // Compute src0_nrows_per_thread + mmctx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads; + if (is_repacked) { + mmctx->src0_nrows_per_thread = hex_round_up(mmctx->src0_nrows_per_thread, 32); + } else { + mmctx->src0_nrows_per_thread += (mmctx->src0_nrows_per_thread & 1); // round up to even + } + + const size_t src0_row_size = src0->nb[1]; + const size_t src0_row_size_padded = hex_round_up(src0_row_size, 128); + + if (hvx_mm_init_vec_dot(mmctx, src0->type) != 0) { + return HTP_STATUS_NO_SUPPORT; + } + + const uint32_t qk = QK_Q8_0_TILED; + const uint32_t nb = (src1->ne[0] + qk - 1) / qk; + const uint32_t total_nb = src1_nrows * nb; + + worker_callback_t quant_job_func; + uint32_t n_quant_jobs = 1; + if (kparams->kernel_type == HTP_MM_KERNEL_HVX_QUANT_ROW_FLAT) { + n_quant_jobs = MIN(src1_nrows, octx->n_threads); + quant_job_func = (src0->type == HTP_TYPE_Q4_1) ? quantize_f32_q8_1_flat : quantize_f32_q8_0_flat; + } else if (src1_nrows < octx->n_threads) { + n_quant_jobs = MIN(total_nb, octx->n_threads); + quant_job_func = (src0->type == HTP_TYPE_Q4_1) ? quantize_f32_q8_1_tiled_block : quantize_f32_q8_0_tiled_block; + for (uint32_t ith = 0; ith < n_quant_jobs; ++ith) { + uint32_t ib_first = (total_nb * (ith + 0)) / n_quant_jobs; + uint32_t ib_last = (total_nb * (ith + 1)) / n_quant_jobs; + mmctx->quant_ib_first[ith] = ib_first; + mmctx->quant_ib_last[ith] = ib_last; + mmctx->quant_r[ith] = ib_first / nb; + mmctx->quant_c[ith] = ib_first % nb; + } + } else { + n_quant_jobs = MIN(src1_nrows, octx->n_threads); + quant_job_func = (src0->type == HTP_TYPE_Q4_1) ? quantize_f32_q8_1_tiled : quantize_f32_q8_0_tiled; + } + + size_t src1_row_size; + if (kparams->kernel_type == HTP_MM_KERNEL_HVX_QUANT_ROW_FLAT) { + src1_row_size = (src0->type == HTP_TYPE_Q4_1) ? htp_mm_q8_1_flat_row_size(src1->ne[0]) : htp_mm_q8_0_flat_row_size(src1->ne[0]); + } else { + src1_row_size = (src0->type == HTP_TYPE_Q4_1) ? htp_mm_q8_1_tiled_row_size(src1->ne[0]) : htp_mm_q8_0_tiled_row_size(src1->ne[0]); + } + + // Set up scratchpads using precomputed sizes from the host + size_t src0_sz = kparams->vtcm_src0_size; + size_t src1_sz = kparams->vtcm_src1_size; + size_t src2_sz = kparams->vtcm_src2_size; + size_t vtcm_size = kparams->vtcm_size; + + size_t src0_sz_per_thread = src0_sz / octx->n_threads; + size_t src1_sz_per_thread = src1_sz; + size_t src2_sz_per_thread = src2_sz / octx->n_threads; + + if (octx->ctx->vtcm_size < vtcm_size) { + FARF(ERROR, "matmul-ffn: current VTCM reservation %zu is too small, needed %zu\n", octx->ctx->vtcm_size, vtcm_size); + return HTP_STATUS_VTCM_TOO_SMALL; + } + + uint8_t * vtcm_ptr = (uint8_t *) octx->ctx->vtcm_base; + mmctx->vtcm_src1 = vtcm_seq_alloc(&vtcm_ptr, src1_sz); + mmctx->vtcm_src0 = vtcm_seq_alloc(&vtcm_ptr, src0_sz); + mmctx->vtcm_src2 = vtcm_seq_alloc(&vtcm_ptr, src2_sz); + + octx->src1_spad.src = NULL; + octx->src0_spad.src = NULL; + octx->src2_spad.src = NULL; + + mmctx->vtcm_src0_stride = is_repacked ? 0 : src0_row_size_padded; + mmctx->vtcm_src2_stride = is_repacked ? 0 : src0_row_size_padded; + mmctx->vtcm_src1_stride = src1_row_size; + + mmctx->vtcm_src0_size_per_thread = src0_sz_per_thread; + mmctx->vtcm_src1_size_per_thread = src1_sz_per_thread; + mmctx->vtcm_src2_size_per_thread = src2_sz_per_thread; + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) + return HTP_STATUS_OK; + + // Run quantization once + mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs; + worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs); + + // Run fused matmul + const uint32_t n_matmul_jobs = octx->n_threads; + worker_callback_t matmul_job_func; + if (is_repacked) { + if (kparams->kernel_type == HTP_MM_KERNEL_HVX_QUANT_ROW_FLAT) { + switch (src0->type) { + case HTP_TYPE_Q4_0: matmul_job_func = hvx_mm_ffn_2d_repacked_q4_0_flat; break; + case HTP_TYPE_Q4_1: matmul_job_func = hvx_mm_ffn_2d_repacked_q4_1_flat; break; + case HTP_TYPE_Q8_0: matmul_job_func = hvx_mm_ffn_2d_repacked_q8_0_flat; break; + case HTP_TYPE_IQ4_NL: matmul_job_func = hvx_mm_ffn_2d_repacked_iq4nl_flat; break; + case HTP_TYPE_MXFP4: matmul_job_func = hvx_mm_ffn_2d_repacked_mxfp4_flat; break; + default: return HTP_STATUS_NO_SUPPORT; + } + } else { + switch (src0->type) { + case HTP_TYPE_Q4_0: matmul_job_func = hvx_mm_ffn_2d_repacked_q4_0; break; + case HTP_TYPE_Q4_1: matmul_job_func = hvx_mm_ffn_2d_repacked_q4_1; break; + case HTP_TYPE_Q8_0: matmul_job_func = hvx_mm_ffn_2d_repacked_q8_0; break; + case HTP_TYPE_IQ4_NL: matmul_job_func = hvx_mm_ffn_2d_repacked_iq4nl; break; + case HTP_TYPE_MXFP4: matmul_job_func = hvx_mm_ffn_2d_repacked_mxfp4; break; + default: return HTP_STATUS_NO_SUPPORT; + } + } + } else { + matmul_job_func = hvx_mm_ffn_2d; + } + worker_pool_run_func(octx->ctx->worker_pool, matmul_job_func, mmctx, n_matmul_jobs); - if (must_free_mapping) free(mapping_buf); return HTP_STATUS_OK; } diff --git a/ggml/src/ggml-hexagon/htp/matmul-ops.h b/ggml/src/ggml-hexagon/htp/matmul-ops.h new file mode 100644 index 0000000000..a94d5430da --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/matmul-ops.h @@ -0,0 +1,508 @@ +#ifndef HTP_MATMUL_OPS_H +#define HTP_MATMUL_OPS_H + +#include +#include +#include "htp-ops.h" +#include "hex-fastdiv.h" +#include "hex-common.h" + +#ifdef __cplusplus +extern "C" { +#endif + +// --- HMX Tile Constraints --- +#define HTP_MM_HMX_TILE_N_COLS 32 +#define HTP_MM_HMX_TILE_N_ROWS 32 +#define HTP_MM_HMX_TILE_SIZE (32 * 32 * sizeof(__fp16)) // 2048 bytes +#define HTP_MM_HMX_TILE_N_ELMS 1024 +#define HTP_MM_HMX_MIN_NROWS 4 + +// --- Weight Repacked Tile Sizes --- +#define HTP_MM_WEIGHT_TILE_SIZE_Q4_0 576 +#define HTP_MM_WEIGHT_TILE_SIZE_Q4_1 640 +#define HTP_MM_WEIGHT_TILE_SIZE_Q8_0 1088 +#define HTP_MM_WEIGHT_TILE_SIZE_IQ4_NL 576 +#define HTP_MM_WEIGHT_TILE_SIZE_MXFP4 544 + +// --- Weight Repacked Aligned Tile Sizes --- +#define HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_Q4_0 640 +#define HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_Q4_1 640 +#define HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_Q8_0 1152 +#define HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_IQ4_NL 640 +#define HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_MXFP4 640 + +// --- Activation Tiled Block Sizes (including padding) --- +#define HTP_MM_ACT_TILE_SIZE_Q8_0 1152 +#define HTP_MM_ACT_TILE_SIZE_Q8_1 1280 + +#define HTP_MM_MAX_PREFETCH 16 + +// --- Solver Cost Model Penalty Weights (HMX-specific) --- +#define HTP_MM_HMX_COST_W_DEQUANT 3 // cost penalty for quantized weight loading/dequantization +#define HTP_MM_HMX_COST_A_CONVERT 2 // cost penalty for activation loading/conversion + +// --- DMA Activation Transfer Configuration --- +#define HTP_MM_DMA_ACT_ROWS_PER_STEP 2 +#define HTP_MM_DMA_ACT_MULTIPLIER 4 + +enum htp_mm_kernel_type { + HTP_MM_KERNEL_UNSUPPORTED = 0, + + // HMX paths + HTP_MM_KERNEL_HMX_2D, + HTP_MM_KERNEL_HMX_F16_BATCHED, + + // HVX floating-point paths + HTP_MM_KERNEL_HVX_F16_F16_VTCM, + HTP_MM_KERNEL_HVX_F16_F16_DDR, + HTP_MM_KERNEL_HVX_F16_F32_DDR, + + HTP_MM_KERNEL_HVX_F32_F32_VTCM, + HTP_MM_KERNEL_HVX_F32_F32_DDR, + HTP_MM_KERNEL_HVX_F32_F16_DDR, + + // HVX quantized paths + HTP_MM_KERNEL_HVX_QUANT_ROW, // standard row-wise parallel quantization + HTP_MM_KERNEL_HVX_QUANT_BLOCK, // parallel block-wise quantization + HTP_MM_KERNEL_HVX_QUANT_ROW_FLAT, // row-wise fallback flat quantization +}; + +// Op-specific struct for precomputed matmul params +struct htp_mm_kernel_params { + int32_t kernel_type; // enum htp_mm_kernel_type + int32_t pipeline; // 1 = pipelined execution, 0 = standard + int32_t m_chunk; // Row chunk size (M chunk) + int32_t n_chunk; // Col chunk size (N chunk) + int32_t n_threads; // Number of threads to spawn + int32_t n_act_threads; // Number of threads for activation preparation + int32_t n_hmx; // 1 = use HMX, 0 = use HVX + int32_t n_prefetch; // Prefetch lookahead buffers/rows in VTCM + int32_t tile_size; // Weight tile size + int32_t aligned_tile_size; // Aligned weight tile size (padded to 128) + int32_t src1_row_size; // Row size for quantized activation + int32_t vtcm_size; // Total required scratchpad size in VTCM + int32_t vtcm_src0_size; // src0 scratchpad size in VTCM + int32_t vtcm_src1_size; // src1 scratchpad size in VTCM + int32_t vtcm_src2_size; // src2 scratchpad size in VTCM (fused only) + int32_t vtcm_src3_size; // src3 scratchpad size in VTCM (fused only) + int32_t vtcm_dst_size; // dst scratchpad size in VTCM + + // Precomputed division values + struct fastdiv_values div_ne12_ne1; + struct fastdiv_values div_ne1; + struct fastdiv_values div_r2; + struct fastdiv_values div_r3; + struct fastdiv_values div_ne11; +}; + +#if defined(__cplusplus) +static_assert(sizeof(struct htp_mm_kernel_params) <= 128, "htp_matmul_kernel_params is too large for kernel_params blob"); +#else +_Static_assert(sizeof(struct htp_mm_kernel_params) <= 128, "htp_matmul_kernel_params is too large for kernel_params blob"); +#endif + +struct mmid_row_mapping { + uint32_t i1; + uint32_t i2; +}; + +// Search for optimal (mc, nc) chunk sizes within VTCM budget. +static inline int htp_mm_hmx_compute_chunks(size_t vtcm_total, + size_t overhead, + size_t per_n_cost, + size_t per_m_cost, + size_t per_mn_cost, + size_t m, + size_t n, + size_t m_block_cost, + size_t n_block_cost, + size_t * m_chunk_out, + size_t * n_chunk_out, + size_t * total_out) { + if (m == 0 || n == 0) return -1; + if (vtcm_total <= overhead) return -1; + if (per_n_cost == 0 || per_m_cost == 0 || per_mn_cost == 0) return -1; + + const size_t usable = vtcm_total - overhead; + + size_t best_cost = SIZE_MAX; + size_t best_mn = 0; + size_t best_m = 0, best_n = 0; + + const size_t n_max = hex_align_down((size_t)n, HTP_MM_HMX_TILE_N_COLS); + for (size_t nc = n_max; nc >= HTP_MM_HMX_TILE_N_COLS; nc -= HTP_MM_HMX_TILE_N_COLS) { + size_t n_fixed = 0, ncmn = 0, mc_denom = 0; + if (hex_mul_overflow(nc, per_n_cost, &n_fixed)) continue; + if (n_fixed >= usable) goto next_nc; + + if (hex_mul_overflow(nc, per_mn_cost, &ncmn)) goto next_nc; + if (hex_add_overflow(per_m_cost, ncmn, &mc_denom) || mc_denom == 0) goto next_nc; + + { + size_t remain = usable - n_fixed; + size_t mc = remain / mc_denom; + mc = hex_align_down(mc, HTP_MM_HMX_TILE_N_ROWS); + mc = hex_smin(mc, m); + + if (mc == 0) { + goto next_nc; + } + + size_t mblocks = ((size_t) m + mc - 1) / mc; + size_t nblocks = ((size_t) n + nc - 1) / nc; + size_t cost = mblocks * m_block_cost + nblocks * n_block_cost; + size_t mn = mc * nc; + if (cost < best_cost || (cost == best_cost && mn > best_mn)) { + best_cost = cost; + best_mn = mn; + best_m = mc; + best_n = nc; + } + } + +next_nc: + if (nc == HTP_MM_HMX_TILE_N_COLS) break; // avoid size_t underflow + } + + if (best_m == 0 || best_n == 0) return -1; + + // Compute exact total (with overflow checks) + size_t t0 = 0, t1 = 0, t2 = 0, mn = 0, total = 0; + if (hex_mul_overflow(best_n, per_n_cost, &t0)) return -1; + if (hex_mul_overflow(best_m, per_m_cost, &t1)) return -1; + if (hex_mul_overflow(best_m, best_n, &mn)) return -1; + if (hex_mul_overflow(mn, per_mn_cost, &t2)) return -1; + if (hex_add_overflow(t0, t1, &total)) return -1; + if (hex_add_overflow(total, t2, &total)) return -1; + if (hex_add_overflow(total, overhead, &total)) return -1; + + *m_chunk_out = best_m; + *n_chunk_out = best_n; + *total_out = total; + return 0; +} + +// --- Tile Size Helpers --- +static inline uint32_t htp_mm_get_weight_tile_size(int weight_type) { + switch (weight_type) { + case HTP_TYPE_Q4_0: + case HTP_TYPE_IQ4_NL: + return HTP_MM_WEIGHT_TILE_SIZE_Q4_0; + case HTP_TYPE_Q4_1: + return HTP_MM_WEIGHT_TILE_SIZE_Q4_1; + case HTP_TYPE_Q8_0: + return HTP_MM_WEIGHT_TILE_SIZE_Q8_0; + case HTP_TYPE_MXFP4: + return HTP_MM_WEIGHT_TILE_SIZE_MXFP4; + default: + return 0; + } +} + +static inline uint32_t htp_mm_get_weight_aligned_tile_size(int weight_type) { + switch (weight_type) { + case HTP_TYPE_Q4_0: + case HTP_TYPE_IQ4_NL: + return HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_Q4_0; + case HTP_TYPE_Q4_1: + return HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_Q4_1; + case HTP_TYPE_Q8_0: + return HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_Q8_0; + case HTP_TYPE_MXFP4: + return HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_MXFP4; + default: + return 0; + } +} + +// --- Activation/Row Size Helpers --- +static inline size_t htp_mm_q8_0_tiled_row_size(uint32_t ne) { + const uint32_t ne_padded = ((ne + 127) / 128) * 128; + const uint32_t nb_32 = ne_padded / 32; + return nb_32 * HTP_MM_ACT_TILE_SIZE_Q8_0; +} + +static inline size_t htp_mm_q8_1_tiled_row_size(uint32_t ne) { + const uint32_t ne_padded = ((ne + 127) / 128) * 128; + const uint32_t nb_32 = ne_padded / 32; + return nb_32 * HTP_MM_ACT_TILE_SIZE_Q8_1; +} + +static inline size_t htp_mm_q8_0_flat_row_size(uint32_t ne) { + const uint32_t quants_size = hex_align_up(ne, 128); + const uint32_t num_scales = (ne + 31) / 32; + const uint32_t scales_size = hex_align_up(num_scales * 2, 128); + return quants_size + scales_size; +} + +static inline size_t htp_mm_q8_1_flat_row_size(uint32_t ne) { + const uint32_t quants_size = hex_align_up(ne, 128); + const uint32_t num_scales = (ne + 31) / 32; + const uint32_t scales_size = hex_align_up(num_scales * 4, 128); + return quants_size + scales_size; +} + +static inline size_t htp_mm_get_tiled_row_stride(int weight_type, uint32_t k) { + uint32_t nb = (k + QK_Q4_0_TILED - 1) / QK_Q4_0_TILED; + switch (weight_type) { + case HTP_TYPE_Q4_0: + case HTP_TYPE_IQ4_NL: + case HTP_TYPE_Q4_1: + case HTP_TYPE_Q8_0: + case HTP_TYPE_MXFP4: + return (size_t) nb * htp_mm_get_weight_tile_size(weight_type); + case HTP_TYPE_F16: + return (size_t) k * sizeof(__fp16); + case HTP_TYPE_F32: + return (size_t) k * sizeof(float); + default: + return 0; + } +} + +static inline size_t htp_mm_round_up(size_t n, size_t m) { + return ((n + m - 1) / m) * m; +} + +static inline bool htp_mm_hmx_pipeline(uint32_t m) { + return m > 32; +} + +static inline void htp_mm_hmx_get_2d_chunk_costs( + int wtype, uint32_t k, bool pipeline, uint32_t aligned_tile_size, + size_t * size_per_n_out, size_t * size_per_m_out, size_t * size_per_mn_out +) { + const bool is_quant = (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_F32); + const size_t row_stride = htp_mm_get_tiled_row_stride(wtype, k); + const size_t vec_dot_size = k * sizeof(uint16_t); + const uint32_t n_k_tiles = k / HTP_MM_HMX_TILE_N_COLS; + const size_t qweight_row_stride = is_quant ? (size_t)(n_k_tiles * aligned_tile_size) / 32 : 0; + + *size_per_n_out = (pipeline ? 2 : 1) * (is_quant ? qweight_row_stride : row_stride) + + (pipeline ? 2 * vec_dot_size : vec_dot_size); + *size_per_m_out = vec_dot_size; + *size_per_mn_out = (pipeline ? 2 : 1) * sizeof(uint16_t); +} + +static inline void htp_mm_hmx_get_batched_chunk_costs( + uint32_t k, uint32_t group_size, + size_t * size_per_n_out, size_t * size_per_m_out, size_t * size_per_mn_out +) { + const size_t vec_dot_size = k * sizeof(uint16_t); + *size_per_n_out = 3 * vec_dot_size; + *size_per_m_out = group_size * vec_dot_size; + *size_per_mn_out = sizeof(uint16_t); +} + +static inline size_t htp_mm_hmx_get_2d_vtcm_size( + int wtype, uint32_t k, size_t mc, size_t nc, bool pipeline, uint32_t act_threads, uint32_t aligned_tile_size +) { + const uint32_t n_k_tiles = k / HTP_MM_HMX_TILE_N_COLS; + const bool is_quant = (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_F32); + const size_t row_stride = htp_mm_get_tiled_row_stride(wtype, k); + const size_t vec_dot_size = k * sizeof(uint16_t); + + const size_t act_f32_size = htp_mm_round_up(act_threads * 4 * k * sizeof(float), HTP_MM_HMX_TILE_SIZE); + size_t weight_area_size = is_quant + ? htp_mm_round_up((nc / 32) * n_k_tiles * aligned_tile_size, HTP_MM_HMX_TILE_SIZE) + : htp_mm_round_up(nc * row_stride, HTP_MM_HMX_TILE_SIZE); + if (pipeline) { + weight_area_size *= 2; + } + const size_t act_area_size = htp_mm_round_up(mc * vec_dot_size, HTP_MM_HMX_TILE_SIZE); + const size_t output_area_size = htp_mm_round_up(mc * nc * sizeof(uint16_t), HTP_MM_HMX_TILE_SIZE); + + size_t scratch0_size = htp_mm_round_up(nc * vec_dot_size, HTP_MM_HMX_TILE_SIZE); + size_t scratch1_size = pipeline ? scratch0_size : 0; + size_t scratch2_size = pipeline ? output_area_size : 0; + + return weight_area_size + act_area_size + act_f32_size + output_area_size + + scratch0_size + scratch1_size + scratch2_size + 256; +} + +static inline size_t htp_mm_hmx_get_batched_vtcm_size( + int wtype, uint32_t k, size_t mc, size_t nc, uint32_t group_size, bool use_dma_activation, bool pipeline, uint32_t act_threads) { + (void)wtype; + (void)pipeline; + const size_t vec_dot_size = k * sizeof(uint16_t); + const size_t f32_scratch_size = use_dma_activation + ? htp_mm_round_up(act_threads * 4 * k * sizeof(float), HTP_MM_HMX_TILE_SIZE) : 0; + + const size_t act_head_stride = mc * k; + const size_t weight_area_size = htp_mm_round_up(nc * vec_dot_size, HTP_MM_HMX_TILE_SIZE); + const size_t act_area_size = htp_mm_round_up(group_size * act_head_stride * sizeof(uint16_t), HTP_MM_HMX_TILE_SIZE); + const size_t output_area_size = htp_mm_round_up(group_size * mc * nc * sizeof(uint16_t), HTP_MM_HMX_TILE_SIZE); + const size_t scratch_area_size = htp_mm_round_up(nc * vec_dot_size, HTP_MM_HMX_TILE_SIZE); + + return weight_area_size + act_area_size + output_area_size + + 2 * scratch_area_size + 256 + f32_scratch_size; +} + +static inline size_t htp_mm_hvx_get_vtcm_sizes( + int kernel_type, + int wtype, + uint32_t ne10, // k + uint32_t src1_nrows, // m_total (or act_nrows) + uint32_t n_threads, + size_t dst_row_size, + size_t src0_row_size, + size_t src1_row_size, + uint32_t n_prefetch, + size_t * vtcm_src0_size_out, + size_t * vtcm_src1_size_out, + size_t * vtcm_dst_size_out +) { + size_t vtcm_src0_size = 0; + size_t vtcm_src1_size = 0; + size_t vtcm_dst_size = 0; + + const bool is_repack = (wtype == HTP_TYPE_Q4_0 || wtype == HTP_TYPE_Q4_1 || + wtype == HTP_TYPE_Q8_0 || wtype == HTP_TYPE_IQ4_NL || + wtype == HTP_TYPE_MXFP4); + + const size_t src0_row_size_padded = htp_mm_round_up(src0_row_size, 128); + const size_t dst_nrows = (src1_nrows > 1) ? 0 : 1; + + switch (kernel_type) { + case HTP_MM_KERNEL_HVX_F16_F16_VTCM: { + size_t f16_src1_row_size = htp_mm_round_up(ne10 * 2, 128); + vtcm_src1_size = htp_mm_round_up(f16_src1_row_size * src1_nrows, 256); + vtcm_src0_size = htp_mm_round_up(n_prefetch * src0_row_size_padded, 256) * n_threads; + vtcm_dst_size = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) * n_threads : 0; + break; + } + case HTP_MM_KERNEL_HVX_F16_F32_DDR: + case HTP_MM_KERNEL_HVX_F16_F16_DDR: + case HTP_MM_KERNEL_HVX_F32_F32_DDR: + case HTP_MM_KERNEL_HVX_F32_F16_DDR: { + vtcm_src0_size = htp_mm_round_up(n_prefetch * src0_row_size, 256) * n_threads; + vtcm_src1_size = htp_mm_round_up(n_prefetch * src1_row_size, 256) * n_threads; + vtcm_dst_size = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) * n_threads : 0; + break; + } + case HTP_MM_KERNEL_HVX_F32_F32_VTCM: { + size_t f32_src1_row_size = htp_mm_round_up(ne10 * 4, 128); + vtcm_src1_size = htp_mm_round_up(f32_src1_row_size * src1_nrows, 256); + vtcm_src0_size = htp_mm_round_up(n_prefetch * src0_row_size_padded, 256) * n_threads; + vtcm_dst_size = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) * n_threads : 0; + break; + } + case HTP_MM_KERNEL_HVX_QUANT_BLOCK: + case HTP_MM_KERNEL_HVX_QUANT_ROW: { + size_t q_src1_row_size = (wtype == HTP_TYPE_Q4_1) ? htp_mm_q8_1_tiled_row_size(ne10) : htp_mm_q8_0_tiled_row_size(ne10); + + vtcm_dst_size = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) : 0; + vtcm_src0_size = htp_mm_round_up(n_prefetch * src0_row_size_padded, 256); + vtcm_src1_size = htp_mm_round_up(q_src1_row_size * src1_nrows, 256); + + // src0 spad is also used in dynamic quantizer to store padded src1 rows + size_t src1_row_size_padded = htp_mm_round_up(q_src1_row_size, QK_Q8_0_TILED * sizeof(float)); + if (vtcm_src0_size < src1_row_size_padded) { + vtcm_src0_size = src1_row_size_padded; + } + + vtcm_src0_size = vtcm_src0_size * n_threads; + vtcm_dst_size = vtcm_dst_size * n_threads; + + if (is_repack) { + uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype); + uint32_t n_k_tiles = ne10 / 32; + uint32_t tile_row_size = n_k_tiles * aligned_tile_size; + size_t repacked_vtcm_size = htp_mm_round_up(n_prefetch * tile_row_size, 256); + if (repacked_vtcm_size < src1_row_size_padded) { + repacked_vtcm_size = src1_row_size_padded; + } + vtcm_src0_size = repacked_vtcm_size * n_threads; + } + break; + } + case HTP_MM_KERNEL_HVX_QUANT_ROW_FLAT: { + size_t q_src1_row_size = (wtype == HTP_TYPE_Q4_1) ? htp_mm_q8_1_flat_row_size(ne10) : htp_mm_q8_0_flat_row_size(ne10); + + vtcm_dst_size = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) : 0; + vtcm_src0_size = htp_mm_round_up(n_prefetch * src0_row_size_padded, 256); + vtcm_src1_size = htp_mm_round_up(q_src1_row_size * src1_nrows, 256); + + size_t src1_row_size_padded = htp_mm_round_up(q_src1_row_size, 256); + if (vtcm_src0_size < src1_row_size_padded) { + vtcm_src0_size = src1_row_size_padded; + } + + vtcm_src0_size = vtcm_src0_size * n_threads; + vtcm_dst_size = vtcm_dst_size * n_threads; + + if (is_repack) { + uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype); + uint32_t n_k_tiles = ne10 / 32; + uint32_t tile_row_size = n_k_tiles * aligned_tile_size; + size_t repacked_vtcm_size = htp_mm_round_up(n_prefetch * tile_row_size, 256); + if (repacked_vtcm_size < src1_row_size_padded) { + repacked_vtcm_size = src1_row_size_padded; + } + vtcm_src0_size = repacked_vtcm_size * n_threads; + } + break; + } + default: + break; + } + + *vtcm_src0_size_out = vtcm_src0_size; + *vtcm_src1_size_out = vtcm_src1_size; + *vtcm_dst_size_out = vtcm_dst_size; + + return vtcm_src0_size + vtcm_src1_size + vtcm_dst_size; +} + +static inline size_t htp_mm_hvx_id_get_vtcm_sizes( + int wtype, + uint32_t ne10, // k + uint32_t src1_nrows, + uint32_t n_threads, + size_t src0_row_size, // nb01 + uint32_t n_prefetch, + size_t * vtcm_src0_size_out, + size_t * vtcm_src1_size_out +) { + const bool is_repack = (wtype == HTP_TYPE_Q4_0 || wtype == HTP_TYPE_Q4_1 || + wtype == HTP_TYPE_Q8_0 || wtype == HTP_TYPE_IQ4_NL || + wtype == HTP_TYPE_MXFP4); + + const size_t src0_row_size_padded = htp_mm_round_up(src0_row_size, 128); + const size_t src1_row_size = (wtype == HTP_TYPE_Q4_1) ? htp_mm_q8_1_tiled_row_size(ne10) + : htp_mm_q8_0_tiled_row_size(ne10); + + size_t src0_sz_per_thread = htp_mm_round_up(n_prefetch * src0_row_size_padded, 256); + size_t src1_sz = htp_mm_round_up(src1_row_size * src1_nrows, 256); + + // src0 spad also holds temporary transposed src1 columns during dynamic quantization. + const size_t src1_row_size_padded = htp_mm_round_up(src1_row_size, QK_Q8_0_TILED * sizeof(float)); + if (src0_sz_per_thread < src1_row_size_padded) { + src0_sz_per_thread = src1_row_size_padded; + } + + if (is_repack) { + const uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype); + const uint32_t n_k_tiles = ne10 / 32; + const uint32_t tile_row_size = n_k_tiles * aligned_tile_size; + size_t repacked_vtcm_size = htp_mm_round_up(n_prefetch * tile_row_size, 256); + if (repacked_vtcm_size < src1_row_size_padded) { + repacked_vtcm_size = src1_row_size_padded; + } + src0_sz_per_thread = repacked_vtcm_size; + } + + const size_t vtcm_src0_size = src0_sz_per_thread * n_threads; + + *vtcm_src0_size_out = vtcm_src0_size; + *vtcm_src1_size_out = src1_sz; + + return vtcm_src0_size + src1_sz; +} + +#ifdef __cplusplus +} +#endif + +#endif // HTP_MATMUL_OPS_H diff --git a/ggml/src/ggml-hexagon/libggml-htp.inf b/ggml/src/ggml-hexagon/libggml-htp.inf index 39cefcdda3..874dde1b88 100644 --- a/ggml/src/ggml-hexagon/libggml-htp.inf +++ b/ggml/src/ggml-hexagon/libggml-htp.inf @@ -14,8 +14,6 @@ Drivers_Dir = 13 1 = %DiskId% [SourceDisksFiles] -libggml-htp-v68.so = 1 -libggml-htp-v69.so = 1 libggml-htp-v73.so = 1 libggml-htp-v75.so = 1 libggml-htp-v79.so = 1 @@ -28,8 +26,6 @@ ExcludeFromSelect = * CopyFiles=Drivers_Dir [Drivers_Dir] -libggml-htp-v68.so,,,0x10 ;COPYFLG_NO_OVERWRITE -libggml-htp-v69.so,,,0x10 ;COPYFLG_NO_OVERWRITE libggml-htp-v73.so,,,0x10 ;COPYFLG_NO_OVERWRITE libggml-htp-v75.so,,,0x10 ;COPYFLG_NO_OVERWRITE libggml-htp-v79.so,,,0x10 ;COPYFLG_NO_OVERWRITE diff --git a/scripts/snapdragon/adb/run-completion.sh b/scripts/snapdragon/adb/run-completion.sh index fe14bb1422..f7622eb527 100755 --- a/scripts/snapdragon/adb/run-completion.sh +++ b/scripts/snapdragon/adb/run-completion.sh @@ -57,19 +57,25 @@ oppoll= opflt= [ "$OF" != "" ] && opflt="GGML_HEXAGON_OPFILTER=$OF" +opfuse= +[ "$OC" != "" ] && opfuse="GGML_HEXAGON_OPFUSION=$OC" + vmem= [ "$VM" != "" ] && vmem="GGML_HEXAGON_VMEM=$VM" mbuf= [ "$MB" != "" ] && mbuf="GGML_HEXAGON_MBUF=$MB" +mmsel= +[ "$MM" != "" ] && mmsel="GGML_HEXAGON_MM_SELECT=$MM" + set -x adb $adbserial $adbhost shell " \ cd $basedir; ulimit -c unlimited; \ LD_LIBRARY_PATH=$basedir/$branch/lib \ ADSP_LIBRARY_PATH=$basedir/$branch/lib \ - $verbose $sched $opmask $profile $nhvx $hmx $ndev $hb $opbatch $opqueue $oppoll $opflt $vmem $mbuf \ + $verbose $sched $opmask $profile $nhvx $hmx $ndev $hb $opbatch $opqueue $oppoll $opflt $opfuse $vmem $mbuf $mmsel \ ./$branch/bin/llama-completion --no-mmap -m $basedir/../gguf/$model \ --poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \ --ctx-size 8192 --ubatch-size 1024 -fa on \ diff --git a/scripts/snapdragon/adb/run-tool.sh b/scripts/snapdragon/adb/run-tool.sh index 6d7e32b321..f6332391bc 100755 --- a/scripts/snapdragon/adb/run-tool.sh +++ b/scripts/snapdragon/adb/run-tool.sh @@ -51,6 +51,12 @@ opqueue= oppoll= [ "$OP" != "" ] && oppoll="GGML_HEXAGON_OPPOLL=$OP" +opfuse= +[ "$OC" != "" ] && opfuse="GGML_HEXAGON_OPFUSION=$OC" + +mmsel= +[ "$MM" != "" ] && mmsel="GGML_HEXAGON_MM_SELECT=$MM" + set -x tool=$1; shift @@ -59,5 +65,5 @@ adb $adbserial $adbhost shell " \ cd $basedir; ulimit -c unlimited; \ LD_LIBRARY_PATH=$basedir/$branch/lib \ ADSP_LIBRARY_PATH=$basedir/$branch/lib \ - $verbose $sched $opmask $profile $nhvx $hmx $ndev $hb $opbatch $opqueue $oppoll ./$branch/bin/$tool $@ \ + $verbose $sched $opmask $profile $nhvx $hmx $ndev $hb $opbatch $opqueue $oppoll $opfuse $mmsel ./$branch/bin/$tool $@ \ " diff --git a/scripts/snapdragon/ggml-hexagon-profile.py b/scripts/snapdragon/ggml-hexagon-profile.py index 05045262f2..c53ad77793 100755 --- a/scripts/snapdragon/ggml-hexagon-profile.py +++ b/scripts/snapdragon/ggml-hexagon-profile.py @@ -26,7 +26,7 @@ COL_MAP = { } op_pattern = re.compile( - r"profile-op\s+(?P[A-Z_0-9+]+):\s+.*?\s+:\s+(?P[\d:x\s\->!]+)\s+:\s+(?P[a-z\d_\s\->x]+)\s+:\s+.*?\s+(?:op-)?usec\s+(?P\d+)\s+(?:op-)?cycles\s+(?P\d+)(?:\s+start\s+(?P\d+))?(?:\s+mhz\s+(?P[\d.]+))?(?:\s+pmu\s+\[(?P[\d,\s]+)\])?(?:\s+evt\s+\[(?P[\d,\s]+)\])?" + r"profile-op\s+(?P[A-Z_0-9+]+):\s+.*?\s+:\s+(?P[\d:x\s\->!]+)\s+:\s+(?P[a-z\d_\s\->x]+)\s+:\s+.*?\s+:\s+(?:op-)?usec\s+(?P\d+)\s+(?:op-)?cycles\s+(?P\d+)(?:\s+start\s+(?P\d+))?(?:\s+mhz\s+(?P[\d.]+))?(?:\s+pmu\s+\[(?P[\d,\s]+)\])?(?:\s+evt\s+\[(?P[\d,\s]+)\])?" ) trace_pattern = re.compile( @@ -93,9 +93,40 @@ def parse_log(file_path, pmu_index=None): + int(ts_match.group('us')) ) - op_match = op_pattern.search(line) + if "|" in line and "profile-op" in line: + parts = [p.strip() for p in line.split("|")] + prefix = parts[0] + prefix_match = re.search(r"profile-op\s+(?P[A-Z_0-9+]+)", prefix) + if not prefix_match: + continue + + if len(parts) == 7: + dims, types, timings = parts[2], parts[3], parts[6] + elif len(parts) == 6: + dims, types, timings = parts[2], parts[3], parts[5] + else: + continue + + timing_match = re.search( + r"(?:op-)?usec\s+(?P\d+)\s+(?:op-)?cycles\s+(?P\d+)(?:\s+start\s+(?P\d+))?(?:\s+mhz\s+(?P[\d.]+))?(?:\s+pmu\s+\[(?P[\d,\s]+)\])?(?:\s+evt\s+\[(?P[\d,\s]+)\])?", + timings + ) + if not timing_match: + continue + + op_match = timing_match + op_name = prefix_match.group("op_name") + else: + op_match = op_pattern.search(line) + if op_match: + op_name = op_match.group('op_name') + dims = op_match.group('dims').strip() + types = op_match.group('types').strip() + else: + op_match = None + if op_match: - pmu_raw = op_match.group('pmu') + pmu_raw = op_match.group('pmu') if 'pmu' in op_match.groupdict() else None pmu_val = None if pmu_raw and pmu_index is not None: try: @@ -105,7 +136,7 @@ def parse_log(file_path, pmu_index=None): except (ValueError, IndexError): pmu_val = None - evt_raw = op_match.group('evt') + evt_raw = op_match.group('evt') if 'evt' in op_match.groupdict() else None evt_val = None if evt_raw: try: @@ -122,9 +153,9 @@ def parse_log(file_path, pmu_index=None): op_text = line[idx + 11:].strip() if idx != -1 else line.strip() current_op = { - 'name': op_match.group('op_name'), - 'dims': op_match.group('dims').strip(), - 'types': op_match.group('types').strip(), + 'name': op_name, + 'dims': dims, + 'types': types, 'op_text': op_text, 'usec': int(op_match.group('usec')), 'cycles': int(op_match.group('cycles')), diff --git a/scripts/snapdragon/ggml-hexagon-trace.py b/scripts/snapdragon/ggml-hexagon-trace.py index 18ec440a9f..37f137a9e7 100755 --- a/scripts/snapdragon/ggml-hexagon-trace.py +++ b/scripts/snapdragon/ggml-hexagon-trace.py @@ -12,7 +12,7 @@ from collections import defaultdict logger = logging.getLogger("ggml-hexagon-trace") op_pattern = re.compile( - r"profile-op\s+(?P[A-Z_0-9+]+):\s+.*?\s+:\s+(?P[\d:x\s\->!]+)\s+:\s+(?P[a-z\d_\s\->x]+)\s+:\s+(?P[\d:x\s\->!]+)\s+:\s+(?:op-)?usec\s+(?P\d+)\s+(?:op-)?cycles\s+(?P\d+)(?:\s+start\s+(?P\d+))?(?:\s+mhz\s+(?P[\d.]+))?(?:\s+pmu\s+\[(?P[\d,\s]+)\])?(?:\s+evt\s+\[(?P[\d,\s]+)\])?" + r"profile-op\s+(?P[A-Z_0-9+]+):\s+.*?\s+:\s+(?P[\d:x\s\->!]+)\s+:\s+(?P[a-z\d_\s\->x]+)\s+:\s+(?P[\d:x\s\->!]+?)\s+:\s+(?:(?P.*?)\s+:\s+)?(?:op-)?usec\s+(?P\d+)\s+(?:op-)?cycles\s+(?P\d+)(?:\s+start\s+(?P\d+))?(?:\s+mhz\s+(?P[\d.]+))?(?:\s+pmu\s+\[(?P[\d,\s]+)\])?(?:\s+evt\s+\[(?P[\d,\s]+)\])?" ) trace_pattern = re.compile( @@ -66,7 +66,40 @@ def parse_log(file_path): for line in f: line_idx += 1 - op_match = op_pattern.search(line) + if "|" in line and "profile-op" in line: + parts = [p.strip() for p in line.split("|")] + prefix = parts[0] + prefix_match = re.search(r"profile-op\s+(?P[A-Z_0-9+]+)", prefix) + if not prefix_match: + continue + + if len(parts) == 7: + dims, types, strides, params, timings = parts[2], parts[3], parts[4], parts[5], parts[6] + elif len(parts) == 6: + dims, types, strides, params, timings = parts[2], parts[3], parts[4], "", parts[5] + else: + continue + + timing_match = re.search( + r"(?:op-)?usec\s+(?P\d+)\s+(?:op-)?cycles\s+(?P\d+)(?:\s+start\s+(?P\d+))?(?:\s+mhz\s+(?P[\d.]+))?(?:\s+pmu\s+\[(?P[\d,\s]+)\])?(?:\s+evt\s+\[(?P[\d,\s]+)\])?", + timings + ) + if not timing_match: + continue + + op_match = timing_match + op_name = prefix_match.group("op_name") + else: + op_match = op_pattern.search(line) + if op_match: + op_name = op_match.group('op_name') + dims = op_match.group('dims').strip() if op_match.group('dims') else '' + types = op_match.group('types').strip() if op_match.group('types') else '' + strides = op_match.group('strides').strip() if op_match.group('strides') else '' + params = op_match.group('params').strip() if ('params' in op_match.groupdict() and op_match.group('params')) else '' + else: + op_match = None + if op_match: cycles_start_raw = op_match.group('start') unwrapped_cycles_start = None @@ -77,10 +110,11 @@ def parse_log(file_path): op_text = line[idx + 11:].strip() if idx != -1 else line.strip() current_op = { - 'name': op_match.group('op_name'), - 'dims': op_match.group('dims').strip() if op_match.group('dims') else '', - 'types': op_match.group('types').strip() if op_match.group('types') else '', - 'strides': op_match.group('strides').strip() if op_match.group('strides') else '', + 'name': op_name, + 'dims': dims, + 'types': types, + 'strides': strides, + 'params': params, 'op_text': op_text, 'usec': int(op_match.group('usec')), 'cycles': int(op_match.group('cycles')), @@ -397,6 +431,8 @@ def generate_perfetto_trace(filtered_ops, output_path): debug_annots.append(make_debug_annotation("line", int_val=op['line_num'])) if 'strides' in op and op['strides']: debug_annots.append(make_debug_annotation("strides", string_val=op['strides'])) + if 'params' in op and op['params'] and op['params'] != '----': + debug_annots.append(make_debug_annotation("params", string_val=op['params'])) # Slice Begin evt_begin = make_track_event(1, 2, name=f"{op['name']} ({op['dims']})", category="operator", debug_annotations=debug_annots) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index e1d3853e43..3f18dbe220 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -8420,6 +8420,11 @@ static std::vector> make_test_cases_eval() { } } + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_0, GGML_TYPE_F32, 2880, 32, 2880, {1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q8_0, GGML_TYPE_F32, 2880, 32, 2880, {1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_MXFP4, GGML_TYPE_F32, 2880, 32, 2880, {1, 1}, {1, 1})); + + #if 0 { // Test paths in OpenCL @@ -8594,6 +8599,7 @@ static std::vector> make_test_cases_eval() { // gpt-oss issue with Vulkan mmq_id test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_MXFP4, GGML_TYPE_F32, 32, 2, false, 2880, 32, 2880)); + test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_Q4_0, GGML_TYPE_F32, 32, 2, false, 2880, 32, 2880)); for (ggml_type type_a : all_types) { test_cases.emplace_back(new test_mul_mat_id(type_a, GGML_TYPE_F32, 4, 2, false, 64, 16, 3*ggml_blck_size(type_a)));