diff --git a/src/llama-context.h b/src/llama-context.h index d3526cd2..02a20052 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -88,12 +88,12 @@ struct llama_kv_cache { std::vector> split_s_l_shadow; // Per-step SSM state checkpoints for speculative decoding. - std::vector per_step_ssm; + std::vector> per_step_ssm; // Per-step conv feature buffer: stores qkv_mixed features from the // verification forward pass so conv state can be reconstructed at any step. // One tensor per recurrent layer, each sized [conv_dim * max_tokens]. - std::vector per_step_qkv; + std::vector> per_step_qkv; int32_t per_step_n_tokens = 0; int32_t per_step_max_allocated = 0; @@ -142,8 +142,8 @@ struct llama_kv_cache { void checkpoint_delete(); // Per-step checkpoint: allocate, restore step k's full state (SSM + conv) to cache - bool per_step_alloc(int max_tokens); - bool per_step_restore(ggml_backend_sched_t sched, int step); + bool per_step_alloc(const llama_model & model, int max_tokens); + bool per_step_restore(const llama_model & model, ggml_backend_sched_t sched, int step); ~llama_kv_cache() { for (struct ggml_context * ctx : ctxs) { diff --git a/src/llama-delta-net.cpp b/src/llama-delta-net.cpp index 9c6d0402..81c51020 100644 --- a/src/llama-delta-net.cpp +++ b/src/llama-delta-net.cpp @@ -454,6 +454,15 @@ static ggml_tensor * get_input_tensor_sm_graph(ggml_context * ctx, ggml_tensor * return cur; } +static void build_qkv_cpy(ggml_context * ctx0, ggml_cgraph * gf, ggml_tensor * qkv_mixed, ggml_tensor * per_step_qkv) { + const int64_t conv_dim = qkv_mixed->ne[0]; + const int64_t n_tok_qkv = qkv_mixed->ne[1] * qkv_mixed->ne[2]; + ggml_tensor * qkv_flat = ggml_reshape_2d(ctx0, qkv_mixed, conv_dim, n_tok_qkv); + ggml_tensor * qkv_dst = ggml_view_2d(ctx0, per_step_qkv, conv_dim, n_tok_qkv, conv_dim * sizeof(float), 0); + auto qkv_cpy = ggml_cpy(ctx0, qkv_flat, qkv_dst); + ggml_build_forward_expand(gf, qkv_cpy); +} + ggml_tensor * delta_net::build_layer_attn_linear_core(ggml_context * ctx0, ggml_cgraph * gf, ggml_tensor * delta_input, ggml_tensor * inp_s_seq_qnext, ggml_tensor * inp_out_ids, uint32_t state_seq_id_local, bool reset_state_local, int il, const llm_build_cb & cb) const { @@ -540,10 +549,18 @@ ggml_tensor * delta_net::build_layer_attn_linear_core(ggml_context * ctx0, ggml_ } auto split_ssm_conv1d = (ggml_split_tensor_t *)l.ssm_conv1d->extra; GGML_ASSERT(split_ssm_conv1d && split_ssm_conv1d->splits[id]); + ggml_tensor * per_step_ckpt = nullptr; + if (save_per_step_states && il < (int)kv_self.ckpt.per_step_ssm.size()) { + per_step_ckpt = kv_self.ckpt.per_step_ssm[il][id]; + //GGML_ASSERT(per_step_ckpt); + } + if (save_per_step_states && il < (int)kv_self.ckpt.per_step_qkv.size() && id < (int)kv_self.ckpt.per_step_qkv[il].size()) { + build_qkv_cpy(ctx0, gf, qkv_mixed, kv_self.ckpt.per_step_qkv[il][id]); + } auto output = build_qkv(ctx0, split_s_l->splits[id], split_ssm_conv1d->splits[id], qkv_mixed, inp_s_seq_qnext, beta, gate, head_k_dim, num_k_heads_id, head_v_dim, num_v_heads_id, hparams.ssm_d_conv, state_seq_id_local, qnext_state_slots, reset_state_local, hparams.f_norm_rms_eps, - l.ssm_beta_alpha ? 0 : 1, il, cb, gf); + l.ssm_beta_alpha ? 0 : 1, il, cb, gf, save_per_step_states, per_step_ckpt); split_norm = (ggml_split_tensor_t *)l.ssm_norm->extra; GGML_ASSERT(split_norm && split_norm->splits[id]); auto split_ssm_out = (ggml_split_tensor_t *)l.ssm_out->extra; @@ -596,18 +613,19 @@ ggml_tensor * delta_net::build_layer_attn_linear_core(ggml_context * ctx0, ggml_ // Get per-step checkpoint tensor if available ggml_tensor * per_step_ckpt = nullptr; if (save_per_step_states && il < (int)kv_self.ckpt.per_step_ssm.size()) { - per_step_ckpt = kv_self.ckpt.per_step_ssm[il]; + per_step_ckpt = kv_self.ckpt.per_step_ssm[il].front(); } // Save qkv_mixed features for per-step conv state reconstruction - if (save_per_step_states && il < (int)kv_self.ckpt.per_step_qkv.size() && kv_self.ckpt.per_step_qkv[il] != nullptr) { - const int64_t conv_dim = qkv_mixed->ne[0]; - const int64_t n_tok_qkv = qkv_mixed->ne[1] * qkv_mixed->ne[2]; - ggml_tensor * qkv_flat = ggml_reshape_2d(ctx0, qkv_mixed, conv_dim, n_tok_qkv); - ggml_tensor * qkv_dst = ggml_view_2d(ctx0, kv_self.ckpt.per_step_qkv[il], - conv_dim, n_tok_qkv, conv_dim * sizeof(float), 0); - auto qkv_cpy = ggml_cpy(ctx0, qkv_flat, qkv_dst); - ggml_build_forward_expand(gf, qkv_cpy); + if (save_per_step_states && il < (int)kv_self.ckpt.per_step_qkv.size() && !kv_self.ckpt.per_step_qkv[il].empty()) { + build_qkv_cpy(ctx0, gf, qkv_mixed, kv_self.ckpt.per_step_qkv[il].front()); + //const int64_t conv_dim = qkv_mixed->ne[0]; + //const int64_t n_tok_qkv = qkv_mixed->ne[1] * qkv_mixed->ne[2]; + //ggml_tensor * qkv_flat = ggml_reshape_2d(ctx0, qkv_mixed, conv_dim, n_tok_qkv); + //ggml_tensor * qkv_dst = ggml_view_2d(ctx0, kv_self.ckpt.per_step_qkv[il].front(), + // conv_dim, n_tok_qkv, conv_dim * sizeof(float), 0); + //auto qkv_cpy = ggml_cpy(ctx0, qkv_flat, qkv_dst); + //ggml_build_forward_expand(gf, qkv_cpy); } auto output = build_qkv(ctx0, kv_self.s_l[il], model.layers[il].ssm_conv1d, diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 77ea6f13..13fe1811 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -297,8 +297,8 @@ struct llama_hparams { return ssm_d_state * ssm_d_inner; } - uint32_t n_embd_v_s_id(int nv) const { - if (ssm_n_group <= 0 || nv < 1 || ssm_dt_rank < 1) return 0; + std::pair n_embd_v_s_dims(int nv) const { + if (ssm_n_group <= 0 || nv < 1 || ssm_dt_rank < 1) return {0, 0}; int num_v_heads = ssm_dt_rank; int num_k_heads = ssm_n_group; int gqa_ratio = num_v_heads / num_k_heads; @@ -308,10 +308,16 @@ struct llama_hparams { int head_k_dim = ssm_d_state; int head_v_dim = ssm_d_inner / num_v_heads; uint32_t conv_dim = 2 * nk * head_k_dim + nv * head_v_dim; - uint32_t conv_state_dim = conv_dim * (ssm_d_conv - 1); + //uint32_t conv_state_dim = conv_dim * (ssm_d_conv - 1); + //uint32_t ssm_state_dim = head_v_dim * head_v_dim * nv; + //return {conv_state_dim, ssm_state_dim}; uint32_t ssm_state_dim = head_v_dim * head_v_dim * nv; - return conv_state_dim + ssm_state_dim; + return {conv_dim, ssm_state_dim}; + } + uint32_t n_embd_v_s_id(int nv) const { + auto [conv_dim, ssm_state_dim] = n_embd_v_s_dims(nv); + return (ssm_d_conv - 1) * conv_dim + ssm_state_dim; } bool is_recurrent(uint32_t il) const { diff --git a/src/llama.cpp b/src/llama.cpp index ca74214c..63bf3585 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -119,6 +119,7 @@ void llama_set_mtp_target_context(struct llama_context * ctx, struct llama_conte #include #include #include +#include #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data @@ -1273,7 +1274,6 @@ bool llama_kv_cache::checkpoint_alloc_shadows() { std::map> split_buft_entries; - uint32_t split_s_idx = 0; for (uint32_t il = 0; il < n_layer; ++il) { if (s_l[il] == nullptr) { continue; @@ -1287,7 +1287,6 @@ bool llama_kv_cache::checkpoint_alloc_shadows() { ggml_backend_buffer_type_t buft = ggml_backend_buffer_get_type(split_info->splits[d]->buffer); split_buft_entries[buft].push_back({split_info->splits[d], il, d}); } - split_s_idx++; } else { nonsplit_entries.push_back({s_l[il], il, -1}); } @@ -1362,15 +1361,14 @@ bool llama_kv_cache::checkpoint_alloc_shadows() { } // Build split shadow lookup - ckpt.split_s_l_shadow.resize(split_s_l.size()); - split_s_idx = 0; + ckpt.split_s_l_shadow.resize(n_layer); for (uint32_t il = 0; il < n_layer; ++il) { if (s_l[il] == nullptr || s_l[il]->extra == nullptr) { continue; } auto * split_info = (const ggml_split_tensor_t *)s_l[il]->extra; - auto & shadow_split = ckpt.split_s_l_shadow[split_s_idx]; + auto & shadow_split = ckpt.split_s_l_shadow[il]; shadow_split.resize(split_info->n_device, nullptr); for (int d = 0; d < split_info->n_device; ++d) { @@ -1383,7 +1381,6 @@ bool llama_kv_cache::checkpoint_alloc_shadows() { } } } - split_s_idx++; } ckpt.allocated = true; @@ -1403,7 +1400,6 @@ bool llama_kv_cache::checkpoint_save(ggml_backend_sched_t sched) { std::unordered_set backends_to_sync; - uint32_t split_s_idx = 0; for (uint32_t il = 0; il < n_layer; ++il) { if (s_l[il] == nullptr) { continue; @@ -1411,7 +1407,7 @@ bool llama_kv_cache::checkpoint_save(ggml_backend_sched_t sched) { if (s_l[il]->extra != nullptr) { auto * split_info = (const ggml_split_tensor_t *)s_l[il]->extra; - auto & shadow_split = ckpt.split_s_l_shadow[split_s_idx]; + auto & shadow_split = ckpt.split_s_l_shadow[il]; for (int d = 0; d < split_info->n_device; ++d) { if (split_info->splits[d] && shadow_split[d]) { //ggml_backend_tensor_copy(split_info->splits[d], shadow_split[d]); @@ -1420,7 +1416,6 @@ bool llama_kv_cache::checkpoint_save(ggml_backend_sched_t sched) { backends_to_sync.insert(src_backend); } } - split_s_idx++; } else { const size_t nbytes = ggml_nbytes(ckpt.s_l_shadow[il]); ggml_backend_tensor_get(s_l[il], ckpt.s_l_shadow[il]->data, 0, nbytes); @@ -1449,7 +1444,6 @@ bool llama_kv_cache::checkpoint_restore(ggml_backend_sched_t sched) { std::unordered_set backends_to_sync; - uint32_t split_s_idx = 0; for (uint32_t il = 0; il < n_layer; ++il) { if (s_l[il] == nullptr) { continue; @@ -1457,7 +1451,7 @@ bool llama_kv_cache::checkpoint_restore(ggml_backend_sched_t sched) { if (s_l[il]->extra != nullptr) { auto * split_info = (const ggml_split_tensor_t *)s_l[il]->extra; - auto & shadow_split = ckpt.split_s_l_shadow[split_s_idx]; + auto & shadow_split = ckpt.split_s_l_shadow[il]; for (int d = 0; d < split_info->n_device; ++d) { if (split_info->splits[d] && shadow_split[d]) { auto dst_backend = ggml_backend_sched_get_tensor_backend(sched, split_info->splits[d]); @@ -1465,7 +1459,6 @@ bool llama_kv_cache::checkpoint_restore(ggml_backend_sched_t sched) { backends_to_sync.insert(dst_backend); } } - split_s_idx++; } else { GGML_ASSERT(ggml_nbytes(ckpt.s_l_shadow[il]) == ggml_nbytes(s_l[il])); ggml_backend_tensor_copy(ckpt.s_l_shadow[il], s_l[il]); @@ -1483,7 +1476,7 @@ void llama_kv_cache::checkpoint_delete() { ckpt.saved = false; } -bool llama_kv_cache::per_step_alloc(int max_tokens) { +bool llama_kv_cache::per_step_alloc(const llama_model & model, int max_tokens) { if (ckpt.per_step_max_allocated >= max_tokens) { return true; } @@ -1503,8 +1496,8 @@ bool llama_kv_cache::per_step_alloc(int max_tokens) { } const uint32_t n_layer = (uint32_t)s_l.size(); - ckpt.per_step_ssm.resize(n_layer, nullptr); - ckpt.per_step_qkv.resize(n_layer, nullptr); + ckpt.per_step_ssm.resize(n_layer); + ckpt.per_step_qkv.resize(n_layer); const int64_t ssm_state_dim = ckpt.per_step_ssm_state_size; const int64_t conv_dim = ckpt.per_step_conv_dim; @@ -1514,14 +1507,26 @@ bool llama_kv_cache::per_step_alloc(int max_tokens) { return false; } - std::map>> buft_layers; + int num_v_heads = model.hparams.ssm_dt_rank; + int head_v_dim = model.hparams.ssm_d_inner / num_v_heads; + + std::map, ggml_backend_buffer_type_t>>> buft_layers; for (uint32_t il = 0; il < n_layer; ++il) { if (s_l[il] == nullptr) continue; - if (s_l[il]->extra != nullptr) continue; // skip split tensors - ggml_backend_buffer_type_t buft = ggml_backend_buffer_get_type(s_l[il]->buffer); - buft_layers[buft].push_back({il, buft}); + if (s_l[il]->extra) { + auto split_sl = (ggml_split_tensor_t *)s_l[il]->extra; + for (int id = 0; id < split_sl->n_device; ++id) { + if (!split_sl->splits[id]) continue; + auto buft = ggml_backend_buffer_get_type(split_sl->splits[id]->buffer); + GGML_ASSERT(buft != nullptr); + buft_layers[buft].push_back({{il, id}, buft}); + } + } else { + ggml_backend_buffer_type_t buft = ggml_backend_buffer_get_type(s_l[il]->buffer); + buft_layers[buft].push_back({{il, -1}, buft}); + } } for (auto & [buft, layers] : buft_layers) { @@ -1537,16 +1542,46 @@ bool llama_kv_cache::per_step_alloc(int max_tokens) { return false; } - for (auto & [il, bt] : layers) { + for (auto & [p, bt] : layers) { + auto [il, id] = p; // SSM state: max_tokens * ssm_state_dim - ggml_tensor * t_ssm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, (int64_t)max_tokens * ssm_state_dim); - ggml_format_name(t_ssm, "per_step_ssm_l%d", il); - ckpt.per_step_ssm[il] = t_ssm; + if (id < 0) { + GGML_ASSERT(ckpt.per_step_ssm[il].empty()); + GGML_ASSERT(ckpt.per_step_qkv[il].empty()); + ggml_tensor * t_ssm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, (int64_t)max_tokens * ssm_state_dim); + ggml_format_name(t_ssm, "per_step_ssm_l%d", il); + ckpt.per_step_ssm[il].push_back(t_ssm); - // Conv features (qkv_mixed): max_tokens * conv_dim - ggml_tensor * t_qkv = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, (int64_t)max_tokens * conv_dim); - ggml_format_name(t_qkv, "per_step_qkv_l%d", il); - ckpt.per_step_qkv[il] = t_qkv; + // Conv features (qkv_mixed): max_tokens * conv_dim + ggml_tensor * t_qkv = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, (int64_t)max_tokens * conv_dim); + ggml_format_name(t_qkv, "per_step_qkv_l%d", il); + ckpt.per_step_qkv[il].push_back(t_qkv); + } else { + auto split_sl = (ggml_split_tensor_t *)s_l[il]->extra; + auto split_ssm_out = (const ggml_split_tensor_t *)model.layers[il].ssm_out->extra; + GGML_ASSERT(split_ssm_out && split_ssm_out->splits[id]); + auto split = split_ssm_out->splits[id]; + GGML_ASSERT(split->ne[0] % head_v_dim == 0); + if (ckpt.per_step_ssm[il].empty()) { + ckpt.per_step_ssm[il].resize(split_sl->n_device, nullptr); + ckpt.per_step_qkv[il].resize(split_sl->n_device, nullptr); + } else { + GGML_ASSERT(int(ckpt.per_step_ssm[il].size()) == split_sl->n_device); + GGML_ASSERT(int(ckpt.per_step_qkv[il].size()) == split_sl->n_device); + GGML_ASSERT(ckpt.per_step_ssm[il][id] == nullptr); + GGML_ASSERT(ckpt.per_step_qkv[il][id] == nullptr); + } + int nv = split->ne[0] / head_v_dim; // number of heads handled by this device + auto [this_conv_dim, this_ssm_dim] = model.hparams.n_embd_v_s_dims(nv); + + auto t_ssm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, (int64_t)max_tokens * this_ssm_dim); + ggml_format_name(t_ssm, "per_step_ssm_l%d_%d", il, id); + ckpt.per_step_ssm[il][id] = t_ssm; + + auto t_qkv = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, (int64_t)max_tokens * this_conv_dim); + ggml_format_name(t_qkv, "per_step_qkv_l%d_%d", il, id); + ckpt.per_step_qkv[il][id] = t_qkv; + } } ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); @@ -1566,7 +1601,63 @@ bool llama_kv_cache::per_step_alloc(int max_tokens) { return true; } -bool llama_kv_cache::per_step_restore(ggml_backend_sched_t sched, int step) { +static void restore_recurrent_cache_tensors(int step, ggml_backend_sched_t sched, + size_t ssm_bytes, size_t conv_bytes, size_t conv_dim, int d_conv_m1, + ggml_tensor * s_l, ggml_tensor * per_step_ssm, + ggml_tensor * s_l_shadow, ggml_tensor * per_step_qkv, + std::vector & old_conv_buf, std::vector & qkv_buf, std::vector & conv_buf, + std::unordered_set & backends_to_sync) { + auto dst_backend = ggml_backend_sched_get_tensor_backend(sched, s_l); + auto dst = *s_l; + dst.ne[0] = ssm_bytes/sizeof(float); + dst.nb[1] = dst.nb[2] = dst.nb[3] = ssm_bytes; + dst.data = (char *)s_l->data + conv_bytes; + auto src = dst; + // I think the commented out version is correct, bt only after fixing the delta-net implementation. + // For now let's just reproduce what we have on the main branch, only avoiding the expensive + // device -> host -> device copy. + //src.data = (char *)per_step_ssm->data + (size_t)(step + 1) * ssm_bytes; + src.data = (char *)per_step_ssm->data + (size_t)step * ssm_bytes; + ggml_backend_tensor_copy_async(dst_backend, dst_backend, &src, &dst); + backends_to_sync.insert(dst_backend); + + if (s_l_shadow != nullptr) { + ggml_backend_tensor_get(s_l_shadow, old_conv_buf.data(), 0, conv_bytes); + } else { + memset(old_conv_buf.data(), 0, conv_bytes); + } + + auto qkv_needed = size_t(step + 1) * conv_dim; + if (per_step_qkv) { + ggml_backend_tensor_get(per_step_qkv, qkv_buf.data(), 0, qkv_needed * sizeof(float)); + } else { + memset(qkv_buf.data(), 0, qkv_needed * sizeof(float)); + } + + for (int32_t col = 0; col < d_conv_m1; col++) { + int32_t src_token = step - (d_conv_m1 - 1) + col; // e.g., K-2, K-1, K for d_conv=4 + if (src_token >= 0) { + for (int64_t d = 0; d < conv_dim; d++) { + conv_buf[col + d * d_conv_m1] = qkv_buf[d + (int64_t)src_token * conv_dim]; + } + } else { + int32_t old_col = d_conv_m1 + src_token; // maps to 0, 1, ... for early steps + if (old_col >= 0 && old_col < d_conv_m1) { + for (int64_t d = 0; d < conv_dim; d++) { + conv_buf[col + d * d_conv_m1] = old_conv_buf[old_col + d * d_conv_m1]; + } + } else { + for (int64_t d = 0; d < conv_dim; d++) { + conv_buf[col + d * d_conv_m1] = 0.0f; + } + } + } + } + + ggml_backend_tensor_set(s_l, conv_buf.data(), 0, conv_bytes); +} + +bool llama_kv_cache::per_step_restore(const llama_model & model, ggml_backend_sched_t sched, int step) { if (ckpt.per_step_ssm.empty() || step < 0) { return false; } @@ -1589,60 +1680,38 @@ bool llama_kv_cache::per_step_restore(ggml_backend_sched_t sched, int step) { const int64_t qkv_needed = (int64_t)(step + 1) * conv_dim; std::vector qkv_buf(qkv_needed); + int num_v_heads = model.hparams.ssm_dt_rank; + int head_v_dim = model.hparams.ssm_d_inner / num_v_heads; + const uint32_t n_layer = (uint32_t)s_l.size(); for (uint32_t il = 0; il < n_layer; ++il) { - if (s_l[il] == nullptr || ckpt.per_step_ssm[il] == nullptr) continue; - if (s_l[il]->extra != nullptr) continue; + if (s_l[il] == nullptr || ckpt.per_step_ssm[il].empty()) continue; - auto dst_backend = ggml_backend_sched_get_tensor_backend(sched, s_l[il]); - auto dst = *s_l[il]; - dst.ne[0] = ssm_bytes/sizeof(float); - dst.nb[1] = dst.nb[2] = dst.nb[3] = ssm_bytes; - dst.data = (char *)s_l[il]->data + conv_bytes; - auto src = dst; - // I think the commented out version is correct, bt only after fixing the delta-net implementation. - // For now let's just reproduce what we have on the main branch, only avoiding the expensive - // device -> host -> device copy. - //src.data = (char *)ckpt.per_step_ssm[il]->data + (size_t)(step + 1) * ssm_bytes; - src.data = (char *)ckpt.per_step_ssm[il]->data + (size_t)step * ssm_bytes; - ggml_backend_tensor_copy_async(dst_backend, dst_backend, &src, &dst); - backends_to_sync.insert(dst_backend); - - if (ckpt.s_l_shadow[il] != nullptr) { - ggml_backend_tensor_get(ckpt.s_l_shadow[il], old_conv_buf.data(), 0, conv_bytes); + if (!s_l[il]->extra) { + restore_recurrent_cache_tensors(step, sched, ssm_bytes, conv_bytes, conv_dim, d_conv_m1, + s_l[il], ckpt.per_step_ssm[il].front(), ckpt.s_l_shadow[il], ckpt.per_step_qkv[il].front(), + old_conv_buf, qkv_buf, conv_buf, backends_to_sync); } else { - memset(old_conv_buf.data(), 0, conv_bytes); - } - - if (ckpt.per_step_qkv[il] != nullptr) { - ggml_backend_tensor_get(ckpt.per_step_qkv[il], qkv_buf.data(), 0, qkv_needed * sizeof(float)); - } else { - memset(qkv_buf.data(), 0, qkv_needed * sizeof(float)); - } - - for (int32_t col = 0; col < d_conv_m1; col++) { - int32_t src_token = step - (d_conv_m1 - 1) + col; // e.g., K-2, K-1, K for d_conv=4 - if (src_token >= 0) { - for (int64_t d = 0; d < conv_dim; d++) { - conv_buf[col + d * d_conv_m1] = qkv_buf[d + (int64_t)src_token * conv_dim]; - } - } else { - int32_t old_col = d_conv_m1 + src_token; // maps to 0, 1, ... for early steps - if (old_col >= 0 && old_col < d_conv_m1) { - for (int64_t d = 0; d < conv_dim; d++) { - conv_buf[col + d * d_conv_m1] = old_conv_buf[old_col + d * d_conv_m1]; - } - } else { - for (int64_t d = 0; d < conv_dim; d++) { - conv_buf[col + d * d_conv_m1] = 0.0f; - } - } + auto split_sl = (const ggml_split_tensor_t *)s_l[il]->extra; + for (int id = 0; id < split_sl->n_device; ++id) { + if (!split_sl->splits[id]) continue; + auto split_ssm_out = (const ggml_split_tensor_t *)model.layers[il].ssm_out->extra; + GGML_ASSERT(split_ssm_out && split_ssm_out->splits[id]); + auto split = split_ssm_out->splits[id]; + GGML_ASSERT(split->ne[0] % head_v_dim == 0); + int nv = split->ne[0] / head_v_dim; // number of heads handled by this device + auto [this_conv_dim, this_ssm_dim] = model.hparams.n_embd_v_s_dims(nv); + auto this_conv_bytes = (d_conv - 1) * this_conv_dim * sizeof(float); + restore_recurrent_cache_tensors(step, sched, this_ssm_dim * sizeof(float), this_conv_bytes, this_conv_dim, d_conv_m1, + split_sl->splits[id], ckpt.per_step_ssm[il][id], ckpt.split_s_l_shadow[il][id], ckpt.per_step_qkv[il][id], + old_conv_buf, qkv_buf, conv_buf, backends_to_sync); } } - - ggml_backend_tensor_set(s_l[il], conv_buf.data(), 0, conv_bytes); } + // Strictly speaking we shouldn't need to do this. We are doing a synchronous copy of the + // convolution state in the above loop, and that involves a ggml_backend_synchronize call. + // But just in case. for (auto backend : backends_to_sync) { ggml_backend_synchronize(backend); } @@ -7024,9 +7093,13 @@ static bool spec_ckpt_try_per_step(llama_kv_cache & kv, const llama_model & mode for (const auto * sl : kv.s_l) { if (!sl) continue; if (sl->extra) { - kv.save_per_step_ssm = false; - return false; + has_gpu = true; + continue; } + //if (sl->extra) { + // kv.save_per_step_ssm = false; + // return false; + //} if (sl->buffer && !ggml_backend_buffer_is_host(sl->buffer)) { has_gpu = true; } else if (sl->buffer) { @@ -7055,7 +7128,7 @@ static bool spec_ckpt_try_per_step(llama_kv_cache & kv, const llama_model & mode kv.ckpt.per_step_d_conv = hp.ssm_d_conv; } - if (!kv.per_step_alloc(max_tokens)) { + if (!kv.per_step_alloc(model, max_tokens)) { kv.save_per_step_ssm = false; return false; } @@ -7131,7 +7204,7 @@ bool llama_spec_ckpt_restore(struct llama_context * ctx, llama_seq_id seq_id, switch (kv.ckpt.selected_spec_mode) { case LLAMA_SPEC_CKPT_PER_STEP: { - if (!kv.per_step_restore(ctx->sched, accepted_step)) { + if (!kv.per_step_restore(ctx->model, ctx->sched, accepted_step)) { return false; } const llama_pos accepted_pos = n_past + accepted_step;