mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-06-28 04:30:15 -05:00
Faster prompt processing on CUDA (#1687)
* Better fixup_stream_k * ggml_cuda_op_mul_mat_q -> ggml_cuda_mul_mat_q_id * Adding forgotten file
This commit is contained in:
parent
cb58a561f0
commit
3a945af45d
@ -2323,8 +2323,7 @@ static int ggml_cuda_mul_mat_q(ggml_backend_cuda_context & ctx, const ggml_tenso
|
||||
quantize_mmq_q8_1_cuda((const float *)src1->data, src1_quantized.get(), src1->ne[0], src1->ne[1], 1, ne10_padded, src0->type, stream);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
ggml_cuda_op_mul_mat_q(ctx, src0, src1, dst, (const char *)src0->data, nullptr, src1_quantized.get(), (float *)dst->data,
|
||||
0, src0->ne[1], src1->ne[1], ne10_padded, stream);
|
||||
ggml_cuda_mul_mat_q_id(ctx, src0, src1, nullptr, dst, nullptr, src1_quantized.get());
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
@ -2355,8 +2354,7 @@ static int ggml_cuda_mul_mat_q(ggml_backend_cuda_context & ctx, const ggml_tenso
|
||||
(float *)dst->data, 0, dst->src[0]->ne[1], src1->ne[1], ne10_padded, stream);
|
||||
}
|
||||
} else {
|
||||
ggml_cuda_op_mul_mat_q(ctx, dst->src[0], src1, dst, (const char *)dst->src[0]->data, nullptr, src1_quantized.get(),
|
||||
(float *)dst->data, 0, dst->src[0]->ne[1], src1->ne[1], ne10_padded, stream);
|
||||
ggml_cuda_mul_mat_q_id(ctx, dst->src[0], src1, nullptr, dst, nullptr, src1_quantized.get());
|
||||
}
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
++node_n;
|
||||
|
||||
47
ggml/src/ggml-cuda/fastdiv.cuh
Normal file
47
ggml/src/ggml-cuda/fastdiv.cuh
Normal file
@ -0,0 +1,47 @@
|
||||
#pragma once
|
||||
|
||||
// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
|
||||
// Precompute mp (m' in the paper) and L such that division
|
||||
// can be computed using a multiply (high 32b of 64b result)
|
||||
// and a shift:
|
||||
//
|
||||
// n/d = (mulhi(n, mp) + n) >> L;
|
||||
static const uint3 init_fastdiv_values(uint64_t d_64) {
|
||||
GGML_ASSERT(d_64 != 0);
|
||||
GGML_ASSERT(d_64 <= std::numeric_limits<uint32_t>::max());
|
||||
|
||||
uint32_t d = (uint32_t)d_64;
|
||||
|
||||
// compute L = ceil(log2(d));
|
||||
uint32_t L = 0;
|
||||
while (L < 32 && (uint32_t{ 1 } << L) < d) {
|
||||
L++;
|
||||
}
|
||||
|
||||
uint32_t mp = (uint32_t) ((uint64_t{ 1 } << 32) * ((uint64_t{ 1 } << L) - d) / d + 1);
|
||||
// pack divisor as well to reduce error surface
|
||||
return make_uint3(mp, L, d);
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ uint32_t fastdiv(uint32_t n, const uint3 fastdiv_values) {
|
||||
// expects fastdiv_values to contain <mp, L, divisor> in <x, y, z>
|
||||
// fastdiv_values.z is unused and optimized away by the compiler.
|
||||
// Compute high 32 bits of n * mp
|
||||
const uint32_t hi = __umulhi(n, fastdiv_values.x);
|
||||
// add n, apply bit shift
|
||||
return (hi + n) >> fastdiv_values.y;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ uint32_t fastmodulo(uint32_t n, const uint3 fastdiv_values) {
|
||||
// expects fastdiv_values to contain <mp, L, divisor> in <x, y, z> (see init_fastdiv_values)
|
||||
return n - fastdiv(n, fastdiv_values) * fastdiv_values.z;
|
||||
}
|
||||
|
||||
// Calculate both division and modulo at once, returns <n/divisor, n%divisor>
|
||||
static __device__ __forceinline__ uint2 fast_div_modulo(uint32_t n, const uint3 fastdiv_values) {
|
||||
// expects fastdiv_values to contain <mp, L, divisor> in <x, y, z> (see init_fastdiv_values)
|
||||
const uint32_t div_val = fastdiv(n, fastdiv_values);
|
||||
const uint32_t mod_val = n - div_val * fastdiv_values.z;
|
||||
return make_uint2(div_val, mod_val);
|
||||
}
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
#include "mma_new.cuh"
|
||||
#include "vecdotq.cuh"
|
||||
#include "iqk_cuda_common.h"
|
||||
#include "fastdiv.cuh"
|
||||
|
||||
#include <vector>
|
||||
#include <climits>
|
||||
@ -3546,10 +3547,10 @@ template <ggml_type type, int mmq_x, bool need_check>
|
||||
static __global__ void mul_mat_q_id(
|
||||
const char * __restrict__ x, const int * __restrict__ y, const int32_t * __restrict__ ids_dst,
|
||||
const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, float * __restrict__ tmp_fixup,
|
||||
const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_row_x, const int ncols_y, const int stride_col_dst,
|
||||
const int channel_ratio, const int nchannels_y, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
||||
const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
|
||||
const int ncols_max) {
|
||||
const uint3 blocks_per_ne00, const int nrows_x, const int ncols_dst, const int stride_row_x, const int ncols_y, const int stride_col_dst,
|
||||
const uint3 channel_ratio, const uint3 nchannels_y, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
||||
const uint3 sample_ratio, const uint3 nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
|
||||
const uint3 ntx) {
|
||||
|
||||
// Skip unused template specializations for faster compilation:
|
||||
if (mmq_x > get_mmq_x_max_device() || mmq_x % mmq_get_granularity_device(mmq_x) != 0) {
|
||||
@ -3563,8 +3564,7 @@ static __global__ void mul_mat_q_id(
|
||||
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
||||
constexpr int mmq_y = get_mmq_y_device();
|
||||
|
||||
const int ntx = (ncols_max + mmq_x - 1) / mmq_x; // Number of tiles x
|
||||
const int nty = (nrows_x + mmq_y - 1) / mmq_y; // Number of tiles y
|
||||
const uint32_t nty = (nrows_x + mmq_y - 1) / mmq_y; // Number of tiles y
|
||||
|
||||
// Initialize the ids for writing back data with just the index.
|
||||
// For regular matrix multiplications this is never changed.
|
||||
@ -3585,8 +3585,9 @@ static __global__ void mul_mat_q_id(
|
||||
// On non-CDNA AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
|
||||
#if (defined(GGML_USE_HIP) && !defined(CDNA)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
|
||||
{
|
||||
const int wt = blockIdx.z / nchannels_y;
|
||||
const int zt = blockIdx.z - wt*nchannels_y;
|
||||
const uint2 tmp2 = fast_div_modulo(blockIdx.z, nchannels_y);
|
||||
const int wt = tmp2.x;
|
||||
const int zt = tmp2.y;
|
||||
const int jt = blockIdx.y;
|
||||
const int it = blockIdx.x;
|
||||
|
||||
@ -3629,39 +3630,39 @@ static __global__ void mul_mat_q_id(
|
||||
const int tile_x_max_i = nrows_x - it*mmq_y - 1;
|
||||
const int tile_y_max_j = col_diff - jt*mmq_x - 1;
|
||||
|
||||
const int64_t offset_x = (wt/sample_ratio )*int64_t(stride_sample_x)
|
||||
+ (zt/channel_ratio)*int64_t(stride_channel_x) + it*mmq_y*int64_t(stride_row_x);
|
||||
const int offset_x = fastdiv(wt, sample_ratio)*stride_sample_x + fastdiv(zt, channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
|
||||
|
||||
constexpr bool fixup = false;
|
||||
mul_mat_q_process_tile_id<type, mmq_x, need_check, fixup>
|
||||
(x + offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst,
|
||||
tile_x_max_i, tile_y_max_j, 0, ncols_x/qk);
|
||||
tile_x_max_i, tile_y_max_j, 0, blocks_per_ne00.z);
|
||||
return;
|
||||
}
|
||||
#endif // (defined(GGML_USE_HIP) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
|
||||
|
||||
const int64_t blocks_per_ne00 = ncols_x / qk;
|
||||
constexpr int blocks_per_iter = MMQ_ITER_K / qk;
|
||||
|
||||
// kbc == k block continuous, current index in continuous ijk space.
|
||||
int64_t kbc = (int64_t) blockIdx.x *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
|
||||
int64_t kbc_stop = (int64_t)(blockIdx.x + 1)*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
|
||||
int64_t kbc = int64_t(blockIdx.x) *nsamples_y.z*nchannels_y.z*ntx.z*nty*blocks_per_ne00.z / gridDim.x;
|
||||
int64_t kbc_stop = int64_t(blockIdx.x + 1)*nsamples_y.z*nchannels_y.z*ntx.z*nty*blocks_per_ne00.z / gridDim.x;
|
||||
|
||||
kbc -= (kbc % blocks_per_ne00) % blocks_per_iter;
|
||||
kbc_stop -= (kbc_stop % blocks_per_ne00) % blocks_per_iter;
|
||||
kbc -= fastmodulo(kbc, blocks_per_ne00) % blocks_per_iter;
|
||||
kbc_stop -= fastmodulo(kbc_stop, blocks_per_ne00) % blocks_per_iter;
|
||||
|
||||
// kb0 == k index when doing the matrix multiplication for an output tile.
|
||||
int kb0_start = kbc % blocks_per_ne00;
|
||||
int kb0_stop = min(blocks_per_ne00, kb0_start + kbc_stop - kbc);
|
||||
while (kbc < kbc_stop && kb0_stop == blocks_per_ne00) {
|
||||
int tmp = kbc;
|
||||
const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
|
||||
tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
|
||||
const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00);
|
||||
tmp -= wt * (nchannels_y*ntx*blocks_per_ne00);
|
||||
const int zt = tmp / (ntx*blocks_per_ne00);
|
||||
tmp -= zt * (ntx*blocks_per_ne00);
|
||||
const int jt = tmp / blocks_per_ne00;
|
||||
int kb0_start = fastmodulo(kbc, blocks_per_ne00);
|
||||
int kb0_stop = min(blocks_per_ne00.z, uint32_t(kb0_start + kbc_stop - kbc));
|
||||
while (kbc < kbc_stop && kb0_stop == int(blocks_per_ne00.z)) {
|
||||
int tmp = fastdiv(kbc, blocks_per_ne00);
|
||||
uint2 tmp2 = fast_div_modulo(tmp, ntx);
|
||||
const int jt = tmp2.y;
|
||||
tmp = tmp2.x;
|
||||
tmp2 = fast_div_modulo(tmp, nchannels_y);
|
||||
const int zt = tmp2.y;
|
||||
tmp = tmp2.x;
|
||||
tmp2 = fast_div_modulo(tmp, nsamples_y);
|
||||
const int wt = tmp2.y;
|
||||
const int it = tmp2.x;
|
||||
|
||||
// Defaults for regular matrix multiplication:
|
||||
int col_low = 0;
|
||||
@ -3679,11 +3680,11 @@ static __global__ void mul_mat_q_id(
|
||||
offset_dst = 0;
|
||||
|
||||
if (jt*mmq_x >= col_diff) {
|
||||
kbc += blocks_per_ne00;
|
||||
kbc -= kbc % blocks_per_ne00;
|
||||
kbc += blocks_per_ne00.z;
|
||||
kbc -= fastmodulo(kbc, blocks_per_ne00);
|
||||
|
||||
kb0_start = 0;
|
||||
kb0_stop = min(blocks_per_ne00, kbc_stop - kbc);
|
||||
kb0_stop = min(blocks_per_ne00.z, uint32_t(kbc_stop - kbc));
|
||||
|
||||
continue;
|
||||
}
|
||||
@ -3708,33 +3709,34 @@ static __global__ void mul_mat_q_id(
|
||||
const int tile_x_max_i = nrows_x - it*mmq_y - 1;
|
||||
const int tile_y_max_j = col_diff - jt*mmq_x - 1;
|
||||
|
||||
const int64_t offset_x = (wt/sample_ratio )*int64_t(stride_sample_x)
|
||||
+ (zt/channel_ratio)*int64_t(stride_channel_x) + it*mmq_y*int64_t(stride_row_x);
|
||||
const int offset_x = fastdiv(wt, sample_ratio)*stride_sample_x + fastdiv(zt, channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
|
||||
|
||||
constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
|
||||
mul_mat_q_process_tile_id<type, mmq_x, need_check, fixup>
|
||||
(x + offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst,
|
||||
tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
|
||||
|
||||
kbc += blocks_per_ne00;
|
||||
kbc -= kbc % blocks_per_ne00;
|
||||
kbc += blocks_per_ne00.z;
|
||||
kbc -= fastmodulo(kbc, blocks_per_ne00);
|
||||
|
||||
kb0_start = 0;
|
||||
kb0_stop = min(blocks_per_ne00, kbc_stop - kbc);
|
||||
kb0_stop = min(blocks_per_ne00.z, uint32_t(kbc_stop - kbc));
|
||||
}
|
||||
|
||||
if (kbc >= kbc_stop) {
|
||||
return;
|
||||
}
|
||||
|
||||
int tmp = kbc;
|
||||
const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
|
||||
tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
|
||||
const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00);
|
||||
tmp -= wt * (nchannels_y*ntx*blocks_per_ne00);
|
||||
const int zt = tmp / (ntx*blocks_per_ne00);
|
||||
tmp -= zt * (ntx*blocks_per_ne00);
|
||||
const int jt = tmp / blocks_per_ne00;
|
||||
int tmp = fastdiv(kbc, blocks_per_ne00);
|
||||
uint2 tmp2 = fast_div_modulo(tmp, ntx);
|
||||
const int jt = tmp2.y;
|
||||
tmp = tmp2.x;
|
||||
tmp2 = fast_div_modulo(tmp, nchannels_y);
|
||||
const int zt = tmp2.y;
|
||||
tmp = tmp2.x;
|
||||
tmp2 = fast_div_modulo(tmp, nsamples_y);
|
||||
const int wt = tmp2.y;
|
||||
const int it = tmp2.x;
|
||||
|
||||
// Defaults for regular matrix multiplication:
|
||||
int col_low = 0;
|
||||
@ -3776,8 +3778,7 @@ static __global__ void mul_mat_q_id(
|
||||
const int tile_x_max_i = nrows_x - it*mmq_y - 1;
|
||||
const int tile_y_max_j = col_diff - jt*mmq_x - 1;
|
||||
|
||||
const int64_t offset_x = (wt/sample_ratio )*int64_t(stride_sample_x)
|
||||
+ (zt/channel_ratio)*int64_t(stride_channel_x) + it*mmq_y*int64_t(stride_row_x);
|
||||
const int offset_x = fastdiv(wt, sample_ratio)*stride_sample_x + fastdiv(zt, channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
|
||||
|
||||
constexpr bool fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
|
||||
mul_mat_q_process_tile_id<type, mmq_x, need_check, fixup>
|
||||
@ -3787,36 +3788,37 @@ static __global__ void mul_mat_q_id(
|
||||
|
||||
|
||||
template <ggml_type type, int mmq_x, bool need_check>
|
||||
__launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device()/2, 1)
|
||||
static __global__ void mul_mat_q_stream_k_fixup_id(
|
||||
const int32_t * ids_dst, const int32_t * expert_bounds, float * __restrict__ dst, const float * __restrict__ tmp_last_tile,
|
||||
const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_col_dst,
|
||||
const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst,
|
||||
const int ncols_max) {
|
||||
constexpr int mmq_y = get_mmq_y_device();
|
||||
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
||||
constexpr int blocks_per_iter = MMQ_ITER_K / qk;
|
||||
const int64_t blocks_per_ne00 = ncols_x / qk;
|
||||
const int32_t * __restrict__ ids_dst, const int32_t * __restrict__ expert_bounds, float * __restrict__ dst,
|
||||
float * __restrict__ tmp_last_tile, const uint3 blocks_per_ne00, const int nrows_x, const int ncols_dst,
|
||||
const int stride_col_dst, const uint3 nchannels_y, const int stride_channel_dst, const uint3 nsamples_y,
|
||||
const int stride_sample_dst, const uint3 ntx) {
|
||||
constexpr int mmq_y = get_mmq_y_device();
|
||||
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
||||
constexpr int ITER_K = MMQ_ITER_K; //get_iter_k(type);
|
||||
constexpr int blocks_per_iter = ITER_K / qk;
|
||||
|
||||
constexpr int nwarps = mmq_get_nwarps_device();
|
||||
constexpr int nwarps = mmq_get_nwarps_device()/2;
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
|
||||
float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f};
|
||||
float sum[mmq_x / nwarps] = {0.0f};
|
||||
const int i = blockIdx.y*warp_size + threadIdx.x;
|
||||
|
||||
const int ntx = (ncols_max + mmq_x - 1) / mmq_x;
|
||||
const int nty = (nrows_x + mmq_y - 1) / mmq_y;
|
||||
const int nty = (nrows_x + mmq_y - 1) / mmq_y;
|
||||
|
||||
const int bidx0 = blockIdx.x;
|
||||
|
||||
// kbc == k block continuous, current index in continuous ijk space.
|
||||
int64_t kbc0 = (int64_t) bidx0 *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
|
||||
int64_t kbc0_stop = (int64_t)(bidx0 + 1)*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
|
||||
int64_t kbc0 = int64_t(blockIdx.x) *nsamples_y.z*nchannels_y.z*ntx.z*nty*blocks_per_ne00.z / gridDim.x;
|
||||
int64_t kbc0_stop = int64_t(blockIdx.x + 1)*nsamples_y.z*nchannels_y.z*ntx.z*nty*blocks_per_ne00.z / gridDim.x;
|
||||
|
||||
kbc0 -= (kbc0 % blocks_per_ne00) % blocks_per_iter;
|
||||
kbc0_stop -= (kbc0_stop % blocks_per_ne00) % blocks_per_iter;
|
||||
kbc0 -= fastmodulo(kbc0, blocks_per_ne00) % blocks_per_iter;
|
||||
kbc0_stop -= fastmodulo(kbc0_stop, blocks_per_ne00) % blocks_per_iter;
|
||||
|
||||
const bool did_not_have_any_data = kbc0 == kbc0_stop;
|
||||
const bool wrote_beginning_of_tile = kbc0 % blocks_per_ne00 == 0;
|
||||
const bool did_not_write_last = kbc0/blocks_per_ne00 == kbc0_stop/blocks_per_ne00 && kbc0_stop % blocks_per_ne00 != 0;
|
||||
const bool wrote_beginning_of_tile = fastmodulo(kbc0, blocks_per_ne00) == 0;
|
||||
const bool did_not_write_last = fastdiv(kbc0, blocks_per_ne00) == fastdiv(kbc0_stop, blocks_per_ne00) && fastmodulo(kbc0_stop, blocks_per_ne00) != 0;
|
||||
if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) {
|
||||
return;
|
||||
}
|
||||
@ -3828,8 +3830,8 @@ static __global__ void mul_mat_q_stream_k_fixup_id(
|
||||
int64_t bidx = bidx0 - 1;
|
||||
int64_t kbc_stop = kbc0;
|
||||
while(true) {
|
||||
int64_t kbc = bidx*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
|
||||
kbc -= (kbc % blocks_per_ne00) % blocks_per_iter;
|
||||
int64_t kbc = bidx*nsamples_y.z*nchannels_y.z*ntx.z*nty*blocks_per_ne00.z / gridDim.x;
|
||||
kbc -= fastmodulo(kbc, blocks_per_ne00) % blocks_per_iter;
|
||||
|
||||
if (kbc == kbc_stop) { // Did not have any data.
|
||||
bidx--;
|
||||
@ -3843,16 +3845,11 @@ static __global__ void mul_mat_q_stream_k_fixup_id(
|
||||
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
||||
const int j = j0 + threadIdx.y;
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i];
|
||||
}
|
||||
sum[j0/nwarps] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i];
|
||||
}
|
||||
|
||||
// If this block started in a previous tile we are done and don't need to combine additional partial results.
|
||||
if (kbc % blocks_per_ne00 == 0 || kbc/blocks_per_ne00 < kbc0/blocks_per_ne00) {
|
||||
if (fastmodulo(kbc, blocks_per_ne00) == 0 || fastdiv(kbc, blocks_per_ne00) < fastdiv(kbc0, blocks_per_ne00)) {
|
||||
break;
|
||||
}
|
||||
bidx--;
|
||||
@ -3863,14 +3860,16 @@ static __global__ void mul_mat_q_stream_k_fixup_id(
|
||||
return;
|
||||
}
|
||||
|
||||
int tmp = kbc0;
|
||||
const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
|
||||
tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
|
||||
const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00);
|
||||
tmp -= wt * (nchannels_y*ntx*blocks_per_ne00);
|
||||
const int zt = tmp / (ntx*blocks_per_ne00);
|
||||
tmp -= zt * (ntx*blocks_per_ne00);
|
||||
const int jt = tmp / blocks_per_ne00;
|
||||
int tmp = fastdiv(kbc0, blocks_per_ne00);
|
||||
uint2 tmp2 = fast_div_modulo(tmp, ntx);
|
||||
const int jt = tmp2.y;
|
||||
tmp = tmp2.x;
|
||||
tmp2 = fast_div_modulo(tmp, nchannels_y);
|
||||
const int zt = tmp2.y;
|
||||
tmp = tmp2.x;
|
||||
tmp2 = fast_div_modulo(tmp, nsamples_y);
|
||||
const int wt = tmp2.y;
|
||||
const int it = tmp2.x;
|
||||
|
||||
if (!ids_dst) {
|
||||
const int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst + it*mmq_y;
|
||||
@ -3878,6 +3877,9 @@ static __global__ void mul_mat_q_stream_k_fixup_id(
|
||||
|
||||
const int i_max = nrows_x - it*mmq_y - 1;
|
||||
const int j_max = ncols_dst - jt*mmq_x - 1;
|
||||
if (need_check && i > i_max) {
|
||||
return;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
||||
@ -3887,16 +3889,7 @@ static __global__ void mul_mat_q_stream_k_fixup_id(
|
||||
return;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
if (need_check && i > i_max) {
|
||||
continue;
|
||||
}
|
||||
|
||||
dst[j*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size];
|
||||
}
|
||||
dst[j*stride_col_dst + i] += sum[j0/nwarps];
|
||||
}
|
||||
return;
|
||||
}
|
||||
@ -3916,6 +3909,9 @@ static __global__ void mul_mat_q_stream_k_fixup_id(
|
||||
|
||||
const int i_max = nrows_x - it*mmq_y - 1;
|
||||
const int j_max = col_diff - jt*mmq_x - 1;
|
||||
if (need_check && i > i_max) {
|
||||
return;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
||||
@ -3925,16 +3921,7 @@ static __global__ void mul_mat_q_stream_k_fixup_id(
|
||||
return;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
if (need_check && i > i_max) {
|
||||
continue;
|
||||
}
|
||||
|
||||
dst[ids_dst_shared[j]*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size];
|
||||
}
|
||||
dst[ids_dst_shared[j]*stride_col_dst + i] += sum[j0/nwarps];
|
||||
}
|
||||
}
|
||||
|
||||
@ -3986,29 +3973,42 @@ static void launch_mul_mat_q_id(ggml_backend_cuda_context & ctx, const mmq_args_
|
||||
const int channel_ratio = args.nchannels_y / args.nchannels_x;
|
||||
const int sample_ratio = args.nsamples_y / args.nsamples_x;
|
||||
|
||||
const uint3 blocks_per_ne00_fd = init_fastdiv_values(args.ncols_x / ggml_cuda_type_traits<type>::qk);
|
||||
const uint3 ntx_fd = init_fastdiv_values(ntx);
|
||||
const uint3 nchannels_y_fd = init_fastdiv_values(args.nchannels_y);
|
||||
const uint3 nsamples_y_fd = init_fastdiv_values(args.nsamples_y);
|
||||
const uint3 channel_ratio_fd = init_fastdiv_values(channel_ratio);
|
||||
const uint3 sample_ratio_fd = init_fastdiv_values(sample_ratio);
|
||||
|
||||
if (!args.use_stream_k) {
|
||||
if (args.nrows_x % mmq_y == 0) {
|
||||
constexpr bool need_check = false;
|
||||
mul_mat_q_id<type, mmq_x, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
|
||||
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
|
||||
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
|
||||
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
|
||||
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
|
||||
args.ncols_max);
|
||||
blocks_per_ne00_fd, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
|
||||
channel_ratio_fd, nchannels_y_fd, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
|
||||
sample_ratio_fd, nsamples_y_fd, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
|
||||
ntx_fd);
|
||||
} else {
|
||||
constexpr bool need_check = true;
|
||||
mul_mat_q_id<type, mmq_x, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
|
||||
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
|
||||
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
|
||||
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
|
||||
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
|
||||
args.ncols_max);
|
||||
blocks_per_ne00_fd, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
|
||||
channel_ratio_fd, nchannels_y_fd, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
|
||||
sample_ratio_fd, nsamples_y_fd, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
|
||||
ntx_fd);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
const dim3 block_nums_stream_k(nsm, 1, 1);
|
||||
const bool fixup_needed = ntx*nty*ntzw % nsm != 0;
|
||||
// For the stream-k kernel it is possible to run it with tiling by setting the number of CUDA blocks equal to the number of tiles.
|
||||
// This is worthwhile if the efficiency of tiling is high and skipping the fixup kernel is more important.
|
||||
const int ntiles_dst = ntx * nty * ntzw;
|
||||
const int tiles_nwaves = (ntiles_dst + nsm - 1) / nsm;
|
||||
const int tiles_efficiency_percent = 100 * ntiles_dst / (nsm*tiles_nwaves);
|
||||
const dim3 block_nums_stream_k(GGML_CUDA_CC_IS_NVIDIA(cc) && tiles_efficiency_percent >= 90 ? ntiles_dst : nsm, 1, 1);
|
||||
|
||||
const bool fixup_needed = ntiles_dst % block_nums_stream_k.x != 0;
|
||||
|
||||
ggml_cuda_pool & pool = ctx.pool(id);
|
||||
ggml_cuda_pool_alloc<float> tmp_fixup(pool);
|
||||
@ -4016,40 +4016,45 @@ static void launch_mul_mat_q_id(ggml_backend_cuda_context & ctx, const mmq_args_
|
||||
tmp_fixup.alloc(block_nums_stream_k.x * mmq_x*mmq_y);
|
||||
}
|
||||
|
||||
const dim3 block_nums_fixup(block_nums_stream_k.x, mmq_y/warp_size, 1);
|
||||
const dim3 block_dims_fixup(block_dims.x, block_dims.y/2, block_dims.z);
|
||||
|
||||
if (args.nrows_x % mmq_y == 0) {
|
||||
constexpr bool need_check = false;
|
||||
mul_mat_q_id<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
|
||||
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
|
||||
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
|
||||
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
|
||||
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
|
||||
args.ncols_max);
|
||||
blocks_per_ne00_fd, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
|
||||
channel_ratio_fd, nchannels_y_fd, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
|
||||
sample_ratio_fd, nsamples_y_fd, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
|
||||
ntx_fd);
|
||||
|
||||
if (!fixup_needed) {
|
||||
return;
|
||||
}
|
||||
|
||||
mul_mat_q_stream_k_fixup_id<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
|
||||
(args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,
|
||||
args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst,
|
||||
args.ncols_max);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
mul_mat_q_stream_k_fixup_id<type, mmq_x, need_check><<<block_nums_fixup, block_dims_fixup, 0, stream>>>
|
||||
(args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, blocks_per_ne00_fd, args.nrows_x, args.ncols_dst,
|
||||
args.nrows_dst, nchannels_y_fd, args.stride_channel_dst, nsamples_y_fd, args.stride_sample_dst,
|
||||
ntx_fd);
|
||||
} else {
|
||||
constexpr bool need_check = true;
|
||||
mul_mat_q_id<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
|
||||
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
|
||||
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
|
||||
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
|
||||
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
|
||||
args.ncols_max);
|
||||
blocks_per_ne00_fd, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
|
||||
channel_ratio_fd, nchannels_y_fd, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
|
||||
sample_ratio_fd, nsamples_y_fd, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
|
||||
ntx_fd);
|
||||
|
||||
if (!fixup_needed) {
|
||||
return;
|
||||
}
|
||||
|
||||
mul_mat_q_stream_k_fixup_id<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
|
||||
(args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,
|
||||
args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst,
|
||||
args.ncols_max);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
mul_mat_q_stream_k_fixup_id<type, mmq_x, need_check><<<block_nums_fixup, block_dims_fixup, 0, stream>>>
|
||||
(args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, blocks_per_ne00_fd, args.nrows_x, args.ncols_dst,
|
||||
args.nrows_dst, nchannels_y_fd, args.stride_channel_dst, nsamples_y_fd, args.stride_sample_dst,
|
||||
ntx_fd);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user