diff --git a/ggml/src/ggml-hexagon/htp/ssm-conv.c b/ggml/src/ggml-hexagon/htp/ssm-conv.c index d574da2e2b..a48bc9ed86 100644 --- a/ggml/src/ggml-hexagon/htp/ssm-conv.c +++ b/ggml/src/ggml-hexagon/htp/ssm-conv.c @@ -183,24 +183,25 @@ static inline void hvx_transpose_32x32_f32(HVX_Vector m[32]) { // transposed into VTCM. // // VTCM layouts (per thread): -// src1_T : {d_inner_per_thread, d_conv} — staged once per launch (small). -// src0_T : {d_inner_tile, ncs} — staged per d_inner-tile. +// src1_T : {d_inner_stride, d_conv} - staged once per launch (small). +// src0_T : {d_inner_tile, ncs} - staged per d_inner-tile. // // d_inner_tile is chosen so that per-thread VTCM stays under the budget. // Each thread iterates ceil(d_inner_per_thread d_inner_tile) tiles serially. #define HTP_SSM_CONV_VTCM_BUDGET (1u << 20) // 1 MiB per thread -// Scalar transpose: src1 {d_conv, d_inner} (DDR) -> {d_inner_per_thread, d_conv} (VTCM) +// Scalar transpose: src1 {d_conv, d_inner} (DDR) -> {d_inner_stride, d_conv} (VTCM) static inline void transpose_src1(const float * src1_data, uint32_t src1_stride_inner, uint32_t i1_off, uint32_t d_inner_per_thread, + uint32_t d_inner_stride, uint32_t d_conv, float * src1_T) { for (uint32_t i = 0; i < d_inner_per_thread; ++i) { const float * src_row = src1_data + (i1_off + i) * src1_stride_inner; for (uint32_t j = 0; j < d_conv; ++j) { - src1_T[j * d_inner_per_thread + i] = src_row[j]; + src1_T[j * d_inner_stride + i] = src_row[j]; } } } @@ -280,6 +281,7 @@ static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void } const uint32_t d_inner_per_thread = ir1 - ir0; + const uint32_t d_inner_stride = scctx->nrows_per_thread; const uint32_t d_inner_tile = scctx->d_inner_tile; const float * src0_data = (const float *) src0->data; @@ -290,8 +292,8 @@ static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void float * src0_T = (float *)(octx->src0_spad.data + ith * octx->src0_spad.size_per_thread); float * src1_T = (float *)(octx->src1_spad.data + ith * octx->src1_spad.size_per_thread); - // Stage src1 weights once into VTCM in {d_inner_per_thread, d_conv} layout. - transpose_src1(src1_data, src1_stride_inner, ir0, d_inner_per_thread, d_conv, src1_T); + // Stage src1 weights once into VTCM in {d_inner_stride, d_conv} layout. + transpose_src1(src1_data, src1_stride_inner, ir0, d_inner_per_thread, d_inner_stride, d_conv, src1_T); const uint32_t C_TILE = VLEN_FP32; @@ -314,7 +316,7 @@ static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void HVX_Vector acc = hvx_vec_splat_f32(0.0f); for (uint32_t j = 0; j < d_conv; ++j) { HVX_Vector x = *(const HVX_Vector *) (src0_T + (t + j) * d_inner_tile + cb); - HVX_Vector w = *(const HVX_Vector *) (src1_T + j * d_inner_per_thread + tile_off + cb); + HVX_Vector w = *(const HVX_Vector *) (src1_T + j * d_inner_stride + tile_off + cb); acc = Q6_Vqf32_vadd_Vqf32Vqf32(acc, Q6_Vqf32_vmpy_VsfVsf(x, w)); } HVX_Vector res = Q6_Vsf_equals_Vqf32(acc); @@ -362,8 +364,7 @@ int op_ssm_conv_f32(struct htp_ops_context * octx) { use_hvx = 1; } - scctx.nrows_per_thread = (d_inner + n_threads - 1) / n_threads; - scctx.nrows_per_thread += (scctx.nrows_per_thread & 1); + scctx.nrows_per_thread = hex_round_up((d_inner + n_threads - 1) / n_threads, VLEN_FP32); const uint32_t d_inner_per_thread = scctx.nrows_per_thread; const uint32_t ncs = src0->ne[0];