From 7c908502ea0868e6ae913f79ba84ba844a5b386a Mon Sep 17 00:00:00 2001 From: Masashi Yoshimura Date: Tue, 23 Jun 2026 17:13:55 +0900 Subject: [PATCH] ggml-webgpu: improve MTP inference by using mat-vec path for small batches (#24811) * ggml-webgpu: improve small batches decoding * Add barrier to the NUM_COLS loop in mul-mat-vec --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 13 +- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 20 +- .../wgsl-shaders/mul_mat_id_vec.wgsl | 4 +- .../ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl | 88 +- .../wgsl-shaders/mul_mat_vec_acc.tmpl | 991 ++++++++++-------- .../wgsl-shaders/mul_mat_vec_q_acc.tmpl | 132 ++- .../ggml-webgpu/wgsl-shaders/quantize_q8.wgsl | 23 +- tests/test-backend-ops.cpp | 2 + 8 files changed, 682 insertions(+), 591 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 6f877f15ce..c00a2e9ee9 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -905,11 +905,12 @@ struct ggml_webgpu_mul_mat_vec_pipeline_key { ggml_type src0_type; ggml_type src1_type; int vectorized; + uint32_t num_cols; bool use_mmvq; bool operator==(const ggml_webgpu_mul_mat_vec_pipeline_key & other) const { return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized && - use_mmvq == other.use_mmvq; + num_cols == other.num_cols && use_mmvq == other.use_mmvq; } }; @@ -919,6 +920,7 @@ struct ggml_webgpu_mul_mat_vec_pipeline_key_hash { ggml_webgpu_hash_combine(seed, key.src0_type); ggml_webgpu_hash_combine(seed, key.src1_type); ggml_webgpu_hash_combine(seed, key.vectorized); + ggml_webgpu_hash_combine(seed, key.num_cols); ggml_webgpu_hash_combine(seed, key.use_mmvq); return seed; } @@ -993,11 +995,12 @@ struct ggml_webgpu_mul_mat_id_pipeline_key { ggml_type src0_type; ggml_type src1_type; uint32_t n_experts; + uint32_t num_cols; int vectorized; bool operator==(const ggml_webgpu_mul_mat_id_pipeline_key & other) const { return src0_type == other.src0_type && src1_type == other.src1_type && n_experts == other.n_experts && - vectorized == other.vectorized; + num_cols == other.num_cols && vectorized == other.vectorized; } }; @@ -1007,6 +1010,7 @@ struct ggml_webgpu_mul_mat_id_pipeline_key_hash { ggml_webgpu_hash_combine(seed, key.src0_type); ggml_webgpu_hash_combine(seed, key.src1_type); ggml_webgpu_hash_combine(seed, key.n_experts); + ggml_webgpu_hash_combine(seed, key.num_cols); ggml_webgpu_hash_combine(seed, key.vectorized); return seed; } @@ -1107,7 +1111,7 @@ inline bool ggml_webgpu_can_use_mmvq(const ggml_tensor * src0, const ggml_tensor * src1, bool supports_dot_product, const std::string & vendor) { - if (src1->ne[1] == 1) { + if (src1->ne[1] <= 4) { bool supports_dp4a = vendor == "amd" || vendor == "intel" || vendor == "nvidia"; if (supports_dp4a && supports_dot_product) { switch (src1->type) { @@ -1889,6 +1893,7 @@ class ggml_webgpu_shader_lib { (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? 1 : 0; + key.num_cols = context.dst->ne[1]; key.use_mmvq = ggml_webgpu_can_use_mmvq(context.src0, context.src1, context.supports_dot_product, context.vendor); @@ -2004,6 +2009,7 @@ class ggml_webgpu_shader_lib { if (key.vectorized) { variant += "_vectorized"; } + defines.push_back(std::string("NUM_COLS=") + std::to_string(key.num_cols)); auto processed = preprocessor.preprocess(shader_src, defines); auto decisions = std::make_shared(); @@ -2421,6 +2427,7 @@ class ggml_webgpu_shader_lib { if (key.vectorized) { variant += "_vectorized"; } + defines.push_back(std::string("NUM_COLS=1")); defines.push_back(std::string("N_EXPERTS=") + std::to_string(key.n_experts)); diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index f71d1aee73..e8eafd185a 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1418,15 +1418,17 @@ static void ggml_webgpu_quantize_q8_dispatch(webgpu_context & const size_t dst_offset = ggml_webgpu_tensor_offset(dst); const size_t q8_src1_align_offset = ROUNDUP_POW2( dst_offset + ggml_nbytes(dst), ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); - const size_t q8_src1_binding_size = - ROUNDUP_POW2(src1->ne[3] * src1->ne[2] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32)), - WEBGPU_STORAGE_BUF_BINDING_MULT); + const size_t q8_src1_binding_size = ROUNDUP_POW2( + src1->ne[3] * src1->ne[2] * src1->ne[1] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32)), + WEBGPU_STORAGE_BUF_BINDING_MULT); std::vector q8_params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), (uint32_t) src1->ne[0], + (uint32_t) src1->ne[1], (uint32_t) src1->ne[2], (uint32_t) src1->ne[3], }; @@ -1442,7 +1444,7 @@ static void ggml_webgpu_quantize_q8_dispatch(webgpu_context & uint32_t q8_wg_x = 1; uint32_t q8_wg_y = 1; const uint32_t wg_per_vec = (src0->ne[0] / 4 + (q8_wg_size - 1)) / q8_wg_size; - const uint32_t q8_total_wg = src1->ne[2] * src1->ne[3] * wg_per_vec; + const uint32_t q8_total_wg = src1->ne[1] * src1->ne[2] * src1->ne[3] * wg_per_vec; const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; compute_2d_workgroups(q8_total_wg, max_wg_per_dim, q8_wg_x, q8_wg_y); @@ -1456,7 +1458,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src1, ggml_tensor * dst) { // Determine if this is a mat-vec operation - bool is_vec = (dst->ne[1] == 1); + bool use_mat_vec = (dst->ne[1] <= 4); // use MMVQ path for mat-vec bool use_mmvq = ggml_webgpu_can_use_mmvq(src0, src1, ctx->global_ctx->capabilities.supports_dot_product, @@ -1482,7 +1484,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, webgpu_pipeline pipeline; std::vector dispatches; - if (is_vec) { + if (use_mat_vec) { if (use_mmvq) { ggml_webgpu_quantize_q8_dispatch(ctx, src0, src1, dst, dispatches); } @@ -1529,7 +1531,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, uint32_t wg_y = 1; const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; - if (is_vec) { + if (use_mat_vec) { auto * decisions = static_cast(pipeline.context.get()); uint32_t batches = dst->ne[2] * dst->ne[3]; @@ -3691,8 +3693,8 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer ggml_webgpu_can_use_mmvq(src0, src1, ctx->webgpu_global_ctx->capabilities.supports_dot_product, ctx->webgpu_global_ctx->vendor); if (use_mmvq) { - const size_t q8_src1_size = - src1->ne[3] * src1->ne[2] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32)); + const size_t q8_src1_size = src1->ne[3] * src1->ne[2] * src1->ne[1] * + (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32)); res = ROUNDUP_POW2(res + q8_src1_size + ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment, WEBGPU_STORAGE_BUF_BINDING_MULT); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl index 6ff9bcf2df..78ae955e6b 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl @@ -103,7 +103,7 @@ fn main( #ifdef USE_SUBGROUP_REDUCTION for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let subgroup_total = subgroupAdd(acc[row]); + let subgroup_total = subgroupAdd(acc[0][row]); if (subgroup_invocation_id == 0u) { partial_sums[partial_index(row, subgroup_id)] = subgroup_total; } @@ -126,7 +126,7 @@ fn main( #ifdef USE_WORKGROUP_REDUCTION for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - partial_sums[partial_index(row, thread_id)] = acc[row]; + partial_sums[partial_index(row, thread_id)] = acc[0][row]; } workgroupBarrier(); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl index f0a7fbd059..ebdf09513e 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl @@ -91,61 +91,67 @@ fn main( let dst_idx_base = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row_base; #ifdef MMVQ - let src1q_idx_base = (src13_idx * params.bs02 * params.broadcast2 + src12_idx) * (params.k / 32u); + let src1q_idx_base = (src13_idx * params.bs02 * params.broadcast2 + src12_idx) * params.n * (params.k / 32u); let acc = accumulate_vec_q_dot(thread_id, row_base, src0_batch_offset, src1q_idx_base); #else let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; let acc = accumulate_vec_dot(thread_id, row_base, src0_batch_offset, src1_idx_base); #endif + for (var col = 0u;col < NUM_COLS;col += 1) { + #ifdef USE_SUBGROUP_REDUCTION - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let subgroup_total = subgroupAdd(acc[row]); - if (subgroup_invocation_id == 0u) { - partial_sums[partial_index(row, subgroup_id)] = subgroup_total; - } - } + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let subgroup_total = subgroupAdd(acc[col][row]); + if (subgroup_invocation_id == 0u) { + partial_sums[partial_index(row, subgroup_id)] = subgroup_total; + } + } - workgroupBarrier(); + workgroupBarrier(); - for (var row = subgroup_id; (row < OUTPUTS_PER_WG) && (row_base + row < params.m); row += num_subgroups) { - let output_row = row_base + row; - var row_acc = 0.0f; - for (var k = subgroup_invocation_id; k < num_subgroups; k += subgroup_size) { - row_acc += partial_sums[partial_index(row, k)]; - } - let row_total = subgroupAdd(row_acc); - if (subgroup_invocation_id == 0) { - dst[dst_idx_base + row] = row_total; - } - } + for (var row = subgroup_id; (row < OUTPUTS_PER_WG) && (row_base + row < params.m); row += num_subgroups) { + let output_row = row_base + row; + var row_acc = 0.0f; + for (var k = subgroup_invocation_id; k < num_subgroups; k += subgroup_size) { + row_acc += partial_sums[partial_index(row, k)]; + } + let row_total = subgroupAdd(row_acc); + if (subgroup_invocation_id == 0) { + dst[dst_idx_base + col * params.m + row] = row_total; + } + } #endif #ifdef USE_WORKGROUP_REDUCTION - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - partial_sums[partial_index(row, thread_id)] = acc[row]; - } + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + partial_sums[partial_index(row, thread_id)] = acc[col][row]; + } + + workgroupBarrier(); + + var stride = WG_SIZE / 2u; + + while (stride > 0) { + if (thread_id < stride) { + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + partial_sums[partial_index(row, thread_id)] += partial_sums[partial_index(row, thread_id + stride)]; + } + } + + workgroupBarrier(); + stride = stride / 2; + } + + if (thread_id < OUTPUTS_PER_WG) { + let output_row = row_base + thread_id; + if (output_row < params.m) { + dst[dst_idx_base + col * params.m + thread_id] = partial_sums[partial_index(thread_id, 0)]; + } + } +#endif workgroupBarrier(); - var stride = WG_SIZE / 2u; - - while (stride > 0) { - if (thread_id < stride) { - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - partial_sums[partial_index(row, thread_id)] += partial_sums[partial_index(row, thread_id + stride)]; - } - } - - workgroupBarrier(); - stride = stride / 2; } - - if (thread_id < OUTPUTS_PER_WG) { - let output_row = row_base + thread_id; - if (output_row < params.m) { - dst[dst_idx_base + thread_id] = partial_sums[partial_index(thread_id, 0)]; - } - } -#endif } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl index 08753b9d64..b0703fe906 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl @@ -32,8 +32,8 @@ fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 { #endif #ifdef MUL_ACC_FLOAT -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let k_vec = params.k / VEC_SIZE; let src1_idx_base_vec = src1_idx_base / VEC_SIZE; @@ -41,12 +41,18 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src // Each thread walks K, loads from the vector, and updates // a small block of output rows held in registers. for (var k = thread_id; k < k_vec; k += WG_SIZE) { - let x = src1[src1_idx_base_vec + k]; + var x_vals: array; + for (var col = 0u;col < NUM_COLS;col += 1) { + x_vals[col] = src1[src1_idx_base_vec + col * (params.stride_11 / VEC_SIZE) + k]; + } for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { let src0_idx = (src0_batch_offset + output_row * params.stride_01) / VEC_SIZE + k; - acc[row] += inner_dot(src0[src0_idx], x); + let w = src0[src0_idx]; + for (var col = 0u;col < NUM_COLS;col += 1) { + acc[col][row] += inner_dot(w, x_vals[col]); + } } } } @@ -60,30 +66,33 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE_BYTES 18 #define THREADS_PER_BLOCK 16 #define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let num_blocks = params.k / BLOCK_SIZE; let thread_within_block = thread_id % THREADS_PER_BLOCK; for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD; i++) { - x_block[i] = f32(src1[x_base + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < ELEMS_PER_THREAD; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; let d = f32(load_f16_at_src0(block_byte_base)); let q_byte = load_u32_at_src0(block_byte_base + 2u + thread_within_block) & 0xFFu; - var row_sum = 0.0; - for (var bit = 0u; bit < 8u; bit++) { - let w = select(-d, d, ((q_byte >> bit) & 1u) != 0u); - row_sum += w * x_block[bit]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var bit = 0u; bit < 8u; bit++) { + let w = select(-d, d, ((q_byte >> bit) & 1u) != 0u); + row_sum += w * x_block[col][bit]; + } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -97,35 +106,37 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE_BYTES 18 #define THREADS_PER_BLOCK 4 #define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let num_blocks = params.k / BLOCK_SIZE; let thread_within_block = thread_id % 4; for (var block = thread_id/THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE/THREADS_PER_BLOCK) { let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4] = f32(src1[x_base + i + 16]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + x_block[col][i + 4] = f32(src1[x_base + col * params.stride_11 + i + 16]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; let d = f32(load_f16_at_src0(block_byte_base)); - var row_sum = 0.0; - let q_packed = load_u32_at_src0(block_byte_base + 2u + 4u * thread_within_block); - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_byte = get_byte(q_packed, byte_idx); - let q_lo = (f32(q_byte & 0xFu) - 8.0) * d; - let q_hi = (f32((q_byte >> 4u) & 0xFu) - 8.0) * d; - row_sum += q_lo * x_block[byte_idx]; - row_sum += q_hi * x_block[byte_idx + 4u]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let q_lo = (f32(q_byte & 0xFu) - 8.0) * d; + let q_hi = (f32((q_byte >> 4u) & 0xFu) - 8.0) * d; + row_sum += q_lo * x_block[col][byte_idx]; + row_sum += q_hi * x_block[col][byte_idx + 4u]; + } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -139,36 +150,38 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE_BYTES 20 #define THREADS_PER_BLOCK 4 #define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let num_blocks = params.k / BLOCK_SIZE; let thread_within_block = thread_id % THREADS_PER_BLOCK; for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4] = f32(src1[x_base + i + 16]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + x_block[col][i + 4] = f32(src1[x_base + col * params.stride_11 + i + 16]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; let d = f32(load_f16_at_src0(block_byte_base)); let m = f32(load_f16_at_src0(block_byte_base + 2u)); - var row_sum = 0.0; - let q_packed = load_u32_at_src0(block_byte_base + 4u + 4u * thread_within_block); - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_byte = get_byte(q_packed, byte_idx); - let q_lo = f32(q_byte & 0xFu) * d + m; - let q_hi = f32((q_byte >> 4u) & 0xFu) * d + m; - row_sum += q_lo * x_block[byte_idx]; - row_sum += q_hi * x_block[byte_idx + 4u]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let q_lo = f32(q_byte & 0xFu) * d + m; + let q_hi = f32((q_byte >> 4u) & 0xFu) * d + m; + row_sum += q_lo * x_block[col][byte_idx]; + row_sum += q_hi * x_block[col][byte_idx + 4u]; + } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -182,19 +195,20 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE_BYTES 22 #define THREADS_PER_BLOCK 4 #define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let num_blocks = params.k / BLOCK_SIZE; let thread_within_block = thread_id % THREADS_PER_BLOCK; for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4] = f32(src1[x_base + i + 16]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + x_block[col][i + 4] = f32(src1[x_base + col * params.stride_11 + i + 16]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -203,18 +217,19 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let qh_packed = load_u32_at_src0(block_byte_base + 2u); let q_packed = load_u32_at_src0(block_byte_base + 6u + 4u * thread_within_block); let qh_shift = thread_within_block * 4u; - var row_sum = 0.0; - - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_byte = get_byte(q_packed, byte_idx); - let qh_lo = ((qh_packed >> (qh_shift + byte_idx)) << 4u) & 0x10u; - let qh_hi = (qh_packed >> (qh_shift + byte_idx + 12u)) & 0x10u; - let q_lo = (f32((q_byte & 0xFu) | qh_lo) - 16.0) * d; - let q_hi = (f32(((q_byte >> 4u) & 0xFu) | qh_hi) - 16.0) * d; - row_sum += q_lo * x_block[byte_idx]; - row_sum += q_hi * x_block[byte_idx + 4u]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let qh_lo = ((qh_packed >> (qh_shift + byte_idx)) << 4u) & 0x10u; + let qh_hi = (qh_packed >> (qh_shift + byte_idx + 12u)) & 0x10u; + let q_lo = (f32((q_byte & 0xFu) | qh_lo) - 16.0) * d; + let q_hi = (f32(((q_byte >> 4u) & 0xFu) | qh_hi) - 16.0) * d; + row_sum += q_lo * x_block[col][byte_idx]; + row_sum += q_hi * x_block[col][byte_idx + 4u]; + } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -228,19 +243,20 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE_BYTES 24 #define THREADS_PER_BLOCK 4 #define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let num_blocks = params.k / BLOCK_SIZE; let thread_within_block = thread_id % THREADS_PER_BLOCK; for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4] = f32(src1[x_base + i + 16]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + x_block[col][i + 4] = f32(src1[x_base + col * params.stride_11 + i + 16]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -250,18 +266,19 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let qh_packed = load_u32_at_src0(block_byte_base + 4u); let q_packed = load_u32_at_src0(block_byte_base + 8u + 4u * thread_within_block); let qh_shift = thread_within_block * 4u; - var row_sum = 0.0; - - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_byte = get_byte(q_packed, byte_idx); - let qh_lo = ((qh_packed >> (qh_shift + byte_idx)) << 4u) & 0x10u; - let qh_hi = (qh_packed >> (qh_shift + byte_idx + 12u)) & 0x10u; - let q_lo = f32((q_byte & 0xFu) | qh_lo) * d + m; - let q_hi = f32(((q_byte >> 4u) & 0xFu) | qh_hi) * d + m; - row_sum += q_lo * x_block[byte_idx]; - row_sum += q_hi * x_block[byte_idx + 4u]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let qh_lo = ((qh_packed >> (qh_shift + byte_idx)) << 4u) & 0x10u; + let qh_hi = (qh_packed >> (qh_shift + byte_idx + 12u)) & 0x10u; + let q_lo = f32((q_byte & 0xFu) | qh_lo) * d + m; + let q_hi = f32(((q_byte >> 4u) & 0xFu) | qh_hi) * d + m; + row_sum += q_lo * x_block[col][byte_idx]; + row_sum += q_hi * x_block[col][byte_idx + 4u]; + } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -275,33 +292,38 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE_BYTES 34 #define THREADS_PER_BLOCK 4 #define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let num_blocks = params.k / BLOCK_SIZE; let thread_within_block = thread_id % THREADS_PER_BLOCK; for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD; i++) { - x_block[i] = f32(src1[x_base + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < ELEMS_PER_THREAD; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; let d = f32(load_f16_at_src0(block_byte_base)); - var row_sum = 0.0; - + var q_packed: array; for (var packed_idx = 0u; packed_idx < ELEMS_PER_THREAD / 4u; packed_idx++) { - let q_packed = load_u32_at_src0(block_byte_base + 2u + 4u * (thread_within_block * 2u + packed_idx)); - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_val = f32(get_byte_i32(q_packed, byte_idx)) * d; - row_sum += q_val * x_block[packed_idx * 4u + byte_idx]; - } + q_packed[packed_idx] = load_u32_at_src0(block_byte_base + 2u + 4u * (thread_within_block * 2u + packed_idx)); + } + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var packed_idx = 0u; packed_idx < ELEMS_PER_THREAD / 4u; packed_idx++) { + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_val = f32(get_byte_i32(q_packed[packed_idx], byte_idx)) * d; + row_sum += q_val * x_block[col][packed_idx * 4u + byte_idx]; + } + } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -315,34 +337,39 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE_BYTES 36 #define THREADS_PER_BLOCK 4 #define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let num_blocks = params.k / BLOCK_SIZE; let thread_within_block = thread_id % THREADS_PER_BLOCK; for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD; i++) { - x_block[i] = f32(src1[x_base + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < ELEMS_PER_THREAD; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; let d = f32(load_f16_at_src0(block_byte_base)); let m = f32(load_f16_at_src0(block_byte_base + 2u)); - var row_sum = 0.0; - + var q_packed: array; for (var packed_idx = 0u; packed_idx < ELEMS_PER_THREAD / 4u; packed_idx++) { - let q_packed = load_u32_at_src0(block_byte_base + 4u + 4u * (thread_within_block * 2u + packed_idx)); - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_val = f32(get_byte_i32(q_packed, byte_idx)) * d + m; - row_sum += q_val * x_block[packed_idx * 4u + byte_idx]; - } + q_packed[packed_idx] = load_u32_at_src0(block_byte_base + 4u + 4u * (thread_within_block * 2u + packed_idx)); + } + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var packed_idx = 0u; packed_idx < ELEMS_PER_THREAD / 4u; packed_idx++) { + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_val = f32(get_byte_i32(q_packed[packed_idx], byte_idx)) * d + m; + row_sum += q_val * x_block[col][packed_idx * 4u + byte_idx]; + } + } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -355,8 +382,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 84 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -379,14 +406,15 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 4u; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4u] = f32(src1[x_base + 32u + i]); - x_block[i + 8u] = f32(src1[x_base + 64u + i]); - x_block[i + 12u] = f32(src1[x_base + 96u + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < 4u; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + x_block[col][i + 4u] = f32(src1[x_base + col * params.stride_11 + 32u + i]); + x_block[col][i + 8u] = f32(src1[x_base + col * params.stride_11 + 64u + i]); + x_block[col][i + 12u] = f32(src1[x_base + col * params.stride_11 + 96u + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -404,30 +432,32 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let qs0 = q_u32 & 0xFFFFu; let qs1 = q_u32 >> 16u; - var sumy = vec4(0.0, 0.0, 0.0, 0.0); - var acc1 = vec4(0.0, 0.0, 0.0, 0.0); - var acc2 = vec4(0.0, 0.0, 0.0, 0.0); + for (var col = 0u;col < NUM_COLS;col += 1) { + var sumy = vec4(0.0, 0.0, 0.0, 0.0); + var acc1 = vec4(0.0, 0.0, 0.0, 0.0); + var acc2 = vec4(0.0, 0.0, 0.0, 0.0); - sumy[0] = x_block[0] + x_block[1] + x_block[2] + x_block[3]; - sumy[1] = x_block[4] + x_block[5] + x_block[6] + x_block[7]; - sumy[2] = x_block[8] + x_block[9] + x_block[10] + x_block[11]; - sumy[3] = x_block[12] + x_block[13] + x_block[14] + x_block[15]; + sumy[0] = x_block[col][0] + x_block[col][1] + x_block[col][2] + x_block[col][3]; + sumy[1] = x_block[col][4] + x_block[col][5] + x_block[col][6] + x_block[col][7]; + sumy[2] = x_block[col][8] + x_block[col][9] + x_block[col][10] + x_block[col][11]; + sumy[3] = x_block[col][12] + x_block[col][13] + x_block[col][14] + x_block[col][15]; - acc1[0] = x_block[0] * f32(qs0 & 0x0003u) + x_block[2] * f32(qs1 & 0x0003u); - acc2[0] = x_block[1] * f32(qs0 & 0x0300u) + x_block[3] * f32(qs1 & 0x0300u); - acc1[1] = x_block[4] * f32(qs0 & 0x000Cu) + x_block[6] * f32(qs1 & 0x000Cu); - acc2[1] = x_block[5] * f32(qs0 & 0x0C00u) + x_block[7] * f32(qs1 & 0x0C00u); - acc1[2] = x_block[8] * f32(qs0 & 0x0030u) + x_block[10] * f32(qs1 & 0x0030u); - acc2[2] = x_block[9] * f32(qs0 & 0x3000u) + x_block[11] * f32(qs1 & 0x3000u); - acc1[3] = x_block[12] * f32(qs0 & 0x00C0u) + x_block[14] * f32(qs1 & 0x00C0u); - acc2[3] = x_block[13] * f32(qs0 & 0xC000u) + x_block[15] * f32(qs1 & 0xC000u); + acc1[0] = x_block[col][0] * f32(qs0 & 0x0003u) + x_block[col][2] * f32(qs1 & 0x0003u); + acc2[0] = x_block[col][1] * f32(qs0 & 0x0300u) + x_block[col][3] * f32(qs1 & 0x0300u); + acc1[1] = x_block[col][4] * f32(qs0 & 0x000Cu) + x_block[col][6] * f32(qs1 & 0x000Cu); + acc2[1] = x_block[col][5] * f32(qs0 & 0x0C00u) + x_block[col][7] * f32(qs1 & 0x0C00u); + acc1[2] = x_block[col][8] * f32(qs0 & 0x0030u) + x_block[col][10] * f32(qs1 & 0x0030u); + acc2[2] = x_block[col][9] * f32(qs0 & 0x3000u) + x_block[col][11] * f32(qs1 & 0x3000u); + acc1[3] = x_block[col][12] * f32(qs0 & 0x00C0u) + x_block[col][14] * f32(qs1 & 0x00C0u); + acc2[3] = x_block[col][13] * f32(qs0 & 0xC000u) + x_block[col][15] * f32(qs1 & 0xC000u); - acc[row] += dall * ((acc1[0] + (1.0/256.0) * acc2[0]) * f32(sc0 & 0xFu) + - (acc1[1] + (1.0/256.0) * acc2[1]) * f32(sc2 & 0xFu) / 4.0 + - (acc1[2] + (1.0/256.0) * acc2[2]) * f32(sc4 & 0xFu) / 16.0 + - (acc1[3] + (1.0/256.0) * acc2[3]) * f32(sc6 & 0xFu) / 64.0) - - dmin * (sumy[0] * f32(sc0 & 0xF0u) + sumy[1] * f32(sc2 & 0xF0u) + - sumy[2] * f32(sc4 & 0xF0u) + sumy[3] * f32(sc6 & 0xF0u)); + acc[col][row] += dall * ((acc1[0] + (1.0/256.0) * acc2[0]) * f32(sc0 & 0xFu) + + (acc1[1] + (1.0/256.0) * acc2[1]) * f32(sc2 & 0xFu) / 4.0 + + (acc1[2] + (1.0/256.0) * acc2[2]) * f32(sc4 & 0xFu) / 16.0 + + (acc1[3] + (1.0/256.0) * acc2[3]) * f32(sc6 & 0xFu) / 64.0) + - dmin * (sumy[0] * f32(sc0 & 0xF0u) + sumy[1] * f32(sc2 & 0xF0u) + + sumy[2] * f32(sc4 & 0xF0u) + sumy[3] * f32(sc6 & 0xF0u)); + } } } } @@ -440,8 +470,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 110 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -485,12 +515,13 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 8u; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 8u] = f32(src1[x_base + 32u + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < 8u; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + x_block[col][i + 8u] = f32(src1[x_base + col * params.stride_11 + 32u + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -516,28 +547,30 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let h_u32_0 = load_u32_at_src0(block_byte_base + h_byte + 0u); let h_u32_1 = load_u32_at_src0(block_byte_base + h_byte + 4u); - var s1 = 0.0; var s2 = 0.0; var s3 = 0.0; - var s4 = 0.0; var s5 = 0.0; var s6 = 0.0; + for (var col = 0u;col < NUM_COLS;col += 1) { + var s1 = 0.0; var s2 = 0.0; var s3 = 0.0; + var s4 = 0.0; var s5 = 0.0; var s6 = 0.0; - for (var l = 0u; l < 8u; l += 2u) { - let q_u32 = select(q_u32_0, q_u32_1, l >= 4u); - let qs = select(q_u32 & 0xFFFFu, q_u32 >> 16u, (l & 2u) != 0u); - let h_u32 = select(h_u32_0, h_u32_1, l >= 4u); - let hv = select(h_u32 & 0xFFFFu, h_u32 >> 16u, (l & 2u) != 0u); + for (var l = 0u; l < 8u; l += 2u) { + let q_u32 = select(q_u32_0, q_u32_1, l >= 4u); + let qs = select(q_u32 & 0xFFFFu, q_u32 >> 16u, (l & 2u) != 0u); + let h_u32 = select(h_u32_0, h_u32_1, l >= 4u); + let hv = select(h_u32 & 0xFFFFu, h_u32 >> 16u, (l & 2u) != 0u); - s1 += x_block[l + 0u] * f32(qs & qm0); - s2 += x_block[l + 1u] * f32(qs & qm1); - s3 += select(0.0, x_block[l + 0u], (hv & hm0) == 0u) + - select(0.0, x_block[l + 1u], (hv & hm1) == 0u); - s4 += x_block[l + 8u] * f32(qs & qm2); - s5 += x_block[l + 9u] * f32(qs & qm3); - s6 += select(0.0, x_block[l + 8u], (hv & hm2) == 0u) + - select(0.0, x_block[l + 9u], (hv & hm3) == 0u); + s1 += x_block[col][l + 0u] * f32(qs & qm0); + s2 += x_block[col][l + 1u] * f32(qs & qm1); + s3 += select(0.0, x_block[col][l + 0u], (hv & hm0) == 0u) + + select(0.0, x_block[col][l + 1u], (hv & hm1) == 0u); + s4 += x_block[col][l + 8u] * f32(qs & qm2); + s5 += x_block[col][l + 9u] * f32(qs & qm3); + s6 += select(0.0, x_block[col][l + 8u], (hv & hm2) == 0u) + + select(0.0, x_block[col][l + 9u], (hv & hm3) == 0u); + } + + let d1 = d * (s1 + (1.0/256.0) * s2 - s3 * v1); + let d2 = d * (s4 + (1.0/256.0) * s5 - s6 * v2); + acc[col][row] += (d1 * scale0 + 0.25 * d2 * scale1) / f32(1u << shift); } - - let d1 = d * (s1 + (1.0/256.0) * s2 - s3 * v1); - let d2 = d * (s4 + (1.0/256.0) * s5 - s6 * v2); - acc[row] += (d1 * scale0 + 0.25 * d2 * scale1) / f32(1u << shift); } } } @@ -550,8 +583,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 144 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -573,12 +606,15 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 4u; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4u] = f32(src1[x_base + 32u + i]); - x_block[i + 8u] = f32(src1[x_base + 128u + i]); - x_block[i + 12u] = f32(src1[x_base + 160u + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + let col_base = x_base + col * params.stride_11; + for (var i = 0u; i < 4u; i++) { + x_block[col][i] = f32(src1[col_base + i]); + x_block[col][i + 4u] = f32(src1[col_base + 32u + i]); + x_block[col][i + 8u] = f32(src1[col_base + 128u + i]); + x_block[col][i + 12u] = f32(src1[col_base + 160u + i]); + } } for (var row = 0u; row < OUTPUTS_PER_WG; row++) { @@ -613,23 +649,25 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let q1_u32 = load_u32_at_src0_aligned(block_byte_base + 16u + q_offset); let q2_u32 = load_u32_at_src0_aligned(block_byte_base + 80u + q_offset); - var dot = vec4(0.0, 0.0, 0.0, 0.0); - var sumx = vec4(0.0, 0.0, 0.0, 0.0); - for (var i = 0u; i < 4u; i++) { - let q1b = byte_of(q1_u32, i); - let q2b = byte_of(q2_u32, i); - dot[0] += x_block[i] * f32(q1b & 0x0Fu); - dot[1] += x_block[i + 4u] * f32(q1b >> 4u); - dot[2] += x_block[i + 8u] * f32(q2b & 0x0Fu); - dot[3] += x_block[i + 12u] * f32(q2b >> 4u); - sumx[0] += x_block[i]; - sumx[1] += x_block[i + 4u]; - sumx[2] += x_block[i + 8u]; - sumx[3] += x_block[i + 12u]; - } + for (var col = 0u;col < NUM_COLS;col += 1) { + var dot = vec4(0.0, 0.0, 0.0, 0.0); + var sumx = vec4(0.0, 0.0, 0.0, 0.0); + for (var i = 0u; i < 4u; i++) { + let q1b = byte_of(q1_u32, i); + let q2b = byte_of(q2_u32, i); + dot[0] += x_block[col][i] * f32(q1b & 0x0Fu); + dot[1] += x_block[col][i + 4u] * f32(q1b >> 4u); + dot[2] += x_block[col][i + 8u] * f32(q2b & 0x0Fu); + dot[3] += x_block[col][i + 12u] * f32(q2b >> 4u); + sumx[0] += x_block[col][i]; + sumx[1] += x_block[col][i + 4u]; + sumx[2] += x_block[col][i + 8u]; + sumx[3] += x_block[col][i + 12u]; + } - acc[row] += d * (dot[0] * scale0 + dot[1] * scale1 + dot[2] * scale2 + dot[3] * scale3) - - dmin * (sumx[0] * min0 + sumx[1] * min1 + sumx[2] * min2 + sumx[3] * min3); + acc[col][row] += d * (dot[0] * scale0 + dot[1] * scale1 + dot[2] * scale2 + dot[3] * scale3) + - dmin * (sumx[0] * min0 + sumx[1] * min1 + sumx[2] * min2 + sumx[3] * min3); + } } } } @@ -642,8 +680,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 176 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -671,14 +709,16 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 4u; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4u] = f32(src1[x_base + 32u + i]); - x_block[i + 8u] = f32(src1[x_base + 128u + i]); - x_block[i + 12u] = f32(src1[x_base + 160u + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + let col_base = x_base + col * params.stride_11; + for (var i = 0u; i < 4u; i++) { + x_block[col][i] = f32(src1[col_base + i]); + x_block[col][i + 4u] = f32(src1[col_base + 32u + i]); + x_block[col][i + 8u] = f32(src1[col_base + 128u + i]); + x_block[col][i + 12u] = f32(src1[col_base + 160u + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -712,37 +752,39 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let q2_u32 = load_u32_at_src0_aligned(block_byte_base + q_offset + 64u); let qh_u32 = load_u32_at_src0_aligned(block_byte_base + qh_offset); - var vals = vec4(0.0, 0.0, 0.0, 0.0); - var sumy = vec4(0.0, 0.0, 0.0, 0.0); - for (var i = 0u; i < 4u; i++) { - let q1b = byte_of(q1_u32, i); - let q2b = byte_of(q2_u32, i); - let qhb = byte_of(qh_u32, i); + for (var col = 0u;col < NUM_COLS;col += 1) { + var vals = vec4(0.0, 0.0, 0.0, 0.0); + var sumy = vec4(0.0, 0.0, 0.0, 0.0); + for (var i = 0u; i < 4u; i++) { + let q1b = byte_of(q1_u32, i); + let q2b = byte_of(q2_u32, i); + let qhb = byte_of(qh_u32, i); - let yl0 = x_block[i]; - let yl8 = x_block[i + 4u]; - let yh0 = x_block[i + 8u]; - let yh8 = x_block[i + 12u]; + let yl0 = x_block[col][i]; + let yl8 = x_block[col][i + 4u]; + let yh0 = x_block[col][i + 8u]; + let yh8 = x_block[col][i + 12u]; - sumy[0] += yl0; - sumy[1] += yl8; - sumy[2] += yh0; - sumy[3] += yh8; + sumy[0] += yl0; + sumy[1] += yl8; + sumy[2] += yh0; + sumy[3] += yh8; - let q0 = f32((q1b & 0x0Fu) | select(0u, 0x10u, (qhb & hm1) != 0u)); - let q1 = f32((q1b >> 4u) | select(0u, 0x10u, (qhb & hm2) != 0u)); - let q2 = f32((q2b & 0x0Fu) | select(0u, 0x10u, (qhb & hm3) != 0u)); - let q3 = f32((q2b >> 4u) | select(0u, 0x10u, (qhb & hm4) != 0u)); + let q0 = f32((q1b & 0x0Fu) | select(0u, 0x10u, (qhb & hm1) != 0u)); + let q1 = f32((q1b >> 4u) | select(0u, 0x10u, (qhb & hm2) != 0u)); + let q2 = f32((q2b & 0x0Fu) | select(0u, 0x10u, (qhb & hm3) != 0u)); + let q3 = f32((q2b >> 4u) | select(0u, 0x10u, (qhb & hm4) != 0u)); - vals[0] += yl0 * q0; - vals[1] += yl8 * q1; - vals[2] += yh0 * q2; - vals[3] += yh8 * q3; + vals[0] += yl0 * q0; + vals[1] += yl8 * q1; + vals[2] += yh0 * q2; + vals[3] += yh8 * q3; + } + + acc[col][row] += d * (f0 * vals[0] + f1 * vals[1] + f4 * vals[2] + f5 * vals[3]) + - dmin * (sumy[0] * m0 + sumy[1] * m1 + + sumy[2] * m4 + sumy[3] * m5); } - - acc[row] += d * (f0 * vals[0] + f1 * vals[1] + f4 * vals[2] + f5 * vals[3]) - - dmin * (sumy[0] * m0 + sumy[1] * m1 + - sumy[2] * m4 + sumy[3] * m5); } } } @@ -755,8 +797,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 210 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -777,14 +819,16 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var l = 0u; l < 4u; l++) { - x_block[l] = f32(src1[x_base + l]); - x_block[l + 4u] = f32(src1[x_base + 32u + l]); - x_block[l + 8u] = f32(src1[x_base + 64u + l]); - x_block[l + 12u] = f32(src1[x_base + 96u + l]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + let col_base = x_base + col * params.stride_11; + for (var l = 0u; l < 4u; l++) { + x_block[col][l] = f32(src1[col_base + l]); + x_block[col][l + 4u] = f32(src1[col_base + 32u + l]); + x_block[col][l + 8u] = f32(src1[col_base + 64u + l]); + x_block[col][l + 12u] = f32(src1[col_base + 96u + l]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -802,26 +846,28 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let sc4 = sbyte_of(sc_u32_1, sc_byte_pos); let sc6 = sbyte_of(sc_u32_1, sc_byte_pos + 2u); - var sums = vec4(0.0, 0.0, 0.0, 0.0); + for (var col = 0u;col < NUM_COLS;col += 1) { + var sums = vec4(0.0, 0.0, 0.0, 0.0); - for (var l = 0u; l < 4u; l++) { - let q1b = byte_of(ql1_u32, l); - let q2b = byte_of(ql2_u32, l); - let qhb = byte_of(qh_u32, l); + for (var l = 0u; l < 4u; l++) { + let q1b = byte_of(ql1_u32, l); + let q2b = byte_of(ql2_u32, l); + let qhb = byte_of(qh_u32, l); - let dq0 = f32(i32((q1b & 0x0Fu) | ((qhb & 0x03u) << 4u)) - 32); - let dq1 = f32(i32((q2b & 0x0Fu) | ((qhb & 0x0Cu) << 2u)) - 32); - let dq2 = f32(i32((q1b >> 4u) | (qhb & 0x30u)) - 32); - let dq3 = f32(i32((q2b >> 4u) | ((qhb & 0xC0u) >> 2u)) - 32); + let dq0 = f32(i32((q1b & 0x0Fu) | ((qhb & 0x03u) << 4u)) - 32); + let dq1 = f32(i32((q2b & 0x0Fu) | ((qhb & 0x0Cu) << 2u)) - 32); + let dq2 = f32(i32((q1b >> 4u) | (qhb & 0x30u)) - 32); + let dq3 = f32(i32((q2b >> 4u) | ((qhb & 0xC0u) >> 2u)) - 32); - sums[0] += x_block[l] * dq0; - sums[1] += x_block[l + 4u] * dq1; - sums[2] += x_block[l + 8u] * dq2; - sums[3] += x_block[l + 12u] * dq3; + sums[0] += x_block[col][l] * dq0; + sums[1] += x_block[col][l + 4u] * dq1; + sums[2] += x_block[col][l + 8u] * dq2; + sums[3] += x_block[col][l + 12u] * dq3; + } + + acc[col][row] += d * (sums[0] * f32(sc0) + sums[1] * f32(sc2) + + sums[2] * f32(sc4) + sums[3] * f32(sc6)); } - - acc[row] += d * (sums[0] * f32(sc0) + sums[1] * f32(sc2) + - sums[2] * f32(sc4) + sums[3] * f32(sc6)); } } } @@ -834,8 +880,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 50 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -850,11 +896,12 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 16u; i++) { - x_block[i] = f32(src1[x_base + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < 16u; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -866,20 +913,22 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000u) != 0u); let qs_w = load_u32_at_src0(block_byte_base + 2u + sub_blk * 4u); - var row_sum = 0.0; - for (var ll = 0u; ll < 2u; ll++) { - let l = slot0 + ll; - let qs_byte = get_byte(qs_w, l); - let ig = (qs_byte | (((qh >> (3u * l)) & 7u) << 8u)) * 8u; - let gw = iq1_grid[ig / 16u]; - let bit_base = (ig % 16u) * 2u; - for (var j = 0u; j < 8u; j++) { - let g = (gw >> (bit_base + j * 2u)) & 3u; - let gs = select(f32(g), f32(g) - 4.0, (g & 2u) != 0u); - row_sum += dl * (gs + delta) * x_block[ll * 8u + j]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let qs_byte = get_byte(qs_w, l); + let ig = (qs_byte | (((qh >> (3u * l)) & 7u) << 8u)) * 8u; + let gw = iq1_grid[ig / 16u]; + let bit_base = (ig % 16u) * 2u; + for (var j = 0u; j < 8u; j++) { + let g = (gw >> (bit_base + j * 2u)) & 3u; + let gs = select(f32(g), f32(g) - 4.0, (g & 2u) != 0u); + row_sum += dl * (gs + delta) * x_block[col][ll * 8u + j]; + } } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -892,8 +941,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 56 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -908,11 +957,12 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 16u; i++) { - x_block[i] = f32(src1[x_base + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < 16u; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -936,26 +986,28 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let qh_lo = qh & 0xFFu; let qh_hi = (qh >> 8u) & 0xFFu; - var row_sum = 0.0; - for (var ll = 0u; ll < 2u; ll++) { - let l = slot0 + ll; - let bit_off = 6u * (sub_blk % 2u) + 3u * (l / 2u); - let sub_scale = (sc_u16 >> bit_off) & 0x7u; - let dl = d * f32(2u * sub_scale + 1u); - let qh_byte = select(qh_lo, qh_hi, l >= 2u); - let ll2 = l % 2u; - let grid_idx = get_byte(qs_w, l) | (((qh_byte >> (4u * ll2)) & 7u) << 8u); - let delta = select(IQ1_DELTA, -IQ1_DELTA, ((qh_byte >> (3u + 4u * ll2)) & 1u) != 0u); - let ig = grid_idx * 8u; - let gw = iq1_grid[ig / 16u]; - let bit_base = (ig % 16u) * 2u; - for (var j = 0u; j < 8u; j++) { - let g = (gw >> (bit_base + j * 2u)) & 3u; - let gs = select(f32(g), f32(g) - 4.0, (g & 2u) != 0u); - row_sum += dl * (gs + delta) * x_block[ll * 8u + j]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let bit_off = 6u * (sub_blk % 2u) + 3u * (l / 2u); + let sub_scale = (sc_u16 >> bit_off) & 0x7u; + let dl = d * f32(2u * sub_scale + 1u); + let qh_byte = select(qh_lo, qh_hi, l >= 2u); + let ll2 = l % 2u; + let grid_idx = get_byte(qs_w, l) | (((qh_byte >> (4u * ll2)) & 7u) << 8u); + let delta = select(IQ1_DELTA, -IQ1_DELTA, ((qh_byte >> (3u + 4u * ll2)) & 1u) != 0u); + let ig = grid_idx * 8u; + let gw = iq1_grid[ig / 16u]; + let bit_base = (ig % 16u) * 2u; + for (var j = 0u; j < 8u; j++) { + let g = (gw >> (bit_base + j * 2u)) & 3u; + let gs = select(f32(g), f32(g) - 4.0, (g & 2u) != 0u); + row_sum += dl * (gs + delta) * x_block[col][ll * 8u + j]; + } } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -968,8 +1020,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 66 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -984,11 +1036,12 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 16u; i++) { - x_block[i] = f32(src1[x_base + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < 16u; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -999,22 +1052,24 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let ls = aux_hi >> 28u; let db = d * (0.5 + f32(ls)) * 0.25; - var row_sum = 0.0; - for (var ll = 0u; ll < 2u; ll++) { - let l = slot0 + ll; - let grid_idx = (aux_lo >> (8u * l)) & 0xFFu; - let signs_idx = (aux_hi >> (7u * l)) & 0x7Fu; - let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu; - let gw_lo = iq2xxs_grid[grid_idx * 2u]; - let gw_hi = iq2xxs_grid[grid_idx * 2u + 1u]; - for (var j = 0u; j < 8u; j++) { - let gw = select(gw_hi, gw_lo, j < 4u); - let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu); - let s = select(1.0, -1.0, ((signs >> j) & 1u) != 0u); - row_sum += db * b * s * x_block[ll * 8u + j]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let grid_idx = (aux_lo >> (8u * l)) & 0xFFu; + let signs_idx = (aux_hi >> (7u * l)) & 0x7Fu; + let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu; + let gw_lo = iq2xxs_grid[grid_idx * 2u]; + let gw_hi = iq2xxs_grid[grid_idx * 2u + 1u]; + for (var j = 0u; j < 8u; j++) { + let gw = select(gw_hi, gw_lo, j < 4u); + let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu); + let s = select(1.0, -1.0, ((signs >> j) & 1u) != 0u); + row_sum += db * b * s * x_block[col][ll * 8u + j]; + } } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -1027,8 +1082,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 74 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -1043,11 +1098,12 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 16u; i++) { - x_block[i] = f32(src1[x_base + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < 16u; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -1058,27 +1114,29 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let scales_word = load_u32_at_src0(block_byte_base + 66u + (sub_blk / 4u) * 4u); let scales_byte = get_byte(scales_word, sub_blk % 4u); - var row_sum = 0.0; - for (var ll = 0u; ll < 2u; ll++) { - let l = slot0 + ll; - let qs_word = select(qs_hi, qs_lo, l < 2u); - let half2 = (l % 2u) * 16u; - let qs_val = (qs_word >> half2) & 0xFFFFu; - let grid_idx = qs_val & 0x1FFu; - let signs_idx = (qs_val >> 9u) & 0x7Fu; - let sub_scale = (scales_byte >> (4u * (l / 2u))) & 0xFu; - let db = d * (0.5 + f32(sub_scale)) * 0.25; - let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu; - let gw_lo = iq2xs_grid[grid_idx * 2u]; - let gw_hi = iq2xs_grid[grid_idx * 2u + 1u]; - for (var j = 0u; j < 8u; j++) { - let gw = select(gw_hi, gw_lo, j < 4u); - let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu); - let s = select(1.0, -1.0, ((signs >> j) & 1u) != 0u); - row_sum += db * b * s * x_block[ll * 8u + j]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let qs_word = select(qs_hi, qs_lo, l < 2u); + let half2 = (l % 2u) * 16u; + let qs_val = (qs_word >> half2) & 0xFFFFu; + let grid_idx = qs_val & 0x1FFu; + let signs_idx = (qs_val >> 9u) & 0x7Fu; + let sub_scale = (scales_byte >> (4u * (l / 2u))) & 0xFu; + let db = d * (0.5 + f32(sub_scale)) * 0.25; + let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu; + let gw_lo = iq2xs_grid[grid_idx * 2u]; + let gw_hi = iq2xs_grid[grid_idx * 2u + 1u]; + for (var j = 0u; j < 8u; j++) { + let gw = select(gw_hi, gw_lo, j < 4u); + let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu); + let s = select(1.0, -1.0, ((signs >> j) & 1u) != 0u); + row_sum += db * b * s * x_block[col][ll * 8u + j]; + } } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -1091,8 +1149,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 82 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -1107,11 +1165,12 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 16u; i++) { - x_block[i] = f32(src1[x_base + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < 16u; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -1124,24 +1183,26 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let sc_word = load_u32_at_src0(block_byte_base + 74u + (sub_blk / 4u) * 4u); let scales_byte = get_byte(sc_word, sub_blk % 4u); - var row_sum = 0.0; - for (var ll = 0u; ll < 2u; ll++) { - let l = slot0 + ll; - let qs_byte = get_byte(qs_w, l); - let sign_byte = get_byte(sg_w, l); - let grid_idx = qs_byte | (((qh_byte >> (2u * l)) & 3u) << 8u); - let sub_scale = (scales_byte >> (4u * (l / 2u))) & 0xFu; - let db = d * (0.5 + f32(sub_scale)) * 0.25; - let gw_lo = iq2s_grid[grid_idx * 2u]; - let gw_hi = iq2s_grid[grid_idx * 2u + 1u]; - for (var j = 0u; j < 8u; j++) { - let gw = select(gw_hi, gw_lo, j < 4u); - let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu); - let s = select(1.0, -1.0, ((sign_byte >> j) & 1u) != 0u); - row_sum += db * b * s * x_block[ll * 8u + j]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let qs_byte = get_byte(qs_w, l); + let sign_byte = get_byte(sg_w, l); + let grid_idx = qs_byte | (((qh_byte >> (2u * l)) & 3u) << 8u); + let sub_scale = (scales_byte >> (4u * (l / 2u))) & 0xFu; + let db = d * (0.5 + f32(sub_scale)) * 0.25; + let gw_lo = iq2s_grid[grid_idx * 2u]; + let gw_hi = iq2s_grid[grid_idx * 2u + 1u]; + for (var j = 0u; j < 8u; j++) { + let gw = select(gw_hi, gw_lo, j < 4u); + let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu); + let s = select(1.0, -1.0, ((sign_byte >> j) & 1u) != 0u); + row_sum += db * b * s * x_block[col][ll * 8u + j]; + } } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -1154,8 +1215,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 98 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -1170,11 +1231,12 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 16u; i++) { - x_block[i] = f32(src1[x_base + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < 16u; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -1186,27 +1248,29 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let ls = aux >> 28u; let db = d * (0.5 + f32(ls)) * 0.5; - var row_sum = 0.0; - for (var ll = 0u; ll < 2u; ll++) { - let l = slot0 + ll; - let qs_word = select(qs_hi, qs_lo, l < 2u); - let byte_pos = (l % 2u) * 2u; - let grid_idx_0 = (qs_word >> (byte_pos * 8u)) & 0xFFu; - let grid_idx_1 = (qs_word >> ((byte_pos + 1u) * 8u)) & 0xFFu; - let signs_idx = (aux >> (7u * l)) & 0x7Fu; - let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu; - let grid1 = iq3xxs_grid[grid_idx_0]; - let grid2 = iq3xxs_grid[grid_idx_1]; - for (var j = 0u; j < 4u; j++) { - let b1 = f32((grid1 >> (j * 8u)) & 0xFFu); - let b2 = f32((grid2 >> (j * 8u)) & 0xFFu); - let s1 = select(1.0, -1.0, ((signs >> j) & 1u) != 0u); - let s2 = select(1.0, -1.0, ((signs >> (j + 4u)) & 1u) != 0u); - row_sum += db * b1 * s1 * x_block[ll * 8u + j]; - row_sum += db * b2 * s2 * x_block[ll * 8u + j + 4u]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let qs_word = select(qs_hi, qs_lo, l < 2u); + let byte_pos = (l % 2u) * 2u; + let grid_idx_0 = (qs_word >> (byte_pos * 8u)) & 0xFFu; + let grid_idx_1 = (qs_word >> ((byte_pos + 1u) * 8u)) & 0xFFu; + let signs_idx = (aux >> (7u * l)) & 0x7Fu; + let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu; + let grid1 = iq3xxs_grid[grid_idx_0]; + let grid2 = iq3xxs_grid[grid_idx_1]; + for (var j = 0u; j < 4u; j++) { + let b1 = f32((grid1 >> (j * 8u)) & 0xFFu); + let b2 = f32((grid2 >> (j * 8u)) & 0xFFu); + let s1 = select(1.0, -1.0, ((signs >> j) & 1u) != 0u); + let s2 = select(1.0, -1.0, ((signs >> (j + 4u)) & 1u) != 0u); + row_sum += db * b1 * s1 * x_block[col][ll * 8u + j]; + row_sum += db * b2 * s2 * x_block[col][ll * 8u + j + 4u]; + } } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -1219,8 +1283,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 110 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -1235,11 +1299,12 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 16u; i++) { - x_block[i] = f32(src1[x_base + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < 16u; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -1255,28 +1320,30 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let sub_scale = (scales_byte >> (4u * (sub_blk % 2u))) & 0xFu; let db = d * (1.0 + 2.0 * f32(sub_scale)); - var row_sum = 0.0; - for (var ll = 0u; ll < 2u; ll++) { - let l = slot0 + ll; - let qs_word = select(qs_hi, qs_lo, l < 2u); - let byte_pos = (l % 2u) * 2u; - let qs0 = (qs_word >> (byte_pos * 8u)) & 0xFFu; - let qs1 = (qs_word >> ((byte_pos + 1u) * 8u)) & 0xFFu; - let grid_idx_1 = qs0 | (((qh_byte >> (2u * l)) & 1u) << 8u); - let grid_idx_2 = qs1 | (((qh_byte >> (2u * l + 1u)) & 1u) << 8u); - let sign_byte = get_byte(sg_w, l); - let grid1 = iq3s_grid[grid_idx_1]; - let grid2 = iq3s_grid[grid_idx_2]; - for (var j = 0u; j < 4u; j++) { - let b1 = f32((grid1 >> (j * 8u)) & 0xFFu); - let b2 = f32((grid2 >> (j * 8u)) & 0xFFu); - let s1 = select(1.0, -1.0, ((sign_byte >> j) & 1u) != 0u); - let s2 = select(1.0, -1.0, ((sign_byte >> (j + 4u)) & 1u) != 0u); - row_sum += db * b1 * s1 * x_block[ll * 8u + j]; - row_sum += db * b2 * s2 * x_block[ll * 8u + j + 4u]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let qs_word = select(qs_hi, qs_lo, l < 2u); + let byte_pos = (l % 2u) * 2u; + let qs0 = (qs_word >> (byte_pos * 8u)) & 0xFFu; + let qs1 = (qs_word >> ((byte_pos + 1u) * 8u)) & 0xFFu; + let grid_idx_1 = qs0 | (((qh_byte >> (2u * l)) & 1u) << 8u); + let grid_idx_2 = qs1 | (((qh_byte >> (2u * l + 1u)) & 1u) << 8u); + let sign_byte = get_byte(sg_w, l); + let grid1 = iq3s_grid[grid_idx_1]; + let grid2 = iq3s_grid[grid_idx_2]; + for (var j = 0u; j < 4u; j++) { + let b1 = f32((grid1 >> (j * 8u)) & 0xFFu); + let b2 = f32((grid2 >> (j * 8u)) & 0xFFu); + let s1 = select(1.0, -1.0, ((sign_byte >> j) & 1u) != 0u); + let s2 = select(1.0, -1.0, ((sign_byte >> (j + 4u)) & 1u) != 0u); + row_sum += db * b1 * s1 * x_block[col][ll * 8u + j]; + row_sum += db * b2 * s2 * x_block[col][ll * 8u + j + 4u]; + } } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -1290,35 +1357,37 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE_BYTES 18 #define THREADS_PER_BLOCK 4 #define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let num_blocks = params.k / BLOCK_SIZE; let thread_within_block = thread_id % THREADS_PER_BLOCK; for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4u; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD / 2u; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4u] = f32(src1[x_base + i + 16u]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < ELEMS_PER_THREAD / 2u; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + x_block[col][i + 4u] = f32(src1[x_base + col * params.stride_11 + i + 16u]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; let d = f32(load_f16_at_src0(block_byte_base)); - var row_sum = 0.0; - let q_packed = load_u32_at_src0(block_byte_base + 2u + 4u * thread_within_block); - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_byte = get_byte(q_packed, byte_idx); - let q_lo = f32(kvalues_iq4nl[q_byte & 0xFu]) * d; - let q_hi = f32(kvalues_iq4nl[(q_byte >> 4u) & 0xFu]) * d; - row_sum += q_lo * x_block[byte_idx]; - row_sum += q_hi * x_block[byte_idx + 4u]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let q_lo = f32(kvalues_iq4nl[q_byte & 0xFu]) * d; + let q_hi = f32(kvalues_iq4nl[(q_byte >> 4u) & 0xFu]) * d; + row_sum += q_lo * x_block[col][byte_idx]; + row_sum += q_hi * x_block[col][byte_idx + 4u]; + } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -1331,8 +1400,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 136 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -1346,11 +1415,12 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 16u; i++) { - x_block[i] = f32(src1[x_base + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < 16u; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -1370,17 +1440,19 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let q_w2 = load_u32_at_src0(block_byte_base + qs_byte_off + 8u); let q_w3 = load_u32_at_src0(block_byte_base + qs_byte_off + 12u); - var row_sum = 0.0; - for (var i = 0u; i < 16u; i++) { - let q_word = select( - select(q_w0, q_w1, i >= 4u), - select(q_w2, q_w3, i >= 12u), - i >= 8u); - let q_byte = get_byte(q_word, i % 4u); - let nib = select(q_byte & 0xFu, (q_byte >> 4u) & 0xFu, half == 1u); - row_sum += f32(kvalues_iq4nl[nib]) * dl * x_block[i]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var i = 0u; i < 16u; i++) { + let q_word = select( + select(q_w0, q_w1, i >= 4u), + select(q_w2, q_w3, i >= 12u), + i >= 8u); + let q_byte = get_byte(q_word, i % 4u); + let nib = select(q_byte & 0xFu, (q_byte >> 4u) & 0xFu, half == 1u); + row_sum += f32(kvalues_iq4nl[nib]) * dl * x_block[col][i]; + } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -1394,35 +1466,38 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE_BYTES 17 #define THREADS_PER_BLOCK 4 #define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let num_blocks = params.k / BLOCK_SIZE; let thread_within_block = thread_id % 4; for (var block = thread_id/THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE/THREADS_PER_BLOCK) { let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4] = f32(src1[x_base + i + 16]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + x_block[col][i + 4] = f32(src1[x_base + col * params.stride_11 + i + 16]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; let eu8 = get_byte(load_u32_at_src0(block_byte_base), 0); let e = ldexp(1.0, i32(eu8) - 128); - var row_sum = 0.0; let q_packed = load_u32_at_src0(block_byte_base + 1u + 4u * thread_within_block); - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_byte = get_byte(q_packed, byte_idx); - let q_lo = f32(kvalues_mxfp4[q_byte & 0xFu]) * e; - let q_hi = f32(kvalues_mxfp4[(q_byte >> 4u) & 0xFu]) * e; - row_sum += q_lo * x_block[byte_idx]; - row_sum += q_hi * x_block[byte_idx + 4u]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let q_lo = f32(kvalues_mxfp4[q_byte & 0xFu]) * e; + let q_hi = f32(kvalues_mxfp4[(q_byte >> 4u) & 0xFu]) * e; + row_sum += q_lo * x_block[col][byte_idx]; + row_sum += q_hi * x_block[col][byte_idx + 4u]; + } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl index 3ef2f77ebe..6ccaf61a6a 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl @@ -51,10 +51,7 @@ fn repack_b_dm(block: u32) -> B_DS_TYPE { fn get_dm(block_byte_base: u32) -> f32 { return f32(load_f16_at_src0(block_byte_base)); } -fn mul_q8_1(row_sum: i32, da: f32, b_ds: B_DS_TYPE) -> f32 { - return f32(row_sum) * (da * b_ds.x) - 8.0 * da * b_ds.y / THREADS_PER_BLOCK; -} -#endif +#endif // MUL_ACC_Q4_0 #ifdef MUL_ACC_Q4_1 #define BLOCK_SIZE_BYTES 20 @@ -85,10 +82,7 @@ fn get_dm(block_byte_base: u32) -> vec2 { f32(load_f16_at_src0(block_byte_base + 2u)) ); } -fn mul_q8_1(row_sum: i32, dma: vec2, b_ds: B_DS_TYPE) -> f32 { - return f32(row_sum) * (dma.x * b_ds.x) + dma.y * b_ds.y / THREADS_PER_BLOCK; -} -#endif +#endif // MUL_ACC_Q4_1 #ifdef MUL_ACC_Q8_0 #define BLOCK_SIZE_BYTES 34 @@ -111,46 +105,48 @@ fn repack_b_dm(block: u32) -> B_DS_TYPE { fn get_dm(block_byte_base: u32) -> f32 { return f32(load_f16_at_src0(block_byte_base)); } -fn mul_q8_1(row_sum: i32, da: f32, b_ds: B_DS_TYPE) -> f32 { - return f32(row_sum) * (da * b_ds); -} -#endif +#endif // MUL_ACC_Q8_0 -#ifdef LEGACY_QUANTS -fn mmvq_dot_product(a_byte_base: u32, b_inner_id: u32, b_repacked: vec2, b_ds: B_DS_TYPE) -> f32 { - var row_sum = 0; - let a_repacked = repack_a(a_byte_base, b_inner_id); - - row_sum += dot4I8Packed(a_repacked[0], b_repacked[0]); - row_sum += dot4I8Packed(a_repacked[1], b_repacked[1]); - - return mul_q8_1(row_sum, get_dm(a_byte_base), b_ds); -} - -fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array { - var acc: array; +#if defined(LEGACY_QUANTS) +fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let num_blocks = params.k / BLOCK_SIZE; for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { - let b_inner_id = thread_id % THREADS_PER_BLOCK; - let b_block_idx = src1q_idx_base + block; - - let b_repacked = repack_b_qs(b_block_idx, b_inner_id); - let b_ds = repack_b_dm(b_block_idx); - + let inner_id = thread_id % THREADS_PER_BLOCK; for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - acc[row] += mmvq_dot_product(block_byte_base, b_inner_id, b_repacked, b_ds); + let a_repacked = repack_a(block_byte_base, inner_id); + let da = get_dm(block_byte_base); + for (var col = 0u;col < NUM_COLS;col += 1) { + let src1q_idx = src1q_idx_base + col * (params.k / Q8_BLOCK_SIZE) + block; + let b_repacked = repack_b_qs(src1q_idx, inner_id); + let b_ds = repack_b_dm(src1q_idx); + + let row_sum = dot4I8Packed(a_repacked[0], b_repacked[0]) + dot4I8Packed(a_repacked[1], b_repacked[1]); + +#if defined(MUL_ACC_Q4_0) + acc[col][row] += f32(row_sum) * (da * b_ds.x) - 8.0 * da * b_ds.y / THREADS_PER_BLOCK; +#endif // MUL_ACC_Q4_0 + +#if defined(MUL_ACC_Q4_1) + acc[col][row] += f32(row_sum) * (da.x * b_ds.x) + da.y * b_ds.y / THREADS_PER_BLOCK; +#endif // MUL_ACC_Q4_1 + +#if defined(MUL_ACC_Q8_0) + acc[col][row] += f32(row_sum) * (da * b_ds); +#endif // MUL_ACC_Q8_0 + } } } } return acc; } -#endif +#endif // LEGACY_QUANTS #ifdef MUL_ACC_Q2_K #define BLOCK_SIZE_BYTES 84 @@ -191,22 +187,7 @@ fn get_scale_min(block_byte_base: u32, tid: u32) -> vec2 { let scale = byte_of(load_u32_at_src0_aligned(scale_byte), scale_byte & 3u); return vec2(f32(scale & 0xFu), f32(scale >> 4u)); } -fn mmvq_dot_product(a_byte_base: u32, tid: u32, b_repacked: vec4, b_ds: B_DS_TYPE) -> f32 { - let a_repacked = repack_a(a_byte_base, tid); - let dm = get_dm(a_byte_base); - let scale_min = get_scale_min(a_byte_base, tid); - - let scale_q = i32(scale_min.x); - let scale_m_i8x4 = u32(scale_min.y) * 0x01010101u; - - let row_sum_d = (dot4I8Packed(b_repacked[0], a_repacked[0]) + dot4I8Packed(b_repacked[1], a_repacked[1]) - + dot4I8Packed(b_repacked[2], a_repacked[2]) + dot4I8Packed(b_repacked[3], a_repacked[3])) * scale_q; - let row_sum_m = dot4I8Packed(b_repacked[0], scale_m_i8x4) + dot4I8Packed(b_repacked[1], scale_m_i8x4) - + dot4I8Packed(b_repacked[2], scale_m_i8x4) + dot4I8Packed(b_repacked[3], scale_m_i8x4); - - return b_ds * (dm.x * f32(row_sum_d) - dm.y * f32(row_sum_m)); -} -#endif +#endif // MUL_ACC_Q2_K #ifdef MUL_ACC_Q4_K #define BLOCK_SIZE_BYTES 144 @@ -265,39 +246,52 @@ fn get_scale_min(block_byte_base: u32, tid: u32) -> vec2 { return vec2(scale, min_val); } -fn mmvq_dot_product(a_byte_base: u32, tid: u32, b_repacked: vec4, b_ds: B_DS_TYPE) -> f32 { - let a_repacked = repack_a(a_byte_base, tid); - let dm = get_dm(a_byte_base); - let scale_min = get_scale_min(a_byte_base, tid); - - let row_sum = dot4I8Packed(a_repacked[0], b_repacked[0]) + dot4I8Packed(a_repacked[1], b_repacked[1]) - + dot4I8Packed(a_repacked[2], b_repacked[2]) + dot4I8Packed(a_repacked[3], b_repacked[3]); - - // Each thread covers half of the Q8_1 block, so add only b_ds.y/2. - return b_ds.x * dm.x * scale_min.x * f32(row_sum) - dm.y * scale_min.y * (b_ds.y / (Q8_BLOCK_SIZE / ELEMS_PER_THREAD)); -} -#endif +#endif // MUL_ACC_Q4_K #ifdef K_QUANTS -fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; for (var block = thread_id / THREADS_PER_BLOCK; block < params.k / BLOCK_SIZE; block += WG_SIZE / THREADS_PER_BLOCK) { - let src1q_idx = src1q_idx_base + (block * BLOCK_SIZE + ELEMS_PER_THREAD * tid) / Q8_BLOCK_SIZE; - let b_repacked = repack_b_qs(src1q_idx, tid); - let b_ds = repack_b_dm(src1q_idx); - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - acc[row] += mmvq_dot_product(block_byte_base, tid, b_repacked, b_ds); + let a_repacked = repack_a(block_byte_base, tid); + let dm = get_dm(block_byte_base); + let scale_min = get_scale_min(block_byte_base, tid); + for (var col = 0u;col < NUM_COLS;col += 1) { + let src1q_idx = src1q_idx_base + col * (params.k / Q8_BLOCK_SIZE) + (block * BLOCK_SIZE + ELEMS_PER_THREAD * tid) / Q8_BLOCK_SIZE; + let b_repacked = repack_b_qs(src1q_idx, tid); + let b_ds = repack_b_dm(src1q_idx); + +#if defined(MUL_ACC_Q2_K) + let scale_q = i32(scale_min.x); + let scale_m_i8x4 = u32(scale_min.y) * 0x01010101u; + + let row_sum_d = (dot4I8Packed(b_repacked[0], a_repacked[0]) + dot4I8Packed(b_repacked[1], a_repacked[1]) + + dot4I8Packed(b_repacked[2], a_repacked[2]) + dot4I8Packed(b_repacked[3], a_repacked[3])) * scale_q; + let row_sum_m = dot4I8Packed(b_repacked[0], scale_m_i8x4) + dot4I8Packed(b_repacked[1], scale_m_i8x4) + + dot4I8Packed(b_repacked[2], scale_m_i8x4) + dot4I8Packed(b_repacked[3], scale_m_i8x4); + + acc[col][row] += b_ds * (dm.x * f32(row_sum_d) - dm.y * f32(row_sum_m)); +#endif // MUL_ACC_Q2_K + +#if defined(MUL_ACC_Q4_K) + let row_sum = dot4I8Packed(a_repacked[0], b_repacked[0]) + dot4I8Packed(a_repacked[1], b_repacked[1]) + + dot4I8Packed(a_repacked[2], b_repacked[2]) + dot4I8Packed(a_repacked[3], b_repacked[3]); + + // Each thread covers half of the Q8_1 block, so add only b_ds.y/2. + acc[col][row] += b_ds.x * dm.x * scale_min.x * f32(row_sum) - dm.y * scale_min.y * (b_ds.y / (Q8_BLOCK_SIZE / ELEMS_PER_THREAD)); +#endif // MUL_ACC_Q4_K + + } } } } return acc; } -#endif +#endif // K_QUANTS diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl index b3f1fa04b8..847b27ffad 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl @@ -9,9 +9,11 @@ requires packed_4x8_integer_dot_product; struct Params { offset_src1: u32, + stride_11: u32, stride_12: u32, stride_13: u32, ne0: u32, + ne1: u32, ne2: u32, ne3: u32, }; @@ -57,25 +59,28 @@ fn main( @builtin(num_workgroups) num_wg: vec3 ) { let thread_id = local_id.x; - let num_vec4 = params.ne0 / 4u; + let ne0_vec4 = params.ne0 / 4u; - let wg_per_vec = (num_vec4 + (WG_SIZE - 1u)) / WG_SIZE; - let total_batches = wg_per_vec * params.ne2 * params.ne3; + let wg_per_vec = (ne0_vec4 + (WG_SIZE - 1u)) / WG_SIZE; + let total_batches = wg_per_vec * params.ne1 * params.ne2 * params.ne3; let wg_linear = wg_id.y * num_wg.x + wg_id.x; if (wg_linear >= total_batches) { return; } - let src13_idx = wg_linear / (params.ne2 * wg_per_vec); - let src12_idx = (wg_linear - src13_idx * (params.ne2 * wg_per_vec)) / wg_per_vec; - let src11_wg_idx = wg_linear % wg_per_vec; - let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; + let vec_idx = wg_linear / wg_per_vec; + let src13_idx = vec_idx / (params.ne2 * params.ne1); + let vec_ne12_num = vec_idx % (params.ne2 * params.ne1); + let src12_idx = vec_ne12_num / params.ne1; + let src11_idx = vec_ne12_num % params.ne1; + let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12 + src11_idx * params.stride_11; let src1_idx_vec4_base = src1_idx_base / 4u; let blocks_per_row = params.ne0 / 32u; let blocks_per_wg = (WG_SIZE * 4u) / 32u; - let src1q_idx_base = (src13_idx * params.ne2 + src12_idx) * blocks_per_row; + let src1q_idx_base = ((src13_idx * params.ne2 + src12_idx) * params.ne1 + src11_idx) * blocks_per_row; + let src11_wg_idx = wg_linear % wg_per_vec; let src1q_idx = src1q_idx_base + src11_wg_idx * blocks_per_wg + thread_id / 8u; let qs_idx = thread_id % 8u; @@ -85,7 +90,7 @@ fn main( var thread_amax = 0.0; let src11_vec4_idx = src11_wg_idx * WG_SIZE + thread_id; - let is_valid = src11_vec4_idx < num_vec4; + let is_valid = src11_vec4_idx < ne0_vec4; #ifdef USE_SUBGROUP_REDUCTION diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 15ae38927c..127c4634c0 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -8433,6 +8433,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {3, 2}, {2, 1})); test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {3, 2}, {1, 2})); test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {3, 2}, {2, 2})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 4, k, {3, 2}, {2, 2})); test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {1, 1}, {1, 1})); test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {1, 1}, {2, 1})); @@ -8449,6 +8450,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {2, 3}, {1, 1}, {0, 1, 3, 2})); test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {2, 3}, {1, 1}, {0, 3, 2, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 4, k, {2, 3}, {1, 1}, {0, 3, 2, 1})); test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, k, {2, 3}, {1, 1}, {0, 2, 1, 3})); test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, k, {2, 3}, {1, 1}, {0, 1, 3, 2})); test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, k, {2, 3}, {1, 1}, {0, 3, 2, 1}));