diff --git a/src/llama.cpp b/src/llama.cpp index a0007c6c..149e1c21 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -3451,7 +3451,7 @@ static int llama_decode_internal( #endif if (u_batch.n_tokens == 1 && u_batch.embd == nullptr && lctx.cparams.graph_reuse) { lctx.prev = std::make_unique(llama_context::Prev{ - (int)u_batch.all_seq_id, (int)lctx.n_outputs, (int)lctx.kv_self.n, + (int)u_batch.all_seq_id, (int)lctx.n_outputs, (int)lctx.kv_self.n, cparams.mtp_op_type, gf}); } } else { @@ -3459,7 +3459,7 @@ static int llama_decode_internal( gf = lctx.prev->graph; } - if (cparams.mtp_op_type != MTP_OP_NONE) { + if (cparams.mtp_op_type != MTP_OP_NONE) { if (!prepare_mtp_graph_inputs(lctx)) { return GGML_STATUS_FAILED; } @@ -3472,7 +3472,7 @@ static int llama_decode_internal( if (lctx.n_outputs == 0) { // no output res = nullptr; - } + } else { const bool has_mtp = lctx.model.hparams.nextn_predict_layers > 0 && lctx.model.mtp; if (cparams.embeddings || has_mtp) { @@ -3514,8 +3514,12 @@ static int llama_decode_internal( printf("graph_compute(...): %d us\n", int(tim2-tim1)); #endif + bool reset_previous = false; // update the kv ring buffer { + if (llama_model_has_recurrent(&lctx.model) && kv_self.head == 0) { + reset_previous = true; + } kv_self.head += n_tokens; // Ensure kv cache head points to a valid index. @@ -3536,7 +3540,7 @@ static int llama_decode_internal( #endif // Do not process logits if MTP is only updating the KV cache. if (cparams.mtp_op_type != MTP_OP_WARMUP && - cparams.mtp_op_type != MTP_OP_UPDATE_ACCEPTED) { + cparams.mtp_op_type != MTP_OP_UPDATE_ACCEPTED) { ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched, res); GGML_ASSERT(backend_res != nullptr); GGML_ASSERT(lctx.logits != nullptr); @@ -3607,6 +3611,12 @@ static int llama_decode_internal( } n_outputs_prev += lctx.n_outputs; cur_token += n_tokens; + if (reset_previous) { + // We need to discard this graph. Otherwise, iwith CUDA graphs enabled, the graph will get resused and this will reset the + // recurrent state for each new token. This is probably not very relevant in practice because we basically never run TG with + // empty context, but for the sake of correctness let's just do it. + lctx.prev.reset(); + } } // set to total number of outputs in the batch, for use in llama_get_logits_ith