mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-27 23:50:20 -05:00
vulkan: opt mul_mat_vecq for mi50 (#22933)
This commit is contained in:
parent
5a6a0dd7e1
commit
487a6cc164
@ -4719,7 +4719,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
|||||||
}
|
}
|
||||||
uint32_t rm_iq = 2 * rm_kq;
|
uint32_t rm_iq = 2 * rm_kq;
|
||||||
|
|
||||||
const bool use_subgroups = device->subgroup_arithmetic && device->architecture != vk_device_architecture::AMD_GCN;
|
const bool use_subgroups = device->subgroup_arithmetic;
|
||||||
// Ensure a subgroup size >= 16 is available
|
// Ensure a subgroup size >= 16 is available
|
||||||
const bool use_subgroups16 = use_subgroups && subgroup_min_size_16;
|
const bool use_subgroups16 = use_subgroups && subgroup_min_size_16;
|
||||||
|
|
||||||
|
|||||||
@ -28,13 +28,10 @@ vec2 cache_b_ds;
|
|||||||
|
|
||||||
#include "mul_mat_vecq_funcs.glsl"
|
#include "mul_mat_vecq_funcs.glsl"
|
||||||
|
|
||||||
void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i) {
|
void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint col, const uint b_qs_idx) {
|
||||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||||
const uint col = i*BLOCK_SIZE + tid*K_PER_ITER;
|
|
||||||
|
|
||||||
// Preload data_b block
|
// Preload data_b block
|
||||||
const uint b_block_idx = (j*p.batch_stride_b + col) / QUANT_K_Q8_1 + b_offset;
|
const uint b_block_idx = (j*p.batch_stride_b + col) / QUANT_K_Q8_1 + b_offset;
|
||||||
const uint b_qs_idx = tid % (32 / K_PER_ITER);
|
|
||||||
const uint b_block_idx_outer = b_block_idx / 4;
|
const uint b_block_idx_outer = b_block_idx / 4;
|
||||||
const uint b_block_idx_inner = b_block_idx % 4;
|
const uint b_block_idx_inner = b_block_idx % 4;
|
||||||
cache_b_ds = vec2(data_b[b_block_idx_outer].ds[b_block_idx_inner]);
|
cache_b_ds = vec2(data_b[b_block_idx_outer].ds[b_block_idx_inner]);
|
||||||
@ -91,35 +88,35 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
uint num_iters = p.ncols / (K_PER_ITER * BLOCK_SIZE);
|
const uint col_stride = K_PER_ITER * BLOCK_SIZE;
|
||||||
if (num_iters * K_PER_ITER * BLOCK_SIZE + K_PER_ITER*tid < p.ncols) {
|
uint num_iters = p.ncols / col_stride;
|
||||||
|
if (num_iters * col_stride + K_PER_ITER * tid < p.ncols) {
|
||||||
num_iters++;
|
num_iters++;
|
||||||
}
|
}
|
||||||
int unroll_count = 4;
|
|
||||||
uint unrolled_iters = num_iters & ~(unroll_count - 1);
|
|
||||||
|
|
||||||
uint i = 0;
|
const uint b_qs_idx = tid % (32 / K_PER_ITER);
|
||||||
while (i < unrolled_iters) {
|
uint col = tid * K_PER_ITER;
|
||||||
|
while (num_iters >= 4) {
|
||||||
// Manually partially unroll the loop
|
// Manually partially unroll the loop
|
||||||
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
|
[[unroll]] for (uint k = 0; k < 4; ++k) {
|
||||||
iter(temp, first_row, num_rows, tid, i*K_PER_ITER);
|
iter(temp, first_row, num_rows, col, b_qs_idx);
|
||||||
i++;
|
col += col_stride;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
num_iters -= 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
unroll_count = 2;
|
if (num_iters >= 2) {
|
||||||
unrolled_iters = num_iters & ~(unroll_count - 1);
|
|
||||||
|
|
||||||
while (i < unrolled_iters) {
|
|
||||||
// Manually partially unroll the loop
|
// Manually partially unroll the loop
|
||||||
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
|
iter(temp, first_row, num_rows, col, b_qs_idx);
|
||||||
iter(temp, first_row, num_rows, tid, i*K_PER_ITER);
|
col += col_stride;
|
||||||
i++;
|
iter(temp, first_row, num_rows, col, b_qs_idx);
|
||||||
}
|
col += col_stride;
|
||||||
|
num_iters -= 2;
|
||||||
}
|
}
|
||||||
while (i < num_iters) {
|
|
||||||
iter(temp, first_row, num_rows, tid, i*K_PER_ITER);
|
if (num_iters > 0) {
|
||||||
i++;
|
iter(temp, first_row, num_rows, col, b_qs_idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
reduce_result(temp, d_offset, first_row, num_rows, tid);
|
reduce_result(temp, d_offset, first_row, num_rows, tid);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user