From 2973e80970367f01200531e40227cc04d6e96623 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Thu, 30 Apr 2026 08:14:30 +0200 Subject: [PATCH] Faster small batch inference for MoE models (#1707) --- ggml/src/ggml-cuda.cu | 128 +++++++++++++++----------- ggml/src/ggml-cuda/mmvq-templates.cuh | 8 +- 2 files changed, 77 insertions(+), 59 deletions(-) diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 8f46691c..23c76b3a 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -2798,7 +2798,7 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten const ggml_tensor * src1 = dst->src[2]; const ggml_tensor * ids = dst->src[3]; - if (src1->ne[1] == 1 && src1->ne[2] == 1 && src1->ne[3] == 1 && + if (src1->ne[1] == 1 && src1->ne[2] <= 8 && src1->ne[3] == 1 && ggml_is_quantized(src0_1->type) && (!src0_2 || ggml_is_quantized(src0_2->type)) && ggml_backend_buffer_is_cuda(src0_1->buffer) && @@ -2826,21 +2826,22 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten auto local_src1 = *src1; local_src1.nb[2] = local_src1.nb[3] = 0; + local_src1.ne[1] = local_src1.ne[2] = local_src1.ne[3] = 1; + + int Ny = src1->ne[2]; const int64_t src1_padded_col_size = GGML_PAD(src1->ne[0], MATRIX_ROW_PADDING); ggml_cuda_pool_alloc src1_quantized(ctx.pool()); - if (ggml_is_quantized(src0_1->type) || (src0_2 && ggml_is_quantized(src0_2->type))) { - GGML_ASSERT(src1->ne[0] % QK8_1 == 0); - auto src_1_ddq_size = src1_padded_col_size*sizeof(block_q8_1)/QK8_1; - local_src1.data = src1_quantized.alloc(src_1_ddq_size); - // Note: no use is currently made of the quantization type passed into quantize_row_q8_1_cuda. - // If that were to change, we would need to adjust the code to handle src0_1->type != src0_2->type - quantize_row_q8_1_cuda((const float *)src1->data, (void *)src1_quantized.get(), src1->ne[0], 1, 1, src1_padded_col_size, - src0_1->type, stream); - CUDA_CHECK(cudaGetLastError()); + GGML_ASSERT(src1->ne[0] % QK8_1 == 0); + auto src_1_ddq_size = src1_padded_col_size*sizeof(block_q8_1)/QK8_1; + local_src1.data = src1_quantized.alloc(src_1_ddq_size * Ny); + // Note: no use is currently made of the quantization type passed into quantize_row_q8_1_cuda. + // If that were to change, we would need to adjust the code to handle src0_1->type != src0_2->type + quantize_row_q8_1_cuda((const float *)src1->data, (void *)src1_quantized.get(), src1->ne[0], Ny, 1, src1_padded_col_size, + src0_1->type, stream); + CUDA_CHECK(cudaGetLastError()); - local_src1.nb[1] = src_1_ddq_size; - } + local_src1.nb[1] = src_1_ddq_size; bool fuse_next = next && next->op == GGML_OP_MUL_MAT_ID && ggml_is_quantized(next->src[0]->type) && ggml_backend_buffer_is_cuda(next->src[0]->buffer) && @@ -2850,37 +2851,46 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten auto unary_op = (ggml_unary_op)dst->op_params[0]; float limit = *(const float *)(dst->op_params + 1); - if (src0_2) { - ggml_cuda_op_fused_mul_mat_vec_q_id(ctx, src0_1, &local_src1, ids, &local_dst, - dst->src[4], dst->src[5], - (const char *)src0_1->data, (const char *)src0_2->data, - (const float *)src1->data, src1_quantized.get(), - (float *)local_dst.data, 0, src0_1->ne[1], 1, src1_padded_col_size, unary_op, limit, stream); - } else { - auto local_src0_1 = *src0_1; - local_src0_1.ne[1] /= 2; - auto local_src0_2 = local_src0_1; - local_src0_2.data = (char *)local_src0_1.data + local_src0_1.ne[1]*local_src0_1.nb[1]; - if (!dst->src[4]) { - ggml_cuda_op_fused_mul_mat_vec_q_id(ctx, &local_src0_1, &local_src1, ids, &local_dst, - nullptr, nullptr, - (const char *)local_src0_2.data, (const char *)local_src0_1.data, - (const float *)src1->data, src1_quantized.get(), - (float *)local_dst.data, 0, local_src0_1.ne[1], 1, src1_padded_col_size, unary_op, limit, stream); + + auto local_ids = *ids; + local_ids.ne[1] = 1; + + for (int iy = 0; iy < Ny; ++iy) { + local_src1.data = src1_quantized.get() + iy*src_1_ddq_size; + local_ids.data = (char *)ids->data + iy*ids->nb[1]; + local_dst.data = (char *)dst->data + iy*dst->nb[2]; + if (src0_2) { + ggml_cuda_op_fused_mul_mat_vec_q_id(ctx, src0_1, &local_src1, &local_ids, &local_dst, + dst->src[4], dst->src[5], + (const char *)src0_1->data, (const char *)src0_2->data, + (const float *)src1->data, (const char *)local_src1.data, + (float *)local_dst.data, 0, src0_1->ne[1], 1, src1_padded_col_size, unary_op, limit, stream); } else { - GGML_ASSERT(!dst->src[5]); - auto local_bias_1 = *dst->src[4]; - local_bias_1.ne[0] /= 2; - auto local_bias_2 = local_bias_1; - local_bias_2.data = (char *)local_bias_1.data + local_bias_1.ne[0]*local_bias_1.nb[0]; - ggml_cuda_op_fused_mul_mat_vec_q_id(ctx, &local_src0_1, &local_src1, ids, &local_dst, - &local_bias_2, &local_bias_1, - (const char *)local_src0_2.data, (const char *)local_src0_1.data, - (const float *)src1->data, src1_quantized.get(), - (float *)local_dst.data, 0, local_src0_1.ne[1], 1, src1_padded_col_size, unary_op, limit, stream); + auto local_src0_1 = *src0_1; + local_src0_1.ne[1] /= 2; + auto local_src0_2 = local_src0_1; + local_src0_2.data = (char *)local_src0_1.data + local_src0_1.ne[1]*local_src0_1.nb[1]; + if (!dst->src[4]) { + ggml_cuda_op_fused_mul_mat_vec_q_id(ctx, &local_src0_1, &local_src1, &local_ids, &local_dst, + nullptr, nullptr, + (const char *)local_src0_2.data, (const char *)local_src0_1.data, + (const float *)src1->data, (const char *)local_src1.data, + (float *)local_dst.data, 0, local_src0_1.ne[1], 1, src1_padded_col_size, unary_op, limit, stream); + } else { + GGML_ASSERT(!dst->src[5]); + auto local_bias_1 = *dst->src[4]; + local_bias_1.ne[0] /= 2; + auto local_bias_2 = local_bias_1; + local_bias_2.data = (char *)local_bias_1.data + local_bias_1.ne[0]*local_bias_1.nb[0]; + ggml_cuda_op_fused_mul_mat_vec_q_id(ctx, &local_src0_1, &local_src1, &local_ids, &local_dst, + &local_bias_2, &local_bias_1, + (const char *)local_src0_2.data, (const char *)local_src0_1.data, + (const float *)src1->data, (const char *)local_src1.data, + (float *)local_dst.data, 0, local_src0_1.ne[1], 1, src1_padded_col_size, unary_op, limit, stream); + } } + CUDA_CHECK(cudaGetLastError()); } - CUDA_CHECK(cudaGetLastError()); if (!fuse_next) return i; @@ -2888,8 +2898,8 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten GGML_ASSERT(dst->ne[0] % QK8_1 == 0); auto dst_row_size = dst_padded_col_size*sizeof(block_q8_1)/QK8_1; auto dst_ddq_size = n_ids*dst_row_size; - ggml_cuda_pool_alloc dst_quantized(ctx.pool(), dst_ddq_size); - quantize_row_q8_1_cuda((const float *)local_dst.data, (void *)dst_quantized.get(), dst->ne[0], n_ids, 1, + ggml_cuda_pool_alloc dst_quantized(ctx.pool(), dst_ddq_size*Ny); + quantize_row_q8_1_cuda((const float *)dst->data, (void *)dst_quantized.get(), dst->ne[0], n_ids*Ny, 1, dst_padded_col_size, next->src[0]->type, stream); CUDA_CHECK(cudaGetLastError()); @@ -2909,20 +2919,28 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten int result = i + 1; - if (i+2 < graph->n_nodes && - graph->nodes[i+2]->op == GGML_OP_ADD_ID && - graph->nodes[i+2]->src[0] == next && - graph->nodes[i+2]->src[2] == ids) { - ggml_cuda_op_mul_mat_vec_q_id(ctx, &local_src0, &local_src1, ids, &local_next, graph->nodes[i+2]->src[1], - (const char *)next->src[0]->data, nullptr, dst_quantized.get(), (float *)graph->nodes[i+2]->data, - 0, next->src[0]->ne[1], 1, dst_padded_col_size, stream); - ++result; - } else { - ggml_cuda_op_mul_mat_vec_q_id(ctx, &local_src0, &local_src1, ids, &local_next, nullptr, - (const char *)next->src[0]->data, nullptr, dst_quantized.get(), (float *)next->data, - 0, next->src[0]->ne[1], 1, dst_padded_col_size, stream); + //printf("next: %ld x %ld x %ld x %ld, %zu x %zu x %zu x %zu\n", next->ne[0], next->ne[1], next->ne[2], next->ne[3], next->nb[0], next->nb[1], next->nb[2], next->nb[3]); + + for (int iy = 0; iy < Ny; ++iy) { + local_ids.data = (char *)ids->data + iy*ids->nb[1]; + auto this_dst_quantized = dst_quantized.get() + iy*dst_ddq_size; + if (i+2 < graph->n_nodes && + graph->nodes[i+2]->op == GGML_OP_ADD_ID && + graph->nodes[i+2]->src[0] == next && + graph->nodes[i+2]->src[2] == ids) { + ggml_cuda_op_mul_mat_vec_q_id(ctx, &local_src0, &local_src1, &local_ids, &local_next, graph->nodes[i+2]->src[1], + (const char *)next->src[0]->data, nullptr, this_dst_quantized, (float *)graph->nodes[i+2]->data + iy*next->ne[0]*n_ids, + 0, next->src[0]->ne[1], 1, dst_padded_col_size, stream); + if (iy == 0) { + ++result; + } + } else { + ggml_cuda_op_mul_mat_vec_q_id(ctx, &local_src0, &local_src1, &local_ids, &local_next, nullptr, + (const char *)next->src[0]->data, nullptr, this_dst_quantized, (float *)next->data + iy*next->ne[0]*n_ids, + 0, next->src[0]->ne[1], 1, dst_padded_col_size, stream); + } + CUDA_CHECK(cudaGetLastError()); } - CUDA_CHECK(cudaGetLastError()); return result; } diff --git a/ggml/src/ggml-cuda/mmvq-templates.cuh b/ggml/src/ggml-cuda/mmvq-templates.cuh index bd3f9f0b..d9bc431b 100644 --- a/ggml/src/ggml-cuda/mmvq-templates.cuh +++ b/ggml/src/ggml-cuda/mmvq-templates.cuh @@ -66,7 +66,7 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) { } template -static __device__ void mul_mat_vec_q( +static __device__ void k_mul_mat_vec_q( const void * __restrict__ vx, const void * __restrict__ vy, const float * bias, float * __restrict__ dst, const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { @@ -150,7 +150,7 @@ static __device__ void mul_mat_vec_q( } template -static __device__ void fused_mul_mat_vec_q( +static __device__ void k_fused_mul_mat_vec_q( const void * __restrict__ vup, const void * __restrict__ vgate, const float * __restrict__ bias_u, const float * __restrict__ bias_g, const void * __restrict__ vy, float * __restrict__ dst, @@ -291,7 +291,7 @@ static __global__ void mul_mat_vec_q( const char * cx = (const char *)vx + i02*nb02; const char * cy = (const char *)vy + i2*nb12; const float * b = (const float *)(bias ? ids_data ? (const char *)bias + i02*bias_nb1 : bias : nullptr); - mul_mat_vec_q(cx, cy, b, (float *)cdst, ncols_x, nrows_x, nrows_y, nrows_dst); + k_mul_mat_vec_q(cx, cy, b, (float *)cdst, ncols_x, nrows_x, nrows_y, nrows_dst); } template @@ -317,7 +317,7 @@ static __global__ void fused_mul_mat_vec_q( const float * cx_u_b = bias_u ? (const float *)((const char *)bias_u + i02*bias_nb1) : nullptr; const float * cx_g_b = bias_g ? (const float *)((const char *)bias_g + i02*bias_nb1) : nullptr; const char * cy = (const char *)vy + i2*nb12; - fused_mul_mat_vec_q(cx_u, cx_g, cx_u_b, cx_g_b, cy, (float *)cdst, ncols_x, nrows_x, nrows_y, nrows_dst, + k_fused_mul_mat_vec_q(cx_u, cx_g, cx_u_b, cx_g_b, cy, (float *)cdst, ncols_x, nrows_x, nrows_y, nrows_dst, unary_op, limit); }