mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-06-28 04:30:15 -05:00
MTP: faster recurrent state restore (#1791)
* MTP: store ready per step convolution states * Cleanup
This commit is contained in:
parent
86b5d076c5
commit
397150caa2
@ -2470,7 +2470,8 @@ extern "C" {
|
||||
struct ggml_tensor * s,
|
||||
struct ggml_tensor * x,
|
||||
struct ggml_tensor * c,
|
||||
struct ggml_tensor * sq);
|
||||
struct ggml_tensor * sq,
|
||||
struct ggml_tensor * saved_steps);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_ssm_scan(
|
||||
struct ggml_context * ctx,
|
||||
|
||||
@ -2,12 +2,13 @@
|
||||
|
||||
#define CUDA_SSM_CONV_BLOCK_SIZE 256
|
||||
|
||||
template <int split_n_t>
|
||||
template <int split_n_t, bool save_steps>
|
||||
static __global__ void ssm_conv_single_seq_f32(
|
||||
const float * src0,
|
||||
const float * src1,
|
||||
const float * src2,
|
||||
float * dst_x,
|
||||
[[maybe_unused]] float * saved,
|
||||
int nc,
|
||||
int nr,
|
||||
int n_t,
|
||||
@ -27,6 +28,11 @@ static __global__ void ssm_conv_single_seq_f32(
|
||||
const float * state_row = src0 + (size_t) row * src0_s1;
|
||||
const float * c_row = src2 + (size_t) row * nc;
|
||||
|
||||
[[maybe_unused]] float * y;
|
||||
if constexpr (save_steps) {
|
||||
y = saved + t0*(nc - 1)*nr;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int it = 0; it < split_n_t; ++it) {
|
||||
const int t = t0 + it;
|
||||
@ -42,18 +48,28 @@ static __global__ void ssm_conv_single_seq_f32(
|
||||
: src1[row + (size_t) (idx - (nc - 1)) * src1_s1];
|
||||
|
||||
sumf += x * c_row[j];
|
||||
if constexpr (save_steps) {
|
||||
if (j > 0) {
|
||||
y[j-1] = x;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
dst_x[row + (size_t) t * nr] = sumf;
|
||||
|
||||
if constexpr (save_steps) {
|
||||
y += (nc - 1)*nr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int split_n_t>
|
||||
template <int split_n_t, bool save_steps>
|
||||
static __global__ void ssm_conv_single_seq_f32_nc4(
|
||||
const float * src0,
|
||||
const float * src1,
|
||||
const float * src2,
|
||||
float * dst_x,
|
||||
[[maybe_unused]] float * saved,
|
||||
int nr,
|
||||
int n_t,
|
||||
int src0_s0,
|
||||
@ -69,6 +85,11 @@ static __global__ void ssm_conv_single_seq_f32_nc4(
|
||||
return;
|
||||
}
|
||||
|
||||
[[maybe_unused]] float * y;
|
||||
if constexpr (save_steps) {
|
||||
y = saved + 3*(t0*nr + row);
|
||||
}
|
||||
|
||||
const float * state_row = src0 + (size_t) row * src0_s1;
|
||||
const float * c_row = src2 + (size_t) row * 4;
|
||||
const float c0 = c_row[0];
|
||||
@ -94,6 +115,13 @@ static __global__ void ssm_conv_single_seq_f32_nc4(
|
||||
const float x3 = i3 < 3 ? state_row[(size_t) i3 * src0_s0] : src1[row + (size_t) (i3 - 3) * src1_s1];
|
||||
|
||||
dst_x[row + (size_t) t * nr] = x0 * c0 + x1 * c1 + x2 * c2 + x3 * c3;
|
||||
|
||||
if constexpr (save_steps) {
|
||||
y[0] = x1;
|
||||
y[1] = x2;
|
||||
y[2] = x3;
|
||||
y += 3*nr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -427,6 +455,7 @@ void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src1 = dst->src[1]; // x: [d_inner, n_tokens]
|
||||
const ggml_tensor * src2 = dst->src[2]; // conv1d.weight: [d_conv, d_inner]
|
||||
const ggml_tensor * src3 = dst->src[3]; // state_seq: [n_kv, n_tokens]
|
||||
const ggml_tensor * src4 = dst->src[4]; // [d_conv - 1, d_inner, n_tokens]
|
||||
|
||||
const int nc = src2->ne[0];
|
||||
const int nr = src0->ne[1];
|
||||
@ -455,6 +484,12 @@ void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
GGML_ASSERT(src3->ne[0] == src0->ne[2]);
|
||||
GGML_ASSERT(src3->ne[1] == src1->ne[1]);
|
||||
|
||||
if (src4) {
|
||||
GGML_ASSERT(n_kv == 1);
|
||||
GGML_ASSERT(src4->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ggml_nelements(src4) >= (nc - 1)*nr*n_t);
|
||||
}
|
||||
|
||||
float * dst_data = (float *) dst->data;
|
||||
float * dst_x = dst_data;
|
||||
float * dst_state = dst_data + (size_t) nr * n_t;
|
||||
@ -476,22 +511,42 @@ void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
constexpr int split_n_t = 32;
|
||||
const dim3 token_grid(row_grid.x, (n_t + split_n_t - 1) / split_n_t, 1);
|
||||
|
||||
if (nc == 4) {
|
||||
ssm_conv_single_seq_f32_nc4<split_n_t><<<token_grid, block_dims, 0, ctx.stream()>>>(
|
||||
(const float *) src0->data,
|
||||
(const float *) src1->data,
|
||||
(const float *) src2->data,
|
||||
dst_x,
|
||||
nr, n_t,
|
||||
src0_s0, src0_s1, src1_s1);
|
||||
if (src4) {
|
||||
if (nc == 4) {
|
||||
ssm_conv_single_seq_f32_nc4<split_n_t, true><<<token_grid, block_dims, 0, ctx.stream()>>>(
|
||||
(const float *) src0->data,
|
||||
(const float *) src1->data,
|
||||
(const float *) src2->data,
|
||||
dst_x, (float *)src4->data,
|
||||
nr, n_t,
|
||||
src0_s0, src0_s1, src1_s1);
|
||||
} else {
|
||||
ssm_conv_single_seq_f32<split_n_t, true><<<token_grid, block_dims, 0, ctx.stream()>>>(
|
||||
(const float *) src0->data,
|
||||
(const float *) src1->data,
|
||||
(const float *) src2->data,
|
||||
dst_x, (float *)src4->data,
|
||||
nc, nr, n_t,
|
||||
src0_s0, src0_s1, src1_s1);
|
||||
}
|
||||
} else {
|
||||
ssm_conv_single_seq_f32<split_n_t><<<token_grid, block_dims, 0, ctx.stream()>>>(
|
||||
(const float *) src0->data,
|
||||
(const float *) src1->data,
|
||||
(const float *) src2->data,
|
||||
dst_x,
|
||||
nc, nr, n_t,
|
||||
src0_s0, src0_s1, src1_s1);
|
||||
if (nc == 4) {
|
||||
ssm_conv_single_seq_f32_nc4<split_n_t, false><<<token_grid, block_dims, 0, ctx.stream()>>>(
|
||||
(const float *) src0->data,
|
||||
(const float *) src1->data,
|
||||
(const float *) src2->data,
|
||||
dst_x, nullptr,
|
||||
nr, n_t,
|
||||
src0_s0, src0_s1, src1_s1);
|
||||
} else {
|
||||
ssm_conv_single_seq_f32<split_n_t, false><<<token_grid, block_dims, 0, ctx.stream()>>>(
|
||||
(const float *) src0->data,
|
||||
(const float *) src1->data,
|
||||
(const float *) src2->data,
|
||||
dst_x, nullptr,
|
||||
nc, nr, n_t,
|
||||
src0_s0, src0_s1, src1_s1);
|
||||
}
|
||||
}
|
||||
|
||||
ssm_conv_single_seq_final_state_f32<<<row_grid, block_dims, 0, ctx.stream()>>>(
|
||||
|
||||
@ -10449,7 +10449,8 @@ struct ggml_tensor * ggml_ssm_conv(
|
||||
struct ggml_tensor * s,
|
||||
struct ggml_tensor * x,
|
||||
struct ggml_tensor * c,
|
||||
struct ggml_tensor * sq) {
|
||||
struct ggml_tensor * sq,
|
||||
struct ggml_tensor * saved_steps) {
|
||||
GGML_ASSERT(ggml_is_3d(s));
|
||||
GGML_ASSERT(ggml_is_matrix(x));
|
||||
GGML_ASSERT(ggml_is_matrix(c));
|
||||
@ -10467,6 +10468,12 @@ struct ggml_tensor * ggml_ssm_conv(
|
||||
GGML_ASSERT(sq->ne[0] == n_kv);
|
||||
GGML_ASSERT(sq->ne[1] == n_tokens);
|
||||
|
||||
if (saved_steps) {
|
||||
GGML_ASSERT(n_kv == 1);
|
||||
GGML_ASSERT(saved_steps->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ggml_nelements(saved_steps) >= (d_conv - 1)*d_inner*n_tokens);
|
||||
}
|
||||
|
||||
bool is_node = false;
|
||||
|
||||
if (s->grad || x->grad || c->grad || sq->grad) {
|
||||
@ -10483,6 +10490,7 @@ struct ggml_tensor * ggml_ssm_conv(
|
||||
result->src[1] = x;
|
||||
result->src[2] = c;
|
||||
result->src[3] = sq;
|
||||
result->src[4] = saved_steps;
|
||||
|
||||
return result;
|
||||
}
|
||||
@ -22399,6 +22407,7 @@ static int ggml_compute_forward_ssm_conv_f32(
|
||||
const struct ggml_tensor * src1 = dst->src[1]; // x
|
||||
const struct ggml_tensor * src2 = dst->src[2]; // conv1d.weight
|
||||
const struct ggml_tensor * src3 = dst->src[3]; // state_seq
|
||||
const struct ggml_tensor * src4 = dst->src[4]; // state_seq
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
@ -22408,6 +22417,11 @@ static int ggml_compute_forward_ssm_conv_f32(
|
||||
const int n_t = src1->ne[1]; // n_tokens
|
||||
const int n_kv = src0->ne[2]; // max number of sequences in the batch
|
||||
|
||||
if (src4) {
|
||||
GGML_ASSERT(n_kv == 1);
|
||||
GGML_ASSERT(src4->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ggml_nelements(src4) >= (nc - 1)*nr*n_t);
|
||||
}
|
||||
GGML_ASSERT((nr*n_t) + (nc*nr*n_kv) == ggml_nelements(dst));
|
||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
||||
@ -22417,7 +22431,7 @@ static int ggml_compute_forward_ssm_conv_f32(
|
||||
// for use with the destination state offset between sequences
|
||||
GGML_ASSERT(src2->nb[2] == src2->ne[1]*src2->ne[0]*sizeof(float));
|
||||
|
||||
if (n_kv == 1 && nc == 4) {
|
||||
if (n_kv == 1 && nc == 4 && !src4) { // TODO: implement per token state saving in iqk_ssm_conv4
|
||||
float * dst_silu = NULL;
|
||||
if (node < cgraph->n_nodes + 2 &&
|
||||
cgraph->nodes[node+1]->op == GGML_OP_VIEW && cgraph->nodes[node+1]->src[0] == dst &&
|
||||
@ -22462,6 +22476,8 @@ static int ggml_compute_forward_ssm_conv_f32(
|
||||
}
|
||||
}
|
||||
|
||||
float * y = src4 ? (float *)src4->data + ir0*(nc-1) : NULL;
|
||||
|
||||
for (int i2 = 0; i2 < n_t; ++i2) {
|
||||
int32_t * sq = (int32_t *) ((char *) src3->data + i2*(src3->nb[1])); // {n_kv, n_tokens}
|
||||
float * x = (float *) ((char *) dst->data + ir0*sizeof(float) + i2*(nr*sizeof(float))); // {d_inner, n_tokens}
|
||||
@ -22492,6 +22508,14 @@ static int ggml_compute_forward_ssm_conv_f32(
|
||||
// insert x on the last column
|
||||
s[(nc - 1) + i1*nc] = x0[i1];
|
||||
}
|
||||
if (y) {
|
||||
for (int i1 = 0; i1 < ir; ++i1) {
|
||||
for (int i0 = 0; i0 < nc - 1; ++i0) {
|
||||
y[i0 + i1*(nc-1)] = s[i0 + 1 + i1*nc];
|
||||
}
|
||||
}
|
||||
y += nr*(nc - 1);
|
||||
}
|
||||
|
||||
// handle copies when there are multiple output states
|
||||
for (int i3 = 1; i3 < n_kv; ++i3) {
|
||||
|
||||
@ -63,7 +63,7 @@ ggml_cgraph * llm_build_context::build_mamba() {
|
||||
// The new conv_states is the last (d_conv - 1) columns
|
||||
// of the last 3rd dimensional "layer" of the self-overlapping view.
|
||||
// For simultaneous sequences, it's more complicated.
|
||||
struct ggml_tensor * x_conv = ggml_ssm_conv(ctx0, conv_states, x, model.layers[il].ssm_conv1d, state_seq);
|
||||
struct ggml_tensor * x_conv = ggml_ssm_conv(ctx0, conv_states, x, model.layers[il].ssm_conv1d, state_seq, nullptr);
|
||||
|
||||
// store last (d_conv - 1) columns of the conv_state part of x_conv back into the KV cache
|
||||
ggml_build_forward_expand(gf,
|
||||
|
||||
@ -93,7 +93,8 @@ struct llama_kv_cache {
|
||||
// Per-step conv feature buffer: stores qkv_mixed features from the
|
||||
// verification forward pass so conv state can be reconstructed at any step.
|
||||
// One tensor per recurrent layer, each sized [conv_dim * max_tokens].
|
||||
std::vector<std::vector<ggml_tensor *>> per_step_qkv;
|
||||
//std::vector<std::vector<ggml_tensor *>> per_step_qkv;
|
||||
std::vector<std::vector<ggml_tensor *>> per_step_conv;
|
||||
|
||||
int32_t per_step_n_tokens = 0;
|
||||
int32_t per_step_max_allocated = 0;
|
||||
|
||||
@ -150,11 +150,10 @@ std::pair<ggml_tensor *, ggml_tensor *> delta_net::build_fused_delta_net(ggml_co
|
||||
|
||||
std::pair<ggml_tensor *, ggml_tensor *> delta_net::build_qkvz(llama_context & lctx, ggml_context * ctx0,
|
||||
ggml_tensor * wqkv, ggml_tensor * wqkv_gate,
|
||||
ggml_tensor * input, int il, const llm_build_cb & cb, ggml_cgraph * gf, ggml_tensor * qkv_cpy) {
|
||||
ggml_tensor * input, int il, const llm_build_cb & cb, ggml_cgraph * gf) {
|
||||
|
||||
const int64_t n_tok = input->ne[1];
|
||||
auto qkv_mixed = qkv_cpy ? ggml_mul_mat_inplace(ctx0, wqkv, input, qkv_cpy)
|
||||
: llm_build_context::llm_build_lora_mm(lctx, ctx0, wqkv, input);
|
||||
auto qkv_mixed = llm_build_context::llm_build_lora_mm(lctx, ctx0, wqkv, input);
|
||||
cb(qkv_mixed, "qkv_mixed", il);
|
||||
ggml_tensor * z = llm_build_context::llm_build_lora_mm(lctx, ctx0, wqkv_gate, input);
|
||||
cb(z, "z", il);
|
||||
@ -167,7 +166,7 @@ std::pair<ggml_tensor *, ggml_tensor *> delta_net::build_qkvz(llama_context & lc
|
||||
|
||||
std::pair<ggml_tensor *, ggml_tensor *> delta_net::build_qkvz(llama_context & lctx, ggml_context * ctx0, ggml_tensor * ssm_in,
|
||||
int64_t head_k_dim, int64_t num_k_heads, int64_t head_v_dim, int64_t num_v_heads,
|
||||
ggml_tensor * input, int il, const llm_build_cb & cb, ggml_tensor * qkv_cpy) {
|
||||
ggml_tensor * input, int il, const llm_build_cb & cb) {
|
||||
|
||||
const int64_t n_tok = input->ne[1];
|
||||
|
||||
@ -214,17 +213,17 @@ std::pair<ggml_tensor *, ggml_tensor *> delta_net::build_qkvz(llama_context & lc
|
||||
cb(value_flat, "value_flat", il);
|
||||
|
||||
ggml_tensor * qkv_mixed = ggml_concat(ctx0, query_flat, key_flat, 0);
|
||||
qkv_mixed = qkv_cpy ? ggml_concat_inplace(ctx0, qkv_mixed, value_flat, qkv_cpy, 0) : ggml_concat(ctx0, qkv_mixed, value_flat, 0);
|
||||
qkv_mixed = ggml_concat(ctx0, qkv_mixed, value_flat, 0);
|
||||
cb(qkv_mixed, "qkv_mixed", il);
|
||||
|
||||
return { qkv_mixed, z };
|
||||
}
|
||||
|
||||
std::pair<ggml_tensor *, ggml_tensor *> delta_net::build_qkvz(llama_context & lctx, ggml_context * ctx0, ggml_tensor * wqkv, ggml_tensor * wqkv_gate, ggml_tensor * ssm_in,
|
||||
int64_t head_k_dim, int64_t num_k_heads, int64_t head_v_dim, int64_t num_v_heads, ggml_tensor * input, int il, const llm_build_cb & cb, ggml_cgraph * gf, ggml_tensor * qkv_cpy) {
|
||||
int64_t head_k_dim, int64_t num_k_heads, int64_t head_v_dim, int64_t num_v_heads, ggml_tensor * input, int il, const llm_build_cb & cb, ggml_cgraph * gf) {
|
||||
GGML_ASSERT((wqkv && wqkv_gate) || ssm_in);
|
||||
return wqkv && wqkv_gate ? build_qkvz(lctx, ctx0, wqkv, wqkv_gate, input, il, cb, gf, qkv_cpy)
|
||||
: build_qkvz(lctx, ctx0, ssm_in, head_k_dim, num_k_heads, head_v_dim, num_v_heads, input, il, cb, qkv_cpy);
|
||||
return wqkv && wqkv_gate ? build_qkvz(lctx, ctx0, wqkv, wqkv_gate, input, il, cb, gf)
|
||||
: build_qkvz(lctx, ctx0, ssm_in, head_k_dim, num_k_heads, head_v_dim, num_v_heads, input, il, cb);
|
||||
}
|
||||
|
||||
std::pair<ggml_tensor *, ggml_tensor *> delta_net::build_beta_gate(llama_context & lctx, ggml_context * ctx0,
|
||||
@ -289,7 +288,7 @@ ggml_tensor * delta_net::build_qkv(ggml_context * ctx0, ggml_tensor * state_stor
|
||||
int64_t head_k_dim, int64_t num_k_heads, int64_t head_v_dim, int64_t num_v_heads, int64_t ssm_d_conv,
|
||||
int64_t state_seq_id_local, uint32_t qnext_state_slots, bool reset_state_local,
|
||||
float eps_norm, int repeat_type, int il, const llm_build_cb & cb, ggml_cgraph * gf,
|
||||
ggml_tensor * per_step_ckpt) {
|
||||
ggml_tensor * per_step_ckpt, ggml_tensor * per_step_conv) {
|
||||
const int64_t key_dim = head_k_dim * num_k_heads;
|
||||
const int64_t value_dim = head_v_dim * num_v_heads;
|
||||
const int64_t conv_dim = key_dim * 2 + value_dim;
|
||||
@ -332,7 +331,7 @@ ggml_tensor * delta_net::build_qkv(ggml_context * ctx0, ggml_tensor * state_stor
|
||||
cb(state, "state_predelta", il);
|
||||
ggml_build_forward_expand(gf, state);
|
||||
|
||||
ggml_tensor * conv_output_raw = ggml_ssm_conv(ctx0, conv_states, qkv_mixed, ssm_conv1d, inp_s_seq_qnext);
|
||||
ggml_tensor * conv_output_raw = ggml_ssm_conv(ctx0, conv_states, qkv_mixed, ssm_conv1d, inp_s_seq_qnext, per_step_conv);
|
||||
cb(conv_output_raw, "conv_output_raw", il);
|
||||
|
||||
ggml_tensor * conv_output = ggml_view_2d(ctx0, conv_output_raw, conv_dim, n_tok, conv_dim * ggml_element_size(conv_output_raw), 0);
|
||||
@ -433,15 +432,6 @@ static ggml_tensor * get_input_tensor_sm_graph(ggml_context * ctx, ggml_tensor *
|
||||
return cur;
|
||||
}
|
||||
|
||||
static void build_qkv_cpy(ggml_context * ctx0, ggml_cgraph * gf, ggml_tensor * qkv_mixed, ggml_tensor * per_step_qkv) {
|
||||
const int64_t conv_dim = qkv_mixed->ne[0];
|
||||
const int64_t n_tok_qkv = qkv_mixed->ne[1] * qkv_mixed->ne[2];
|
||||
ggml_tensor * qkv_flat = ggml_reshape_2d(ctx0, qkv_mixed, conv_dim, n_tok_qkv);
|
||||
ggml_tensor * qkv_dst = ggml_view_2d(ctx0, per_step_qkv, conv_dim, n_tok_qkv, conv_dim * sizeof(float), 0);
|
||||
auto qkv_cpy = ggml_cpy(ctx0, qkv_flat, qkv_dst);
|
||||
ggml_build_forward_expand(gf, qkv_cpy);
|
||||
}
|
||||
|
||||
ggml_tensor * delta_net::build_layer_attn_linear_core(ggml_context * ctx0, ggml_cgraph * gf,
|
||||
ggml_tensor * delta_input, ggml_tensor * inp_s_seq_qnext, ggml_tensor * inp_out_ids,
|
||||
uint32_t state_seq_id_local, bool reset_state_local, int il, const llm_build_cb & cb) const {
|
||||
@ -492,21 +482,17 @@ ggml_tensor * delta_net::build_layer_attn_linear_core(ggml_context * ctx0, ggml_
|
||||
int il_cb = 1000*il + id;
|
||||
int64_t num_k_heads_id, num_v_heads_id;
|
||||
ggml_tensor *qkv_mixed, *z;
|
||||
auto qkv_cpy = save_per_step_states &&
|
||||
il < (int)kv_self.ckpt.per_step_qkv.size() &&
|
||||
id < (int)kv_self.ckpt.per_step_qkv[il].size() ?
|
||||
kv_self.ckpt.per_step_qkv[il][id] : nullptr;
|
||||
if (split_wqkv && split_wqkv_gate) {
|
||||
num_k_heads_id = split_wqkv->splits[id]->ne[1]/(head_k_dim*(2 + gqa_ratio));
|
||||
num_v_heads_id = num_k_heads_id * gqa_ratio;
|
||||
auto p = build_qkvz(lctx, ctx0, split_wqkv->splits[id], split_wqkv_gate->splits[id], cur, il_cb, cb, gf, qkv_cpy);
|
||||
auto p = build_qkvz(lctx, ctx0, split_wqkv->splits[id], split_wqkv_gate->splits[id], cur, il_cb, cb, gf);
|
||||
qkv_mixed = p.first;
|
||||
z = p.second;
|
||||
} else {
|
||||
num_k_heads_id = split_smm_in->splits[id]->ne[1]/(2*head_k_dim*(1 + gqa_ratio));
|
||||
num_v_heads_id = num_k_heads_id * gqa_ratio;
|
||||
auto p = build_qkvz(lctx, ctx0, nullptr, nullptr, split_smm_in->splits[id],
|
||||
head_k_dim, num_k_heads_id, head_v_dim, num_v_heads_id, cur, il, cb, gf, qkv_cpy);
|
||||
head_k_dim, num_k_heads_id, head_v_dim, num_v_heads_id, cur, il, cb, gf);
|
||||
qkv_mixed = p.first;
|
||||
z = p.second;
|
||||
}
|
||||
@ -535,15 +521,15 @@ ggml_tensor * delta_net::build_layer_attn_linear_core(ggml_context * ctx0, ggml_
|
||||
ggml_tensor * per_step_ckpt = nullptr;
|
||||
if (save_per_step_states && il < (int)kv_self.ckpt.per_step_ssm.size()) {
|
||||
per_step_ckpt = kv_self.ckpt.per_step_ssm[il][id];
|
||||
//GGML_ASSERT(per_step_ckpt);
|
||||
}
|
||||
//if (save_per_step_states && il < (int)kv_self.ckpt.per_step_qkv.size() && id < (int)kv_self.ckpt.per_step_qkv[il].size()) {
|
||||
// build_qkv_cpy(ctx0, gf, qkv_mixed, kv_self.ckpt.per_step_qkv[il][id]);
|
||||
//}
|
||||
auto per_step_conv = save_per_step_states && il < (int)kv_self.ckpt.per_step_conv.size() &&
|
||||
id < (int)kv_self.ckpt.per_step_conv[il].size()
|
||||
? kv_self.ckpt.per_step_conv[il][id] : nullptr;
|
||||
|
||||
auto output = build_qkv(ctx0, split_s_l->splits[id], split_ssm_conv1d->splits[id], qkv_mixed, inp_s_seq_qnext, beta, gate,
|
||||
head_k_dim, num_k_heads_id, head_v_dim, num_v_heads_id, hparams.ssm_d_conv,
|
||||
state_seq_id_local, qnext_state_slots, reset_state_local, hparams.f_norm_rms_eps,
|
||||
l.ssm_beta_alpha ? 0 : 1, il, cb, gf, per_step_ckpt);
|
||||
l.ssm_beta_alpha ? 0 : 1, il, cb, gf, per_step_ckpt, per_step_conv);
|
||||
split_norm = (ggml_split_tensor_t *)l.ssm_norm->extra;
|
||||
GGML_ASSERT(split_norm && split_norm->splits[id]);
|
||||
auto split_ssm_out = (ggml_split_tensor_t *)l.ssm_out->extra;
|
||||
@ -587,10 +573,8 @@ ggml_tensor * delta_net::build_layer_attn_linear_core(ggml_context * ctx0, ggml_
|
||||
auto norm = model.layers[il].attn_norm->extra ? ((ggml_split_tensor_t *)model.layers[il].attn_norm->extra)->splits[idx] : model.layers[il].attn_norm;
|
||||
auto cur = llm_build_context::llm_build_norm(ctx0, input, hparams, norm, nullptr, LLM_NORM_RMS, cb, il);
|
||||
|
||||
auto qkv_cpy = save_per_step_states && il < (int)kv_self.ckpt.per_step_qkv.size() && !kv_self.ckpt.per_step_qkv[il].empty()
|
||||
? kv_self.ckpt.per_step_qkv[il].front() : nullptr;
|
||||
auto [qkv_mixed, z] = build_qkvz(lctx, ctx0, model.layers[il].wqkv, model.layers[il].wqkv_gate, model.layers[il].ssm_in,
|
||||
head_k_dim, num_k_heads, head_v_dim, num_v_heads, cur, il, cb, gf, qkv_cpy);
|
||||
head_k_dim, num_k_heads, head_v_dim, num_v_heads, cur, il, cb, gf);
|
||||
|
||||
auto [beta, gate] = build_beta_gate(lctx, ctx0, model.layers[il].ssm_beta_alpha, model.layers[il].ssm_beta, model.layers[il].ssm_alpha,
|
||||
model.layers[il].ssm_dt, model.layers[il].ssm_a, num_k_heads, num_v_heads, n_seqs, cur, il, cb, gf);
|
||||
@ -601,16 +585,14 @@ ggml_tensor * delta_net::build_layer_attn_linear_core(ggml_context * ctx0, ggml_
|
||||
per_step_ckpt = kv_self.ckpt.per_step_ssm[il].front();
|
||||
}
|
||||
|
||||
// Save qkv_mixed features for per-step conv state reconstruction
|
||||
//if (save_per_step_states && il < (int)kv_self.ckpt.per_step_qkv.size() && !kv_self.ckpt.per_step_qkv[il].empty()) {
|
||||
// build_qkv_cpy(ctx0, gf, qkv_mixed, kv_self.ckpt.per_step_qkv[il].front());
|
||||
//}
|
||||
auto per_step_conv = save_per_step_states && il < (int)kv_self.ckpt.per_step_conv.size() && !kv_self.ckpt.per_step_conv[il].empty()
|
||||
? kv_self.ckpt.per_step_conv[il].front() : nullptr;
|
||||
|
||||
auto output = build_qkv(ctx0, kv_self.s_l[il], model.layers[il].ssm_conv1d,
|
||||
qkv_mixed, inp_s_seq_qnext, beta, gate,
|
||||
head_k_dim, num_k_heads, head_v_dim, num_v_heads, hparams.ssm_d_conv,
|
||||
state_seq_id_local, qnext_state_slots, reset_state_local, hparams.f_norm_rms_eps,
|
||||
model.layers[il].ssm_beta_alpha ? 0 : 1, il, cb, gf, per_step_ckpt);
|
||||
model.layers[il].ssm_beta_alpha ? 0 : 1, il, cb, gf, per_step_ckpt, per_step_conv);
|
||||
|
||||
auto gated_output = build_gated_output(lctx, ctx0, model.layers[il].ssm_norm, model.layers[il].ssm_out, output, z, head_v_dim, num_v_heads, n_tok, il, cb);
|
||||
if (inp_out_ids) {
|
||||
@ -620,7 +602,6 @@ ggml_tensor * delta_net::build_layer_attn_linear_core(ggml_context * ctx0, ggml_
|
||||
output = ggml_add(ctx0, gated_output, input);
|
||||
cb(output, "ssm_output", il);
|
||||
return output;
|
||||
//return build_gated_output(lctx, ctx0, model.layers[il].ssm_norm, model.layers[il].ssm_out, output, z, head_v_dim, num_v_heads, n_tok, il, cb);
|
||||
|
||||
}
|
||||
|
||||
|
||||
@ -34,16 +34,16 @@ private:
|
||||
|
||||
static std::pair<ggml_tensor *, ggml_tensor *> build_qkvz(llama_context & lctx, ggml_context * ctx0,
|
||||
ggml_tensor * wqkv, ggml_tensor * wqkv_gate, ggml_tensor * input, int il, const llm_build_cb & cb,
|
||||
ggml_cgraph * gf, ggml_tensor * qkv_copy);
|
||||
ggml_cgraph * gf);
|
||||
|
||||
static std::pair<ggml_tensor *, ggml_tensor *> build_qkvz(llama_context & lctx, ggml_context * ctx0, ggml_tensor * ssm_in,
|
||||
int64_t head_k_dim, int64_t num_k_heads, int64_t head_v_dim, int64_t num_v_heads, ggml_tensor * input, int il,
|
||||
const llm_build_cb & cb, ggml_tensor * qkv_copy);
|
||||
const llm_build_cb & cb);
|
||||
|
||||
static std::pair<ggml_tensor *, ggml_tensor *> build_qkvz(llama_context & lctx, ggml_context * ctx0,
|
||||
ggml_tensor * wqkv, ggml_tensor * wqkv_gate, ggml_tensor * ssm_in,
|
||||
int64_t head_k_dim, int64_t num_k_heads, int64_t head_v_dim, int64_t num_v_heads, ggml_tensor * input,
|
||||
int il, const llm_build_cb & cb, ggml_cgraph * gf, ggml_tensor * qkv_copy);
|
||||
int il, const llm_build_cb & cb, ggml_cgraph * gf);
|
||||
|
||||
static std::pair<ggml_tensor *, ggml_tensor *> build_beta_gate(llama_context & lctx, ggml_context * ctx0,
|
||||
ggml_tensor * ssm_beta_alpha, ggml_tensor * ssm_beta, ggml_tensor * ssm_alpha,
|
||||
@ -55,7 +55,7 @@ private:
|
||||
int64_t head_k_dim, int64_t num_k_heads, int64_t head_v_dim, int64_t num_v_heads, int64_t ssm_d_conv,
|
||||
int64_t state_seq_id_local, uint32_t qnext_state_slots, bool reset_state_local,
|
||||
float eps_norm, int repeat_type, int il, const llm_build_cb & cb, ggml_cgraph * gf,
|
||||
ggml_tensor * per_step_ckpt = nullptr);
|
||||
ggml_tensor * per_step_ssm = nullptr, ggml_tensor * per_step_conv = nullptr);
|
||||
|
||||
static ggml_tensor * build_gated_output(llama_context & lctx, ggml_context * ctx0, ggml_tensor * ssm_norm, ggml_tensor * ssm_out,
|
||||
ggml_tensor * output, ggml_tensor * z, int64_t head_v_dim, int64_t num_v_heads, int64_t n_tok, int il, const llm_build_cb & cb);
|
||||
|
||||
@ -1554,16 +1554,17 @@ bool llama_kv_cache::per_step_alloc(const llama_model & model, int max_tokens) {
|
||||
ckpt.per_step_ctxs.clear();
|
||||
ckpt.per_step_bufs.clear();
|
||||
ckpt.per_step_ssm.clear();
|
||||
ckpt.per_step_qkv.clear();
|
||||
ckpt.per_step_conv.clear();
|
||||
ckpt.per_step_max_allocated = 0;
|
||||
}
|
||||
|
||||
const uint32_t n_layer = (uint32_t)s_l.size();
|
||||
ckpt.per_step_ssm.resize(n_layer);
|
||||
ckpt.per_step_qkv.resize(n_layer);
|
||||
ckpt.per_step_conv.resize(n_layer);
|
||||
|
||||
const int64_t ssm_state_dim = ckpt.per_step_ssm_state_size;
|
||||
const int64_t ssm_state_dim = ckpt.per_step_ssm_state_size;
|
||||
const int64_t conv_dim = ckpt.per_step_conv_dim;
|
||||
const int64_t conv_state_dim = ckpt.per_step_conv_state_dim;
|
||||
if (ssm_state_dim <= 0 || conv_dim <= 0) {
|
||||
LLAMA_LOG_ERROR("%s: per_step dimensions not set (ssm=%lld, conv_dim=%lld)\n",
|
||||
__func__, (long long)ssm_state_dim, (long long)conv_dim);
|
||||
@ -1572,6 +1573,7 @@ bool llama_kv_cache::per_step_alloc(const llama_model & model, int max_tokens) {
|
||||
|
||||
int num_v_heads = model.hparams.ssm_dt_rank;
|
||||
int head_v_dim = model.hparams.ssm_d_inner / num_v_heads;
|
||||
int d_conv = model.hparams.ssm_d_conv;
|
||||
|
||||
std::map<ggml_backend_buffer_type_t, std::vector<std::pair<std::pair<uint32_t, int32_t>, ggml_backend_buffer_type_t>>> buft_layers;
|
||||
|
||||
@ -1611,16 +1613,16 @@ bool llama_kv_cache::per_step_alloc(const llama_model & model, int max_tokens) {
|
||||
if (id < 0) {
|
||||
if (max_tokens > 1) {
|
||||
GGML_ASSERT(ckpt.per_step_ssm[il].empty());
|
||||
GGML_ASSERT(ckpt.per_step_qkv[il].empty());
|
||||
GGML_ASSERT(ckpt.per_step_conv[il].empty());
|
||||
ggml_tensor * t_ssm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, (int64_t)(max_tokens - 1) * ssm_state_dim);
|
||||
ggml_format_name(t_ssm, "per_step_ssm_l%d", il);
|
||||
ckpt.per_step_ssm[il].push_back(t_ssm);
|
||||
}
|
||||
|
||||
// Conv features (qkv_mixed): max_tokens * conv_dim
|
||||
ggml_tensor * t_qkv = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, (int64_t)max_tokens * conv_dim);
|
||||
ggml_tensor * t_qkv = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, (int64_t)max_tokens * conv_state_dim);
|
||||
ggml_format_name(t_qkv, "per_step_qkv_l%d", il);
|
||||
ckpt.per_step_qkv[il].push_back(t_qkv);
|
||||
ckpt.per_step_conv[il].push_back(t_qkv);
|
||||
} else {
|
||||
auto split_sl = (ggml_split_tensor_t *)s_l[il]->extra;
|
||||
auto split_ssm_out = (const ggml_split_tensor_t *)model.layers[il].ssm_out->extra;
|
||||
@ -1628,13 +1630,13 @@ bool llama_kv_cache::per_step_alloc(const llama_model & model, int max_tokens) {
|
||||
auto split = split_ssm_out->splits[id];
|
||||
GGML_ASSERT(split->ne[0] % head_v_dim == 0);
|
||||
if (ckpt.per_step_ssm[il].empty()) {
|
||||
ckpt.per_step_ssm[il].resize(split_sl->n_device, nullptr);
|
||||
ckpt.per_step_qkv[il].resize(split_sl->n_device, nullptr);
|
||||
ckpt.per_step_ssm [il].resize(split_sl->n_device, nullptr);
|
||||
ckpt.per_step_conv[il].resize(split_sl->n_device, nullptr);
|
||||
} else {
|
||||
GGML_ASSERT(int(ckpt.per_step_ssm[il].size()) == split_sl->n_device);
|
||||
GGML_ASSERT(int(ckpt.per_step_qkv[il].size()) == split_sl->n_device);
|
||||
GGML_ASSERT(int(ckpt.per_step_conv[il].size()) == split_sl->n_device);
|
||||
GGML_ASSERT(ckpt.per_step_ssm[il][id] == nullptr);
|
||||
GGML_ASSERT(ckpt.per_step_qkv[il][id] == nullptr);
|
||||
GGML_ASSERT(ckpt.per_step_conv[il][id] == nullptr);
|
||||
}
|
||||
int nv = split->ne[0] / head_v_dim; // number of heads handled by this device
|
||||
auto [this_conv_dim, this_ssm_dim] = model.hparams.n_embd_v_s_dims(nv);
|
||||
@ -1645,9 +1647,9 @@ bool llama_kv_cache::per_step_alloc(const llama_model & model, int max_tokens) {
|
||||
ckpt.per_step_ssm[il][id] = t_ssm;
|
||||
}
|
||||
|
||||
auto t_qkv = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, (int64_t)max_tokens * this_conv_dim);
|
||||
auto t_qkv = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, (int64_t)max_tokens * this_conv_dim * (d_conv - 1));
|
||||
ggml_format_name(t_qkv, "per_step_qkv_l%d_%d", il, id);
|
||||
ckpt.per_step_qkv[il][id] = t_qkv;
|
||||
ckpt.per_step_conv[il][id] = t_qkv;
|
||||
}
|
||||
}
|
||||
|
||||
@ -1669,59 +1671,24 @@ bool llama_kv_cache::per_step_alloc(const llama_model & model, int max_tokens) {
|
||||
}
|
||||
|
||||
static void restore_recurrent_cache_tensors(int step, ggml_backend_sched_t sched,
|
||||
size_t ssm_bytes, size_t conv_bytes, size_t conv_dim, int d_conv_m1,
|
||||
ggml_tensor * s_l, ggml_tensor * per_step_ssm,
|
||||
ggml_tensor * s_l_shadow, ggml_tensor * per_step_qkv,
|
||||
std::vector<float> & old_conv_buf, std::vector<float> & qkv_buf, std::vector<float> & conv_buf,
|
||||
size_t ssm_bytes, size_t conv_bytes,
|
||||
ggml_tensor * s_l, ggml_tensor * per_step_ssm, ggml_tensor * per_step_conv,
|
||||
std::unordered_set<ggml_backend_t> & backends_to_sync) {
|
||||
auto dst_backend = ggml_backend_sched_get_tensor_backend(sched, s_l);
|
||||
auto dst = *s_l;
|
||||
dst.ne[0] = ssm_bytes/sizeof(float);
|
||||
dst.nb[1] = dst.nb[2] = dst.nb[3] = ssm_bytes;
|
||||
dst.nb[1] = dst.nb[2] = dst.nb[3] = ssm_bytes + conv_bytes;
|
||||
dst.data = (char *)s_l->data + conv_bytes;
|
||||
auto src = dst;
|
||||
// I think the commented out version is correct, bt only after fixing the delta-net implementation.
|
||||
// For now let's just reproduce what we have on the main branch, only avoiding the expensive
|
||||
// device -> host -> device copy.
|
||||
//src.data = (char *)per_step_ssm->data + (size_t)(step + 1) * ssm_bytes;
|
||||
src.data = (char *)per_step_ssm->data + (size_t)step * ssm_bytes;
|
||||
ggml_backend_tensor_copy_async(dst_backend, dst_backend, &src, &dst);
|
||||
backends_to_sync.insert(dst_backend);
|
||||
|
||||
if (s_l_shadow != nullptr) {
|
||||
ggml_backend_tensor_get(s_l_shadow, old_conv_buf.data(), 0, conv_bytes);
|
||||
} else {
|
||||
memset(old_conv_buf.data(), 0, conv_bytes);
|
||||
}
|
||||
|
||||
auto qkv_needed = size_t(step + 1) * conv_dim;
|
||||
if (per_step_qkv) {
|
||||
ggml_backend_tensor_get(per_step_qkv, qkv_buf.data(), 0, qkv_needed * sizeof(float));
|
||||
} else {
|
||||
memset(qkv_buf.data(), 0, qkv_needed * sizeof(float));
|
||||
}
|
||||
|
||||
for (int32_t col = 0; col < d_conv_m1; col++) {
|
||||
int32_t src_token = step - (d_conv_m1 - 1) + col; // e.g., K-2, K-1, K for d_conv=4
|
||||
if (src_token >= 0) {
|
||||
for (int64_t d = 0; d < conv_dim; d++) {
|
||||
conv_buf[col + d * d_conv_m1] = qkv_buf[d + (int64_t)src_token * conv_dim];
|
||||
}
|
||||
} else {
|
||||
int32_t old_col = d_conv_m1 + src_token; // maps to 0, 1, ... for early steps
|
||||
if (old_col >= 0 && old_col < d_conv_m1) {
|
||||
for (int64_t d = 0; d < conv_dim; d++) {
|
||||
conv_buf[col + d * d_conv_m1] = old_conv_buf[old_col + d * d_conv_m1];
|
||||
}
|
||||
} else {
|
||||
for (int64_t d = 0; d < conv_dim; d++) {
|
||||
conv_buf[col + d * d_conv_m1] = 0.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ggml_backend_tensor_set(s_l, conv_buf.data(), 0, conv_bytes);
|
||||
dst.data = (char *)s_l->data;
|
||||
dst.ne[0] = conv_bytes/sizeof(float);
|
||||
src = dst;
|
||||
src.data = (char *)per_step_conv->data + (size_t)step * conv_bytes;
|
||||
ggml_backend_tensor_copy_async(dst_backend, dst_backend, &src, &dst);
|
||||
}
|
||||
|
||||
bool llama_kv_cache::per_step_restore(const llama_model & model, ggml_backend_sched_t sched, int step) {
|
||||
@ -1739,7 +1706,6 @@ bool llama_kv_cache::per_step_restore(const llama_model & model, ggml_backend_sc
|
||||
|
||||
const int64_t ssm_bytes = ssm_state_dim * sizeof(float);
|
||||
const int64_t conv_bytes = conv_state_dim * sizeof(float);
|
||||
const int32_t d_conv_m1 = d_conv - 1; // number of columns in conv state
|
||||
|
||||
std::vector<float> ssm_buf(ssm_state_dim);
|
||||
std::vector<float> conv_buf(conv_state_dim); // reconstructed conv state
|
||||
@ -1755,9 +1721,9 @@ bool llama_kv_cache::per_step_restore(const llama_model & model, ggml_backend_sc
|
||||
if (s_l[il] == nullptr || ckpt.per_step_ssm[il].empty()) continue;
|
||||
|
||||
if (!s_l[il]->extra) {
|
||||
restore_recurrent_cache_tensors(step, sched, ssm_bytes, conv_bytes, conv_dim, d_conv_m1,
|
||||
s_l[il], ckpt.per_step_ssm[il].front(), ckpt.s_l_shadow[il], ckpt.per_step_qkv[il].front(),
|
||||
old_conv_buf, qkv_buf, conv_buf, backends_to_sync);
|
||||
restore_recurrent_cache_tensors(step, sched, ssm_bytes, conv_bytes,
|
||||
s_l[il], ckpt.per_step_ssm[il].front(), ckpt.per_step_conv[il].front(),
|
||||
backends_to_sync);
|
||||
} else {
|
||||
auto split_sl = (const ggml_split_tensor_t *)s_l[il]->extra;
|
||||
for (int id = 0; id < split_sl->n_device; ++id) {
|
||||
@ -1769,16 +1735,14 @@ bool llama_kv_cache::per_step_restore(const llama_model & model, ggml_backend_sc
|
||||
int nv = split->ne[0] / head_v_dim; // number of heads handled by this device
|
||||
auto [this_conv_dim, this_ssm_dim] = model.hparams.n_embd_v_s_dims(nv);
|
||||
auto this_conv_bytes = (d_conv - 1) * this_conv_dim * sizeof(float);
|
||||
restore_recurrent_cache_tensors(step, sched, this_ssm_dim * sizeof(float), this_conv_bytes, this_conv_dim, d_conv_m1,
|
||||
split_sl->splits[id], ckpt.per_step_ssm[il][id], ckpt.split_s_l_shadow[il][id], ckpt.per_step_qkv[il][id],
|
||||
old_conv_buf, qkv_buf, conv_buf, backends_to_sync);
|
||||
restore_recurrent_cache_tensors(step, sched, this_ssm_dim * sizeof(float), this_conv_bytes,
|
||||
split_sl->splits[id], ckpt.per_step_ssm[il][id], ckpt.per_step_conv[il][id],
|
||||
backends_to_sync);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Strictly speaking we shouldn't need to do this. We are doing a synchronous copy of the
|
||||
// convolution state in the above loop, and that involves a ggml_backend_synchronize call.
|
||||
// But just in case.
|
||||
// TODO: do we need to synchronize here?
|
||||
for (auto backend : backends_to_sync) {
|
||||
ggml_backend_synchronize(backend);
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user