llama: limit max outputs of llama_context (#23861)

* llama: save more VRAM by reserving n_outputs == n_seqs when possible

* add n_outputs_per_seq

* move n_outputs_max to server-context

* change ubatch to batch everywhere
This commit is contained in:
Aman Gupta 2026-06-01 23:01:38 +08:00 committed by GitHub
parent 95b8b8ec1a
commit de6f727aae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 71 additions and 11 deletions

View File

@ -1563,6 +1563,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
cparams.n_ctx = params.n_ctx;
cparams.n_seq_max = params.n_parallel;
cparams.n_rs_seq = params.speculative.need_n_rs_seq();
cparams.n_outputs_max = std::max(params.n_outputs_max, 0);
cparams.n_batch = params.n_batch;
cparams.n_ubatch = params.n_ubatch;
cparams.n_threads = params.cpuparams.n_threads;

View File

@ -431,6 +431,7 @@ struct common_params {
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
int32_t n_parallel = 1; // number of parallel sequences to decode
int32_t n_sequences = 1; // number of sequences to decode
int32_t n_outputs_max = 0; // max outputs in a batch (0 = n_batch)
int32_t grp_attn_n = 1; // group-attention factor
int32_t grp_attn_w = 512; // group-attention width
int32_t n_print = -1; // print token count every n tokens (-1 = disabled)

View File

@ -339,6 +339,7 @@ extern "C" {
uint32_t n_ubatch; // physical maximum batch size
uint32_t n_seq_max; // max number of sequences (i.e. distinct states for recurrent models)
uint32_t n_rs_seq; // number of recurrent-state snapshots per seq for rollback (0 = no rollback) [EXPERIMENTAL]
uint32_t n_outputs_max; // max outputs in a ubatch (0 = n_batch)
int32_t n_threads; // number of threads to use for generation
int32_t n_threads_batch; // number of threads to use for batch processing

View File

@ -182,6 +182,8 @@ llama_context::llama_context(
cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
cparams.n_outputs_max = params.n_outputs_max == 0 ? cparams.n_batch : params.n_outputs_max;
cparams.op_offload = params.op_offload;
cparams.kv_unified = params.kv_unified;
@ -531,7 +533,7 @@ void llama_context::sched_reserve() {
// note: n_outputs must match n_tokens for embedding models with mean/rank pooling,
// because build_pooling creates inp_mean with shape [n_tokens, n_seqs] and multiplies
// it with t_embd which is reduced to [n_outputs, ...] via out_ids. if n_outputs != n_tokens,
// the ggml_mul_mat assertion fails. this matches the pp reservation below (line ~553).
// the ggml_mul_mat assertion fails.
const uint32_t n_tokens_ch = 16*n_seqs;
auto * gf = graph_reserve(n_tokens_ch, n_seqs, n_tokens_ch, mctx.get(), true);
if (!gf) {
@ -577,16 +579,18 @@ void llama_context::sched_reserve() {
int n_splits_tg = -1;
int n_nodes_tg = -1;
const uint32_t n_outputs_pp = std::min(n_tokens, cparams.n_outputs_max);
// reserve pp (prompt processing) graph first so that buffers are only allocated once
{
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(),
auto * gf = graph_reserve(n_tokens, n_seqs, n_outputs_pp, mctx.get(),
model.hparams.no_alloc, model.hparams.no_alloc ? backend_buf_exp_size.data() : nullptr);
if (!gf) {
if (cparams.pipeline_parallel) {
LLAMA_LOG_WARN("%s: compute buffer allocation failed, retrying without pipeline parallelism\n", __func__);
cparams.pipeline_parallel = false;
sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, false, cparams.op_offload));
gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
gf = graph_reserve(n_tokens, n_seqs, n_outputs_pp, mctx.get());
}
if (!gf) {
throw std::runtime_error("failed to allocate compute pp buffers");
@ -614,7 +618,7 @@ void llama_context::sched_reserve() {
//
// auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get());
//
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), model.hparams.no_alloc);
auto * gf = graph_reserve(n_tokens, n_seqs, n_outputs_pp, mctx.get(), model.hparams.no_alloc);
if (!gf) {
throw std::runtime_error("failed to allocate compute pp buffers");
}
@ -774,7 +778,9 @@ bool llama_context::memory_update(bool optimize) {
const uint32_t n_seqs = cparams.n_seq_max;
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
const uint32_t n_outputs_max = std::min(n_tokens, cparams.n_outputs_max);
auto * gf = graph_reserve(n_tokens, n_seqs, n_outputs_max, mctx.get());
if (!gf) {
LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__);
}
@ -2140,6 +2146,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
this->n_outputs = 0;
GGML_ASSERT(n_outputs_max <= cparams.n_outputs_max);
return n_outputs_max;
}
@ -2226,8 +2234,6 @@ ggml_cgraph * llama_context::graph_reserve(
if (n_tokens % n_seqs != 0) {
n_tokens = ((n_tokens + (n_seqs - 1)) / n_seqs) * n_seqs; // round to next multiple of n_seqs
n_outputs = std::max(n_outputs, n_tokens);
LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs);
}
@ -3337,6 +3343,7 @@ llama_context_params llama_context_default_params() {
/*.n_ubatch =*/ 512,
/*.n_seq_max =*/ 1,
/*.n_rs_seq =*/ 0,
/*.n_outputs_max =*/ 0,
/*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default
/*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS,
/*.ctx_type =*/ LLAMA_CONTEXT_TYPE_DEFAULT,

View File

@ -13,6 +13,7 @@ struct llama_cparams {
uint32_t n_ubatch;
uint32_t n_seq_max;
uint32_t n_rs_seq; // number of recurrent-state snapshots per seq for rollback
uint32_t n_outputs_max; // max outputs supported by the context
int32_t n_threads; // number of threads to use for generation
int32_t n_threads_batch; // number of threads to use for batch processing

View File

@ -37,6 +37,49 @@ using json = nlohmann::ordered_json;
constexpr int HTTP_POLLING_SECONDS = 1;
static uint32_t server_n_outputs_max(const common_params & params) {
const uint32_t n_batch = params.n_batch;
if (params.embedding ||
(params.pooling_type != LLAMA_POOLING_TYPE_UNSPECIFIED && params.pooling_type != LLAMA_POOLING_TYPE_NONE)) {
return n_batch;
}
uint32_t n_outputs_per_seq = 1;
for (const auto type : params.speculative.types) {
switch (type) {
case COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE:
case COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3:
case COMMON_SPECULATIVE_TYPE_DRAFT_MTP:
n_outputs_per_seq = std::max<uint32_t>(n_outputs_per_seq, 1 + std::max(0, params.speculative.draft.n_max));
break;
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE:
n_outputs_per_seq = std::max<uint32_t>(n_outputs_per_seq, 1 + params.speculative.ngram_simple.size_m);
break;
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K:
n_outputs_per_seq = std::max<uint32_t>(n_outputs_per_seq, 1 + params.speculative.ngram_map_k.size_m);
break;
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V:
n_outputs_per_seq = std::max<uint32_t>(n_outputs_per_seq, 1 + params.speculative.ngram_map_k4v.size_m);
break;
case COMMON_SPECULATIVE_TYPE_NGRAM_MOD:
n_outputs_per_seq = std::max<uint32_t>(n_outputs_per_seq, 1 + std::max(0, params.speculative.ngram_mod.n_max));
break;
case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE:
n_outputs_per_seq = std::max<uint32_t>(n_outputs_per_seq, 1 + 8);
break;
case COMMON_SPECULATIVE_TYPE_NONE:
case COMMON_SPECULATIVE_TYPE_COUNT:
break;
}
}
const uint64_t n_outputs = (uint64_t) params.n_parallel * n_outputs_per_seq;
return std::max<uint32_t>(1, std::min<uint64_t>(n_batch, n_outputs));
}
// state diagram: https://github.com/ggml-org/llama.cpp/pull/9283
enum slot_state {
SLOT_STATE_IDLE,
@ -753,6 +796,7 @@ private:
SRV_INF("loading model '%s'\n", params.model.path.c_str());
params_base = params;
params_base.n_outputs_max = server_n_outputs_max(params_base);
std::string & mmproj_path = params_base.mmproj.path;
bool has_mmproj = !mmproj_path.empty();
@ -818,6 +862,10 @@ private:
measure_model_bytes = false;
}
if (!has_draft) {
params_dft.n_outputs_max = params_base.n_parallel;
}
auto mparams_dft = common_model_params_to_llama(params_dft);
auto cparams_dft = common_context_params_to_llama(params_dft);
if (spec_mtp) {
@ -941,10 +989,11 @@ private:
params_base.model.path.c_str());
auto cparams_mtp = common_context_params_to_llama(params_base);
cparams_mtp.ctx_type = LLAMA_CONTEXT_TYPE_MTP;
cparams_mtp.type_k = params_base.speculative.draft.cache_type_k;
cparams_mtp.type_v = params_base.speculative.draft.cache_type_v;
cparams_mtp.n_rs_seq = 0;
cparams_mtp.ctx_type = LLAMA_CONTEXT_TYPE_MTP;
cparams_mtp.type_k = params_base.speculative.draft.cache_type_k;
cparams_mtp.type_v = params_base.speculative.draft.cache_type_v;
cparams_mtp.n_rs_seq = 0;
cparams_mtp.n_outputs_max = params_base.n_parallel;
ctx_dft.reset(llama_init_from_model(model_tgt, cparams_mtp));
if (ctx_dft == nullptr) {