diff --git a/common/common.cpp b/common/common.cpp index a3593b07..dccfe1db 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -95,6 +95,195 @@ common_time_meas::~common_time_meas() { t_acc += ggml_time_us() - t_start_us; } } + +bool common_speculative_type_is_self_spec(enum common_speculative_type type) { + switch (type) { + case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: + case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: + case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: + case COMMON_SPECULATIVE_TYPE_NGRAM_MOD: + case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: + case COMMON_SPECULATIVE_TYPE_SUFFIX: + return true; + default: + return false; + } +} + +static int32_t common_speculative_stage_effective_n_max( + const common_params_speculative & params, + const common_speculative_stage_params & stage) { + return stage.has_n_max_override() ? stage.n_max : params.n_max; +} + +static int32_t common_speculative_stage_effective_n_min( + const common_params_speculative & params, + const common_speculative_stage_params & stage) { + return stage.has_n_min_override() ? stage.n_min : params.n_min; +} + +std::vector common_params_speculative::get_resolved_stages() const { + if (!stages.empty()) { + return stages; + } + + if (type == COMMON_SPECULATIVE_TYPE_NONE) { + return {}; + } + + return {{ .type = type }}; +} + +common_params_speculative common_params_speculative::with_stage_overrides(const common_speculative_stage_params & stage) const { + common_params_speculative result = *this; + + result.type = stage.type; + + if (stage.has_n_max_override()) { + result.n_max = stage.n_max; + } + if (stage.has_n_min_override()) { + result.n_min = stage.n_min; + } + if (stage.has_p_min_override()) { + result.p_min = stage.p_min; + } + if (stage.has_ngram_size_n_override()) { + result.ngram_size_n = stage.ngram_size_n; + result.ngram_mod.reset(); + } + if (stage.has_ngram_size_m_override()) { + result.ngram_size_m = stage.ngram_size_m; + } + if (stage.has_ngram_min_hits_override()) { + result.ngram_min_hits = stage.ngram_min_hits; + } + if (stage.has_suffix_min_match_len_override()) { + result.suffix_min_match_len = stage.suffix_min_match_len; + } + if (stage.has_suffix_max_depth_override()) { + result.suffix_max_depth = stage.suffix_max_depth; + } + + result.n_max = std::max(result.n_max, 0); + result.n_min = std::max(0, std::min(result.n_min, result.n_max)); + result.stages.clear(); + + return result; +} + +bool common_params_speculative::has_stage_chain() const { + return !get_resolved_stages().empty(); +} + +bool common_params_speculative::has_stage_type(common_speculative_type stage_type) const { + const auto resolved = get_resolved_stages(); + return std::any_of(resolved.begin(), resolved.end(), [stage_type](const common_speculative_stage_params & stage) { + return stage.type == stage_type; + }); +} + +bool common_params_speculative::has_composite_stage_chain() const { + return get_resolved_stages().size() > 1; +} + +int32_t common_params_speculative::get_max_stage_n_max() const { + const auto resolved = get_resolved_stages(); + if (resolved.empty()) { + return std::max(n_max, 0); + } + + int32_t max_n_max = 0; + for (const auto & stage : resolved) { + max_n_max = std::max(max_n_max, common_speculative_stage_effective_n_max(*this, stage)); + } + + return std::max(max_n_max, 0); +} + +int32_t common_params_speculative::get_min_usable_stage_n_min() const { + const auto resolved = get_resolved_stages(); + if (resolved.empty()) { + return std::max(0, std::min(n_min, n_max)); + } + + int32_t min_n_min = INT_MAX; + for (const auto & stage : resolved) { + min_n_min = std::min(min_n_min, std::max(0, std::min(common_speculative_stage_effective_n_min(*this, stage), common_speculative_stage_effective_n_max(*this, stage)))); + } + + return min_n_min == INT_MAX ? 0 : min_n_min; +} + +bool common_speculative_validate_chain(const common_params_speculative & params, std::string * error) { + const auto fail = [error](const std::string & msg) { + if (error != nullptr) { + *error = msg; + } + return false; + }; + + const auto resolved = params.get_resolved_stages(); + if (resolved.empty()) { + return true; + } + + if (resolved.size() > 2) { + return fail("at most two speculative stages are supported in this PR"); + } + + std::unordered_set seen_types; + for (const auto & stage : resolved) { + if (stage.type == COMMON_SPECULATIVE_TYPE_NONE && resolved.size() > 1) { + return fail("the 'none' speculative stage cannot be combined with other stages"); + } + + if (!seen_types.insert((int) stage.type).second) { + return fail("duplicate speculative stage type in chain: " + common_speculative_type_to_str(stage.type)); + } + + const auto stage_params = params.with_stage_overrides(stage); + if (stage_params.n_min > stage_params.n_max) { + return fail("speculative stage has n_min greater than n_max"); + } + + if (stage.type == COMMON_SPECULATIVE_TYPE_DRAFT && !params.has_dft()) { + return fail("draft speculative stage requires a draft model or draft params"); + } + } + + if (resolved.size() == 2) { + const auto first = resolved[0].type; + const auto second = resolved[1].type; + + if (!common_speculative_type_is_self_spec(first)) { + return fail("two-stage speculative mode currently requires a self-spec stage first"); + } + + if (second != COMMON_SPECULATIVE_TYPE_MTP && second != COMMON_SPECULATIVE_TYPE_DRAFT) { + return fail("two-stage speculative mode currently supports only MTP or draft-model fallback after self-spec"); + } + } + + return true; +} + +std::string common_speculative_stage_chain_to_str(const common_params_speculative & params) { + const auto resolved = params.get_resolved_stages(); + if (resolved.empty()) { + return "none"; + } + + std::ostringstream oss; + for (size_t i = 0; i < resolved.size(); ++i) { + if (i > 0) { + oss << " -> "; + } + oss << common_speculative_type_to_str(resolved[i].type); + } + + return oss.str(); +} // // Environment variable utils // @@ -419,6 +608,38 @@ static bool is_autoy(const std::string & value) { return value == "auto" || value == "-1"; } +static void common_speculative_finalize_stages(gpt_params & params) { + auto & spec = params.speculative; + + if (!spec.stages.empty()) { + spec.type = spec.stages.front().type; + params.has_mtp = spec.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP); + return; + } + + const bool wants_mtp = params.has_mtp; + const bool wants_draft = spec.has_dft(); + + if (spec.type != COMMON_SPECULATIVE_TYPE_NONE) { + spec.stages.push_back({ .type = spec.type }); + + if (common_speculative_type_is_self_spec(spec.type)) { + if (wants_mtp) { + spec.stages.push_back({ .type = COMMON_SPECULATIVE_TYPE_MTP }); + } else if (wants_draft) { + spec.stages.push_back({ .type = COMMON_SPECULATIVE_TYPE_DRAFT }); + } + } + } else if (wants_mtp) { + spec.stages.push_back({ .type = COMMON_SPECULATIVE_TYPE_MTP }); + } else if (wants_draft) { + spec.stages.push_back({ .type = COMMON_SPECULATIVE_TYPE_DRAFT }); + } + + spec.type = spec.stages.empty() ? COMMON_SPECULATIVE_TYPE_NONE : spec.stages.front().type; + params.has_mtp = spec.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP); +} + bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { bool invalid_param = false; std::string arg; @@ -480,6 +701,14 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { params.use_jinja ? "" : "\nnote: llama.cpp was started without --jinja, we only support commonly used templates" )); } + + common_speculative_finalize_stages(params); + + std::string spec_error; + if (!common_speculative_validate_chain(params.speculative, &spec_error)) { + throw std::invalid_argument("error: invalid speculative stage configuration: " + spec_error); + } + return true; } @@ -586,6 +815,124 @@ std::vector> string_split_pairs(const std::string & str, char d } return values; } + +static std::string common_normalize_spec_stage_key(std::string key) { + while (!key.empty() && key.front() == '-') { + key.erase(key.begin()); + } + + std::replace(key.begin(), key.end(), '-', '_'); + + if (key.rfind("spec_", 0) == 0) { + key.erase(0, 5); + } + + return key; +} + +static void common_speculative_remove_explicit_stage(common_params_speculative & params, common_speculative_type type) { + params.stages.erase(std::remove_if(params.stages.begin(), params.stages.end(), [type](const common_speculative_stage_params & stage) { + return stage.type == type; + }), params.stages.end()); + + if (params.stages.empty() && params.type == type) { + params.type = COMMON_SPECULATIVE_TYPE_NONE; + } +} + +static void common_speculative_stage_apply_kv( + common_speculative_stage_params & stage, + const std::string & key_raw, + const std::string & value_raw) { + const std::string key = common_normalize_spec_stage_key(key_raw); + + if (key == "draft" || key == "draft_max" || key == "draft_n" || key == "n_max") { + stage.n_max = std::stoi(value_raw); + if (stage.n_max < 0) { + throw std::invalid_argument("speculative stage n_max must be >= 0"); + } + return; + } + if (key == "draft_min" || key == "draft_n_min" || key == "n_min") { + stage.n_min = std::stoi(value_raw); + if (stage.n_min < 0) { + throw std::invalid_argument("speculative stage n_min must be >= 0"); + } + return; + } + if (key == "draft_p_min" || key == "p_min") { + stage.p_min = std::stof(value_raw); + if (stage.p_min < 0.0f) { + throw std::invalid_argument("speculative stage p_min must be >= 0"); + } + return; + } + if (key == "ngram_size_n") { + stage.ngram_size_n = std::stoi(value_raw); + if (stage.ngram_size_n < 1 || stage.ngram_size_n > 1024) { + throw std::invalid_argument("speculative stage ngram_size_n must be between 1 and 1024 inclusive"); + } + return; + } + if (key == "ngram_size_m") { + stage.ngram_size_m = std::stoi(value_raw); + if (stage.ngram_size_m < 1 || stage.ngram_size_m > 1024) { + throw std::invalid_argument("speculative stage ngram_size_m must be between 1 and 1024 inclusive"); + } + return; + } + if (key == "ngram_min_hits") { + stage.ngram_min_hits = std::stoi(value_raw); + if (stage.ngram_min_hits < 1) { + throw std::invalid_argument("speculative stage ngram_min_hits must be at least 1"); + } + return; + } + if (key == "suffix_min_match_len" || key == "suffix_pattern_len") { + stage.suffix_min_match_len = std::stoi(value_raw); + if (stage.suffix_min_match_len < 1) { + throw std::invalid_argument("speculative stage suffix_min_match_len must be at least 1"); + } + return; + } + if (key == "suffix_max_depth") { + stage.suffix_max_depth = std::stoi(value_raw); + if (stage.suffix_max_depth < 1) { + throw std::invalid_argument("speculative stage suffix_max_depth must be at least 1"); + } + return; + } + + throw std::invalid_argument("unknown speculative stage parameter: " + key_raw); +} + +static common_speculative_stage_params common_speculative_stage_from_arg(const std::string & value) { + const auto spec_pos = value.find(':'); + const std::string type_name = value.substr(0, spec_pos); + + common_speculative_stage_params stage; + stage.type = common_speculative_type_from_name(type_name); + if (stage.type == COMMON_SPECULATIVE_TYPE_COUNT) { + throw std::invalid_argument("unknown speculative stage type: " + type_name); + } + + if (spec_pos == std::string::npos) { + return stage; + } + + std::stringstream ss(value.substr(spec_pos + 1)); + std::string kv; + while (std::getline(ss, kv, ',')) { + const auto eq_pos = kv.find('='); + if (eq_pos == std::string::npos) { + throw std::invalid_argument("invalid speculative stage option: " + kv); + } + + common_speculative_stage_apply_kv(stage, kv.substr(0, eq_pos), kv.substr(eq_pos + 1)); + } + + return stage; +} } #define CHECK_ARG if (++i >= argc) { invalid_param = true; return true; } @@ -1092,26 +1439,37 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.speculative.model = argv[i]; return true; } + if (arg == "--spec-stage") { + CHECK_ARG + + if (params.speculative.stages.empty()) { + if (params.speculative.type != COMMON_SPECULATIVE_TYPE_NONE) { + throw std::invalid_argument("--spec-stage cannot be combined with --spec-type; use only --spec-stage for explicit stage chains"); + } + if (params.has_mtp) { + throw std::invalid_argument("--spec-stage cannot be combined with -mtp/--multi-token-prediction; add the mtp fallback explicitly with --spec-stage mtp[:k=v,...]"); + } + } + + params.speculative.stages.push_back(common_speculative_stage_from_arg(argv[i])); + if (params.speculative.stages.size() == 1) { + params.speculative.type = params.speculative.stages.front().type; + } + params.has_mtp = params.speculative.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP); + return true; + } if (arg == "--spec-type") { CHECK_ARG - std::string value = argv[i]; - if (value == "none") { - params.speculative.type = COMMON_SPECULATIVE_TYPE_NONE; - } else if (value == "ngram-cache") { - params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_CACHE; - } else if (value == "ngram-simple") { - params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE; - } else if (value == "ngram-map-k") { - params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K; - } else if (value == "ngram-map-k4v") { - params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V; - } else if (value == "ngram-mod") { - params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MOD; - } else if (value == "suffix") { - params.speculative.type = COMMON_SPECULATIVE_TYPE_SUFFIX; - } else if (value == "mtp") { - params.speculative.type = COMMON_SPECULATIVE_TYPE_MTP; - params.has_mtp = true; + if (!params.speculative.stages.empty()) { + throw std::invalid_argument("--spec-type cannot be combined with --spec-stage; use only --spec-stage for explicit stage chains"); + } + + const auto type = common_speculative_type_from_name(argv[i]); + if (type == COMMON_SPECULATIVE_TYPE_NONE || type == COMMON_SPECULATIVE_TYPE_MTP || common_speculative_type_is_self_spec(type)) { + params.speculative.type = type; + if (type == COMMON_SPECULATIVE_TYPE_MTP) { + params.has_mtp = true; + } } else { throw std::invalid_argument("unknown speculative decoding type"); } @@ -1588,11 +1946,16 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "-mtp" || arg == "--multi-token-prediction") { + if (!params.speculative.stages.empty()) { + throw std::invalid_argument("-mtp/--multi-token-prediction cannot be combined with --spec-stage; add the mtp fallback explicitly with --spec-stage mtp[:k=v,...]"); + } + params.has_mtp = true; return true; } if (arg == "-no-mtp" || arg == "--no-multi-token-prediction") { params.has_mtp = false; + common_speculative_remove_explicit_stage(params.speculative, COMMON_SPECULATIVE_TYPE_MTP); return true; } if (arg == "-draft" || arg == "--draft-params") { @@ -2766,18 +3129,21 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", "-hfr, --hf-repo REPO", "Hugging Face model repository (default: unused)" }); options.push_back({ "*", "-hff, --hf-file FILE", "Hugging Face model file (default: unused)" }); options.push_back({ "*", "-hft, --hf-token TOKEN", "Hugging Face access token (default: value from HF_TOKEN environment variable)" }); - options.push_back({ "*", "-mtp, --multi-token-prediction", "whether to use multi-token-prediction (if supported) (default: %s)", params.has_mtp ? "true" : "false" }); - options.push_back({ "*", "-no-mtp, --no-multi-token-prediction", "whether to use multi-token-prediction (if supported) (default: %s)", !params.has_mtp ? "true" : "false" }); + options.push_back({ "*", "-mtp, --multi-token-prediction", "legacy shortcut for enabling MTP when --spec-stage is not used (default: %s)", params.has_mtp ? "true" : "false" }); + options.push_back({ "*", "-no-mtp, --no-multi-token-prediction", "disable the legacy MTP shortcut or remove an explicit MTP stage (default: %s)", !params.has_mtp ? "true" : "false" }); options.push_back({ "*", "--draft-max, --draft, --draft-n N", - "number of tokens to draft for speculative decoding (default: %d)", params.speculative.n_max }); - options.push_back({ "*", "--draft-min, --draft-n-min N", "minimum number of draft tokens to use for speculative decoding" }); - options.push_back({ "*", "--draft-p-min P", "minimum speculative decoding probability (greedy) (default: %.1f)", (double)params.speculative.p_min }); + "global default number of tokens to draft for speculative decoding or for stages without an explicit n_max override (default: %d)", params.speculative.n_max }); + options.push_back({ "*", "--draft-min, --draft-n-min N", "global default minimum draft threshold or fallback threshold for stages without an explicit n_min override" }); + options.push_back({ "*", "--draft-p-min P", "global default minimum speculative decoding probability (greedy) for stages without an explicit p_min override (default: %.1f)", (double)params.speculative.p_min }); options.push_back({ "*", "--recurrent-ckpt-mode MODE", "checkpoint strategy for recurrent/hybrid speculative decoding\n" " auto auto-select: per-step if CUDA full-GPU, gpu-fallback otherwise (default)\n" " per-step save SSM state per draft step in VRAM; no re-decode on rejection\n" " gpu-fallback copy state to GPU buffer; re-decode on rejection\n" " cpu serialise state via llama_state_seq; re-decode on rejection" }); - options.push_back({ "*", "--spec-type Name [none | mtp | ngram - cache | ngram - simple | ngram - map - k | ngram - map - k4v | ngram - mod | suffix]", "type of speculative decoding to use (default: %d)\n", (int)params.speculative.type}); + options.push_back({ "*", "--spec-stage SPEC[:k=v,...]", "explicit speculative stage. repeat once for a supported two-stage chain.\n" + "examples: --spec-stage ngram-mod:n_max=64,n_min=2 --spec-stage mtp:n_max=1\n" + "supported two-stage shape in this PR: self-spec first, then mtp or draft fallback" }); + options.push_back({ "*", "--spec-type Name [none | mtp | ngram-cache | ngram-simple | ngram-map-k | ngram-map-k4v | ngram-mod | suffix]", "single-stage speculative selection when --spec-stage is not used (default: %d)\n", (int)params.speculative.type}); options.push_back({ "*", "--spec-ngram-size-n N", "ngram size N for ngram-simple/ngram-map speculative decoding, length of lookup n-gram (default: %d)\n",params.speculative.ngram_size_n }); options.push_back({ "*", "--spec-ngram-size-m N", "ngram size M for ngram-simple/ngram-map speculative decoding, length of draft m-gram (default: %d)\n", params.speculative.ngram_size_m }); @@ -3600,7 +3966,7 @@ struct llama_model_params common_model_params_to_llama(const gpt_params & params mparams.validate_quants = params.validate_quants; mparams.merge_qkv = params.merge_qkv; mparams.merge_up_gate_exps = params.merge_up_gate_exps; - mparams.mtp = params.has_mtp; + mparams.mtp = params.speculative.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP); mparams.flash_attn = params.flash_attn; mparams.defer_experts = params.defer_experts; if (params.kv_overrides.empty()) { @@ -3685,7 +4051,7 @@ struct llama_context_params common_context_params_to_llama(const gpt_params & pa cparams.thresh_experts = params.thresh_experts; cparams.only_active_experts = params.only_active_exps; cparams.max_extra_alloc = params.max_extra_alloc_MiB; - cparams.mtp = params.has_mtp; + cparams.mtp = params.speculative.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP); cparams.mtp_op_type = MTP_OP_NONE; cparams.type_k = kv_cache_type_from_str(params.cache_type_k); diff --git a/common/common.h b/common/common.h index 62d0295f..2ea38928 100644 --- a/common/common.h +++ b/common/common.h @@ -151,6 +151,35 @@ enum common_speculative_type { COMMON_SPECULATIVE_TYPE_COUNT // number of types, unknown type }; +std::string common_speculative_type_name_str(); +enum common_speculative_type common_speculative_type_from_name(const std::string & name); +std::string common_speculative_type_to_str(enum common_speculative_type type); +bool common_speculative_type_is_self_spec(enum common_speculative_type type); + +struct common_speculative_stage_params { + common_speculative_type type = COMMON_SPECULATIVE_TYPE_NONE; + + int32_t n_max = -1; + int32_t n_min = -1; + float p_min = -1.0f; + + uint16_t ngram_size_n = 0; + uint16_t ngram_size_m = 0; + uint16_t ngram_min_hits = 0; + + int32_t suffix_min_match_len = -1; + int32_t suffix_max_depth = -1; + + bool has_n_max_override() const { return n_max >= 0; } + bool has_n_min_override() const { return n_min >= 0; } + bool has_p_min_override() const { return p_min >= 0.0f; } + bool has_ngram_size_n_override() const { return ngram_size_n > 0; } + bool has_ngram_size_m_override() const { return ngram_size_m > 0; } + bool has_ngram_min_hits_override() const { return ngram_min_hits > 0; } + bool has_suffix_min_match_len_override() const { return suffix_min_match_len >= 0; } + bool has_suffix_max_depth_override() const { return suffix_max_depth >= 0; } +}; + struct common_params_model { std::string path = ""; // model local path // NOLINT std::string url = ""; // model url to download // NOLINT @@ -174,6 +203,7 @@ struct common_params_speculative { int32_t n_max = 16; // number of tokens to draft during speculative decoding int32_t n_min = 0; // minimum number of tokens to draft during speculative decoding + std::vector stages; // explicit stage chain for single-spec or self-spec + model fallback float p_split = 0.1f; // speculative decoding split probability float p_min = 0.75f; // minimum speculative decoding probability (greedy) @@ -216,8 +246,19 @@ struct common_params_speculative { //return !mparams_dft.path.empty() || !mparams_dft.hf_repo.empty(); } + std::vector get_resolved_stages() const; + common_params_speculative with_stage_overrides(const common_speculative_stage_params & stage) const; + bool has_stage_chain() const; + bool has_stage_type(common_speculative_type stage_type) const; + bool has_composite_stage_chain() const; + int32_t get_max_stage_n_max() const; + int32_t get_min_usable_stage_n_min() const; + }; +bool common_speculative_validate_chain(const common_params_speculative & params, std::string * error = nullptr); +std::string common_speculative_stage_chain_to_str(const common_params_speculative & params); + struct gpt_params { std::string devices; uint32_t seed = LLAMA_DEFAULT_SEED; // RNG seed diff --git a/common/speculative.cpp b/common/speculative.cpp index 10f365b3..e8291727 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -49,11 +49,14 @@ const std::map common_speculative_typ }; struct common_speculative_config { + common_speculative_stage_params stage; common_speculative_type type; common_params_speculative params; - common_speculative_config(common_speculative_type t, - const common_params_speculative & p = common_params_speculative{}) : type(t), params(p) {} + common_speculative_config( + const common_speculative_stage_params & s, + const common_params_speculative & p = common_params_speculative{}) + : stage(s), type(s.type), params(p) {} }; static bool common_speculative_are_compatible( @@ -165,6 +168,8 @@ struct common_speculative_state { virtual void accept(uint16_t n_accepted) = 0; }; +static void mtp_invalidate_cached_draft(const llama_context * ctx); + struct common_speculative_state_mtp : public common_speculative_state { llama_context * ctx_tgt; llama_context * ctx_mtp = nullptr; @@ -202,6 +207,7 @@ struct common_speculative_state_mtp : public common_speculative_state { void begin(const llama_tokens & prompt) override { GGML_UNUSED(prompt); + mtp_invalidate_cached_draft(ctx_mtp); } void draft( @@ -952,6 +958,7 @@ struct common_speculative_state_suffix : public common_speculative_state { }; struct common_speculative { + std::vector configs; // resolved stage config for each implementation std::vector> impls; // list of implementations to use and their states common_speculative_state * curr_impl = nullptr; // current implementation in use (for stats) std::unique_ptr tuner; @@ -959,6 +966,46 @@ struct common_speculative { int64_t t_step_start_us = 0; }; +static bool common_speculative_stage_chain_matches( + const std::vector & stages, + const std::vector & configs) { + if (stages.size() != configs.size()) { + return false; + } + + for (size_t i = 0; i < stages.size(); ++i) { + if (stages[i].type != configs[i].type) { + return false; + } + } + + return true; +} + +static common_params_speculative common_speculative_get_runtime_params( + const common_speculative_config & config, + const common_params_speculative & params, + const common_speculative_stage_params & stage) { + common_params_speculative result = config.params; + + result.type = config.type; + result.n_max = stage.has_n_max_override() ? stage.n_max : params.n_max; + result.n_min = stage.has_n_min_override() ? stage.n_min : params.n_min; + result.p_min = stage.has_p_min_override() ? stage.p_min : params.p_min; + + if (config.type == COMMON_SPECULATIVE_TYPE_SUFFIX) { + result.suffix_min_match_len = stage.has_suffix_min_match_len_override() + ? stage.suffix_min_match_len + : params.suffix_min_match_len; + } + + result.n_max = std::max(result.n_max, 0); + result.n_min = std::max(0, std::min(result.n_min, result.n_max)); + result.stages.clear(); + + return result; +} + static common_ngram_map get_common_ngram_map(const common_speculative_config & config) { uint16_t size_key = config.params.ngram_size_n; uint16_t size_value = config.params.ngram_size_m; @@ -1010,7 +1057,10 @@ std::string common_speculative_type_to_str(enum common_speculative_type type) { } enum common_speculative_type common_speculative_type_from_name(const std::string & name) { - const auto it = common_speculative_type_from_name_map.find(name); + std::string normalized = name; + std::replace(normalized.begin(), normalized.end(), '-', '_'); + + const auto it = common_speculative_type_from_name_map.find(normalized); if (it == common_speculative_type_from_name_map.end()) { return COMMON_SPECULATIVE_TYPE_COUNT; } @@ -1053,8 +1103,36 @@ done: common_speculative * common_speculative_init( common_params_speculative & params, llama_context * ctx_tgt) { + std::string chain_error; + if (!common_speculative_validate_chain(params, &chain_error)) { + LOG_ERR("%s: invalid speculative stage chain: %s\n", __func__, chain_error.c_str()); + return nullptr; + } + + const auto stages = params.get_resolved_stages(); + if (params.model_dft && llama_model_is_gemma4_mtp_assistant(params.model_dft)) { + const bool has_draft_stage = std::any_of(stages.begin(), stages.end(), [](const common_speculative_stage_params & stage) { + return stage.type == COMMON_SPECULATIVE_TYPE_DRAFT; + }); + + if (has_draft_stage) { + LOG_ERR("%s: Gemma4 assistant models only support MTP stages; omit -md for self-spec-only runs or use -mtp/--spec-stage mtp for assistant-backed MTP\n", __func__); + return nullptr; + } + } + + const bool needs_draft_ctx = std::any_of(stages.begin(), stages.end(), [¶ms](const common_speculative_stage_params & stage) { + return stage.type == COMMON_SPECULATIVE_TYPE_DRAFT || + (stage.type == COMMON_SPECULATIVE_TYPE_MTP && params.model_dft != nullptr); + }); + llama_context * ctx_dft = nullptr; - if (params.model_dft) { + if (needs_draft_ctx) { + if (!params.model_dft) { + LOG_ERR("%s: draft speculative stage requires a loaded draft model\n", __func__); + return nullptr; + } + ctx_dft = llama_init_from_model(params.model_dft, params.cparams_dft); if (ctx_dft == nullptr) { LOG_ERR("%s", "failed to create draft context\n"); @@ -1062,68 +1140,30 @@ common_speculative * common_speculative_init( } } - // Compute the implementations to use based on the config and their order of preference - std::vector configs = {}; // list of speculative configs to try - { - bool has_draft_eagle3 = false; // TODO PR-18039: if params.speculative.eagle3 - bool has_mtp = (params.type == COMMON_SPECULATIVE_TYPE_MTP); - bool has_draft = !params.mparams_dft.path.empty() && !has_mtp; + // Compute the implementations to use based on the resolved stage chain. + std::vector configs = {}; + configs.reserve(stages.size()); - bool has_ngram_cache = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_CACHE); - bool has_ngram_simple = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE); - bool has_ngram_map_k = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K); - bool has_ngram_map_k4v = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V); - bool has_ngram_mod = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_MOD); - bool has_suffix = (params.type == COMMON_SPECULATIVE_TYPE_SUFFIX); + for (const auto & stage : stages) { + common_params_speculative stage_params = params.with_stage_overrides(stage); - // In a more complex implementation we could use the same implementation but with different parameters. - // This was initially used in PR-18471 but removed to simplify the code. - if (has_ngram_simple) { - // This implementation can guess a lot of tokens without any draft model. - configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, params)); - } - if (has_ngram_map_k) { - configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, params)); - } - if (has_ngram_map_k4v) { - // This implementation can guess tokens with high acceptance rate but is more expensive. - configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, params)); - } - if (has_ngram_mod) { - // shared instance for all speculative decoding contexts - if (!params.ngram_mod) { - params.ngram_mod = std::make_shared(params.ngram_size_n, 4*1024*1024); + if (stage.type == COMMON_SPECULATIVE_TYPE_NGRAM_MOD && !stage_params.ngram_mod) { + stage_params.ngram_mod = std::make_shared(stage_params.ngram_size_n, 4*1024*1024); - LOG_INF("%s: initialized ngram_mod with n=%d, size=%zu (%.3f MB)\n", __func__, - params.ngram_size_n, params.ngram_mod->size(), - (float)(params.ngram_mod->size_bytes())/1024/1024); + LOG_INF("%s: initialized ngram_mod with n=%d, size=%zu (%.3f MB)\n", __func__, + stage_params.ngram_size_n, stage_params.ngram_mod->size(), + (float)(stage_params.ngram_mod->size_bytes())/1024/1024); - if (params.ngram_size_n < 16) { - LOG_WRN("%s: ngram_mod n=%d is too small - poor quality is possible, see: https://github.com/ggml-org/llama.cpp/pull/19164\n", __func__, params.ngram_size_n); - } + if (stage_params.ngram_size_n < 16) { + LOG_WRN("%s: ngram_mod n=%d is too small - poor quality is possible, see: https://github.com/ggml-org/llama.cpp/pull/19164\n", __func__, stage_params.ngram_size_n); } + } - configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_MOD, params)); - } - if (has_ngram_cache) { - configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, params)); - } - if (has_suffix) { - configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_SUFFIX, params)); - } - if (has_mtp) { - configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_MTP, params)); - } - if (has_draft) { - configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT, params)); - } - if (has_draft_eagle3) { - configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_EAGLE3, params)); - } + configs.push_back(common_speculative_config(stage, stage_params)); } if (!configs.empty() && llama_model_has_recurrent(llama_get_model(ctx_tgt))) { - const int ckpt_tokens = std::max(1, params.n_max + 1); + const int ckpt_tokens = std::max(1, params.get_max_stage_n_max() + 1); const int actual_mode = llama_spec_ckpt_init(ctx_tgt, params.recurrent_ckpt_mode, ckpt_tokens); if (actual_mode == LLAMA_SPEC_CKPT_NONE) { LOG_ERR("%s: failed to prepare recurrent checkpoint mode '%s' during speculative init (max_tokens=%d)\n", @@ -1152,7 +1192,7 @@ common_speculative * common_speculative_init( impls.push_back(std::make_unique(config.type, /* .ctx_tgt = */ ctx_tgt, /* .ctx_dft = */ ctx_dft, - /* .replacements = */ params.replacements + /* .replacements = */ config.params.replacements )); break; } @@ -1160,7 +1200,7 @@ common_speculative * common_speculative_init( llama_context * ctx_mtp = ctx_dft; if (!ctx_mtp) { const llama_model * model = llama_get_model(ctx_tgt); - ctx_mtp = llama_init_from_model(const_cast(model), params.cparams_dft); + ctx_mtp = llama_init_from_model(const_cast(model), config.params.cparams_dft); if (!ctx_mtp) { LOG_ERR("%s: failed to create MTP context\n", __func__); return nullptr; @@ -1209,7 +1249,7 @@ common_speculative * common_speculative_init( } case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: { auto state = create_state_ngram_cache( - params.lookup_cache_static, params.lookup_cache_dynamic, config); + config.params.lookup_cache_static, config.params.lookup_cache_dynamic, config); impls.push_back(std::make_unique(state)); break; } @@ -1231,11 +1271,14 @@ common_speculative * common_speculative_init( } auto * result = new common_speculative { + /* .configs = */ std::move(configs), /* .impls = */ std::move(impls) }; // initialize autotune if requested - if (params.autotune && !result->impls.empty()) { + if (params.autotune && params.has_composite_stage_chain()) { + LOG_WRN("Autotune disabled — explicit speculative stage chains are not supported yet\n"); + } else if (params.autotune && !result->impls.empty()) { auto actual_type = result->impls[0]->type; if (actual_type != COMMON_SPECULATIVE_TYPE_NONE && actual_type != COMMON_SPECULATIVE_TYPE_EAGLE3) { @@ -1290,6 +1333,16 @@ static mtp_last_embd & mtp_get_last_embd(const llama_context * ctx) { return last; } +static void mtp_invalidate_cached_draft(const llama_context * ctx) { + if (ctx == nullptr) { + return; + } + + auto & last = mtp_get_last_embd(ctx); + last.last_id = -1; + last.prob = 0.0f; +} + llama_tokens common_speculative_draft( common_speculative * spec, common_params_speculative & params, @@ -1306,26 +1359,42 @@ llama_tokens common_speculative_draft( spec->tuner->propose(params); } + const auto runtime_stages = params.get_resolved_stages(); + const bool use_runtime_stage_overrides = common_speculative_stage_chain_matches(runtime_stages, spec->configs); + spec->curr_impl = nullptr; // reset current implementation - for (auto & impl : spec->impls) { + for (size_t i = 0; i < spec->impls.size(); ++i) { + auto & impl = spec->impls[i]; + const auto & runtime_stage = use_runtime_stage_overrides ? runtime_stages[i] : spec->configs[i].stage; + common_params_speculative impl_params = common_speculative_get_runtime_params(spec->configs[i], params, runtime_stage); + result.clear(); + { common_time_meas tm(impl->t_draft_us, !impl->gen_perf); - impl->draft(params, prompt_tgt, id_last, draft_base_pos, draft_seq_id, result); + impl->draft(impl_params, prompt_tgt, id_last, draft_base_pos, draft_seq_id, result); impl->n_call_draft++; } - if (!result.empty()) { - LOG_DBG("%s: called impl %s, hist size = %zu, call_count = %zu, gen = %zu\n", __func__, - common_speculative_type_to_str(impl.get()->type).c_str(), prompt_tgt.size(), - impl.get()->n_call_draft, result.size()); - - spec->curr_impl = impl.get(); // set current implementation for stats - impl->n_gen_drafts++; - impl->n_gen_tokens += result.size(); - - break; // We have a draft, so break out of the loop and return it. + if (result.empty()) { + continue; } + + if (common_speculative_type_is_self_spec(impl->type) && impl_params.n_min > 0 && (int)result.size() < impl_params.n_min) { + LOG_DBG("%s: impl %s drafted %zu tokens, below fallback threshold %d - trying next implementation\n", + __func__, common_speculative_type_to_str(impl->type).c_str(), result.size(), impl_params.n_min); + result.clear(); + continue; + } + LOG_DBG("%s: called impl %s, hist size = %zu, call_count = %zu, gen = %zu\n", __func__, + common_speculative_type_to_str(impl.get()->type).c_str(), prompt_tgt.size(), + impl.get()->n_call_draft, result.size()); + + spec->curr_impl = impl.get(); + impl->n_gen_drafts++; + impl->n_gen_tokens += result.size(); + + break; // We have a draft, so break out of the loop and return it. } // store draft count for tuner feedback @@ -1362,6 +1431,12 @@ void common_speculative_accept(common_speculative * spec, uint16_t n_accepted) { impl->accept(n_accepted); impl->n_call_accept++; } + + if (impl->type != COMMON_SPECULATIVE_TYPE_MTP) { + if (auto * ctx_mtp = common_speculative_get_mtp_ctx(spec); ctx_mtp != nullptr) { + mtp_invalidate_cached_draft(ctx_mtp); + } + } } void common_speculative_print_stats(const common_speculative * spec, double slot_tps, int n_decoded, int n_past, common_params_speculative * active_params) { @@ -1420,6 +1495,14 @@ llama_context * common_speculative_get_mtp_ctx(common_speculative * spec) { return nullptr; } +common_speculative_type common_speculative_current_type(const common_speculative * spec) { + if (spec == nullptr || spec->curr_impl == nullptr) { + return COMMON_SPECULATIVE_TYPE_NONE; + } + + return spec->curr_impl->type; +} + void common_speculative_context_shift( common_speculative * spec, llama_seq_id seq_id, @@ -1447,6 +1530,11 @@ std::vector mtp_speculative_gen_draft( if (!smpl) return drafts; + if (n_draft <= 0) { + mtp_invalidate_cached_draft(ctx); + return drafts; + } + common_sampler_reset(smpl); llama_batch mtp_batch = llama_batch_init(1, 0, 1); diff --git a/common/speculative.h b/common/speculative.h index 9695ce05..981e42b5 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -46,6 +46,7 @@ void common_speculative_print_stats(const common_speculative * spec, double slot // get the MTP context from the speculative object (nullptr if not MTP type) llama_context * common_speculative_get_mtp_ctx(common_speculative * spec); +common_speculative_type common_speculative_current_type(const common_speculative * spec); // Context shift for MTP to match how server handle main model void common_speculative_context_shift( diff --git a/docs/parameters.md b/docs/parameters.md index 765c4094..77ec4fa2 100644 --- a/docs/parameters.md +++ b/docs/parameters.md @@ -124,6 +124,7 @@ Check the details [here](./speculative.md). | `--spec-ngram-size-m N` | ngram size M for ngram-simple/ngram-map speculative decoding, length of draft m-gram | 48 | [PR 1261](https://github.com/ikawrakow/ik_llama.cpp/pull/1261) | | `--spec-ngram-min-hits N` | minimum hits for ngram-map speculative decoding | 1 | [PR 1261](https://github.com/ikawrakow/ik_llama.cpp/pull/1261) | | `--spec-type Name` | Comma-separated list of draft model parameters | - | none / ngram - cache / ngram - simple / ngram - map - k / ngram - map - k4v / ngram - mod / suffix [PR 1261](https://github.com/ikawrakow/ik_llama.cpp/pull/1261) [PR 1646](https://github.com/ikawrakow/ik_llama.cpp/pull/1646) | +| `--spec-stage SPEC[:k=v,...]` | Add an explicit speculative stage; repeat once for a supported two-stage chain | - | Supported two-stage shape: self-spec first, then `mtp` or `draft` fallback. [PR 1789](https://github.com/ikawrakow/ik_llama.cpp/pull/1789) | | `-mtp, --multi-token-prediction` | | - | MTP decoding [PR 1270](https://github.com/ikawrakow/ik_llama.cpp/pull/1270) [1698](https://github.com/ikawrakow/ik_llama.cpp/pull/1698) | | `-no-mtp, --no-multi-token-prediction` | | - | MTP decoding [PR 1270](https://github.com/ikawrakow/ik_llama.cpp/pull/1270) [1698](https://github.com/ikawrakow/ik_llama.cpp/pull/1698) | | `--draft-max` | | - | MTP decoding [PR 1270](https://github.com/ikawrakow/ik_llama.cpp/pull/1270) [1698](https://github.com/ikawrakow/ik_llama.cpp/pull/1698) | @@ -131,6 +132,13 @@ Check the details [here](./speculative.md). | `--spec-autotune` | Automatically tune speculative params to maximize tokens/sec | - | Automatically determines the near-optimal arguments for the type of speculation being performed [PR 1595](https://github.com/ikawrakow/ik_llama.cpp/pull/1595) | | `--recurrent-ckpt-mode MODE` | Checkpoint strategy for recurrent/hybrid speculative decoding | auto | One of: - `auto` auto-select: per-step if CUDA full-GPU, gpu-fallback otherwise - `per-step` save SSM state per draft step in VRAM; no re-decode on rejection - `gpu-fallback` copy state to GPU buffer; re-decode on rejection - `cpu` serialise state via llama_state_seq; re-decode on rejection [PR 1669](https://github.com/ikawrakow/ik_llama.cpp/pull/1669) [PR 1774](https://github.com/ikawrakow/ik_llama.cpp/pull/1774) | +Notes: + +- `--spec-type` cannot be combined with `--spec-stage`. +- Explicit stage chains currently support at most two stages. +- Supported self-spec stage names are `ngram-cache`, `ngram-simple`, `ngram-map-k`, `ngram-map-k4v`, `ngram-mod`, and `suffix`. +- Composite stage chains disable speculative autotune. + ## Cache Prompt to Host Memory When user starts a new conversation, the old conversation's kv cache will be saved in ram and can be retrieved later. This greatly reduces prompt processing time when switching between conversations and can have as many conversation as your ram is allowed. @@ -369,10 +377,26 @@ WIP | `--check-tensors` | Check model tensor data for invalid values | false | | | `--override-kv KEY=TYPE:VALUE` | Override model metadata by key | - | Advanced option to override model metadata by key. May be specified multiple times. types: int, float, bool, str. Example: `--override-kv tokenizer.ggml.add_bos_token=bool:false` | | `-m, --model FNAME` | Model path | models/$filename | Mandatory, the GGUF model file to be served. | -| `-md, --model-draft FNAME` | Draft model for speculative decoding | unused | | -| `--draft-max, --draft, --draft-n N` | Number of tokens to draft for speculative decoding | 16 | | -| `--draft-min, --draft-n-min N` | Minimum number of draft tokens to use for speculative decoding | - | | -| `--draft-p-min P` | Minimum speculative decoding probability (greedy) | 0.8 | | +| `-md, --model-draft FNAME` | Draft model for speculative decoding | unused | Required when an explicit `draft` stage is used. | +| `--draft-max, --draft, --draft-n N` | Global speculative draft cap, or fallback value for stages without an explicit `n_max` override | 16 | Also used by single-stage MTP and draft-model speculation. | +| `--draft-min, --draft-n-min N` | Global minimum speculative draft threshold, or fallback value for stages without an explicit `n_min` override | 0 | | +| `--draft-p-min P` | Global minimum speculative decoding probability (greedy), or fallback value for stages without an explicit `p_min` override | 0.8 | | + +### Request-Level Speculative Overrides + +When the server is started with speculative decoding enabled, request JSON may override: + +- `speculative.n_max` +- `speculative.n_min` +- `speculative.p_min` +- `speculative.stages` + +Request-level `speculative.stages` is constrained: + +- The number of stages must match the stage chain configured at server startup. +- Each request stage must keep the same `type` as the corresponding startup stage. +- Only `type`, `n_max`, `n_min`, and `p_min` are accepted per request. +- Structural stage parameters such as ngram sizes, ngram hit thresholds, and suffix depth remain startup-only. ## Server Options diff --git a/examples/server/README.md b/examples/server/README.md index 2e674eb9..be3a7f74 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -210,6 +210,10 @@ model: -m, --model FNAME model path (default: models/$filename with filename from --hf-file or --model-url if set, otherwise models/7B/ggml-model-f16.gguf) -md, --model-draft FNAME draft model for speculative decoding (default: unused) + --spec-stage SPEC[:k=v,...] + explicit speculative stage. repeat once for a supported two-stage chain + examples: --spec-stage ngram-mod:n_max=64,n_min=2 --spec-stage mtp:n_max=1 + supported two-stage shape: self-spec first, then mtp or draft fallback -mu, --model-url MODEL_URL model download url (default: unused) -hfr, --hf-repo REPO Hugging Face model repository (default: unused) -hff, --hf-file FILE Hugging Face model file (default: unused) @@ -960,6 +964,35 @@ To know the `id` of the adapter, use GET `/lora-adapters` ## More examples +### Composite speculative decoding + +Use `--spec-stage` for explicit stage chains. The currently supported two-stage shape is self-spec first, then `mtp` or `draft` fallback. + +Example with `ngram-mod` plus MTP fallback: + +```bash +./build/bin/llama-server \ + --model /models/target-mtp.gguf \ + --spec-stage ngram-mod:n_max=64,n_min=2,ngram_size_n=8 \ + --spec-stage mtp:n_max=1,p_min=0.0 +``` + +Example with `ngram-mod` plus draft-model fallback: + +```bash +./build/bin/llama-server \ + --model /models/target.gguf \ + --model-draft /models/draft.gguf \ + --spec-stage ngram-mod:n_max=64,n_min=2,ngram_size_n=8 \ + --spec-stage draft:n_max=4,p_min=0.0 +``` + +Notes: + +- Use `--spec-type` when you want a single self-spec stage only. +- `--spec-type` cannot be combined with `--spec-stage`. +- Explicit stage chains currently support at most two stages. + ### Change system prompt on runtime To use the server example to serve multiple chat-type clients while keeping the same system prompt, you can utilize the option `system_prompt`. This only needs to be used once. diff --git a/examples/server/server-common.h b/examples/server/server-common.h index e150ef28..9036e1f2 100644 --- a/examples/server/server-common.h +++ b/examples/server/server-common.h @@ -93,12 +93,45 @@ using raw_buffer = std::vector; void server_log(const char* level, const char* function, int line, const char* message, const json& extra); +static const json * json_value_ptr(const json & body, const std::string & key) { + auto direct = body.find(key); + if (direct != body.end()) { + return &(*direct); + } + + const json * current = &body; + size_t start = 0; + + while (start < key.size()) { + const size_t dot = key.find('.', start); + const std::string segment = key.substr(start, dot == std::string::npos ? std::string::npos : dot - start); + + if (!current->is_object()) { + return nullptr; + } + + auto it = current->find(segment); + if (it == current->end()) { + return nullptr; + } + + if (dot == std::string::npos) { + return &(*it); + } + + current = &(*it); + start = dot + 1; + } + + return nullptr; +} + template static T json_value(const json& body, const std::string& key, const T& default_value) { // Fallback null to default value - if (body.contains(key) && !body.at(key).is_null()) { + if (const json * value = json_value_ptr(body, key); value != nullptr && !value->is_null()) { try { - return body.at(key); + return *value; } catch (NLOHMANN_JSON_NAMESPACE::detail::type_error const& err) { std::stringstream ss; diff --git a/examples/server/server-context.cpp b/examples/server/server-context.cpp index cbd76a6d..69217ba2 100644 --- a/examples/server/server-context.cpp +++ b/examples/server/server-context.cpp @@ -104,6 +104,32 @@ static void cache_and_sync_slot_mtp_hidden_from_rows(server_slot & slot, llama_c cache_and_sync_slot_mtp_hidden(slot, ctx, rows.data() + (n_rows - 1) * n_embd, n_embd); } +static const float * mtp_hidden_last_row(const std::vector & rows, int n_embd) { + if (n_embd <= 0 || rows.size() < (size_t) n_embd) { + return nullptr; + } + + const size_t n_rows = rows.size() / n_embd; + if (n_rows == 0) { + return nullptr; + } + + return rows.data() + (n_rows - 1) * n_embd; +} + +static bool sync_external_mtp_after_non_mtp_accept( + server_slot & slot, + llama_context * ctx, + const std::vector & mtp_commit_states, + int n_embd) { + if (!slot.use_gemma4_external_mtp || mtp_commit_states.empty() || n_embd <= 0) { + return false; + } + + cache_and_sync_slot_mtp_hidden_from_rows(slot, ctx, mtp_commit_states, n_embd); + return true; +} + static void apply_slot_mtp_accept( server_slot & slot, llama_context * ctx, @@ -183,6 +209,12 @@ static int32_t server_mtp_media_warmup_callback(void * user_data, const llama_ba return server_mtp_warmup_batch(data->ctx_tgt, get_slot_mtp_ctx(*data->slot, data->ctx_tgt), batch, *data->slot); } +static bool server_response_needs_chat_parse(oaicompat_type oaicompat) { + return oaicompat == OAICOMPAT_TYPE_CHAT || + oaicompat == OAICOMPAT_TYPE_ANTHROPIC || + oaicompat == OAICOMPAT_TYPE_RESP; +} + void server_speculative_checkpoint::clear() { valid = false; per_step_enabled = false; @@ -229,6 +261,104 @@ static bool save_speculative_checkpoint(server_slot & slot, llama_model * model, return true; } +static void server_remove_speculative_stage(common_params_speculative & spec, common_speculative_type type) { + spec.stages.erase(std::remove_if(spec.stages.begin(), spec.stages.end(), [type](const common_speculative_stage_params & stage) { + return stage.type == type; + }), spec.stages.end()); + + if (spec.type == type) { + spec.type = COMMON_SPECULATIVE_TYPE_NONE; + const auto resolved = spec.get_resolved_stages(); + spec.type = resolved.empty() ? COMMON_SPECULATIVE_TYPE_NONE : resolved.front().type; + } +} + +static bool server_speculative_has_mtp(const common_params_speculative & spec) { + return spec.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP); +} + +static bool server_speculative_same_stage_types( + const common_params_speculative & lhs, + const common_params_speculative & rhs) { + const auto lhs_stages = lhs.get_resolved_stages(); + const auto rhs_stages = rhs.get_resolved_stages(); + + if (lhs_stages.size() != rhs_stages.size()) { + return false; + } + + for (size_t i = 0; i < lhs_stages.size(); ++i) { + if (lhs_stages[i].type != rhs_stages[i].type) { + return false; + } + } + + return true; +} + +static void server_reject_dead_speculative_request_overrides(const json & data) { + if (json_value_ptr(data, "speculative.type") != nullptr) { + throw std::runtime_error("Error: speculative.type request override is not supported; keep the startup stage types and use speculative.stages or speculative.n_max/n_min/p_min"); + } + + if (json_value_ptr(data, "speculative.ngram_size_n") != nullptr || + json_value_ptr(data, "speculative.ngram_size_m") != nullptr || + json_value_ptr(data, "speculative.ngram_min_hits") != nullptr || + json_value_ptr(data, "speculative.suffix_min_match_len") != nullptr || + json_value_ptr(data, "speculative.suffix_max_depth") != nullptr) { + throw std::runtime_error("Error: structural speculative overrides are startup-only; per-request overrides only support speculative.n_max, speculative.n_min, speculative.p_min, and speculative.stages"); + } +} + +static common_speculative_stage_params server_parse_speculative_stage_json(const json & stage_json) { + if (!stage_json.is_object()) { + throw std::runtime_error("Error: speculative.stages entries must be objects"); + } + if (!stage_json.contains("type") || !stage_json["type"].is_string()) { + throw std::runtime_error("Error: speculative.stages entries must include a string 'type'"); + } + + common_speculative_stage_params stage; + stage.type = common_speculative_type_from_name(stage_json["type"].get()); + if (stage.type == COMMON_SPECULATIVE_TYPE_COUNT) { + throw std::runtime_error("Error: unknown speculative stage type in speculative.stages"); + } + + for (const auto & item : stage_json.items()) { + if (item.key() == "type") { + continue; + } + + if (item.key() == "n_max") { + stage.n_max = item.value().get(); + if (stage.n_max < 0) { + throw std::runtime_error("Error: speculative.stages[].n_max must be >= 0"); + } + continue; + } + + if (item.key() == "n_min") { + stage.n_min = item.value().get(); + if (stage.n_min < 0) { + throw std::runtime_error("Error: speculative.stages[].n_min must be >= 0"); + } + continue; + } + + if (item.key() == "p_min") { + stage.p_min = item.value().get(); + if (stage.p_min < 0.0f) { + throw std::runtime_error("Error: speculative.stages[].p_min must be >= 0"); + } + continue; + } + + throw std::runtime_error("Error: per-request speculative.stages only support type, n_max, n_min, and p_min; structural stage overrides are startup-only"); + } + + return stage; +} + server_context::~server_context() { if (ctx) { llama_free(ctx); @@ -328,9 +458,13 @@ bool server_context::load_model(const gpt_params& params_) { LOG_ERROR("%s\n", "err: speculative decode is not supported by multimodal"); return false; } - if (params_base.speculative.type != COMMON_SPECULATIVE_TYPE_NONE && - params_base.speculative.type != COMMON_SPECULATIVE_TYPE_MTP) { + const auto spec_stages = params_base.speculative.get_resolved_stages(); + const bool multimodal_spec_supported = spec_stages.empty() || + (spec_stages.size() == 1 && spec_stages.front().type == COMMON_SPECULATIVE_TYPE_MTP); + if (!multimodal_spec_supported) { params_base.speculative.type = COMMON_SPECULATIVE_TYPE_NONE; + params_base.speculative.stages.clear(); + params_base.has_mtp = false; SRV_WRN("%s\n", "speculative decoding is not supported by multimodal, it will be disabled"); } } @@ -380,9 +514,12 @@ bool server_context::load_model(const gpt_params& params_) { params_base.speculative.cparams_dft = cparams_dft; } - else if (params_base.has_mtp && llama_model_n_nextn_layer(model) == 0) { - LOG_WARNING("WARNING: -mtp flag provided, but model has 0 NextN layers. MTP will be disabled.\n", {}); + if (server_speculative_has_mtp(params_base.speculative) && + llama_model_n_nextn_layer(model) == 0 && + !params_use_gemma4_external_mtp(params_base)) { + LOG_WARNING("WARNING: MTP speculative stage requested, but model has 0 NextN layers. MTP will be disabled.\n", {}); params_base.has_mtp = false; + server_remove_speculative_stage(params_base.speculative, COMMON_SPECULATIVE_TYPE_MTP); } return true; } @@ -433,13 +570,14 @@ void server_context::init() { slot.ga_n = ga_n; slot.ga_w = ga_w; + slot.params.speculative = params_base.speculative; slot.sparams = params_base.sparams; - if (params_base.has_mtp) { + const bool wants_mtp_stage = server_speculative_has_mtp(params_base.speculative); + if (wants_mtp_stage) { const bool has_external_mtp = params_use_gemma4_external_mtp(params_base); if (llama_model_n_nextn_layer(model) > 0 || has_external_mtp) { - params_base.speculative.type = COMMON_SPECULATIVE_TYPE_MTP; params_base.pooling_type = LLAMA_POOLING_TYPE_NONE; if (!has_external_mtp) { @@ -452,25 +590,24 @@ void server_context::init() { slot.has_mtp = true; slot.use_gemma4_external_mtp = has_external_mtp; - slot.params.speculative.type = COMMON_SPECULATIVE_TYPE_MTP; - slot.params.speculative.n_min = 0; slot.params.speculative.cparams_dft = params_base.speculative.cparams_dft; - slot.batch_spec = llama_batch_init(slot.params.speculative.n_max + 1, 0, 1); + slot.batch_spec = llama_batch_init(slot.params.speculative.get_max_stage_n_max() + 1, 0, 1); SLT_DBG(slot, "batch_spec contains %d tokens\n", slot.batch_spec.n_tokens); SRV_INF("%s\n", "MTP needs embeddings on decode, enabling"); llama_set_embeddings(ctx, true); } else { - SRV_WRN("%s\n", "MTP enabled via flag, but model has 0 NextN layers. Disabling speculative."); - params_base.speculative.type = COMMON_SPECULATIVE_TYPE_NONE; + SRV_WRN("%s\n", "MTP speculative stage requested, but model has 0 NextN layers. Removing MTP from the configured stage chain."); + params_base.has_mtp = false; + server_remove_speculative_stage(params_base.speculative, COMMON_SPECULATIVE_TYPE_MTP); + slot.params.speculative = params_base.speculative; slot.has_mtp = false; } } - const bool requested_spec = params_base.speculative.type != COMMON_SPECULATIVE_TYPE_NONE || - params_base.speculative.has_dft(); + const bool requested_spec = !params_base.speculative.get_resolved_stages().empty(); bool can_spec = true; if (!params_base.dry_run) { @@ -633,6 +770,7 @@ void server_slot::reset() { checkpoint_pos = 0; image_just_processed = false; do_checkpoint = false; + mtp_hidden_state.clear(); positional_bans.clear(); ban_phrases.clear(); @@ -711,7 +849,7 @@ int server_slot::get_n_draft_max() const { } // determine the max draft that fits the current slot state - int n_draft_max = params.speculative.n_max; + int n_draft_max = params.speculative.get_max_stage_n_max(); // note: slot.prompt is not yet expanded with the `id` token sampled above // also, need to leave space for 1 extra token to allow context shifts @@ -723,8 +861,9 @@ int server_slot::get_n_draft_max() const { SLT_DBG(*this, "max possible draft: %d\n", n_draft_max); - if (n_draft_max < params.speculative.n_min) { - SLT_DBG(*this, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, params.speculative.n_min); + const int min_usable_draft = params.speculative.get_min_usable_stage_n_min(); + if (n_draft_max < min_usable_draft) { + SLT_DBG(*this, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, min_usable_draft); n_draft_max = 0; } return n_draft_max; @@ -1239,40 +1378,83 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task) slot.params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs); // speculative decoding parameters - slot.params.speculative.n_max = json_value(data, "speculative.n_max", params_base.speculative.n_max); - slot.params.speculative.n_min = json_value(data, "speculative.n_min", params_base.speculative.n_min); - slot.params.speculative.p_min = json_value(data, "speculative.p_min", params_base.speculative.p_min); + try { + slot.params.speculative = defaults.speculative; + slot.params.speculative.n_max = json_value(data, "speculative.n_max", params_base.speculative.n_max); + slot.params.speculative.n_min = json_value(data, "speculative.n_min", params_base.speculative.n_min); + slot.params.speculative.p_min = json_value(data, "speculative.p_min", params_base.speculative.p_min); - slot.params.speculative.n_min = std::min(slot.params.speculative.n_max, slot.params.speculative.n_min); - slot.params.speculative.n_min = std::max(slot.params.speculative.n_min, 0); - slot.params.speculative.n_max = std::max(slot.params.speculative.n_max, 0); + server_reject_dead_speculative_request_overrides(data); - slot.params.speculative.type = common_speculative_type_from_name(json_value(data, "speculative.type", common_speculative_type_to_str(defaults.speculative.type))); + const json stages = json_value(data, "speculative.stages", json()); + if (!stages.is_null()) { + if (!stages.is_array()) { + throw std::runtime_error("Error: speculative.stages must be an array"); + } - // Clamp speculative parameters - slot.params.speculative.n_min = std::min(slot.params.speculative.n_max, slot.params.speculative.n_min); - slot.params.speculative.n_min = std::max(slot.params.speculative.n_min, 0); - slot.params.speculative.n_max = std::max(slot.params.speculative.n_max, 0); + const auto default_stages = defaults.speculative.get_resolved_stages(); + if (stages.size() != default_stages.size()) { + throw std::runtime_error("Error: speculative.stages must provide the same number of stages configured at server startup"); + } - if (slot.can_speculate() && - llama_model_has_recurrent(model) && - slot.params.speculative.n_max > params_base.speculative.n_max) { - send_error(task, - "Error: speculative.n_max=" + std::to_string(slot.params.speculative.n_max) + - " exceeds the recurrent speculative startup limit of " + std::to_string(params_base.speculative.n_max) + - "; restart the server with a higher --draft-max to reserve checkpoint capacity", - ERROR_TYPE_INVALID_REQUEST); + slot.params.speculative.stages = default_stages; + for (size_t i = 0; i < stages.size(); ++i) { + const auto stage_override = server_parse_speculative_stage_json(stages[i]); + if (stage_override.type != default_stages[i].type) { + throw std::runtime_error("Error: speculative.stages must preserve the stage types configured at server startup"); + } + + if (stage_override.has_n_max_override()) { + slot.params.speculative.stages[i].n_max = stage_override.n_max; + } + if (stage_override.has_n_min_override()) { + slot.params.speculative.stages[i].n_min = stage_override.n_min; + } + if (stage_override.has_p_min_override()) { + slot.params.speculative.stages[i].p_min = stage_override.p_min; + } + } + + const auto resolved = slot.params.speculative.get_resolved_stages(); + slot.params.speculative.type = resolved.empty() ? COMMON_SPECULATIVE_TYPE_NONE : resolved.front().type; + } + + slot.params.speculative.n_min = std::min(slot.params.speculative.n_max, slot.params.speculative.n_min); + slot.params.speculative.n_min = std::max(slot.params.speculative.n_min, 0); + slot.params.speculative.n_max = std::max(slot.params.speculative.n_max, 0); + + if (slot.can_speculate() && + llama_model_has_recurrent(model) && + slot.params.speculative.n_max > params_base.speculative.n_max) { + send_error(task, + "Error: speculative.n_max=" + std::to_string(slot.params.speculative.n_max) + + " exceeds the recurrent speculative startup limit of " + std::to_string(params_base.speculative.n_max) + + "; restart the server with a higher --draft-max to reserve checkpoint capacity", + ERROR_TYPE_INVALID_REQUEST); + return false; + } + + if (!server_speculative_same_stage_types(slot.params.speculative, defaults.speculative)) { + throw std::runtime_error("Error: per-request speculative stages must match the server startup stage types; only stage parameter overrides are supported"); + } + + if (slot.params.speculative.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP) && !slot.has_mtp) { + throw std::runtime_error("Error: MTP speculative stage requested, but the server was not started with MTP support"); + } + + if (slot.params.speculative.has_stage_type(COMMON_SPECULATIVE_TYPE_DRAFT) && !params_base.speculative.has_dft()) { + throw std::runtime_error("Error: draft speculative stage requested, but no draft model is loaded"); + } + + std::string spec_error; + if (!common_speculative_validate_chain(slot.params.speculative, &spec_error)) { + throw std::runtime_error("Error: invalid speculative request configuration: " + spec_error); + } + } catch (const std::exception & e) { + send_error(task, e.what(), ERROR_TYPE_INVALID_REQUEST); return false; } - slot.params.speculative.ngram_size_n = json_value(data, "speculative.ngram_size_n", defaults.speculative.ngram_size_n); - slot.params.speculative.ngram_size_m = json_value(data, "speculative.ngram_size_m", defaults.speculative.ngram_size_m); - slot.params.speculative.ngram_min_hits = json_value(data, "speculative.ngram_m_hits", defaults.speculative.ngram_min_hits); - - slot.params.speculative.ngram_size_n = std::max(std::min(1, (int)slot.params.speculative.ngram_size_n), 1024); - slot.params.speculative.ngram_size_m = std::max(std::min(1, (int)slot.params.speculative.ngram_size_m), 1024); - slot.params.speculative.ngram_min_hits = std::max(std::min(1, (int)slot.params.speculative.ngram_min_hits), 1024); - if (slot.sparams.penalty_last_n < -1) { throw std::runtime_error("Error: repeat_last_n must be >= -1"); @@ -2378,31 +2560,33 @@ void server_context::send_partial_response(server_slot& slot, completion_token_o {"id_slot", slot.id}, {"multimodal", false} }; - slot.update_chat_msg(true, res->oaicompat_msg_diffs); + if (server_response_needs_chat_parse(slot.params.oaicompat)) { + slot.update_chat_msg(true, res->oaicompat_msg_diffs); - res->anthropic_has_reasoning = !slot.chat_msg.reasoning_content.empty(); + res->anthropic_has_reasoning = !slot.chat_msg.reasoning_content.empty(); - res->anthropic_thinking_block_started = slot.anthropic_thinking_block_started; - res->anthropic_text_block_started = slot.anthropic_text_block_started; + res->anthropic_thinking_block_started = slot.anthropic_thinking_block_started; + res->anthropic_text_block_started = slot.anthropic_text_block_started; - res->oai_resp_thinking_block_started = slot.oai_resp_thinking_block_started; - res->oai_resp_text_block_started = slot.oai_resp_text_block_started; + res->oai_resp_thinking_block_started = slot.oai_resp_thinking_block_started; + res->oai_resp_text_block_started = slot.oai_resp_text_block_started; - for (const auto& diff : res->oaicompat_msg_diffs) { - if (!diff.reasoning_content_delta.empty() && !slot.anthropic_thinking_block_started) { - slot.anthropic_thinking_block_started = true; - } - if (!diff.content_delta.empty() && !slot.anthropic_text_block_started) { - slot.anthropic_text_block_started = true; - } - if (!diff.reasoning_content_delta.empty() && !slot.oai_resp_thinking_block_started) { - slot.oai_resp_thinking_block_started = true; - } - if (!diff.content_delta.empty() && !slot.oai_resp_text_block_started) { - slot.oai_resp_text_block_started = true; - } - if (!diff.tool_call_delta.name.empty()) { - slot.oai_resp_fc_id = diff.tool_call_delta.id; + for (const auto& diff : res->oaicompat_msg_diffs) { + if (!diff.reasoning_content_delta.empty() && !slot.anthropic_thinking_block_started) { + slot.anthropic_thinking_block_started = true; + } + if (!diff.content_delta.empty() && !slot.anthropic_text_block_started) { + slot.anthropic_text_block_started = true; + } + if (!diff.reasoning_content_delta.empty() && !slot.oai_resp_thinking_block_started) { + slot.oai_resp_thinking_block_started = true; + } + if (!diff.content_delta.empty() && !slot.oai_resp_text_block_started) { + slot.oai_resp_text_block_started = true; + } + if (!diff.tool_call_delta.name.empty()) { + slot.oai_resp_fc_id = diff.tool_call_delta.id; + } } } @@ -2439,7 +2623,9 @@ void server_context::send_final_response(server_slot& slot) { res->post_sampling_probs = slot.params.post_sampling_probs; res->oaicompat = slot.params.oaicompat; res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; - res->oaicompat_msg = slot.update_chat_msg(false, res->oaicompat_msg_diffs); + if (server_response_needs_chat_parse(slot.params.oaicompat)) { + res->oaicompat_msg = slot.update_chat_msg(false, res->oaicompat_msg_diffs); + } res->oai_resp_id = slot.oai_resp_id; res->oai_resp_reasoning_id = slot.oai_resp_reasoning_id; res->oai_resp_message_id = slot.oai_resp_message_id; @@ -3452,7 +3638,7 @@ void server_context::add_sampled_tokens() { } static const llama_tokens empty_prompt; - const llama_tokens & cached_text_tokens = slot.has_mtp + const llama_tokens & cached_text_tokens = slot.has_mtp && !slot.params.speculative.has_composite_stage_chain() ? empty_prompt : slot.cache_tokens.get_text_tokens(); @@ -3491,14 +3677,14 @@ void server_context::add_sampled_tokens() { common_batch_add(batch, slot.sampled, slot.cache_tokens.pos_next(), { slot.id }, true); slot.cache_tokens.push_back(slot.sampled); - if (slot.params.speculative.n_min > (int)draft.size()) { - SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int)draft.size(), slot.params.speculative.n_min); + const int min_usable_draft = slot.params.speculative.get_min_usable_stage_n_min(); + if (min_usable_draft > (int)draft.size()) { + SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int)draft.size(), min_usable_draft); // fallback to normal decoding slot.i_batch = slot.i_batch_dft[0]; slot.drafted.clear(); slot.i_batch_dft.clear(); - } - else { + } else { // keep track of total number of drafted tokens tested slot.n_draft_total += draft.size(); @@ -3586,6 +3772,7 @@ void server_context::apply_checkpoint(server_slot & slot) { slot.n_past_prompt = 0; slot.n_past_se = 0; slot.ga_i = 0; + slot.cache_tokens.keep_first(0); pos_next = 0; common_sampler_reset(slot.ctx_sampling); } @@ -4028,7 +4215,11 @@ void server_context::extend_context(const int32_t n_tokens) { // Restore recurrent state and re-decode accepted tokens after speculative-decode rejection. static void restore_speculative_checkpoint( server_slot & slot, llama_context * ctx, llama_model * model, + common_speculative_type spec_type_used, const std::vector & ids, int n_draft, + const std::vector & mtp_commit_tokens, + const std::vector & mtp_commit_states, + const std::vector & mtp_hidden_state_seed, const std::vector & mtp_hidden_state_pre, int32_t mtp_n_past_base) { if (slot.spec_ckpt.per_step_enabled) { const int step = (int)ids.size() - 1; @@ -4043,8 +4234,36 @@ static void restore_speculative_checkpoint( // Update MTP KV cache and hidden state using embeddings collected before checkpoint restore. if (slot.has_mtp && !mtp_hidden_state_pre.empty()) { - const int n_embd = get_ctx_mtp_n_embd(ctx); - apply_slot_mtp_accept(slot, ctx, mtp_hidden_state_pre, ids, mtp_n_past_base, n_embd); + llama_context * mtp_ctx = common_speculative_get_mtp_ctx(slot.spec); + llama_context * mtp_target = mtp_ctx ? mtp_ctx : ctx; + + if (spec_type_used == COMMON_SPECULATIVE_TYPE_MTP) { + const int n_embd = get_ctx_mtp_n_embd(ctx); + apply_slot_mtp_accept(slot, ctx, mtp_hidden_state_pre, ids, mtp_n_past_base, n_embd); + } else if (!mtp_commit_tokens.empty() && !mtp_commit_states.empty()) { + const int n_embd = get_ctx_mtp_n_embd(ctx); + if (sync_external_mtp_after_non_mtp_accept(slot, ctx, mtp_commit_states, n_embd)) { + SLT_DBG(slot, "%s", "synced external MTP hidden state from accepted-prefix rows after per-step restore"); + } else { + const float * seed_hidden = mtp_hidden_last_row(mtp_hidden_state_seed, n_embd); + + if (seed_hidden == nullptr) { + SLT_WRN(slot, "%s", "missing MTP seed hidden state for accepted-prefix replay after per-step restore"); + slot.mtp_hidden_state.clear(); + } else { + llama_batch accepted_batch = llama_batch_init(mtp_commit_tokens.size(), 0, 1); + for (size_t i = 0; i < mtp_commit_tokens.size(); ++i) { + common_batch_add(accepted_batch, mtp_commit_tokens[i], mtp_n_past_base + i, { slot.id }, true); + } + + llama_set_draft_input_hidden_state(mtp_target, seed_hidden); + mtp_update_kv_cache(mtp_target, accepted_batch, false); + llama_batch_free(accepted_batch); + + slot.mtp_hidden_state.assign(mtp_commit_states.end() - n_embd, mtp_commit_states.end()); + } + } + } } SLT_DBG(slot, "per-step restore: step=%d (rejected %d drafts)\n", @@ -4124,7 +4343,10 @@ void server_context::speculative_decoding_accept() { continue; } + const llama_token sampled_before = slot.sampled; + const common_speculative_type spec_type_used = common_speculative_current_type(slot.spec); size_t n_draft = slot.drafted.size(); + const std::vector mtp_hidden_state_seed = slot.has_mtp ? slot.mtp_hidden_state : std::vector{}; slot.ctx_sampling->to_generated_text = &slot.generated_text; if (n_draft > 0) { @@ -4155,6 +4377,8 @@ void server_context::speculative_decoding_accept() { int32_t mtp_n_past_base = 0; std::vector mtp_hidden_state_pre; + std::vector mtp_commit_tokens; + std::vector mtp_commit_states; if (slot.has_mtp) { const int32_t n_pre_spec_tokens = slot.cache_tokens.n_tokens() - (int32_t)(slot.drafted.size() + 1); mtp_n_past_base = slot.cache_tokens.pos_next(n_pre_spec_tokens); @@ -4168,6 +4392,20 @@ void server_context::speculative_decoding_accept() { memcpy(mtp_hidden_state_pre.data() + i * n_embd, emb_i, n_embd * sizeof(float)); } } + + if (spec_type_used != COMMON_SPECULATIVE_TYPE_MTP) { + mtp_commit_tokens.reserve(ids.size()); + mtp_commit_tokens.push_back(sampled_before); + mtp_commit_tokens.insert(mtp_commit_tokens.end(), ids.begin(), ids.end() - 1); + + mtp_commit_states.resize(ids.size() * n_embd); + for (size_t i = 0; i < ids.size(); ++i) { + const float * emb_i = llama_get_embeddings_ith(ctx, slot.i_batch_dft[i]); + if (emb_i) { + memcpy(mtp_commit_states.data() + i * n_embd, emb_i, n_embd * sizeof(float)); + } + } + } } else { const float* emb0 = llama_get_embeddings_ith(ctx, 0); if (emb0) { @@ -4204,11 +4442,39 @@ void server_context::speculative_decoding_accept() { // for recurrent/hybrid models: if any drafts were rejected, restore recurrent state const bool any_rejected = (ids.size() - 1) < n_draft; if (any_rejected && slot.spec_ckpt.valid) { - restore_speculative_checkpoint(slot, ctx, model, ids, n_draft, mtp_hidden_state_pre, mtp_n_past_base); + restore_speculative_checkpoint(slot, ctx, model, spec_type_used, ids, n_draft, mtp_commit_tokens, mtp_commit_states, mtp_hidden_state_seed, mtp_hidden_state_pre, mtp_n_past_base); } else { if (slot.has_mtp && !mtp_hidden_state_pre.empty()) { - const int n_embd = get_ctx_mtp_n_embd(ctx); - apply_slot_mtp_accept(slot, ctx, mtp_hidden_state_pre, ids, mtp_n_past_base, n_embd); + llama_context * mtp_ctx = common_speculative_get_mtp_ctx(slot.spec); + llama_context * mtp_target = mtp_ctx ? mtp_ctx : ctx; + + if (spec_type_used == COMMON_SPECULATIVE_TYPE_MTP) { + const int n_embd = get_ctx_mtp_n_embd(ctx); + apply_slot_mtp_accept(slot, ctx, mtp_hidden_state_pre, ids, mtp_n_past_base, n_embd); + } else if (!mtp_commit_tokens.empty() && !mtp_commit_states.empty()) { + const int n_embd = get_ctx_mtp_n_embd(ctx); + if (sync_external_mtp_after_non_mtp_accept(slot, ctx, mtp_commit_states, n_embd)) { + SLT_DBG(slot, "%s", "synced external MTP hidden state from accepted-prefix rows"); + } else { + const float * seed_hidden = mtp_hidden_last_row(mtp_hidden_state_seed, n_embd); + + if (seed_hidden == nullptr) { + SLT_WRN(slot, "%s", "missing MTP seed hidden state for accepted-prefix replay"); + slot.mtp_hidden_state.clear(); + } else { + llama_batch accepted_batch = llama_batch_init(mtp_commit_tokens.size(), 0, 1); + for (size_t i = 0; i < mtp_commit_tokens.size(); ++i) { + common_batch_add(accepted_batch, mtp_commit_tokens[i], mtp_n_past_base + i, { slot.id }, true); + } + + llama_set_draft_input_hidden_state(mtp_target, seed_hidden); + mtp_update_kv_cache(mtp_target, accepted_batch, false); + llama_batch_free(accepted_batch); + + slot.mtp_hidden_state.assign(mtp_commit_states.end() - n_embd, mtp_commit_states.end()); + } + } + } } llama_kv_cache_seq_rm(ctx, slot.id, slot.cache_tokens.pos_next(slot.n_past), -1); discard_speculative_checkpoint(slot, ctx); @@ -4586,8 +4852,8 @@ void server_context::process_batch_tokens(int32_t & n_batch) { continue; // continue loop of n_batch } - server_slot * mtp_warmup_slot = nullptr; - if (params_base.has_mtp) { + server_slot * mtp_warmup_slot = nullptr; + if (server_speculative_has_mtp(params_base.speculative)) { for (auto& slot : slots) { if ((slot.state == SLOT_STATE_PROCESSING && slot.n_decoded == 0) || (slot.state == SLOT_STATE_IDLE && slot.command == SLOT_COMMAND_LOAD_PROMPT)) { @@ -4632,7 +4898,7 @@ void server_context::process_batch_tokens(int32_t & n_batch) { if (slot.n_decoded == 0 && slot.can_speculate()) { static const llama_tokens empty_prompt; - const llama_tokens & spec_prompt = slot.has_mtp + const llama_tokens & spec_prompt = slot.has_mtp && !slot.params.speculative.has_composite_stage_chain() ? empty_prompt : slot.cache_tokens.get_text_tokens(); common_speculative_begin(slot.spec, spec_prompt); diff --git a/examples/server/webui_llamacpp/src/lib/types/api.d.ts b/examples/server/webui_llamacpp/src/lib/types/api.d.ts index eda280cb..696bd130 100644 --- a/examples/server/webui_llamacpp/src/lib/types/api.d.ts +++ b/examples/server/webui_llamacpp/src/lib/types/api.d.ts @@ -1,5 +1,12 @@ import type { ChatMessagePromptProgress } from './chat'; +export type ApiSpeculativeStage = { + type: string; + n_max?: number; + n_min?: number; + p_min?: number; +}; + export interface ApiChatMessageContentPart { type: 'text' | 'image_url' | 'input_audio'; text?: string; @@ -124,6 +131,7 @@ export interface ApiLlamaCppServerProps { 'speculative.n_max': number; 'speculative.n_min': number; 'speculative.p_min': number; + 'speculative.stages'?: ApiSpeculativeStage[]; timings_per_token: boolean; post_sampling_probs: boolean; lora: Array<{ name: string; scale: number }>; @@ -284,6 +292,7 @@ export interface ApiSlotData { 'speculative.n_max': number; 'speculative.n_min': number; 'speculative.p_min': number; + 'speculative.stages'?: ApiSpeculativeStage[]; timings_per_token: boolean; post_sampling_probs: boolean; lora: Array<{ name: string; scale: number }>;