Gemma4 MTP: avoid casting KV cache to f32 (#1786)

This commit is contained in:
Kawrakow 2026-05-13 09:11:27 +03:00 committed by GitHub
parent f478a3ec0b
commit 86b5d076c5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 80 additions and 5 deletions

View File

@ -1843,6 +1843,16 @@ extern "C" {
int64_t ne2,
int64_t ne3);
GGML_API struct ggml_tensor * ggml_reshape_4d_ext(
struct ggml_context * ctx,
struct ggml_tensor * a,
enum ggml_type type,
int64_t ne0,
int64_t ne1,
int64_t ne2,
int64_t ne3);
// offset in bytes
GGML_API struct ggml_tensor * ggml_view_1d(
struct ggml_context * ctx,

View File

@ -8579,6 +8579,36 @@ struct ggml_tensor * ggml_reshape_4d(
return result;
}
struct ggml_tensor * ggml_reshape_4d_ext(
struct ggml_context * ctx,
struct ggml_tensor * a,
enum ggml_type type,
int64_t ne0,
int64_t ne1,
int64_t ne2,
int64_t ne3) {
bool is_node = false;
if (a->grad) {
is_node = true;
}
const int64_t ne[4] = { ne0, ne1, ne2, ne3 };
struct ggml_tensor * result = ggml_new_tensor_impl(ctx, type, 4, ne, a, 0);
ggml_format_name_fast(a->name, " (reshaped)", 11, result->name);
GGML_ASSERT(ggml_nbytes(a) == ggml_nbytes(result));
result->op = GGML_OP_RESHAPE;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
return result;
}
static struct ggml_tensor * ggml_view_impl(
struct ggml_context * ctx,
struct ggml_tensor * a,

View File

@ -63,6 +63,11 @@ static void gemma4_mtp_prepare_frozen_kv_views(
k_parts.reserve(split_k->n_device);
v_parts.reserve(split_v->n_device);
int n_k_reshaped = 0;
int n_v_reshaped = 0;
int n_k_heads = 0;
int n_v_heads = 0;
for (int id = 0; id < split_k->n_device; ++id) {
ggml_tensor * split_kl = split_k->splits[id];
ggml_tensor * split_vl = split_v->splits[id];
@ -82,7 +87,12 @@ static void gemma4_mtp_prepare_frozen_kv_views(
ggml_row_size(split_kl->type, n_embd_head_k) * split_n_head_kv,
ggml_row_size(split_kl->type, n_embd_head_k),
0);
if (k_part->type != GGML_TYPE_F32) {
if (auto row_size = ggml_row_size(k_part->type, k_part->ne[0]); row_size % sizeof(float) == 0) {
n_k_heads += split_n_head_kv;
k_part = ggml_reshape_4d_ext(ctx0, k_part, GGML_TYPE_F32, (row_size/sizeof(float))*k_part->ne[2], k_part->ne[1], 1, 1);
++n_k_reshaped;
}
else if (k_part->type != GGML_TYPE_F32) {
k_part = ggml_cast(ctx0, k_part, GGML_TYPE_F32);
}
cb(k_part, "mtp_frozen_k_split", 1000 * (assistant_il + 1) + id);
@ -92,7 +102,12 @@ static void gemma4_mtp_prepare_frozen_kv_views(
ggml_row_size(split_vl->type, split_n_head_kv * n_embd_head_v),
ggml_row_size(split_vl->type, n_embd_head_v),
0);
if (v_part->type != GGML_TYPE_F32) {
if (auto row_size = ggml_row_size(v_part->type, v_part->ne[0]); row_size % sizeof(float) == 0) {
v_part = ggml_reshape_4d_ext(ctx0, v_part, GGML_TYPE_F32, (row_size/sizeof(float))*v_part->ne[2], v_part->ne[1], 1, 1);
n_v_heads += split_n_head_kv;
++n_v_reshaped;
}
else if (v_part->type != GGML_TYPE_F32) {
v_part = ggml_cast(ctx0, v_part, GGML_TYPE_F32);
}
cb(v_part, "mtp_frozen_v_split", 1000 * (assistant_il + 1) + id);
@ -102,12 +117,32 @@ static void gemma4_mtp_prepare_frozen_kv_views(
}
GGML_ASSERT(!k_parts.empty() && k_parts.size() == v_parts.size());
GGML_ASSERT((int)k_parts.size() == n_k_reshaped || n_k_reshaped == 0);
GGML_ASSERT((int)v_parts.size() == n_v_reshaped || n_v_reshaped == 0);
ggml_tensor * k_full = k_parts[0];
ggml_tensor * v_full = v_parts[0];
for (size_t i = 1; i < k_parts.size(); ++i) {
k_full = ggml_concat(ctx0, k_full, k_parts[i], 2);
v_full = ggml_concat(ctx0, v_full, v_parts[i], 2);
if ((int)k_parts.size() == n_k_reshaped) {
for (int i = 1; i < n_k_reshaped; ++i) {
k_full = ggml_concat(ctx0, k_full, k_parts[i], 0);
}
k_full = ggml_reshape_4d_ext(ctx0, k_full, k_cache->type, n_embd_head_k, n_k_heads, k_full->ne[1], 1);
k_full = ggml_permute(ctx0, k_full, 0, 2, 1, 3);
} else {
for (size_t i = 1; i < k_parts.size(); ++i) {
k_full = ggml_concat(ctx0, k_full, k_parts[i], 2);
}
}
if ((int)v_parts.size() == n_v_reshaped) {
for (int i = 1; i < n_v_reshaped; ++i) {
v_full = ggml_concat(ctx0, v_full, v_parts[i], 0);
}
v_full = ggml_reshape_4d_ext(ctx0, v_full, v_cache->type, n_embd_head_v, n_v_heads, v_full->ne[1], 1);
v_full = ggml_permute(ctx0, v_full, 0, 2, 1, 3);
} else {
for (size_t i = 1; i < v_parts.size(); ++i) {
v_full = ggml_concat(ctx0, v_full, v_parts[i], 2);
}
}
if (k_full->type != GGML_TYPE_F16) {