mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-06-28 04:30:15 -05:00
Gemma4 MTP: avoid casting KV cache to f32 (#1786)
This commit is contained in:
parent
f478a3ec0b
commit
86b5d076c5
@ -1843,6 +1843,16 @@ extern "C" {
|
||||
int64_t ne2,
|
||||
int64_t ne3);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_reshape_4d_ext(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
enum ggml_type type,
|
||||
int64_t ne0,
|
||||
int64_t ne1,
|
||||
int64_t ne2,
|
||||
int64_t ne3);
|
||||
|
||||
|
||||
// offset in bytes
|
||||
GGML_API struct ggml_tensor * ggml_view_1d(
|
||||
struct ggml_context * ctx,
|
||||
|
||||
@ -8579,6 +8579,36 @@ struct ggml_tensor * ggml_reshape_4d(
|
||||
return result;
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_reshape_4d_ext(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
enum ggml_type type,
|
||||
int64_t ne0,
|
||||
int64_t ne1,
|
||||
int64_t ne2,
|
||||
int64_t ne3) {
|
||||
|
||||
bool is_node = false;
|
||||
|
||||
if (a->grad) {
|
||||
is_node = true;
|
||||
}
|
||||
|
||||
const int64_t ne[4] = { ne0, ne1, ne2, ne3 };
|
||||
struct ggml_tensor * result = ggml_new_tensor_impl(ctx, type, 4, ne, a, 0);
|
||||
ggml_format_name_fast(a->name, " (reshaped)", 11, result->name);
|
||||
|
||||
GGML_ASSERT(ggml_nbytes(a) == ggml_nbytes(result));
|
||||
|
||||
result->op = GGML_OP_RESHAPE;
|
||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
result->src[0] = a;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
|
||||
static struct ggml_tensor * ggml_view_impl(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
|
||||
@ -63,6 +63,11 @@ static void gemma4_mtp_prepare_frozen_kv_views(
|
||||
k_parts.reserve(split_k->n_device);
|
||||
v_parts.reserve(split_v->n_device);
|
||||
|
||||
int n_k_reshaped = 0;
|
||||
int n_v_reshaped = 0;
|
||||
int n_k_heads = 0;
|
||||
int n_v_heads = 0;
|
||||
|
||||
for (int id = 0; id < split_k->n_device; ++id) {
|
||||
ggml_tensor * split_kl = split_k->splits[id];
|
||||
ggml_tensor * split_vl = split_v->splits[id];
|
||||
@ -82,7 +87,12 @@ static void gemma4_mtp_prepare_frozen_kv_views(
|
||||
ggml_row_size(split_kl->type, n_embd_head_k) * split_n_head_kv,
|
||||
ggml_row_size(split_kl->type, n_embd_head_k),
|
||||
0);
|
||||
if (k_part->type != GGML_TYPE_F32) {
|
||||
if (auto row_size = ggml_row_size(k_part->type, k_part->ne[0]); row_size % sizeof(float) == 0) {
|
||||
n_k_heads += split_n_head_kv;
|
||||
k_part = ggml_reshape_4d_ext(ctx0, k_part, GGML_TYPE_F32, (row_size/sizeof(float))*k_part->ne[2], k_part->ne[1], 1, 1);
|
||||
++n_k_reshaped;
|
||||
}
|
||||
else if (k_part->type != GGML_TYPE_F32) {
|
||||
k_part = ggml_cast(ctx0, k_part, GGML_TYPE_F32);
|
||||
}
|
||||
cb(k_part, "mtp_frozen_k_split", 1000 * (assistant_il + 1) + id);
|
||||
@ -92,7 +102,12 @@ static void gemma4_mtp_prepare_frozen_kv_views(
|
||||
ggml_row_size(split_vl->type, split_n_head_kv * n_embd_head_v),
|
||||
ggml_row_size(split_vl->type, n_embd_head_v),
|
||||
0);
|
||||
if (v_part->type != GGML_TYPE_F32) {
|
||||
if (auto row_size = ggml_row_size(v_part->type, v_part->ne[0]); row_size % sizeof(float) == 0) {
|
||||
v_part = ggml_reshape_4d_ext(ctx0, v_part, GGML_TYPE_F32, (row_size/sizeof(float))*v_part->ne[2], v_part->ne[1], 1, 1);
|
||||
n_v_heads += split_n_head_kv;
|
||||
++n_v_reshaped;
|
||||
}
|
||||
else if (v_part->type != GGML_TYPE_F32) {
|
||||
v_part = ggml_cast(ctx0, v_part, GGML_TYPE_F32);
|
||||
}
|
||||
cb(v_part, "mtp_frozen_v_split", 1000 * (assistant_il + 1) + id);
|
||||
@ -102,12 +117,32 @@ static void gemma4_mtp_prepare_frozen_kv_views(
|
||||
}
|
||||
|
||||
GGML_ASSERT(!k_parts.empty() && k_parts.size() == v_parts.size());
|
||||
GGML_ASSERT((int)k_parts.size() == n_k_reshaped || n_k_reshaped == 0);
|
||||
GGML_ASSERT((int)v_parts.size() == n_v_reshaped || n_v_reshaped == 0);
|
||||
|
||||
ggml_tensor * k_full = k_parts[0];
|
||||
ggml_tensor * v_full = v_parts[0];
|
||||
for (size_t i = 1; i < k_parts.size(); ++i) {
|
||||
k_full = ggml_concat(ctx0, k_full, k_parts[i], 2);
|
||||
v_full = ggml_concat(ctx0, v_full, v_parts[i], 2);
|
||||
if ((int)k_parts.size() == n_k_reshaped) {
|
||||
for (int i = 1; i < n_k_reshaped; ++i) {
|
||||
k_full = ggml_concat(ctx0, k_full, k_parts[i], 0);
|
||||
}
|
||||
k_full = ggml_reshape_4d_ext(ctx0, k_full, k_cache->type, n_embd_head_k, n_k_heads, k_full->ne[1], 1);
|
||||
k_full = ggml_permute(ctx0, k_full, 0, 2, 1, 3);
|
||||
} else {
|
||||
for (size_t i = 1; i < k_parts.size(); ++i) {
|
||||
k_full = ggml_concat(ctx0, k_full, k_parts[i], 2);
|
||||
}
|
||||
}
|
||||
if ((int)v_parts.size() == n_v_reshaped) {
|
||||
for (int i = 1; i < n_v_reshaped; ++i) {
|
||||
v_full = ggml_concat(ctx0, v_full, v_parts[i], 0);
|
||||
}
|
||||
v_full = ggml_reshape_4d_ext(ctx0, v_full, v_cache->type, n_embd_head_v, n_v_heads, v_full->ne[1], 1);
|
||||
v_full = ggml_permute(ctx0, v_full, 0, 2, 1, 3);
|
||||
} else {
|
||||
for (size_t i = 1; i < v_parts.size(); ++i) {
|
||||
v_full = ggml_concat(ctx0, v_full, v_parts[i], 2);
|
||||
}
|
||||
}
|
||||
|
||||
if (k_full->type != GGML_TYPE_F16) {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user