Pre-allocate buffers for hybrid model checkpoints (#1774)

* hybrid-spec: improve recurrent checkpoint handling in speculative decoding

* change per-step save to support scheduling and asynchronous tensor operations

* remove redudant backend tensor fallback

* improve recurrent tensor handling for split graph
This commit is contained in:
Samuel Oliveira Alves 2026-05-12 01:21:25 -03:00 committed by GitHub
parent c2f498ab4c
commit be8435793e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 304 additions and 89 deletions

View File

@ -117,7 +117,7 @@ void spec_tuner::write_best(common_params_speculative & params) const {
}
}
void spec_tuner::init(common_speculative_type type, const common_params_speculative & user_params) {
void spec_tuner::init(common_speculative_type type, const common_params_speculative & user_params, const llama_model * model_tgt) {
enabled = true;
spec_type = type;
coords.clear();
@ -136,7 +136,9 @@ void spec_tuner::init(common_speculative_type type, const common_params_speculat
{
spec_tuner_coord coord;
coord.name = "n_max";
int hi = std::max(16, (int)user_params.n_max);
const bool recurrent_target = model_tgt != nullptr && llama_model_has_recurrent(model_tgt);
int hi = recurrent_target ? std::max(1, (int) user_params.n_max)
: std::max(16, (int) user_params.n_max);
coord.build_grid_int(1, hi, 1, user_params.n_max);
coords.push_back(std::move(coord));
}

View File

@ -2,6 +2,8 @@
#include "common.h"
struct llama_model;
struct spec_tuner_arm {
float value;
double Q = 0.0; // mean per-step Tokens-Per-Second (TPS)
@ -55,7 +57,7 @@ struct spec_tuner {
common_speculative_type spec_type = COMMON_SPECULATIVE_TYPE_NONE;
std::vector<spec_tuner_coord> coords;
void init(common_speculative_type type, const common_params_speculative & user_params);
void init(common_speculative_type type, const common_params_speculative & user_params, const llama_model * model_tgt);
void propose(common_params_speculative & params);
void accept_feedback(int n_accepted, int n_drafted, double step_tps);
void end_of_request(double slot_tps, int n_past, common_params_speculative & active_params);

View File

@ -1122,6 +1122,25 @@ common_speculative * common_speculative_init(
}
}
if (!configs.empty() && llama_model_has_recurrent(llama_get_model(ctx_tgt))) {
const int ckpt_tokens = std::max(1, params.n_max + 1);
const int actual_mode = llama_spec_ckpt_init(ctx_tgt, params.recurrent_ckpt_mode, ckpt_tokens);
if (actual_mode == LLAMA_SPEC_CKPT_NONE) {
LOG_ERR("%s: failed to prepare recurrent checkpoint mode '%s' during speculative init (max_tokens=%d)\n",
__func__,
params.recurrent_ckpt_mode == LLAMA_SPEC_CKPT_PER_STEP ? "per-step" :
params.recurrent_ckpt_mode == LLAMA_SPEC_CKPT_GPU_FALLBACK ? "gpu-fallback" :
params.recurrent_ckpt_mode == LLAMA_SPEC_CKPT_CPU ? "cpu" : "auto",
ckpt_tokens);
if (ctx_dft != nullptr) {
llama_free(ctx_dft);
}
return nullptr;
}
llama_spec_ckpt_discard(ctx_tgt);
params.recurrent_ckpt_mode = actual_mode;
}
std::vector<std::unique_ptr<common_speculative_state>> impls = {};
for (const common_speculative_config & config : configs) {
@ -1221,7 +1240,7 @@ common_speculative * common_speculative_init(
if (actual_type != COMMON_SPECULATIVE_TYPE_NONE &&
actual_type != COMMON_SPECULATIVE_TYPE_EAGLE3) {
result->tuner = std::make_unique<spec_tuner>();
result->tuner->init(actual_type, params);
result->tuner->init(actual_type, params, llama_get_model(ctx_tgt));
LOG_DBG("Autotune initialized for %s, tuning %zu parameters\n",
common_speculative_type_to_str(actual_type).c_str(),
result->tuner->coords.size());

View File

@ -454,6 +454,9 @@ void server_context::init() {
}
}
const bool requested_spec = params_base.speculative.type != COMMON_SPECULATIVE_TYPE_NONE ||
params_base.speculative.has_dft();
bool can_spec = true;
if (!params_base.dry_run) {
can_spec = common_speculative_is_compat(ctx);
@ -462,7 +465,7 @@ void server_context::init() {
SRV_WRN("%s", "speculative decoding not supported by this context\n");
}
// try speculative decoding
if (can_spec) {
if (can_spec && requested_spec) {
slot.spec = common_speculative_init(params_base.speculative, slot.ctx);
if (slot.spec) {
if (mctx && !slot.has_mtp) {
@ -471,11 +474,15 @@ void server_context::init() {
}
SLT_INF(slot, "%s", "speculative decoding context initialized\n");
} else {
if (slot.has_mtp) {
SRV_ERR("%s", "failed to initialize MTP speculative context, aborting\n");
GGML_ABORT("MTP context creation failed");
if (llama_model_has_recurrent(model)) {
SRV_ERR("%s", "failed to initialize recurrent speculative context\n");
throw std::runtime_error("recurrent speculative context initialization failed");
} else if (slot.has_mtp) {
SRV_ERR("%s", "failed to initialize MTP speculative context\n");
throw std::runtime_error("MTP speculative context initialization failed");
} else {
SLT_INF(slot, "%s", "speculative decoding context not initialized\n");
SRV_ERR("%s", "failed to initialize speculative decoding context\n");
throw std::runtime_error("speculative decoding context initialization failed");
}
}
}
@ -1233,6 +1240,17 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
slot.params.speculative.n_min = std::max(slot.params.speculative.n_min, 0);
slot.params.speculative.n_max = std::max(slot.params.speculative.n_max, 0);
if (slot.can_speculate() &&
llama_model_has_recurrent(model) &&
slot.params.speculative.n_max > params_base.speculative.n_max) {
send_error(task,
"Error: speculative.n_max=" + std::to_string(slot.params.speculative.n_max) +
" exceeds the recurrent speculative startup limit of " + std::to_string(params_base.speculative.n_max) +
"; restart the server with a higher --draft-max to reserve checkpoint capacity",
ERROR_TYPE_INVALID_REQUEST);
return false;
}
slot.params.speculative.ngram_size_n = json_value(data, "speculative.ngram_size_n", defaults.speculative.ngram_size_n);
slot.params.speculative.ngram_size_m = json_value(data, "speculative.ngram_size_m", defaults.speculative.ngram_size_m);
slot.params.speculative.ngram_min_hits = json_value(data, "speculative.ngram_m_hits", defaults.speculative.ngram_min_hits);

View File

@ -258,7 +258,7 @@ struct server_context {
gpt_params params_base;
llama_batch batch;
llama_batch batch = {};
bool clean_kv_cache = true;
bool add_bos_token = true;

View File

@ -584,7 +584,13 @@ int main(int argc, char ** argv) {
state.store(SERVER_STATE_ERROR);
return 1;
} else {
try {
ctx_server.init();
} catch (const std::exception & e) {
LOG_ERROR("server init failed", {{"error", e.what()}});
state.store(SERVER_STATE_ERROR);
return 1;
}
state.store(SERVER_STATE_READY);
}

View File

@ -103,6 +103,8 @@ struct llama_kv_cache {
int32_t per_step_d_conv = 0;
int selected_spec_mode = -1;
int fixed_spec_mode = LLAMA_SPEC_CKPT_NONE;
int32_t fixed_max_tokens = 0;
// Serialised sequence state for CPU mode
std::vector<uint8_t> cpu_state_data;
@ -115,6 +117,7 @@ struct llama_kv_cache {
std::vector<ggml_backend_buffer_t> shadow_bufs;
bool allocated = false;
bool shadow_conv_only = false;
bool saved = false;
~gpu_checkpoint() {
@ -135,13 +138,14 @@ struct llama_kv_cache {
gpu_checkpoint ckpt;
bool checkpoint_alloc_shadows();
bool checkpoint_alloc_shadows(bool conv_only_shadow = false);
bool checkpoint_supported() const;
bool checkpoint_save(ggml_backend_sched_t sched);
bool checkpoint_restore(ggml_backend_sched_t sched);
void checkpoint_delete();
// Per-step checkpoint: allocate, restore step k's full state (SSM + conv) to cache
bool per_step_save(ggml_backend_sched_t sched);
bool per_step_alloc(const llama_model & model, int max_tokens);
bool per_step_restore(const llama_model & model, ggml_backend_sched_t sched, int step);

View File

@ -70,7 +70,9 @@ delta_net::delta_net(llama_context & _lctx, const llama_batch & _batch) : lctx(_
GGML_ASSERT((uint32_t) s < qnext_state_slots);
}
int max_per_step = lctx.kv_self.save_per_step_ssm ? std::min<int>(8, lctx.kv_self.ckpt.per_step_max_allocated) : 0;
int max_per_step = lctx.kv_self.save_per_step_ssm
? lctx.kv_self.ckpt.per_step_max_allocated
: 0;
save_per_step_states = lctx.kv_self.save_per_step_ssm && batch.n_tokens > 1 && batch.n_tokens <= max_per_step;
}

View File

@ -1255,8 +1255,15 @@ bool llama_kv_cache::checkpoint_supported() const {
return false;
}
bool llama_kv_cache::checkpoint_alloc_shadows() {
bool llama_kv_cache::checkpoint_alloc_shadows(bool conv_only_shadow) {
if (ckpt.allocated) {
if (ckpt.shadow_conv_only != conv_only_shadow) {
LLAMA_LOG_ERROR("%s: requested %s shadow buffers, but %s shadow buffers are already allocated\n",
__func__,
conv_only_shadow ? "conv-state-only" : "full-state",
ckpt.shadow_conv_only ? "conv-state-only" : "full-state");
return false;
}
return true;
}
@ -1269,10 +1276,7 @@ bool llama_kv_cache::checkpoint_alloc_shadows() {
int split_idx; // -1 for non-split
};
const bool conv_only_shadow = save_per_step_ssm && ckpt.per_step_conv_state_dim > 0;
std::vector<tensor_entry> nonsplit_entries;
std::map<ggml_backend_buffer_type_t, std::vector<tensor_entry>> split_buft_entries;
std::map<ggml_backend_buffer_type_t, std::vector<tensor_entry>> buft_entries;
for (uint32_t il = 0; il < n_layer; ++il) {
if (s_l[il] == nullptr) {
@ -1285,16 +1289,18 @@ bool llama_kv_cache::checkpoint_alloc_shadows() {
for (int d = 0; d < split_info->n_device; ++d) {
if (split_info->splits[d] == nullptr) continue;
ggml_backend_buffer_type_t buft = ggml_backend_buffer_get_type(split_info->splits[d]->buffer);
split_buft_entries[buft].push_back({split_info->splits[d], il, d});
buft_entries[buft].push_back({split_info->splits[d], il, d});
}
} else {
nonsplit_entries.push_back({s_l[il], il, -1});
ggml_backend_buffer_type_t buft = ggml_backend_buffer_get_type(s_l[il]->buffer);
buft_entries[buft].push_back({s_l[il], il, -1});
}
}
if (!nonsplit_entries.empty()) {
// Allocate all shadows on the same backend type as the source tensor.
for (auto & [buft, entries] : buft_entries) {
ggml_init_params params = {
/*.mem_size =*/ nonsplit_entries.size() * ggml_tensor_overhead(),
/*.mem_size =*/ entries.size() * ggml_tensor_overhead(),
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true,
};
@ -1304,60 +1310,39 @@ bool llama_kv_cache::checkpoint_alloc_shadows() {
return false;
}
for (auto & entry : nonsplit_entries) {
// Only need the conv portion when per-step is active.
const int64_t nelems = conv_only_shadow
? ckpt.per_step_conv_state_dim
: (int64_t)ggml_nelements(entry.primary);
ggml_tensor * shadow = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nelems);
ggml_format_name(shadow, "shadow_s_l%d", entry.il);
ckpt.s_l_shadow[entry.il] = shadow;
}
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, ggml_backend_cpu_buffer_type());
if (!buf) {
LLAMA_LOG_ERROR("%s: failed to allocate CPU buffer for shadow tensors\n", __func__);
ggml_free(ctx);
return false;
}
ggml_backend_buffer_clear(buf, 0);
LLAMA_LOG_INFO("%s: CPU shadow buffer = %8.2f MiB (%s)\n", __func__,
ggml_backend_buffer_get_size(buf) / 1024.0 / 1024.0,
conv_only_shadow ? "conv-state only" : "full recurrent state");
ckpt.shadow_ctxs.push_back(ctx);
ckpt.shadow_bufs.push_back(buf);
}
// Allocate split shadows on their respective devices
for (auto & [buft, entries] : split_buft_entries) {
ggml_init_params params = {
/*.mem_size =*/ entries.size() * ggml_tensor_overhead(),
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true,
};
ggml_context * ctx = ggml_init(params);
if (!ctx) {
LLAMA_LOG_ERROR("%s: failed to create ggml context for split shadow tensors\n", __func__);
return false;
}
for (auto & entry : entries) {
ggml_tensor * shadow = ggml_dup_tensor(ctx, entry.primary);
ggml_tensor * shadow = nullptr;
if (conv_only_shadow && entry.split_idx < 0) {
shadow = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ckpt.per_step_conv_state_dim);
} else {
shadow = ggml_dup_tensor(ctx, entry.primary);
}
if (entry.split_idx >= 0) {
ggml_format_name(shadow, "shadow_s_l%d_d%d", entry.il, entry.split_idx);
} else {
ggml_format_name(shadow, "shadow_s_l%d", entry.il);
}
entry.primary = shadow;
}
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
if (!buf) {
LLAMA_LOG_ERROR("%s: failed to allocate buffer for split shadow tensors\n", __func__);
LLAMA_LOG_ERROR("%s: failed to allocate buffer for shadow tensors\n", __func__);
ggml_free(ctx);
return false;
}
ggml_backend_buffer_clear(buf, 0);
LLAMA_LOG_INFO("%s: %10s split shadow buffer = %8.2f MiB\n", __func__,
ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf) / 1024.0 / 1024.0);
LLAMA_LOG_INFO("%s: %10s shadow buffer = %8.2f MiB%s\n", __func__,
ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf) / 1024.0 / 1024.0,
conv_only_shadow ? " (conv-state only)" : "");
ckpt.shadow_ctxs.push_back(ctx);
ckpt.shadow_bufs.push_back(buf);
for (const auto & entry : entries) {
if (entry.split_idx < 0) {
ckpt.s_l_shadow[entry.il] = entry.primary;
}
}
}
// Build split shadow lookup
@ -1374,7 +1359,7 @@ bool llama_kv_cache::checkpoint_alloc_shadows() {
for (int d = 0; d < split_info->n_device; ++d) {
if (split_info->splits[d] == nullptr) continue;
ggml_backend_buffer_type_t buft = ggml_backend_buffer_get_type(split_info->splits[d]->buffer);
for (auto & entry : split_buft_entries[buft]) {
for (auto & entry : buft_entries[buft]) {
if (entry.il == il && entry.split_idx == d) {
shadow_split[d] = entry.primary;
break;
@ -1383,15 +1368,18 @@ bool llama_kv_cache::checkpoint_alloc_shadows() {
}
}
ckpt.shadow_conv_only = conv_only_shadow;
ckpt.allocated = true;
return true;
}
bool llama_kv_cache::checkpoint_save(ggml_backend_sched_t sched) {
if (!checkpoint_alloc_shadows()) {
if (!checkpoint_alloc_shadows(false)) {
return false;
}
GGML_ASSERT(!ckpt.shadow_conv_only);
const uint32_t n_layer = (uint32_t)s_l.size();
ckpt.cells_snapshot = cells;
@ -1417,8 +1405,11 @@ bool llama_kv_cache::checkpoint_save(ggml_backend_sched_t sched) {
}
}
} else {
const size_t nbytes = ggml_nbytes(ckpt.s_l_shadow[il]);
ggml_backend_tensor_get(s_l[il], ckpt.s_l_shadow[il]->data, 0, nbytes);
GGML_ASSERT(ckpt.s_l_shadow[il] != nullptr);
auto src_backend = ggml_backend_sched_get_tensor_backend(sched, s_l[il]);
GGML_ASSERT(src_backend != nullptr);
ggml_backend_tensor_copy_async(src_backend, src_backend, s_l[il], ckpt.s_l_shadow[il]);
backends_to_sync.insert(src_backend);
}
}
@ -1436,6 +1427,8 @@ bool llama_kv_cache::checkpoint_restore(ggml_backend_sched_t sched) {
return false;
}
GGML_ASSERT(!ckpt.shadow_conv_only);
const uint32_t n_layer = (uint32_t)s_l.size();
cells = ckpt.cells_snapshot;
@ -1460,8 +1453,12 @@ bool llama_kv_cache::checkpoint_restore(ggml_backend_sched_t sched) {
}
}
} else {
GGML_ASSERT(ckpt.s_l_shadow[il] != nullptr);
GGML_ASSERT(ggml_nbytes(ckpt.s_l_shadow[il]) == ggml_nbytes(s_l[il]));
ggml_backend_tensor_copy(ckpt.s_l_shadow[il], s_l[il]);
auto dst_backend = ggml_backend_sched_get_tensor_backend(sched, s_l[il]);
GGML_ASSERT(dst_backend != nullptr);
ggml_backend_tensor_copy_async(dst_backend, dst_backend, ckpt.s_l_shadow[il], s_l[il]);
backends_to_sync.insert(dst_backend);
}
}
@ -1476,6 +1473,68 @@ void llama_kv_cache::checkpoint_delete() {
ckpt.saved = false;
}
bool llama_kv_cache::per_step_save(ggml_backend_sched_t sched) {
const uint32_t n_layer = (uint32_t)s_l.size();
const int64_t conv_state_dim = ckpt.per_step_conv_state_dim;
ckpt.cells_snapshot = cells;
ckpt.head_snapshot = head;
ckpt.used_snapshot = used;
if (conv_state_dim > 0 && !checkpoint_alloc_shadows(true)) {
return false;
}
// Non-split recurrent tensors only need the pre-spec conv complement in
// their reduced shadow buffers. Split tensors keep full shadow copies so
// restore can still seed each split conv prefix from split_s_l_shadow.
const size_t conv_bytes = (size_t)std::max<int64_t>(conv_state_dim, 0) * sizeof(float);
std::unordered_set<ggml_backend_t> backends_to_sync;
for (uint32_t il = 0; il < n_layer; ++il) {
if (s_l[il] == nullptr) {
continue;
}
if (s_l[il]->extra != nullptr) {
auto * split_info = (const ggml_split_tensor_t *)s_l[il]->extra;
auto & shadow_split = ckpt.split_s_l_shadow[il];
for (int d = 0; d < split_info->n_device; ++d) {
if (split_info->splits[d] == nullptr || shadow_split[d] == nullptr) {
continue;
}
auto src_backend = ggml_backend_sched_get_tensor_backend(sched, split_info->splits[d]);
GGML_ASSERT(src_backend != nullptr);
ggml_backend_tensor_copy_async(src_backend, src_backend, split_info->splits[d], shadow_split[d]);
backends_to_sync.insert(src_backend);
}
continue;
}
if (conv_bytes == 0) {
continue;
}
GGML_ASSERT(ckpt.s_l_shadow[il] != nullptr);
ggml_tensor src = *s_l[il];
src.ne[0] = conv_bytes / sizeof(float);
src.nb[1] = src.nb[2] = src.nb[3] = conv_bytes;
auto src_backend = ggml_backend_sched_get_tensor_backend(sched, s_l[il]);
GGML_ASSERT(src_backend != nullptr);
ggml_backend_tensor_copy_async(src_backend, src_backend, &src, ckpt.s_l_shadow[il]);
backends_to_sync.insert(src_backend);
}
for (auto backend : backends_to_sync) {
ggml_backend_synchronize(backend);
}
ckpt.saved = true;
return true;
}
bool llama_kv_cache::per_step_alloc(const llama_model & model, int max_tokens) {
if (ckpt.per_step_max_allocated >= max_tokens) {
return true;
@ -7089,9 +7148,9 @@ void llama_kv_cache_clear(struct llama_context * ctx) {
// Unified speculative-checkpoint
static bool spec_ckpt_try_per_step(llama_kv_cache & kv, const llama_model & model, int max_tokens) {
// Graph-split recurrent tensors are not supported. CPU-only and mixed
// CPU/GPU recurrent placement are allowed as long as each layer has a
// concrete backend buffer for the per-step tensors.
// Split recurrent tensors are supported as long as each layer exposes
// concrete backend buffers for the per-step tensors. CPU-only and mixed
// CPU/GPU recurrent placement are also allowed.
bool has_gpu = false;
bool has_cpu = false;
for (const auto * sl : kv.s_l) {
@ -7100,10 +7159,6 @@ static bool spec_ckpt_try_per_step(llama_kv_cache & kv, const llama_model & mode
has_gpu = true;
continue;
}
//if (sl->extra) {
// kv.save_per_step_ssm = false;
// return false;
//}
if (sl->buffer && !ggml_backend_buffer_is_host(sl->buffer)) {
has_gpu = true;
} else if (sl->buffer) {
@ -7137,9 +7192,73 @@ static bool spec_ckpt_try_per_step(llama_kv_cache & kv, const llama_model & mode
return false;
}
if (!kv.checkpoint_alloc_shadows(true)) {
LLAMA_LOG_ERROR("%s: failed to allocate conv-state shadow buffers for per-step checkpoints\n", __func__);
kv.save_per_step_ssm = false;
return false;
}
return true;
}
static size_t llama_spec_ckpt_cpu_state_reserve(const llama_context * ctx, llama_seq_id seq_id) {
const auto & kv_self = ctx->kv_self;
size_t size = sizeof(uint32_t); // cell_count
if (seq_id >= 0 && llama_kv_qnext_seq_id_in_range(kv_self, seq_id) && (uint32_t) seq_id < kv_self.size) {
size += sizeof(llama_pos);
size += sizeof(uint32_t); // n_seq_id = 0 for seq-specific saves
}
const uint32_t v_state = kv_self.v_l.empty() ? 2 : kv_self.v_trans ? 1 : 0;
const uint32_t n_layer = kv_self.k_l.size();
size += sizeof(v_state);
size += sizeof(n_layer);
size += (size_t) n_layer * (sizeof(int32_t) + sizeof(uint64_t));
if (v_state == 0) {
size += (size_t) n_layer * (sizeof(int32_t) + sizeof(uint64_t));
} else if (v_state == 1) {
size += (size_t) n_layer * (sizeof(int32_t) + sizeof(uint32_t) + sizeof(uint32_t));
}
const uint32_t qnext_state = llama_kv_has_qnext_state_storage(kv_self) ? 1 : 0;
size += sizeof(qnext_state);
if (qnext_state != 0) {
for (uint32_t il = 0; il < n_layer; ++il) {
const bool has_s_cache = il < kv_self.s_l.size() && kv_self.s_l[il] != nullptr;
const uint64_t s_size_row = has_s_cache ? ggml_row_size(kv_self.s_l[il]->type, kv_self.s_l[il]->ne[0]) : 0;
const uint32_t s_rows = has_s_cache && seq_id >= 0 && llama_kv_qnext_seq_id_in_range(kv_self, seq_id) && (uint32_t) seq_id < kv_self.size
? 1
: 0;
size += sizeof(int32_t);
size += sizeof(uint64_t);
size += sizeof(uint32_t);
size += (size_t) s_rows * s_size_row;
}
}
return size;
}
static const char * llama_spec_ckpt_mode_name(int mode) {
switch (mode) {
case LLAMA_SPEC_CKPT_PER_STEP:
return "per-step";
case LLAMA_SPEC_CKPT_GPU_FALLBACK:
return "gpu-fallback";
case LLAMA_SPEC_CKPT_CPU:
return "cpu";
case LLAMA_SPEC_CKPT_AUTO:
return "auto";
default:
return "none";
}
}
int llama_spec_ckpt_init(struct llama_context * ctx, int mode, int max_tokens) {
auto & kv = ctx->kv_self;
@ -7150,7 +7269,19 @@ int llama_spec_ckpt_init(struct llama_context * ctx, int mode, int max_tokens) {
return (int)LLAMA_SPEC_CKPT_NONE;
}
if (kv.ckpt.fixed_spec_mode != LLAMA_SPEC_CKPT_NONE) {
if (kv.ckpt.fixed_spec_mode == LLAMA_SPEC_CKPT_PER_STEP && max_tokens > kv.ckpt.fixed_max_tokens) {
LLAMA_LOG_WARN("%s: fixed per-step checkpoint capacity is %d tokens, but the current speculative batch requests %d; disabling checkpoint for this batch\n",
__func__, kv.ckpt.fixed_max_tokens, max_tokens);
return (int)LLAMA_SPEC_CKPT_NONE;
}
kv.ckpt.selected_spec_mode = kv.ckpt.fixed_spec_mode;
return kv.ckpt.selected_spec_mode;
}
int requested = mode;
int resolved = LLAMA_SPEC_CKPT_NONE;
// prefer PER_STEP → GPU_FALLBACK → CPU
if (requested == LLAMA_SPEC_CKPT_AUTO) {
@ -7159,22 +7290,53 @@ int llama_spec_ckpt_init(struct llama_context * ctx, int mode, int max_tokens) {
if (requested == LLAMA_SPEC_CKPT_PER_STEP) {
if (spec_ckpt_try_per_step(kv, ctx->model, max_tokens)) {
kv.ckpt.selected_spec_mode = LLAMA_SPEC_CKPT_PER_STEP;
return (int)LLAMA_SPEC_CKPT_PER_STEP;
}
if (mode == LLAMA_SPEC_CKPT_PER_STEP) {
LLAMA_LOG_WARN("%s: per-step not available, falling back to GPU fallback mode\n", __func__);
}
resolved = LLAMA_SPEC_CKPT_PER_STEP;
} else if (mode == LLAMA_SPEC_CKPT_PER_STEP) {
LLAMA_LOG_ERROR("%s: failed to preallocate per-step checkpoint buffers for max_tokens=%d; --recurrent-ckpt-mode=%s requires startup allocation\n",
__func__, max_tokens, llama_spec_ckpt_mode_name(mode));
return (int)LLAMA_SPEC_CKPT_NONE;
} else {
LLAMA_LOG_WARN("%s: auto checkpoint mode could not preallocate per-step buffers for max_tokens=%d; falling back to gpu-fallback\n",
__func__, max_tokens);
requested = LLAMA_SPEC_CKPT_GPU_FALLBACK;
}
if (requested == LLAMA_SPEC_CKPT_GPU_FALLBACK) {
kv.ckpt.selected_spec_mode = LLAMA_SPEC_CKPT_GPU_FALLBACK;
return (int)LLAMA_SPEC_CKPT_GPU_FALLBACK;
}
kv.ckpt.selected_spec_mode = LLAMA_SPEC_CKPT_CPU;
return (int)LLAMA_SPEC_CKPT_CPU;
if (resolved == LLAMA_SPEC_CKPT_NONE && requested == LLAMA_SPEC_CKPT_GPU_FALLBACK) {
if (kv.checkpoint_alloc_shadows()) {
resolved = LLAMA_SPEC_CKPT_GPU_FALLBACK;
} else if (mode == LLAMA_SPEC_CKPT_GPU_FALLBACK) {
LLAMA_LOG_ERROR("%s: failed to preallocate gpu-fallback checkpoint shadows at startup; --recurrent-ckpt-mode=%s requires startup allocation\n",
__func__, llama_spec_ckpt_mode_name(mode));
return (int)LLAMA_SPEC_CKPT_NONE;
} else {
LLAMA_LOG_WARN("%s: auto checkpoint mode could not preallocate gpu-fallback checkpoint shadows; falling back to cpu\n",
__func__);
requested = LLAMA_SPEC_CKPT_CPU;
}
}
if (resolved == LLAMA_SPEC_CKPT_NONE) {
resolved = LLAMA_SPEC_CKPT_CPU;
}
if (resolved == LLAMA_SPEC_CKPT_CPU) {
const size_t cpu_reserve = llama_spec_ckpt_cpu_state_reserve(ctx, 0);
kv.ckpt.cpu_state_data.clear();
kv.ckpt.cpu_state_data.reserve(cpu_reserve);
LLAMA_LOG_INFO("%s: CPU serialized checkpoint reserve = %8.2f MiB (per seq)\n",
__func__, cpu_reserve / 1024.0 / 1024.0);
}
kv.ckpt.fixed_spec_mode = resolved;
kv.ckpt.fixed_max_tokens = resolved == LLAMA_SPEC_CKPT_PER_STEP ? max_tokens : 0;
kv.ckpt.selected_spec_mode = resolved;
LLAMA_LOG_INFO("%s: fixed recurrent checkpoint mode = %s%s\n",
__func__, llama_spec_ckpt_mode_name(resolved),
resolved == LLAMA_SPEC_CKPT_PER_STEP ? (std::string(" (max_tokens=") + std::to_string(max_tokens) + ")").c_str() : "");
return resolved;
}
bool llama_spec_ckpt_save(struct llama_context * ctx, llama_seq_id seq_id) {
@ -7183,7 +7345,7 @@ bool llama_spec_ckpt_save(struct llama_context * ctx, llama_seq_id seq_id) {
switch (kv.ckpt.selected_spec_mode) {
case LLAMA_SPEC_CKPT_PER_STEP:
kv.save_per_step_ssm = true;
return kv.checkpoint_save(ctx->sched);
return kv.per_step_save(ctx->sched);
case LLAMA_SPEC_CKPT_GPU_FALLBACK:
return kv.checkpoint_save(ctx->sched);