mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-27 23:50:20 -05:00
vulkan: Support GET_ROWS_BACK (#24883)
This commit is contained in:
parent
c5606364b2
commit
92e854ab83
@ -791,6 +791,7 @@ struct vk_device_struct {
|
||||
vk_pipeline pipeline_mul_mat_vec_nc_f16_f32;
|
||||
vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
|
||||
vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
|
||||
vk_pipeline pipeline_get_rows_back_f32;
|
||||
vk_pipeline pipeline_acc_f32;
|
||||
vk_pipeline pipeline_set_f32;
|
||||
|
||||
@ -4946,6 +4947,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4], "get_rows_mxfp4_f32", get_rows_mxfp4_f32_len, get_rows_mxfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_NVFP4], "get_rows_nvfp4_f32", get_rows_nvfp4_f32_len, get_rows_nvfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_back_f32, "get_rows_back_f32", get_rows_back_f32_len, get_rows_back_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {256, 1, 1}, {}, 1, true);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, sizeof(vk_op_flash_attn_split_k_reduce_push_constants), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
|
||||
@ -10408,6 +10410,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||
return ctx->device->pipeline_get_rows_f32[src0->type];
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_GET_ROWS_BACK:
|
||||
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_get_rows_back_f32;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_ACC:
|
||||
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_acc_f32;
|
||||
@ -11304,6 +11311,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||
elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
|
||||
elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
|
||||
break;
|
||||
case GGML_OP_GET_ROWS_BACK:
|
||||
elements = { (uint32_t)dst->ne[0], (uint32_t)dst->ne[1], 1 };
|
||||
elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
|
||||
break;
|
||||
case GGML_OP_ARGSORT:
|
||||
GGML_ASSERT(0);
|
||||
break;
|
||||
@ -11564,6 +11575,21 @@ static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context& subctx,
|
||||
});
|
||||
}
|
||||
|
||||
static void ggml_vk_get_rows_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
||||
const uint32_t src1_type_size = ggml_type_size(src1->type);
|
||||
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
||||
|
||||
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_GET_ROWS_BACK, {
|
||||
(uint32_t)ggml_nelements(src0),
|
||||
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
||||
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2], (uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
|
||||
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
||||
0,
|
||||
0.0f, 0.0f, 0,
|
||||
});
|
||||
}
|
||||
|
||||
static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
||||
const uint32_t src1_type_size = ggml_type_size(src1->type);
|
||||
@ -14476,6 +14502,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||
case GGML_OP_GET_ROWS:
|
||||
ggml_vk_get_rows(ctx, compute_ctx, src0, src1, node);
|
||||
|
||||
break;
|
||||
case GGML_OP_GET_ROWS_BACK:
|
||||
ggml_vk_get_rows_back(ctx, compute_ctx, src0, src1, node);
|
||||
|
||||
break;
|
||||
case GGML_OP_ADD:
|
||||
if (ctx->num_additional_fused_ops) {
|
||||
@ -17197,6 +17227,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
return false;
|
||||
}
|
||||
}
|
||||
case GGML_OP_GET_ROWS_BACK:
|
||||
return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_SET_ROWS:
|
||||
{
|
||||
switch (op->type) {
|
||||
|
||||
25
ggml/src/ggml-vulkan/vulkan-shaders/get_rows_back.comp
Normal file
25
ggml/src/ggml-vulkan/vulkan-shaders/get_rows_back.comp
Normal file
@ -0,0 +1,25 @@
|
||||
#version 450
|
||||
|
||||
#include "types.glsl"
|
||||
#include "generic_binary_head.glsl"
|
||||
|
||||
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
void main() {
|
||||
const uint col = gl_GlobalInvocationID.x;
|
||||
|
||||
if (col >= p.ne20) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (uint row = gl_GlobalInvocationID.y; row < p.ne21; row += gl_WorkGroupSize.y * gl_NumWorkGroups.y) {
|
||||
float sum = 0.0f;
|
||||
for (uint i = 0; i < p.ne10; ++i) {
|
||||
if (data_b[get_boffset() + i*p.nb10] == int(row)) {
|
||||
sum += data_a[get_aoffset() + i*p.nb01 + col*p.nb00];
|
||||
}
|
||||
}
|
||||
|
||||
data_d[get_doffset() + row*p.nb21 + col*p.nb20] = sum;
|
||||
}
|
||||
}
|
||||
@ -843,6 +843,7 @@ void process_shaders() {
|
||||
|
||||
string_to_spv("repeat_i32", "repeat.comp", {{"A_TYPE", "int32_t"}, {"D_TYPE", "int32_t"}});
|
||||
string_to_spv("repeat_back_f32", "repeat_back.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("get_rows_back_f32", "get_rows_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}});
|
||||
|
||||
string_to_spv("repeat_i16", "repeat.comp", {{"A_TYPE", "int16_t"}, {"D_TYPE", "int16_t"}});
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user