mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-06-28 04:30:15 -05:00
Faster per step recurrent state restore when using MTP
This commit is contained in:
parent
ab0f22b819
commit
f6deca0f97
@ -143,7 +143,7 @@ struct llama_kv_cache {
|
||||
|
||||
// Per-step checkpoint: allocate, restore step k's full state (SSM + conv) to cache
|
||||
bool per_step_alloc(int max_tokens);
|
||||
bool per_step_restore(int step);
|
||||
bool per_step_restore(ggml_backend_sched_t sched, int step);
|
||||
|
||||
~llama_kv_cache() {
|
||||
for (struct ggml_context * ctx : ctxs) {
|
||||
|
||||
@ -1559,11 +1559,13 @@ bool llama_kv_cache::per_step_alloc(int max_tokens) {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool llama_kv_cache::per_step_restore(int step) {
|
||||
bool llama_kv_cache::per_step_restore(ggml_backend_sched_t sched, int step) {
|
||||
if (ckpt.per_step_ssm.empty() || step < 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::unordered_set<ggml_backend_t> backends_to_sync;
|
||||
|
||||
const int64_t ssm_state_dim = ckpt.per_step_ssm_state_size;
|
||||
const int64_t conv_state_dim = ckpt.per_step_conv_state_dim;
|
||||
const int64_t conv_dim = ckpt.per_step_conv_dim;
|
||||
@ -1581,13 +1583,23 @@ bool llama_kv_cache::per_step_restore(int step) {
|
||||
std::vector<float> qkv_buf(qkv_needed);
|
||||
|
||||
const uint32_t n_layer = (uint32_t)s_l.size();
|
||||
int n_restored = 0;
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
if (s_l[il] == nullptr || ckpt.per_step_ssm[il] == nullptr) continue;
|
||||
if (s_l[il]->extra != nullptr) continue;
|
||||
|
||||
ggml_backend_tensor_get(ckpt.per_step_ssm[il], ssm_buf.data(),
|
||||
(size_t)step * ssm_bytes, ssm_bytes);
|
||||
auto dst_backend = ggml_backend_sched_get_tensor_backend(sched, s_l[il]);
|
||||
auto dst = *s_l[il];
|
||||
dst.ne[0] = ssm_bytes/sizeof(float);
|
||||
dst.nb[1] = dst.nb[2] = dst.nb[3] = ssm_bytes;
|
||||
dst.data = (char *)s_l[il]->data + conv_bytes;
|
||||
auto src = dst;
|
||||
// I think the commented out version is correct, bt only after fixing the delta-net implementation.
|
||||
// For now let's just reproduce what we have on the main branch, only avoiding the expensive
|
||||
// device -> host -> device copy.
|
||||
//src.data = (char *)ckpt.per_step_ssm[il]->data + (size_t)(step + 1) * ssm_bytes;
|
||||
src.data = (char *)ckpt.per_step_ssm[il]->data + (size_t)step * ssm_bytes;
|
||||
ggml_backend_tensor_copy_async(dst_backend, dst_backend, &src, &dst);
|
||||
backends_to_sync.insert(dst_backend);
|
||||
|
||||
if (ckpt.s_l_shadow[il] != nullptr) {
|
||||
ggml_backend_tensor_get(ckpt.s_l_shadow[il], old_conv_buf.data(), 0, conv_bytes);
|
||||
@ -1622,8 +1634,10 @@ bool llama_kv_cache::per_step_restore(int step) {
|
||||
}
|
||||
|
||||
ggml_backend_tensor_set(s_l[il], conv_buf.data(), 0, conv_bytes);
|
||||
ggml_backend_tensor_set(s_l[il], ssm_buf.data(), conv_bytes, ssm_bytes);
|
||||
n_restored++;
|
||||
}
|
||||
|
||||
for (auto backend : backends_to_sync) {
|
||||
ggml_backend_synchronize(backend);
|
||||
}
|
||||
|
||||
return true;
|
||||
@ -7057,7 +7071,7 @@ bool llama_spec_ckpt_restore(struct llama_context * ctx, llama_seq_id seq_id,
|
||||
|
||||
switch (kv.ckpt.selected_spec_mode) {
|
||||
case LLAMA_SPEC_CKPT_PER_STEP: {
|
||||
if (!kv.per_step_restore(accepted_step)) {
|
||||
if (!kv.per_step_restore(ctx->sched, accepted_step)) {
|
||||
return false;
|
||||
}
|
||||
const llama_pos accepted_pos = n_past + accepted_step;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user