diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index f4a578b893..5fbebc6d75 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -699,6 +699,7 @@ struct vk_device_struct { bool add_rms_fusion; uint32_t partials_binding_alignment; + uint32_t max_nodes_per_submit; bool shader_64b_indexing; @@ -5878,6 +5879,14 @@ static vk_device ggml_vk_get_device(size_t idx) { device->subgroup_vote = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eVote); + // Submit at least every 100 nodes, in case there are workloads without as much matmul. + device->max_nodes_per_submit = 100; + const char* GGML_VK_MAX_NODES_PER_SUBMIT = getenv("GGML_VK_MAX_NODES_PER_SUBMIT"); + if (GGML_VK_MAX_NODES_PER_SUBMIT != nullptr) { + uint32_t max_nodes_per_submit = std::stoul(GGML_VK_MAX_NODES_PER_SUBMIT); + device->max_nodes_per_submit = std::max(max_nodes_per_submit, 1u); + } + const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr; device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute; @@ -16173,8 +16182,6 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg // Submit after enough work has accumulated, to overlap CPU cmdbuffer generation with GPU execution. // Estimate the amount of matmul work by looking at the weight matrix size, and submit every 100MB // (and scaled down based on model size, so smaller models submit earlier). - // Also submit at least every 100 nodes, in case there are workloads without as much matmul. - int nodes_per_submit = 100; int submitted_nodes = 0; int submit_count = 0; uint64_t mul_mat_bytes = 0; @@ -16400,7 +16407,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg // Signal the almost_ready fence when the graph is mostly complete (< 20% remaining) bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5; - bool submit = (submitted_nodes >= nodes_per_submit) || + bool submit = ((uint32_t)submitted_nodes >= ctx->device->max_nodes_per_submit) || (mul_mat_bytes_per_submit != 0 && mul_mat_bytes >= mul_mat_bytes_per_submit) || (i + ctx->num_additional_fused_ops >= last_node) || (almost_ready && !ctx->almost_ready_fence_pending);