diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 0b605fa86b..f71d1aee73 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -3788,7 +3788,7 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) { ctx->memset_pipeline = ggml_webgpu_create_pipeline(ctx->device, wgsl_memset, "memset", constants); } -static void create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { +static void ggml_backend_webgpu_request_adapter(wgpu::Instance & instance, wgpu::Adapter & adapter) { wgpu::RequestAdapterOptions options = {}; #ifndef __EMSCRIPTEN__ @@ -3800,17 +3800,20 @@ static void create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { options.nextInChain = &adapterTogglesDesc; #endif - ctx->webgpu_global_ctx->instance.WaitAny( - ctx->webgpu_global_ctx->instance.RequestAdapter( - &options, wgpu::CallbackMode::AllowSpontaneous, - [&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) { - if (status != wgpu::RequestAdapterStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message); - return; - } - ctx->webgpu_global_ctx->adapter = std::move(adapter); - }), - UINT64_MAX); + instance.WaitAny(instance.RequestAdapter( + &options, wgpu::CallbackMode::AllowSpontaneous, + [&adapter](wgpu::RequestAdapterStatus status, wgpu::Adapter _adapter, const char * message) { + if (status != wgpu::RequestAdapterStatus::Success) { + GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message); + return; + } + adapter = std::move(_adapter); + }), + UINT64_MAX); +} + +static void create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { + ggml_backend_webgpu_request_adapter(ctx->webgpu_global_ctx->instance, ctx->webgpu_global_ctx->adapter); GGML_ASSERT(ctx->webgpu_global_ctx->adapter != nullptr); ctx->webgpu_global_ctx->adapter.GetLimits(&ctx->webgpu_global_ctx->capabilities.limits); @@ -4543,20 +4546,7 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() { // Probe for adapter support wgpu::Adapter adapter; if (ctx->webgpu_global_ctx->instance != nullptr) { - wgpu::RequestAdapterOptions options = {}; - - // probe for adapter support - ctx->webgpu_global_ctx->instance.WaitAny( - ctx->webgpu_global_ctx->instance.RequestAdapter( - &options, wgpu::CallbackMode::AllowSpontaneous, - [&adapter](wgpu::RequestAdapterStatus status, wgpu::Adapter _adapter, const char * message) { - if (status != wgpu::RequestAdapterStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message); - return; - } - adapter = std::move(_adapter); - }), - UINT64_MAX); + ggml_backend_webgpu_request_adapter(ctx->webgpu_global_ctx->instance, adapter); } // WebGPU backend requires f16 support and, on native, implicit device synchronization.