diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index a19b1f1d..8c0b0c89 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1843,6 +1843,16 @@ extern "C" { int64_t ne2, int64_t ne3); + GGML_API struct ggml_tensor * ggml_reshape_4d_ext( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_type type, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3); + + // offset in bytes GGML_API struct ggml_tensor * ggml_view_1d( struct ggml_context * ctx, diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 92157c1c..112eb0ae 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -8579,6 +8579,36 @@ struct ggml_tensor * ggml_reshape_4d( return result; } +struct ggml_tensor * ggml_reshape_4d_ext( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_type type, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3) { + + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + const int64_t ne[4] = { ne0, ne1, ne2, ne3 }; + struct ggml_tensor * result = ggml_new_tensor_impl(ctx, type, 4, ne, a, 0); + ggml_format_name_fast(a->name, " (reshaped)", 11, result->name); + + GGML_ASSERT(ggml_nbytes(a) == ggml_nbytes(result)); + + result->op = GGML_OP_RESHAPE; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + + + static struct ggml_tensor * ggml_view_impl( struct ggml_context * ctx, struct ggml_tensor * a, diff --git a/src/graphs/build_gemma4.cpp b/src/graphs/build_gemma4.cpp index fa254bb7..b0853a26 100644 --- a/src/graphs/build_gemma4.cpp +++ b/src/graphs/build_gemma4.cpp @@ -63,6 +63,11 @@ static void gemma4_mtp_prepare_frozen_kv_views( k_parts.reserve(split_k->n_device); v_parts.reserve(split_v->n_device); + int n_k_reshaped = 0; + int n_v_reshaped = 0; + int n_k_heads = 0; + int n_v_heads = 0; + for (int id = 0; id < split_k->n_device; ++id) { ggml_tensor * split_kl = split_k->splits[id]; ggml_tensor * split_vl = split_v->splits[id]; @@ -82,7 +87,12 @@ static void gemma4_mtp_prepare_frozen_kv_views( ggml_row_size(split_kl->type, n_embd_head_k) * split_n_head_kv, ggml_row_size(split_kl->type, n_embd_head_k), 0); - if (k_part->type != GGML_TYPE_F32) { + if (auto row_size = ggml_row_size(k_part->type, k_part->ne[0]); row_size % sizeof(float) == 0) { + n_k_heads += split_n_head_kv; + k_part = ggml_reshape_4d_ext(ctx0, k_part, GGML_TYPE_F32, (row_size/sizeof(float))*k_part->ne[2], k_part->ne[1], 1, 1); + ++n_k_reshaped; + } + else if (k_part->type != GGML_TYPE_F32) { k_part = ggml_cast(ctx0, k_part, GGML_TYPE_F32); } cb(k_part, "mtp_frozen_k_split", 1000 * (assistant_il + 1) + id); @@ -92,7 +102,12 @@ static void gemma4_mtp_prepare_frozen_kv_views( ggml_row_size(split_vl->type, split_n_head_kv * n_embd_head_v), ggml_row_size(split_vl->type, n_embd_head_v), 0); - if (v_part->type != GGML_TYPE_F32) { + if (auto row_size = ggml_row_size(v_part->type, v_part->ne[0]); row_size % sizeof(float) == 0) { + v_part = ggml_reshape_4d_ext(ctx0, v_part, GGML_TYPE_F32, (row_size/sizeof(float))*v_part->ne[2], v_part->ne[1], 1, 1); + n_v_heads += split_n_head_kv; + ++n_v_reshaped; + } + else if (v_part->type != GGML_TYPE_F32) { v_part = ggml_cast(ctx0, v_part, GGML_TYPE_F32); } cb(v_part, "mtp_frozen_v_split", 1000 * (assistant_il + 1) + id); @@ -102,12 +117,32 @@ static void gemma4_mtp_prepare_frozen_kv_views( } GGML_ASSERT(!k_parts.empty() && k_parts.size() == v_parts.size()); + GGML_ASSERT((int)k_parts.size() == n_k_reshaped || n_k_reshaped == 0); + GGML_ASSERT((int)v_parts.size() == n_v_reshaped || n_v_reshaped == 0); ggml_tensor * k_full = k_parts[0]; ggml_tensor * v_full = v_parts[0]; - for (size_t i = 1; i < k_parts.size(); ++i) { - k_full = ggml_concat(ctx0, k_full, k_parts[i], 2); - v_full = ggml_concat(ctx0, v_full, v_parts[i], 2); + if ((int)k_parts.size() == n_k_reshaped) { + for (int i = 1; i < n_k_reshaped; ++i) { + k_full = ggml_concat(ctx0, k_full, k_parts[i], 0); + } + k_full = ggml_reshape_4d_ext(ctx0, k_full, k_cache->type, n_embd_head_k, n_k_heads, k_full->ne[1], 1); + k_full = ggml_permute(ctx0, k_full, 0, 2, 1, 3); + } else { + for (size_t i = 1; i < k_parts.size(); ++i) { + k_full = ggml_concat(ctx0, k_full, k_parts[i], 2); + } + } + if ((int)v_parts.size() == n_v_reshaped) { + for (int i = 1; i < n_v_reshaped; ++i) { + v_full = ggml_concat(ctx0, v_full, v_parts[i], 0); + } + v_full = ggml_reshape_4d_ext(ctx0, v_full, v_cache->type, n_embd_head_v, n_v_heads, v_full->ne[1], 1); + v_full = ggml_permute(ctx0, v_full, 0, 2, 1, 3); + } else { + for (size_t i = 1; i < v_parts.size(); ++i) { + v_full = ggml_concat(ctx0, v_full, v_parts[i], 2); + } } if (k_full->type != GGML_TYPE_F16) {