diff --git a/ggml/src/ggml-sycl/binbcast.cpp b/ggml/src/ggml-sycl/binbcast.cpp index ad2e6ca35e..306eeddc0c 100644 --- a/ggml/src/ggml-sycl/binbcast.cpp +++ b/ggml/src/ggml-sycl/binbcast.cpp @@ -293,6 +293,11 @@ inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_t (sycl::ext::oneapi::bfloat16 *) dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2, nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1), main_stream); + } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_BF16) { + op()((const sycl::ext::oneapi::bfloat16 *) src0->data, (const float *) src1->data, + (sycl::ext::oneapi::bfloat16 *) dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2, + ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2, nb3, ggml_is_contiguous(src0), + ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1), main_stream); #endif } else { fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__, ggml_type_name(dst->type), diff --git a/ggml/src/ggml-sycl/element_wise.cpp b/ggml/src/ggml-sycl/element_wise.cpp index aca68e58ee..0c82ceb969 100644 --- a/ggml/src/ggml-sycl/element_wise.cpp +++ b/ggml/src/ggml-sycl/element_wise.cpp @@ -43,14 +43,44 @@ static __dpct_inline__ T op_sgn(T x) { return x > static_cast(0.f) ? static_cast(1.f) : ((x < static_cast(0.f) ? static_cast(-1.f) : static_cast(0.f))); } + template static __dpct_inline__ T op_abs(T x) { - return sycl::fabs(x); + if constexpr (std::is_same_v) { + return sycl::ext::oneapi::experimental::fabs(x); // or experimental namespace if needed + } else { + return sycl::fabs(x); + } +} + +template +static __dpct_inline__ T op_expm1(T x) { + if constexpr (std::is_same_v) { + return static_cast( + sycl::expm1(static_cast(x)) + ); + } else { + return sycl::expm1(x); + } } template static __dpct_inline__ T op_elu(T x) { - return (x > static_cast(0.f)) ? x : sycl::expm1(x); + return (x > static_cast(0.f)) ? x : op_expm1(x); +} + +template +static __dpct_inline__ T op_tanh(T x) { + if constexpr (std::is_same_v) { + constexpr int ver = __INTEL_LLVM_COMPILER; +#if defined(__INTEL_LLVM_COMPILER) && (__INTEL_LLVM_COMPILER >= 20260000) + return sycl::ext::oneapi::experimental::tanh(x); +#else + return static_cast(sycl::tanh(static_cast(x))); +#endif + } else { + return sycl::tanh(x); + } } template @@ -59,74 +89,106 @@ static __dpct_inline__ T op_gelu(T x) { const T SQRT_2_OVER_PI = static_cast(0.79788456080286535587989211986876f); return static_cast(0.5f) * x * (static_cast(1.0f) + - sycl::tanh(SQRT_2_OVER_PI * x * (static_cast(1.0f) + GELU_COEF_A * x * x))); + op_tanh(SQRT_2_OVER_PI * x * (static_cast(1.0f) + GELU_COEF_A * x * x))); +} + +template +static __dpct_inline__ T op_exp(T x) { + if constexpr (std::is_same_v) { + return sycl::ext::oneapi::experimental::exp(x); + } else { + return sycl::exp(x); + } } template static __dpct_inline__ T op_silu(T x) { - return x / (static_cast(1.0f) + sycl::native::exp(-x)); + return x / (static_cast(1.0f) + op_exp(-x)); } template -static __dpct_inline__ T op_gelu_quick(T x) { - const T GELU_QUICK_COEF_LOCAL = static_cast(-1.702f); - return x * (static_cast(1.0f) / (static_cast(1.0f) + sycl::native::exp(GELU_QUICK_COEF_LOCAL * x))); +static __dpct_inline__ T op_erf(T x) { + if constexpr (std::is_same_v) { + return static_cast( + sycl::erf(static_cast(x)) + ); + } else { + return sycl::erf(x); + } } template static __dpct_inline__ T op_gelu_erf(T x) { const T SQRT_2_INV = static_cast(0.70710678118654752440084436210484f); - return static_cast(0.5f) * x * (static_cast(1.0f) + sycl::erf(x * SQRT_2_INV)); + return static_cast(0.5f) * x * (static_cast(1.0f) + op_erf(x * SQRT_2_INV)); } template -static __dpct_inline__ T op_tanh(T x) { - return sycl::tanh(x); +static __dpct_inline__ T op_gelu_quick(T x) { + const T GELU_QUICK_COEF_LOCAL = static_cast(-1.702f); + return x * (static_cast(1.0f) / (static_cast(1.0f) + op_exp(GELU_QUICK_COEF_LOCAL * x))); } template static __dpct_inline__ T op_relu(T x) { - return sycl::fmax(x, static_cast(0)); + if constexpr (std::is_same_v) { + return sycl::ext::oneapi::experimental::fmax(x, static_cast(0)); + } else { + return sycl::fmax(x, static_cast(0)); + } } template static __dpct_inline__ T op_sigmoid(T x) { - return static_cast(1.0f) / (static_cast(1.0f) + sycl::native::exp(-x)); + return static_cast(1.0f) / (static_cast(1.0f) + op_exp(-x)); } template static __dpct_inline__ T op_sqrt(T x) { - return sycl::sqrt(x); + if constexpr (std::is_same_v) { + return sycl::ext::oneapi::experimental::sqrt(x); + } else { + return sycl::sqrt(x); + } } template static __dpct_inline__ T op_sin(T x) { - return sycl::sin(x); + if constexpr (std::is_same_v) { + return sycl::ext::oneapi::experimental::sin(x); + } else { + return sycl::sin(x); + } } template static __dpct_inline__ T op_cos(T x) { - return sycl::cos(x); + if constexpr (std::is_same_v) { + return sycl::ext::oneapi::experimental::cos(x); + } else { + return sycl::cos(x); + } } template static __dpct_inline__ T op_hardsigmoid(T x) { - return sycl::fmin(static_cast(1.0f), sycl::fmax(static_cast(0.0f), (x + static_cast(3.0f)) / static_cast(6.0f))); + if constexpr (std::is_same_v) { + return sycl::ext::oneapi::experimental::fmin( + static_cast(1.0f), sycl::ext::oneapi::experimental::fmax( + static_cast(0.0f), (x + static_cast(3.0f)) / static_cast(6.0f))); + } else { + return sycl::fmin(static_cast(1.0f), + sycl::fmax(static_cast(0.0f), (x + static_cast(3.0f)) / static_cast(6.0f))); + } } template static __dpct_inline__ T op_hardswish(T x) { - return x * sycl::fmin(static_cast(1.0f), sycl::fmax(static_cast(0.0f), (x + static_cast(3.0f)) / static_cast(6.0f))); -} - -template -static __dpct_inline__ T op_exp(T x) { - return sycl::exp(x); -} - -template -static __dpct_inline__ T op_expm1(T x) { - return sycl::expm1(x); + if constexpr (std::is_same_v) { + return x * sycl::ext::oneapi::experimental::fmin(static_cast(1.0f), sycl::ext::oneapi::experimental::fmax(static_cast(0.0f), (x + static_cast(3.0f)) / static_cast(6.0f))); + } else { + return x * sycl::fmin(static_cast(1.0f), sycl::fmax(static_cast(0.0f), (x + static_cast(3.0f)) / static_cast(6.0f))); + } } template @@ -134,13 +196,17 @@ static __dpct_inline__ T op_log(T x) { if (x <= static_cast(0)) { return neg_infinity(); } - return sycl::log(x); + if constexpr (std::is_same_v) { + return sycl::ext::oneapi::experimental::log(x); + } else { + return sycl::log(x); + } } template static __dpct_inline__ T op_softplus(T x) { const float xf = (float) x; - const float ax = sycl::fabs(xf); + const float ax = op_abs(xf); const float m = sycl::fmax(xf, 0.0f); const float y = m + sycl::log1p(sycl::exp(-ax)); return (T) y; @@ -159,8 +225,14 @@ static __dpct_inline__ T op_step(T x) { template static __dpct_inline__ T op_leaky_relu(T x, float negative_slope) { T neg_slope_T = static_cast(negative_slope); - return sycl::fmax(x, static_cast(0)) + + if constexpr (std::is_same_v) { + return sycl::ext::oneapi::experimental::fmax(x, static_cast(0)) + + sycl::ext::oneapi::experimental::fmin(x, static_cast(0.0f)) * neg_slope_T; + + } else { + return sycl::fmax(x, static_cast(0)) + sycl::fmin(x, static_cast(0.0f)) * neg_slope_T; + } } template @@ -175,22 +247,40 @@ static __dpct_inline__ T op_clamp(T x, float min_val, float max_val) { template static __dpct_inline__ T op_floor(T x) { - return sycl::floor(x); + if constexpr (std::is_same_v) { + return sycl::ext::oneapi::experimental::floor(x); + } else { + return sycl::floor(x); + } } template static __dpct_inline__ T op_ceil(T x) { - return sycl::ceil(x); + if constexpr (std::is_same_v) { + return sycl::ext::oneapi::experimental::ceil(x); + } else { + return sycl::ceil(x); + } } template static __dpct_inline__ T op_round(T x) { - return sycl::round(x); + if constexpr (std::is_same_v) { + return static_cast( + sycl::round(static_cast(x)) + ); + } else { + return sycl::round(x); + } } template static __dpct_inline__ T op_trunc(T x) { - return sycl::trunc(x); + if constexpr (std::is_same_v) { + return sycl::ext::oneapi::experimental::trunc(x); + } else { + return sycl::trunc(x); + } } template @@ -339,7 +429,7 @@ static void acc_f32_sycl(const float *x, const float *y, float *dst, const int num_blocks = (n_elements + SYCL_ACC_BLOCK_SIZE - 1) / SYCL_ACC_BLOCK_SIZE; stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE)), - [=](sycl::nd_item<3> /*item_ct1*/) { + [=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, ne13, s1, s2, s3, offset); }); } @@ -354,8 +444,8 @@ static void arange_kernel(T * dst, const int k, T start, T step, template static inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) { - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16 || dst->src[0]->type == GGML_TYPE_BF16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_BF16); GGML_ASSERT(dst->src[0]->type == dst->type); dpct::queue_ptr main_stream = ctx.stream(); @@ -367,6 +457,14 @@ static inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx, kernel_invoker(data_pts.src, data_pts.dst, (int)ggml_nelements(dst->src[0]), main_stream, std::forward(args)...); break; } +#ifdef GGML_SYCL_HAS_BF16 + case GGML_TYPE_BF16: + { + auto data_pts = cast_data(dst); + kernel_invoker(data_pts.src, data_pts.dst, (int)ggml_nelements(dst->src[0]), main_stream, std::forward(args)...); + break; + } +#endif case GGML_TYPE_F32: { auto data_pts = cast_data(dst); @@ -480,7 +578,7 @@ static inline void ggml_sycl_op_unary( stream->parallel_for( sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256), sycl::range<1>(256)), - [=](sycl::nd_item<1> item_ct1) { + [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { unary_op_generic_kernel( src, dst_ptr, k_elements, ne0, ne1, ne2, ne3, @@ -508,7 +606,7 @@ static inline void ggml_sycl_op_arange(ggml_backend_sycl_context & ctx, ggml_ten stream->parallel_for( sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_ARANGE_BLOCK_SIZE), sycl::range<1>(SYCL_ARANGE_BLOCK_SIZE)), - [=](sycl::nd_item<1> item_ct1) { + [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { arange_kernel(dst_ptr, k, start, step, item_ct1); }); } @@ -602,7 +700,7 @@ static inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor stream->parallel_for( sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_EXP_BLOCK_SIZE), sycl::range<1>(SYCL_EXP_BLOCK_SIZE)), - [=](sycl::nd_item<1> item_ct1) { + [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { unary_op_log_kernel(src, dst_ptr, k_elements, item_ct1); }); }); @@ -640,7 +738,7 @@ static inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, ggml_tenso stream->parallel_for( sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SQRT_BLOCK_SIZE), sycl::range<1>(SYCL_SQRT_BLOCK_SIZE)), - [=](sycl::nd_item<1> item_ct1) { + [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { unary_op_sqrt_kernel(src, dst_ptr, k_elements, item_ct1); }); }); @@ -653,7 +751,7 @@ static inline void ggml_sycl_op_sin(ggml_backend_sycl_context & ctx, ggml_tensor stream->parallel_for( sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIN_BLOCK_SIZE), sycl::range<1>(SYCL_SIN_BLOCK_SIZE)), - [=](sycl::nd_item<1> item_ct1) { + [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { unary_op_sin_kernel(src, dst_ptr, k_elements, item_ct1); }); }); @@ -666,7 +764,7 @@ static inline void ggml_sycl_op_cos(ggml_backend_sycl_context & ctx, ggml_tensor stream->parallel_for( sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIN_BLOCK_SIZE), sycl::range<1>(SYCL_SIN_BLOCK_SIZE)), - [=](sycl::nd_item<1> item_ct1) { + [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { unary_op_cos_kernel(src, dst_ptr, k_elements, item_ct1); }); }); @@ -681,7 +779,7 @@ static inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, ggml stream->parallel_for( sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_RELU_BLOCK_SIZE), sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), - [=](sycl::nd_item<1> item_ct1) { + [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { unary_op_leaky_relu_kernel(src, dst_ptr, k_elements, slope, item_ct1); }); }, negative_slope); @@ -694,7 +792,7 @@ static inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, ggml_tensor stream->parallel_for( sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SQR_BLOCK_SIZE), sycl::range<1>(SYCL_SQR_BLOCK_SIZE)), - [=](sycl::nd_item<1> item_ct1) { + [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { unary_op_sqr_kernel(src, dst_ptr, k_elements, item_ct1); }); }); @@ -711,7 +809,7 @@ static inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tens stream->parallel_for( sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_CLAMP_BLOCK_SIZE), sycl::range<1>(SYCL_CLAMP_BLOCK_SIZE)), - [=](sycl::nd_item<1> item_ct1) { + [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { clamp(src, dst_ptr, min_arg, max_arg, k_elements, item_ct1); }); }, min_val, max_val); @@ -774,7 +872,8 @@ static inline void ggml_sycl_op_geglu(ggml_backend_sycl_context & ctx, ggml_tens [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) { const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE); main_stream->parallel_for( - sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { + sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), + sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { gated_op_fused_geglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1); }); }); @@ -785,7 +884,8 @@ static inline void ggml_sycl_op_reglu(ggml_backend_sycl_context & ctx, ggml_tens [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) { const uint32_t num_blocks = ceil_div((uint32_t)k, SYCL_RELU_BLOCK_SIZE); // Using RELU block size for reglu main_stream->parallel_for( - sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { + sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), + sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { gated_op_fused_reglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1); }); }); @@ -796,7 +896,8 @@ static inline void ggml_sycl_op_swiglu(ggml_backend_sycl_context & ctx, ggml_ten [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) { const uint32_t num_blocks = ceil_div((uint32_t)k, SYCL_SILU_BLOCK_SIZE); // Using SILU block size for swiglu main_stream->parallel_for( - sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { + sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), + sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { gated_op_fused_swiglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1); }); }); @@ -811,7 +912,6 @@ __dpct_inline__ float ggml_sycl_op_swiglu_oai_single(float x, float g, float alp return out_glu; } - template static void swiglu_oai_kernel(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, @@ -845,7 +945,7 @@ static void swiglu_oai_sycl(const T * x, const int64_t num_blocks = (k + SYCL_GLU_BLOCK_SIZE - 1) / SYCL_GLU_BLOCK_SIZE; stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_GLU_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_GLU_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { swiglu_oai_kernel(x, g, dst, k, n, o0, o1, alpha, limit, item_ct1); }); } @@ -899,7 +999,8 @@ static inline void ggml_sycl_op_geglu_erf(ggml_backend_sycl_context & ctx, ggml_ [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) { const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE); main_stream->parallel_for( - sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { + sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), + sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { gated_op_fused_geglu_erf(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1); }); }); @@ -910,7 +1011,8 @@ static inline void ggml_sycl_op_geglu_quick(ggml_backend_sycl_context & ctx, ggm [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) { const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE); main_stream->parallel_for( - sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { + sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), + sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { gated_op_fused_geglu_quick(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1); }); });