fix(hexagon): use padded stride for ssm-conv weights (#24470)

This commit is contained in:
Guanhuai Zhang 2026-06-21 05:58:49 +08:00 committed by GitHub
parent 84de01a1f1
commit 4a80943174
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -183,24 +183,25 @@ static inline void hvx_transpose_32x32_f32(HVX_Vector m[32]) {
// transposed into VTCM.
//
// VTCM layouts (per thread):
// src1_T : {d_inner_per_thread, d_conv} — staged once per launch (small).
// src0_T : {d_inner_tile, ncs} staged per d_inner-tile.
// src1_T : {d_inner_stride, d_conv} - staged once per launch (small).
// src0_T : {d_inner_tile, ncs} - staged per d_inner-tile.
//
// d_inner_tile is chosen so that per-thread VTCM stays under the budget.
// Each thread iterates ceil(d_inner_per_thread d_inner_tile) tiles serially.
#define HTP_SSM_CONV_VTCM_BUDGET (1u << 20) // 1 MiB per thread
// Scalar transpose: src1 {d_conv, d_inner} (DDR) -> {d_inner_per_thread, d_conv} (VTCM)
// Scalar transpose: src1 {d_conv, d_inner} (DDR) -> {d_inner_stride, d_conv} (VTCM)
static inline void transpose_src1(const float * src1_data,
uint32_t src1_stride_inner,
uint32_t i1_off,
uint32_t d_inner_per_thread,
uint32_t d_inner_stride,
uint32_t d_conv,
float * src1_T) {
for (uint32_t i = 0; i < d_inner_per_thread; ++i) {
const float * src_row = src1_data + (i1_off + i) * src1_stride_inner;
for (uint32_t j = 0; j < d_conv; ++j) {
src1_T[j * d_inner_per_thread + i] = src_row[j];
src1_T[j * d_inner_stride + i] = src_row[j];
}
}
}
@ -280,6 +281,7 @@ static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void
}
const uint32_t d_inner_per_thread = ir1 - ir0;
const uint32_t d_inner_stride = scctx->nrows_per_thread;
const uint32_t d_inner_tile = scctx->d_inner_tile;
const float * src0_data = (const float *) src0->data;
@ -290,8 +292,8 @@ static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void
float * src0_T = (float *)(octx->src0_spad.data + ith * octx->src0_spad.size_per_thread);
float * src1_T = (float *)(octx->src1_spad.data + ith * octx->src1_spad.size_per_thread);
// Stage src1 weights once into VTCM in {d_inner_per_thread, d_conv} layout.
transpose_src1(src1_data, src1_stride_inner, ir0, d_inner_per_thread, d_conv, src1_T);
// Stage src1 weights once into VTCM in {d_inner_stride, d_conv} layout.
transpose_src1(src1_data, src1_stride_inner, ir0, d_inner_per_thread, d_inner_stride, d_conv, src1_T);
const uint32_t C_TILE = VLEN_FP32;
@ -314,7 +316,7 @@ static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void
HVX_Vector acc = hvx_vec_splat_f32(0.0f);
for (uint32_t j = 0; j < d_conv; ++j) {
HVX_Vector x = *(const HVX_Vector *) (src0_T + (t + j) * d_inner_tile + cb);
HVX_Vector w = *(const HVX_Vector *) (src1_T + j * d_inner_per_thread + tile_off + cb);
HVX_Vector w = *(const HVX_Vector *) (src1_T + j * d_inner_stride + tile_off + cb);
acc = Q6_Vqf32_vadd_Vqf32Vqf32(acc, Q6_Vqf32_vmpy_VsfVsf(x, w));
}
HVX_Vector res = Q6_Vsf_equals_Vqf32(acc);
@ -362,8 +364,7 @@ int op_ssm_conv_f32(struct htp_ops_context * octx) {
use_hvx = 1;
}
scctx.nrows_per_thread = (d_inner + n_threads - 1) / n_threads;
scctx.nrows_per_thread += (scctx.nrows_per_thread & 1);
scctx.nrows_per_thread = hex_round_up((d_inner + n_threads - 1) / n_threads, VLEN_FP32);
const uint32_t d_inner_per_thread = scctx.nrows_per_thread;
const uint32_t ncs = src0->ne[0];