mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-06-28 04:30:15 -05:00
Avoid recurrent state copy (#1777)
This commit is contained in:
parent
94940cd882
commit
3557b446f8
@ -1269,6 +1269,12 @@ extern "C" {
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
int dim);
|
||||
GGML_API struct ggml_tensor * ggml_concat_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * result,
|
||||
int dim);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_abs(
|
||||
struct ggml_context * ctx,
|
||||
|
||||
@ -6944,6 +6944,40 @@ struct ggml_tensor * ggml_concat(
|
||||
return result;
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_concat_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * result_in,
|
||||
int dim) {
|
||||
GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS);
|
||||
if (!result_in) {
|
||||
return ggml_concat(ctx, a, b, dim);
|
||||
}
|
||||
|
||||
for (int d = 0; d < GGML_MAX_DIMS; ++d) {
|
||||
int64_t ne;
|
||||
if (d == dim) {
|
||||
ne = a->ne[d] + b->ne[d];
|
||||
} else {
|
||||
GGML_ASSERT(a->ne[d] == b->ne[d]);
|
||||
ne = a->ne[d];
|
||||
}
|
||||
GGML_ASSERT(ne == result_in->ne[d]);
|
||||
}
|
||||
|
||||
struct ggml_tensor * result = ggml_view_tensor(ctx, result_in);
|
||||
|
||||
ggml_set_op_params_i32(result, 0, dim);
|
||||
|
||||
result->op = GGML_OP_CONCAT;
|
||||
result->grad = NULL;
|
||||
result->src[0] = a;
|
||||
result->src[1] = b;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_abs
|
||||
|
||||
struct ggml_tensor * ggml_abs(
|
||||
|
||||
@ -405,10 +405,7 @@ ggml_tensor * delta_net::build_qkv(ggml_context * ctx0, ggml_tensor * state_stor
|
||||
cb(new_conv_states_cont, "new_conv_states_cont", il);
|
||||
ggml_tensor * new_conv_flat = ggml_reshape_2d(ctx0, new_conv_states_cont, conv_state_dim, 1);
|
||||
ggml_tensor * new_ssm_flat = ggml_reshape_2d(ctx0, new_state, ssm_state_dim, 1);
|
||||
ggml_tensor * new_state_flat = ggml_concat(ctx0, new_conv_flat, new_ssm_flat, 0);
|
||||
cb(new_state_flat, "new_state_flat", il);
|
||||
|
||||
auto state_cpy = ggml_cpy(ctx0, new_state_flat, state_dst);
|
||||
auto state_cpy = ggml_concat_inplace(ctx0, new_conv_flat, new_ssm_flat, state_dst, 0);
|
||||
cb(state_cpy, "state_cpy", il);
|
||||
ggml_build_forward_expand(gf, state_cpy);
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user