Faster small batch inference for MoE models

This commit is contained in:
Kawrakow 2026-04-29 07:32:39 +00:00
parent fb05c2e9a2
commit 47b46ba399
2 changed files with 77 additions and 59 deletions

View File

@ -2798,7 +2798,7 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten
const ggml_tensor * src1 = dst->src[2];
const ggml_tensor * ids = dst->src[3];
if (src1->ne[1] == 1 && src1->ne[2] == 1 && src1->ne[3] == 1 &&
if (src1->ne[1] == 1 && src1->ne[2] <= 8 && src1->ne[3] == 1 &&
ggml_is_quantized(src0_1->type) &&
(!src0_2 || ggml_is_quantized(src0_2->type)) &&
ggml_backend_buffer_is_cuda(src0_1->buffer) &&
@ -2826,21 +2826,22 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten
auto local_src1 = *src1;
local_src1.nb[2] = local_src1.nb[3] = 0;
local_src1.ne[1] = local_src1.ne[2] = local_src1.ne[3] = 1;
int Ny = src1->ne[2];
const int64_t src1_padded_col_size = GGML_PAD(src1->ne[0], MATRIX_ROW_PADDING);
ggml_cuda_pool_alloc<char> src1_quantized(ctx.pool());
if (ggml_is_quantized(src0_1->type) || (src0_2 && ggml_is_quantized(src0_2->type))) {
GGML_ASSERT(src1->ne[0] % QK8_1 == 0);
auto src_1_ddq_size = src1_padded_col_size*sizeof(block_q8_1)/QK8_1;
local_src1.data = src1_quantized.alloc(src_1_ddq_size);
// Note: no use is currently made of the quantization type passed into quantize_row_q8_1_cuda.
// If that were to change, we would need to adjust the code to handle src0_1->type != src0_2->type
quantize_row_q8_1_cuda((const float *)src1->data, (void *)src1_quantized.get(), src1->ne[0], 1, 1, src1_padded_col_size,
src0_1->type, stream);
CUDA_CHECK(cudaGetLastError());
GGML_ASSERT(src1->ne[0] % QK8_1 == 0);
auto src_1_ddq_size = src1_padded_col_size*sizeof(block_q8_1)/QK8_1;
local_src1.data = src1_quantized.alloc(src_1_ddq_size * Ny);
// Note: no use is currently made of the quantization type passed into quantize_row_q8_1_cuda.
// If that were to change, we would need to adjust the code to handle src0_1->type != src0_2->type
quantize_row_q8_1_cuda((const float *)src1->data, (void *)src1_quantized.get(), src1->ne[0], Ny, 1, src1_padded_col_size,
src0_1->type, stream);
CUDA_CHECK(cudaGetLastError());
local_src1.nb[1] = src_1_ddq_size;
}
local_src1.nb[1] = src_1_ddq_size;
bool fuse_next = next && next->op == GGML_OP_MUL_MAT_ID && ggml_is_quantized(next->src[0]->type) &&
ggml_backend_buffer_is_cuda(next->src[0]->buffer) &&
@ -2850,37 +2851,46 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten
auto unary_op = (ggml_unary_op)dst->op_params[0];
float limit = *(const float *)(dst->op_params + 1);
if (src0_2) {
ggml_cuda_op_fused_mul_mat_vec_q_id(ctx, src0_1, &local_src1, ids, &local_dst,
dst->src[4], dst->src[5],
(const char *)src0_1->data, (const char *)src0_2->data,
(const float *)src1->data, src1_quantized.get(),
(float *)local_dst.data, 0, src0_1->ne[1], 1, src1_padded_col_size, unary_op, limit, stream);
} else {
auto local_src0_1 = *src0_1;
local_src0_1.ne[1] /= 2;
auto local_src0_2 = local_src0_1;
local_src0_2.data = (char *)local_src0_1.data + local_src0_1.ne[1]*local_src0_1.nb[1];
if (!dst->src[4]) {
ggml_cuda_op_fused_mul_mat_vec_q_id(ctx, &local_src0_1, &local_src1, ids, &local_dst,
nullptr, nullptr,
(const char *)local_src0_2.data, (const char *)local_src0_1.data,
(const float *)src1->data, src1_quantized.get(),
(float *)local_dst.data, 0, local_src0_1.ne[1], 1, src1_padded_col_size, unary_op, limit, stream);
auto local_ids = *ids;
local_ids.ne[1] = 1;
for (int iy = 0; iy < Ny; ++iy) {
local_src1.data = src1_quantized.get() + iy*src_1_ddq_size;
local_ids.data = (char *)ids->data + iy*ids->nb[1];
local_dst.data = (char *)dst->data + iy*dst->nb[2];
if (src0_2) {
ggml_cuda_op_fused_mul_mat_vec_q_id(ctx, src0_1, &local_src1, &local_ids, &local_dst,
dst->src[4], dst->src[5],
(const char *)src0_1->data, (const char *)src0_2->data,
(const float *)src1->data, (const char *)local_src1.data,
(float *)local_dst.data, 0, src0_1->ne[1], 1, src1_padded_col_size, unary_op, limit, stream);
} else {
GGML_ASSERT(!dst->src[5]);
auto local_bias_1 = *dst->src[4];
local_bias_1.ne[0] /= 2;
auto local_bias_2 = local_bias_1;
local_bias_2.data = (char *)local_bias_1.data + local_bias_1.ne[0]*local_bias_1.nb[0];
ggml_cuda_op_fused_mul_mat_vec_q_id(ctx, &local_src0_1, &local_src1, ids, &local_dst,
&local_bias_2, &local_bias_1,
(const char *)local_src0_2.data, (const char *)local_src0_1.data,
(const float *)src1->data, src1_quantized.get(),
(float *)local_dst.data, 0, local_src0_1.ne[1], 1, src1_padded_col_size, unary_op, limit, stream);
auto local_src0_1 = *src0_1;
local_src0_1.ne[1] /= 2;
auto local_src0_2 = local_src0_1;
local_src0_2.data = (char *)local_src0_1.data + local_src0_1.ne[1]*local_src0_1.nb[1];
if (!dst->src[4]) {
ggml_cuda_op_fused_mul_mat_vec_q_id(ctx, &local_src0_1, &local_src1, &local_ids, &local_dst,
nullptr, nullptr,
(const char *)local_src0_2.data, (const char *)local_src0_1.data,
(const float *)src1->data, (const char *)local_src1.data,
(float *)local_dst.data, 0, local_src0_1.ne[1], 1, src1_padded_col_size, unary_op, limit, stream);
} else {
GGML_ASSERT(!dst->src[5]);
auto local_bias_1 = *dst->src[4];
local_bias_1.ne[0] /= 2;
auto local_bias_2 = local_bias_1;
local_bias_2.data = (char *)local_bias_1.data + local_bias_1.ne[0]*local_bias_1.nb[0];
ggml_cuda_op_fused_mul_mat_vec_q_id(ctx, &local_src0_1, &local_src1, &local_ids, &local_dst,
&local_bias_2, &local_bias_1,
(const char *)local_src0_2.data, (const char *)local_src0_1.data,
(const float *)src1->data, (const char *)local_src1.data,
(float *)local_dst.data, 0, local_src0_1.ne[1], 1, src1_padded_col_size, unary_op, limit, stream);
}
}
CUDA_CHECK(cudaGetLastError());
}
CUDA_CHECK(cudaGetLastError());
if (!fuse_next) return i;
@ -2888,8 +2898,8 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten
GGML_ASSERT(dst->ne[0] % QK8_1 == 0);
auto dst_row_size = dst_padded_col_size*sizeof(block_q8_1)/QK8_1;
auto dst_ddq_size = n_ids*dst_row_size;
ggml_cuda_pool_alloc<char> dst_quantized(ctx.pool(), dst_ddq_size);
quantize_row_q8_1_cuda((const float *)local_dst.data, (void *)dst_quantized.get(), dst->ne[0], n_ids, 1,
ggml_cuda_pool_alloc<char> dst_quantized(ctx.pool(), dst_ddq_size*Ny);
quantize_row_q8_1_cuda((const float *)dst->data, (void *)dst_quantized.get(), dst->ne[0], n_ids*Ny, 1,
dst_padded_col_size, next->src[0]->type, stream);
CUDA_CHECK(cudaGetLastError());
@ -2909,20 +2919,28 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten
int result = i + 1;
if (i+2 < graph->n_nodes &&
graph->nodes[i+2]->op == GGML_OP_ADD_ID &&
graph->nodes[i+2]->src[0] == next &&
graph->nodes[i+2]->src[2] == ids) {
ggml_cuda_op_mul_mat_vec_q_id(ctx, &local_src0, &local_src1, ids, &local_next, graph->nodes[i+2]->src[1],
(const char *)next->src[0]->data, nullptr, dst_quantized.get(), (float *)graph->nodes[i+2]->data,
0, next->src[0]->ne[1], 1, dst_padded_col_size, stream);
++result;
} else {
ggml_cuda_op_mul_mat_vec_q_id(ctx, &local_src0, &local_src1, ids, &local_next, nullptr,
(const char *)next->src[0]->data, nullptr, dst_quantized.get(), (float *)next->data,
0, next->src[0]->ne[1], 1, dst_padded_col_size, stream);
//printf("next: %ld x %ld x %ld x %ld, %zu x %zu x %zu x %zu\n", next->ne[0], next->ne[1], next->ne[2], next->ne[3], next->nb[0], next->nb[1], next->nb[2], next->nb[3]);
for (int iy = 0; iy < Ny; ++iy) {
local_ids.data = (char *)ids->data + iy*ids->nb[1];
auto this_dst_quantized = dst_quantized.get() + iy*dst_ddq_size;
if (i+2 < graph->n_nodes &&
graph->nodes[i+2]->op == GGML_OP_ADD_ID &&
graph->nodes[i+2]->src[0] == next &&
graph->nodes[i+2]->src[2] == ids) {
ggml_cuda_op_mul_mat_vec_q_id(ctx, &local_src0, &local_src1, &local_ids, &local_next, graph->nodes[i+2]->src[1],
(const char *)next->src[0]->data, nullptr, this_dst_quantized, (float *)graph->nodes[i+2]->data + iy*next->ne[0]*n_ids,
0, next->src[0]->ne[1], 1, dst_padded_col_size, stream);
if (iy == 0) {
++result;
}
} else {
ggml_cuda_op_mul_mat_vec_q_id(ctx, &local_src0, &local_src1, &local_ids, &local_next, nullptr,
(const char *)next->src[0]->data, nullptr, this_dst_quantized, (float *)next->data + iy*next->ne[0]*n_ids,
0, next->src[0]->ne[1], 1, dst_padded_col_size, stream);
}
CUDA_CHECK(cudaGetLastError());
}
CUDA_CHECK(cudaGetLastError());
return result;
}

View File

@ -66,7 +66,7 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
}
template <ggml_type type, int ncols_y, int nwarps>
static __device__ void mul_mat_vec_q(
static __device__ void k_mul_mat_vec_q(
const void * __restrict__ vx, const void * __restrict__ vy,
const float * bias, float * __restrict__ dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
@ -150,7 +150,7 @@ static __device__ void mul_mat_vec_q(
}
template <ggml_type type, int ncols_y, int nwarps>
static __device__ void fused_mul_mat_vec_q(
static __device__ void k_fused_mul_mat_vec_q(
const void * __restrict__ vup, const void * __restrict__ vgate,
const float * __restrict__ bias_u, const float * __restrict__ bias_g,
const void * __restrict__ vy, float * __restrict__ dst,
@ -291,7 +291,7 @@ static __global__ void mul_mat_vec_q(
const char * cx = (const char *)vx + i02*nb02;
const char * cy = (const char *)vy + i2*nb12;
const float * b = (const float *)(bias ? ids_data ? (const char *)bias + i02*bias_nb1 : bias : nullptr);
mul_mat_vec_q<type, ncols_y, nwarps>(cx, cy, b, (float *)cdst, ncols_x, nrows_x, nrows_y, nrows_dst);
k_mul_mat_vec_q<type, ncols_y, nwarps>(cx, cy, b, (float *)cdst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
template <ggml_type type, int ncols_y, int nwarps>
@ -317,7 +317,7 @@ static __global__ void fused_mul_mat_vec_q(
const float * cx_u_b = bias_u ? (const float *)((const char *)bias_u + i02*bias_nb1) : nullptr;
const float * cx_g_b = bias_g ? (const float *)((const char *)bias_g + i02*bias_nb1) : nullptr;
const char * cy = (const char *)vy + i2*nb12;
fused_mul_mat_vec_q<type, ncols_y, nwarps>(cx_u, cx_g, cx_u_b, cx_g_b, cy, (float *)cdst, ncols_x, nrows_x, nrows_y, nrows_dst,
k_fused_mul_mat_vec_q<type, ncols_y, nwarps>(cx_u, cx_g, cx_u_b, cx_g_b, cy, (float *)cdst, ncols_x, nrows_x, nrows_y, nrows_dst,
unary_op, limit);
}