Revert "server: defer recurrent-state reset to graph build (addresses #1696 r…" (#1704)

This reverts commit fb05c2e9a2ded4d42861bc000ac01778cd1ba4c9.
This commit is contained in:
Kawrakow 2026-04-28 12:33:34 +02:00 committed by GitHub
parent fb05c2e9a2
commit 453a027c17
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 3 additions and 61 deletions

View File

@ -59,14 +59,6 @@ ggml_cgraph * llm_build_context::build_qwen35moe() {
ggml_build_forward_expand(gf, cur);
// Consume the recurrent-state reset flags. Any slot that was marked by
// llama_kv_cache_seq_rm has had its reset op embedded in this graph
// through delta_net's reset_state branch, so we can clear the flags now;
// the reset will fire when the graph executes.
std::fill(lctx.kv_self.pending_recurrent_reset.begin(),
lctx.kv_self.pending_recurrent_reset.end(),
false);
return gf;
}
@ -147,14 +139,6 @@ ggml_cgraph * llm_build_context::build_qwen35() {
ggml_build_forward_expand(gf, cur);
// Consume the recurrent-state reset flags. Any slot that was marked by
// llama_kv_cache_seq_rm has had its reset op embedded in this graph
// through delta_net's reset_state branch, so we can clear the flags now;
// the reset will fire when the graph executes.
std::fill(lctx.kv_self.pending_recurrent_reset.begin(),
lctx.kv_self.pending_recurrent_reset.end(),
false);
return gf;
}

View File

@ -85,13 +85,5 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
ggml_build_forward_expand(gf, cur);
// Consume the recurrent-state reset flags. Any slot that was marked by
// llama_kv_cache_seq_rm has had its reset op embedded in this graph
// through delta_net's reset_state branch, so we can clear the flags now;
// the reset will fire when the graph executes.
std::fill(lctx.kv_self.pending_recurrent_reset.begin(),
lctx.kv_self.pending_recurrent_reset.end(),
false);
return gf;
}

View File

@ -62,15 +62,6 @@ struct llama_kv_cache {
// When true, the delta_net graph builder will enable per-step SSM state saves
bool save_per_step_ssm = false;
// Set by llama_kv_cache_seq_rm when a hybrid/recurrent slot's cell is
// fully emptied. Read once during the next graph build to inject a
// state-reset op into the recurrent layers, then cleared. Indexed by
// slot/cell index (= column in cache.s_l[layer]). Sized by
// qnext_state_slots when the cache is initialized; empty for non-
// hybrid models, in which case all reads short-circuit through the
// size() bounds check.
std::vector<bool> pending_recurrent_reset;
std::vector<llama_split_tensor> split_k_l;
std::vector<llama_split_tensor> split_v_l;
std::vector<llama_split_tensor> split_s_l;

View File

@ -646,12 +646,8 @@ ggml_tensor * delta_net::build_layer_attn_linear(ggml_context * ctx0, ggml_cgrap
GGML_ASSERT(model.layers[il].wqkv_gate != nullptr || model.layers[il].ssm_in != nullptr);
if (all_same_seq) {
const uint32_t state_seq_id = (uint32_t) token_seq_ids.front();
const bool needs_recurrent_reset =
state_seq_id < lctx.kv_self.pending_recurrent_reset.size() &&
lctx.kv_self.pending_recurrent_reset[state_seq_id];
bool reset_state = (batch.pos != nullptr && batch.pos[0] == 0) || needs_recurrent_reset;
return build_layer_attn_linear_core(ctx0, gf, cur, lctx.inp_s_seq_qnext, inp_out_ids, state_seq_id, reset_state, il, cb);
bool reset_state = batch.pos != nullptr && batch.pos[0] == 0;
return build_layer_attn_linear_core(ctx0, gf, cur, lctx.inp_s_seq_qnext, inp_out_ids, token_seq_ids.front(), reset_state, il, cb);
}
GGML_ASSERT(has_unique_seq_ids && "qwen3next mixed-sequence batches require unique sequence IDs per token");
@ -661,11 +657,8 @@ ggml_tensor * delta_net::build_layer_attn_linear(ggml_context * ctx0, ggml_cgrap
ggml_tensor * cur_i = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (size_t) i * cur->nb[1]);
ggml_tensor * inp_s_seq_qnext_i = ggml_view_2d(ctx0, lctx.inp_s_seq_qnext, 1, 1, lctx.inp_s_seq_qnext->nb[1], (size_t) i * lctx.inp_s_seq_qnext->nb[1]);
const bool reset_state_i = batch.pos != nullptr && batch.pos[i] == 0;
const uint32_t state_seq_id_i = (uint32_t) token_seq_ids[i];
const bool needs_recurrent_reset_i =
state_seq_id_i < lctx.kv_self.pending_recurrent_reset.size() &&
lctx.kv_self.pending_recurrent_reset[state_seq_id_i];
const bool reset_state_i = (batch.pos != nullptr && batch.pos[i] == 0) || needs_recurrent_reset_i;
ggml_tensor * out_i = build_layer_attn_linear_core(ctx0, gf, cur_i, inp_s_seq_qnext_i, inp_out_ids, state_seq_id_i, reset_state_i, il, cb);
out = out == nullptr ? out_i : ggml_concat(ctx0, out, out_i, 1);

View File

@ -851,7 +851,6 @@ static bool llama_kv_cache_init(
LLAMA_LOG_WARN("%s: reducing qwen3next state slots from %u to %u to fit KV cache size\n",
__func__, std::max<uint32_t>(1, cparams.n_seq_max), qnext_state_slots);
}
cache.pending_recurrent_reset.assign(qnext_state_slots, false);
int n_mla = 0;
const int64_t n_mtp_first_layer = n_layer - hparams.nextn_predict_layers;
@ -1598,14 +1597,6 @@ static void llama_kv_cache_clear(struct llama_kv_cache & cache) {
for (auto & buf : cache.bufs) {
ggml_backend_buffer_clear(buf, 0);
}
// The full clear above zeroed every recurrent state buffer, so any
// pending per-slot resets recorded by an earlier seq_rm are now
// redundant. Drop them so the next graph build does not emit a
// spurious in-graph reset op.
std::fill(cache.pending_recurrent_reset.begin(),
cache.pending_recurrent_reset.end(),
false);
}
static bool llama_kv_cache_seq_rm(
@ -1655,15 +1646,6 @@ static bool llama_kv_cache_seq_rm(
cache.cells[i].pos = -1;
if (has_qnext_state) {
cache.cells[i].src = i;
// Defer the recurrent-state reset to graph build time:
// delta-net's existing reset path (ggml_scale state, 0.0f)
// does the zeroing inside the compute graph, so we just
// record which slots need reset here. The flags are
// consumed and cleared at the end of build_qwen3next /
// build_qwen35.
if ((uint32_t) i < cache.pending_recurrent_reset.size()) {
cache.pending_recurrent_reset[i] = true;
}
}
if (new_head == cache.size) new_head = i;
}