Discard very first compute graph for recurrent models (#1393)

This commit is contained in:
Kawrakow 2026-03-10 09:41:47 +01:00 committed by GitHub
parent f90b4c2f27
commit cda15bf175
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -3451,7 +3451,7 @@ static int llama_decode_internal(
#endif
if (u_batch.n_tokens == 1 && u_batch.embd == nullptr && lctx.cparams.graph_reuse) {
lctx.prev = std::make_unique<llama_context::Prev>(llama_context::Prev{
(int)u_batch.all_seq_id, (int)lctx.n_outputs, (int)lctx.kv_self.n,
(int)u_batch.all_seq_id, (int)lctx.n_outputs, (int)lctx.kv_self.n,
cparams.mtp_op_type, gf});
}
} else {
@ -3459,7 +3459,7 @@ static int llama_decode_internal(
gf = lctx.prev->graph;
}
if (cparams.mtp_op_type != MTP_OP_NONE) {
if (cparams.mtp_op_type != MTP_OP_NONE) {
if (!prepare_mtp_graph_inputs(lctx)) {
return GGML_STATUS_FAILED;
}
@ -3472,7 +3472,7 @@ static int llama_decode_internal(
if (lctx.n_outputs == 0) {
// no output
res = nullptr;
}
}
else {
const bool has_mtp = lctx.model.hparams.nextn_predict_layers > 0 && lctx.model.mtp;
if (cparams.embeddings || has_mtp) {
@ -3514,8 +3514,12 @@ static int llama_decode_internal(
printf("graph_compute(...): %d us\n", int(tim2-tim1));
#endif
bool reset_previous = false;
// update the kv ring buffer
{
if (llama_model_has_recurrent(&lctx.model) && kv_self.head == 0) {
reset_previous = true;
}
kv_self.head += n_tokens;
// Ensure kv cache head points to a valid index.
@ -3536,7 +3540,7 @@ static int llama_decode_internal(
#endif
// Do not process logits if MTP is only updating the KV cache.
if (cparams.mtp_op_type != MTP_OP_WARMUP &&
cparams.mtp_op_type != MTP_OP_UPDATE_ACCEPTED) {
cparams.mtp_op_type != MTP_OP_UPDATE_ACCEPTED) {
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched, res);
GGML_ASSERT(backend_res != nullptr);
GGML_ASSERT(lctx.logits != nullptr);
@ -3607,6 +3611,12 @@ static int llama_decode_internal(
}
n_outputs_prev += lctx.n_outputs;
cur_token += n_tokens;
if (reset_previous) {
// We need to discard this graph. Otherwise, iwith CUDA graphs enabled, the graph will get resused and this will reset the
// recurrent state for each new token. This is probably not very relevant in practice because we basically never run TG with
// empty context, but for the sake of correctness let's just do it.
lctx.prev.reset();
}
}
// set to total number of outputs in the batch, for use in llama_get_logits_ith