diff --git a/examples/server/server-context.cpp b/examples/server/server-context.cpp index 3246ec2f..b272162e 100644 --- a/examples/server/server-context.cpp +++ b/examples/server/server-context.cpp @@ -721,9 +721,8 @@ void server_slot::release() { command = SLOT_COMMAND_RELEASE; state = SLOT_STATE_IDLE; task.reset(); - llama_decode_reset(); } - + llama_decode_reset(); } @@ -4545,8 +4544,14 @@ void server_context::process_batch_tokens(int32_t & n_batch) { for (auto& slot : slots) { slot.state = SLOT_STATE_PROCESSING; slot.command = SLOT_COMMAND_NONE; - slot.release(); - if (ret != user_cancel) { + if (ret == user_cancel) { + llama_pos cur_pos = llama_kv_cache_seq_pos_max(slot.ctx, slot.id); + slot.n_past = slot.cache_tokens.size_up_to_pos(cur_pos + 1); + slot.cache_tokens.keep_first(slot.n_past); + slot.release(); + } + else { + slot.release(); LLAMA_LOG_INFO("n_past = %d\n", (int)slot.cache_tokens.size()); send_error(slot, "Input prompt is too big compared to KV size. Please try increasing KV size."); } diff --git a/src/llama.cpp b/src/llama.cpp index 6b680187..0043d1d9 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -4688,9 +4688,6 @@ static int llama_decode_internal( kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(max_cell, pad))); } } - if (stop_internal_decode) { - return -3; - } #if IK_PRINT_TIMING auto tim2 = ggml_time_us(); @@ -4922,6 +4919,9 @@ static int llama_decode_internal( // empty context, but for the sake of correctness let's just do it. lctx.prev.reset(); } + if (stop_internal_decode) { + return -3; + } } // set to total number of outputs in the batch, for use in llama_get_logits_ith