mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-06-28 04:30:15 -05:00
Discard very first compute graph for recurrent models (#1393)
This commit is contained in:
parent
f90b4c2f27
commit
cda15bf175
@ -3451,7 +3451,7 @@ static int llama_decode_internal(
|
||||
#endif
|
||||
if (u_batch.n_tokens == 1 && u_batch.embd == nullptr && lctx.cparams.graph_reuse) {
|
||||
lctx.prev = std::make_unique<llama_context::Prev>(llama_context::Prev{
|
||||
(int)u_batch.all_seq_id, (int)lctx.n_outputs, (int)lctx.kv_self.n,
|
||||
(int)u_batch.all_seq_id, (int)lctx.n_outputs, (int)lctx.kv_self.n,
|
||||
cparams.mtp_op_type, gf});
|
||||
}
|
||||
} else {
|
||||
@ -3459,7 +3459,7 @@ static int llama_decode_internal(
|
||||
gf = lctx.prev->graph;
|
||||
}
|
||||
|
||||
if (cparams.mtp_op_type != MTP_OP_NONE) {
|
||||
if (cparams.mtp_op_type != MTP_OP_NONE) {
|
||||
if (!prepare_mtp_graph_inputs(lctx)) {
|
||||
return GGML_STATUS_FAILED;
|
||||
}
|
||||
@ -3472,7 +3472,7 @@ static int llama_decode_internal(
|
||||
if (lctx.n_outputs == 0) {
|
||||
// no output
|
||||
res = nullptr;
|
||||
}
|
||||
}
|
||||
else {
|
||||
const bool has_mtp = lctx.model.hparams.nextn_predict_layers > 0 && lctx.model.mtp;
|
||||
if (cparams.embeddings || has_mtp) {
|
||||
@ -3514,8 +3514,12 @@ static int llama_decode_internal(
|
||||
printf("graph_compute(...): %d us\n", int(tim2-tim1));
|
||||
#endif
|
||||
|
||||
bool reset_previous = false;
|
||||
// update the kv ring buffer
|
||||
{
|
||||
if (llama_model_has_recurrent(&lctx.model) && kv_self.head == 0) {
|
||||
reset_previous = true;
|
||||
}
|
||||
kv_self.head += n_tokens;
|
||||
|
||||
// Ensure kv cache head points to a valid index.
|
||||
@ -3536,7 +3540,7 @@ static int llama_decode_internal(
|
||||
#endif
|
||||
// Do not process logits if MTP is only updating the KV cache.
|
||||
if (cparams.mtp_op_type != MTP_OP_WARMUP &&
|
||||
cparams.mtp_op_type != MTP_OP_UPDATE_ACCEPTED) {
|
||||
cparams.mtp_op_type != MTP_OP_UPDATE_ACCEPTED) {
|
||||
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched, res);
|
||||
GGML_ASSERT(backend_res != nullptr);
|
||||
GGML_ASSERT(lctx.logits != nullptr);
|
||||
@ -3607,6 +3611,12 @@ static int llama_decode_internal(
|
||||
}
|
||||
n_outputs_prev += lctx.n_outputs;
|
||||
cur_token += n_tokens;
|
||||
if (reset_previous) {
|
||||
// We need to discard this graph. Otherwise, iwith CUDA graphs enabled, the graph will get resused and this will reset the
|
||||
// recurrent state for each new token. This is probably not very relevant in practice because we basically never run TG with
|
||||
// empty context, but for the sake of correctness let's just do it.
|
||||
lctx.prev.reset();
|
||||
}
|
||||
}
|
||||
|
||||
// set to total number of outputs in the batch, for use in llama_get_logits_ith
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user