From de6f727aaec7dc477629946d80c803a0bb7af0a1 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Mon, 1 Jun 2026 23:01:38 +0800 Subject: [PATCH] llama: limit max outputs of `llama_context` (#23861) * llama: save more VRAM by reserving n_outputs == n_seqs when possible * add n_outputs_per_seq * move n_outputs_max to server-context * change ubatch to batch everywhere --- common/common.cpp | 1 + common/common.h | 1 + include/llama.h | 1 + src/llama-context.cpp | 21 ++++++++---- src/llama-cparams.h | 1 + tools/server/server-context.cpp | 57 ++++++++++++++++++++++++++++++--- 6 files changed, 71 insertions(+), 11 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 97daf28178..81b8b75002 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1563,6 +1563,7 @@ struct llama_context_params common_context_params_to_llama(const common_params & cparams.n_ctx = params.n_ctx; cparams.n_seq_max = params.n_parallel; cparams.n_rs_seq = params.speculative.need_n_rs_seq(); + cparams.n_outputs_max = std::max(params.n_outputs_max, 0); cparams.n_batch = params.n_batch; cparams.n_ubatch = params.n_ubatch; cparams.n_threads = params.cpuparams.n_threads; diff --git a/common/common.h b/common/common.h index 99898800d1..92064a0e4d 100644 --- a/common/common.h +++ b/common/common.h @@ -431,6 +431,7 @@ struct common_params { int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited) int32_t n_parallel = 1; // number of parallel sequences to decode int32_t n_sequences = 1; // number of sequences to decode + int32_t n_outputs_max = 0; // max outputs in a batch (0 = n_batch) int32_t grp_attn_n = 1; // group-attention factor int32_t grp_attn_w = 512; // group-attention width int32_t n_print = -1; // print token count every n tokens (-1 = disabled) diff --git a/include/llama.h b/include/llama.h index e8374c53b7..a79a491c59 100644 --- a/include/llama.h +++ b/include/llama.h @@ -339,6 +339,7 @@ extern "C" { uint32_t n_ubatch; // physical maximum batch size uint32_t n_seq_max; // max number of sequences (i.e. distinct states for recurrent models) uint32_t n_rs_seq; // number of recurrent-state snapshots per seq for rollback (0 = no rollback) [EXPERIMENTAL] + uint32_t n_outputs_max; // max outputs in a ubatch (0 = n_batch) int32_t n_threads; // number of threads to use for generation int32_t n_threads_batch; // number of threads to use for batch processing diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 99bdf092b3..a7dfb24924 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -182,6 +182,8 @@ llama_context::llama_context( cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch); + cparams.n_outputs_max = params.n_outputs_max == 0 ? cparams.n_batch : params.n_outputs_max; + cparams.op_offload = params.op_offload; cparams.kv_unified = params.kv_unified; @@ -531,7 +533,7 @@ void llama_context::sched_reserve() { // note: n_outputs must match n_tokens for embedding models with mean/rank pooling, // because build_pooling creates inp_mean with shape [n_tokens, n_seqs] and multiplies // it with t_embd which is reduced to [n_outputs, ...] via out_ids. if n_outputs != n_tokens, - // the ggml_mul_mat assertion fails. this matches the pp reservation below (line ~553). + // the ggml_mul_mat assertion fails. const uint32_t n_tokens_ch = 16*n_seqs; auto * gf = graph_reserve(n_tokens_ch, n_seqs, n_tokens_ch, mctx.get(), true); if (!gf) { @@ -577,16 +579,18 @@ void llama_context::sched_reserve() { int n_splits_tg = -1; int n_nodes_tg = -1; + const uint32_t n_outputs_pp = std::min(n_tokens, cparams.n_outputs_max); + // reserve pp (prompt processing) graph first so that buffers are only allocated once { - auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), + auto * gf = graph_reserve(n_tokens, n_seqs, n_outputs_pp, mctx.get(), model.hparams.no_alloc, model.hparams.no_alloc ? backend_buf_exp_size.data() : nullptr); if (!gf) { if (cparams.pipeline_parallel) { LLAMA_LOG_WARN("%s: compute buffer allocation failed, retrying without pipeline parallelism\n", __func__); cparams.pipeline_parallel = false; sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, false, cparams.op_offload)); - gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); + gf = graph_reserve(n_tokens, n_seqs, n_outputs_pp, mctx.get()); } if (!gf) { throw std::runtime_error("failed to allocate compute pp buffers"); @@ -614,7 +618,7 @@ void llama_context::sched_reserve() { // // auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get()); // - auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), model.hparams.no_alloc); + auto * gf = graph_reserve(n_tokens, n_seqs, n_outputs_pp, mctx.get(), model.hparams.no_alloc); if (!gf) { throw std::runtime_error("failed to allocate compute pp buffers"); } @@ -774,7 +778,9 @@ bool llama_context::memory_update(bool optimize) { const uint32_t n_seqs = cparams.n_seq_max; const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); - auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); + const uint32_t n_outputs_max = std::min(n_tokens, cparams.n_outputs_max); + + auto * gf = graph_reserve(n_tokens, n_seqs, n_outputs_max, mctx.get()); if (!gf) { LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__); } @@ -2140,6 +2146,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { this->n_outputs = 0; + GGML_ASSERT(n_outputs_max <= cparams.n_outputs_max); + return n_outputs_max; } @@ -2226,8 +2234,6 @@ ggml_cgraph * llama_context::graph_reserve( if (n_tokens % n_seqs != 0) { n_tokens = ((n_tokens + (n_seqs - 1)) / n_seqs) * n_seqs; // round to next multiple of n_seqs - n_outputs = std::max(n_outputs, n_tokens); - LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs); } @@ -3337,6 +3343,7 @@ llama_context_params llama_context_default_params() { /*.n_ubatch =*/ 512, /*.n_seq_max =*/ 1, /*.n_rs_seq =*/ 0, + /*.n_outputs_max =*/ 0, /*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default /*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS, /*.ctx_type =*/ LLAMA_CONTEXT_TYPE_DEFAULT, diff --git a/src/llama-cparams.h b/src/llama-cparams.h index 20ec59fe33..ba4951a09a 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -13,6 +13,7 @@ struct llama_cparams { uint32_t n_ubatch; uint32_t n_seq_max; uint32_t n_rs_seq; // number of recurrent-state snapshots per seq for rollback + uint32_t n_outputs_max; // max outputs supported by the context int32_t n_threads; // number of threads to use for generation int32_t n_threads_batch; // number of threads to use for batch processing diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index bfe3443c1d..a5a6882545 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -37,6 +37,49 @@ using json = nlohmann::ordered_json; constexpr int HTTP_POLLING_SECONDS = 1; +static uint32_t server_n_outputs_max(const common_params & params) { + const uint32_t n_batch = params.n_batch; + + if (params.embedding || + (params.pooling_type != LLAMA_POOLING_TYPE_UNSPECIFIED && params.pooling_type != LLAMA_POOLING_TYPE_NONE)) { + return n_batch; + } + + uint32_t n_outputs_per_seq = 1; + + for (const auto type : params.speculative.types) { + switch (type) { + case COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE: + case COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3: + case COMMON_SPECULATIVE_TYPE_DRAFT_MTP: + n_outputs_per_seq = std::max(n_outputs_per_seq, 1 + std::max(0, params.speculative.draft.n_max)); + break; + case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: + n_outputs_per_seq = std::max(n_outputs_per_seq, 1 + params.speculative.ngram_simple.size_m); + break; + case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: + n_outputs_per_seq = std::max(n_outputs_per_seq, 1 + params.speculative.ngram_map_k.size_m); + break; + case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: + n_outputs_per_seq = std::max(n_outputs_per_seq, 1 + params.speculative.ngram_map_k4v.size_m); + break; + case COMMON_SPECULATIVE_TYPE_NGRAM_MOD: + n_outputs_per_seq = std::max(n_outputs_per_seq, 1 + std::max(0, params.speculative.ngram_mod.n_max)); + break; + case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: + n_outputs_per_seq = std::max(n_outputs_per_seq, 1 + 8); + break; + case COMMON_SPECULATIVE_TYPE_NONE: + case COMMON_SPECULATIVE_TYPE_COUNT: + break; + } + } + + const uint64_t n_outputs = (uint64_t) params.n_parallel * n_outputs_per_seq; + + return std::max(1, std::min(n_batch, n_outputs)); +} + // state diagram: https://github.com/ggml-org/llama.cpp/pull/9283 enum slot_state { SLOT_STATE_IDLE, @@ -753,6 +796,7 @@ private: SRV_INF("loading model '%s'\n", params.model.path.c_str()); params_base = params; + params_base.n_outputs_max = server_n_outputs_max(params_base); std::string & mmproj_path = params_base.mmproj.path; bool has_mmproj = !mmproj_path.empty(); @@ -818,6 +862,10 @@ private: measure_model_bytes = false; } + if (!has_draft) { + params_dft.n_outputs_max = params_base.n_parallel; + } + auto mparams_dft = common_model_params_to_llama(params_dft); auto cparams_dft = common_context_params_to_llama(params_dft); if (spec_mtp) { @@ -941,10 +989,11 @@ private: params_base.model.path.c_str()); auto cparams_mtp = common_context_params_to_llama(params_base); - cparams_mtp.ctx_type = LLAMA_CONTEXT_TYPE_MTP; - cparams_mtp.type_k = params_base.speculative.draft.cache_type_k; - cparams_mtp.type_v = params_base.speculative.draft.cache_type_v; - cparams_mtp.n_rs_seq = 0; + cparams_mtp.ctx_type = LLAMA_CONTEXT_TYPE_MTP; + cparams_mtp.type_k = params_base.speculative.draft.cache_type_k; + cparams_mtp.type_v = params_base.speculative.draft.cache_type_v; + cparams_mtp.n_rs_seq = 0; + cparams_mtp.n_outputs_max = params_base.n_parallel; ctx_dft.reset(llama_init_from_model(model_tgt, cparams_mtp)); if (ctx_dft == nullptr) {