diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 4f4f073cb6..cbc9459c91 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -66,7 +66,6 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_base(ggml const char * op_str = "undefined"; switch (op) { case GGML_OP_ADD_ID: op_str = "add_id"; break; - case GGML_OP_CONCAT: op_str = "concat"; break; default: GGML_ABORT("fatal error"); }; @@ -211,6 +210,21 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat(ggml_meta return res; } +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_concat(ggml_metal_library_t lib, ggml_type tsrc) { + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_concat_%s", ggml_type_name(tsrc)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + } + + return res; +} + ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary(ggml_metal_library_t lib, const ggml_tensor * op) { char base[256]; char name[256]; diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index 4a3ebb5569..d465f31c08 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -115,6 +115,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_get_rows struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_set_rows (ggml_metal_library_t lib, enum ggml_type tidx, enum ggml_type tdst); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_diag (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat (ggml_metal_library_t lib, enum ggml_type tsrc); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_concat (ggml_metal_library_t lib, enum ggml_type tsrc); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum (ggml_metal_library_t lib, const struct ggml_tensor * op); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index d583bd6efc..26b0e77f2f 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1123,13 +1123,24 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te return true; case GGML_OP_CONCAT: { - // kernel_concat copies one float-sized value per element. - // Other scalar types need a type-generic copy kernel first. const enum ggml_type src0_type = op->src[0]->type; const enum ggml_type src1_type = op->src[1]->type; - return src0_type == src1_type && - src0_type == op->type && - (src0_type == GGML_TYPE_F32 || src0_type == GGML_TYPE_I32); + if (src0_type != src1_type || src0_type != op->type) { + return false; + } + switch (src0_type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_I64: + return true; + case GGML_TYPE_BF16: + return has_bfloat; + default: + return false; + } } case GGML_OP_ADD: case GGML_OP_SUB: diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index e2ce56e9e2..b13bf27115 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -556,7 +556,7 @@ int ggml_metal_op_concat(ggml_metal_op_t ctx, int idx) { /*.dim =*/ dim, }; - auto pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_CONCAT); + auto pipeline = ggml_metal_library_get_pipeline_concat(lib, op->type); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 072f14e68f..97abecb474 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -7513,14 +7513,15 @@ template [[host_name("kernel_cpy_q5_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32< template [[host_name("kernel_cpy_q5_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32; template [[host_name("kernel_cpy_q8_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32; +template kernel void kernel_concat( - constant ggml_metal_kargs_concat & args, - device const char * src0, - device const char * src1, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { + constant ggml_metal_kargs_concat & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { const int i3 = tgpig.z; const int i2 = tgpig.y; @@ -7533,21 +7534,31 @@ kernel void kernel_concat( int o[4] = {0, 0, 0, 0}; o[args.dim] = args.dim == 0 ? args.ne00 : (args.dim == 1 ? args.ne01 : (args.dim == 2 ? args.ne02 : args.ne03)); - device const float * x; - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + device const T * x; + if (i0 < args.ne00 && i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) { - x = (device const float *)(src0 + (i3 )*args.nb03 + (i2 )*args.nb02 + (i1 )*args.nb01 + (i0 )*args.nb00); + x = (device const T *)(src0 + (i3 )*args.nb03 + (i2 )*args.nb02 + (i1 )*args.nb01 + (i0 )*args.nb00); } else { - x = (device const float *)(src1 + (i3 - o[3])*args.nb13 + (i2 - o[2])*args.nb12 + (i1 - o[1])*args.nb11 + (i0 - o[0])*args.nb10); + x = (device const T *)(src1 + (i3 - o[3])*args.nb13 + (i2 - o[2])*args.nb12 + (i1 - o[1])*args.nb11 + (i0 - o[0])*args.nb10); } - device float * y = (device float *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); + device T * y = (device T *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); *y = *x; } } +typedef decltype(kernel_concat) kernel_concat_t; + +template [[host_name("kernel_concat_f32")]] kernel kernel_concat_t kernel_concat; +template [[host_name("kernel_concat_f16")]] kernel kernel_concat_t kernel_concat; +template [[host_name("kernel_concat_bf16")]] kernel kernel_concat_t kernel_concat; +template [[host_name("kernel_concat_i8")]] kernel kernel_concat_t kernel_concat; +template [[host_name("kernel_concat_i16")]] kernel kernel_concat_t kernel_concat; +template [[host_name("kernel_concat_i32")]] kernel kernel_concat_t kernel_concat; +template [[host_name("kernel_concat_i64")]] kernel kernel_concat_t kernel_concat; + template void kernel_mul_mv_q2_K_f32_impl( args_t args,