From 6effcecd0bf3cb2209999cecfa297ed4d8523b5a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Fri, 5 Jun 2026 17:35:13 +0200 Subject: [PATCH] TP: round up granularity to 128 (#24180) * TP: round up granularity to 128 * remove assert --- src/llama-model.cpp | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 1f442d8a32..784deb70af 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -553,10 +553,12 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str }; auto get_split_granularity = [&](int64_t blck_size, uint32_t il, const std::vector> & segments) -> std::vector { + // for better performance it may make sense to round up blck_size to a higher power of 2 so that more efficient kernels can be used if (hparams.is_recr(il)) { // linear attention - const int64_t head_dim = hparams.ssm_d_state; - const int64_t granularity_qkv = std::lcm(blck_size, head_dim); + const int64_t head_dim = hparams.ssm_d_state; + const int64_t blck_size_perf = std::lcm(blck_size, 128); + const int64_t granularity_qkv = std::lcm(blck_size_perf, head_dim); if (std::regex_match(tensor_name, pattern_qkv_weight) || std::regex_match(tensor_name, pattern_attn_gate_weight) || std::regex_match(tensor_name, pattern_ssm_conv1d) || std::regex_match(tensor_name, pattern_ssm_out_weight)) { return std::vector(segments.size(), granularity_qkv); @@ -578,17 +580,24 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str // regular attention const uint32_t n_gqa = hparams.n_gqa(il); const uint32_t n_embd_q = n_gqa * hparams.n_embd_head_k(il); - if (std::regex_match(tensor_name, pattern_attn_sinks)) { - GGML_ASSERT(segments.size() == 1); - return {std::lcm(n_embd_q, blck_size)/n_embd_q * n_gqa}; + + // to handle head sizes like 80, only increase granularity while it doesn't cause underutilization + int64_t blck_size_perf = blck_size; + while (blck_size_perf < 128 && blck_size_perf*ud->n_devices < n_embd_q) { + blck_size_perf *= 2; } - const int64_t granularity_q = std::lcm(n_embd_q, blck_size); + if (std::regex_match(tensor_name, pattern_attn_sinks)) { + GGML_ASSERT(segments.size() == 1); + return {std::lcm(n_embd_q, blck_size_perf)/n_embd_q * n_gqa}; + } + + const int64_t granularity_q = std::lcm(n_embd_q, blck_size_perf); if (std::regex_match(tensor_name, pattern_q_weight) || std::regex_match(tensor_name, pattern_q_bias)) { GGML_ASSERT(segments.size() == 1); // some models have Q gate tensors, for those cases the granularity needs to be doubled: if (ud->model->arch == LLM_ARCH_QWEN3NEXT || ud->model->arch == LLM_ARCH_QWEN35 || ud->model->arch == LLM_ARCH_QWEN35MOE) { - return {std::lcm(2*n_embd_q, blck_size)}; + return {std::lcm(2*n_embd_q, blck_size_perf)}; } return {granularity_q}; } @@ -613,8 +622,9 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str // FFN if (std::regex_match(tensor_name, pattern_ffn_up_gate_weight) || std::regex_match(tensor_name, pattern_ffn_up_gate_bias) || std::regex_match(tensor_name, pattern_ffn_gate_up_weight) || std::regex_match(tensor_name, pattern_ffn_down_weight)) { + const int64_t blck_size_perf = std::lcm(blck_size, 128); GGML_ASSERT(segments.size() == 1); - return {blck_size}; + return {blck_size_perf}; } // everything else @@ -627,7 +637,6 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str tensor_config tc = get_tensor_config(); split_state.axis = tc.axis; if (split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS) { - const int64_t ne_full = tensor->ne[split_state.axis]; const int64_t blck_size = ggml_blck_size(tc.tensor_axis_0->type); const float * tensor_split = ud->model->tensor_split(); std::vector tensor_split_scan; @@ -644,7 +653,6 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str const int64_t ne_s = segments[is].first; const uint32_t nr_s = segments[is].second; const int64_t g_s = granularity[is]; - GGML_ASSERT(ne_full % g_s == 0); int64_t low = 0; size_t j = 0; for (; j < ud->n_devices - 1; j++) {