Avoid recurrent state copy (#1777)

This commit is contained in:
Kawrakow 2026-05-11 13:13:59 +03:00 committed by GitHub
parent 94940cd882
commit 3557b446f8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 41 additions and 4 deletions

View File

@ -1269,6 +1269,12 @@ extern "C" {
struct ggml_tensor * a,
struct ggml_tensor * b,
int dim);
GGML_API struct ggml_tensor * ggml_concat_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * result,
int dim);
GGML_API struct ggml_tensor * ggml_abs(
struct ggml_context * ctx,

View File

@ -6944,6 +6944,40 @@ struct ggml_tensor * ggml_concat(
return result;
}
struct ggml_tensor * ggml_concat_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * result_in,
int dim) {
GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS);
if (!result_in) {
return ggml_concat(ctx, a, b, dim);
}
for (int d = 0; d < GGML_MAX_DIMS; ++d) {
int64_t ne;
if (d == dim) {
ne = a->ne[d] + b->ne[d];
} else {
GGML_ASSERT(a->ne[d] == b->ne[d]);
ne = a->ne[d];
}
GGML_ASSERT(ne == result_in->ne[d]);
}
struct ggml_tensor * result = ggml_view_tensor(ctx, result_in);
ggml_set_op_params_i32(result, 0, dim);
result->op = GGML_OP_CONCAT;
result->grad = NULL;
result->src[0] = a;
result->src[1] = b;
return result;
}
// ggml_abs
struct ggml_tensor * ggml_abs(

View File

@ -405,10 +405,7 @@ ggml_tensor * delta_net::build_qkv(ggml_context * ctx0, ggml_tensor * state_stor
cb(new_conv_states_cont, "new_conv_states_cont", il);
ggml_tensor * new_conv_flat = ggml_reshape_2d(ctx0, new_conv_states_cont, conv_state_dim, 1);
ggml_tensor * new_ssm_flat = ggml_reshape_2d(ctx0, new_state, ssm_state_dim, 1);
ggml_tensor * new_state_flat = ggml_concat(ctx0, new_conv_flat, new_ssm_flat, 0);
cb(new_state_flat, "new_state_flat", il);
auto state_cpy = ggml_cpy(ctx0, new_state_flat, state_dst);
auto state_cpy = ggml_concat_inplace(ctx0, new_conv_flat, new_ssm_flat, state_dst, 0);
cb(state_cpy, "state_cpy", il);
ggml_build_forward_expand(gf, state_cpy);