Faster per step recurrent state restore when using MTP

This commit is contained in:
Kawrakow 2026-05-09 13:31:03 +00:00
parent ab0f22b819
commit f6deca0f97
2 changed files with 22 additions and 8 deletions

View File

@ -143,7 +143,7 @@ struct llama_kv_cache {
// Per-step checkpoint: allocate, restore step k's full state (SSM + conv) to cache
bool per_step_alloc(int max_tokens);
bool per_step_restore(int step);
bool per_step_restore(ggml_backend_sched_t sched, int step);
~llama_kv_cache() {
for (struct ggml_context * ctx : ctxs) {

View File

@ -1559,11 +1559,13 @@ bool llama_kv_cache::per_step_alloc(int max_tokens) {
return true;
}
bool llama_kv_cache::per_step_restore(int step) {
bool llama_kv_cache::per_step_restore(ggml_backend_sched_t sched, int step) {
if (ckpt.per_step_ssm.empty() || step < 0) {
return false;
}
std::unordered_set<ggml_backend_t> backends_to_sync;
const int64_t ssm_state_dim = ckpt.per_step_ssm_state_size;
const int64_t conv_state_dim = ckpt.per_step_conv_state_dim;
const int64_t conv_dim = ckpt.per_step_conv_dim;
@ -1581,13 +1583,23 @@ bool llama_kv_cache::per_step_restore(int step) {
std::vector<float> qkv_buf(qkv_needed);
const uint32_t n_layer = (uint32_t)s_l.size();
int n_restored = 0;
for (uint32_t il = 0; il < n_layer; ++il) {
if (s_l[il] == nullptr || ckpt.per_step_ssm[il] == nullptr) continue;
if (s_l[il]->extra != nullptr) continue;
ggml_backend_tensor_get(ckpt.per_step_ssm[il], ssm_buf.data(),
(size_t)step * ssm_bytes, ssm_bytes);
auto dst_backend = ggml_backend_sched_get_tensor_backend(sched, s_l[il]);
auto dst = *s_l[il];
dst.ne[0] = ssm_bytes/sizeof(float);
dst.nb[1] = dst.nb[2] = dst.nb[3] = ssm_bytes;
dst.data = (char *)s_l[il]->data + conv_bytes;
auto src = dst;
// I think the commented out version is correct, bt only after fixing the delta-net implementation.
// For now let's just reproduce what we have on the main branch, only avoiding the expensive
// device -> host -> device copy.
//src.data = (char *)ckpt.per_step_ssm[il]->data + (size_t)(step + 1) * ssm_bytes;
src.data = (char *)ckpt.per_step_ssm[il]->data + (size_t)step * ssm_bytes;
ggml_backend_tensor_copy_async(dst_backend, dst_backend, &src, &dst);
backends_to_sync.insert(dst_backend);
if (ckpt.s_l_shadow[il] != nullptr) {
ggml_backend_tensor_get(ckpt.s_l_shadow[il], old_conv_buf.data(), 0, conv_bytes);
@ -1622,8 +1634,10 @@ bool llama_kv_cache::per_step_restore(int step) {
}
ggml_backend_tensor_set(s_l[il], conv_buf.data(), 0, conv_bytes);
ggml_backend_tensor_set(s_l[il], ssm_buf.data(), conv_bytes, ssm_bytes);
n_restored++;
}
for (auto backend : backends_to_sync) {
ggml_backend_synchronize(backend);
}
return true;
@ -7057,7 +7071,7 @@ bool llama_spec_ckpt_restore(struct llama_context * ctx, llama_seq_id seq_id,
switch (kv.ckpt.selected_spec_mode) {
case LLAMA_SPEC_CKPT_PER_STEP: {
if (!kv.per_step_restore(accepted_step)) {
if (!kv.per_step_restore(ctx->sched, accepted_step)) {
return false;
}
const llama_pos accepted_pos = n_past + accepted_step;