diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index c090d44420..ea91918732 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -4719,7 +4719,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { } uint32_t rm_iq = 2 * rm_kq; - const bool use_subgroups = device->subgroup_arithmetic && device->architecture != vk_device_architecture::AMD_GCN; + const bool use_subgroups = device->subgroup_arithmetic; // Ensure a subgroup size >= 16 is available const bool use_subgroups16 = use_subgroups && subgroup_min_size_16; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp index fd84c3c91d..7bbee577fb 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp @@ -28,13 +28,10 @@ vec2 cache_b_ds; #include "mul_mat_vecq_funcs.glsl" -void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i) { +void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint col, const uint b_qs_idx) { [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { - const uint col = i*BLOCK_SIZE + tid*K_PER_ITER; - // Preload data_b block const uint b_block_idx = (j*p.batch_stride_b + col) / QUANT_K_Q8_1 + b_offset; - const uint b_qs_idx = tid % (32 / K_PER_ITER); const uint b_block_idx_outer = b_block_idx / 4; const uint b_block_idx_inner = b_block_idx % 4; cache_b_ds = vec2(data_b[b_block_idx_outer].ds[b_block_idx_inner]); @@ -91,35 +88,35 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { } } - uint num_iters = p.ncols / (K_PER_ITER * BLOCK_SIZE); - if (num_iters * K_PER_ITER * BLOCK_SIZE + K_PER_ITER*tid < p.ncols) { + const uint col_stride = K_PER_ITER * BLOCK_SIZE; + uint num_iters = p.ncols / col_stride; + if (num_iters * col_stride + K_PER_ITER * tid < p.ncols) { num_iters++; } - int unroll_count = 4; - uint unrolled_iters = num_iters & ~(unroll_count - 1); - uint i = 0; - while (i < unrolled_iters) { + const uint b_qs_idx = tid % (32 / K_PER_ITER); + uint col = tid * K_PER_ITER; + while (num_iters >= 4) { // Manually partially unroll the loop - [[unroll]] for (uint k = 0; k < unroll_count; ++k) { - iter(temp, first_row, num_rows, tid, i*K_PER_ITER); - i++; + [[unroll]] for (uint k = 0; k < 4; ++k) { + iter(temp, first_row, num_rows, col, b_qs_idx); + col += col_stride; } + + num_iters -= 4; } - unroll_count = 2; - unrolled_iters = num_iters & ~(unroll_count - 1); - - while (i < unrolled_iters) { + if (num_iters >= 2) { // Manually partially unroll the loop - [[unroll]] for (uint k = 0; k < unroll_count; ++k) { - iter(temp, first_row, num_rows, tid, i*K_PER_ITER); - i++; - } + iter(temp, first_row, num_rows, col, b_qs_idx); + col += col_stride; + iter(temp, first_row, num_rows, col, b_qs_idx); + col += col_stride; + num_iters -= 2; } - while (i < num_iters) { - iter(temp, first_row, num_rows, tid, i*K_PER_ITER); - i++; + + if (num_iters > 0) { + iter(temp, first_row, num_rows, col, b_qs_idx); } reduce_result(temp, d_offset, first_row, num_rows, tid);