diff --git a/examples/server/server-context.cpp b/examples/server/server-context.cpp index 6bb46172..52b63f7f 100644 --- a/examples/server/server-context.cpp +++ b/examples/server/server-context.cpp @@ -111,6 +111,15 @@ bool server_context::load_model(const gpt_params& params_) { } // Load draft model for speculative decoding if specified if (has_draft_model) { + + if (llama_model_has_recurrent(model)) { + LLAMA_LOG_WARN("\n=======================================================================\n"); + LLAMA_LOG_WARN(" Speculative decodong is not suported for recurrent/hybrid models\n"); + LLAMA_LOG_WARN(" --> bailing out\n"); + LLAMA_LOG_WARN("========================================================================\n\n"); + GGML_ABORT("Fatal error"); + } + LLAMA_LOG_INFO("\n\n==================================loading DRAFT model==================================\n\n"); gpt_params params_dft; @@ -1470,9 +1479,9 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task) // // TODO: try to make this conditional on the context or the memory module, instead of the model type params_base.do_checkpoint = do_checkpoint; - if (slot.n_buffer != 0) { - LLAMA_LOG_WARN("banned strings is not supported by recurrent model, it will be disabled.\n"); - } + if (slot.n_buffer != 0) { + LLAMA_LOG_WARN("banned strings is not supported by recurrent model, it will be disabled.\n"); + } if (params_base.ctx_shift) { params_base.ctx_shift = false; LOG_WARNING("%s\n", "ctx_shift is not supported by recurrent model, it will be disabled"); diff --git a/src/llama.cpp b/src/llama.cpp index de62849b..da1c8a12 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -4130,7 +4130,7 @@ static int llama_decode_internal( if (n_outputs_new) { GGML_ASSERT( n_outputs_prev + n_outputs_new <= n_outputs); GGML_ASSERT((n_outputs_prev + n_outputs_new)*n_vocab <= (int64_t) lctx.logits_size); - + if (res->ne[1] == n_tokens && n_outputs_new < n_tokens) { int32_t i_out = 0; if (u_batch.logits && !embd_pooled) {