MTP: ebable per step recurrent state for split mode graph

This commit is contained in:
Kawrakow 2026-05-10 12:47:30 +00:00
parent 4bbdb8ed0b
commit d81090541b
4 changed files with 193 additions and 96 deletions

View File

@ -88,12 +88,12 @@ struct llama_kv_cache {
std::vector<std::vector<ggml_tensor *>> split_s_l_shadow;
// Per-step SSM state checkpoints for speculative decoding.
std::vector<ggml_tensor *> per_step_ssm;
std::vector<std::vector<ggml_tensor *>> per_step_ssm;
// Per-step conv feature buffer: stores qkv_mixed features from the
// verification forward pass so conv state can be reconstructed at any step.
// One tensor per recurrent layer, each sized [conv_dim * max_tokens].
std::vector<ggml_tensor *> per_step_qkv;
std::vector<std::vector<ggml_tensor *>> per_step_qkv;
int32_t per_step_n_tokens = 0;
int32_t per_step_max_allocated = 0;
@ -142,8 +142,8 @@ struct llama_kv_cache {
void checkpoint_delete();
// Per-step checkpoint: allocate, restore step k's full state (SSM + conv) to cache
bool per_step_alloc(int max_tokens);
bool per_step_restore(ggml_backend_sched_t sched, int step);
bool per_step_alloc(const llama_model & model, int max_tokens);
bool per_step_restore(const llama_model & model, ggml_backend_sched_t sched, int step);
~llama_kv_cache() {
for (struct ggml_context * ctx : ctxs) {

View File

@ -454,6 +454,15 @@ static ggml_tensor * get_input_tensor_sm_graph(ggml_context * ctx, ggml_tensor *
return cur;
}
static void build_qkv_cpy(ggml_context * ctx0, ggml_cgraph * gf, ggml_tensor * qkv_mixed, ggml_tensor * per_step_qkv) {
const int64_t conv_dim = qkv_mixed->ne[0];
const int64_t n_tok_qkv = qkv_mixed->ne[1] * qkv_mixed->ne[2];
ggml_tensor * qkv_flat = ggml_reshape_2d(ctx0, qkv_mixed, conv_dim, n_tok_qkv);
ggml_tensor * qkv_dst = ggml_view_2d(ctx0, per_step_qkv, conv_dim, n_tok_qkv, conv_dim * sizeof(float), 0);
auto qkv_cpy = ggml_cpy(ctx0, qkv_flat, qkv_dst);
ggml_build_forward_expand(gf, qkv_cpy);
}
ggml_tensor * delta_net::build_layer_attn_linear_core(ggml_context * ctx0, ggml_cgraph * gf,
ggml_tensor * delta_input, ggml_tensor * inp_s_seq_qnext, ggml_tensor * inp_out_ids,
uint32_t state_seq_id_local, bool reset_state_local, int il, const llm_build_cb & cb) const {
@ -540,10 +549,18 @@ ggml_tensor * delta_net::build_layer_attn_linear_core(ggml_context * ctx0, ggml_
}
auto split_ssm_conv1d = (ggml_split_tensor_t *)l.ssm_conv1d->extra;
GGML_ASSERT(split_ssm_conv1d && split_ssm_conv1d->splits[id]);
ggml_tensor * per_step_ckpt = nullptr;
if (save_per_step_states && il < (int)kv_self.ckpt.per_step_ssm.size()) {
per_step_ckpt = kv_self.ckpt.per_step_ssm[il][id];
//GGML_ASSERT(per_step_ckpt);
}
if (save_per_step_states && il < (int)kv_self.ckpt.per_step_qkv.size() && id < (int)kv_self.ckpt.per_step_qkv[il].size()) {
build_qkv_cpy(ctx0, gf, qkv_mixed, kv_self.ckpt.per_step_qkv[il][id]);
}
auto output = build_qkv(ctx0, split_s_l->splits[id], split_ssm_conv1d->splits[id], qkv_mixed, inp_s_seq_qnext, beta, gate,
head_k_dim, num_k_heads_id, head_v_dim, num_v_heads_id, hparams.ssm_d_conv,
state_seq_id_local, qnext_state_slots, reset_state_local, hparams.f_norm_rms_eps,
l.ssm_beta_alpha ? 0 : 1, il, cb, gf);
l.ssm_beta_alpha ? 0 : 1, il, cb, gf, save_per_step_states, per_step_ckpt);
split_norm = (ggml_split_tensor_t *)l.ssm_norm->extra;
GGML_ASSERT(split_norm && split_norm->splits[id]);
auto split_ssm_out = (ggml_split_tensor_t *)l.ssm_out->extra;
@ -596,18 +613,19 @@ ggml_tensor * delta_net::build_layer_attn_linear_core(ggml_context * ctx0, ggml_
// Get per-step checkpoint tensor if available
ggml_tensor * per_step_ckpt = nullptr;
if (save_per_step_states && il < (int)kv_self.ckpt.per_step_ssm.size()) {
per_step_ckpt = kv_self.ckpt.per_step_ssm[il];
per_step_ckpt = kv_self.ckpt.per_step_ssm[il].front();
}
// Save qkv_mixed features for per-step conv state reconstruction
if (save_per_step_states && il < (int)kv_self.ckpt.per_step_qkv.size() && kv_self.ckpt.per_step_qkv[il] != nullptr) {
const int64_t conv_dim = qkv_mixed->ne[0];
const int64_t n_tok_qkv = qkv_mixed->ne[1] * qkv_mixed->ne[2];
ggml_tensor * qkv_flat = ggml_reshape_2d(ctx0, qkv_mixed, conv_dim, n_tok_qkv);
ggml_tensor * qkv_dst = ggml_view_2d(ctx0, kv_self.ckpt.per_step_qkv[il],
conv_dim, n_tok_qkv, conv_dim * sizeof(float), 0);
auto qkv_cpy = ggml_cpy(ctx0, qkv_flat, qkv_dst);
ggml_build_forward_expand(gf, qkv_cpy);
if (save_per_step_states && il < (int)kv_self.ckpt.per_step_qkv.size() && !kv_self.ckpt.per_step_qkv[il].empty()) {
build_qkv_cpy(ctx0, gf, qkv_mixed, kv_self.ckpt.per_step_qkv[il].front());
//const int64_t conv_dim = qkv_mixed->ne[0];
//const int64_t n_tok_qkv = qkv_mixed->ne[1] * qkv_mixed->ne[2];
//ggml_tensor * qkv_flat = ggml_reshape_2d(ctx0, qkv_mixed, conv_dim, n_tok_qkv);
//ggml_tensor * qkv_dst = ggml_view_2d(ctx0, kv_self.ckpt.per_step_qkv[il].front(),
// conv_dim, n_tok_qkv, conv_dim * sizeof(float), 0);
//auto qkv_cpy = ggml_cpy(ctx0, qkv_flat, qkv_dst);
//ggml_build_forward_expand(gf, qkv_cpy);
}
auto output = build_qkv(ctx0, kv_self.s_l[il], model.layers[il].ssm_conv1d,

View File

@ -297,8 +297,8 @@ struct llama_hparams {
return ssm_d_state * ssm_d_inner;
}
uint32_t n_embd_v_s_id(int nv) const {
if (ssm_n_group <= 0 || nv < 1 || ssm_dt_rank < 1) return 0;
std::pair<uint32_t, uint32_t> n_embd_v_s_dims(int nv) const {
if (ssm_n_group <= 0 || nv < 1 || ssm_dt_rank < 1) return {0, 0};
int num_v_heads = ssm_dt_rank;
int num_k_heads = ssm_n_group;
int gqa_ratio = num_v_heads / num_k_heads;
@ -308,10 +308,16 @@ struct llama_hparams {
int head_k_dim = ssm_d_state;
int head_v_dim = ssm_d_inner / num_v_heads;
uint32_t conv_dim = 2 * nk * head_k_dim + nv * head_v_dim;
uint32_t conv_state_dim = conv_dim * (ssm_d_conv - 1);
//uint32_t conv_state_dim = conv_dim * (ssm_d_conv - 1);
//uint32_t ssm_state_dim = head_v_dim * head_v_dim * nv;
//return {conv_state_dim, ssm_state_dim};
uint32_t ssm_state_dim = head_v_dim * head_v_dim * nv;
return conv_state_dim + ssm_state_dim;
return {conv_dim, ssm_state_dim};
}
uint32_t n_embd_v_s_id(int nv) const {
auto [conv_dim, ssm_state_dim] = n_embd_v_s_dims(nv);
return (ssm_d_conv - 1) * conv_dim + ssm_state_dim;
}
bool is_recurrent(uint32_t il) const {

View File

@ -119,6 +119,7 @@ void llama_set_mtp_target_context(struct llama_context * ctx, struct llama_conte
#include <type_traits>
#include <unordered_map>
#include <regex>
#include <tuple>
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
@ -1273,7 +1274,6 @@ bool llama_kv_cache::checkpoint_alloc_shadows() {
std::map<ggml_backend_buffer_type_t, std::vector<tensor_entry>> split_buft_entries;
uint32_t split_s_idx = 0;
for (uint32_t il = 0; il < n_layer; ++il) {
if (s_l[il] == nullptr) {
continue;
@ -1287,7 +1287,6 @@ bool llama_kv_cache::checkpoint_alloc_shadows() {
ggml_backend_buffer_type_t buft = ggml_backend_buffer_get_type(split_info->splits[d]->buffer);
split_buft_entries[buft].push_back({split_info->splits[d], il, d});
}
split_s_idx++;
} else {
nonsplit_entries.push_back({s_l[il], il, -1});
}
@ -1362,15 +1361,14 @@ bool llama_kv_cache::checkpoint_alloc_shadows() {
}
// Build split shadow lookup
ckpt.split_s_l_shadow.resize(split_s_l.size());
split_s_idx = 0;
ckpt.split_s_l_shadow.resize(n_layer);
for (uint32_t il = 0; il < n_layer; ++il) {
if (s_l[il] == nullptr || s_l[il]->extra == nullptr) {
continue;
}
auto * split_info = (const ggml_split_tensor_t *)s_l[il]->extra;
auto & shadow_split = ckpt.split_s_l_shadow[split_s_idx];
auto & shadow_split = ckpt.split_s_l_shadow[il];
shadow_split.resize(split_info->n_device, nullptr);
for (int d = 0; d < split_info->n_device; ++d) {
@ -1383,7 +1381,6 @@ bool llama_kv_cache::checkpoint_alloc_shadows() {
}
}
}
split_s_idx++;
}
ckpt.allocated = true;
@ -1403,7 +1400,6 @@ bool llama_kv_cache::checkpoint_save(ggml_backend_sched_t sched) {
std::unordered_set<ggml_backend_t> backends_to_sync;
uint32_t split_s_idx = 0;
for (uint32_t il = 0; il < n_layer; ++il) {
if (s_l[il] == nullptr) {
continue;
@ -1411,7 +1407,7 @@ bool llama_kv_cache::checkpoint_save(ggml_backend_sched_t sched) {
if (s_l[il]->extra != nullptr) {
auto * split_info = (const ggml_split_tensor_t *)s_l[il]->extra;
auto & shadow_split = ckpt.split_s_l_shadow[split_s_idx];
auto & shadow_split = ckpt.split_s_l_shadow[il];
for (int d = 0; d < split_info->n_device; ++d) {
if (split_info->splits[d] && shadow_split[d]) {
//ggml_backend_tensor_copy(split_info->splits[d], shadow_split[d]);
@ -1420,7 +1416,6 @@ bool llama_kv_cache::checkpoint_save(ggml_backend_sched_t sched) {
backends_to_sync.insert(src_backend);
}
}
split_s_idx++;
} else {
const size_t nbytes = ggml_nbytes(ckpt.s_l_shadow[il]);
ggml_backend_tensor_get(s_l[il], ckpt.s_l_shadow[il]->data, 0, nbytes);
@ -1449,7 +1444,6 @@ bool llama_kv_cache::checkpoint_restore(ggml_backend_sched_t sched) {
std::unordered_set<ggml_backend_t> backends_to_sync;
uint32_t split_s_idx = 0;
for (uint32_t il = 0; il < n_layer; ++il) {
if (s_l[il] == nullptr) {
continue;
@ -1457,7 +1451,7 @@ bool llama_kv_cache::checkpoint_restore(ggml_backend_sched_t sched) {
if (s_l[il]->extra != nullptr) {
auto * split_info = (const ggml_split_tensor_t *)s_l[il]->extra;
auto & shadow_split = ckpt.split_s_l_shadow[split_s_idx];
auto & shadow_split = ckpt.split_s_l_shadow[il];
for (int d = 0; d < split_info->n_device; ++d) {
if (split_info->splits[d] && shadow_split[d]) {
auto dst_backend = ggml_backend_sched_get_tensor_backend(sched, split_info->splits[d]);
@ -1465,7 +1459,6 @@ bool llama_kv_cache::checkpoint_restore(ggml_backend_sched_t sched) {
backends_to_sync.insert(dst_backend);
}
}
split_s_idx++;
} else {
GGML_ASSERT(ggml_nbytes(ckpt.s_l_shadow[il]) == ggml_nbytes(s_l[il]));
ggml_backend_tensor_copy(ckpt.s_l_shadow[il], s_l[il]);
@ -1483,7 +1476,7 @@ void llama_kv_cache::checkpoint_delete() {
ckpt.saved = false;
}
bool llama_kv_cache::per_step_alloc(int max_tokens) {
bool llama_kv_cache::per_step_alloc(const llama_model & model, int max_tokens) {
if (ckpt.per_step_max_allocated >= max_tokens) {
return true;
}
@ -1503,8 +1496,8 @@ bool llama_kv_cache::per_step_alloc(int max_tokens) {
}
const uint32_t n_layer = (uint32_t)s_l.size();
ckpt.per_step_ssm.resize(n_layer, nullptr);
ckpt.per_step_qkv.resize(n_layer, nullptr);
ckpt.per_step_ssm.resize(n_layer);
ckpt.per_step_qkv.resize(n_layer);
const int64_t ssm_state_dim = ckpt.per_step_ssm_state_size;
const int64_t conv_dim = ckpt.per_step_conv_dim;
@ -1514,14 +1507,26 @@ bool llama_kv_cache::per_step_alloc(int max_tokens) {
return false;
}
std::map<ggml_backend_buffer_type_t, std::vector<std::pair<uint32_t, ggml_backend_buffer_type_t>>> buft_layers;
int num_v_heads = model.hparams.ssm_dt_rank;
int head_v_dim = model.hparams.ssm_d_inner / num_v_heads;
std::map<ggml_backend_buffer_type_t, std::vector<std::pair<std::pair<uint32_t, int32_t>, ggml_backend_buffer_type_t>>> buft_layers;
for (uint32_t il = 0; il < n_layer; ++il) {
if (s_l[il] == nullptr) continue;
if (s_l[il]->extra != nullptr) continue; // skip split tensors
ggml_backend_buffer_type_t buft = ggml_backend_buffer_get_type(s_l[il]->buffer);
buft_layers[buft].push_back({il, buft});
if (s_l[il]->extra) {
auto split_sl = (ggml_split_tensor_t *)s_l[il]->extra;
for (int id = 0; id < split_sl->n_device; ++id) {
if (!split_sl->splits[id]) continue;
auto buft = ggml_backend_buffer_get_type(split_sl->splits[id]->buffer);
GGML_ASSERT(buft != nullptr);
buft_layers[buft].push_back({{il, id}, buft});
}
} else {
ggml_backend_buffer_type_t buft = ggml_backend_buffer_get_type(s_l[il]->buffer);
buft_layers[buft].push_back({{il, -1}, buft});
}
}
for (auto & [buft, layers] : buft_layers) {
@ -1537,16 +1542,46 @@ bool llama_kv_cache::per_step_alloc(int max_tokens) {
return false;
}
for (auto & [il, bt] : layers) {
for (auto & [p, bt] : layers) {
auto [il, id] = p;
// SSM state: max_tokens * ssm_state_dim
ggml_tensor * t_ssm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, (int64_t)max_tokens * ssm_state_dim);
ggml_format_name(t_ssm, "per_step_ssm_l%d", il);
ckpt.per_step_ssm[il] = t_ssm;
if (id < 0) {
GGML_ASSERT(ckpt.per_step_ssm[il].empty());
GGML_ASSERT(ckpt.per_step_qkv[il].empty());
ggml_tensor * t_ssm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, (int64_t)max_tokens * ssm_state_dim);
ggml_format_name(t_ssm, "per_step_ssm_l%d", il);
ckpt.per_step_ssm[il].push_back(t_ssm);
// Conv features (qkv_mixed): max_tokens * conv_dim
ggml_tensor * t_qkv = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, (int64_t)max_tokens * conv_dim);
ggml_format_name(t_qkv, "per_step_qkv_l%d", il);
ckpt.per_step_qkv[il] = t_qkv;
// Conv features (qkv_mixed): max_tokens * conv_dim
ggml_tensor * t_qkv = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, (int64_t)max_tokens * conv_dim);
ggml_format_name(t_qkv, "per_step_qkv_l%d", il);
ckpt.per_step_qkv[il].push_back(t_qkv);
} else {
auto split_sl = (ggml_split_tensor_t *)s_l[il]->extra;
auto split_ssm_out = (const ggml_split_tensor_t *)model.layers[il].ssm_out->extra;
GGML_ASSERT(split_ssm_out && split_ssm_out->splits[id]);
auto split = split_ssm_out->splits[id];
GGML_ASSERT(split->ne[0] % head_v_dim == 0);
if (ckpt.per_step_ssm[il].empty()) {
ckpt.per_step_ssm[il].resize(split_sl->n_device, nullptr);
ckpt.per_step_qkv[il].resize(split_sl->n_device, nullptr);
} else {
GGML_ASSERT(int(ckpt.per_step_ssm[il].size()) == split_sl->n_device);
GGML_ASSERT(int(ckpt.per_step_qkv[il].size()) == split_sl->n_device);
GGML_ASSERT(ckpt.per_step_ssm[il][id] == nullptr);
GGML_ASSERT(ckpt.per_step_qkv[il][id] == nullptr);
}
int nv = split->ne[0] / head_v_dim; // number of heads handled by this device
auto [this_conv_dim, this_ssm_dim] = model.hparams.n_embd_v_s_dims(nv);
auto t_ssm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, (int64_t)max_tokens * this_ssm_dim);
ggml_format_name(t_ssm, "per_step_ssm_l%d_%d", il, id);
ckpt.per_step_ssm[il][id] = t_ssm;
auto t_qkv = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, (int64_t)max_tokens * this_conv_dim);
ggml_format_name(t_qkv, "per_step_qkv_l%d_%d", il, id);
ckpt.per_step_qkv[il][id] = t_qkv;
}
}
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
@ -1566,7 +1601,63 @@ bool llama_kv_cache::per_step_alloc(int max_tokens) {
return true;
}
bool llama_kv_cache::per_step_restore(ggml_backend_sched_t sched, int step) {
static void restore_recurrent_cache_tensors(int step, ggml_backend_sched_t sched,
size_t ssm_bytes, size_t conv_bytes, size_t conv_dim, int d_conv_m1,
ggml_tensor * s_l, ggml_tensor * per_step_ssm,
ggml_tensor * s_l_shadow, ggml_tensor * per_step_qkv,
std::vector<float> & old_conv_buf, std::vector<float> & qkv_buf, std::vector<float> & conv_buf,
std::unordered_set<ggml_backend_t> & backends_to_sync) {
auto dst_backend = ggml_backend_sched_get_tensor_backend(sched, s_l);
auto dst = *s_l;
dst.ne[0] = ssm_bytes/sizeof(float);
dst.nb[1] = dst.nb[2] = dst.nb[3] = ssm_bytes;
dst.data = (char *)s_l->data + conv_bytes;
auto src = dst;
// I think the commented out version is correct, bt only after fixing the delta-net implementation.
// For now let's just reproduce what we have on the main branch, only avoiding the expensive
// device -> host -> device copy.
//src.data = (char *)per_step_ssm->data + (size_t)(step + 1) * ssm_bytes;
src.data = (char *)per_step_ssm->data + (size_t)step * ssm_bytes;
ggml_backend_tensor_copy_async(dst_backend, dst_backend, &src, &dst);
backends_to_sync.insert(dst_backend);
if (s_l_shadow != nullptr) {
ggml_backend_tensor_get(s_l_shadow, old_conv_buf.data(), 0, conv_bytes);
} else {
memset(old_conv_buf.data(), 0, conv_bytes);
}
auto qkv_needed = size_t(step + 1) * conv_dim;
if (per_step_qkv) {
ggml_backend_tensor_get(per_step_qkv, qkv_buf.data(), 0, qkv_needed * sizeof(float));
} else {
memset(qkv_buf.data(), 0, qkv_needed * sizeof(float));
}
for (int32_t col = 0; col < d_conv_m1; col++) {
int32_t src_token = step - (d_conv_m1 - 1) + col; // e.g., K-2, K-1, K for d_conv=4
if (src_token >= 0) {
for (int64_t d = 0; d < conv_dim; d++) {
conv_buf[col + d * d_conv_m1] = qkv_buf[d + (int64_t)src_token * conv_dim];
}
} else {
int32_t old_col = d_conv_m1 + src_token; // maps to 0, 1, ... for early steps
if (old_col >= 0 && old_col < d_conv_m1) {
for (int64_t d = 0; d < conv_dim; d++) {
conv_buf[col + d * d_conv_m1] = old_conv_buf[old_col + d * d_conv_m1];
}
} else {
for (int64_t d = 0; d < conv_dim; d++) {
conv_buf[col + d * d_conv_m1] = 0.0f;
}
}
}
}
ggml_backend_tensor_set(s_l, conv_buf.data(), 0, conv_bytes);
}
bool llama_kv_cache::per_step_restore(const llama_model & model, ggml_backend_sched_t sched, int step) {
if (ckpt.per_step_ssm.empty() || step < 0) {
return false;
}
@ -1589,60 +1680,38 @@ bool llama_kv_cache::per_step_restore(ggml_backend_sched_t sched, int step) {
const int64_t qkv_needed = (int64_t)(step + 1) * conv_dim;
std::vector<float> qkv_buf(qkv_needed);
int num_v_heads = model.hparams.ssm_dt_rank;
int head_v_dim = model.hparams.ssm_d_inner / num_v_heads;
const uint32_t n_layer = (uint32_t)s_l.size();
for (uint32_t il = 0; il < n_layer; ++il) {
if (s_l[il] == nullptr || ckpt.per_step_ssm[il] == nullptr) continue;
if (s_l[il]->extra != nullptr) continue;
if (s_l[il] == nullptr || ckpt.per_step_ssm[il].empty()) continue;
auto dst_backend = ggml_backend_sched_get_tensor_backend(sched, s_l[il]);
auto dst = *s_l[il];
dst.ne[0] = ssm_bytes/sizeof(float);
dst.nb[1] = dst.nb[2] = dst.nb[3] = ssm_bytes;
dst.data = (char *)s_l[il]->data + conv_bytes;
auto src = dst;
// I think the commented out version is correct, bt only after fixing the delta-net implementation.
// For now let's just reproduce what we have on the main branch, only avoiding the expensive
// device -> host -> device copy.
//src.data = (char *)ckpt.per_step_ssm[il]->data + (size_t)(step + 1) * ssm_bytes;
src.data = (char *)ckpt.per_step_ssm[il]->data + (size_t)step * ssm_bytes;
ggml_backend_tensor_copy_async(dst_backend, dst_backend, &src, &dst);
backends_to_sync.insert(dst_backend);
if (ckpt.s_l_shadow[il] != nullptr) {
ggml_backend_tensor_get(ckpt.s_l_shadow[il], old_conv_buf.data(), 0, conv_bytes);
if (!s_l[il]->extra) {
restore_recurrent_cache_tensors(step, sched, ssm_bytes, conv_bytes, conv_dim, d_conv_m1,
s_l[il], ckpt.per_step_ssm[il].front(), ckpt.s_l_shadow[il], ckpt.per_step_qkv[il].front(),
old_conv_buf, qkv_buf, conv_buf, backends_to_sync);
} else {
memset(old_conv_buf.data(), 0, conv_bytes);
}
if (ckpt.per_step_qkv[il] != nullptr) {
ggml_backend_tensor_get(ckpt.per_step_qkv[il], qkv_buf.data(), 0, qkv_needed * sizeof(float));
} else {
memset(qkv_buf.data(), 0, qkv_needed * sizeof(float));
}
for (int32_t col = 0; col < d_conv_m1; col++) {
int32_t src_token = step - (d_conv_m1 - 1) + col; // e.g., K-2, K-1, K for d_conv=4
if (src_token >= 0) {
for (int64_t d = 0; d < conv_dim; d++) {
conv_buf[col + d * d_conv_m1] = qkv_buf[d + (int64_t)src_token * conv_dim];
}
} else {
int32_t old_col = d_conv_m1 + src_token; // maps to 0, 1, ... for early steps
if (old_col >= 0 && old_col < d_conv_m1) {
for (int64_t d = 0; d < conv_dim; d++) {
conv_buf[col + d * d_conv_m1] = old_conv_buf[old_col + d * d_conv_m1];
}
} else {
for (int64_t d = 0; d < conv_dim; d++) {
conv_buf[col + d * d_conv_m1] = 0.0f;
}
}
auto split_sl = (const ggml_split_tensor_t *)s_l[il]->extra;
for (int id = 0; id < split_sl->n_device; ++id) {
if (!split_sl->splits[id]) continue;
auto split_ssm_out = (const ggml_split_tensor_t *)model.layers[il].ssm_out->extra;
GGML_ASSERT(split_ssm_out && split_ssm_out->splits[id]);
auto split = split_ssm_out->splits[id];
GGML_ASSERT(split->ne[0] % head_v_dim == 0);
int nv = split->ne[0] / head_v_dim; // number of heads handled by this device
auto [this_conv_dim, this_ssm_dim] = model.hparams.n_embd_v_s_dims(nv);
auto this_conv_bytes = (d_conv - 1) * this_conv_dim * sizeof(float);
restore_recurrent_cache_tensors(step, sched, this_ssm_dim * sizeof(float), this_conv_bytes, this_conv_dim, d_conv_m1,
split_sl->splits[id], ckpt.per_step_ssm[il][id], ckpt.split_s_l_shadow[il][id], ckpt.per_step_qkv[il][id],
old_conv_buf, qkv_buf, conv_buf, backends_to_sync);
}
}
ggml_backend_tensor_set(s_l[il], conv_buf.data(), 0, conv_bytes);
}
// Strictly speaking we shouldn't need to do this. We are doing a synchronous copy of the
// convolution state in the above loop, and that involves a ggml_backend_synchronize call.
// But just in case.
for (auto backend : backends_to_sync) {
ggml_backend_synchronize(backend);
}
@ -7024,9 +7093,13 @@ static bool spec_ckpt_try_per_step(llama_kv_cache & kv, const llama_model & mode
for (const auto * sl : kv.s_l) {
if (!sl) continue;
if (sl->extra) {
kv.save_per_step_ssm = false;
return false;
has_gpu = true;
continue;
}
//if (sl->extra) {
// kv.save_per_step_ssm = false;
// return false;
//}
if (sl->buffer && !ggml_backend_buffer_is_host(sl->buffer)) {
has_gpu = true;
} else if (sl->buffer) {
@ -7055,7 +7128,7 @@ static bool spec_ckpt_try_per_step(llama_kv_cache & kv, const llama_model & mode
kv.ckpt.per_step_d_conv = hp.ssm_d_conv;
}
if (!kv.per_step_alloc(max_tokens)) {
if (!kv.per_step_alloc(model, max_tokens)) {
kv.save_per_step_ssm = false;
return false;
}
@ -7131,7 +7204,7 @@ bool llama_spec_ckpt_restore(struct llama_context * ctx, llama_seq_id seq_id,
switch (kv.ckpt.selected_spec_mode) {
case LLAMA_SPEC_CKPT_PER_STEP: {
if (!kv.per_step_restore(ctx->sched, accepted_step)) {
if (!kv.per_step_restore(ctx->model, ctx->sched, accepted_step)) {
return false;
}
const llama_pos accepted_pos = n_past + accepted_step;