diff --git a/src/llama-context.h b/src/llama-context.h index 6554f562..d3526cd2 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -143,7 +143,7 @@ struct llama_kv_cache { // Per-step checkpoint: allocate, restore step k's full state (SSM + conv) to cache bool per_step_alloc(int max_tokens); - bool per_step_restore(int step); + bool per_step_restore(ggml_backend_sched_t sched, int step); ~llama_kv_cache() { for (struct ggml_context * ctx : ctxs) { diff --git a/src/llama.cpp b/src/llama.cpp index 8edacfb0..ca74214c 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -1566,11 +1566,13 @@ bool llama_kv_cache::per_step_alloc(int max_tokens) { return true; } -bool llama_kv_cache::per_step_restore(int step) { +bool llama_kv_cache::per_step_restore(ggml_backend_sched_t sched, int step) { if (ckpt.per_step_ssm.empty() || step < 0) { return false; } + std::unordered_set backends_to_sync; + const int64_t ssm_state_dim = ckpt.per_step_ssm_state_size; const int64_t conv_state_dim = ckpt.per_step_conv_state_dim; const int64_t conv_dim = ckpt.per_step_conv_dim; @@ -1588,13 +1590,23 @@ bool llama_kv_cache::per_step_restore(int step) { std::vector qkv_buf(qkv_needed); const uint32_t n_layer = (uint32_t)s_l.size(); - int n_restored = 0; for (uint32_t il = 0; il < n_layer; ++il) { if (s_l[il] == nullptr || ckpt.per_step_ssm[il] == nullptr) continue; if (s_l[il]->extra != nullptr) continue; - ggml_backend_tensor_get(ckpt.per_step_ssm[il], ssm_buf.data(), - (size_t)step * ssm_bytes, ssm_bytes); + auto dst_backend = ggml_backend_sched_get_tensor_backend(sched, s_l[il]); + auto dst = *s_l[il]; + dst.ne[0] = ssm_bytes/sizeof(float); + dst.nb[1] = dst.nb[2] = dst.nb[3] = ssm_bytes; + dst.data = (char *)s_l[il]->data + conv_bytes; + auto src = dst; + // I think the commented out version is correct, bt only after fixing the delta-net implementation. + // For now let's just reproduce what we have on the main branch, only avoiding the expensive + // device -> host -> device copy. + //src.data = (char *)ckpt.per_step_ssm[il]->data + (size_t)(step + 1) * ssm_bytes; + src.data = (char *)ckpt.per_step_ssm[il]->data + (size_t)step * ssm_bytes; + ggml_backend_tensor_copy_async(dst_backend, dst_backend, &src, &dst); + backends_to_sync.insert(dst_backend); if (ckpt.s_l_shadow[il] != nullptr) { ggml_backend_tensor_get(ckpt.s_l_shadow[il], old_conv_buf.data(), 0, conv_bytes); @@ -1629,8 +1641,10 @@ bool llama_kv_cache::per_step_restore(int step) { } ggml_backend_tensor_set(s_l[il], conv_buf.data(), 0, conv_bytes); - ggml_backend_tensor_set(s_l[il], ssm_buf.data(), conv_bytes, ssm_bytes); - n_restored++; + } + + for (auto backend : backends_to_sync) { + ggml_backend_synchronize(backend); } return true; @@ -7117,7 +7131,7 @@ bool llama_spec_ckpt_restore(struct llama_context * ctx, llama_seq_id seq_id, switch (kv.ckpt.selected_spec_mode) { case LLAMA_SPEC_CKPT_PER_STEP: { - if (!kv.per_step_restore(accepted_step)) { + if (!kv.per_step_restore(ctx->sched, accepted_step)) { return false; } const llama_pos accepted_pos = n_past + accepted_step;