mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-06-28 04:30:15 -05:00
q6_0: can now be used for kv-cache on Metal
This commit is contained in:
parent
0d0cd1ee68
commit
037bbd2d58
@ -276,6 +276,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F32_Q6_0,
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
|
||||
GGML_METAL_KERNEL_TYPE_CONCAT,
|
||||
GGML_METAL_KERNEL_TYPE_SQR,
|
||||
@ -803,6 +804,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q6_0, cpy_f32_q6_0, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
|
||||
@ -970,6 +972,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q5_0:
|
||||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_Q6_0:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
return true;
|
||||
default:
|
||||
@ -3318,6 +3321,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
|
||||
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break;
|
||||
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break;
|
||||
case GGML_TYPE_Q6_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q6_0].pipeline; break;
|
||||
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL].pipeline; break;
|
||||
default: GGML_ABORT("not implemented");
|
||||
};
|
||||
|
||||
@ -3527,6 +3527,77 @@ kernel void kernel_cpy_f32_q5_1(
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_cpy_f32_q6_0(
|
||||
device const float * src0,
|
||||
device void * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant int64_t & ne03,
|
||||
constant uint64_t & nb00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant uint64_t & nb03,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant int64_t & ne2,
|
||||
constant int64_t & ne3,
|
||||
constant uint64_t & nb0,
|
||||
constant uint64_t & nb1,
|
||||
constant uint64_t & nb2,
|
||||
constant uint64_t & nb3,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
const int64_t i03 = tgpig[2];
|
||||
const int64_t i02 = tgpig[1];
|
||||
const int64_t i01 = tgpig[0];
|
||||
|
||||
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||
|
||||
const int64_t i3 = n / (ne2*ne1*ne0);
|
||||
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
||||
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
||||
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK6_0;
|
||||
|
||||
device block_q6_0 * dst_data = (device block_q6_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
for (int64_t i00 = tpitg.x*QK6_0; i00 < ne00; i00 += ntg.x*QK6_0) {
|
||||
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
||||
|
||||
float amax = 0.0f; // absolute max
|
||||
float max = 0.0f;
|
||||
|
||||
for (int j = 0; j < QK6_0; j++) {
|
||||
const float v = src[j];
|
||||
if (amax < fabs(v)) {
|
||||
amax = fabs(v);
|
||||
max = v;
|
||||
}
|
||||
}
|
||||
|
||||
const float d = max / -32;
|
||||
const float id = d ? 1.0f/d : 0.0f;
|
||||
|
||||
device block_q6_0 & b6 = dst_data[i00/QK6_0];
|
||||
b6.d = d;
|
||||
device uint16_t * aux16 = (device uint16_t *)b6.qh;
|
||||
aux16[0] = aux16[1] = aux16[2] = aux16[3] = 0;
|
||||
|
||||
for (int j = 0; j < QK6_0/2; ++j) {
|
||||
const float x0 = src[0 + j]*id;
|
||||
const float x1 = src[QK6_0/2 + j]*id;
|
||||
|
||||
const uint8_t xi0 = MIN(63, (int8_t)(x0 + 32.5f));
|
||||
const uint8_t xi1 = MIN(63, (int8_t)(x1 + 32.5f));
|
||||
|
||||
b6.qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
|
||||
const uint8_t h = (xi0 >> 4) | ((xi1 >> 4) << 2);
|
||||
b6.qh[j%(QK6_0/4)] |= (h << 4*(j/(QK6_0/4)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static inline int best_index_int8(int n, constant float * val, float x) {
|
||||
if (x <= val[0]) return 0;
|
||||
if (x >= val[n-1]) return n-1;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user