mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-06-28 04:30:15 -05:00
2444 lines
103 KiB
C++
2444 lines
103 KiB
C++
#include "speculative.h"
|
|
|
|
#include "common.h"
|
|
#include "ggml.h"
|
|
#include "llama.h"
|
|
#include "log.h"
|
|
#include "ngram-cache.h"
|
|
#include "ngram-map.h"
|
|
#include "ngram-mod.h"
|
|
#include "sampling.h"
|
|
#include "suffix-tree.h"
|
|
|
|
#include <algorithm>
|
|
#include <atomic>
|
|
#include <cstdlib>
|
|
#include <cstring>
|
|
#include <iomanip>
|
|
#include <limits>
|
|
#include <map>
|
|
#include <sstream>
|
|
#include <unordered_map>
|
|
|
|
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
|
|
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
|
|
|
|
void llama_set_mtp_target_context(struct llama_context * ctx, struct llama_context * target_ctx);
|
|
|
|
const std::vector<enum common_speculative_type> common_speculative_types = {
|
|
COMMON_SPECULATIVE_TYPE_NONE,
|
|
COMMON_SPECULATIVE_TYPE_DRAFT,
|
|
COMMON_SPECULATIVE_TYPE_DFLASH,
|
|
COMMON_SPECULATIVE_TYPE_MTP,
|
|
COMMON_SPECULATIVE_TYPE_EAGLE3,
|
|
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE,
|
|
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K,
|
|
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V,
|
|
COMMON_SPECULATIVE_TYPE_NGRAM_MOD,
|
|
COMMON_SPECULATIVE_TYPE_NGRAM_CACHE,
|
|
COMMON_SPECULATIVE_TYPE_SUFFIX
|
|
};
|
|
|
|
const std::map<std::string, enum common_speculative_type> common_speculative_type_from_name_map = {
|
|
{"none", COMMON_SPECULATIVE_TYPE_NONE},
|
|
{"draft", COMMON_SPECULATIVE_TYPE_DRAFT},
|
|
{"dflash", COMMON_SPECULATIVE_TYPE_DFLASH},
|
|
{"mtp", COMMON_SPECULATIVE_TYPE_MTP},
|
|
{"eagle3", COMMON_SPECULATIVE_TYPE_EAGLE3},
|
|
{"ngram_simple", COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE},
|
|
{"ngram_map_k", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K},
|
|
{"ngram_map_k4v", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V},
|
|
{"ngram_mod", COMMON_SPECULATIVE_TYPE_NGRAM_MOD},
|
|
{"ngram_cache", COMMON_SPECULATIVE_TYPE_NGRAM_CACHE},
|
|
{"suffix", COMMON_SPECULATIVE_TYPE_SUFFIX}
|
|
};
|
|
|
|
void common_speculative_checkpoint::clear() {
|
|
valid = false;
|
|
per_step_enabled = false;
|
|
n_past = 0;
|
|
sampled = LLAMA_TOKEN_NULL;
|
|
|
|
if (sampler != nullptr) {
|
|
common_sampler_free(sampler);
|
|
sampler = nullptr;
|
|
}
|
|
}
|
|
|
|
struct common_speculative_config {
|
|
common_speculative_stage_params stage;
|
|
common_speculative_type type;
|
|
common_params_speculative params;
|
|
|
|
common_speculative_config(
|
|
const common_speculative_stage_params & s,
|
|
const common_params_speculative & p = common_params_speculative{})
|
|
: stage(s), type(s.type), params(p) {}
|
|
};
|
|
|
|
static bool common_speculative_are_compatible(
|
|
const llama_model * model_tgt,
|
|
const llama_model * model_dft) {
|
|
const llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt);
|
|
const llama_vocab * vocab_dft = llama_model_get_vocab(model_dft);
|
|
|
|
const auto vocab_type_tgt = llama_vocab_type(vocab_tgt);
|
|
LOG_DBG("%s: vocab_type tgt: %d\n", __func__, vocab_type_tgt);
|
|
|
|
const auto vocab_type_dft = llama_vocab_type(vocab_dft);
|
|
LOG_DBG("%s: vocab_type dft: %d\n", __func__, vocab_type_dft);
|
|
|
|
if (vocab_type_tgt != vocab_type_dft) {
|
|
LOG_DBG("%s: draft model vocab type must match target model to use speculation but ", __func__);
|
|
LOG_DBG("vocab_type_dft = %d while vocab_type_tgt = %d\n", vocab_type_dft, vocab_type_tgt);
|
|
return false;
|
|
}
|
|
|
|
if (
|
|
llama_vocab_get_add_bos(vocab_tgt) != llama_vocab_get_add_bos(vocab_dft) ||
|
|
llama_vocab_get_add_eos(vocab_tgt) != llama_vocab_get_add_eos(vocab_dft) ||
|
|
llama_vocab_bos(vocab_tgt) != llama_vocab_bos(vocab_dft) ||
|
|
llama_vocab_eos(vocab_tgt) != llama_vocab_eos(vocab_dft)
|
|
) {
|
|
LOG_DBG("%s: draft model special tokens must match target model to use speculation\n", __func__);
|
|
return false;
|
|
}
|
|
|
|
{
|
|
const int n_vocab_tgt = llama_vocab_n_tokens(vocab_tgt);
|
|
const int n_vocab_dft = llama_vocab_n_tokens(vocab_dft);
|
|
const int vocab_diff = n_vocab_tgt > n_vocab_dft
|
|
? n_vocab_tgt - n_vocab_dft
|
|
: n_vocab_dft - n_vocab_tgt;
|
|
|
|
if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
|
|
LOG_DBG("%s: draft model vocab must closely match target model to use speculation but ", __func__);
|
|
LOG_DBG("target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
|
|
n_vocab_tgt, llama_vocab_n_tokens(vocab_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
|
|
return false;
|
|
}
|
|
|
|
for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) {
|
|
const char * token_text_tgt = llama_vocab_get_text(vocab_tgt, i);
|
|
const char * token_text_dft = llama_vocab_get_text(vocab_dft, i);
|
|
|
|
if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
|
|
LOG_DBG("%s: draft model vocab must match target model to use speculation but ", __func__);
|
|
LOG_DBG("token %d content differs - target '%s', draft '%s'\n", i,
|
|
common_token_to_piece(vocab_tgt, i).c_str(),
|
|
common_token_to_piece(vocab_dft, i).c_str());
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
static bool common_speculative_are_dflash_compatible(
|
|
const llama_model * model_tgt,
|
|
const llama_model * model_dft) {
|
|
const llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt);
|
|
const llama_vocab * vocab_dft = llama_model_get_vocab(model_dft);
|
|
|
|
if (llama_vocab_type(vocab_tgt) != llama_vocab_type(vocab_dft)) {
|
|
LOG_DBG("%s: DFlash draft model vocab type must match the target model\n", __func__);
|
|
return false;
|
|
}
|
|
|
|
const bool add_bos_tgt = llama_vocab_get_add_bos(vocab_tgt);
|
|
const bool add_bos_dft = llama_vocab_get_add_bos(vocab_dft);
|
|
const bool add_eos_tgt = llama_vocab_get_add_eos(vocab_tgt);
|
|
const bool add_eos_dft = llama_vocab_get_add_eos(vocab_dft);
|
|
const llama_token bos_tgt = llama_vocab_bos(vocab_tgt);
|
|
const llama_token bos_dft = llama_vocab_bos(vocab_dft);
|
|
const llama_token eos_tgt = llama_vocab_eos(vocab_tgt);
|
|
const llama_token eos_dft = llama_vocab_eos(vocab_dft);
|
|
|
|
if (add_bos_tgt != add_bos_dft || add_eos_tgt != add_eos_dft ||
|
|
(add_bos_tgt && bos_tgt != bos_dft) ||
|
|
(add_eos_tgt && eos_tgt != eos_dft)) {
|
|
LOG_DBG("%s: DFlash draft special tokens must match the target model (add_bos=%d/%d add_eos=%d/%d bos=%d/%d eos=%d/%d)\n",
|
|
__func__,
|
|
(int) add_bos_tgt,
|
|
(int) add_bos_dft,
|
|
(int) add_eos_tgt,
|
|
(int) add_eos_dft,
|
|
(int) bos_tgt,
|
|
(int) bos_dft,
|
|
(int) eos_tgt,
|
|
(int) eos_dft);
|
|
return false;
|
|
}
|
|
|
|
const int n_vocab_tgt = llama_vocab_n_tokens(vocab_tgt);
|
|
const int n_vocab_dft = llama_vocab_n_tokens(vocab_dft);
|
|
const int vocab_diff = n_vocab_tgt > n_vocab_dft
|
|
? n_vocab_tgt - n_vocab_dft
|
|
: n_vocab_dft - n_vocab_tgt;
|
|
|
|
if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
|
|
LOG_DBG("%s: DFlash draft vocab size differs too much from the target model (%d vs %d)\n",
|
|
__func__, n_vocab_dft, n_vocab_tgt);
|
|
return false;
|
|
}
|
|
|
|
for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) {
|
|
const char * token_text_tgt = llama_vocab_get_text(vocab_tgt, i);
|
|
const char * token_text_dft = llama_vocab_get_text(vocab_dft, i);
|
|
|
|
if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
|
|
LOG_DBG("%s: DFlash draft token %d differs - target '%s', draft '%s'\n", __func__, i,
|
|
common_token_to_piece(vocab_tgt, i).c_str(),
|
|
common_token_to_piece(vocab_dft, i).c_str());
|
|
return false;
|
|
}
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
// state of an implementation of speculative decoding
|
|
//
|
|
// each implementation has a unique type and a state that is implementation-specific
|
|
// in a subclass of common_speculative_state
|
|
struct common_speculative_state {
|
|
const enum common_speculative_type type;
|
|
|
|
size_t n_call_begin = 0; // number of times this implementation was called for refresh.
|
|
size_t n_call_draft = 0; // number of times this implementation was called for generation.
|
|
size_t n_call_accept = 0; // number of times this implementation was called for accumulation.
|
|
|
|
size_t n_gen_drafts = 0; // number of times a draft or part was generated by this implementation.
|
|
size_t n_acc_drafts = 0; // number of times a draft or part was accepted by the target model.
|
|
size_t n_gen_tokens = 0; // number of tokens generated by this implementation.
|
|
size_t n_acc_tokens = 0; // number of tokens accepted by the target model.
|
|
|
|
// TODO: track performance of most recent calls
|
|
const bool gen_perf = true; // whether to generate performance stats.
|
|
|
|
int64_t t_begin_us = 0; // total time spent in refresh of this implementation in microseconds.
|
|
int64_t t_draft_us = 0; // total time spent in generating drafts in this implementation in microseconds.
|
|
int64_t t_accept_us = 0; // total time spent in accumulation of this implementation in microseconds.
|
|
|
|
common_speculative_state(enum common_speculative_type type) : type(type) {}
|
|
|
|
virtual ~common_speculative_state() = default;
|
|
|
|
virtual void begin(const llama_tokens & prompt) = 0;
|
|
|
|
virtual void draft(
|
|
const common_params_speculative & params,
|
|
const llama_tokens & prompt_tgt,
|
|
llama_token id_last,
|
|
llama_tokens & result) = 0;
|
|
|
|
virtual void draft(
|
|
const common_params_speculative & params,
|
|
const llama_tokens & prompt_tgt,
|
|
llama_token id_last,
|
|
llama_pos draft_base_pos,
|
|
llama_seq_id draft_seq_id,
|
|
llama_tokens & result) {
|
|
GGML_UNUSED(draft_base_pos);
|
|
GGML_UNUSED(draft_seq_id);
|
|
draft(params, prompt_tgt, id_last, result);
|
|
}
|
|
|
|
virtual void accept(uint16_t n_accepted) = 0;
|
|
};
|
|
|
|
struct common_speculative_state_mtp;
|
|
struct common_speculative_state_dflash;
|
|
|
|
static void dflash_contract_log_append(
|
|
const common_speculative_state_dflash & state,
|
|
llama_seq_id seq_id,
|
|
const std::vector<llama_pos> & new_positions);
|
|
static void dflash_contract_log_draft(
|
|
const common_speculative_state_dflash & state,
|
|
int32_t n_keep,
|
|
size_t result_size);
|
|
|
|
static common_speculative_state_mtp * common_speculative_get_mtp_state(common_speculative * spec);
|
|
static const common_speculative_state_mtp * common_speculative_get_mtp_state(const common_speculative * spec);
|
|
static common_speculative_state_dflash * common_speculative_get_dflash_state(common_speculative * spec);
|
|
static const common_speculative_state_dflash * common_speculative_get_dflash_state(const common_speculative * spec);
|
|
static int32_t common_speculative_feature_width(const common_speculative * spec);
|
|
static void dflash_materialize_target_window_features(common_speculative_state_dflash & state);
|
|
static void dflash_ring_reset_rows(common_speculative_state_dflash & state, const float * rows, int32_t n_rows);
|
|
static void dflash_append_target_features(
|
|
common_speculative_state_dflash & state,
|
|
const float * feature_rows,
|
|
int32_t n_rows);
|
|
static void dflash_clear_target_features(common_speculative_state_dflash & state);
|
|
static void mtp_invalidate_cached_drafts(common_speculative_state_mtp & state);
|
|
static bool common_speculative_checkpoint_save(
|
|
common_speculative_checkpoint & ckpt,
|
|
llama_model * model,
|
|
llama_context * ctx,
|
|
common_sampler * sampler_src,
|
|
const common_params_sampling & sparams,
|
|
llama_seq_id seq_id,
|
|
llama_pos n_past,
|
|
llama_token sampled,
|
|
int max_tokens,
|
|
int ckpt_mode);
|
|
|
|
static std::vector<llama_token> mtp_speculative_gen_draft(
|
|
common_speculative_state_mtp & state,
|
|
struct common_sampler * smpl,
|
|
struct llama_context * ctx,
|
|
int n_draft,
|
|
float p_min,
|
|
llama_token id_last,
|
|
llama_pos n_past,
|
|
llama_seq_id seq_id,
|
|
bool constant_draft_positions = false);
|
|
|
|
static int32_t mtp_update_kv_cache(struct llama_context * ctx, const llama_batch & batch, bool is_prompt_warmup);
|
|
|
|
static bool dflash_contract_log_enabled() {
|
|
const char * env = std::getenv("IK_DFLASH_CONTRACT_LOG");
|
|
if (env == nullptr || *env == '\0') {
|
|
return false;
|
|
}
|
|
|
|
return std::strcmp(env, "0") != 0 &&
|
|
std::strcmp(env, "false") != 0 &&
|
|
std::strcmp(env, "off") != 0;
|
|
}
|
|
|
|
static bool dflash_stats_log_enabled() {
|
|
const char * env = std::getenv("IK_DFLASH_STATS_LOG");
|
|
if (env == nullptr || *env == '\0') {
|
|
return false;
|
|
}
|
|
|
|
return std::strcmp(env, "0") != 0 &&
|
|
std::strcmp(env, "false") != 0 &&
|
|
std::strcmp(env, "off") != 0;
|
|
}
|
|
|
|
template <typename T>
|
|
static std::string dflash_contract_format_values(
|
|
const std::vector<T> & values,
|
|
size_t edge_count = 4) {
|
|
std::ostringstream oss;
|
|
oss << '[';
|
|
if (values.empty()) {
|
|
oss << ']';
|
|
return oss.str();
|
|
}
|
|
|
|
const size_t head = std::min(edge_count, values.size());
|
|
for (size_t i = 0; i < head; ++i) {
|
|
if (i > 0) {
|
|
oss << ',';
|
|
}
|
|
oss << values[i];
|
|
}
|
|
|
|
if (values.size() > edge_count * 2) {
|
|
oss << ",...,";
|
|
for (size_t i = values.size() - edge_count; i < values.size(); ++i) {
|
|
if (i > values.size() - edge_count) {
|
|
oss << ',';
|
|
}
|
|
oss << values[i];
|
|
}
|
|
} else {
|
|
for (size_t i = head; i < values.size(); ++i) {
|
|
oss << ',' << values[i];
|
|
}
|
|
}
|
|
|
|
oss << ']';
|
|
return oss.str();
|
|
}
|
|
|
|
struct dflash_contract_pos_summary {
|
|
llama_pos first = -1;
|
|
llama_pos last = -1;
|
|
int32_t gap_count = 0;
|
|
int32_t nonmono_count = 0;
|
|
};
|
|
|
|
static dflash_contract_pos_summary dflash_contract_summarize_positions(
|
|
const std::vector<llama_pos> & positions) {
|
|
dflash_contract_pos_summary summary;
|
|
if (positions.empty()) {
|
|
return summary;
|
|
}
|
|
|
|
summary.first = positions.front();
|
|
summary.last = positions.back();
|
|
for (size_t i = 1; i < positions.size(); ++i) {
|
|
if (positions[i] <= positions[i - 1]) {
|
|
summary.nonmono_count++;
|
|
} else if (positions[i] != positions[i - 1] + 1) {
|
|
summary.gap_count++;
|
|
}
|
|
}
|
|
|
|
return summary;
|
|
}
|
|
|
|
struct mtp_last_embd {
|
|
std::vector<float> embd;
|
|
float prob = 0.0f;
|
|
int last_id = -1;
|
|
};
|
|
|
|
struct common_speculative_state_mtp : public common_speculative_state {
|
|
llama_context * ctx_tgt;
|
|
llama_context * ctx_mtp = nullptr;
|
|
common_sampler * smpl;
|
|
// For Gemma 4 external MTP assistant: draft positions are held constant
|
|
bool constant_draft_positions = false;
|
|
int n_embd = 0;
|
|
std::unordered_map<llama_seq_id, std::vector<float>> target_hidden_by_seq;
|
|
std::unordered_map<llama_seq_id, mtp_last_embd> draft_cache_by_seq;
|
|
|
|
common_speculative_state_mtp(
|
|
enum common_speculative_type type,
|
|
llama_context * ctx_tgt,
|
|
llama_context * ctx_mtp,
|
|
bool constant_draft_positions = false)
|
|
: common_speculative_state(type)
|
|
, ctx_tgt(ctx_tgt)
|
|
, ctx_mtp(ctx_mtp)
|
|
, constant_draft_positions(constant_draft_positions)
|
|
{
|
|
struct common_params_sampling sparams;
|
|
sparams.samplers_sequence = {
|
|
llama_sampler_type::DIST,
|
|
};
|
|
smpl = common_sampler_init(llama_get_model(ctx_mtp), sparams);
|
|
llama_set_mtp_target_context(ctx_mtp, ctx_tgt);
|
|
n_embd = llama_mtp_state_n_embd(ctx_mtp);
|
|
|
|
LOG_INF("%s: MTP context ready (n_ctx=%d, constant_draft_positions=%s)\n", __func__,
|
|
llama_n_ctx(ctx_mtp), constant_draft_positions ? "true" : "false");
|
|
}
|
|
|
|
~common_speculative_state_mtp() override {
|
|
common_sampler_free(smpl);
|
|
if (ctx_mtp) {
|
|
llama_free(ctx_mtp);
|
|
}
|
|
}
|
|
|
|
void begin(const llama_tokens & prompt) override {
|
|
GGML_UNUSED(prompt);
|
|
target_hidden_by_seq.clear();
|
|
draft_cache_by_seq.clear();
|
|
}
|
|
|
|
void draft(
|
|
const common_params_speculative & params,
|
|
const llama_tokens & prompt_tgt,
|
|
llama_token id_last,
|
|
llama_tokens & result) override {
|
|
draft(params, prompt_tgt, id_last, -1, 0, result);
|
|
}
|
|
|
|
void draft(
|
|
const common_params_speculative & params,
|
|
const llama_tokens & prompt_tgt,
|
|
llama_token id_last,
|
|
llama_pos draft_base_pos,
|
|
llama_seq_id seq_id,
|
|
llama_tokens & result) override {
|
|
|
|
const llama_pos mtp_pos_max = llama_kv_cache_seq_pos_max(ctx_mtp, seq_id);
|
|
const bool has_draft_base_pos = draft_base_pos >= 0;
|
|
// Prefer the target slot position when the caller has it. Gemma4 external MTP reads
|
|
// the target KV cache directly, so ctx_mtp's own KV position is not authoritative.
|
|
const llama_pos n_past = has_draft_base_pos
|
|
? draft_base_pos
|
|
: (mtp_pos_max >= 0 ? mtp_pos_max + 1 : (llama_pos) prompt_tgt.size());
|
|
|
|
if (!has_draft_base_pos && !prompt_tgt.empty() && mtp_pos_max < (llama_pos)prompt_tgt.size() - 1) {
|
|
LOG_WRN("%s: MTP context not fully warmed up: pos_max = %d, expected = %d\n",
|
|
__func__, (int)mtp_pos_max, (int)prompt_tgt.size() - 1);
|
|
}
|
|
if (has_draft_base_pos && !constant_draft_positions && mtp_pos_max < n_past - 1) {
|
|
LOG_WRN("%s: MTP context not fully warmed up: pos_max = %d, expected >= %d\n",
|
|
__func__, (int)mtp_pos_max, (int)n_past - 1);
|
|
}
|
|
|
|
llama_context * ctx = ctx_mtp;
|
|
|
|
const auto hidden_it = target_hidden_by_seq.find(seq_id);
|
|
if (hidden_it == target_hidden_by_seq.end() || (int) hidden_it->second.size() != n_embd) {
|
|
LOG_WRN("%s: missing target hidden state for seq_id %d\n", __func__, (int) seq_id);
|
|
result.clear();
|
|
return;
|
|
}
|
|
|
|
if (!llama_set_draft_input_hidden_state_copy(ctx, hidden_it->second.data(), hidden_it->second.size())) {
|
|
result.clear();
|
|
return;
|
|
}
|
|
|
|
result = mtp_speculative_gen_draft(
|
|
*this,
|
|
smpl,
|
|
ctx,
|
|
params.n_max,
|
|
params.p_min,
|
|
id_last,
|
|
n_past,
|
|
seq_id,
|
|
constant_draft_positions
|
|
);
|
|
}
|
|
|
|
void accept(uint16_t n_accepted) override {
|
|
GGML_UNUSED(n_accepted);
|
|
}
|
|
};
|
|
|
|
#include "speculative-impl.h"
|
|
|
|
static bool common_speculative_capture_target_features(
|
|
common_speculative * spec,
|
|
const common_speculative_feature_view & features);
|
|
|
|
static bool common_speculative_feature_view_from_hidden_rows(
|
|
const std::vector<float> & hidden_rows,
|
|
int32_t width,
|
|
llama_seq_id seq_id,
|
|
llama_pos pos_base,
|
|
common_speculative_feature_view & view) {
|
|
view = {};
|
|
view.kind = COMMON_SPECULATIVE_FEATURE_HIDDEN_STATE;
|
|
view.width = width;
|
|
|
|
if (width <= 0 || hidden_rows.empty() || hidden_rows.size() % (size_t) width != 0) {
|
|
return false;
|
|
}
|
|
|
|
const size_t n_rows = hidden_rows.size() / (size_t) width;
|
|
view.rows.reserve(n_rows);
|
|
for (size_t i = 0; i < n_rows; ++i) {
|
|
view.rows.push_back({
|
|
/* .seq_id = */ seq_id,
|
|
/* .pos = */ pos_base + (llama_pos) i,
|
|
/* .data = */ hidden_rows.data() + i * (size_t) width,
|
|
});
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
static bool common_speculative_collect_target_batch_features(
|
|
const common_speculative * spec,
|
|
llama_context * ctx,
|
|
const llama_batch & batch,
|
|
common_speculative_feature_view & features) {
|
|
features = {};
|
|
if (common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_DFLASH)) {
|
|
return llama_spec_get_dflash_feature_view(ctx, batch, features);
|
|
}
|
|
|
|
if (!common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_MTP)) {
|
|
return true;
|
|
}
|
|
|
|
if (!llama_spec_get_hidden_feature_view(ctx, batch, features)) {
|
|
return false;
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
static bool common_speculative_collect_target_seq_batch_features(
|
|
const common_speculative * spec,
|
|
llama_context * ctx,
|
|
const llama_batch & batch,
|
|
llama_seq_id seq_id,
|
|
common_speculative_feature_view & features) {
|
|
features = {};
|
|
if (common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_DFLASH)) {
|
|
return llama_spec_get_dflash_feature_view_for_seq(ctx, batch, seq_id, features);
|
|
}
|
|
|
|
if (!common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_MTP)) {
|
|
return true;
|
|
}
|
|
|
|
if (!llama_spec_get_hidden_feature_view_for_seq(ctx, batch, seq_id, features)) {
|
|
return false;
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
bool common_speculative_capture_output_hidden(
|
|
common_speculative * spec,
|
|
llama_context * ctx,
|
|
int32_t output_index,
|
|
llama_seq_id seq_id,
|
|
llama_pos pos) {
|
|
if (!common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_MTP)) {
|
|
return true;
|
|
}
|
|
|
|
common_speculative_feature_view features;
|
|
if (!llama_spec_get_hidden_feature_view_from_output_index(ctx, output_index, seq_id, pos, features)) {
|
|
return false;
|
|
}
|
|
|
|
return common_speculative_capture_target_features(spec, features);
|
|
}
|
|
|
|
bool common_speculative_ensure_sequence_hidden(
|
|
common_speculative * spec,
|
|
llama_context * ctx,
|
|
llama_seq_id seq_id,
|
|
llama_pos pos) {
|
|
if (!common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_MTP) || common_speculative_has_sequence_hidden(spec, seq_id)) {
|
|
return true;
|
|
}
|
|
|
|
return common_speculative_capture_output_hidden(spec, ctx, -1, seq_id, pos);
|
|
}
|
|
|
|
common_speculative_draft_result common_speculative_draft_ex(
|
|
common_speculative * spec,
|
|
llama_context * ctx,
|
|
common_params_speculative & params,
|
|
const llama_tokens & prompt_tgt,
|
|
llama_token id_last,
|
|
llama_pos draft_base_pos,
|
|
llama_seq_id draft_seq_id) {
|
|
common_speculative_draft_result result = {};
|
|
|
|
if (common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_MTP)) {
|
|
if (!common_speculative_ensure_sequence_hidden(spec, ctx, draft_seq_id, draft_base_pos - 1)) {
|
|
LOG_ERR("%s: seq_id=%d MTP hidden state is empty during speculation\n",
|
|
__func__, (int) draft_seq_id);
|
|
return result;
|
|
}
|
|
}
|
|
|
|
result.tokens = common_speculative_draft(
|
|
spec,
|
|
params,
|
|
prompt_tgt,
|
|
id_last,
|
|
draft_base_pos,
|
|
draft_seq_id);
|
|
result.type = spec != nullptr && spec->curr_impl != nullptr
|
|
? spec->curr_impl->type
|
|
: COMMON_SPECULATIVE_TYPE_NONE;
|
|
|
|
return result;
|
|
}
|
|
|
|
static bool common_speculative_has_target_features(const common_speculative * spec) {
|
|
return common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_MTP) ||
|
|
common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_DFLASH);
|
|
}
|
|
|
|
bool common_speculative_load_draft_model(
|
|
common_params_speculative & params,
|
|
const gpt_params & params_base) {
|
|
if (!params.has_dft()) {
|
|
return true;
|
|
}
|
|
|
|
gpt_params params_dft;
|
|
params_dft.devices = params.devices;
|
|
params_dft.model = params.model;
|
|
params_dft.main_gpu = params_base.main_gpu;
|
|
params_dft.n_gpu_layers = params.n_gpu_layers;
|
|
params_dft.rpc_servers = params_base.rpc_servers;
|
|
params_dft.cache_type_k = params.cache_type_k.empty() ? params_base.cache_type_k : params.cache_type_k;
|
|
params_dft.cache_type_v = params.cache_type_v.empty() ? params_base.cache_type_v : params.cache_type_v;
|
|
params_dft.flash_attn = params_base.flash_attn;
|
|
params_dft.k_cache_hadamard = params_base.k_cache_hadamard;
|
|
params_dft.v_cache_hadamard = params_base.v_cache_hadamard;
|
|
|
|
if (params.has_stage_type(COMMON_SPECULATIVE_TYPE_DFLASH)) {
|
|
params_dft.split_mode = params_base.split_mode;
|
|
for (size_t i = 0; i < std::size(params_dft.tensor_split); ++i) {
|
|
params_dft.tensor_split[i] = params_base.tensor_split[i];
|
|
}
|
|
params_dft.attn_max_batch = params_base.attn_max_batch;
|
|
params_dft.graph_reuse = params_base.graph_reuse;
|
|
params_dft.split_mode_graph_scheduling = params_base.split_mode_graph_scheduling;
|
|
params_dft.scheduler_async = params_base.scheduler_async;
|
|
params_dft.max_extra_alloc_MiB = params_base.max_extra_alloc_MiB;
|
|
params_dft.reduce_type = params_base.reduce_type;
|
|
}
|
|
|
|
if (!params.params.empty()) {
|
|
auto [argc, argv] = parse_command_line("llama-server " + params.params);
|
|
if (!gpt_params_parse(argc, argv, params_dft)) {
|
|
gpt_params_print_usage(argc, argv, params_dft);
|
|
free_command_line(argc, argv);
|
|
return false;
|
|
}
|
|
free_command_line(argc, argv);
|
|
}
|
|
|
|
LOG_INF("%s: loading draft model '%s'\n", __func__, params_dft.model.c_str());
|
|
|
|
if (params_dft.n_ctx == 0) {
|
|
params_dft.n_ctx = params.n_ctx;
|
|
}
|
|
if (params.has_stage_type(COMMON_SPECULATIVE_TYPE_DFLASH) && params_dft.n_gpu_layers < 0) {
|
|
params_dft.n_gpu_layers = params_base.n_gpu_layers;
|
|
}
|
|
params_dft.n_ctx = params_dft.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_dft.n_ctx;
|
|
params_dft.n_parallel = 1;
|
|
params_dft.n_batch = params_dft.n_ctx;
|
|
|
|
params.mparams_dft.path = params_dft.model;
|
|
|
|
llama_model_params mparams_dft = common_model_params_to_llama(params_dft);
|
|
llama_model * loaded_model = llama_model_load_from_file(params_dft.model.c_str(), mparams_dft);
|
|
if (loaded_model == nullptr) {
|
|
LOG_ERR("%s: failed to load draft model '%s'\n", __func__, params.model.c_str());
|
|
return false;
|
|
}
|
|
|
|
params.model_dft = loaded_model;
|
|
params.cparams_dft = common_context_params_to_llama(params_dft);
|
|
return true;
|
|
}
|
|
|
|
bool common_speculative_prepare_mtp_runtime(
|
|
common_params_speculative & params,
|
|
const gpt_params & params_base,
|
|
const llama_model * model,
|
|
bool has_external_mtp) {
|
|
if (!params.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP)) {
|
|
return false;
|
|
}
|
|
|
|
if (llama_model_n_nextn_layer(model) == 0 && !has_external_mtp) {
|
|
LOG_WRN("%s: MTP speculative stage requested, but model has 0 NextN layers. Removing MTP from the configured stage chain.\n",
|
|
__func__);
|
|
params.remove_stage_type(COMMON_SPECULATIVE_TYPE_MTP);
|
|
if (!params.needs_dft_model()) {
|
|
params.clear_dft();
|
|
}
|
|
return false;
|
|
}
|
|
|
|
if (!has_external_mtp) {
|
|
gpt_params params_mtp = params_base;
|
|
params_mtp.pooling_type = LLAMA_POOLING_TYPE_NONE;
|
|
params.cparams_dft = common_context_params_to_llama(params_mtp);
|
|
}
|
|
|
|
params.cparams_dft.mtp = true;
|
|
params.cparams_dft.mtp_op_type = MTP_OP_WARMUP;
|
|
params.cparams_dft.embeddings = true;
|
|
|
|
return true;
|
|
}
|
|
|
|
common_speculative_init_status common_speculative_try_init(
|
|
common_params_speculative & params,
|
|
llama_context * ctx_tgt,
|
|
common_speculative ** out_spec) {
|
|
if (out_spec != nullptr) {
|
|
*out_spec = nullptr;
|
|
}
|
|
|
|
if (!params.has_stage_chain()) {
|
|
return COMMON_SPECULATIVE_INIT_SKIPPED;
|
|
}
|
|
|
|
common_speculative * spec = common_speculative_init(params, ctx_tgt);
|
|
if (spec != nullptr) {
|
|
if (out_spec != nullptr) {
|
|
*out_spec = spec;
|
|
}
|
|
return COMMON_SPECULATIVE_INIT_READY;
|
|
}
|
|
|
|
const llama_model * model = ctx_tgt != nullptr ? llama_get_model(ctx_tgt) : nullptr;
|
|
if (model != nullptr && llama_model_has_recurrent(model)) {
|
|
return COMMON_SPECULATIVE_INIT_ERR_RECURRENT;
|
|
}
|
|
if (params.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP)) {
|
|
return COMMON_SPECULATIVE_INIT_ERR_MTP;
|
|
}
|
|
return COMMON_SPECULATIVE_INIT_ERR_GENERIC;
|
|
}
|
|
|
|
void common_speculative_prepare_startup(
|
|
gpt_params & params_base,
|
|
bool allow_parallel_mtp) {
|
|
auto & params = params_base.speculative;
|
|
|
|
if (!allow_parallel_mtp && params_base.n_parallel > 1 && params.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP)) {
|
|
LOG_WRN("%s: MTP is not supported with parallel slots yet, removing the MTP stage to avoid cross-slot corruption. n_parallel=%d, stage_chain=%s\n",
|
|
__func__, params_base.n_parallel, common_speculative_stage_chain_to_str(params).c_str());
|
|
params.remove_stage_type(COMMON_SPECULATIVE_TYPE_MTP);
|
|
}
|
|
|
|
if (!params.needs_dft_model()) {
|
|
params.clear_dft();
|
|
}
|
|
|
|
params_base.has_mtp = params.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP);
|
|
}
|
|
|
|
bool common_speculative_finalize_startup(
|
|
gpt_params & params_base,
|
|
const llama_model * model) {
|
|
auto & params = params_base.speculative;
|
|
|
|
if (!params.needs_dft_model()) {
|
|
params.clear_dft();
|
|
}
|
|
|
|
if (params.has_dft()) {
|
|
LLAMA_LOG_INFO("\n\n==================================loading DRAFT model==================================\n\n");
|
|
if (!common_speculative_load_draft_model(params, params_base)) {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
params_base.has_mtp = params.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP);
|
|
const bool has_external_mtp = params_base.has_mtp &&
|
|
llama_model_is_gemma4_mtp_assistant(params.model_dft);
|
|
|
|
params_base.has_mtp = common_speculative_prepare_mtp_runtime(
|
|
params,
|
|
params_base,
|
|
model,
|
|
has_external_mtp);
|
|
if (params_base.has_mtp) {
|
|
params_base.pooling_type = LLAMA_POOLING_TYPE_NONE;
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
bool common_speculative_before_draft(
|
|
common_speculative * spec,
|
|
llama_model * model,
|
|
llama_context * ctx,
|
|
common_sampler * sampler_src,
|
|
const common_params_sampling & sparams,
|
|
llama_seq_id seq_id,
|
|
llama_pos n_past,
|
|
llama_token sampled,
|
|
int max_tokens,
|
|
int ckpt_mode) {
|
|
if (spec == nullptr) {
|
|
return false;
|
|
}
|
|
|
|
return common_speculative_checkpoint_save(
|
|
spec->checkpoint,
|
|
model,
|
|
ctx,
|
|
sampler_src,
|
|
sparams,
|
|
seq_id,
|
|
n_past,
|
|
sampled,
|
|
max_tokens,
|
|
ckpt_mode);
|
|
}
|
|
|
|
int32_t common_speculative_on_target_seq_batch(
|
|
common_speculative * spec,
|
|
llama_context * ctx_tgt,
|
|
const llama_batch & batch,
|
|
llama_seq_id seq_id,
|
|
bool is_prompt_warmup) {
|
|
if (ctx_tgt == nullptr || batch.n_tokens <= 0) {
|
|
return 0;
|
|
}
|
|
|
|
if (!common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_DFLASH)) {
|
|
llama_context * ctx_mtp = common_speculative_get_companion_ctx(spec);
|
|
ctx_mtp = ctx_mtp ? ctx_mtp : ctx_tgt;
|
|
if (ctx_mtp == nullptr) {
|
|
return 0;
|
|
}
|
|
|
|
const int n_embd_src = common_speculative_ctx_mtp_n_embd(ctx_tgt);
|
|
const int n_embd_dst = common_speculative_ctx_mtp_n_embd(ctx_mtp);
|
|
if (n_embd_src <= 0 || n_embd_dst <= 0) {
|
|
return -1;
|
|
}
|
|
|
|
if (n_embd_src != n_embd_dst) {
|
|
LOG_ERR("MTP warmup hidden state width mismatch: n_embd_src = %d, n_embd_dst = %d\n", n_embd_src, n_embd_dst);
|
|
return -1;
|
|
}
|
|
}
|
|
|
|
common_speculative_feature_view feature_view;
|
|
const llama_batch * batch_for_spec = &batch;
|
|
llama_batch seq_batch = {};
|
|
const bool needs_seq_split = is_prompt_warmup && !common_speculative_batch_is_exact_single_seq(batch, seq_id);
|
|
auto * dflash_state = common_speculative_get_dflash_state(spec);
|
|
const bool measure_dflash_warmup_collect = dflash_state != nullptr && is_prompt_warmup;
|
|
|
|
if (needs_seq_split) {
|
|
const int n_seq_tokens = common_speculative_copy_seq_batch(batch, seq_id, seq_batch);
|
|
if (n_seq_tokens <= 0) {
|
|
return n_seq_tokens < 0 ? -1 : 0;
|
|
}
|
|
|
|
const int64_t t_collect_us = measure_dflash_warmup_collect ? ggml_time_us() : 0;
|
|
if (!common_speculative_collect_target_seq_batch_features(spec, ctx_tgt, batch, seq_id, feature_view)) {
|
|
llama_batch_free(seq_batch);
|
|
return -1;
|
|
}
|
|
if (measure_dflash_warmup_collect) {
|
|
dflash_state->t_warmup_collect_us += (uint64_t) (ggml_time_us() - t_collect_us);
|
|
dflash_state->n_warmup_collect_calls++;
|
|
dflash_state->n_warmup_collect_rows += (size_t) n_seq_tokens;
|
|
}
|
|
|
|
batch_for_spec = &seq_batch;
|
|
} else {
|
|
const int64_t t_collect_us = measure_dflash_warmup_collect ? ggml_time_us() : 0;
|
|
if (!common_speculative_collect_target_batch_features(spec, ctx_tgt, batch, feature_view)) {
|
|
return -1;
|
|
}
|
|
if (measure_dflash_warmup_collect) {
|
|
dflash_state->t_warmup_collect_us += (uint64_t) (ggml_time_us() - t_collect_us);
|
|
dflash_state->n_warmup_collect_calls++;
|
|
dflash_state->n_warmup_collect_rows += (size_t) batch.n_tokens;
|
|
}
|
|
}
|
|
|
|
const int32_t ret = common_speculative_on_target_batch(spec, *batch_for_spec, feature_view, is_prompt_warmup);
|
|
if (needs_seq_split) {
|
|
llama_batch_free(seq_batch);
|
|
}
|
|
|
|
return ret;
|
|
}
|
|
|
|
bool common_speculative_copy_output_hidden_rows(
|
|
const common_speculative * spec,
|
|
llama_context * ctx,
|
|
const std::vector<int32_t> & output_indices,
|
|
std::vector<float> & hidden_rows) {
|
|
hidden_rows.clear();
|
|
if (common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_DFLASH)) {
|
|
return llama_spec_copy_dflash_rows_from_output_indices(ctx, output_indices, hidden_rows);
|
|
}
|
|
|
|
if (!common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_MTP)) {
|
|
return true;
|
|
}
|
|
|
|
return llama_spec_copy_hidden_rows_from_output_indices(ctx, output_indices, hidden_rows);
|
|
}
|
|
|
|
static bool common_speculative_build_commit_tokens(
|
|
common_speculative_type spec_type_used,
|
|
llama_token sampled_before,
|
|
const std::vector<llama_token> & ids,
|
|
std::vector<llama_token> & commit_tokens) {
|
|
commit_tokens.clear();
|
|
if (ids.empty()) {
|
|
return true;
|
|
}
|
|
|
|
if (spec_type_used == COMMON_SPECULATIVE_TYPE_MTP) {
|
|
commit_tokens = ids;
|
|
return true;
|
|
}
|
|
|
|
commit_tokens.reserve(ids.size());
|
|
commit_tokens.push_back(sampled_before);
|
|
if (ids.size() > 1) {
|
|
commit_tokens.insert(commit_tokens.end(), ids.begin(), ids.end() - 1);
|
|
}
|
|
|
|
return commit_tokens.size() == ids.size();
|
|
}
|
|
|
|
static bool common_speculative_apply_hidden_rows(
|
|
common_speculative * spec,
|
|
llama_seq_id seq_id,
|
|
llama_pos pos_base,
|
|
const std::vector<llama_token> & ids,
|
|
const std::vector<float> & hidden_rows) {
|
|
const int32_t feature_width = common_speculative_feature_width(spec);
|
|
if (feature_width <= 0 || ids.empty()) {
|
|
return true;
|
|
}
|
|
|
|
const size_t expected_floats = ids.size() * (size_t) feature_width;
|
|
if (hidden_rows.size() != expected_floats) {
|
|
return false;
|
|
}
|
|
|
|
llama_batch accepted_batch = llama_batch_init(ids.size(), 0, 1);
|
|
for (size_t i = 0; i < ids.size(); ++i) {
|
|
common_batch_add(accepted_batch, ids[i], pos_base + (llama_pos) i, { seq_id }, true);
|
|
}
|
|
|
|
common_speculative_feature_view feature_view;
|
|
const bool have_feature_view = common_speculative_feature_view_from_hidden_rows(
|
|
hidden_rows, feature_width, seq_id, pos_base, feature_view);
|
|
const int32_t ret = have_feature_view
|
|
? common_speculative_on_target_batch(spec, accepted_batch, feature_view, false)
|
|
: -1;
|
|
|
|
llama_batch_free(accepted_batch);
|
|
return ret == 0;
|
|
}
|
|
|
|
bool common_speculative_commit_accepted_hidden_rows(
|
|
common_speculative * spec,
|
|
common_speculative_type spec_type_used,
|
|
llama_seq_id seq_id,
|
|
llama_pos pos_base,
|
|
llama_token sampled_before,
|
|
const std::vector<llama_token> & ids,
|
|
const std::vector<float> & hidden_rows) {
|
|
if (common_speculative_feature_width(spec) <= 0 || ids.empty()) {
|
|
return true;
|
|
}
|
|
|
|
std::vector<llama_token> commit_tokens;
|
|
if (!common_speculative_build_commit_tokens(spec_type_used, sampled_before, ids, commit_tokens)) {
|
|
return false;
|
|
}
|
|
|
|
auto * dflash_state = common_speculative_get_dflash_state(spec);
|
|
const int64_t t_commit_us = dflash_state != nullptr ? ggml_time_us() : 0;
|
|
const bool ok = common_speculative_apply_hidden_rows(spec, seq_id, pos_base, commit_tokens, hidden_rows);
|
|
if (dflash_state != nullptr) {
|
|
dflash_state->t_accept_commit_us += (uint64_t) (ggml_time_us() - t_commit_us);
|
|
dflash_state->n_accept_commit_calls++;
|
|
dflash_state->n_accept_commit_rows += commit_tokens.size();
|
|
}
|
|
|
|
return ok;
|
|
}
|
|
|
|
bool common_speculative_commit_accepted_output(
|
|
common_speculative * spec,
|
|
llama_context * ctx,
|
|
common_speculative_type spec_type_used,
|
|
llama_seq_id seq_id,
|
|
llama_pos pos_base,
|
|
llama_token sampled_before,
|
|
const std::vector<llama_token> & ids,
|
|
const std::vector<int32_t> & output_indices) {
|
|
if (common_speculative_feature_width(spec) <= 0 || ids.empty()) {
|
|
return true;
|
|
}
|
|
|
|
std::vector<float> hidden_rows;
|
|
auto * dflash_state = common_speculative_get_dflash_state(spec);
|
|
const int64_t t_copy_us = dflash_state != nullptr ? ggml_time_us() : 0;
|
|
if (!common_speculative_copy_output_hidden_rows(spec, ctx, output_indices, hidden_rows)) {
|
|
return false;
|
|
}
|
|
if (dflash_state != nullptr) {
|
|
dflash_state->t_accept_output_copy_us += (uint64_t) (ggml_time_us() - t_copy_us);
|
|
dflash_state->n_accept_output_copy_calls++;
|
|
dflash_state->n_accept_output_copy_rows += output_indices.size();
|
|
}
|
|
|
|
return common_speculative_commit_accepted_hidden_rows(
|
|
spec,
|
|
spec_type_used,
|
|
seq_id,
|
|
pos_base,
|
|
sampled_before,
|
|
ids,
|
|
hidden_rows);
|
|
}
|
|
|
|
static bool common_speculative_checkpoint_save(
|
|
common_speculative_checkpoint & ckpt,
|
|
llama_model * model,
|
|
llama_context * ctx,
|
|
common_sampler * sampler_src,
|
|
const common_params_sampling & sparams,
|
|
llama_seq_id seq_id,
|
|
llama_pos n_past,
|
|
llama_token sampled,
|
|
int max_tokens,
|
|
int ckpt_mode) {
|
|
ckpt.clear();
|
|
ckpt.n_past = n_past;
|
|
ckpt.sampled = sampled;
|
|
|
|
const int actual_mode = llama_spec_ckpt_init(ctx, ckpt_mode, max_tokens);
|
|
if (actual_mode == LLAMA_SPEC_CKPT_NONE) {
|
|
return false;
|
|
}
|
|
ckpt.per_step_enabled = (actual_mode == LLAMA_SPEC_CKPT_PER_STEP);
|
|
|
|
ckpt.valid = llama_spec_ckpt_save(ctx, seq_id);
|
|
if (!ckpt.valid) {
|
|
llama_spec_ckpt_discard(ctx);
|
|
return false;
|
|
}
|
|
|
|
ckpt.sampler = common_sampler_init(model, sparams);
|
|
if (ckpt.sampler == nullptr) {
|
|
common_speculative_checkpoint_discard(ckpt, ctx);
|
|
return false;
|
|
}
|
|
|
|
if (sampler_src != nullptr) {
|
|
common_sampler_clone(sampler_src, ckpt.sampler);
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
const common_speculative_checkpoint * common_speculative_get_checkpoint(const common_speculative * spec) {
|
|
return spec != nullptr ? &spec->checkpoint : nullptr;
|
|
}
|
|
|
|
void common_speculative_checkpoint_discard(
|
|
common_speculative_checkpoint & ckpt,
|
|
llama_context * ctx) {
|
|
ckpt.clear();
|
|
llama_spec_ckpt_discard(ctx);
|
|
}
|
|
|
|
void common_speculative_checkpoint_restore(
|
|
common_speculative_checkpoint & ckpt,
|
|
common_speculative * spec,
|
|
llama_context * ctx,
|
|
common_sampler * sampler_dst,
|
|
llama_seq_id seq_id,
|
|
common_speculative_type spec_type_used,
|
|
llama_token sampled_before,
|
|
const std::vector<llama_token> & ids,
|
|
int n_draft,
|
|
const std::vector<float> & mtp_hidden_state_pre,
|
|
int32_t mtp_n_past_base) {
|
|
if (!ckpt.valid) {
|
|
return;
|
|
}
|
|
|
|
if (ckpt.per_step_enabled) {
|
|
const int step = (int) ids.size() - 1;
|
|
llama_spec_ckpt_restore(ctx, seq_id, ckpt.n_past, step);
|
|
|
|
if (ckpt.sampler != nullptr && sampler_dst != nullptr) {
|
|
common_sampler_clone(ckpt.sampler, sampler_dst);
|
|
}
|
|
if (sampler_dst != nullptr) {
|
|
for (llama_token id : ids) {
|
|
common_sampler_accept(sampler_dst, ctx, id, true);
|
|
}
|
|
}
|
|
|
|
if (common_speculative_has_target_features(spec) && !mtp_hidden_state_pre.empty()) {
|
|
if (!common_speculative_commit_accepted_hidden_rows(
|
|
spec,
|
|
spec_type_used,
|
|
seq_id,
|
|
mtp_n_past_base,
|
|
sampled_before,
|
|
ids,
|
|
mtp_hidden_state_pre)) {
|
|
common_speculative_clear_sequence_hidden(spec, seq_id);
|
|
} else if (spec_type_used != COMMON_SPECULATIVE_TYPE_MTP) {
|
|
LOG_DBG("%s: seq_id=%d synced MTP target hidden state from accepted-prefix rows after per-step restore\n",
|
|
__func__, (int) seq_id);
|
|
}
|
|
}
|
|
|
|
LOG_DBG("%s: seq_id=%d per-step restore: step=%d (rejected %d drafts)\n",
|
|
__func__, (int) seq_id, step, (int) (n_draft - (ids.size() - 1)));
|
|
} else {
|
|
llama_spec_ckpt_restore(ctx, seq_id, ckpt.n_past, 0);
|
|
|
|
if (ckpt.sampler != nullptr && sampler_dst != nullptr) {
|
|
common_sampler_clone(ckpt.sampler, sampler_dst);
|
|
}
|
|
|
|
if (!ids.empty()) {
|
|
const int n_re = (int) ids.size();
|
|
llama_batch re_batch = llama_batch_init(n_re, 0, 1);
|
|
common_batch_add(re_batch, ckpt.sampled, ckpt.n_past, { seq_id }, n_re == 1);
|
|
for (int j = 0; j < n_re - 1; ++j) {
|
|
common_batch_add(re_batch, ids[j], ckpt.n_past + 1 + j, { seq_id }, j == n_re - 2);
|
|
}
|
|
|
|
if (common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_MTP)) {
|
|
for (int j = 0; j < re_batch.n_tokens; ++j) {
|
|
re_batch.logits[j] = true;
|
|
}
|
|
llama_set_embeddings(ctx, true);
|
|
}
|
|
|
|
const int ret = llama_decode(ctx, re_batch);
|
|
if (ret != 0) {
|
|
LOG_ERR("%s: seq_id=%d failed to re-decode accepted tokens after checkpoint restore: %d\n",
|
|
__func__, (int) seq_id, ret);
|
|
}
|
|
|
|
if (common_speculative_has_target_features(spec)) {
|
|
std::vector<int32_t> redecoded_indices(n_re);
|
|
for (int j = 0; j < n_re; ++j) {
|
|
redecoded_indices[j] = j;
|
|
}
|
|
|
|
if (!common_speculative_commit_accepted_output(
|
|
spec,
|
|
ctx,
|
|
spec_type_used,
|
|
seq_id,
|
|
ckpt.n_past,
|
|
sampled_before,
|
|
ids,
|
|
redecoded_indices)) {
|
|
common_speculative_clear_sequence_hidden(spec, seq_id);
|
|
}
|
|
}
|
|
|
|
if (sampler_dst != nullptr) {
|
|
for (llama_token id : ids) {
|
|
common_sampler_accept(sampler_dst, ctx, id, true);
|
|
}
|
|
}
|
|
|
|
llama_batch_free(re_batch);
|
|
LOG_DBG("%s: seq_id=%d spec checkpoint restored: re-decoded %d tokens (rejected %d drafts)\n",
|
|
__func__, (int) seq_id, n_re, (int) (n_draft - (ids.size() - 1)));
|
|
}
|
|
}
|
|
|
|
common_speculative_checkpoint_discard(ckpt, ctx);
|
|
}
|
|
|
|
void common_speculative_commit(
|
|
common_speculative * spec,
|
|
llama_context * ctx,
|
|
common_sampler * sampler_dst,
|
|
llama_seq_id seq_id,
|
|
llama_token sampled_before,
|
|
const std::vector<llama_token> & ids,
|
|
int n_draft,
|
|
llama_pos pos_base,
|
|
const std::vector<int32_t> & accepted_output_indices) {
|
|
GGML_ASSERT(spec != nullptr);
|
|
GGML_ASSERT(!ids.empty());
|
|
|
|
common_speculative_checkpoint & ckpt = spec->checkpoint;
|
|
const common_speculative_type spec_type_used = spec->curr_impl != nullptr
|
|
? spec->curr_impl->type
|
|
: COMMON_SPECULATIVE_TYPE_NONE;
|
|
const bool any_rejected = (int) ids.size() - 1 < n_draft;
|
|
std::vector<float> mtp_hidden_state_pre;
|
|
|
|
common_speculative_accept(spec, ids.size() - 1);
|
|
|
|
if (common_speculative_has_target_features(spec) &&
|
|
any_rejected &&
|
|
ckpt.valid &&
|
|
!accepted_output_indices.empty()) {
|
|
if (!common_speculative_copy_output_hidden_rows(spec, ctx, accepted_output_indices, mtp_hidden_state_pre)) {
|
|
mtp_hidden_state_pre.clear();
|
|
}
|
|
}
|
|
|
|
if (any_rejected && ckpt.valid) {
|
|
common_speculative_checkpoint_restore(
|
|
ckpt,
|
|
spec,
|
|
ctx,
|
|
sampler_dst,
|
|
seq_id,
|
|
spec_type_used,
|
|
sampled_before,
|
|
ids,
|
|
n_draft,
|
|
mtp_hidden_state_pre,
|
|
pos_base);
|
|
return;
|
|
}
|
|
|
|
if (common_speculative_has_target_features(spec) && !accepted_output_indices.empty()) {
|
|
if (!common_speculative_commit_accepted_output(
|
|
spec,
|
|
ctx,
|
|
spec_type_used,
|
|
seq_id,
|
|
pos_base,
|
|
sampled_before,
|
|
ids,
|
|
accepted_output_indices)) {
|
|
common_speculative_clear_sequence_hidden(spec, seq_id);
|
|
} else if (spec_type_used != COMMON_SPECULATIVE_TYPE_MTP) {
|
|
LOG_DBG("%s: seq_id=%d synced MTP target hidden state from accepted-prefix rows\n",
|
|
__func__, (int) seq_id);
|
|
}
|
|
}
|
|
|
|
llama_kv_cache_seq_rm(ctx, seq_id, pos_base + (llama_pos) (ids.size() - 1), -1);
|
|
common_speculative_checkpoint_discard(ckpt, ctx);
|
|
}
|
|
|
|
void common_speculative_print_stats(const common_speculative * spec, double slot_tps, int n_decoded, int n_past, common_params_speculative * active_params) {
|
|
if (spec == nullptr) {
|
|
return;
|
|
}
|
|
|
|
for (const auto & impl : spec->impls) {
|
|
std::string str_perf;
|
|
if (impl->gen_perf) {
|
|
std::ostringstream oss;
|
|
oss << std::fixed << std::setprecision(3) << impl->t_begin_us / 1000.0 << ", ";
|
|
oss << std::fixed << std::setprecision(3) << impl->t_draft_us / 1000.0 << ", ";
|
|
oss << std::fixed << std::setprecision(3) << impl->t_accept_us / 1000.0;
|
|
str_perf = ", dur(b,g,a) = " + oss.str() + " ms";
|
|
} else {
|
|
str_perf = "";
|
|
}
|
|
|
|
LOG_INF("statistics %s: #calls(b,g,a) = %zu %zu %zu, #gen drafts = %zu, #acc drafts = %zu, #gen tokens = %zu, #acc tokens = %zu%s\n",
|
|
common_speculative_type_to_str(impl->type).c_str(),
|
|
impl->n_call_begin, impl->n_call_draft, impl->n_call_accept,
|
|
impl->n_gen_drafts,
|
|
impl->n_acc_drafts,
|
|
impl->n_gen_tokens,
|
|
impl->n_acc_tokens,
|
|
str_perf.c_str());
|
|
|
|
if (impl->type == COMMON_SPECULATIVE_TYPE_DFLASH) {
|
|
const auto * dflash_state = dynamic_cast<const common_speculative_state_dflash *>(impl.get());
|
|
if (dflash_state != nullptr && dflash_stats_log_enabled()) {
|
|
llama_dflash_profile_stats capture_stats;
|
|
llama_dflash_profile_stats graph_stats;
|
|
const bool have_capture = llama_dflash_profile_get_stats(dflash_state->ctx_tgt, &capture_stats);
|
|
const bool have_graph = llama_dflash_profile_get_stats(dflash_state->ctx_dft, &graph_stats);
|
|
|
|
LOG_INF("statistics dflash detail: cross_ctx=%d, window_rows=%d, pos=[%d..%d], window_updates=%zu, rows_seen=%zu, rows_dropped=%zu, shifts=%zu, draft_fail(empty/set/decode)=%zu/%zu/%zu, next_draft_pos=%d\n",
|
|
dflash_state->cross_ctx,
|
|
dflash_state->target_window_rows,
|
|
dflash_state->target_window_pos.empty() ? -1 : (int) dflash_state->target_window_pos.front(),
|
|
dflash_state->target_window_pos.empty() ? -1 : (int) dflash_state->target_window_pos.back(),
|
|
dflash_state->n_window_updates,
|
|
dflash_state->n_rows_seen,
|
|
dflash_state->n_rows_dropped,
|
|
dflash_state->n_context_shifts,
|
|
dflash_state->n_draft_empty,
|
|
dflash_state->n_set_target_fail,
|
|
dflash_state->n_decode_fail,
|
|
(int) dflash_state->last_draft_pos_base);
|
|
|
|
if (have_capture || have_graph) {
|
|
const double kv_cache_total_ms = (double) (
|
|
graph_stats.graph_kv_cache_build_us +
|
|
graph_stats.graph_kv_cache_reserve_us +
|
|
graph_stats.graph_kv_cache_reset_us +
|
|
graph_stats.graph_kv_cache_alloc_us +
|
|
graph_stats.graph_kv_cache_feature_upload_us +
|
|
graph_stats.graph_kv_cache_pos_upload_us +
|
|
graph_stats.graph_kv_cache_compute_us +
|
|
graph_stats.graph_kv_cache_sync_us +
|
|
graph_stats.graph_kv_cache_read_concat_pad_us) / 1000.0;
|
|
const double kv_upload_feature_ms = (double) graph_stats.graph_kv_cache_feature_upload_us / 1000.0;
|
|
const double kv_upload_pos_ms = (double) graph_stats.graph_kv_cache_pos_upload_us / 1000.0;
|
|
const double kv_upload_total_ms = kv_upload_feature_ms + kv_upload_pos_ms;
|
|
const double kv_compute_ms = (double) graph_stats.graph_kv_cache_compute_us / 1000.0;
|
|
const double kv_sync_ms = (double) graph_stats.graph_kv_cache_sync_us / 1000.0;
|
|
const double kv_workspace_total_ms = (double) (
|
|
graph_stats.graph_kv_workspace_build_us +
|
|
graph_stats.graph_kv_workspace_reserve_us +
|
|
graph_stats.graph_kv_workspace_reset_us +
|
|
graph_stats.graph_kv_workspace_alloc_us +
|
|
graph_stats.graph_kv_workspace_compute_us +
|
|
graph_stats.graph_kv_workspace_sync_us) / 1000.0;
|
|
const double draft_kv_traffic_ms = (double) (
|
|
graph_stats.graph_main_node_k_ctx_view_us +
|
|
graph_stats.graph_main_node_v_ctx_view_us +
|
|
graph_stats.graph_main_node_k_concat_us +
|
|
graph_stats.graph_main_node_v_concat_us +
|
|
graph_stats.graph_main_node_k_pad_us +
|
|
graph_stats.graph_main_node_v_pad_us +
|
|
graph_stats.graph_main_node_k_perm_cont_us +
|
|
graph_stats.graph_main_node_v_perm_cont_us) / 1000.0;
|
|
const double draft_main_profiled_ms = (double) (
|
|
graph_stats.graph_main_node_qcur_us +
|
|
graph_stats.graph_main_node_k_draft_us +
|
|
graph_stats.graph_main_node_v_draft_us +
|
|
graph_stats.graph_main_node_flash_attn_us +
|
|
graph_stats.graph_main_node_attn_out_us +
|
|
graph_stats.graph_main_node_ffn_us +
|
|
graph_stats.graph_main_node_result_rows_us +
|
|
graph_stats.graph_main_node_result_norm_us +
|
|
graph_stats.graph_main_node_result_us) / 1000.0;
|
|
const double replay_append_ms = (double) dflash_state->t_accept_append_us / 1000.0;
|
|
const double feature_path_ms = (double) (
|
|
capture_stats.capture_prepare_sync_us +
|
|
capture_stats.capture_materialize_us +
|
|
graph_stats.set_target_copy_us +
|
|
graph_stats.graph_feature_copy_us +
|
|
graph_stats.graph_pos_copy_us +
|
|
graph_stats.graph_mask_build_us) / 1000.0;
|
|
const double decode_internal_ms = (double) (
|
|
graph_stats.decode_prelude_us +
|
|
graph_stats.decode_sched_reset_us +
|
|
graph_stats.decode_build_graph_us +
|
|
graph_stats.decode_sched_alloc_graph_us +
|
|
graph_stats.decode_prepare_us +
|
|
graph_stats.decode_set_inputs_us +
|
|
graph_stats.decode_graph_compute_us +
|
|
graph_stats.decode_result_us +
|
|
graph_stats.decode_embedding_us +
|
|
graph_stats.decode_final_sched_reset_us) / 1000.0;
|
|
|
|
LOG_INF("statistics dflash profile: capture(sync/materialize)=%.3f/%.3f ms calls=%llu/%llu bytes=%llu phase(prompt/verify batches changes)=%llu/%llu %llu/%llu, set_target=%.3f ms rows=%llu bytes=%llu, decode(llama_output_reserve/prepare)=%.3f/%.3f ms calls=%llu/%llu realloc(bytes)=%llu/%llu, prep(total/features/pos/mask)=%.3f/%.3f/%.3f/%.3f ms kv_cache(total/build/reserve/reset/alloc/up_f/up_p/compute/sync/read)=%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f ms calls(prepare/cache/read)=%llu/%llu/%llu bytes(feature/pos/mask/read)=%llu/%llu/%llu/%llu host_layers=%d, fallback_pos(copy/graph)=%llu/%llu, nonmono(copy/graph)=%llu/%llu, capture_fail=%llu/%llu decode_prepare_fail=%llu, visible_kv_max=%llu, last(rows=%d width=%d left_pad=%d n_tokens=%d n_kv=%d pos=[%d..%d])\n",
|
|
(double) capture_stats.capture_prepare_sync_us / 1000.0,
|
|
(double) capture_stats.capture_materialize_us / 1000.0,
|
|
(unsigned long long) capture_stats.capture_prepare_calls,
|
|
(unsigned long long) capture_stats.capture_materialize_calls,
|
|
(unsigned long long) capture_stats.capture_materialize_bytes,
|
|
(unsigned long long) capture_stats.capture_prompt_batches,
|
|
(unsigned long long) capture_stats.capture_prompt_shape_changes,
|
|
(unsigned long long) capture_stats.capture_verify_batches,
|
|
(unsigned long long) capture_stats.capture_verify_shape_changes,
|
|
(double) graph_stats.set_target_copy_us / 1000.0,
|
|
(unsigned long long) graph_stats.set_target_rows,
|
|
(unsigned long long) graph_stats.set_target_copy_bytes,
|
|
(double) graph_stats.decode_output_reserve_us / 1000.0,
|
|
(double) graph_stats.decode_prepare_us / 1000.0,
|
|
(unsigned long long) graph_stats.decode_output_reserve_calls,
|
|
(unsigned long long) graph_stats.decode_prepare_calls,
|
|
(unsigned long long) graph_stats.decode_output_reserve_reallocs,
|
|
(unsigned long long) graph_stats.decode_output_reserve_realloc_bytes,
|
|
(double) graph_stats.graph_prepare_total_us / 1000.0,
|
|
(double) graph_stats.graph_feature_copy_us / 1000.0,
|
|
(double) graph_stats.graph_pos_copy_us / 1000.0,
|
|
(double) graph_stats.graph_mask_build_us / 1000.0,
|
|
kv_cache_total_ms,
|
|
(double) graph_stats.graph_kv_cache_build_us / 1000.0,
|
|
(double) graph_stats.graph_kv_cache_reserve_us / 1000.0,
|
|
(double) graph_stats.graph_kv_cache_reset_us / 1000.0,
|
|
(double) graph_stats.graph_kv_cache_alloc_us / 1000.0,
|
|
(double) graph_stats.graph_kv_cache_feature_upload_us / 1000.0,
|
|
(double) graph_stats.graph_kv_cache_pos_upload_us / 1000.0,
|
|
(double) graph_stats.graph_kv_cache_compute_us / 1000.0,
|
|
(double) graph_stats.graph_kv_cache_sync_us / 1000.0,
|
|
(double) graph_stats.graph_kv_cache_read_concat_pad_us / 1000.0,
|
|
(unsigned long long) graph_stats.graph_prepare_calls,
|
|
(unsigned long long) graph_stats.graph_kv_cache_calls,
|
|
(unsigned long long) graph_stats.graph_kv_cache_read_concat_pad_calls,
|
|
(unsigned long long) graph_stats.graph_feature_bytes,
|
|
(unsigned long long) graph_stats.graph_pos_bytes,
|
|
(unsigned long long) graph_stats.graph_mask_bytes,
|
|
(unsigned long long) graph_stats.graph_kv_cache_cached_bytes,
|
|
graph_stats.last_kv_cache_host_layers,
|
|
(unsigned long long) graph_stats.set_target_missing_positions,
|
|
(unsigned long long) graph_stats.graph_pos_fallbacks,
|
|
(unsigned long long) graph_stats.set_target_non_monotonic_positions,
|
|
(unsigned long long) graph_stats.graph_pos_non_monotonic,
|
|
(unsigned long long) capture_stats.capture_prepare_failures,
|
|
(unsigned long long) capture_stats.capture_materialize_failures,
|
|
(unsigned long long) graph_stats.decode_prepare_failures,
|
|
(unsigned long long) graph_stats.graph_visible_kv_max,
|
|
graph_stats.last_n_rows,
|
|
graph_stats.last_width,
|
|
graph_stats.last_left_pad,
|
|
graph_stats.last_n_tokens,
|
|
graph_stats.last_n_kv_total,
|
|
(int) graph_stats.last_pos_first,
|
|
(int) graph_stats.last_pos_last);
|
|
|
|
LOG_INF("statistics dflash features: total=%.3f ms capture(sync/materialize)=%.3f/%.3f ms set_target=%.3f ms prep(feature/pos/mask)=%.3f/%.3f/%.3f ms rows(materialize/set_target)=%llu/%llu bytes(materialize/set_target/feature/pos/mask)=%llu/%llu/%llu/%llu/%llu\n",
|
|
feature_path_ms,
|
|
(double) capture_stats.capture_prepare_sync_us / 1000.0,
|
|
(double) capture_stats.capture_materialize_us / 1000.0,
|
|
(double) graph_stats.set_target_copy_us / 1000.0,
|
|
(double) graph_stats.graph_feature_copy_us / 1000.0,
|
|
(double) graph_stats.graph_pos_copy_us / 1000.0,
|
|
(double) graph_stats.graph_mask_build_us / 1000.0,
|
|
(unsigned long long) capture_stats.capture_materialize_rows,
|
|
(unsigned long long) graph_stats.set_target_rows,
|
|
(unsigned long long) capture_stats.capture_materialize_bytes,
|
|
(unsigned long long) graph_stats.set_target_copy_bytes,
|
|
(unsigned long long) graph_stats.graph_feature_bytes,
|
|
(unsigned long long) graph_stats.graph_pos_bytes,
|
|
(unsigned long long) graph_stats.graph_mask_bytes);
|
|
|
|
LOG_INF("statistics dflash kv: total=%.3f ms build/reserve/reset/alloc/upload_f/upload_p/compute/sync/read=%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f ms calls=%llu cached_bytes=%llu host_layers=%d\n",
|
|
kv_cache_total_ms,
|
|
(double) graph_stats.graph_kv_cache_build_us / 1000.0,
|
|
(double) graph_stats.graph_kv_cache_reserve_us / 1000.0,
|
|
(double) graph_stats.graph_kv_cache_reset_us / 1000.0,
|
|
(double) graph_stats.graph_kv_cache_alloc_us / 1000.0,
|
|
(double) graph_stats.graph_kv_cache_feature_upload_us / 1000.0,
|
|
(double) graph_stats.graph_kv_cache_pos_upload_us / 1000.0,
|
|
(double) graph_stats.graph_kv_cache_compute_us / 1000.0,
|
|
(double) graph_stats.graph_kv_cache_sync_us / 1000.0,
|
|
(double) graph_stats.graph_kv_cache_read_concat_pad_us / 1000.0,
|
|
(unsigned long long) graph_stats.graph_kv_cache_calls,
|
|
(unsigned long long) graph_stats.graph_kv_cache_cached_bytes,
|
|
graph_stats.last_kv_cache_host_layers);
|
|
|
|
if (graph_stats.graph_kv_workspace_calls > 0) {
|
|
LOG_INF("statistics dflash kv workspace: total=%.3f ms build/reserve/reset/alloc/compute/sync=%.3f/%.3f/%.3f/%.3f/%.3f/%.3f ms calls=%llu\n",
|
|
kv_workspace_total_ms,
|
|
(double) graph_stats.graph_kv_workspace_build_us / 1000.0,
|
|
(double) graph_stats.graph_kv_workspace_reserve_us / 1000.0,
|
|
(double) graph_stats.graph_kv_workspace_reset_us / 1000.0,
|
|
(double) graph_stats.graph_kv_workspace_alloc_us / 1000.0,
|
|
(double) graph_stats.graph_kv_workspace_compute_us / 1000.0,
|
|
(double) graph_stats.graph_kv_workspace_sync_us / 1000.0,
|
|
(unsigned long long) graph_stats.graph_kv_workspace_calls);
|
|
}
|
|
|
|
if (graph_stats.decode_internal_chunks > 0) {
|
|
LOG_INF("statistics dflash decode: llama_decode(total)=%.3f ms calls=%zu chunks=%llu rebuilds=%llu sync_points=%llu internal(total/prelude/sched_reset/build/alloc/prepare/set_inputs/compute/get_result/get_embedding/final_reset)=%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f ms\n",
|
|
(double) dflash_state->t_draft_decode_us / 1000.0,
|
|
dflash_state->n_call_draft,
|
|
(unsigned long long) graph_stats.decode_internal_chunks,
|
|
(unsigned long long) graph_stats.decode_graph_rebuilds,
|
|
(unsigned long long) graph_stats.decode_sync_profile_points,
|
|
decode_internal_ms,
|
|
(double) graph_stats.decode_prelude_us / 1000.0,
|
|
(double) graph_stats.decode_sched_reset_us / 1000.0,
|
|
(double) graph_stats.decode_build_graph_us / 1000.0,
|
|
(double) graph_stats.decode_sched_alloc_graph_us / 1000.0,
|
|
(double) graph_stats.decode_prepare_us / 1000.0,
|
|
(double) graph_stats.decode_set_inputs_us / 1000.0,
|
|
(double) graph_stats.decode_graph_compute_us / 1000.0,
|
|
(double) graph_stats.decode_result_us / 1000.0,
|
|
(double) graph_stats.decode_embedding_us / 1000.0,
|
|
(double) graph_stats.decode_final_sched_reset_us / 1000.0);
|
|
}
|
|
|
|
if (graph_stats.graph_kv_node_fused_target_calls > 0 ||
|
|
graph_stats.graph_kv_node_k_proj_calls > 0 ||
|
|
graph_stats.graph_kv_node_k_norm_calls > 0 ||
|
|
graph_stats.graph_kv_node_k_rope_calls > 0 ||
|
|
graph_stats.graph_kv_node_v_proj_calls > 0 ||
|
|
graph_stats.graph_kv_node_k_store_calls > 0 ||
|
|
graph_stats.graph_kv_node_v_store_calls > 0) {
|
|
LOG_INF("statistics dflash kv nodes: fused_target/k_proj/k_norm/k_rope/v_proj/k_store/v_store=%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f ms calls=%llu/%llu/%llu/%llu/%llu/%llu/%llu\n",
|
|
(double) graph_stats.graph_kv_node_fused_target_us / 1000.0,
|
|
(double) graph_stats.graph_kv_node_k_proj_us / 1000.0,
|
|
(double) graph_stats.graph_kv_node_k_norm_us / 1000.0,
|
|
(double) graph_stats.graph_kv_node_k_rope_us / 1000.0,
|
|
(double) graph_stats.graph_kv_node_v_proj_us / 1000.0,
|
|
(double) graph_stats.graph_kv_node_k_store_us / 1000.0,
|
|
(double) graph_stats.graph_kv_node_v_store_us / 1000.0,
|
|
(unsigned long long) graph_stats.graph_kv_node_fused_target_calls,
|
|
(unsigned long long) graph_stats.graph_kv_node_k_proj_calls,
|
|
(unsigned long long) graph_stats.graph_kv_node_k_norm_calls,
|
|
(unsigned long long) graph_stats.graph_kv_node_k_rope_calls,
|
|
(unsigned long long) graph_stats.graph_kv_node_v_proj_calls,
|
|
(unsigned long long) graph_stats.graph_kv_node_k_store_calls,
|
|
(unsigned long long) graph_stats.graph_kv_node_v_store_calls);
|
|
}
|
|
|
|
if (graph_stats.graph_main_node_qcur_calls > 0 ||
|
|
graph_stats.graph_main_node_k_draft_calls > 0 ||
|
|
graph_stats.graph_main_node_v_draft_calls > 0 ||
|
|
graph_stats.graph_main_node_flash_attn_calls > 0 ||
|
|
graph_stats.graph_main_node_attn_out_calls > 0 ||
|
|
graph_stats.graph_main_node_ffn_calls > 0 ||
|
|
graph_stats.graph_main_node_result_rows_calls > 0 ||
|
|
graph_stats.graph_main_node_result_norm_calls > 0 ||
|
|
graph_stats.graph_main_node_result_calls > 0) {
|
|
LOG_INF("statistics dflash draft nodes: profiled=%.3f ms graph_compute=%.3f ms qcur/k_draft/v_draft/flash_attn/attn_out/ffn/result_rows/result_norm/result=%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f ms calls=%llu/%llu/%llu/%llu/%llu/%llu/%llu/%llu/%llu\n",
|
|
draft_main_profiled_ms,
|
|
(double) graph_stats.decode_graph_compute_us / 1000.0,
|
|
(double) graph_stats.graph_main_node_qcur_us / 1000.0,
|
|
(double) graph_stats.graph_main_node_k_draft_us / 1000.0,
|
|
(double) graph_stats.graph_main_node_v_draft_us / 1000.0,
|
|
(double) graph_stats.graph_main_node_flash_attn_us / 1000.0,
|
|
(double) graph_stats.graph_main_node_attn_out_us / 1000.0,
|
|
(double) graph_stats.graph_main_node_ffn_us / 1000.0,
|
|
(double) graph_stats.graph_main_node_result_rows_us / 1000.0,
|
|
(double) graph_stats.graph_main_node_result_norm_us / 1000.0,
|
|
(double) graph_stats.graph_main_node_result_us / 1000.0,
|
|
(unsigned long long) graph_stats.graph_main_node_qcur_calls,
|
|
(unsigned long long) graph_stats.graph_main_node_k_draft_calls,
|
|
(unsigned long long) graph_stats.graph_main_node_v_draft_calls,
|
|
(unsigned long long) graph_stats.graph_main_node_flash_attn_calls,
|
|
(unsigned long long) graph_stats.graph_main_node_attn_out_calls,
|
|
(unsigned long long) graph_stats.graph_main_node_ffn_calls,
|
|
(unsigned long long) graph_stats.graph_main_node_result_rows_calls,
|
|
(unsigned long long) graph_stats.graph_main_node_result_norm_calls,
|
|
(unsigned long long) graph_stats.graph_main_node_result_calls);
|
|
}
|
|
|
|
if (graph_stats.graph_main_node_k_ctx_view_calls > 0 ||
|
|
graph_stats.graph_main_node_v_ctx_view_calls > 0 ||
|
|
graph_stats.graph_main_node_k_concat_calls > 0 ||
|
|
graph_stats.graph_main_node_v_concat_calls > 0 ||
|
|
graph_stats.graph_main_node_k_pad_calls > 0 ||
|
|
graph_stats.graph_main_node_v_pad_calls > 0 ||
|
|
graph_stats.graph_main_node_k_perm_cont_calls > 0 ||
|
|
graph_stats.graph_main_node_v_perm_cont_calls > 0) {
|
|
LOG_INF("statistics dflash draft kv traffic: total=%.3f ms graph_compute=%.3f ms k_ctx_view/v_ctx_view/k_concat/v_concat/k_pad/v_pad/k_perm_cont/v_perm_cont=%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f ms calls=%llu/%llu/%llu/%llu/%llu/%llu/%llu/%llu\n",
|
|
draft_kv_traffic_ms,
|
|
(double) graph_stats.decode_graph_compute_us / 1000.0,
|
|
(double) graph_stats.graph_main_node_k_ctx_view_us / 1000.0,
|
|
(double) graph_stats.graph_main_node_v_ctx_view_us / 1000.0,
|
|
(double) graph_stats.graph_main_node_k_concat_us / 1000.0,
|
|
(double) graph_stats.graph_main_node_v_concat_us / 1000.0,
|
|
(double) graph_stats.graph_main_node_k_pad_us / 1000.0,
|
|
(double) graph_stats.graph_main_node_v_pad_us / 1000.0,
|
|
(double) graph_stats.graph_main_node_k_perm_cont_us / 1000.0,
|
|
(double) graph_stats.graph_main_node_v_perm_cont_us / 1000.0,
|
|
(unsigned long long) graph_stats.graph_main_node_k_ctx_view_calls,
|
|
(unsigned long long) graph_stats.graph_main_node_v_ctx_view_calls,
|
|
(unsigned long long) graph_stats.graph_main_node_k_concat_calls,
|
|
(unsigned long long) graph_stats.graph_main_node_v_concat_calls,
|
|
(unsigned long long) graph_stats.graph_main_node_k_pad_calls,
|
|
(unsigned long long) graph_stats.graph_main_node_v_pad_calls,
|
|
(unsigned long long) graph_stats.graph_main_node_k_perm_cont_calls,
|
|
(unsigned long long) graph_stats.graph_main_node_v_perm_cont_calls);
|
|
}
|
|
|
|
LOG_INF("statistics dflash hot: kv(upload_f/upload_p/upload/compute/sync)=%.3f/%.3f/%.3f/%.3f/%.3f ms calls=%llu replay(accepted_prefix_append)=%.3f ms calls=%zu rows=%zu\n",
|
|
kv_upload_feature_ms,
|
|
kv_upload_pos_ms,
|
|
kv_upload_total_ms,
|
|
kv_compute_ms,
|
|
kv_sync_ms,
|
|
(unsigned long long) graph_stats.graph_kv_cache_calls,
|
|
replay_append_ms,
|
|
dflash_state->n_accept_append_calls,
|
|
dflash_state->n_accept_append_rows);
|
|
|
|
LOG_INF("statistics dflash stages: draft(decode/sample)=%.3f/%.3f ms warmup(collect/append)=%.3f/%.3f ms calls=%zu/%zu rows=%zu/%zu accept(total/output_copy/append)=%.3f/%.3f/%.3f ms calls=%zu/%zu/%zu rows=%zu/%zu/%zu\n",
|
|
(double) dflash_state->t_draft_decode_us / 1000.0,
|
|
(double) dflash_state->t_draft_sample_us / 1000.0,
|
|
(double) dflash_state->t_warmup_collect_us / 1000.0,
|
|
(double) dflash_state->t_warmup_append_us / 1000.0,
|
|
dflash_state->n_warmup_collect_calls,
|
|
dflash_state->n_warmup_append_calls,
|
|
dflash_state->n_warmup_collect_rows,
|
|
dflash_state->n_warmup_append_rows,
|
|
(double) dflash_state->t_accept_commit_us / 1000.0,
|
|
(double) dflash_state->t_accept_output_copy_us / 1000.0,
|
|
(double) dflash_state->t_accept_append_us / 1000.0,
|
|
dflash_state->n_accept_commit_calls,
|
|
dflash_state->n_accept_output_copy_calls,
|
|
dflash_state->n_accept_append_calls,
|
|
dflash_state->n_accept_commit_rows,
|
|
dflash_state->n_accept_output_copy_rows,
|
|
dflash_state->n_accept_append_rows);
|
|
|
|
if (dflash_state->n_accept_append_calls > 0) {
|
|
LOG_INF("statistics dflash replay: append(filter/window_alloc/replace/keep_old/new_rows/commit/log)=%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f ms calls=%zu replace/slide=%zu/%zu\n",
|
|
(double) dflash_state->t_accept_append_filter_us / 1000.0,
|
|
(double) dflash_state->t_accept_append_window_alloc_us / 1000.0,
|
|
(double) dflash_state->t_accept_append_replace_us / 1000.0,
|
|
(double) dflash_state->t_accept_append_keep_old_us / 1000.0,
|
|
(double) dflash_state->t_accept_append_new_rows_us / 1000.0,
|
|
(double) dflash_state->t_accept_append_commit_detail_us / 1000.0,
|
|
(double) dflash_state->t_accept_append_log_us / 1000.0,
|
|
dflash_state->n_accept_append_calls,
|
|
dflash_state->n_accept_append_replace_calls,
|
|
dflash_state->n_accept_append_slide_calls);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if (spec->tuner && spec->tuner->enabled && slot_tps > 0.0 && n_decoded > 0) {
|
|
auto * mutable_spec = const_cast<common_speculative *>(spec);
|
|
if (active_params) {
|
|
mutable_spec->tuner->end_of_request(slot_tps, n_past, *active_params);
|
|
} else {
|
|
common_params_speculative tmp_params;
|
|
mutable_spec->tuner->end_of_request(slot_tps, n_past, tmp_params);
|
|
}
|
|
}
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------
|
|
// MTP
|
|
// ----------------------------------------------------------------------------
|
|
|
|
static common_speculative_state_mtp * common_speculative_get_mtp_state(common_speculative * spec) {
|
|
if (!spec) {
|
|
return nullptr;
|
|
}
|
|
|
|
for (auto & impl : spec->impls) {
|
|
if (impl->type != COMMON_SPECULATIVE_TYPE_MTP) {
|
|
continue;
|
|
}
|
|
|
|
if (auto * mtp_state = dynamic_cast<common_speculative_state_mtp *>(impl.get())) {
|
|
return mtp_state;
|
|
}
|
|
}
|
|
|
|
return nullptr;
|
|
}
|
|
|
|
static const common_speculative_state_mtp * common_speculative_get_mtp_state(const common_speculative * spec) {
|
|
return common_speculative_get_mtp_state(const_cast<common_speculative *>(spec));
|
|
}
|
|
|
|
static common_speculative_state_dflash * common_speculative_get_dflash_state(common_speculative * spec) {
|
|
if (!spec) {
|
|
return nullptr;
|
|
}
|
|
|
|
for (auto & impl : spec->impls) {
|
|
if (impl->type != COMMON_SPECULATIVE_TYPE_DFLASH) {
|
|
continue;
|
|
}
|
|
|
|
if (auto * dflash_state = dynamic_cast<common_speculative_state_dflash *>(impl.get())) {
|
|
return dflash_state;
|
|
}
|
|
}
|
|
|
|
return nullptr;
|
|
}
|
|
|
|
static const common_speculative_state_dflash * common_speculative_get_dflash_state(const common_speculative * spec) {
|
|
return common_speculative_get_dflash_state(const_cast<common_speculative *>(spec));
|
|
}
|
|
|
|
static int32_t common_speculative_feature_width(const common_speculative * spec) {
|
|
if (const auto * dflash_state = common_speculative_get_dflash_state(spec); dflash_state != nullptr) {
|
|
return dflash_state->n_target_features;
|
|
}
|
|
|
|
if (const auto * mtp_state = common_speculative_get_mtp_state(spec); mtp_state != nullptr) {
|
|
return mtp_state->n_embd;
|
|
}
|
|
|
|
return 0;
|
|
}
|
|
|
|
static mtp_last_embd & mtp_get_last_embd(common_speculative_state_mtp & state, llama_seq_id seq_id) {
|
|
auto & last = state.draft_cache_by_seq[seq_id];
|
|
if ((int) last.embd.size() != state.n_embd) {
|
|
last.embd.resize(state.n_embd);
|
|
}
|
|
return last;
|
|
}
|
|
|
|
static void mtp_invalidate_cached_draft(common_speculative_state_mtp & state, llama_seq_id seq_id) {
|
|
auto it = state.draft_cache_by_seq.find(seq_id);
|
|
if (it == state.draft_cache_by_seq.end()) {
|
|
return;
|
|
}
|
|
|
|
it->second.last_id = -1;
|
|
it->second.prob = 0.0f;
|
|
}
|
|
|
|
static void mtp_invalidate_cached_drafts(common_speculative_state_mtp & state) {
|
|
for (auto & entry : state.draft_cache_by_seq) {
|
|
entry.second.last_id = -1;
|
|
entry.second.prob = 0.0f;
|
|
}
|
|
}
|
|
|
|
static void mtp_store_target_hidden(
|
|
common_speculative_state_mtp & state,
|
|
llama_seq_id seq_id,
|
|
const float * hidden,
|
|
int32_t width) {
|
|
if (hidden == nullptr || width <= 0) {
|
|
return;
|
|
}
|
|
|
|
auto & stored = state.target_hidden_by_seq[seq_id];
|
|
stored.assign(hidden, hidden + width);
|
|
}
|
|
|
|
static void mtp_clear_target_hidden(common_speculative_state_mtp & state, llama_seq_id seq_id) {
|
|
state.target_hidden_by_seq.erase(seq_id);
|
|
state.draft_cache_by_seq.erase(seq_id);
|
|
}
|
|
|
|
// DFlash target-window replay and maintenance helpers.
|
|
struct dflash_append_breakdown {
|
|
uint64_t filter_us = 0;
|
|
uint64_t window_alloc_us = 0;
|
|
uint64_t replace_us = 0;
|
|
uint64_t keep_old_us = 0;
|
|
uint64_t new_rows_us = 0;
|
|
uint64_t commit_us = 0;
|
|
uint64_t log_us = 0;
|
|
bool replace_call = false;
|
|
};
|
|
|
|
static void dflash_record_window_update(
|
|
common_speculative_state_dflash & state,
|
|
int32_t keep_rows,
|
|
int32_t append_rows,
|
|
bool replace) {
|
|
state.target_window_keep_rows = std::max<int32_t>(0, keep_rows);
|
|
state.target_window_append_rows = std::max<int32_t>(0, append_rows);
|
|
state.target_window_replace = replace;
|
|
state.target_window_version++;
|
|
}
|
|
|
|
static void dflash_ring_reset_rows(
|
|
common_speculative_state_dflash & state,
|
|
const float * rows,
|
|
int32_t n_rows) {
|
|
const size_t row_width = (size_t) state.n_target_features;
|
|
if (n_rows <= 0 || rows == nullptr) {
|
|
state.target_window_ring_write_pos = 0;
|
|
state.target_window_ring_filled = 0;
|
|
return;
|
|
}
|
|
|
|
if (state.target_window_ring.size() != (size_t) state.cross_ctx * row_width) {
|
|
state.target_window_ring.resize((size_t) state.cross_ctx * row_width);
|
|
}
|
|
|
|
std::memcpy(state.target_window_ring.data(), rows, (size_t) n_rows * row_width * sizeof(float));
|
|
state.target_window_ring_write_pos = n_rows % state.cross_ctx;
|
|
state.target_window_ring_filled = n_rows;
|
|
state.target_window_materialized = false;
|
|
}
|
|
|
|
static void dflash_ring_append_rows(
|
|
common_speculative_state_dflash & state,
|
|
const float * rows,
|
|
int32_t n_rows) {
|
|
const size_t row_width = (size_t) state.n_target_features;
|
|
if (n_rows <= 0 || rows == nullptr) {
|
|
return;
|
|
}
|
|
|
|
if (state.target_window_ring.size() != (size_t) state.cross_ctx * row_width) {
|
|
state.target_window_ring.resize((size_t) state.cross_ctx * row_width);
|
|
}
|
|
|
|
int32_t write_pos = state.target_window_ring_write_pos;
|
|
int32_t remaining = n_rows;
|
|
const float * src = rows;
|
|
while (remaining > 0) {
|
|
const int32_t chunk_rows = std::min<int32_t>(remaining, state.cross_ctx - write_pos);
|
|
std::memcpy(
|
|
state.target_window_ring.data() + (size_t) write_pos * row_width,
|
|
src,
|
|
(size_t) chunk_rows * row_width * sizeof(float));
|
|
src += (size_t) chunk_rows * row_width;
|
|
remaining -= chunk_rows;
|
|
write_pos = (write_pos + chunk_rows) % state.cross_ctx;
|
|
}
|
|
|
|
state.target_window_ring_write_pos = write_pos;
|
|
state.target_window_ring_filled = std::min(state.cross_ctx, state.target_window_ring_filled + n_rows);
|
|
state.target_window_materialized = false;
|
|
}
|
|
|
|
static void dflash_materialize_target_window_features(common_speculative_state_dflash & state) {
|
|
if (state.target_window_materialized || state.target_window_rows <= 0) {
|
|
return;
|
|
}
|
|
|
|
const size_t row_width = (size_t) state.n_target_features;
|
|
state.target_window.resize((size_t) state.target_window_rows * row_width);
|
|
|
|
const int32_t read_start = (state.target_window_ring_write_pos - state.target_window_rows + state.cross_ctx) % state.cross_ctx;
|
|
const int32_t first_rows = std::min<int32_t>(state.target_window_rows, state.cross_ctx - read_start);
|
|
std::memcpy(
|
|
state.target_window.data(),
|
|
state.target_window_ring.data() + (size_t) read_start * row_width,
|
|
(size_t) first_rows * row_width * sizeof(float));
|
|
|
|
const int32_t second_rows = state.target_window_rows - first_rows;
|
|
if (second_rows > 0) {
|
|
std::memcpy(
|
|
state.target_window.data() + (size_t) first_rows * row_width,
|
|
state.target_window_ring.data(),
|
|
(size_t) second_rows * row_width * sizeof(float));
|
|
}
|
|
|
|
state.target_window_materialized = true;
|
|
}
|
|
|
|
static bool dflash_append_target_features(
|
|
common_speculative_state_dflash & state,
|
|
const common_speculative_feature_view & features,
|
|
const llama_batch & batch,
|
|
llama_seq_id seq_id,
|
|
dflash_append_breakdown * breakdown = nullptr) {
|
|
GGML_UNUSED(batch);
|
|
|
|
if (features.kind != COMMON_SPECULATIVE_FEATURE_HIDDEN_STATE ||
|
|
features.width != state.n_target_features ||
|
|
features.rows.empty() ||
|
|
state.cross_ctx <= 0) {
|
|
return false;
|
|
}
|
|
|
|
const size_t row_width = (size_t) state.n_target_features;
|
|
std::vector<float> new_rows;
|
|
std::vector<llama_pos> new_positions;
|
|
new_rows.reserve(features.rows.size() * row_width);
|
|
new_positions.reserve(features.rows.size());
|
|
|
|
const int64_t t_filter_us = ggml_time_us();
|
|
for (const auto & row : features.rows) {
|
|
if (row.seq_id != seq_id || row.data == nullptr) {
|
|
continue;
|
|
}
|
|
|
|
new_positions.push_back(row.pos);
|
|
new_rows.insert(new_rows.end(), row.data, row.data + row_width);
|
|
}
|
|
if (breakdown != nullptr) {
|
|
breakdown->filter_us += (uint64_t) (ggml_time_us() - t_filter_us);
|
|
}
|
|
|
|
if (new_positions.empty()) {
|
|
return false;
|
|
}
|
|
|
|
const int32_t n_rows = (int32_t) new_positions.size();
|
|
state.n_window_updates++;
|
|
state.n_rows_seen += (size_t) n_rows;
|
|
if (n_rows >= state.cross_ctx) {
|
|
state.n_rows_dropped += (size_t) state.target_window_rows + (size_t) (n_rows - state.cross_ctx);
|
|
const int32_t keep_from = n_rows - state.cross_ctx;
|
|
const int64_t t_replace_us = ggml_time_us();
|
|
state.target_window_pos.assign(new_positions.begin() + keep_from, new_positions.end());
|
|
state.target_window_append_features.assign(
|
|
new_rows.begin() + (ptrdiff_t) keep_from * (ptrdiff_t) row_width,
|
|
new_rows.end());
|
|
dflash_ring_reset_rows(state, state.target_window_append_features.data(), state.cross_ctx);
|
|
if (breakdown != nullptr) {
|
|
breakdown->replace_us += (uint64_t) (ggml_time_us() - t_replace_us);
|
|
breakdown->replace_call = true;
|
|
}
|
|
|
|
const int64_t t_commit_us = ggml_time_us();
|
|
state.target_window_rows = state.cross_ctx;
|
|
state.target_window_ring_filled = state.target_window_rows;
|
|
state.last_target_pos = state.target_window_pos.empty() ? -1 : state.target_window_pos.back();
|
|
dflash_record_window_update(state, 0, state.target_window_rows, true);
|
|
if (breakdown != nullptr) {
|
|
breakdown->commit_us += (uint64_t) (ggml_time_us() - t_commit_us);
|
|
}
|
|
|
|
const int64_t t_log_us = ggml_time_us();
|
|
dflash_contract_log_append(state, seq_id, new_positions);
|
|
if (breakdown != nullptr) {
|
|
breakdown->log_us += (uint64_t) (ggml_time_us() - t_log_us);
|
|
}
|
|
return true;
|
|
}
|
|
|
|
const int32_t keep_old_rows = std::min<int32_t>(state.target_window_rows, state.cross_ctx - n_rows);
|
|
state.n_rows_dropped += (size_t) std::max<int32_t>(0, state.target_window_rows - keep_old_rows);
|
|
const int64_t t_window_alloc_us = ggml_time_us();
|
|
std::vector<llama_pos> & next_window_pos = state.target_window_pos_stage;
|
|
next_window_pos.resize((size_t) (keep_old_rows + n_rows));
|
|
if (breakdown != nullptr) {
|
|
breakdown->window_alloc_us += (uint64_t) (ggml_time_us() - t_window_alloc_us);
|
|
}
|
|
|
|
if (keep_old_rows > 0) {
|
|
const int64_t t_keep_old_us = ggml_time_us();
|
|
std::copy(state.target_window_pos.end() - keep_old_rows, state.target_window_pos.end(), next_window_pos.begin());
|
|
if (breakdown != nullptr) {
|
|
breakdown->keep_old_us += (uint64_t) (ggml_time_us() - t_keep_old_us);
|
|
}
|
|
}
|
|
|
|
const int64_t t_new_rows_us = ggml_time_us();
|
|
state.target_window_append_features.assign(new_rows.begin(), new_rows.end());
|
|
dflash_ring_append_rows(state, state.target_window_append_features.data(), n_rows);
|
|
std::copy(new_positions.begin(), new_positions.end(), next_window_pos.begin() + keep_old_rows);
|
|
if (breakdown != nullptr) {
|
|
breakdown->new_rows_us += (uint64_t) (ggml_time_us() - t_new_rows_us);
|
|
}
|
|
|
|
const int64_t t_commit_us = ggml_time_us();
|
|
state.target_window_pos.swap(next_window_pos);
|
|
next_window_pos.clear();
|
|
state.target_window_rows = keep_old_rows + n_rows;
|
|
state.target_window_ring_filled = state.target_window_rows;
|
|
state.last_target_pos = state.target_window_pos.empty() ? -1 : state.target_window_pos.back();
|
|
dflash_record_window_update(state, keep_old_rows, n_rows, false);
|
|
if (breakdown != nullptr) {
|
|
breakdown->commit_us += (uint64_t) (ggml_time_us() - t_commit_us);
|
|
}
|
|
|
|
const int64_t t_log_us = ggml_time_us();
|
|
dflash_contract_log_append(state, seq_id, new_positions);
|
|
if (breakdown != nullptr) {
|
|
breakdown->log_us += (uint64_t) (ggml_time_us() - t_log_us);
|
|
}
|
|
return true;
|
|
}
|
|
|
|
static void dflash_clear_target_features(common_speculative_state_dflash & state) {
|
|
state.target_window.clear();
|
|
state.target_window_pos.clear();
|
|
state.target_window_stage.clear();
|
|
state.target_window_pos_stage.clear();
|
|
state.target_window_append_features.clear();
|
|
state.target_window_rows = 0;
|
|
state.target_window_ring_write_pos = 0;
|
|
state.target_window_ring_filled = 0;
|
|
state.target_window_keep_rows = 0;
|
|
state.target_window_append_rows = 0;
|
|
state.target_window_replace = false;
|
|
state.target_window_materialized = false;
|
|
state.last_target_pos = -1;
|
|
llama_reset_dflash_kv_cache_state(state.ctx_dft);
|
|
}
|
|
|
|
static void dflash_context_shift(
|
|
common_speculative_state_dflash & state,
|
|
llama_pos kv_keep,
|
|
llama_pos kv_discard,
|
|
llama_pos kv_past) {
|
|
if (kv_discard <= 0 || state.target_window_rows <= 0 || state.target_window_pos.empty()) {
|
|
return;
|
|
}
|
|
|
|
dflash_materialize_target_window_features(state);
|
|
|
|
const size_t row_width = (size_t) state.n_target_features;
|
|
const llama_pos discard_begin = kv_keep;
|
|
const llama_pos discard_end = kv_keep + kv_discard;
|
|
|
|
std::vector<float> shifted_rows;
|
|
std::vector<llama_pos> shifted_positions;
|
|
shifted_rows.reserve(state.target_window.size());
|
|
shifted_positions.reserve(state.target_window_pos.size());
|
|
|
|
for (int32_t row = 0; row < state.target_window_rows; ++row) {
|
|
llama_pos pos = state.target_window_pos[(size_t) row];
|
|
if (pos >= discard_begin && pos < discard_end) {
|
|
continue;
|
|
}
|
|
|
|
if (pos >= discard_end && pos < kv_past) {
|
|
pos -= kv_discard;
|
|
}
|
|
|
|
const float * row_src = state.target_window.data() + (size_t) row * row_width;
|
|
shifted_rows.insert(shifted_rows.end(), row_src, row_src + row_width);
|
|
shifted_positions.push_back(pos);
|
|
}
|
|
|
|
state.target_window = std::move(shifted_rows);
|
|
state.target_window_pos = std::move(shifted_positions);
|
|
state.target_window_rows = (int32_t) state.target_window_pos.size();
|
|
dflash_ring_reset_rows(state, state.target_window.data(), state.target_window_rows);
|
|
state.last_target_pos = state.target_window_pos.empty() ? -1 : state.target_window_pos.back();
|
|
dflash_record_window_update(state, 0, state.target_window_rows, true);
|
|
llama_reset_dflash_kv_cache_state(state.ctx_dft);
|
|
state.n_context_shifts++;
|
|
}
|
|
|
|
static bool common_speculative_capture_target_features(common_speculative * spec, const common_speculative_feature_view & features) {
|
|
auto * mtp_state = common_speculative_get_mtp_state(spec);
|
|
if (mtp_state == nullptr || features.kind != COMMON_SPECULATIVE_FEATURE_HIDDEN_STATE || features.width <= 0) {
|
|
return false;
|
|
}
|
|
|
|
bool captured = false;
|
|
for (const auto & row : features.rows) {
|
|
if (row.data == nullptr) {
|
|
continue;
|
|
}
|
|
|
|
mtp_store_target_hidden(*mtp_state, row.seq_id, row.data, features.width);
|
|
mtp_invalidate_cached_draft(*mtp_state, row.seq_id);
|
|
captured = true;
|
|
}
|
|
|
|
return captured;
|
|
}
|
|
|
|
bool common_speculative_has_sequence_hidden(const common_speculative * spec, llama_seq_id seq_id) {
|
|
const auto * mtp_state = common_speculative_get_mtp_state(spec);
|
|
if (mtp_state == nullptr) {
|
|
return false;
|
|
}
|
|
|
|
auto it = mtp_state->target_hidden_by_seq.find(seq_id);
|
|
return it != mtp_state->target_hidden_by_seq.end() && !it->second.empty();
|
|
}
|
|
|
|
void common_speculative_clear_sequence_hidden(common_speculative * spec, llama_seq_id seq_id) {
|
|
auto * mtp_state = common_speculative_get_mtp_state(spec);
|
|
if (mtp_state != nullptr) {
|
|
mtp_clear_target_hidden(*mtp_state, seq_id);
|
|
}
|
|
|
|
if (auto * dflash_state = common_speculative_get_dflash_state(spec); dflash_state != nullptr) {
|
|
dflash_clear_target_features(*dflash_state);
|
|
}
|
|
}
|
|
|
|
void common_speculative_clear_sequence(
|
|
common_speculative * spec,
|
|
llama_seq_id seq_id,
|
|
bool clear_companion_ctx) {
|
|
if (spec != nullptr) {
|
|
spec->checkpoint.clear();
|
|
spec->curr_impl = nullptr;
|
|
spec->last_n_drafted = 0;
|
|
spec->t_step_start_us = 0;
|
|
}
|
|
|
|
common_speculative_clear_sequence_hidden(spec, seq_id);
|
|
|
|
if (clear_companion_ctx) {
|
|
if (auto * ctx_mtp = common_speculative_get_companion_ctx(spec); ctx_mtp != nullptr) {
|
|
llama_kv_cache_clear(ctx_mtp);
|
|
}
|
|
}
|
|
}
|
|
|
|
bool common_speculative_trim_sequence(
|
|
common_speculative * spec,
|
|
llama_context * ctx,
|
|
llama_seq_id seq_id,
|
|
llama_pos pos_begin) {
|
|
const bool target_trimmed = llama_kv_cache_seq_rm(ctx, seq_id, pos_begin, -1);
|
|
if (auto * ctx_mtp = common_speculative_get_companion_ctx(spec); ctx_mtp != nullptr) {
|
|
return target_trimmed && llama_kv_cache_seq_rm(ctx_mtp, seq_id, pos_begin, -1);
|
|
}
|
|
|
|
return target_trimmed;
|
|
}
|
|
|
|
void common_speculative_clear_sequence_kv(
|
|
common_speculative * spec,
|
|
llama_context * ctx,
|
|
llama_seq_id seq_id) {
|
|
common_speculative_clear_sequence(spec, seq_id);
|
|
llama_kv_cache_seq_rm(ctx, seq_id, -1, -1);
|
|
if (auto * ctx_mtp = common_speculative_get_companion_ctx(spec); ctx_mtp != nullptr) {
|
|
llama_kv_cache_seq_rm(ctx_mtp, seq_id, -1, -1);
|
|
}
|
|
}
|
|
|
|
llama_context * common_speculative_get_companion_ctx(common_speculative * spec) {
|
|
if (auto * mtp_state = common_speculative_get_mtp_state(spec); mtp_state != nullptr) {
|
|
return mtp_state->ctx_mtp;
|
|
}
|
|
|
|
if (auto * dflash_state = common_speculative_get_dflash_state(spec); dflash_state != nullptr) {
|
|
return dflash_state->ctx_dft;
|
|
}
|
|
|
|
return nullptr;
|
|
}
|
|
|
|
static int32_t mtp_accept_batch(
|
|
common_speculative_state_mtp & state,
|
|
const llama_batch & accepted_batch,
|
|
llama_seq_id seq_id,
|
|
const float * hidden_rows) {
|
|
if (accepted_batch.n_tokens == 0 || hidden_rows == nullptr) {
|
|
return 0;
|
|
}
|
|
|
|
const size_t hidden_rows_floats = (size_t) accepted_batch.n_tokens * state.n_embd;
|
|
if (!llama_set_draft_input_hidden_state_copy(state.ctx_mtp, hidden_rows, hidden_rows_floats)) {
|
|
return -1;
|
|
}
|
|
if (mtp_update_kv_cache(state.ctx_mtp, accepted_batch, false) != 0) {
|
|
return -1;
|
|
}
|
|
|
|
auto & last = mtp_get_last_embd(state, seq_id);
|
|
const float * embd = llama_get_embeddings_ith(state.ctx_mtp, accepted_batch.n_tokens - 1);
|
|
if (embd != nullptr) {
|
|
std::memcpy(last.embd.data(), embd, last.embd.size() * sizeof(float));
|
|
if (!llama_set_draft_input_hidden_state_copy(state.ctx_mtp, last.embd.data(), last.embd.size())) {
|
|
return -1;
|
|
}
|
|
last.last_id = common_sampler_sample_speculative(nullptr, state.ctx_mtp, accepted_batch.n_tokens - 1, &last.prob);
|
|
}
|
|
|
|
return 0;
|
|
}
|
|
|
|
int32_t common_speculative_on_target_batch(
|
|
common_speculative * spec,
|
|
const llama_batch & batch,
|
|
const common_speculative_feature_view & features,
|
|
bool is_prompt_warmup) {
|
|
if (auto * dflash_state = common_speculative_get_dflash_state(spec); dflash_state != nullptr) {
|
|
if (features.kind != COMMON_SPECULATIVE_FEATURE_HIDDEN_STATE || batch.n_tokens <= 0) {
|
|
return 0;
|
|
}
|
|
|
|
if (features.width != dflash_state->n_target_features) {
|
|
LOG_ERR("%s: DFlash feature width mismatch: got %d expected %d\n",
|
|
__func__, features.width, dflash_state->n_target_features);
|
|
return -1;
|
|
}
|
|
|
|
if (batch.n_seq_id == nullptr || batch.seq_id == nullptr || batch.n_seq_id[0] <= 0 || batch.seq_id[0] == nullptr) {
|
|
return -1;
|
|
}
|
|
|
|
const llama_seq_id seq_id = batch.seq_id[0][0];
|
|
for (int i = 0; i < batch.n_tokens; ++i) {
|
|
if (batch.n_seq_id[i] != 1 || batch.seq_id[i] == nullptr || batch.seq_id[i][0] != seq_id) {
|
|
return -1;
|
|
}
|
|
}
|
|
|
|
dflash_append_breakdown append_breakdown;
|
|
const int64_t t_append_us = ggml_time_us();
|
|
if (!dflash_append_target_features(*dflash_state, features, batch, seq_id, &append_breakdown)) {
|
|
return -1;
|
|
}
|
|
|
|
const uint64_t append_us = (uint64_t) (ggml_time_us() - t_append_us);
|
|
if (is_prompt_warmup) {
|
|
dflash_state->t_warmup_append_us += append_us;
|
|
dflash_state->n_warmup_append_calls++;
|
|
dflash_state->n_warmup_append_rows += (size_t) batch.n_tokens;
|
|
} else {
|
|
dflash_state->t_accept_append_us += append_us;
|
|
dflash_state->t_accept_append_filter_us += append_breakdown.filter_us;
|
|
dflash_state->t_accept_append_window_alloc_us += append_breakdown.window_alloc_us;
|
|
dflash_state->t_accept_append_replace_us += append_breakdown.replace_us;
|
|
dflash_state->t_accept_append_keep_old_us += append_breakdown.keep_old_us;
|
|
dflash_state->t_accept_append_new_rows_us += append_breakdown.new_rows_us;
|
|
dflash_state->t_accept_append_commit_detail_us += append_breakdown.commit_us;
|
|
dflash_state->t_accept_append_log_us += append_breakdown.log_us;
|
|
dflash_state->n_accept_append_calls++;
|
|
dflash_state->n_accept_append_rows += (size_t) batch.n_tokens;
|
|
if (append_breakdown.replace_call) {
|
|
dflash_state->n_accept_append_replace_calls++;
|
|
} else {
|
|
dflash_state->n_accept_append_slide_calls++;
|
|
}
|
|
}
|
|
|
|
return 0;
|
|
}
|
|
|
|
auto * mtp_state = common_speculative_get_mtp_state(spec);
|
|
if (mtp_state == nullptr) {
|
|
return 0;
|
|
}
|
|
|
|
if (features.kind != COMMON_SPECULATIVE_FEATURE_HIDDEN_STATE || features.width <= 0 || batch.n_tokens <= 0) {
|
|
return 0;
|
|
}
|
|
|
|
if (batch.n_seq_id == nullptr || batch.seq_id == nullptr || batch.n_seq_id[0] <= 0 || batch.seq_id[0] == nullptr) {
|
|
return -1;
|
|
}
|
|
|
|
const llama_seq_id seq_id = batch.seq_id[0][0];
|
|
for (int i = 0; i < batch.n_tokens; ++i) {
|
|
if (batch.n_seq_id[i] != 1 || batch.seq_id[i] == nullptr || batch.seq_id[i][0] != seq_id) {
|
|
return -1;
|
|
}
|
|
}
|
|
|
|
std::vector<float> hidden_rows_storage;
|
|
if (!common_speculative_feature_view_copy_batch_rows(features, batch, seq_id, &hidden_rows_storage)) {
|
|
return -1;
|
|
}
|
|
|
|
const float * first_hidden = hidden_rows_storage.data();
|
|
const float * last_hidden = hidden_rows_storage.data() + (size_t) (batch.n_tokens - 1) * features.width;
|
|
mtp_store_target_hidden(*mtp_state, seq_id, last_hidden, features.width);
|
|
|
|
if (mtp_state->constant_draft_positions) {
|
|
mtp_invalidate_cached_draft(*mtp_state, seq_id);
|
|
return 0;
|
|
}
|
|
|
|
if (is_prompt_warmup) {
|
|
if (!llama_set_draft_input_hidden_state_copy(mtp_state->ctx_mtp, hidden_rows_storage.data(), hidden_rows_storage.size())) {
|
|
return -1;
|
|
}
|
|
const int32_t ret = mtp_update_kv_cache(mtp_state->ctx_mtp, batch, true);
|
|
mtp_invalidate_cached_draft(*mtp_state, seq_id);
|
|
return ret;
|
|
}
|
|
|
|
return mtp_accept_batch(*mtp_state, batch, seq_id, first_hidden);
|
|
}
|
|
|
|
common_speculative_type common_speculative_current_type(const common_speculative * spec) {
|
|
if (spec == nullptr || spec->curr_impl == nullptr) {
|
|
return COMMON_SPECULATIVE_TYPE_NONE;
|
|
}
|
|
|
|
return spec->curr_impl->type;
|
|
}
|
|
|
|
void common_speculative_context_shift(
|
|
common_speculative * spec,
|
|
llama_seq_id seq_id,
|
|
llama_pos kv_keep,
|
|
llama_pos kv_discard,
|
|
llama_pos kv_past) {
|
|
if (auto * ctx_mtp = common_speculative_get_companion_ctx(spec); ctx_mtp != nullptr) {
|
|
llama_kv_cache_seq_rm (ctx_mtp, seq_id, kv_keep, kv_keep + kv_discard);
|
|
llama_kv_cache_seq_add(ctx_mtp, seq_id, kv_keep + kv_discard, kv_past, -kv_discard);
|
|
}
|
|
|
|
if (auto * dflash_state = common_speculative_get_dflash_state(spec); dflash_state != nullptr) {
|
|
dflash_context_shift(*dflash_state, kv_keep, kv_discard, kv_past);
|
|
}
|
|
}
|
|
|
|
std::vector<llama_token> mtp_speculative_gen_draft(
|
|
common_speculative_state_mtp & state,
|
|
struct common_sampler * smpl,
|
|
struct llama_context * ctx,
|
|
int n_draft,
|
|
float p_min,
|
|
llama_token id_last,
|
|
llama_pos n_past,
|
|
llama_seq_id seq_id,
|
|
bool constant_draft_positions) {
|
|
|
|
llama_tokens drafts;
|
|
drafts.reserve(n_draft);
|
|
|
|
if (!smpl) return drafts;
|
|
|
|
if (n_draft <= 0) {
|
|
mtp_invalidate_cached_draft(state, seq_id);
|
|
return drafts;
|
|
}
|
|
|
|
common_sampler_reset(smpl);
|
|
|
|
llama_batch mtp_batch = llama_batch_init(1, 0, 1);
|
|
llama_set_mtp_op_type(ctx, MTP_OP_DRAFT_GEN);
|
|
|
|
float prob;
|
|
auto prob_ptr = p_min > 0 ? &prob : nullptr;
|
|
|
|
llama_token current_input_id = id_last;
|
|
llama_pos current_n_past = n_past;
|
|
const int n_embd = llama_mtp_state_n_embd(ctx);
|
|
|
|
auto & last = mtp_get_last_embd(state, seq_id);
|
|
int i0 = 0;
|
|
if (last.last_id >= 0) {
|
|
if (last.prob < p_min) {
|
|
n_draft = 1;
|
|
}
|
|
current_input_id = last.last_id;
|
|
last.last_id = -1;
|
|
drafts.push_back(current_input_id);
|
|
current_n_past++;
|
|
if (!llama_set_draft_input_hidden_state_copy(ctx, last.embd.data(), last.embd.size())) {
|
|
llama_batch_free(mtp_batch);
|
|
llama_set_mtp_op_type(ctx, MTP_OP_NONE);
|
|
return drafts;
|
|
}
|
|
i0 = 1;
|
|
}
|
|
|
|
int n_decode = 0;
|
|
for (int i = i0; i < n_draft; ++i) {
|
|
mtp_batch.n_tokens = 0;
|
|
const llama_pos draft_pos = constant_draft_positions ? n_past : current_n_past;
|
|
common_batch_add(mtp_batch, current_input_id, draft_pos, {seq_id}, true);
|
|
|
|
++n_decode;
|
|
if (llama_decode(ctx, mtp_batch) != 0) {
|
|
break;
|
|
}
|
|
|
|
llama_token id_next = common_sampler_sample_speculative(smpl, ctx, 0, prob_ptr);
|
|
|
|
if (i > 0 && prob_ptr && prob < p_min) {
|
|
break;
|
|
}
|
|
|
|
drafts.push_back(id_next);
|
|
|
|
const float * emb = llama_get_embeddings_ith(ctx, 0);
|
|
if (!emb) {
|
|
break;
|
|
}
|
|
|
|
// Keep a stable copy because later decode steps reuse ctx->embd storage.
|
|
memcpy(last.embd.data(), emb, n_embd * sizeof(float));
|
|
if (!llama_set_draft_input_hidden_state_copy(ctx, last.embd.data(), last.embd.size())) {
|
|
break;
|
|
}
|
|
|
|
current_input_id = id_next;
|
|
current_n_past++;
|
|
|
|
if (prob_ptr && prob < p_min) {
|
|
break;
|
|
}
|
|
}
|
|
llama_batch_free(mtp_batch);
|
|
llama_set_mtp_op_type(ctx, MTP_OP_NONE);
|
|
|
|
// Purge the metadata for the draft tokens.
|
|
// This prevents cache state corruption where two cells map to the same logical position.
|
|
// If the state contained in `last` had a valid token id and probability, it means that we
|
|
// have previously run an "accept" batch, where the token sampled from the main model was included.
|
|
// Even in that case, the token at `n_past` is already committed and must remain in the KV cache,
|
|
// so we only discard the speculative tail starting at `n_past + 1`.
|
|
if (n_decode > 0) {
|
|
llama_kv_cache_seq_rm(ctx, seq_id, n_past + 1, n_past + n_decode + 2);
|
|
}
|
|
|
|
return drafts;
|
|
}
|
|
|
|
|
|
int32_t mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, bool is_prompt_warmup) {
|
|
if (batch.n_tokens == 0) {
|
|
return 0;
|
|
}
|
|
|
|
llama_seq_id seq_id = batch.seq_id[0][0];
|
|
llama_pos start_pos = batch.pos[0];
|
|
|
|
if (llama_kv_cache_seq_pos_max(ctx, seq_id) >= start_pos) {
|
|
llama_kv_cache_seq_rm(ctx, seq_id, start_pos, -1);
|
|
}
|
|
|
|
LOG_DBG("[MTP-UPDATE|%s] Updating %d tokens for seq_id %d from pos %d...\n",
|
|
is_prompt_warmup ? "PROMPT_WARMUP" : "GEN_ACCEPTED", batch.n_tokens, seq_id, (int)start_pos);
|
|
|
|
// We never need all logits. We only need the logits of the last token so we can sample
|
|
// the next draft token. In the MTP_OP_WARMUP case we do not need logits at all, but just
|
|
// in case we also get the logits of the last token.
|
|
llama_batch mtp_batch = batch;
|
|
for (int i = 0; i < mtp_batch.n_tokens; ++i) {
|
|
mtp_batch.logits[i] = false;
|
|
}
|
|
mtp_batch.logits[mtp_batch.n_tokens-1] = true;
|
|
if (is_prompt_warmup) {
|
|
llama_set_mtp_op_type(ctx, MTP_OP_WARMUP);
|
|
} else {
|
|
llama_set_mtp_op_type(ctx, MTP_OP_UPDATE_ACCEPTED);
|
|
}
|
|
|
|
const int32_t ret = llama_decode(ctx, mtp_batch);
|
|
llama_set_mtp_op_type(ctx, MTP_OP_NONE);
|
|
return ret;
|
|
}
|