Use async copies to save/restore recurrent state (#1759)

This commit is contained in:
Kawrakow 2026-05-09 08:31:56 +03:00 committed by GitHub
parent 9f60de9cc5
commit 2f0b47c19d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 27 additions and 10 deletions

View File

@ -4064,7 +4064,7 @@ GGML_CALL static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_
}
} else {
// src and dst are on the same backend
printf("Why is this being invoked?\n");
// printf("Why is this being invoked?\n");
CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyDeviceToDevice, cuda_ctx_src->stream()));
}
return true;

View File

@ -137,8 +137,8 @@ struct llama_kv_cache {
bool checkpoint_alloc_shadows();
bool checkpoint_supported() const;
bool checkpoint_save();
bool checkpoint_restore();
bool checkpoint_save(ggml_backend_sched_t sched);
bool checkpoint_restore(ggml_backend_sched_t sched);
void checkpoint_delete();
// Per-step checkpoint: allocate, restore step k's full state (SSM + conv) to cache

View File

@ -1383,7 +1383,7 @@ bool llama_kv_cache::checkpoint_alloc_shadows() {
return true;
}
bool llama_kv_cache::checkpoint_save() {
bool llama_kv_cache::checkpoint_save(ggml_backend_sched_t sched) {
if (!checkpoint_alloc_shadows()) {
return false;
}
@ -1394,6 +1394,8 @@ bool llama_kv_cache::checkpoint_save() {
ckpt.head_snapshot = head;
ckpt.used_snapshot = used;
std::unordered_set<ggml_backend_t> backends_to_sync;
uint32_t split_s_idx = 0;
for (uint32_t il = 0; il < n_layer; ++il) {
if (s_l[il] == nullptr) {
@ -1405,7 +1407,10 @@ bool llama_kv_cache::checkpoint_save() {
auto & shadow_split = ckpt.split_s_l_shadow[split_s_idx];
for (int d = 0; d < split_info->n_device; ++d) {
if (split_info->splits[d] && shadow_split[d]) {
ggml_backend_tensor_copy(split_info->splits[d], shadow_split[d]);
//ggml_backend_tensor_copy(split_info->splits[d], shadow_split[d]);
auto src_backend = ggml_backend_sched_get_tensor_backend(sched, split_info->splits[d]);
ggml_backend_tensor_copy_async(src_backend, src_backend, split_info->splits[d], shadow_split[d]);
backends_to_sync.insert(src_backend);
}
}
split_s_idx++;
@ -1415,11 +1420,15 @@ bool llama_kv_cache::checkpoint_save() {
}
}
for (auto backend : backends_to_sync) {
ggml_backend_synchronize(backend);
}
ckpt.saved = true;
return true;
}
bool llama_kv_cache::checkpoint_restore() {
bool llama_kv_cache::checkpoint_restore(ggml_backend_sched_t sched) {
if (!ckpt.saved) {
LLAMA_LOG_ERROR("%s: no checkpoint saved\n", __func__);
return false;
@ -1431,6 +1440,8 @@ bool llama_kv_cache::checkpoint_restore() {
head = ckpt.head_snapshot;
used = ckpt.used_snapshot;
std::unordered_set<ggml_backend_t> backends_to_sync;
uint32_t split_s_idx = 0;
for (uint32_t il = 0; il < n_layer; ++il) {
if (s_l[il] == nullptr) {
@ -1442,7 +1453,9 @@ bool llama_kv_cache::checkpoint_restore() {
auto & shadow_split = ckpt.split_s_l_shadow[split_s_idx];
for (int d = 0; d < split_info->n_device; ++d) {
if (split_info->splits[d] && shadow_split[d]) {
ggml_backend_tensor_copy(shadow_split[d], split_info->splits[d]);
auto dst_backend = ggml_backend_sched_get_tensor_backend(sched, split_info->splits[d]);
ggml_backend_tensor_copy_async(dst_backend, dst_backend, shadow_split[d], split_info->splits[d]);
backends_to_sync.insert(dst_backend);
}
}
split_s_idx++;
@ -1452,6 +1465,10 @@ bool llama_kv_cache::checkpoint_restore() {
}
}
for (auto backend : backends_to_sync) {
ggml_backend_synchronize(backend);
}
return true;
}
@ -7015,10 +7032,10 @@ bool llama_spec_ckpt_save(struct llama_context * ctx, llama_seq_id seq_id) {
switch (kv.ckpt.selected_spec_mode) {
case LLAMA_SPEC_CKPT_PER_STEP:
kv.save_per_step_ssm = true;
return kv.checkpoint_save();
return kv.checkpoint_save(ctx->sched);
case LLAMA_SPEC_CKPT_GPU_FALLBACK:
return kv.checkpoint_save();
return kv.checkpoint_save(ctx->sched);
case LLAMA_SPEC_CKPT_CPU: {
const size_t need = llama_state_seq_get_size(ctx, seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
@ -7052,7 +7069,7 @@ bool llama_spec_ckpt_restore(struct llama_context * ctx, llama_seq_id seq_id,
}
case LLAMA_SPEC_CKPT_GPU_FALLBACK:
kv.checkpoint_restore();
kv.checkpoint_restore(ctx->sched);
llama_kv_cache_seq_rm(kv, seq_id, n_past, -1);
return false;