mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-27 23:50:20 -05:00
* vulkan: optimize operations in the IM2COL shader * Add comments and improve the code formatting
139 lines
3.7 KiB
Plaintext
139 lines
3.7 KiB
Plaintext
#version 450
|
|
|
|
#extension GL_EXT_shader_16bit_storage : require
|
|
#extension GL_EXT_control_flow_attributes : require
|
|
|
|
#include "types.glsl"
|
|
|
|
layout (push_constant) uniform parameter
|
|
{
|
|
BDA_STORAGE_T dst_addr;
|
|
uint batch_offset; uint offset_delta;
|
|
uint IC;
|
|
uint IW; uint IH;
|
|
uint OW; uint OH;
|
|
uint KW; uint KH;
|
|
uint OH_batch;
|
|
uint CHW;
|
|
int s0; int s1;
|
|
int p0; int p1;
|
|
int d0; int d1;
|
|
uint batch_IC;
|
|
} p;
|
|
|
|
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
|
|
|
|
const uint NUM_ITER = 512 / BLOCK_SIZE;
|
|
|
|
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
|
|
|
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
|
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
|
|
|
#if BDA
|
|
layout (buffer_reference) buffer D_ptr {D_TYPE d;};
|
|
#endif
|
|
|
|
void im2col(const uint ow, const uint z_idx) {
|
|
const uint oh = z_idx % p.OH;
|
|
const uint batch_idx = z_idx / p.OH;
|
|
|
|
const uint gidx = gl_LocalInvocationID.x;
|
|
const uint src_batch = batch_idx * p.batch_offset;
|
|
const BDA_OFFSET_T dst_row = ((BDA_OFFSET_T(batch_idx) * p.OH + oh) * p.OW + ow) * p.CHW;
|
|
|
|
const uint KHKW = p.KH * p.KW;
|
|
|
|
// Precompute base input coordinates
|
|
const int base_iw = int(ow * p.s0) - p.p0;
|
|
const int base_ih = int(oh * p.s1) - p.p1;
|
|
|
|
// Precompute step deltas
|
|
const uint delta_ic = BLOCK_SIZE / KHKW;
|
|
const uint delta_rem = BLOCK_SIZE % KHKW;
|
|
|
|
const uint delta_ky = delta_rem / p.KW;
|
|
const uint delta_kx = delta_rem % p.KW;
|
|
|
|
const uint delta_ic_offset = delta_ic * p.offset_delta;
|
|
|
|
// If using BDA mode, precompute the base pointer and step size
|
|
#if BDA
|
|
const BDA_STORAGE_T base_dst_addr = p.dst_addr + D_SIZE * dst_row;
|
|
const uint bda_step = D_SIZE * BLOCK_SIZE;
|
|
#endif
|
|
|
|
uint wg_x = gl_WorkGroupID.x;
|
|
do {
|
|
const uint wg_offset = wg_x * 512;
|
|
|
|
uint chw_idx = wg_offset + gidx;
|
|
|
|
uint ic = chw_idx / KHKW;
|
|
uint rem = chw_idx % KHKW;
|
|
|
|
uint ky = rem / p.KW;
|
|
uint kx = rem % p.KW;
|
|
|
|
uint ic_offset = src_batch + ic * p.offset_delta;
|
|
|
|
// Initialize running pointer/index for the destination buffer
|
|
#if BDA
|
|
BDA_STORAGE_T current_dst_addr = base_dst_addr + D_SIZE * chw_idx;
|
|
#else
|
|
uint current_dst_idx = dst_row + chw_idx;
|
|
#endif
|
|
|
|
[[unroll]] for (uint i = 0; i < NUM_ITER; ++i) {
|
|
if (chw_idx >= p.CHW) {
|
|
return;
|
|
}
|
|
|
|
const int iiw = base_iw + int(kx * p.d0);
|
|
const int iih = base_ih + int(ky * p.d1);
|
|
|
|
A_TYPE val = A_TYPE(0);
|
|
if (uint(iih) < p.IH && uint(iiw) < p.IW) {
|
|
val = data_a[ic_offset + uint(iih) * p.IW + uint(iiw)];
|
|
}
|
|
|
|
#if BDA
|
|
D_ptr(current_dst_addr).d = D_TYPE(val);
|
|
current_dst_addr += bda_step;
|
|
#else
|
|
data_d[current_dst_idx] = D_TYPE(val);
|
|
current_dst_idx += BLOCK_SIZE;
|
|
#endif
|
|
|
|
chw_idx += BLOCK_SIZE;
|
|
ic_offset += delta_ic_offset;
|
|
kx += delta_kx;
|
|
ky += delta_ky;
|
|
|
|
// Handle X axis wrap
|
|
uint kx_wrap = uint(kx >= p.KW);
|
|
kx -= kx_wrap * p.KW;
|
|
ky += kx_wrap;
|
|
|
|
// Handle Y axis wrap
|
|
uint ky_wrap = uint(ky >= p.KH);
|
|
ky -= ky_wrap * p.KH;
|
|
ic_offset += ky_wrap * p.offset_delta;
|
|
}
|
|
|
|
wg_x += gl_NumWorkGroups.x;
|
|
} while (wg_x * 512 < p.CHW);
|
|
}
|
|
|
|
void main() {
|
|
uint ow = gl_GlobalInvocationID.y;
|
|
while (ow < p.OW) {
|
|
uint z = gl_GlobalInvocationID.z;
|
|
while (z < p.OH_batch) {
|
|
im2col(ow, z);
|
|
z += gl_NumWorkGroups.z;
|
|
}
|
|
ow += gl_NumWorkGroups.y;
|
|
}
|
|
}
|