mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-06-28 04:30:15 -05:00
Merge pull request #1970 from SamuelOliveirads/feat/dflash-implementation
Add DFlash support
This commit is contained in:
commit
f9078e169b
2
.flake8
2
.flake8
@ -17,3 +17,5 @@ exclude =
|
||||
# This contains builds that we don't want to check
|
||||
dist # This is generated with `python build .` for package releases
|
||||
# max-complexity = 10
|
||||
per-file-ignores =
|
||||
gguf-py/gguf/constants.py: E201, E222
|
||||
|
||||
@ -157,6 +157,9 @@ common_params_speculative common_params_speculative::with_stage_overrides(const
|
||||
if (stage.has_p_min_override()) {
|
||||
result.p_min = stage.p_min;
|
||||
}
|
||||
if (stage.has_dflash_cross_ctx_override()) {
|
||||
result.dflash_cross_ctx = stage.dflash_cross_ctx;
|
||||
}
|
||||
if (stage.has_ngram_size_n_override()) {
|
||||
result.ngram_size_n = stage.ngram_size_n;
|
||||
result.ngram_mod.reset();
|
||||
@ -212,6 +215,7 @@ bool common_params_speculative::has_composite_stage_chain() const {
|
||||
|
||||
bool common_params_speculative::needs_dft_model() const {
|
||||
return has_stage_type(COMMON_SPECULATIVE_TYPE_DRAFT) ||
|
||||
has_stage_type(COMMON_SPECULATIVE_TYPE_DFLASH) ||
|
||||
(has_stage_type(COMMON_SPECULATIVE_TYPE_MTP) && has_dft());
|
||||
}
|
||||
|
||||
@ -287,8 +291,12 @@ bool common_speculative_validate_chain(const common_params_speculative & params,
|
||||
return fail("speculative stage has n_min greater than n_max");
|
||||
}
|
||||
|
||||
if (stage.type == COMMON_SPECULATIVE_TYPE_DRAFT && !params.has_dft()) {
|
||||
return fail("draft speculative stage requires a draft model or draft params");
|
||||
if ((stage.type == COMMON_SPECULATIVE_TYPE_DRAFT || stage.type == COMMON_SPECULATIVE_TYPE_DFLASH) && !params.has_dft()) {
|
||||
return fail(common_speculative_type_to_str(stage.type) + " speculative stage requires a draft model or draft params");
|
||||
}
|
||||
|
||||
if (stage.type == COMMON_SPECULATIVE_TYPE_DFLASH && stage_params.dflash_cross_ctx < 1) {
|
||||
return fail("dflash speculative stage requires cross_ctx >= 1");
|
||||
}
|
||||
}
|
||||
|
||||
@ -906,6 +914,13 @@ static void common_speculative_stage_apply_kv(
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (key == "cross_ctx" || key == "dflash_cross_ctx") {
|
||||
stage.dflash_cross_ctx = std::stoi(value_raw);
|
||||
if (stage.dflash_cross_ctx < 1) {
|
||||
throw std::invalid_argument("speculative stage dflash cross_ctx must be at least 1");
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (key == "ngram_size_n") {
|
||||
stage.ngram_size_n = std::stoi(value_raw);
|
||||
if (stage.ngram_size_n < 1 || stage.ngram_size_n > 1024) {
|
||||
@ -3253,11 +3268,12 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
|
||||
" gpu-fallback copy state to GPU buffer; re-decode on rejection\n"
|
||||
" cpu serialise state via llama_state_seq; re-decode on rejection" });
|
||||
options.push_back({ "*", "--spec-type SPEC[:k=v,...]", "canonical speculative stage entry; repeat for a supported two-stage chain.\n"
|
||||
"types: none, draft, mtp, ngram-cache, ngram-simple, ngram-map-k, ngram-map-k4v, ngram-mod, suffix\n"
|
||||
"canonical keys: n_max,n_min,p_min,ngram_size_n,ngram_size_m,ngram_min_hits,suffix_min_match_len,suffix_max_depth,suffix_corpus\n"
|
||||
"types: none, draft, dflash, mtp, ngram-cache, ngram-simple, ngram-map-k, ngram-map-k4v, ngram-mod, suffix\n"
|
||||
"canonical keys: n_max,n_min,p_min,cross_ctx,ngram_size_n,ngram_size_m,ngram_min_hits,suffix_min_match_len,suffix_max_depth,suffix_corpus\n"
|
||||
"for comma-bearing string values, quote the value inside the stage payload for normal shell use\n"
|
||||
"if argv is passed directly without shell unescaping, the parser also accepts escaped commas as \\,\n"
|
||||
"examples: --spec-type mtp:n_max=1,p_min=0.0\n"
|
||||
" --model-draft draft.gguf --spec-type dflash:n_max=4,cross_ctx=512\n"
|
||||
" --spec-type ngram-mod:n_max=64,n_min=2,ngram_size_n=8 --spec-type mtp:n_max=1,p_min=0.0\n"
|
||||
" --spec-type \"suffix:n_max=16,n_min=2,suffix_min_match_len=5,suffix_max_depth=64,suffix_corpus='/tmp/spec,type-corpus.json'\"\n"
|
||||
"legacy --spec-stage, --draft-*, --spec-ngram-*, --suffix-* and -mtp flags are rejected" });
|
||||
|
||||
@ -140,6 +140,7 @@ thinking_tokens thinking_tokens_from_string(const std::string& format);
|
||||
enum common_speculative_type {
|
||||
COMMON_SPECULATIVE_TYPE_NONE, // no speculative decoding
|
||||
COMMON_SPECULATIVE_TYPE_DRAFT, // draft model
|
||||
COMMON_SPECULATIVE_TYPE_DFLASH, // DFlash draft model
|
||||
COMMON_SPECULATIVE_TYPE_MTP, // MTP model
|
||||
COMMON_SPECULATIVE_TYPE_EAGLE3, // eagle draft model
|
||||
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding
|
||||
@ -162,6 +163,7 @@ struct common_speculative_stage_params {
|
||||
int32_t n_max = -1;
|
||||
int32_t n_min = -1;
|
||||
float p_min = -1.0f;
|
||||
int32_t dflash_cross_ctx = -1;
|
||||
|
||||
uint16_t ngram_size_n = 0;
|
||||
uint16_t ngram_size_m = 0;
|
||||
@ -174,6 +176,7 @@ struct common_speculative_stage_params {
|
||||
bool has_n_max_override() const { return n_max >= 0; }
|
||||
bool has_n_min_override() const { return n_min >= 0; }
|
||||
bool has_p_min_override() const { return p_min >= 0.0f; }
|
||||
bool has_dflash_cross_ctx_override() const { return dflash_cross_ctx >= 0; }
|
||||
bool has_ngram_size_n_override() const { return ngram_size_n > 0; }
|
||||
bool has_ngram_size_m_override() const { return ngram_size_m > 0; }
|
||||
bool has_ngram_min_hits_override() const { return ngram_min_hits > 0; }
|
||||
@ -206,6 +209,7 @@ struct common_params_speculative {
|
||||
int32_t n_max = 16; // number of tokens to draft during speculative decoding
|
||||
int32_t n_min = 0; // minimum number of tokens to draft during speculative decoding
|
||||
std::vector<common_speculative_stage_params> stages; // explicit stage chain for single-spec or self-spec + model fallback
|
||||
int32_t dflash_cross_ctx = 512; // target-feature context window for DFlash
|
||||
|
||||
float p_split = 0.1f; // speculative decoding split probability
|
||||
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
|
||||
|
||||
530
common/speculative-dflash-impl.h
Normal file
530
common/speculative-dflash-impl.h
Normal file
@ -0,0 +1,530 @@
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstddef>
|
||||
#include <cstring>
|
||||
#include <vector>
|
||||
|
||||
static bool common_speculative_are_dflash_compatible(
|
||||
const llama_model * model_tgt,
|
||||
const llama_model * model_dft) {
|
||||
const llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt);
|
||||
const llama_vocab * vocab_dft = llama_model_get_vocab(model_dft);
|
||||
|
||||
if (llama_vocab_type(vocab_tgt) != llama_vocab_type(vocab_dft)) {
|
||||
LOG_DBG("%s: DFlash draft model vocab type must match the target model\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
const bool add_bos_tgt = llama_vocab_get_add_bos(vocab_tgt);
|
||||
const bool add_bos_dft = llama_vocab_get_add_bos(vocab_dft);
|
||||
const bool add_eos_tgt = llama_vocab_get_add_eos(vocab_tgt);
|
||||
const bool add_eos_dft = llama_vocab_get_add_eos(vocab_dft);
|
||||
const llama_token bos_tgt = llama_vocab_bos(vocab_tgt);
|
||||
const llama_token bos_dft = llama_vocab_bos(vocab_dft);
|
||||
const llama_token eos_tgt = llama_vocab_eos(vocab_tgt);
|
||||
const llama_token eos_dft = llama_vocab_eos(vocab_dft);
|
||||
|
||||
if (add_bos_tgt != add_bos_dft || add_eos_tgt != add_eos_dft ||
|
||||
(add_bos_tgt && bos_tgt != bos_dft) ||
|
||||
(add_eos_tgt && eos_tgt != eos_dft)) {
|
||||
LOG_DBG("%s: DFlash draft special tokens must match the target model (add_bos=%d/%d add_eos=%d/%d bos=%d/%d eos=%d/%d)\n",
|
||||
__func__,
|
||||
(int) add_bos_tgt,
|
||||
(int) add_bos_dft,
|
||||
(int) add_eos_tgt,
|
||||
(int) add_eos_dft,
|
||||
(int) bos_tgt,
|
||||
(int) bos_dft,
|
||||
(int) eos_tgt,
|
||||
(int) eos_dft);
|
||||
return false;
|
||||
}
|
||||
|
||||
const int n_vocab_tgt = llama_vocab_n_tokens(vocab_tgt);
|
||||
const int n_vocab_dft = llama_vocab_n_tokens(vocab_dft);
|
||||
const int vocab_diff = n_vocab_tgt > n_vocab_dft
|
||||
? n_vocab_tgt - n_vocab_dft
|
||||
: n_vocab_dft - n_vocab_tgt;
|
||||
|
||||
if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
|
||||
LOG_DBG("%s: DFlash draft vocab size differs too much from the target model (%d vs %d)\n",
|
||||
__func__, n_vocab_dft, n_vocab_tgt);
|
||||
return false;
|
||||
}
|
||||
|
||||
for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) {
|
||||
const char * token_text_tgt = llama_vocab_get_text(vocab_tgt, i);
|
||||
const char * token_text_dft = llama_vocab_get_text(vocab_dft, i);
|
||||
|
||||
if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
|
||||
LOG_DBG("%s: DFlash draft token %d differs - target '%s', draft '%s'\n", __func__, i,
|
||||
common_token_to_piece(vocab_tgt, i).c_str(),
|
||||
common_token_to_piece(vocab_dft, i).c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
struct common_speculative_state_dflash;
|
||||
static void dflash_materialize_target_window_features(common_speculative_state_dflash & state);
|
||||
|
||||
// DFlash runtime state and draft path.
|
||||
struct common_speculative_state_dflash : public common_speculative_state {
|
||||
llama_context * ctx_tgt;
|
||||
llama_context * ctx_dft;
|
||||
|
||||
llama_batch batch = {};
|
||||
|
||||
int32_t block_size = 0;
|
||||
int32_t mask_token_id = -1;
|
||||
int32_t n_target_features = 0;
|
||||
int32_t cross_ctx = 0;
|
||||
bool ready = false;
|
||||
|
||||
std::vector<int32_t> target_layer_ids;
|
||||
std::vector<float> target_window;
|
||||
std::vector<llama_pos> target_window_pos;
|
||||
std::vector<float> target_window_stage;
|
||||
std::vector<llama_pos> target_window_pos_stage;
|
||||
std::vector<float> target_window_ring;
|
||||
std::vector<float> target_window_append_features;
|
||||
int32_t target_window_rows = 0;
|
||||
int32_t target_window_ring_write_pos = 0;
|
||||
int32_t target_window_ring_filled = 0;
|
||||
uint64_t target_window_version = 0;
|
||||
int32_t target_window_keep_rows = 0;
|
||||
int32_t target_window_append_rows = 0;
|
||||
bool target_window_replace = false;
|
||||
bool target_window_materialized = false;
|
||||
llama_pos last_target_pos = -1;
|
||||
|
||||
common_speculative_state_dflash(
|
||||
enum common_speculative_type type,
|
||||
llama_context * ctx_tgt,
|
||||
llama_context * ctx_dft,
|
||||
int32_t cross_ctx)
|
||||
: common_speculative_state(type)
|
||||
, ctx_tgt(ctx_tgt)
|
||||
, ctx_dft(ctx_dft)
|
||||
, cross_ctx(std::max(1, cross_ctx))
|
||||
{
|
||||
const llama_model * model_tgt = llama_get_model(ctx_tgt);
|
||||
const llama_model * model_dft = llama_get_model(ctx_dft);
|
||||
|
||||
if (!common_speculative_are_dflash_compatible(model_tgt, model_dft)) {
|
||||
LOG_ERR("%s: DFlash draft model vocab/tokenizer is incompatible with the target model\n", __func__);
|
||||
return;
|
||||
}
|
||||
|
||||
block_size = llama_model_dflash_block_size(model_dft);
|
||||
mask_token_id = llama_model_dflash_mask_token_id(model_dft);
|
||||
n_target_features = llama_model_dflash_n_target_features(model_dft);
|
||||
const int32_t n_target_layers = llama_model_dflash_n_target_layers(model_dft);
|
||||
|
||||
if (block_size <= 0 || mask_token_id < 0 || n_target_features <= 0 || n_target_layers <= 0) {
|
||||
LOG_ERR("%s: invalid DFlash metadata (block_size=%d, mask_token_id=%d, n_target_features=%d, n_target_layers=%d)\n",
|
||||
__func__, block_size, mask_token_id, n_target_features, n_target_layers);
|
||||
return;
|
||||
}
|
||||
|
||||
target_layer_ids.resize((size_t) n_target_layers);
|
||||
if (llama_model_dflash_target_layer_ids(model_dft, target_layer_ids.data(), n_target_layers) != n_target_layers) {
|
||||
LOG_ERR("%s: failed to read DFlash target layer ids\n", __func__);
|
||||
target_layer_ids.clear();
|
||||
return;
|
||||
}
|
||||
|
||||
const auto * vocab_tgt = llama_model_get_vocab(model_tgt);
|
||||
const int32_t target_vocab_size = llama_vocab_n_tokens(vocab_tgt);
|
||||
const int32_t target_hidden_size = llama_model_n_embd(model_tgt);
|
||||
const int32_t draft_hidden_size = llama_model_n_embd(model_dft);
|
||||
const int32_t target_mask_token_id = llama_model_dflash_target_mask_token_id(model_tgt);
|
||||
const int32_t expected_n_target_features = target_hidden_size > 0 ? target_hidden_size * n_target_layers : 0;
|
||||
|
||||
if (target_mask_token_id != (int32_t) LLAMA_TOKEN_NULL && mask_token_id != target_mask_token_id) {
|
||||
LOG_ERR("%s: DFlash mask token mismatch (draft=%d target=%d)\n",
|
||||
__func__, mask_token_id, target_mask_token_id);
|
||||
return;
|
||||
}
|
||||
|
||||
if (target_hidden_size <= 0 || draft_hidden_size <= 0) {
|
||||
LOG_ERR("%s: invalid DFlash hidden sizes (draft=%d target=%d)\n",
|
||||
__func__, draft_hidden_size, target_hidden_size);
|
||||
return;
|
||||
}
|
||||
|
||||
if (expected_n_target_features <= 0 || n_target_features != expected_n_target_features) {
|
||||
LOG_ERR("%s: DFlash target feature width mismatch (metadata=%d expected=%d target_hidden=%d target_layers=%d)\n",
|
||||
__func__, n_target_features, expected_n_target_features, target_hidden_size, n_target_layers);
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<int32_t> sorted_target_layer_ids = target_layer_ids;
|
||||
std::sort(sorted_target_layer_ids.begin(), sorted_target_layer_ids.end());
|
||||
if (std::adjacent_find(sorted_target_layer_ids.begin(), sorted_target_layer_ids.end()) != sorted_target_layer_ids.end()) {
|
||||
LOG_ERR("%s: duplicate DFlash target layer ids survived into runtime validation\n", __func__);
|
||||
target_layer_ids.clear();
|
||||
return;
|
||||
}
|
||||
|
||||
const int32_t n_target_model_layers = llama_n_layer(model_tgt);
|
||||
for (int32_t layer_id : target_layer_ids) {
|
||||
if (layer_id < 0 || layer_id >= n_target_model_layers) {
|
||||
LOG_ERR("%s: invalid DFlash target layer id %d for target model with %d layers\n",
|
||||
__func__, layer_id, n_target_model_layers);
|
||||
target_layer_ids.clear();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
const int32_t io_mode = llama_model_dflash_io_mode(model_dft, model_tgt);
|
||||
if (io_mode == LLAMA_DFLASH_IO_MODE_INVALID) {
|
||||
LOG_ERR("%s: DFlash draft is missing required IO tensors after target sharing\n", __func__);
|
||||
return;
|
||||
}
|
||||
|
||||
if (io_mode == LLAMA_DFLASH_IO_MODE_MIXED) {
|
||||
LOG_ERR("%s: DFlash IO contract must be fully shared or fully self-contained, but resolved to mixed mode\n", __func__);
|
||||
return;
|
||||
}
|
||||
|
||||
if (io_mode == LLAMA_DFLASH_IO_MODE_SELF_CONTAINED && !llama_model_dflash_io_tensors_match(model_dft, target_hidden_size, target_vocab_size)) {
|
||||
LOG_ERR("%s: DFlash self-contained IO tensors do not match the target hidden/vocab contract (target_hidden=%d target_vocab=%d)\n",
|
||||
__func__,
|
||||
target_hidden_size,
|
||||
target_vocab_size);
|
||||
return;
|
||||
}
|
||||
|
||||
if (!llama_set_dflash_capture_layers(ctx_tgt, target_layer_ids.data(), (int32_t) target_layer_ids.size())) {
|
||||
LOG_ERR("%s: failed to configure DFlash target capture callback\n", __func__);
|
||||
return;
|
||||
}
|
||||
|
||||
batch = llama_batch_init(std::max(1, block_size), 0, 1);
|
||||
target_window.reserve((size_t) this->cross_ctx * (size_t) n_target_features);
|
||||
target_window_stage.reserve((size_t) this->cross_ctx * (size_t) n_target_features);
|
||||
target_window_ring.resize((size_t) this->cross_ctx * (size_t) n_target_features);
|
||||
target_window_append_features.reserve((size_t) this->cross_ctx * (size_t) n_target_features);
|
||||
target_window_pos.reserve((size_t) this->cross_ctx);
|
||||
target_window_pos_stage.reserve((size_t) this->cross_ctx);
|
||||
ready = true;
|
||||
|
||||
llama_set_dflash_visible_cross_ctx(ctx_dft, this->cross_ctx);
|
||||
LOG_INF("%s: DFlash context ready (n_ctx=%d, block_size=%d, cross_ctx=%d, n_target_features=%d, n_target_layers=%d)\n",
|
||||
__func__, llama_n_ctx(ctx_dft), block_size, this->cross_ctx, n_target_features, n_target_layers);
|
||||
}
|
||||
|
||||
~common_speculative_state_dflash() override {
|
||||
llama_clear_dflash_capture(ctx_tgt);
|
||||
if (ctx_dft) {
|
||||
llama_free(ctx_dft);
|
||||
}
|
||||
if (batch.token != nullptr) {
|
||||
llama_batch_free(batch);
|
||||
}
|
||||
}
|
||||
|
||||
void begin(const llama_tokens & prompt) override {
|
||||
GGML_UNUSED(prompt);
|
||||
llama_kv_cache_clear(ctx_dft);
|
||||
llama_reset_dflash_kv_cache_state(ctx_dft);
|
||||
}
|
||||
|
||||
void draft(
|
||||
const common_params_speculative & params,
|
||||
const llama_tokens & prompt_tgt,
|
||||
llama_token id_last,
|
||||
llama_tokens & result) override {
|
||||
GGML_UNUSED(prompt_tgt);
|
||||
|
||||
result.clear();
|
||||
if (!ready || target_window_rows <= 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int32_t n_keep = std::min<int32_t>(params.n_max, block_size - 1);
|
||||
if (n_keep <= 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const float * target_features = nullptr;
|
||||
size_t target_feature_floats = 0;
|
||||
llama_dflash_window_update window_update = {
|
||||
target_window_version,
|
||||
target_window_keep_rows,
|
||||
target_window_append_rows,
|
||||
target_window_replace,
|
||||
target_window_append_features.empty() ? nullptr : target_window_append_features.data(),
|
||||
target_window_append_features.size(),
|
||||
};
|
||||
const llama_dflash_kv_cache_transition cache_plan =
|
||||
llama_plan_dflash_kv_cache_transition_for_ctx(ctx_dft, window_update, target_window_rows);
|
||||
|
||||
if (cache_plan.rebuild_cache) {
|
||||
dflash_materialize_target_window_features(*this);
|
||||
target_features = target_window.data();
|
||||
target_feature_floats = target_window.size();
|
||||
window_update.append_features = target_window.data();
|
||||
window_update.append_floats = target_window.size();
|
||||
window_update.append_rows = target_window_rows;
|
||||
}
|
||||
|
||||
if (!llama_set_dflash_target_features_view(ctx_dft, target_features, target_feature_floats, target_window_rows, target_window_pos.data(), &window_update)) {
|
||||
LOG_ERR("%s: failed to set DFlash target features\n", __func__);
|
||||
return;
|
||||
}
|
||||
|
||||
llama_kv_cache_clear(ctx_dft);
|
||||
batch.n_tokens = 0;
|
||||
const int32_t batch_len = n_keep + 1;
|
||||
const llama_pos draft_pos_base = last_target_pos >= 0 ? last_target_pos + 1 : (llama_pos) target_window_rows;
|
||||
const llama_pos seed_pos = last_target_pos >= 0 ? last_target_pos : draft_pos_base - 1;
|
||||
common_batch_add(batch, id_last, seed_pos, { 0 }, false);
|
||||
for (int32_t i = 1; i < batch_len; ++i) {
|
||||
common_batch_add(batch, mask_token_id, draft_pos_base + (i - 1), { 0 }, i <= n_keep);
|
||||
}
|
||||
|
||||
if (llama_decode(ctx_dft, batch) != 0) {
|
||||
LOG_ERR("%s: llama_decode() failed for DFlash draft batch\n", __func__);
|
||||
batch.n_tokens = 0;
|
||||
return;
|
||||
}
|
||||
|
||||
result.reserve((size_t) n_keep);
|
||||
for (int32_t i = 0; i < n_keep; ++i) {
|
||||
llama_token id = llama_get_dflash_draft_token_ith(ctx_dft, i);
|
||||
if (id == LLAMA_TOKEN_NULL) {
|
||||
id = common_sampler_sample_speculative(nullptr, ctx_dft, i + 1, nullptr);
|
||||
}
|
||||
result.push_back(id);
|
||||
}
|
||||
|
||||
batch.n_tokens = 0;
|
||||
}
|
||||
|
||||
void accept(uint16_t n_accepted) override {
|
||||
GGML_UNUSED(n_accepted);
|
||||
}
|
||||
};
|
||||
|
||||
static void dflash_record_window_update(
|
||||
common_speculative_state_dflash & state,
|
||||
int32_t keep_rows,
|
||||
int32_t append_rows,
|
||||
bool replace) {
|
||||
state.target_window_keep_rows = std::max<int32_t>(0, keep_rows);
|
||||
state.target_window_append_rows = std::max<int32_t>(0, append_rows);
|
||||
state.target_window_replace = replace;
|
||||
state.target_window_version++;
|
||||
}
|
||||
|
||||
static void dflash_ring_reset_rows(
|
||||
common_speculative_state_dflash & state,
|
||||
const float * rows,
|
||||
int32_t n_rows) {
|
||||
const size_t row_width = (size_t) state.n_target_features;
|
||||
if (n_rows <= 0 || rows == nullptr) {
|
||||
state.target_window_ring_write_pos = 0;
|
||||
state.target_window_ring_filled = 0;
|
||||
return;
|
||||
}
|
||||
|
||||
if (state.target_window_ring.size() != (size_t) state.cross_ctx * row_width) {
|
||||
state.target_window_ring.resize((size_t) state.cross_ctx * row_width);
|
||||
}
|
||||
|
||||
std::memcpy(state.target_window_ring.data(), rows, (size_t) n_rows * row_width * sizeof(float));
|
||||
state.target_window_ring_write_pos = n_rows % state.cross_ctx;
|
||||
state.target_window_ring_filled = n_rows;
|
||||
state.target_window_materialized = false;
|
||||
}
|
||||
|
||||
static void dflash_ring_append_rows(
|
||||
common_speculative_state_dflash & state,
|
||||
const float * rows,
|
||||
int32_t n_rows) {
|
||||
const size_t row_width = (size_t) state.n_target_features;
|
||||
if (n_rows <= 0 || rows == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (state.target_window_ring.size() != (size_t) state.cross_ctx * row_width) {
|
||||
state.target_window_ring.resize((size_t) state.cross_ctx * row_width);
|
||||
}
|
||||
|
||||
int32_t write_pos = state.target_window_ring_write_pos;
|
||||
int32_t remaining = n_rows;
|
||||
const float * src = rows;
|
||||
while (remaining > 0) {
|
||||
const int32_t chunk_rows = std::min<int32_t>(remaining, state.cross_ctx - write_pos);
|
||||
std::memcpy(
|
||||
state.target_window_ring.data() + (size_t) write_pos * row_width,
|
||||
src,
|
||||
(size_t) chunk_rows * row_width * sizeof(float));
|
||||
src += (size_t) chunk_rows * row_width;
|
||||
remaining -= chunk_rows;
|
||||
write_pos = (write_pos + chunk_rows) % state.cross_ctx;
|
||||
}
|
||||
|
||||
state.target_window_ring_write_pos = write_pos;
|
||||
state.target_window_ring_filled = std::min(state.cross_ctx, state.target_window_ring_filled + n_rows);
|
||||
state.target_window_materialized = false;
|
||||
}
|
||||
|
||||
static void dflash_materialize_target_window_features(common_speculative_state_dflash & state) {
|
||||
if (state.target_window_materialized || state.target_window_rows <= 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const size_t row_width = (size_t) state.n_target_features;
|
||||
state.target_window.resize((size_t) state.target_window_rows * row_width);
|
||||
|
||||
const int32_t read_start = (state.target_window_ring_write_pos - state.target_window_rows + state.cross_ctx) % state.cross_ctx;
|
||||
const int32_t first_rows = std::min<int32_t>(state.target_window_rows, state.cross_ctx - read_start);
|
||||
std::memcpy(
|
||||
state.target_window.data(),
|
||||
state.target_window_ring.data() + (size_t) read_start * row_width,
|
||||
(size_t) first_rows * row_width * sizeof(float));
|
||||
|
||||
const int32_t second_rows = state.target_window_rows - first_rows;
|
||||
if (second_rows > 0) {
|
||||
std::memcpy(
|
||||
state.target_window.data() + (size_t) first_rows * row_width,
|
||||
state.target_window_ring.data(),
|
||||
(size_t) second_rows * row_width * sizeof(float));
|
||||
}
|
||||
|
||||
state.target_window_materialized = true;
|
||||
}
|
||||
|
||||
static bool dflash_append_target_features(
|
||||
common_speculative_state_dflash & state,
|
||||
const common_speculative_feature_view & features,
|
||||
llama_seq_id seq_id) {
|
||||
if (features.kind != COMMON_SPECULATIVE_FEATURE_HIDDEN_STATE ||
|
||||
features.width != state.n_target_features ||
|
||||
features.rows.empty() ||
|
||||
state.cross_ctx <= 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const size_t row_width = (size_t) state.n_target_features;
|
||||
std::vector<float> new_rows;
|
||||
std::vector<llama_pos> new_positions;
|
||||
new_rows.reserve(features.rows.size() * row_width);
|
||||
new_positions.reserve(features.rows.size());
|
||||
|
||||
for (const auto & row : features.rows) {
|
||||
if (row.seq_id != seq_id || row.data == nullptr) {
|
||||
continue;
|
||||
}
|
||||
|
||||
new_positions.push_back(row.pos);
|
||||
new_rows.insert(new_rows.end(), row.data, row.data + row_width);
|
||||
}
|
||||
|
||||
if (new_positions.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const int32_t n_rows = (int32_t) new_positions.size();
|
||||
if (n_rows >= state.cross_ctx) {
|
||||
const int32_t keep_from = n_rows - state.cross_ctx;
|
||||
state.target_window_pos.assign(new_positions.begin() + keep_from, new_positions.end());
|
||||
state.target_window_append_features.assign(
|
||||
new_rows.begin() + (ptrdiff_t) keep_from * (ptrdiff_t) row_width,
|
||||
new_rows.end());
|
||||
dflash_ring_reset_rows(state, state.target_window_append_features.data(), state.cross_ctx);
|
||||
|
||||
state.target_window_rows = state.cross_ctx;
|
||||
state.target_window_ring_filled = state.target_window_rows;
|
||||
state.last_target_pos = state.target_window_pos.empty() ? -1 : state.target_window_pos.back();
|
||||
dflash_record_window_update(state, 0, state.target_window_rows, true);
|
||||
return true;
|
||||
}
|
||||
|
||||
const int32_t keep_old_rows = std::min<int32_t>(state.target_window_rows, state.cross_ctx - n_rows);
|
||||
std::vector<llama_pos> & next_window_pos = state.target_window_pos_stage;
|
||||
next_window_pos.resize((size_t) (keep_old_rows + n_rows));
|
||||
|
||||
if (keep_old_rows > 0) {
|
||||
std::copy(state.target_window_pos.end() - keep_old_rows, state.target_window_pos.end(), next_window_pos.begin());
|
||||
}
|
||||
|
||||
state.target_window_append_features.assign(new_rows.begin(), new_rows.end());
|
||||
dflash_ring_append_rows(state, state.target_window_append_features.data(), n_rows);
|
||||
std::copy(new_positions.begin(), new_positions.end(), next_window_pos.begin() + keep_old_rows);
|
||||
|
||||
state.target_window_pos.swap(next_window_pos);
|
||||
next_window_pos.clear();
|
||||
state.target_window_rows = keep_old_rows + n_rows;
|
||||
state.target_window_ring_filled = state.target_window_rows;
|
||||
state.last_target_pos = state.target_window_pos.empty() ? -1 : state.target_window_pos.back();
|
||||
dflash_record_window_update(state, keep_old_rows, n_rows, false);
|
||||
return true;
|
||||
}
|
||||
|
||||
static void dflash_clear_target_features(common_speculative_state_dflash & state) {
|
||||
state.target_window.clear();
|
||||
state.target_window_pos.clear();
|
||||
state.target_window_stage.clear();
|
||||
state.target_window_pos_stage.clear();
|
||||
state.target_window_append_features.clear();
|
||||
state.target_window_rows = 0;
|
||||
state.target_window_ring_write_pos = 0;
|
||||
state.target_window_ring_filled = 0;
|
||||
state.target_window_keep_rows = 0;
|
||||
state.target_window_append_rows = 0;
|
||||
state.target_window_replace = false;
|
||||
state.target_window_materialized = false;
|
||||
state.last_target_pos = -1;
|
||||
llama_reset_dflash_kv_cache_state(state.ctx_dft);
|
||||
}
|
||||
|
||||
static void dflash_context_shift(
|
||||
common_speculative_state_dflash & state,
|
||||
llama_pos kv_keep,
|
||||
llama_pos kv_discard,
|
||||
llama_pos kv_past) {
|
||||
if (kv_discard <= 0 || state.target_window_rows <= 0 || state.target_window_pos.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
dflash_materialize_target_window_features(state);
|
||||
|
||||
const size_t row_width = (size_t) state.n_target_features;
|
||||
const llama_pos discard_begin = kv_keep;
|
||||
const llama_pos discard_end = kv_keep + kv_discard;
|
||||
|
||||
std::vector<float> shifted_rows;
|
||||
std::vector<llama_pos> shifted_positions;
|
||||
shifted_rows.reserve(state.target_window.size());
|
||||
shifted_positions.reserve(state.target_window_pos.size());
|
||||
|
||||
for (int32_t row = 0; row < state.target_window_rows; ++row) {
|
||||
llama_pos pos = state.target_window_pos[(size_t) row];
|
||||
if (pos >= discard_begin && pos < discard_end) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (pos >= discard_end && pos < kv_past) {
|
||||
pos -= kv_discard;
|
||||
}
|
||||
|
||||
const float * row_src = state.target_window.data() + (size_t) row * row_width;
|
||||
shifted_rows.insert(shifted_rows.end(), row_src, row_src + row_width);
|
||||
shifted_positions.push_back(pos);
|
||||
}
|
||||
|
||||
state.target_window = std::move(shifted_rows);
|
||||
state.target_window_pos = std::move(shifted_positions);
|
||||
state.target_window_rows = (int32_t) state.target_window_pos.size();
|
||||
dflash_ring_reset_rows(state, state.target_window.data(), state.target_window_rows);
|
||||
state.last_target_pos = state.target_window_pos.empty() ? -1 : state.target_window_pos.back();
|
||||
dflash_record_window_update(state, 0, state.target_window_rows, true);
|
||||
llama_reset_dflash_kv_cache_state(state.ctx_dft);
|
||||
}
|
||||
@ -11,9 +11,13 @@
|
||||
#include "suffix-tree.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <atomic>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <iomanip>
|
||||
#include <limits>
|
||||
#include <map>
|
||||
#include <sstream>
|
||||
#include <unordered_map>
|
||||
|
||||
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
|
||||
@ -24,6 +28,7 @@ void llama_set_mtp_target_context(struct llama_context * ctx, struct llama_conte
|
||||
const std::vector<enum common_speculative_type> common_speculative_types = {
|
||||
COMMON_SPECULATIVE_TYPE_NONE,
|
||||
COMMON_SPECULATIVE_TYPE_DRAFT,
|
||||
COMMON_SPECULATIVE_TYPE_DFLASH,
|
||||
COMMON_SPECULATIVE_TYPE_MTP,
|
||||
COMMON_SPECULATIVE_TYPE_EAGLE3,
|
||||
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE,
|
||||
@ -37,6 +42,7 @@ const std::vector<enum common_speculative_type> common_speculative_types = {
|
||||
const std::map<std::string, enum common_speculative_type> common_speculative_type_from_name_map = {
|
||||
{"none", COMMON_SPECULATIVE_TYPE_NONE},
|
||||
{"draft", COMMON_SPECULATIVE_TYPE_DRAFT},
|
||||
{"dflash", COMMON_SPECULATIVE_TYPE_DFLASH},
|
||||
{"mtp", COMMON_SPECULATIVE_TYPE_MTP},
|
||||
{"eagle3", COMMON_SPECULATIVE_TYPE_EAGLE3},
|
||||
{"ngram_simple", COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE},
|
||||
@ -180,9 +186,13 @@ struct common_speculative_state {
|
||||
};
|
||||
|
||||
struct common_speculative_state_mtp;
|
||||
struct common_speculative_state_dflash;
|
||||
|
||||
static common_speculative_state_mtp * common_speculative_get_mtp_state(common_speculative * spec);
|
||||
static const common_speculative_state_mtp * common_speculative_get_mtp_state(const common_speculative * spec);
|
||||
static common_speculative_state_dflash * common_speculative_get_dflash_state(common_speculative * spec);
|
||||
static const common_speculative_state_dflash * common_speculative_get_dflash_state(const common_speculative * spec);
|
||||
static int32_t common_speculative_feature_width(const common_speculative * spec);
|
||||
static void mtp_invalidate_cached_drafts(common_speculative_state_mtp & state);
|
||||
static bool common_speculative_checkpoint_save(
|
||||
common_speculative_checkpoint & ckpt,
|
||||
@ -325,6 +335,8 @@ struct common_speculative_state_mtp : public common_speculative_state {
|
||||
}
|
||||
};
|
||||
|
||||
#include "speculative-dflash-impl.h"
|
||||
|
||||
struct common_speculative_state_draft : public common_speculative_state {
|
||||
llama_context * ctx_tgt; // only used for retokenizing from ctx_dft
|
||||
llama_context * ctx_dft;
|
||||
@ -1025,17 +1037,13 @@ struct common_speculative_state_suffix : public common_speculative_state {
|
||||
};
|
||||
|
||||
struct common_speculative {
|
||||
common_speculative_checkpoint checkpoint;
|
||||
std::vector<common_speculative_config> configs; // resolved stage config for each implementation
|
||||
std::vector<std::unique_ptr<common_speculative_state>> impls; // list of implementations to use and their states
|
||||
common_speculative_checkpoint checkpoint;
|
||||
common_speculative_state * curr_impl = nullptr; // current implementation in use (for stats)
|
||||
std::unique_ptr<spec_tuner> tuner;
|
||||
int last_n_drafted = 0;
|
||||
int64_t t_step_start_us = 0;
|
||||
|
||||
~common_speculative() {
|
||||
checkpoint.clear();
|
||||
}
|
||||
};
|
||||
|
||||
static bool common_speculative_stage_chain_matches(
|
||||
@ -1116,6 +1124,7 @@ std::string common_speculative_type_to_str(enum common_speculative_type type) {
|
||||
switch (type) {
|
||||
case COMMON_SPECULATIVE_TYPE_NONE: return "none";
|
||||
case COMMON_SPECULATIVE_TYPE_DRAFT: return "draft";
|
||||
case COMMON_SPECULATIVE_TYPE_DFLASH: return "dflash";
|
||||
case COMMON_SPECULATIVE_TYPE_MTP: return "mtp";
|
||||
case COMMON_SPECULATIVE_TYPE_EAGLE3: return "eagle3";
|
||||
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return "ngram_simple";
|
||||
@ -1188,13 +1197,18 @@ common_speculative * common_speculative_init(
|
||||
});
|
||||
|
||||
if (has_draft_stage) {
|
||||
LOG_ERR("%s: Gemma4 assistant models only support MTP stages; omit -md for self-spec-only runs or use --spec-type mtp:n_max=1,p_min=0.0 for assistant-backed MTP\n", __func__);
|
||||
LOG_ERR("%s: Gemma4 assistant models only support MTP stages; omit -md for self-spec-only runs or use -mtp/--spec-stage mtp for assistant-backed MTP\n", __func__);
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
const bool has_dflash_stage = std::any_of(stages.begin(), stages.end(), [](const common_speculative_stage_params & stage) {
|
||||
return stage.type == COMMON_SPECULATIVE_TYPE_DFLASH;
|
||||
});
|
||||
|
||||
const bool needs_draft_ctx = std::any_of(stages.begin(), stages.end(), [¶ms](const common_speculative_stage_params & stage) {
|
||||
return stage.type == COMMON_SPECULATIVE_TYPE_DRAFT ||
|
||||
stage.type == COMMON_SPECULATIVE_TYPE_DFLASH ||
|
||||
(stage.type == COMMON_SPECULATIVE_TYPE_MTP && params.model_dft != nullptr);
|
||||
});
|
||||
|
||||
@ -1205,7 +1219,40 @@ common_speculative * common_speculative_init(
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ctx_dft = llama_init_from_model(params.model_dft, params.cparams_dft);
|
||||
llama_context_params cparams_dft = params.cparams_dft;
|
||||
|
||||
if (has_dflash_stage) {
|
||||
if (!llama_model_share_dflash_io_tensors(params.model_dft, llama_get_model(ctx_tgt))) {
|
||||
LOG_ERR("%s: failed to share target IO tensors with DFlash draft model\n", __func__);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
int32_t max_cross_ctx = 0;
|
||||
for (const auto & stage : stages) {
|
||||
if (stage.type != COMMON_SPECULATIVE_TYPE_DFLASH) {
|
||||
continue;
|
||||
}
|
||||
|
||||
max_cross_ctx = std::max(max_cross_ctx, params.with_stage_overrides(stage).dflash_cross_ctx);
|
||||
}
|
||||
|
||||
const int32_t block_size = llama_model_dflash_block_size(params.model_dft);
|
||||
if (block_size <= 0) {
|
||||
LOG_ERR("%s: invalid DFlash draft block size\n", __func__);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const int64_t required_n_ctx = (int64_t) max_cross_ctx + (int64_t) block_size;
|
||||
if (required_n_ctx > std::numeric_limits<int32_t>::max()) {
|
||||
LOG_ERR("%s: invalid DFlash draft context size cross_ctx=%d block_size=%d required_n_ctx=%lld\n",
|
||||
__func__, max_cross_ctx, block_size, (long long) required_n_ctx);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
cparams_dft.n_ctx = (uint32_t) required_n_ctx;
|
||||
}
|
||||
|
||||
ctx_dft = llama_init_from_model(params.model_dft, cparams_dft);
|
||||
if (ctx_dft == nullptr) {
|
||||
LOG_ERR("%s", "failed to create draft context\n");
|
||||
return nullptr;
|
||||
@ -1268,6 +1315,20 @@ common_speculative * common_speculative_init(
|
||||
));
|
||||
break;
|
||||
}
|
||||
case COMMON_SPECULATIVE_TYPE_DFLASH: {
|
||||
auto state = std::make_unique<common_speculative_state_dflash>(
|
||||
config.type,
|
||||
ctx_tgt,
|
||||
ctx_dft,
|
||||
config.params.dflash_cross_ctx);
|
||||
if (!state->ready) {
|
||||
LOG_ERR("%s: failed to initialize DFlash speculative state\n", __func__);
|
||||
return nullptr;
|
||||
}
|
||||
impls.push_back(std::move(state));
|
||||
ctx_dft = nullptr;
|
||||
break;
|
||||
}
|
||||
case COMMON_SPECULATIVE_TYPE_MTP: {
|
||||
llama_context * ctx_mtp = ctx_dft;
|
||||
if (!ctx_mtp) {
|
||||
@ -1343,7 +1404,6 @@ common_speculative * common_speculative_init(
|
||||
}
|
||||
|
||||
auto * result = new common_speculative {
|
||||
/* .checkpoint = */ {},
|
||||
/* .configs = */ std::move(configs),
|
||||
/* .impls = */ std::move(impls)
|
||||
};
|
||||
@ -1369,175 +1429,12 @@ common_speculative * common_speculative_init(
|
||||
return result;
|
||||
}
|
||||
|
||||
common_speculative_init_status common_speculative_try_init(
|
||||
common_params_speculative & params,
|
||||
llama_context * ctx_tgt,
|
||||
common_speculative ** out_spec) {
|
||||
if (out_spec != nullptr) {
|
||||
*out_spec = nullptr;
|
||||
}
|
||||
|
||||
if (!params.has_stage_chain()) {
|
||||
return COMMON_SPECULATIVE_INIT_SKIPPED;
|
||||
}
|
||||
|
||||
common_speculative * spec = common_speculative_init(params, ctx_tgt);
|
||||
if (spec != nullptr) {
|
||||
if (out_spec != nullptr) {
|
||||
*out_spec = spec;
|
||||
}
|
||||
return COMMON_SPECULATIVE_INIT_READY;
|
||||
}
|
||||
|
||||
const llama_model * model = ctx_tgt != nullptr ? llama_get_model(ctx_tgt) : nullptr;
|
||||
if (model != nullptr && llama_model_has_recurrent(model)) {
|
||||
return COMMON_SPECULATIVE_INIT_ERR_RECURRENT;
|
||||
}
|
||||
if (params.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP)) {
|
||||
return COMMON_SPECULATIVE_INIT_ERR_MTP;
|
||||
}
|
||||
return COMMON_SPECULATIVE_INIT_ERR_GENERIC;
|
||||
}
|
||||
|
||||
void common_speculative_prepare_startup(
|
||||
gpt_params & params_base,
|
||||
bool allow_parallel_mtp) {
|
||||
auto & params = params_base.speculative;
|
||||
|
||||
if (!allow_parallel_mtp && params_base.n_parallel > 1 && params.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP)) {
|
||||
LOG_WRN("%s: MTP is not supported with parallel slots yet, removing the MTP stage to avoid cross-slot corruption. n_parallel=%d, stage_chain=%s\n",
|
||||
__func__, params_base.n_parallel, common_speculative_stage_chain_to_str(params).c_str());
|
||||
params.remove_stage_type(COMMON_SPECULATIVE_TYPE_MTP);
|
||||
}
|
||||
|
||||
if (!params.needs_dft_model()) {
|
||||
params.clear_dft();
|
||||
}
|
||||
|
||||
params_base.has_mtp = params.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP);
|
||||
}
|
||||
|
||||
bool common_speculative_finalize_startup(
|
||||
gpt_params & params_base,
|
||||
const llama_model * model) {
|
||||
auto & params = params_base.speculative;
|
||||
|
||||
if (!params.needs_dft_model()) {
|
||||
params.clear_dft();
|
||||
}
|
||||
|
||||
if (params.has_dft()) {
|
||||
LLAMA_LOG_INFO("\n\n==================================loading DRAFT model==================================\n\n");
|
||||
if (!common_speculative_load_draft_model(params, params_base)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
params_base.has_mtp = params.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP);
|
||||
const bool has_external_mtp = params_base.has_mtp &&
|
||||
llama_model_is_gemma4_mtp_assistant(params.model_dft);
|
||||
|
||||
params_base.has_mtp = common_speculative_prepare_mtp_runtime(
|
||||
params,
|
||||
params_base,
|
||||
model,
|
||||
has_external_mtp);
|
||||
if (params_base.has_mtp) {
|
||||
params_base.pooling_type = LLAMA_POOLING_TYPE_NONE;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool common_speculative_load_draft_model(
|
||||
common_params_speculative & params,
|
||||
const gpt_params & params_base) {
|
||||
if (!params.has_dft()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
gpt_params params_dft;
|
||||
params_dft.devices = params.devices;
|
||||
params_dft.model = params.model;
|
||||
params_dft.main_gpu = params_base.main_gpu;
|
||||
params_dft.n_gpu_layers = params.n_gpu_layers;
|
||||
params_dft.rpc_servers = params_base.rpc_servers;
|
||||
params_dft.cache_type_k = params.cache_type_k.empty() ? params_base.cache_type_k : params.cache_type_k;
|
||||
params_dft.cache_type_v = params.cache_type_v.empty() ? params_base.cache_type_v : params.cache_type_v;
|
||||
params_dft.flash_attn = params_base.flash_attn;
|
||||
params_dft.k_cache_hadamard = params_base.k_cache_hadamard;
|
||||
params_dft.v_cache_hadamard = params_base.v_cache_hadamard;
|
||||
|
||||
if (!params.params.empty()) {
|
||||
auto [argc, argv] = parse_command_line("llama-server " + params.params);
|
||||
if (!gpt_params_parse(argc, argv, params_dft)) {
|
||||
gpt_params_print_usage(argc, argv, params_dft);
|
||||
free_command_line(argc, argv);
|
||||
return false;
|
||||
}
|
||||
free_command_line(argc, argv);
|
||||
}
|
||||
|
||||
LOG_INF("%s: loading draft model '%s'\n", __func__, params_dft.model.c_str());
|
||||
|
||||
if (params_dft.n_ctx == 0) {
|
||||
params_dft.n_ctx = params.n_ctx;
|
||||
}
|
||||
params_dft.n_ctx = params_dft.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_dft.n_ctx;
|
||||
params_dft.n_parallel = 1;
|
||||
params_dft.n_batch = params_dft.n_ctx;
|
||||
|
||||
params.mparams_dft.path = params_dft.model;
|
||||
|
||||
llama_model_params mparams_dft = common_model_params_to_llama(params_dft);
|
||||
llama_model * loaded_model = llama_model_load_from_file(params_dft.model.c_str(), mparams_dft);
|
||||
if (loaded_model == nullptr) {
|
||||
LOG_ERR("%s: failed to load draft model '%s'\n", __func__, params.model.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
params.model_dft = loaded_model;
|
||||
params.cparams_dft = common_context_params_to_llama(params_dft);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool common_speculative_prepare_mtp_runtime(
|
||||
common_params_speculative & params,
|
||||
const gpt_params & params_base,
|
||||
const llama_model * model,
|
||||
bool has_external_mtp) {
|
||||
if (!params.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (llama_model_n_nextn_layer(model) == 0 && !has_external_mtp) {
|
||||
LOG_WRN("%s: MTP speculative stage requested, but model has 0 NextN layers. Removing MTP from the configured stage chain.\n",
|
||||
__func__);
|
||||
params.remove_stage_type(COMMON_SPECULATIVE_TYPE_MTP);
|
||||
if (!params.needs_dft_model()) {
|
||||
params.clear_dft();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!has_external_mtp) {
|
||||
gpt_params params_mtp = params_base;
|
||||
params_mtp.pooling_type = LLAMA_POOLING_TYPE_NONE;
|
||||
params.cparams_dft = common_context_params_to_llama(params_mtp);
|
||||
}
|
||||
|
||||
params.cparams_dft.mtp = true;
|
||||
params.cparams_dft.mtp_op_type = MTP_OP_WARMUP;
|
||||
params.cparams_dft.embeddings = true;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void common_speculative_free(common_speculative * spec) {
|
||||
if (spec == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
spec->checkpoint.clear();
|
||||
delete spec;
|
||||
}
|
||||
|
||||
@ -1546,11 +1443,6 @@ void common_speculative_begin(common_speculative * spec, const llama_tokens & pr
|
||||
return;
|
||||
}
|
||||
|
||||
spec->checkpoint.clear();
|
||||
spec->curr_impl = nullptr;
|
||||
spec->last_n_drafted = 0;
|
||||
spec->t_step_start_us = 0;
|
||||
|
||||
for (auto & impl : spec->impls) {
|
||||
common_time_meas tm(impl->t_begin_us, !impl->gen_perf);
|
||||
impl->begin(prompt);
|
||||
@ -1654,34 +1546,6 @@ void common_speculative_accept(common_speculative * spec, uint16_t n_accepted) {
|
||||
}
|
||||
}
|
||||
|
||||
bool common_speculative_before_draft(
|
||||
common_speculative * spec,
|
||||
llama_model * model,
|
||||
llama_context * ctx,
|
||||
common_sampler * sampler_src,
|
||||
const common_params_sampling & sparams,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos n_past,
|
||||
llama_token sampled,
|
||||
int max_tokens,
|
||||
int ckpt_mode) {
|
||||
if (spec == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return common_speculative_checkpoint_save(
|
||||
spec->checkpoint,
|
||||
model,
|
||||
ctx,
|
||||
sampler_src,
|
||||
sparams,
|
||||
seq_id,
|
||||
n_past,
|
||||
sampled,
|
||||
max_tokens,
|
||||
ckpt_mode);
|
||||
}
|
||||
|
||||
static bool common_speculative_has_type(const common_speculative * spec, common_speculative_type type) {
|
||||
if (spec == nullptr) {
|
||||
return false;
|
||||
@ -1830,6 +1694,10 @@ static bool common_speculative_collect_target_batch_features(
|
||||
const llama_batch & batch,
|
||||
common_speculative_feature_view & features) {
|
||||
features = {};
|
||||
if (common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_DFLASH)) {
|
||||
return llama_spec_get_dflash_feature_view(ctx, batch, features);
|
||||
}
|
||||
|
||||
if (!common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_MTP)) {
|
||||
return true;
|
||||
}
|
||||
@ -1848,6 +1716,10 @@ static bool common_speculative_collect_target_seq_batch_features(
|
||||
llama_seq_id seq_id,
|
||||
common_speculative_feature_view & features) {
|
||||
features = {};
|
||||
if (common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_DFLASH)) {
|
||||
return llama_spec_get_dflash_feature_view_for_seq(ctx, batch, seq_id, features);
|
||||
}
|
||||
|
||||
if (!common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_MTP)) {
|
||||
return true;
|
||||
}
|
||||
@ -1921,27 +1793,246 @@ common_speculative_draft_result common_speculative_draft_ex(
|
||||
return result;
|
||||
}
|
||||
|
||||
static bool common_speculative_has_target_features(const common_speculative * spec) {
|
||||
return common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_MTP) ||
|
||||
common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_DFLASH);
|
||||
}
|
||||
|
||||
bool common_speculative_load_draft_model(
|
||||
common_params_speculative & params,
|
||||
const gpt_params & params_base) {
|
||||
if (!params.has_dft()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
gpt_params params_dft;
|
||||
params_dft.devices = params.devices;
|
||||
params_dft.model = params.model;
|
||||
params_dft.main_gpu = params_base.main_gpu;
|
||||
params_dft.n_gpu_layers = params.n_gpu_layers;
|
||||
params_dft.rpc_servers = params_base.rpc_servers;
|
||||
params_dft.cache_type_k = params.cache_type_k.empty() ? params_base.cache_type_k : params.cache_type_k;
|
||||
params_dft.cache_type_v = params.cache_type_v.empty() ? params_base.cache_type_v : params.cache_type_v;
|
||||
params_dft.flash_attn = params_base.flash_attn;
|
||||
params_dft.k_cache_hadamard = params_base.k_cache_hadamard;
|
||||
params_dft.v_cache_hadamard = params_base.v_cache_hadamard;
|
||||
|
||||
if (params.has_stage_type(COMMON_SPECULATIVE_TYPE_DFLASH)) {
|
||||
params_dft.split_mode = params_base.split_mode;
|
||||
for (size_t i = 0; i < std::size(params_dft.tensor_split); ++i) {
|
||||
params_dft.tensor_split[i] = params_base.tensor_split[i];
|
||||
}
|
||||
params_dft.attn_max_batch = params_base.attn_max_batch;
|
||||
params_dft.graph_reuse = params_base.graph_reuse;
|
||||
params_dft.split_mode_graph_scheduling = params_base.split_mode_graph_scheduling;
|
||||
params_dft.scheduler_async = params_base.scheduler_async;
|
||||
params_dft.max_extra_alloc_MiB = params_base.max_extra_alloc_MiB;
|
||||
params_dft.reduce_type = params_base.reduce_type;
|
||||
}
|
||||
|
||||
if (!params.params.empty()) {
|
||||
auto [argc, argv] = parse_command_line("llama-server " + params.params);
|
||||
if (!gpt_params_parse(argc, argv, params_dft)) {
|
||||
gpt_params_print_usage(argc, argv, params_dft);
|
||||
free_command_line(argc, argv);
|
||||
return false;
|
||||
}
|
||||
free_command_line(argc, argv);
|
||||
}
|
||||
|
||||
LOG_INF("%s: loading draft model '%s'\n", __func__, params_dft.model.c_str());
|
||||
|
||||
if (params_dft.n_ctx == 0) {
|
||||
params_dft.n_ctx = params.n_ctx;
|
||||
}
|
||||
if (params.has_stage_type(COMMON_SPECULATIVE_TYPE_DFLASH) && params_dft.n_gpu_layers < 0) {
|
||||
params_dft.n_gpu_layers = params_base.n_gpu_layers;
|
||||
}
|
||||
params_dft.n_ctx = params_dft.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_dft.n_ctx;
|
||||
params_dft.n_parallel = 1;
|
||||
params_dft.n_batch = params_dft.n_ctx;
|
||||
|
||||
params.mparams_dft.path = params_dft.model;
|
||||
|
||||
llama_model_params mparams_dft = common_model_params_to_llama(params_dft);
|
||||
llama_model * loaded_model = llama_model_load_from_file(params_dft.model.c_str(), mparams_dft);
|
||||
if (loaded_model == nullptr) {
|
||||
LOG_ERR("%s: failed to load draft model '%s'\n", __func__, params.model.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
params.model_dft = loaded_model;
|
||||
params.cparams_dft = common_context_params_to_llama(params_dft);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool common_speculative_prepare_mtp_runtime(
|
||||
common_params_speculative & params,
|
||||
const gpt_params & params_base,
|
||||
const llama_model * model,
|
||||
bool has_external_mtp) {
|
||||
if (!params.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (llama_model_n_nextn_layer(model) == 0 && !has_external_mtp) {
|
||||
LOG_WRN("%s: MTP speculative stage requested, but model has 0 NextN layers. Removing MTP from the configured stage chain.\n",
|
||||
__func__);
|
||||
params.remove_stage_type(COMMON_SPECULATIVE_TYPE_MTP);
|
||||
if (!params.needs_dft_model()) {
|
||||
params.clear_dft();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!has_external_mtp) {
|
||||
gpt_params params_mtp = params_base;
|
||||
params_mtp.pooling_type = LLAMA_POOLING_TYPE_NONE;
|
||||
params.cparams_dft = common_context_params_to_llama(params_mtp);
|
||||
}
|
||||
|
||||
params.cparams_dft.mtp = true;
|
||||
params.cparams_dft.mtp_op_type = MTP_OP_WARMUP;
|
||||
params.cparams_dft.embeddings = true;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
common_speculative_init_status common_speculative_try_init(
|
||||
common_params_speculative & params,
|
||||
llama_context * ctx_tgt,
|
||||
common_speculative ** out_spec) {
|
||||
if (out_spec != nullptr) {
|
||||
*out_spec = nullptr;
|
||||
}
|
||||
|
||||
if (!params.has_stage_chain()) {
|
||||
return COMMON_SPECULATIVE_INIT_SKIPPED;
|
||||
}
|
||||
|
||||
common_speculative * spec = common_speculative_init(params, ctx_tgt);
|
||||
if (spec != nullptr) {
|
||||
if (out_spec != nullptr) {
|
||||
*out_spec = spec;
|
||||
}
|
||||
return COMMON_SPECULATIVE_INIT_READY;
|
||||
}
|
||||
|
||||
const llama_model * model = ctx_tgt != nullptr ? llama_get_model(ctx_tgt) : nullptr;
|
||||
if (model != nullptr && llama_model_has_recurrent(model)) {
|
||||
return COMMON_SPECULATIVE_INIT_ERR_RECURRENT;
|
||||
}
|
||||
if (params.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP)) {
|
||||
return COMMON_SPECULATIVE_INIT_ERR_MTP;
|
||||
}
|
||||
return COMMON_SPECULATIVE_INIT_ERR_GENERIC;
|
||||
}
|
||||
|
||||
void common_speculative_prepare_startup(
|
||||
gpt_params & params_base,
|
||||
bool allow_parallel_mtp) {
|
||||
auto & params = params_base.speculative;
|
||||
|
||||
if (!allow_parallel_mtp && params_base.n_parallel > 1 && params.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP)) {
|
||||
LOG_WRN("%s: MTP is not supported with parallel slots yet, removing the MTP stage to avoid cross-slot corruption. n_parallel=%d, stage_chain=%s\n",
|
||||
__func__, params_base.n_parallel, common_speculative_stage_chain_to_str(params).c_str());
|
||||
params.remove_stage_type(COMMON_SPECULATIVE_TYPE_MTP);
|
||||
}
|
||||
|
||||
if (!params.needs_dft_model()) {
|
||||
params.clear_dft();
|
||||
}
|
||||
|
||||
params_base.has_mtp = params.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP);
|
||||
}
|
||||
|
||||
bool common_speculative_finalize_startup(
|
||||
gpt_params & params_base,
|
||||
const llama_model * model) {
|
||||
auto & params = params_base.speculative;
|
||||
|
||||
if (!params.needs_dft_model()) {
|
||||
params.clear_dft();
|
||||
}
|
||||
|
||||
if (params.has_dft()) {
|
||||
LLAMA_LOG_INFO("\n\n==================================loading DRAFT model==================================\n\n");
|
||||
if (!common_speculative_load_draft_model(params, params_base)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
params_base.has_mtp = params.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP);
|
||||
const bool has_external_mtp = params_base.has_mtp &&
|
||||
llama_model_is_gemma4_mtp_assistant(params.model_dft);
|
||||
|
||||
params_base.has_mtp = common_speculative_prepare_mtp_runtime(
|
||||
params,
|
||||
params_base,
|
||||
model,
|
||||
has_external_mtp);
|
||||
if (params_base.has_mtp) {
|
||||
params_base.pooling_type = LLAMA_POOLING_TYPE_NONE;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool common_speculative_before_draft(
|
||||
common_speculative * spec,
|
||||
llama_model * model,
|
||||
llama_context * ctx,
|
||||
common_sampler * sampler_src,
|
||||
const common_params_sampling & sparams,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos n_past,
|
||||
llama_token sampled,
|
||||
int max_tokens,
|
||||
int ckpt_mode) {
|
||||
if (spec == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return common_speculative_checkpoint_save(
|
||||
spec->checkpoint,
|
||||
model,
|
||||
ctx,
|
||||
sampler_src,
|
||||
sparams,
|
||||
seq_id,
|
||||
n_past,
|
||||
sampled,
|
||||
max_tokens,
|
||||
ckpt_mode);
|
||||
}
|
||||
|
||||
int32_t common_speculative_on_target_seq_batch(
|
||||
common_speculative * spec,
|
||||
llama_context * ctx_tgt,
|
||||
const llama_batch & batch,
|
||||
llama_seq_id seq_id,
|
||||
bool is_prompt_warmup) {
|
||||
llama_context * ctx_mtp = common_speculative_get_companion_ctx(spec);
|
||||
ctx_mtp = ctx_mtp ? ctx_mtp : ctx_tgt;
|
||||
if (ctx_tgt == nullptr || ctx_mtp == nullptr || batch.n_tokens <= 0) {
|
||||
if (ctx_tgt == nullptr || batch.n_tokens <= 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const int n_embd_src = common_speculative_ctx_mtp_n_embd(ctx_tgt);
|
||||
const int n_embd_dst = common_speculative_ctx_mtp_n_embd(ctx_mtp);
|
||||
if (n_embd_src <= 0 || n_embd_dst <= 0) {
|
||||
return -1;
|
||||
}
|
||||
if (!common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_DFLASH)) {
|
||||
llama_context * ctx_mtp = common_speculative_get_companion_ctx(spec);
|
||||
ctx_mtp = ctx_mtp ? ctx_mtp : ctx_tgt;
|
||||
if (ctx_mtp == nullptr) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (n_embd_src != n_embd_dst) {
|
||||
LOG_ERR("MTP warmup hidden state width mismatch: n_embd_src = %d, n_embd_dst = %d\n", n_embd_src, n_embd_dst);
|
||||
return -1;
|
||||
const int n_embd_src = common_speculative_ctx_mtp_n_embd(ctx_tgt);
|
||||
const int n_embd_dst = common_speculative_ctx_mtp_n_embd(ctx_mtp);
|
||||
if (n_embd_src <= 0 || n_embd_dst <= 0) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (n_embd_src != n_embd_dst) {
|
||||
LOG_ERR("MTP warmup hidden state width mismatch: n_embd_src = %d, n_embd_dst = %d\n", n_embd_src, n_embd_dst);
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
common_speculative_feature_view feature_view;
|
||||
@ -1981,6 +2072,10 @@ bool common_speculative_copy_output_hidden_rows(
|
||||
const std::vector<int32_t> & output_indices,
|
||||
std::vector<float> & hidden_rows) {
|
||||
hidden_rows.clear();
|
||||
if (common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_DFLASH)) {
|
||||
return llama_spec_copy_dflash_rows_from_output_indices(ctx, output_indices, hidden_rows);
|
||||
}
|
||||
|
||||
if (!common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_MTP)) {
|
||||
return true;
|
||||
}
|
||||
@ -2018,13 +2113,13 @@ static bool common_speculative_apply_hidden_rows(
|
||||
llama_pos pos_base,
|
||||
const std::vector<llama_token> & ids,
|
||||
const std::vector<float> & hidden_rows) {
|
||||
auto * mtp_state = common_speculative_get_mtp_state(spec);
|
||||
if (mtp_state == nullptr || ids.empty()) {
|
||||
const int32_t feature_width = common_speculative_feature_width(spec);
|
||||
if (feature_width <= 0 || ids.empty()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
const size_t expected_floats = ids.size() * (size_t) mtp_state->n_embd;
|
||||
if (mtp_state->n_embd <= 0 || hidden_rows.size() != expected_floats) {
|
||||
const size_t expected_floats = ids.size() * (size_t) feature_width;
|
||||
if (hidden_rows.size() != expected_floats) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -2035,7 +2130,7 @@ static bool common_speculative_apply_hidden_rows(
|
||||
|
||||
common_speculative_feature_view feature_view;
|
||||
const bool have_feature_view = common_speculative_feature_view_from_hidden_rows(
|
||||
hidden_rows, mtp_state->n_embd, seq_id, pos_base, feature_view);
|
||||
hidden_rows, feature_width, seq_id, pos_base, feature_view);
|
||||
const int32_t ret = have_feature_view
|
||||
? common_speculative_on_target_batch(spec, accepted_batch, feature_view, false)
|
||||
: -1;
|
||||
@ -2052,7 +2147,7 @@ bool common_speculative_commit_accepted_hidden_rows(
|
||||
llama_token sampled_before,
|
||||
const std::vector<llama_token> & ids,
|
||||
const std::vector<float> & hidden_rows) {
|
||||
if (!common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_MTP) || ids.empty()) {
|
||||
if (common_speculative_feature_width(spec) <= 0 || ids.empty()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -2073,7 +2168,7 @@ bool common_speculative_commit_accepted_output(
|
||||
llama_token sampled_before,
|
||||
const std::vector<llama_token> & ids,
|
||||
const std::vector<int32_t> & output_indices) {
|
||||
if (!common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_MTP) || ids.empty()) {
|
||||
if (common_speculative_feature_width(spec) <= 0 || ids.empty()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -2172,7 +2267,7 @@ void common_speculative_checkpoint_restore(
|
||||
}
|
||||
}
|
||||
|
||||
if (common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_MTP) && !mtp_hidden_state_pre.empty()) {
|
||||
if (common_speculative_has_target_features(spec) && !mtp_hidden_state_pre.empty()) {
|
||||
if (!common_speculative_commit_accepted_hidden_rows(
|
||||
spec,
|
||||
spec_type_used,
|
||||
@ -2218,7 +2313,7 @@ void common_speculative_checkpoint_restore(
|
||||
__func__, (int) seq_id, ret);
|
||||
}
|
||||
|
||||
if (common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_MTP)) {
|
||||
if (common_speculative_has_target_features(spec)) {
|
||||
std::vector<int32_t> redecoded_indices(n_re);
|
||||
for (int j = 0; j < n_re; ++j) {
|
||||
redecoded_indices[j] = j;
|
||||
@ -2274,7 +2369,7 @@ void common_speculative_commit(
|
||||
|
||||
common_speculative_accept(spec, ids.size() - 1);
|
||||
|
||||
if (common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_MTP) &&
|
||||
if (common_speculative_has_target_features(spec) &&
|
||||
any_rejected &&
|
||||
ckpt.valid &&
|
||||
!accepted_output_indices.empty()) {
|
||||
@ -2299,7 +2394,7 @@ void common_speculative_commit(
|
||||
return;
|
||||
}
|
||||
|
||||
if (common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_MTP) && !accepted_output_indices.empty()) {
|
||||
if (common_speculative_has_target_features(spec) && !accepted_output_indices.empty()) {
|
||||
if (!common_speculative_commit_accepted_output(
|
||||
spec,
|
||||
ctx,
|
||||
@ -2345,6 +2440,7 @@ void common_speculative_print_stats(const common_speculative * spec, double slot
|
||||
impl->n_gen_tokens,
|
||||
impl->n_acc_tokens,
|
||||
str_perf.c_str());
|
||||
|
||||
}
|
||||
|
||||
if (spec->tuner && spec->tuner->enabled && slot_tps > 0.0 && n_decoded > 0) {
|
||||
@ -2384,6 +2480,40 @@ static const common_speculative_state_mtp * common_speculative_get_mtp_state(con
|
||||
return common_speculative_get_mtp_state(const_cast<common_speculative *>(spec));
|
||||
}
|
||||
|
||||
static common_speculative_state_dflash * common_speculative_get_dflash_state(common_speculative * spec) {
|
||||
if (!spec) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
for (auto & impl : spec->impls) {
|
||||
if (impl->type != COMMON_SPECULATIVE_TYPE_DFLASH) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto * dflash_state = dynamic_cast<common_speculative_state_dflash *>(impl.get())) {
|
||||
return dflash_state;
|
||||
}
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static const common_speculative_state_dflash * common_speculative_get_dflash_state(const common_speculative * spec) {
|
||||
return common_speculative_get_dflash_state(const_cast<common_speculative *>(spec));
|
||||
}
|
||||
|
||||
static int32_t common_speculative_feature_width(const common_speculative * spec) {
|
||||
if (const auto * dflash_state = common_speculative_get_dflash_state(spec); dflash_state != nullptr) {
|
||||
return dflash_state->n_target_features;
|
||||
}
|
||||
|
||||
if (const auto * mtp_state = common_speculative_get_mtp_state(spec); mtp_state != nullptr) {
|
||||
return mtp_state->n_embd;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
static mtp_last_embd & mtp_get_last_embd(common_speculative_state_mtp & state, llama_seq_id seq_id) {
|
||||
auto & last = state.draft_cache_by_seq[seq_id];
|
||||
if ((int) last.embd.size() != state.n_embd) {
|
||||
@ -2459,11 +2589,13 @@ bool common_speculative_has_sequence_hidden(const common_speculative * spec, lla
|
||||
|
||||
void common_speculative_clear_sequence_hidden(common_speculative * spec, llama_seq_id seq_id) {
|
||||
auto * mtp_state = common_speculative_get_mtp_state(spec);
|
||||
if (mtp_state == nullptr) {
|
||||
return;
|
||||
if (mtp_state != nullptr) {
|
||||
mtp_clear_target_hidden(*mtp_state, seq_id);
|
||||
}
|
||||
|
||||
mtp_clear_target_hidden(*mtp_state, seq_id);
|
||||
if (auto * dflash_state = common_speculative_get_dflash_state(spec); dflash_state != nullptr) {
|
||||
dflash_clear_target_features(*dflash_state);
|
||||
}
|
||||
}
|
||||
|
||||
void common_speculative_clear_sequence(
|
||||
@ -2515,6 +2647,10 @@ llama_context * common_speculative_get_companion_ctx(common_speculative * spec)
|
||||
return mtp_state->ctx_mtp;
|
||||
}
|
||||
|
||||
if (auto * dflash_state = common_speculative_get_dflash_state(spec); dflash_state != nullptr) {
|
||||
return dflash_state->ctx_dft;
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@ -2553,6 +2689,34 @@ int32_t common_speculative_on_target_batch(
|
||||
const llama_batch & batch,
|
||||
const common_speculative_feature_view & features,
|
||||
bool is_prompt_warmup) {
|
||||
if (auto * dflash_state = common_speculative_get_dflash_state(spec); dflash_state != nullptr) {
|
||||
if (features.kind != COMMON_SPECULATIVE_FEATURE_HIDDEN_STATE || batch.n_tokens <= 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (features.width != dflash_state->n_target_features) {
|
||||
LOG_ERR("%s: DFlash feature width mismatch: got %d expected %d\n",
|
||||
__func__, features.width, dflash_state->n_target_features);
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (batch.n_seq_id == nullptr || batch.seq_id == nullptr || batch.n_seq_id[0] <= 0 || batch.seq_id[0] == nullptr) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
const llama_seq_id seq_id = batch.seq_id[0][0];
|
||||
for (int i = 0; i < batch.n_tokens; ++i) {
|
||||
if (batch.n_seq_id[i] != 1 || batch.seq_id[i] == nullptr || batch.seq_id[i][0] != seq_id) {
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
if (!dflash_append_target_features(*dflash_state, features, seq_id)) {
|
||||
return -1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
auto * mtp_state = common_speculative_get_mtp_state(spec);
|
||||
if (mtp_state == nullptr) {
|
||||
return 0;
|
||||
@ -2617,6 +2781,10 @@ void common_speculative_context_shift(
|
||||
llama_kv_cache_seq_rm (ctx_mtp, seq_id, kv_keep, kv_keep + kv_discard);
|
||||
llama_kv_cache_seq_add(ctx_mtp, seq_id, kv_keep + kv_discard, kv_past, -kv_discard);
|
||||
}
|
||||
|
||||
if (auto * dflash_state = common_speculative_get_dflash_state(spec); dflash_state != nullptr) {
|
||||
dflash_context_shift(*dflash_state, kv_keep, kv_discard, kv_past);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<llama_token> mtp_speculative_gen_draft(
|
||||
|
||||
@ -64,6 +64,7 @@ class Model:
|
||||
model_name: str | None
|
||||
metadata_override: Path | None
|
||||
dir_model_card: Path
|
||||
target_model_dir: Path | None
|
||||
|
||||
# subclasses should define this!
|
||||
model_arch: gguf.MODEL_ARCH
|
||||
@ -71,7 +72,8 @@ class Model:
|
||||
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool = False,
|
||||
use_temp_file: bool = False, eager: bool = False,
|
||||
metadata_override: Path | None = None, model_name: str | None = None,
|
||||
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False):
|
||||
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False,
|
||||
target_model_dir: Path | None = None):
|
||||
if type(self) is Model:
|
||||
raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
|
||||
|
||||
@ -93,6 +95,7 @@ class Model:
|
||||
self.metadata_override = metadata_override
|
||||
self.model_name = model_name
|
||||
self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py
|
||||
self.target_model_dir = target_model_dir
|
||||
|
||||
# Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type
|
||||
if self.ftype == gguf.LlamaFileType.GUESSED:
|
||||
@ -459,6 +462,14 @@ class Model:
|
||||
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
|
||||
@staticmethod
|
||||
def load_text_hparams(dir_model: Path) -> dict[str, Any]:
|
||||
hparams = Model.load_hparams(dir_model)
|
||||
text_config = hparams.get("text_config")
|
||||
if isinstance(text_config, dict):
|
||||
return {**hparams, **text_config}
|
||||
return hparams
|
||||
|
||||
@classmethod
|
||||
def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]:
|
||||
assert names
|
||||
@ -500,13 +511,14 @@ class Model:
|
||||
return seems_special
|
||||
|
||||
# used for GPT-2 BPE and WordPiece vocabs
|
||||
def get_vocab_base(self) -> tuple[list[str], list[int], str]:
|
||||
def get_vocab_base(self, dir_model: Path | None = None, vocab_size: int | None = None) -> tuple[list[str], list[int], str]:
|
||||
tokens: list[str] = []
|
||||
toktypes: list[int] = []
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
|
||||
vocab_size = self.hparams.get("vocab_size", len(tokenizer.vocab))
|
||||
dir_model = dir_model or self.dir_model
|
||||
tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)
|
||||
vocab_size = vocab_size or self.hparams.get("vocab_size", len(tokenizer.vocab))
|
||||
assert max(tokenizer.vocab.values()) < vocab_size
|
||||
|
||||
tokpre = self.get_vocab_base_pre(tokenizer)
|
||||
@ -558,12 +570,12 @@ class Model:
|
||||
if chkhsh == "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273":
|
||||
# ref: https://huggingface.co/alvarobartt/grok-2-tokenizer
|
||||
res = "grok-2"
|
||||
if chkhsh == "0ef9807a4087ebef797fc749390439009c3b9eda9ad1a097abbe738f486c01e5":
|
||||
# ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B
|
||||
res = "llama-bpe"
|
||||
if chkhsh == "972da7b59cec44d1f0a490a86c96df53859e486e481563e5dddac155013d87ac":
|
||||
# ref: https://huggingface.co/poolside/Laguna-XS.2
|
||||
res = "laguna"
|
||||
if chkhsh == "0ef9807a4087ebef797fc749390439009c3b9eda9ad1a097abbe738f486c01e5":
|
||||
# ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B
|
||||
res = "llama-bpe"
|
||||
if chkhsh == "049ecf7629871e3041641907f3de7c733e4dbfdc736f57d882ba0b0845599754":
|
||||
# ref: https://huggingface.co/deepseek-ai/deepseek-llm-7b-base
|
||||
res = "deepseek-llm"
|
||||
@ -600,6 +612,18 @@ class Model:
|
||||
if chkhsh == "e636dc30a262dcc0d8c323492e32ae2b70728f4df7dfe9737d9f920a282b8aea":
|
||||
# ref: https://huggingface.co/Qwen/Qwen1.5-7B
|
||||
res = "qwen2"
|
||||
if chkhsh == "d30d75d9059f1aa2c19359de71047b3ae408c70875e8a3ccf8c5fba56c9d8af4":
|
||||
# ref: https://huggingface.co/Qwen/Qwen3.5-9B-Instruct
|
||||
res = "qwen35"
|
||||
if chkhsh == "99cc61242f7106804ce24fdf3a6451e4a55251078dffd5453c806e11b2310db3":
|
||||
# ref: https://huggingface.co/Qwen/Qwen3.5-27B
|
||||
res = "qwen35"
|
||||
if chkhsh == "1444df51289cfa8063b96f0e62b1125440111bc79a52003ea14b6eac7016fd5f":
|
||||
# ref: https://huggingface.co/z-lab/Qwen3.5-27B-DFlash (uses Qwen3.5 tokenizer)
|
||||
res = "qwen35"
|
||||
if chkhsh == "4f53cda18c2baa0c0354bb5f9a3ecbe5ed12ab4d8e11ba873c2f11161202b945":
|
||||
# ref: https://huggingface.co/Qwen/Qwen3.6-35B-A3B (identical pre-tokenizer regex to qwen35)
|
||||
res = "qwen35"
|
||||
if chkhsh == "b6dc8df998e1cfbdc4eac8243701a65afe638679230920b50d6f17d81c098166":
|
||||
# ref: https://huggingface.co/allenai/OLMo-1.7-7B-hf
|
||||
res = "olmo"
|
||||
@ -690,19 +714,20 @@ class Model:
|
||||
return res
|
||||
# Marker: End get_vocab_base_pre
|
||||
|
||||
def _set_vocab_gpt2(self) -> None:
|
||||
tokens, toktypes, tokpre = self.get_vocab_base()
|
||||
def _set_vocab_gpt2(self, dir_model: Path | None = None, vocab_size: int | None = None) -> None:
|
||||
dir_model = dir_model or self.dir_model
|
||||
tokens, toktypes, tokpre = self.get_vocab_base(dir_model=dir_model, vocab_size=vocab_size)
|
||||
self.gguf_writer.add_tokenizer_model("gpt2")
|
||||
self.gguf_writer.add_tokenizer_pre(tokpre)
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
|
||||
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
|
||||
special_vocab = gguf.SpecialVocab(dir_model, load_merges=True)
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
def _set_vocab_qwen(self):
|
||||
dir_model = self.dir_model
|
||||
hparams = self.hparams
|
||||
def _set_vocab_qwen(self, dir_model: Path | None = None, hparams: dict[str, Any] | None = None):
|
||||
dir_model = dir_model or self.dir_model
|
||||
hparams = hparams or self.hparams
|
||||
tokens: list[str] = []
|
||||
toktypes: list[int] = []
|
||||
|
||||
@ -2260,15 +2285,254 @@ class Qwen2MoeModel(Model):
|
||||
if len(experts) > 0:
|
||||
raise ValueError(f"Unprocessed experts: {experts}")
|
||||
|
||||
|
||||
@Model.register("Qwen3ForCausalLM")
|
||||
class Qwen3Model(Qwen2Model):
|
||||
model_arch = gguf.MODEL_ARCH.QWEN3
|
||||
|
||||
|
||||
@Model.register("Qwen3MoeForCausalLM")
|
||||
class Qwen3MoeModel(Qwen2MoeModel):
|
||||
model_arch = gguf.MODEL_ARCH.QWEN3MOE
|
||||
|
||||
|
||||
@Model.register("DFlashDraftModel")
|
||||
class DFlashDraftModel(Qwen3Model):
|
||||
model_arch = gguf.MODEL_ARCH.DFLASH_DRAFT
|
||||
|
||||
_target_hparams: dict[str, Any] | None = None
|
||||
_target_raw_hparams: dict[str, Any] | None = None
|
||||
_saw_token_embd = False
|
||||
_saw_output = False
|
||||
|
||||
def _require_target_model_dir(self) -> Path:
|
||||
if self.target_model_dir is None:
|
||||
raise ValueError("DFlashDraftModel conversion requires --target-model-dir <matching target model directory>")
|
||||
return self.target_model_dir
|
||||
|
||||
def _get_target_hparams(self) -> dict[str, Any]:
|
||||
if self._target_hparams is None:
|
||||
self._target_hparams = Model.load_text_hparams(self._require_target_model_dir())
|
||||
return self._target_hparams
|
||||
|
||||
def _get_target_raw_hparams(self) -> dict[str, Any]:
|
||||
if self._target_raw_hparams is None:
|
||||
self._target_raw_hparams = Model.load_hparams(self._require_target_model_dir())
|
||||
return self._target_raw_hparams
|
||||
|
||||
def _target_uses_gemma4_vocab(self) -> bool:
|
||||
raw_hparams = self._get_target_raw_hparams()
|
||||
model_type = str(raw_hparams.get("model_type", ""))
|
||||
if model_type.startswith("gemma4"):
|
||||
return True
|
||||
architectures = raw_hparams.get("architectures")
|
||||
if isinstance(architectures, list):
|
||||
return any(str(arch).startswith("Gemma4") for arch in architectures)
|
||||
return False
|
||||
|
||||
def _get_target_hidden_size(self) -> int | None:
|
||||
raw_hparams = self._get_target_raw_hparams()
|
||||
if (hidden_size := raw_hparams.get("hidden_size")) is not None:
|
||||
return int(hidden_size)
|
||||
if (hidden_size := raw_hparams.get("backbone_hidden_size")) is not None:
|
||||
return int(hidden_size)
|
||||
text_hparams = raw_hparams.get("text_config")
|
||||
if isinstance(text_hparams, dict) and (hidden_size := text_hparams.get("hidden_size")) is not None:
|
||||
return int(hidden_size)
|
||||
return None
|
||||
|
||||
def _set_vocab_gemma4(self, dir_model: Path, vocab_size: int | None = None) -> None:
|
||||
vocab = gguf.LlamaHfVocab(dir_model)
|
||||
tokens = []
|
||||
scores = []
|
||||
toktypes = []
|
||||
visible_tokens = {
|
||||
"<|channel>",
|
||||
"<channel|>",
|
||||
"<|tool_call>",
|
||||
"<tool_call|>",
|
||||
"<|tool_response>",
|
||||
"<tool_response|>",
|
||||
"<|\"|>",
|
||||
}
|
||||
|
||||
for text, score, toktype in vocab.all_tokens():
|
||||
tokens.append(text)
|
||||
scores.append(score)
|
||||
text_str = text.decode()
|
||||
if text_str in visible_tokens:
|
||||
toktypes.append(gguf.TokenType.USER_DEFINED)
|
||||
logger.info(f"Token {text_str!r} is set to USER_DEFINED")
|
||||
else:
|
||||
toktypes.append(toktype)
|
||||
|
||||
if vocab_size is not None and len(tokens) != int(vocab_size):
|
||||
raise ValueError(
|
||||
f"DFlashDraftModel: Gemma4 tokenizer size {len(tokens)} does not match expected vocab_size={int(vocab_size)}"
|
||||
)
|
||||
|
||||
self.gguf_writer.add_tokenizer_model("gemma4")
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_scores(scores)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
|
||||
special_vocab = gguf.SpecialVocab(dir_model, load_merges=True)
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
self.gguf_writer.add_add_space_prefix(False)
|
||||
self.gguf_writer.add_add_bos_token(True)
|
||||
|
||||
def set_vocab(self):
|
||||
target_hparams = self._get_target_hparams()
|
||||
target_model_dir = self._require_target_model_dir()
|
||||
if self._target_uses_gemma4_vocab():
|
||||
self._set_vocab_gemma4(
|
||||
dir_model=target_model_dir,
|
||||
vocab_size=target_hparams.get("vocab_size"),
|
||||
)
|
||||
return
|
||||
self._set_vocab_gpt2(
|
||||
dir_model=target_model_dir,
|
||||
vocab_size=target_hparams.get("vocab_size"),
|
||||
)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
|
||||
self.gguf_writer.add_causal_attention(False)
|
||||
self.gguf_writer.add_rope_dimension_count(self.hparams.get("head_dim", 128))
|
||||
|
||||
rope_scaling = self.hparams.get("rope_scaling")
|
||||
if isinstance(rope_scaling, dict):
|
||||
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type"))
|
||||
rope_factor = rope_scaling.get("factor")
|
||||
|
||||
if rope_type == "linear" and rope_factor is not None:
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
|
||||
self.gguf_writer.add_rope_scaling_factor(rope_factor)
|
||||
elif rope_type == "yarn" and rope_factor is not None:
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
|
||||
self.gguf_writer.add_rope_scaling_factor(rope_factor)
|
||||
|
||||
if (orig_ctx_len := rope_scaling.get("original_max_position_embeddings")) is not None:
|
||||
self.gguf_writer.add_rope_scaling_orig_ctx_len(orig_ctx_len)
|
||||
if (yarn_ext_factor := rope_scaling.get("extrapolation_factor")) is not None:
|
||||
self.gguf_writer.add_rope_scaling_yarn_ext_factor(yarn_ext_factor)
|
||||
if (yarn_attn_factor := rope_scaling.get("attention_factor", rope_scaling.get("attn_factor"))) is not None:
|
||||
self.gguf_writer.add_rope_scaling_yarn_attn_factor(yarn_attn_factor)
|
||||
if (yarn_beta_fast := rope_scaling.get("beta_fast")) is not None:
|
||||
self.gguf_writer.add_rope_scaling_yarn_beta_fast(yarn_beta_fast)
|
||||
if (yarn_beta_slow := rope_scaling.get("beta_slow")) is not None:
|
||||
self.gguf_writer.add_rope_scaling_yarn_beta_slow(yarn_beta_slow)
|
||||
|
||||
arch = self.gguf_writer.arch
|
||||
dflash_cfg = self.hparams.get("dflash_config")
|
||||
dflash_cfg = dflash_cfg if isinstance(dflash_cfg, dict) else {}
|
||||
|
||||
def dflash_required_value(name: str) -> Any:
|
||||
if name in dflash_cfg:
|
||||
return dflash_cfg[name]
|
||||
if name in self.hparams:
|
||||
return self.hparams[name]
|
||||
raise ValueError(f"DFlashDraftModel conversion requires explicit {name} metadata")
|
||||
|
||||
block_size = int(dflash_required_value("block_size"))
|
||||
self.gguf_writer.add_uint32(f"{arch}.dflash.block_size", block_size)
|
||||
|
||||
mask_token_id = int(dflash_required_value("mask_token_id"))
|
||||
self.gguf_writer.add_uint32(f"{arch}.dflash.mask_token_id", mask_token_id)
|
||||
|
||||
target_layer_ids = [int(layer_id) for layer_id in dflash_required_value("target_layer_ids")]
|
||||
if len(target_layer_ids) == 0:
|
||||
raise ValueError("DFlashDraftModel conversion requires at least one target_layer_id")
|
||||
self.gguf_writer.add_array(f"{arch}.dflash.target_layer_ids", target_layer_ids)
|
||||
|
||||
if "n_target_features" in dflash_cfg:
|
||||
n_target_features = int(dflash_cfg["n_target_features"])
|
||||
elif "n_target_features" in self.hparams:
|
||||
n_target_features = int(self.hparams["n_target_features"])
|
||||
else:
|
||||
target_hidden_size = self._get_target_hidden_size()
|
||||
if target_hidden_size is None:
|
||||
raise ValueError("DFlashDraftModel: target config is missing hidden_size")
|
||||
|
||||
draft_hidden_size = self.hparams.get("hidden_size")
|
||||
if draft_hidden_size is None:
|
||||
raise ValueError("DFlashDraftModel: draft config is missing hidden_size")
|
||||
|
||||
n_target_features = int(target_hidden_size) * len(target_layer_ids)
|
||||
|
||||
if target_hidden_size is not None and int(target_hidden_size) != int(draft_hidden_size):
|
||||
logger.warning(
|
||||
"DFlashDraftModel: target hidden_size=%d differs from draft hidden_size=%d; using target hidden width for n_target_features",
|
||||
int(target_hidden_size),
|
||||
int(draft_hidden_size),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"DFlashDraftModel: inferred n_target_features=%d from target hidden_size=%d and n_target_layers=%d",
|
||||
n_target_features,
|
||||
int(target_hidden_size),
|
||||
len(target_layer_ids),
|
||||
)
|
||||
|
||||
self.gguf_writer.add_uint32(f"{arch}.dflash.n_target_features", n_target_features)
|
||||
|
||||
logger.info(
|
||||
"DFlashDraftModel metadata: block_size=%s mask_token_id=%s target_layer_ids=%s n_target_features=%s",
|
||||
block_size,
|
||||
mask_token_id,
|
||||
target_layer_ids,
|
||||
n_target_features,
|
||||
)
|
||||
|
||||
def prepare_tensors(self):
|
||||
super().prepare_tensors()
|
||||
|
||||
if self._saw_output and not self._saw_token_embd:
|
||||
raise ValueError(
|
||||
"DFlashDraftModel conversion requires token_embd.weight when output.weight is present"
|
||||
)
|
||||
|
||||
if self._saw_token_embd and self._saw_output:
|
||||
io_mode = "self-contained"
|
||||
elif self._saw_token_embd:
|
||||
io_mode = "self-contained-tied"
|
||||
else:
|
||||
io_mode = "shared-target"
|
||||
|
||||
logger.info(
|
||||
"DFlashDraftModel IO contract: io=%s token_embd=%s output=%s target_model_dir=%s",
|
||||
io_mode,
|
||||
self._saw_token_embd,
|
||||
self._saw_output,
|
||||
self._require_target_model_dir(),
|
||||
)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
top_level_name = name[6:] if name.startswith("model.") else name
|
||||
|
||||
if top_level_name == "fc.weight":
|
||||
return [(f"{gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.DFLASH_FC]}.weight", data_torch)]
|
||||
if top_level_name == "hidden_norm.weight":
|
||||
return [(f"{gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.DFLASH_HIDDEN_NORM]}.weight", data_torch)]
|
||||
if name == "norm.weight":
|
||||
name = "model.norm.weight"
|
||||
elif name.startswith("layers."):
|
||||
name = f"model.{name}"
|
||||
|
||||
tensors = list(super().modify_tensors(data_torch, name, bid))
|
||||
token_embd_name = f"{gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.TOKEN_EMBD]}.weight"
|
||||
output_name = f"{gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.OUTPUT]}.weight"
|
||||
|
||||
for tensor_name, _ in tensors:
|
||||
if tensor_name == token_embd_name:
|
||||
self._saw_token_embd = True
|
||||
elif tensor_name == output_name:
|
||||
self._saw_output = True
|
||||
|
||||
return tensors
|
||||
|
||||
|
||||
@Model.register("MellumForCausalLM")
|
||||
class MellumModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.MELLUM
|
||||
@ -4617,6 +4881,7 @@ class JaisModel(Model):
|
||||
super().prepare_tensors()
|
||||
self.gguf_writer.add_max_alibi_bias(self.max_alibi_bias)
|
||||
|
||||
|
||||
@Model.register("MiniMaxM2ForCausalLM")
|
||||
class MiniMaxM2Model(Model):
|
||||
model_arch = gguf.MODEL_ARCH.MINIMAXM2
|
||||
@ -4689,10 +4954,12 @@ class SmolLM3Model(LlamaModel):
|
||||
chat_template = tokenizer.chat_template.replace("[:]", "")
|
||||
self.gguf_writer.add_chat_template(chat_template)
|
||||
|
||||
|
||||
@Model.register("SeedOssForCausalLM")
|
||||
class SeedOssModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.SEED_OSS
|
||||
|
||||
|
||||
@Model.register("Dots1ForCausalLM")
|
||||
class Dots1Model(Qwen2MoeModel):
|
||||
model_arch = gguf.MODEL_ARCH.DOTS1
|
||||
@ -4853,6 +5120,7 @@ class Glm4MoeModel(Model):
|
||||
if len(experts) > 0:
|
||||
raise ValueError(f"Unprocessed experts: {experts}")
|
||||
|
||||
|
||||
@Model.register("ChatGLMModel", "ChatGLMForConditionalGeneration")
|
||||
class ChatGLMModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.CHATGLM
|
||||
@ -5035,6 +5303,7 @@ class ChatGLMModel(Model):
|
||||
name = name.removeprefix("transformer.")
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
|
||||
@Model.register("BailingMoeV2ForCausalLM")
|
||||
class BailingMoeV2Model(Model):
|
||||
model_arch = gguf.MODEL_ARCH.BAILINGMOE2
|
||||
@ -5392,6 +5661,10 @@ def parse_args() -> argparse.Namespace:
|
||||
"--metadata", type=Path,
|
||||
help="Specify the path for an authorship metadata override file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--target-model-dir", type=Path,
|
||||
help="matching target model directory; required for DFlash conversion to reuse tokenizer and infer target feature width",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
@ -5471,7 +5744,8 @@ def main() -> None:
|
||||
metadata_override=args.metadata, model_name=args.model_name,
|
||||
split_max_tensors=args.split_max_tensors,
|
||||
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
|
||||
small_first_shard=args.no_tensor_first_split)
|
||||
small_first_shard=args.no_tensor_first_split,
|
||||
target_model_dir=args.target_model_dir)
|
||||
|
||||
if args.vocab_only:
|
||||
logger.info("Exporting model vocab...")
|
||||
|
||||
@ -78,6 +78,10 @@ models = [
|
||||
{"name": "refact", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/smallcloudai/Refact-1_6-base", },
|
||||
{"name": "command-r", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/CohereForAI/c4ai-command-r-v01", },
|
||||
{"name": "qwen2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen1.5-7B", },
|
||||
{"name": "qwen35", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3.5-9B-Instruct", "chkhsh": "d30d75d9059f1aa2c19359de71047b3ae408c70875e8a3ccf8c5fba56c9d8af4", },
|
||||
{"name": "qwen35", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3.5-27B", "chkhsh": "99cc61242f7106804ce24fdf3a6451e4a55251078dffd5453c806e11b2310db3", },
|
||||
{"name": "qwen35", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/z-lab/Qwen3.5-27B-DFlash", "chkhsh": "1444df51289cfa8063b96f0e62b1125440111bc79a52003ea14b6eac7016fd5f", },
|
||||
{"name": "qwen35", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3.6-35B-A3B", "chkhsh": "4f53cda18c2baa0c0354bb5f9a3ecbe5ed12ab4d8e11ba873c2f11161202b945", },
|
||||
{"name": "olmo", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/allenai/OLMo-1.7-7B-hf", },
|
||||
{"name": "dbrx", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/databricks/dbrx-base", },
|
||||
{"name": "jina-v2-en", "tokt": TOKENIZER_TYPE.WPM, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-en", }, # WPM!
|
||||
@ -155,39 +159,46 @@ for model in models:
|
||||
if tokt == TOKENIZER_TYPE.SPM or tokt == TOKENIZER_TYPE.UGM:
|
||||
continue
|
||||
|
||||
# Skip if the tokenizer folder does not exist or there are other download issues previously
|
||||
if not os.path.exists(f"models/tokenizers/{name}"):
|
||||
logger.warning(f"Directory for tokenizer {name} not found. Skipping...")
|
||||
continue
|
||||
chkhsh = model.get("chkhsh")
|
||||
|
||||
# create the tokenizer
|
||||
try:
|
||||
if name == "t5":
|
||||
tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}", use_fast=False)
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}")
|
||||
except (OSError, TypeError) as e:
|
||||
logger.error(f"Error loading tokenizer for model {name}. The model may not exist or is not accessible with the provided token. Error: {e}")
|
||||
continue # Skip to the next model if the tokenizer can't be loaded
|
||||
if chkhsh is None:
|
||||
# Skip if the tokenizer folder does not exist or there are other download issues previously
|
||||
if not os.path.exists(f"models/tokenizers/{name}"):
|
||||
logger.warning(f"Directory for tokenizer {name} not found. Skipping...")
|
||||
continue
|
||||
|
||||
chktok = tokenizer.encode(CHK_TXT)
|
||||
chkhsh = sha256(str(chktok).encode()).hexdigest()
|
||||
# create the tokenizer
|
||||
try:
|
||||
if name == "t5":
|
||||
tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}", use_fast=False)
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}")
|
||||
except (OSError, TypeError) as e:
|
||||
logger.error(f"Error loading tokenizer for model {name}. The model may not exist or is not accessible with the provided token. Error: {e}")
|
||||
continue # Skip to the next model if the tokenizer can't be loaded
|
||||
|
||||
chktok = tokenizer.encode(CHK_TXT)
|
||||
chkhsh = sha256(str(chktok).encode()).hexdigest()
|
||||
|
||||
logger.info(f"model: {name}")
|
||||
logger.info(f"tokt: {tokt}")
|
||||
logger.info(f"repo: {model['repo']}")
|
||||
logger.info(f"chktok: {chktok}")
|
||||
logger.info(f"chkhsh: {chkhsh}")
|
||||
|
||||
# print the "pre_tokenizer" content from the tokenizer.json
|
||||
with open(f"models/tokenizers/{name}/tokenizer.json", "r", encoding="utf-8") as f:
|
||||
cfg = json.load(f)
|
||||
normalizer = cfg["normalizer"]
|
||||
logger.info("normalizer: " + json.dumps(normalizer, indent=4))
|
||||
pre_tokenizer = cfg["pre_tokenizer"]
|
||||
logger.info("pre_tokenizer: " + json.dumps(pre_tokenizer, indent=4))
|
||||
if "ignore_merges" in cfg["model"]:
|
||||
logger.info("ignore_merges: " + json.dumps(cfg["model"]["ignore_merges"], indent=4))
|
||||
if model.get("chkhsh") is None:
|
||||
logger.info(f"chktok: {chktok}")
|
||||
|
||||
# print the "pre_tokenizer" content from the tokenizer.json
|
||||
with open(f"models/tokenizers/{name}/tokenizer.json", "r", encoding="utf-8") as f:
|
||||
cfg = json.load(f)
|
||||
normalizer = cfg["normalizer"]
|
||||
logger.info("normalizer: " + json.dumps(normalizer, indent=4))
|
||||
pre_tokenizer = cfg["pre_tokenizer"]
|
||||
logger.info("pre_tokenizer: " + json.dumps(pre_tokenizer, indent=4))
|
||||
if "ignore_merges" in cfg["model"]:
|
||||
logger.info("ignore_merges: " + json.dumps(cfg["model"]["ignore_merges"], indent=4))
|
||||
else:
|
||||
logger.info("using manually provided tokenizer hash")
|
||||
|
||||
logger.info("")
|
||||
|
||||
@ -354,6 +365,6 @@ logger.info("\nRun the following commands to generate the vocab files for testin
|
||||
for model in models:
|
||||
name = model["name"]
|
||||
|
||||
print(f"python3 convert_hf_to_gguf.py models/tokenizers/{name}/ --outfile models/ggml-vocab-{name}.gguf --vocab-only") # noqa: NP100
|
||||
logger.info(f"python3 convert_hf_to_gguf.py models/tokenizers/{name}/ --outfile models/ggml-vocab-{name}.gguf --vocab-only") # noqa: NP100
|
||||
|
||||
logger.info("\n")
|
||||
|
||||
@ -6,6 +6,7 @@
|
||||
|
||||
#include "common.h"
|
||||
#include "llama.h"
|
||||
#include "llama-spec-features.h"
|
||||
#include "log.h"
|
||||
#include "sampling.h"
|
||||
#include "speculative.h"
|
||||
@ -45,6 +46,16 @@ static void log_text(const gpt_params & params_base, const std::string & text) {
|
||||
}
|
||||
}
|
||||
|
||||
static bool server_slot_prompt_batch_overlaps(
|
||||
const server_slot & slot,
|
||||
int32_t batch_i0,
|
||||
int32_t batch_i1) {
|
||||
if (slot.prompt_batch_i0 < 0 || slot.prompt_batch_i1 <= slot.prompt_batch_i0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return slot.prompt_batch_i0 < batch_i1 && batch_i0 < slot.prompt_batch_i1;
|
||||
}
|
||||
struct server_mtp_warmup {
|
||||
llama_context * ctx_tgt;
|
||||
server_slot * slot;
|
||||
@ -67,6 +78,15 @@ static bool server_response_needs_chat_parse(oaicompat_type oaicompat) {
|
||||
oaicompat == OAICOMPAT_TYPE_RESP;
|
||||
}
|
||||
|
||||
static bool server_speculative_uses_target_features(const common_params_speculative & spec) {
|
||||
return spec.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP) ||
|
||||
spec.has_stage_type(COMMON_SPECULATIVE_TYPE_DFLASH);
|
||||
}
|
||||
|
||||
static bool server_speculative_requires_single_slot(const common_params_speculative & spec) {
|
||||
return spec.has_stage_chain();
|
||||
}
|
||||
|
||||
static bool server_speculative_same_stage_types(
|
||||
const common_params_speculative & lhs,
|
||||
const common_params_speculative & rhs) {
|
||||
@ -151,6 +171,15 @@ static common_speculative_stage_params server_parse_speculative_stage_json(const
|
||||
}
|
||||
|
||||
server_context::~server_context() {
|
||||
// Speculative state may reference the live target context during teardown.
|
||||
for (server_slot& slot : slots) {
|
||||
if (slot.ctx_sampling != nullptr) {
|
||||
common_sampler_free(slot.ctx_sampling);
|
||||
}
|
||||
common_speculative_free(slot.spec);
|
||||
slot.spec = nullptr;
|
||||
}
|
||||
|
||||
if (ctx) {
|
||||
llama_free(ctx);
|
||||
ctx = nullptr;
|
||||
@ -162,17 +191,7 @@ server_context::~server_context() {
|
||||
}
|
||||
// Free multimodal
|
||||
mtmd_free(mctx);
|
||||
|
||||
// Clear any sampling context
|
||||
for (server_slot& slot : slots) {
|
||||
if (slot.ctx_sampling != nullptr) {
|
||||
common_sampler_free(slot.ctx_sampling);
|
||||
}
|
||||
common_speculative_free(slot.spec);
|
||||
}
|
||||
|
||||
params_base.speculative.clear_dft();
|
||||
|
||||
llama_batch_free(batch);
|
||||
}
|
||||
|
||||
@ -197,6 +216,12 @@ bool server_context::load_model(const gpt_params& params_) {
|
||||
|
||||
common_speculative_prepare_startup(params_base, false);
|
||||
|
||||
if (server_speculative_requires_single_slot(params_base.speculative) && params_base.n_parallel > 1) {
|
||||
LOG_ERROR("Speculative decoding is currently limited to a single server slot (-np 1).\n", {
|
||||
{"n_parallel", params_base.n_parallel},
|
||||
});
|
||||
return false;
|
||||
}
|
||||
const bool has_draft_model = params_base.speculative.has_dft();
|
||||
std::string & mmproj_path = params_base.mmproj.path;
|
||||
if (!mmproj_path.empty()) {
|
||||
@ -307,7 +332,6 @@ void server_context::init() {
|
||||
|
||||
slot.params.speculative = params_base.speculative;
|
||||
slot.sparams = params_base.sparams;
|
||||
|
||||
// try speculative decoding
|
||||
if (can_spec && requested_spec) {
|
||||
switch (common_speculative_try_init(params_base.speculative, slot.ctx, &slot.spec)) {
|
||||
@ -445,9 +469,12 @@ void server_slot::reset() {
|
||||
n_past_prompt = 0;
|
||||
n_discarded_prompt = 0;
|
||||
n_kept_prompt = 0;
|
||||
prompt_batch_i0 = -1;
|
||||
prompt_batch_i1 = -1;
|
||||
n_sent_text = 0;
|
||||
drafted.clear();
|
||||
i_batch_dft.clear();
|
||||
spec_prompt_warmup_failed = false;
|
||||
n_sent_token_probs = 0;
|
||||
infill = false;
|
||||
ga_i = 0;
|
||||
@ -540,7 +567,7 @@ void server_slot::add_token_string(const completion_token_output& token) {
|
||||
}
|
||||
|
||||
bool server_slot::can_speculate() const {
|
||||
return (!!spec || uses_mtp());
|
||||
return !spec_prompt_warmup_failed && (!!spec || uses_mtp());
|
||||
}
|
||||
|
||||
int server_slot::get_n_draft_max() const {
|
||||
@ -3186,7 +3213,7 @@ void server_context::discard_n_kv_and_cache_tokens(llama_context* ctx, server_sl
|
||||
const auto pos_max = llama_kv_cache_seq_pos_max(slot.ctx, slot.id);
|
||||
llama_kv_cache_seq_rm(ctx, slot.id, slot.cache_tokens.pos_next(kv_keep), slot.cache_tokens.pos_next(kv_keep + kv_discard));
|
||||
llama_kv_cache_seq_add(ctx, slot.id, kv_keep + kv_discard, kv_past, -kv_discard);
|
||||
if (slot.uses_mtp() && slot.spec) {
|
||||
if (slot.spec) {
|
||||
common_speculative_context_shift(slot.spec, slot.id, kv_keep, kv_discard, kv_past);
|
||||
}
|
||||
if (slot.params.cache_prompt) {
|
||||
@ -3586,6 +3613,9 @@ bool server_context::create_checkpoint(server_slot & slot) {
|
||||
void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t n_batch, int32_t & batch_type) {
|
||||
if (params_base.cont_batching || batch.n_tokens == 0) {
|
||||
for (auto& slot : slots) {
|
||||
slot.prompt_batch_i0 = -1;
|
||||
slot.prompt_batch_i1 = -1;
|
||||
|
||||
// this slot still has a prompt to be processed
|
||||
if (slot.state == SLOT_STATE_IDLE && slot.command == SLOT_COMMAND_LOAD_PROMPT) {
|
||||
auto& prompt_tokens = slot.prompt_tokens;
|
||||
@ -3868,6 +3898,7 @@ void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t
|
||||
int32_t ga_i = slot.ga_i;
|
||||
int32_t ga_n = slot.ga_n;
|
||||
int32_t ga_w = slot.ga_w;
|
||||
const int32_t prompt_batch_i0 = batch.n_tokens;
|
||||
|
||||
// add prompt tokens for processing in the current batch
|
||||
// TODO: the self-extend stuff here is a mess - simplify and/or abstract it somehow
|
||||
@ -3902,6 +3933,9 @@ void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t
|
||||
}
|
||||
|
||||
}
|
||||
slot.prompt_batch_i0 = prompt_batch_i0;
|
||||
slot.prompt_batch_i1 = batch.n_tokens;
|
||||
|
||||
LOG_VERBOSE("prompt processing progress", {
|
||||
{"id_slot", slot.id},
|
||||
{"n_past", slot.n_past},
|
||||
@ -4014,7 +4048,7 @@ void server_context::speculative_decoding_accept() {
|
||||
}
|
||||
|
||||
std::vector<int32_t> accepted_output_indices;
|
||||
if (slot.uses_mtp()) {
|
||||
if (server_speculative_uses_target_features(slot.params.speculative)) {
|
||||
if (!ids.empty()) {
|
||||
accepted_output_indices.assign(slot.i_batch_dft.begin(), slot.i_batch_dft.begin() + ids.size());
|
||||
}
|
||||
@ -4366,6 +4400,7 @@ void server_context::update_allowlist_state(server_slot& slot) {
|
||||
void server_context::process_batch_tokens(int32_t & n_batch) {
|
||||
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
|
||||
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
|
||||
bool finish_prompt_warmup_batch = false;
|
||||
extend_context(n_tokens);
|
||||
|
||||
llama_batch batch_view = {
|
||||
@ -4425,19 +4460,26 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
|
||||
continue; // continue loop of n_batch
|
||||
}
|
||||
|
||||
if (params_base.speculative.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP)) {
|
||||
if (server_speculative_uses_target_features(params_base.speculative)) {
|
||||
for (auto & slot : slots) {
|
||||
if (!slot.spec || !slot.uses_mtp()) {
|
||||
if (!slot.spec || !server_speculative_uses_target_features(slot.params.speculative)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if ((slot.state != SLOT_STATE_PROCESSING || slot.n_decoded != 0) &&
|
||||
(slot.state != SLOT_STATE_IDLE || slot.command != SLOT_COMMAND_LOAD_PROMPT)) {
|
||||
if (slot.spec_prompt_warmup_failed) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!server_slot_prompt_batch_overlaps(slot, i, i + n_tokens)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (common_speculative_on_target_seq_batch(slot.spec, ctx, batch_view, slot.id, true) != 0) {
|
||||
LOG_ERROR("failed to warm up MTP state from prompt batch for slot %d\n", slot.id);
|
||||
common_speculative_clear_sequence_hidden(slot.spec, slot.id);
|
||||
slot.spec_prompt_warmup_failed = true;
|
||||
LOG_ERROR("failed to warm up speculative target-feature state from prompt batch for slot %d\n", slot.id);
|
||||
} else {
|
||||
finish_prompt_warmup_batch = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -4558,6 +4600,10 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
|
||||
}
|
||||
// speculative decoding - main model sample and accept
|
||||
speculative_decoding_accept();
|
||||
|
||||
if (finish_prompt_warmup_batch) {
|
||||
llama_finish_dflash_capture_batch(ctx, true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -4589,6 +4635,11 @@ void server_context::update_slots() {
|
||||
// start populating the batch for this iteration
|
||||
common_batch_clear(batch);
|
||||
|
||||
for (auto & slot : slots) {
|
||||
slot.prompt_batch_i0 = -1;
|
||||
slot.prompt_batch_i1 = -1;
|
||||
}
|
||||
|
||||
// first, add sampled tokens from any ongoing sequences
|
||||
add_sampled_tokens(); // Prepare batch for inference
|
||||
|
||||
|
||||
@ -51,6 +51,8 @@ struct server_slot {
|
||||
|
||||
int32_t i_batch = -1;
|
||||
int32_t n_predict = -1; // TODO: disambiguate from params.n_predict
|
||||
int32_t prompt_batch_i0 = -1;
|
||||
int32_t prompt_batch_i1 = -1;
|
||||
|
||||
int32_t n_prompt_tokens = 0;
|
||||
int32_t n_prompt_tokens_cache = 0;
|
||||
@ -157,6 +159,7 @@ struct server_slot {
|
||||
// expiring logit bias
|
||||
std::vector<common_sampler::elb_state> prev_elb_states;
|
||||
|
||||
bool spec_prompt_warmup_failed = false;
|
||||
// speculative decoding stats
|
||||
int32_t n_draft_total = 0; // Total draft tokens generated
|
||||
int32_t n_draft_accepted = 0; // Draft tokens actually accepted
|
||||
|
||||
@ -246,6 +246,8 @@ class MODEL_ARCH(IntEnum):
|
||||
GEMMA3 = auto()
|
||||
GEMMA4 = auto()
|
||||
GEMMA4_MTP = auto()
|
||||
DFLASH = auto()
|
||||
DFLASH_DRAFT = auto()
|
||||
STARCODER2 = auto()
|
||||
MAMBA = auto()
|
||||
XVERSE = auto()
|
||||
@ -272,6 +274,7 @@ class MODEL_ARCH(IntEnum):
|
||||
SEED_OSS = auto()
|
||||
LAGUNA = auto()
|
||||
|
||||
|
||||
class MODEL_TENSOR(IntEnum):
|
||||
TOKEN_EMBD = auto()
|
||||
TOKEN_EMBD_NORM = auto()
|
||||
@ -379,6 +382,8 @@ class MODEL_TENSOR(IntEnum):
|
||||
MTP_POST_PROJ = auto()
|
||||
MTP_TOKEN_ORDERING = auto()
|
||||
MTP_CENTROIDS = auto()
|
||||
DFLASH_FC = auto()
|
||||
DFLASH_HIDDEN_NORM = auto()
|
||||
|
||||
|
||||
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
@ -416,6 +421,8 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
MODEL_ARCH.GEMMA3: "gemma3",
|
||||
MODEL_ARCH.GEMMA4: "gemma4",
|
||||
MODEL_ARCH.GEMMA4_MTP: "gemma4_mtp",
|
||||
MODEL_ARCH.DFLASH: "dflash",
|
||||
MODEL_ARCH.DFLASH_DRAFT: "dflash-draft",
|
||||
MODEL_ARCH.STARCODER2: "starcoder2",
|
||||
MODEL_ARCH.MAMBA: "mamba",
|
||||
MODEL_ARCH.XVERSE: "xverse",
|
||||
@ -551,6 +558,8 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
MODEL_TENSOR.MTP_POST_PROJ: "mtp_post_proj",
|
||||
MODEL_TENSOR.MTP_TOKEN_ORDERING: "mtp_token_ordering",
|
||||
MODEL_TENSOR.MTP_CENTROIDS: "mtp_centroids",
|
||||
MODEL_TENSOR.DFLASH_FC: "dflash_fc",
|
||||
MODEL_TENSOR.DFLASH_HIDDEN_NORM: "dflash_hidden_norm",
|
||||
}
|
||||
|
||||
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
@ -1286,6 +1295,40 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD,
|
||||
MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM,
|
||||
],
|
||||
MODEL_ARCH.DFLASH: [
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_Q_NORM,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_K_NORM,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.ATTN_POST_NORM,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.DFLASH_FC,
|
||||
MODEL_TENSOR.DFLASH_HIDDEN_NORM,
|
||||
],
|
||||
MODEL_ARCH.DFLASH_DRAFT: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_Q_NORM,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_K_NORM,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.ATTN_POST_NORM,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.DFLASH_FC,
|
||||
MODEL_TENSOR.DFLASH_HIDDEN_NORM,
|
||||
],
|
||||
MODEL_ARCH.BITNET: [
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
|
||||
@ -1096,6 +1096,10 @@ extern "C" {
|
||||
// returns NULL for invalid ids.
|
||||
LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i);
|
||||
|
||||
// Get the argmax token ID for DFlash draft position i without materializing full logits.
|
||||
// Returns LLAMA_TOKEN_NULL if argmax is not available (falls back to logits path).
|
||||
LLAMA_API llama_token llama_get_dflash_draft_token_ith(struct llama_context * ctx, int32_t i);
|
||||
|
||||
// Get all output token embeddings.
|
||||
// when pooling_type == LLAMA_POOLING_TYPE_NONE or when using a generative model,
|
||||
// the embeddings for which llama_batch.logits[i] != 0 are stored contiguously
|
||||
|
||||
@ -41,6 +41,8 @@ add_library(llama
|
||||
../include/llama.h
|
||||
llama.cpp
|
||||
llama-spec-features.cpp
|
||||
llama-spec-features-dflash.cpp
|
||||
llama-dflash.cpp
|
||||
llama-vocab.cpp
|
||||
llama-grammar.cpp
|
||||
llama-sampling.cpp
|
||||
@ -99,6 +101,7 @@ add_library(llama
|
||||
graphs/build_gemma2.cpp
|
||||
graphs/build_gemma3.cpp
|
||||
graphs/build_gemma4.cpp
|
||||
graphs/build_dflash.cpp
|
||||
graphs/build_mamba.cpp
|
||||
graphs/build_command_r.cpp
|
||||
graphs/build_olmo.cpp
|
||||
|
||||
439
src/graphs/build_dflash.cpp
Normal file
439
src/graphs/build_dflash.cpp
Normal file
@ -0,0 +1,439 @@
|
||||
#include "../llama-build-context.h"
|
||||
#include "../llama-context.h"
|
||||
#include "../llama-model.h"
|
||||
|
||||
#include <cmath>
|
||||
|
||||
ggml_cgraph * llm_build_context::build_dflash_kv_workspace() {
|
||||
const int64_t n_embd_head_k = hparams.n_embd_head_k(0);
|
||||
const int64_t n_embd_head_v = hparams.n_embd_head_v(0);
|
||||
const int64_t ctx_len = lctx.dflash.visible_cross_ctx > 0
|
||||
? (int64_t) lctx.dflash.visible_cross_ctx
|
||||
: std::max<int64_t>(1, (int64_t) cparams.n_ctx - (int64_t) hparams.dflash_block_size);
|
||||
const int32_t cache_rows = std::clamp(lctx.dflash.kv.cache_view_n_filled, 0, (int32_t) ctx_len);
|
||||
const int32_t cache_write_pos = ctx_len > 0
|
||||
? ((lctx.dflash.kv.cache_view_write_pos % (int32_t) ctx_len) + (int32_t) ctx_len) % (int32_t) ctx_len
|
||||
: 0;
|
||||
|
||||
GGML_ASSERT(n_embd_head_k == n_embd_head_v);
|
||||
GGML_ASSERT(lctx.ensure_dflash_kv_cache_tensors((int32_t) ctx_len));
|
||||
GGML_ASSERT((int32_t) lctx.dflash.kv.k_ctx_workspace.size() == n_layer);
|
||||
GGML_ASSERT((int32_t) lctx.dflash.kv.v_ctx_workspace.size() == n_layer);
|
||||
|
||||
ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes((int) std::max<int64_t>(1, ctx_len)) + 16 * n_layer, false);
|
||||
|
||||
auto build_ordered_cache_view = [&](ggml_tensor * cache) -> ggml_tensor * {
|
||||
if (!lctx.dflash.kv.cache_view_valid || cache_rows <= 0) {
|
||||
return cache;
|
||||
}
|
||||
|
||||
if (cache_rows < ctx_len) {
|
||||
ggml_tensor * zero_pad = ggml_view_3d(ctx0, cache,
|
||||
cache->ne[0],
|
||||
cache->ne[1],
|
||||
ctx_len - cache_rows,
|
||||
cache->nb[1],
|
||||
cache->nb[2],
|
||||
(size_t) cache_rows * cache->nb[2]);
|
||||
ggml_tensor * valid = ggml_view_3d(ctx0, cache,
|
||||
cache->ne[0],
|
||||
cache->ne[1],
|
||||
cache_rows,
|
||||
cache->nb[1],
|
||||
cache->nb[2],
|
||||
0);
|
||||
return ggml_concat(ctx0, zero_pad, valid, 2);
|
||||
}
|
||||
|
||||
if (cache_write_pos == 0) {
|
||||
return cache;
|
||||
}
|
||||
|
||||
ggml_tensor * tail = ggml_view_3d(ctx0, cache,
|
||||
cache->ne[0],
|
||||
cache->ne[1],
|
||||
ctx_len - cache_write_pos,
|
||||
cache->nb[1],
|
||||
cache->nb[2],
|
||||
(size_t) cache_write_pos * cache->nb[2]);
|
||||
ggml_tensor * head = ggml_view_3d(ctx0, cache,
|
||||
cache->ne[0],
|
||||
cache->ne[1],
|
||||
cache_write_pos,
|
||||
cache->nb[1],
|
||||
cache->nb[2],
|
||||
0);
|
||||
return ggml_concat(ctx0, tail, head, 2);
|
||||
};
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
GGML_ASSERT((size_t) il < lctx.dflash.kv.k_ctx_cache.size());
|
||||
GGML_ASSERT((size_t) il < lctx.dflash.kv.v_ctx_cache.size());
|
||||
|
||||
ggml_tensor * Kordered = build_ordered_cache_view(lctx.dflash.kv.k_ctx_cache[(size_t) il]);
|
||||
ggml_tensor * Vordered = build_ordered_cache_view(lctx.dflash.kv.v_ctx_cache[(size_t) il]);
|
||||
cb(Kordered, "dflash_workspace_k_ctx_view", il);
|
||||
cb(Vordered, "dflash_workspace_v_ctx_view", il);
|
||||
|
||||
ggml_tensor * Kworkspace = ggml_cont(ctx0, ggml_permute(ctx0, Kordered, 0, 2, 1, 3));
|
||||
ggml_tensor * Vworkspace = ggml_cont(ctx0, ggml_permute(ctx0, Vordered, 0, 2, 1, 3));
|
||||
cb(Kworkspace, "dflash_workspace_k_perm_cont", il);
|
||||
cb(Vworkspace, "dflash_workspace_v_perm_cont", il);
|
||||
|
||||
ggml_tensor * Kdst = ggml_view_3d(ctx0, lctx.dflash.kv.k_ctx_workspace[(size_t) il],
|
||||
lctx.dflash.kv.k_ctx_workspace[(size_t) il]->ne[0],
|
||||
ctx_len,
|
||||
lctx.dflash.kv.k_ctx_workspace[(size_t) il]->ne[2],
|
||||
lctx.dflash.kv.k_ctx_workspace[(size_t) il]->nb[1],
|
||||
lctx.dflash.kv.k_ctx_workspace[(size_t) il]->nb[2],
|
||||
0);
|
||||
ggml_tensor * Vdst = ggml_view_3d(ctx0, lctx.dflash.kv.v_ctx_workspace[(size_t) il],
|
||||
lctx.dflash.kv.v_ctx_workspace[(size_t) il]->ne[0],
|
||||
ctx_len,
|
||||
lctx.dflash.kv.v_ctx_workspace[(size_t) il]->ne[2],
|
||||
lctx.dflash.kv.v_ctx_workspace[(size_t) il]->nb[1],
|
||||
lctx.dflash.kv.v_ctx_workspace[(size_t) il]->nb[2],
|
||||
0);
|
||||
|
||||
ggml_tensor * Kstore = ggml_cpy(ctx0, Kworkspace, Kdst);
|
||||
ggml_tensor * Vstore = ggml_cpy(ctx0, Vworkspace, Vdst);
|
||||
cb(Kstore, "dflash_workspace_k_store", il);
|
||||
cb(Vstore, "dflash_workspace_v_store", il);
|
||||
ggml_build_forward_expand(gf, Kstore);
|
||||
ggml_build_forward_expand(gf, Vstore);
|
||||
}
|
||||
|
||||
return gf;
|
||||
}
|
||||
|
||||
ggml_cgraph * llm_build_context::build_dflash_kv_cache() {
|
||||
const int64_t n_embd_head_k = hparams.n_embd_head_k(0);
|
||||
const int64_t n_embd_head_v = hparams.n_embd_head_v(0);
|
||||
const int64_t n_target_features = hparams.dflash_n_target_features;
|
||||
const int64_t ctx_len = lctx.dflash.visible_cross_ctx > 0
|
||||
? (int64_t) lctx.dflash.visible_cross_ctx
|
||||
: std::max<int64_t>(1, (int64_t) cparams.n_ctx - (int64_t) hparams.dflash_block_size);
|
||||
const int64_t update_rows = std::max<int64_t>(1, lctx.dflash.kv.cache_update_rows > 0 ? lctx.dflash.kv.cache_update_rows : ctx_len);
|
||||
const int32_t write_pos = lctx.dflash.kv.cache_write_pos;
|
||||
|
||||
GGML_ASSERT(n_embd_head_k == n_embd_head_v);
|
||||
GGML_ASSERT(n_target_features > 0);
|
||||
GGML_ASSERT(lctx.ensure_dflash_kv_cache_tensors((int32_t) ctx_len));
|
||||
GGML_ASSERT(update_rows > 0 && update_rows <= ctx_len);
|
||||
GGML_ASSERT(write_pos >= 0 && write_pos < ctx_len);
|
||||
|
||||
ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes((int) std::max<int64_t>(1, update_rows)) + 24 * n_layer, false);
|
||||
|
||||
lctx.dflash.kv.cache_input_target_features = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_target_features, update_rows);
|
||||
ggml_set_input(lctx.dflash.kv.cache_input_target_features);
|
||||
cb(lctx.dflash.kv.cache_input_target_features, "dflash_kv_input_target_features", -1);
|
||||
|
||||
lctx.dflash.kv.cache_input_pos_ctx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, update_rows);
|
||||
ggml_set_input(lctx.dflash.kv.cache_input_pos_ctx);
|
||||
cb(lctx.dflash.kv.cache_input_pos_ctx, "dflash_kv_input_pos_ctx", -1);
|
||||
|
||||
ggml_tensor * fused_target = llm_build_lora_mm(lctx, ctx0, model.dflash_fc, lctx.dflash.kv.cache_input_target_features);
|
||||
fused_target = llm_build_norm(ctx0, fused_target, hparams, model.dflash_hidden_norm, nullptr, LLM_NORM_RMS, cb, -1);
|
||||
cb(fused_target, "dflash_kv_fused_target", -1);
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
GGML_ASSERT((size_t) il < lctx.dflash.kv.k_ctx_cache.size());
|
||||
GGML_ASSERT((size_t) il < lctx.dflash.kv.v_ctx_cache.size());
|
||||
|
||||
ggml_tensor * Kcur_ctx_proj = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, fused_target);
|
||||
cb(Kcur_ctx_proj, "dflash_kv_k_proj", il);
|
||||
|
||||
ggml_tensor * Kcur_ctx = ggml_reshape_3d(ctx0, Kcur_ctx_proj, n_embd_head_k, n_head_kv, update_rows);
|
||||
Kcur_ctx = llm_build_norm(ctx0, Kcur_ctx, hparams, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, cb, il);
|
||||
cb(Kcur_ctx, "dflash_kv_k_norm", il);
|
||||
Kcur_ctx = ggml_rope_ext(ctx0, Kcur_ctx, lctx.dflash.kv.cache_input_pos_ctx, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
cb(Kcur_ctx, "dflash_kv_k_rope", il);
|
||||
|
||||
ggml_tensor * Vcur_ctx = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, fused_target);
|
||||
cb(Vcur_ctx, "dflash_kv_v_proj", il);
|
||||
Vcur_ctx = ggml_reshape_3d(ctx0, Vcur_ctx, n_embd_head_v, n_head_kv, update_rows);
|
||||
|
||||
const int32_t first_rows = std::min<int32_t>((int32_t) update_rows, (int32_t) ctx_len - write_pos);
|
||||
const int32_t second_rows = (int32_t) update_rows - first_rows;
|
||||
|
||||
if (first_rows > 0) {
|
||||
ggml_tensor * Ksrc_first = first_rows == update_rows
|
||||
? Kcur_ctx
|
||||
: ggml_view_3d(ctx0, Kcur_ctx,
|
||||
Kcur_ctx->ne[0],
|
||||
Kcur_ctx->ne[1],
|
||||
first_rows,
|
||||
Kcur_ctx->nb[1],
|
||||
Kcur_ctx->nb[2],
|
||||
0);
|
||||
ggml_tensor * Vsrc_first = first_rows == update_rows
|
||||
? Vcur_ctx
|
||||
: ggml_view_3d(ctx0, Vcur_ctx,
|
||||
Vcur_ctx->ne[0],
|
||||
Vcur_ctx->ne[1],
|
||||
first_rows,
|
||||
Vcur_ctx->nb[1],
|
||||
Vcur_ctx->nb[2],
|
||||
0);
|
||||
ggml_tensor * Kdst_first = ggml_view_3d(ctx0, lctx.dflash.kv.k_ctx_cache[(size_t) il],
|
||||
lctx.dflash.kv.k_ctx_cache[(size_t) il]->ne[0],
|
||||
lctx.dflash.kv.k_ctx_cache[(size_t) il]->ne[1],
|
||||
first_rows,
|
||||
lctx.dflash.kv.k_ctx_cache[(size_t) il]->nb[1],
|
||||
lctx.dflash.kv.k_ctx_cache[(size_t) il]->nb[2],
|
||||
(size_t) write_pos * lctx.dflash.kv.k_ctx_cache[(size_t) il]->nb[2]);
|
||||
ggml_tensor * Vdst_first = ggml_view_3d(ctx0, lctx.dflash.kv.v_ctx_cache[(size_t) il],
|
||||
lctx.dflash.kv.v_ctx_cache[(size_t) il]->ne[0],
|
||||
lctx.dflash.kv.v_ctx_cache[(size_t) il]->ne[1],
|
||||
first_rows,
|
||||
lctx.dflash.kv.v_ctx_cache[(size_t) il]->nb[1],
|
||||
lctx.dflash.kv.v_ctx_cache[(size_t) il]->nb[2],
|
||||
(size_t) write_pos * lctx.dflash.kv.v_ctx_cache[(size_t) il]->nb[2]);
|
||||
|
||||
ggml_tensor * Kstore_first = ggml_cpy(ctx0, Ksrc_first, Kdst_first);
|
||||
cb(Kstore_first, "dflash_kv_k_store", il);
|
||||
ggml_build_forward_expand(gf, Kstore_first);
|
||||
|
||||
ggml_tensor * Vstore_first = ggml_cpy(ctx0, Vsrc_first, Vdst_first);
|
||||
cb(Vstore_first, "dflash_kv_v_store", il);
|
||||
ggml_build_forward_expand(gf, Vstore_first);
|
||||
}
|
||||
|
||||
if (second_rows > 0) {
|
||||
ggml_tensor * Ksrc_second = ggml_view_3d(ctx0, Kcur_ctx,
|
||||
Kcur_ctx->ne[0],
|
||||
Kcur_ctx->ne[1],
|
||||
second_rows,
|
||||
Kcur_ctx->nb[1],
|
||||
Kcur_ctx->nb[2],
|
||||
(size_t) first_rows * Kcur_ctx->nb[2]);
|
||||
ggml_tensor * Vsrc_second = ggml_view_3d(ctx0, Vcur_ctx,
|
||||
Vcur_ctx->ne[0],
|
||||
Vcur_ctx->ne[1],
|
||||
second_rows,
|
||||
Vcur_ctx->nb[1],
|
||||
Vcur_ctx->nb[2],
|
||||
(size_t) first_rows * Vcur_ctx->nb[2]);
|
||||
ggml_tensor * Kdst_second = ggml_view_3d(ctx0, lctx.dflash.kv.k_ctx_cache[(size_t) il],
|
||||
lctx.dflash.kv.k_ctx_cache[(size_t) il]->ne[0],
|
||||
lctx.dflash.kv.k_ctx_cache[(size_t) il]->ne[1],
|
||||
second_rows,
|
||||
lctx.dflash.kv.k_ctx_cache[(size_t) il]->nb[1],
|
||||
lctx.dflash.kv.k_ctx_cache[(size_t) il]->nb[2],
|
||||
0);
|
||||
ggml_tensor * Vdst_second = ggml_view_3d(ctx0, lctx.dflash.kv.v_ctx_cache[(size_t) il],
|
||||
lctx.dflash.kv.v_ctx_cache[(size_t) il]->ne[0],
|
||||
lctx.dflash.kv.v_ctx_cache[(size_t) il]->ne[1],
|
||||
second_rows,
|
||||
lctx.dflash.kv.v_ctx_cache[(size_t) il]->nb[1],
|
||||
lctx.dflash.kv.v_ctx_cache[(size_t) il]->nb[2],
|
||||
0);
|
||||
|
||||
ggml_tensor * Kstore_second = ggml_cpy(ctx0, Ksrc_second, Kdst_second);
|
||||
cb(Kstore_second, "dflash_kv_k_store", il);
|
||||
ggml_build_forward_expand(gf, Kstore_second);
|
||||
|
||||
ggml_tensor * Vstore_second = ggml_cpy(ctx0, Vsrc_second, Vdst_second);
|
||||
cb(Vstore_second, "dflash_kv_v_store", il);
|
||||
ggml_build_forward_expand(gf, Vstore_second);
|
||||
}
|
||||
}
|
||||
|
||||
return gf;
|
||||
}
|
||||
|
||||
ggml_cgraph * llm_build_context::build_dflash() {
|
||||
const int64_t n_embd_head_k = hparams.n_embd_head_k(0);
|
||||
const int64_t n_embd_head_v = hparams.n_embd_head_v(0);
|
||||
const int64_t n_target_features = hparams.dflash_n_target_features;
|
||||
const int64_t ctx_len = lctx.dflash.visible_cross_ctx > 0
|
||||
? (int64_t) lctx.dflash.visible_cross_ctx
|
||||
: std::max<int64_t>(1, (int64_t) cparams.n_ctx - (int64_t) hparams.dflash_block_size);
|
||||
const int32_t cache_write_pos = ctx_len > 0
|
||||
? ((lctx.dflash.kv.cache_view_write_pos % (int32_t) ctx_len) + (int32_t) ctx_len) % (int32_t) ctx_len
|
||||
: 0;
|
||||
const int64_t n_kv_total = GGML_PAD(ctx_len + n_tokens, flash_attn ? 256 : 32);
|
||||
const int64_t n_kv_pad = n_kv_total - (ctx_len + n_tokens);
|
||||
|
||||
GGML_ASSERT(n_embd_head_k == n_embd_head_v);
|
||||
GGML_ASSERT(n_target_features > 0);
|
||||
GGML_ASSERT(lctx.ensure_dflash_kv_cache_tensors((int32_t) ctx_len));
|
||||
GGML_ASSERT(cache_write_pos >= 0 && cache_write_pos < ctx_len);
|
||||
|
||||
ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes((int) std::max<int64_t>(n_tokens, ctx_len)) + 32 * n_layer, false);
|
||||
|
||||
bool have_swa_layers = false;
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
if (hparams.swa_layers[il]) {
|
||||
have_swa_layers = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
lctx.dflash.inputs.kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv_total, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
lctx.dflash.kv.kq_mask_tensor = lctx.dflash.inputs.kq_mask;
|
||||
ggml_set_input(lctx.dflash.inputs.kq_mask);
|
||||
cb(lctx.dflash.inputs.kq_mask, "dflash_kq_mask", -1);
|
||||
|
||||
ggml_tensor * dflash_kq_mask_full = flash_attn ? ggml_cast(ctx0, lctx.dflash.inputs.kq_mask, GGML_TYPE_F16) : lctx.dflash.inputs.kq_mask;
|
||||
ggml_tensor * dflash_kq_mask_swa = nullptr;
|
||||
lctx.dflash.inputs.kq_mask_swa = nullptr;
|
||||
lctx.dflash.kv.kq_mask_swa_tensor = nullptr;
|
||||
if (have_swa_layers && hparams.n_swa > 0) {
|
||||
lctx.dflash.inputs.kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv_total, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
lctx.dflash.kv.kq_mask_swa_tensor = lctx.dflash.inputs.kq_mask_swa;
|
||||
ggml_set_input(lctx.dflash.inputs.kq_mask_swa);
|
||||
cb(lctx.dflash.inputs.kq_mask_swa, "dflash_kq_mask_swa", -1);
|
||||
dflash_kq_mask_swa = flash_attn ? ggml_cast(ctx0, lctx.dflash.inputs.kq_mask_swa, GGML_TYPE_F16) : lctx.dflash.inputs.kq_mask_swa;
|
||||
}
|
||||
|
||||
ggml_tensor * tok_embd = model.tok_embd;
|
||||
if (tok_embd == nullptr) {
|
||||
tok_embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_Q4_0, n_embd, hparams.n_vocab);
|
||||
}
|
||||
|
||||
ggml_tensor * inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, tok_embd, cb);
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
ggml_tensor * inp_out_ids = (n_tokens > 1 && n_outputs < n_tokens) ? build_inp_out_ids() : nullptr;
|
||||
bool result_rows_selected = false;
|
||||
|
||||
const float kq_scale = 1.0f / std::sqrt((float) n_embd_head_k);
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
|
||||
ggml_tensor * cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens);
|
||||
Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, cb, il);
|
||||
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
ggml_tensor * Kcur_noise = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
|
||||
Kcur_noise = ggml_reshape_3d(ctx0, Kcur_noise, n_embd_head_k, n_head_kv, n_tokens);
|
||||
Kcur_noise = llm_build_norm(ctx0, Kcur_noise, hparams, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, cb, il);
|
||||
Kcur_noise = ggml_rope_ext(ctx0, Kcur_noise, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
cb(Kcur_noise, "Kcur_noise", il);
|
||||
|
||||
ggml_tensor * Vcur_noise = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
|
||||
Vcur_noise = ggml_reshape_3d(ctx0, Vcur_noise, n_embd_head_v, n_head_kv, n_tokens);
|
||||
cb(Vcur_noise, "Vcur_noise", il);
|
||||
|
||||
GGML_ASSERT((size_t) il < lctx.dflash.kv.k_ctx_workspace.size());
|
||||
GGML_ASSERT((size_t) il < lctx.dflash.kv.v_ctx_workspace.size());
|
||||
GGML_ASSERT(lctx.dflash.kv.k_ctx_workspace[(size_t) il] != nullptr);
|
||||
GGML_ASSERT(lctx.dflash.kv.v_ctx_workspace[(size_t) il] != nullptr);
|
||||
|
||||
ggml_tensor * Kcur_ctx = ggml_view_3d(ctx0, lctx.dflash.kv.k_ctx_workspace[(size_t) il],
|
||||
lctx.dflash.kv.k_ctx_workspace[(size_t) il]->ne[0],
|
||||
ctx_len,
|
||||
lctx.dflash.kv.k_ctx_workspace[(size_t) il]->ne[2],
|
||||
lctx.dflash.kv.k_ctx_workspace[(size_t) il]->nb[1],
|
||||
lctx.dflash.kv.k_ctx_workspace[(size_t) il]->nb[2],
|
||||
0);
|
||||
ggml_tensor * Vcur_ctx = ggml_view_3d(ctx0, lctx.dflash.kv.v_ctx_workspace[(size_t) il],
|
||||
lctx.dflash.kv.v_ctx_workspace[(size_t) il]->ne[0],
|
||||
ctx_len,
|
||||
lctx.dflash.kv.v_ctx_workspace[(size_t) il]->ne[2],
|
||||
lctx.dflash.kv.v_ctx_workspace[(size_t) il]->nb[1],
|
||||
lctx.dflash.kv.v_ctx_workspace[(size_t) il]->nb[2],
|
||||
0);
|
||||
cb(Kcur_ctx, "Kcur_ctx_workspace", il);
|
||||
cb(Vcur_ctx, "Vcur_ctx_workspace", il);
|
||||
|
||||
ggml_tensor * Kcur_draft = ggml_cont(ctx0, ggml_permute(ctx0, Kcur_noise, 0, 2, 1, 3));
|
||||
ggml_tensor * Vcur_draft = ggml_cont(ctx0, ggml_permute(ctx0, Vcur_noise, 0, 2, 1, 3));
|
||||
cb(Kcur_draft, "dflash_main_k_perm_cont", il);
|
||||
cb(Vcur_draft, "dflash_main_v_perm_cont", il);
|
||||
|
||||
ggml_tensor * Kcur = ggml_concat(ctx0, Kcur_ctx, Kcur_draft, 1);
|
||||
ggml_tensor * Vcur = ggml_concat(ctx0, Vcur_ctx, Vcur_draft, 1);
|
||||
cb(Kcur, "dflash_main_k_concat", il);
|
||||
cb(Vcur, "dflash_main_v_concat", il);
|
||||
|
||||
if (n_kv_pad > 0) {
|
||||
Kcur = ggml_pad(ctx0, Kcur, 0, (int) n_kv_pad, 0, 0);
|
||||
Vcur = ggml_pad(ctx0, Vcur, 0, (int) n_kv_pad, 0, 0);
|
||||
cb(Kcur, "dflash_main_k_pad", il);
|
||||
cb(Vcur, "dflash_main_v_pad", il);
|
||||
}
|
||||
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
|
||||
ggml_tensor * k = Kcur;
|
||||
ggml_tensor * v = Vcur;
|
||||
ggml_tensor * dflash_kq_mask_l = (hparams.swa_layers[il] && dflash_kq_mask_swa != nullptr)
|
||||
? dflash_kq_mask_swa
|
||||
: dflash_kq_mask_full;
|
||||
cb(q, "q", il);
|
||||
|
||||
cur = ggml_flash_attn_ext(ctx0, q, k, v, dflash_kq_mask_l, kq_scale, hparams.f_max_alibi_bias,
|
||||
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
|
||||
cb(cur, "flash_attn", il);
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
cur = ggml_reshape_2d(ctx0, cur, model.layers[il].wo->ne[0], n_tokens);
|
||||
cb(cur, "flash_attn_reshaped", il);
|
||||
|
||||
cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur);
|
||||
cb(cur, "kqv_out", il);
|
||||
|
||||
cur = ggml_add(ctx0, cur, inpSA);
|
||||
cb(cur, "attn_residual", il);
|
||||
|
||||
if (inp_out_ids != nullptr && il == n_layer - 1) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
cb(cur, "result_output_rows", -1);
|
||||
result_rows_selected = true;
|
||||
}
|
||||
|
||||
ggml_tensor * ffn_residual = cur;
|
||||
cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "attn_post_norm", il);
|
||||
|
||||
cur = llm_build_ffn(ctx0, lctx, nullptr, cur,
|
||||
model.layers[il].ffn_up, nullptr, nullptr,
|
||||
model.layers[il].ffn_gate, nullptr, nullptr,
|
||||
model.layers[il].ffn_down, nullptr, nullptr,
|
||||
nullptr,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, cb, il, gf, false, false);
|
||||
cb(cur, "ffn_out", il);
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_residual);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
ggml_tensor * output = const_cast<ggml_tensor *>(llama_model_dflash_output_tensor(&model));
|
||||
if (output == nullptr) {
|
||||
output = ggml_new_tensor_2d(ctx0, GGML_TYPE_Q4_0, n_embd, hparams.n_vocab);
|
||||
}
|
||||
|
||||
ggml_tensor * result_input = inpL;
|
||||
if (inp_out_ids && !result_rows_selected) {
|
||||
result_input = ggml_get_rows(ctx0, result_input, inp_out_ids);
|
||||
cb(result_input, "result_output_rows", -1);
|
||||
}
|
||||
|
||||
ggml_tensor * result = build_output(lctx, ctx0, result_input, output, model.output_norm, cb);
|
||||
cb(result, "result_output", -1);
|
||||
ggml_build_forward_expand(gf, result);
|
||||
|
||||
lctx.dflash.draft_tokens_tensor = nullptr;
|
||||
ggml_tensor * draft_tokens = ggml_argmax(ctx0, result);
|
||||
ggml_set_name(draft_tokens, "draft_argmax");
|
||||
ggml_build_forward_expand(gf, draft_tokens);
|
||||
lctx.dflash.draft_tokens_tensor = draft_tokens;
|
||||
|
||||
return gf;
|
||||
}
|
||||
@ -83,6 +83,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_MISTRAL4, "mistral4" },
|
||||
{ LLM_ARCH_GEMMA4, "gemma4" },
|
||||
{ LLM_ARCH_GEMMA4_MTP, "gemma4_mtp" },
|
||||
{ LLM_ARCH_DFLASH_DRAFT, "dflash-draft" },
|
||||
{ LLM_ARCH_GEMMA4_ASSISTANT,"gemma4_assistant" },
|
||||
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
||||
};
|
||||
@ -153,6 +154,10 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||
{ LLM_KV_MTP_USE_ORDERED_EMBEDDINGS, "%s.use_ordered_embeddings" },
|
||||
{ LLM_KV_MTP_CENTROID_COUNT, "%s.centroid_count" },
|
||||
{ LLM_KV_MTP_CENTROID_TOP_K, "%s.centroid_top_k" },
|
||||
{ LLM_KV_DFLASH_BLOCK_SIZE, "%s.dflash.block_size" },
|
||||
{ LLM_KV_DFLASH_MASK_TOKEN_ID, "%s.dflash.mask_token_id" },
|
||||
{ LLM_KV_DFLASH_TARGET_LAYER_IDS, "%s.dflash.target_layer_ids" },
|
||||
{ LLM_KV_DFLASH_N_TARGET_FEATURES, "%s.dflash.n_target_features" },
|
||||
|
||||
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
|
||||
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
|
||||
|
||||
@ -82,6 +82,7 @@ enum llm_arch {
|
||||
LLM_ARCH_MISTRAL4,
|
||||
LLM_ARCH_GEMMA4,
|
||||
LLM_ARCH_GEMMA4_MTP,
|
||||
LLM_ARCH_DFLASH_DRAFT,
|
||||
LLM_ARCH_GEMMA4_ASSISTANT,
|
||||
LLM_ARCH_UNKNOWN,
|
||||
};
|
||||
@ -143,6 +144,10 @@ enum llm_kv {
|
||||
LLM_KV_MTP_USE_ORDERED_EMBEDDINGS,
|
||||
LLM_KV_MTP_CENTROID_COUNT,
|
||||
LLM_KV_MTP_CENTROID_TOP_K,
|
||||
LLM_KV_DFLASH_BLOCK_SIZE,
|
||||
LLM_KV_DFLASH_MASK_TOKEN_ID,
|
||||
LLM_KV_DFLASH_TARGET_LAYER_IDS,
|
||||
LLM_KV_DFLASH_N_TARGET_FEATURES,
|
||||
|
||||
LLM_KV_ATTENTION_HEAD_COUNT,
|
||||
LLM_KV_ATTENTION_HEAD_COUNT_KV,
|
||||
@ -372,6 +377,8 @@ enum llm_tensor {
|
||||
LLM_TENSOR_MTP_POST_PROJ,
|
||||
LLM_TENSOR_MTP_TOKEN_ORDERING,
|
||||
LLM_TENSOR_MTP_CENTROIDS,
|
||||
LLM_TENSOR_DFLASH_FC,
|
||||
LLM_TENSOR_DFLASH_HIDDEN_NORM,
|
||||
|
||||
LLM_TENSOR_UNKNOWN,
|
||||
};
|
||||
|
||||
@ -35,7 +35,9 @@ llm_build_context::llm_build_context(
|
||||
const llm_build_cb & cb,
|
||||
bool worst_case,
|
||||
bool warmup,
|
||||
int n_outputs_) :
|
||||
int n_outputs_,
|
||||
bool clear_lctx_inputs,
|
||||
std::vector<uint8_t> * buf_compute_meta_override) :
|
||||
model (lctx.model),
|
||||
lctx (lctx),
|
||||
hparams (model.hparams),
|
||||
@ -82,8 +84,9 @@ llm_build_context::llm_build_context(
|
||||
thresh_experts (cparams.thresh_experts),
|
||||
pooling_type (cparams.pooling_type),
|
||||
rope_type (hparams.rope_type),
|
||||
clear_lctx_inputs(clear_lctx_inputs),
|
||||
cb (cb),
|
||||
buf_compute_meta (lctx.buf_compute_meta) {
|
||||
buf_compute_meta (buf_compute_meta_override ? *buf_compute_meta_override : lctx.buf_compute_meta) {
|
||||
// all initializations should be done in init()
|
||||
}
|
||||
|
||||
@ -96,22 +99,27 @@ void llm_build_context::init() {
|
||||
|
||||
ctx0 = ggml_init(params);
|
||||
|
||||
lctx.inp_tokens = nullptr;
|
||||
lctx.inp_embd = nullptr;
|
||||
lctx.inp_pos = nullptr;
|
||||
lctx.inp_out_ids = nullptr;
|
||||
lctx.inp_KQ_mask = nullptr;
|
||||
lctx.inp_KQ_mask_swa = nullptr;
|
||||
lctx.inp_K_shift = nullptr;
|
||||
lctx.inp_mean = nullptr;
|
||||
lctx.inp_cls = nullptr;
|
||||
lctx.inp_s_copy = nullptr;
|
||||
lctx.inp_s_mask = nullptr;
|
||||
lctx.inp_s_seq = nullptr;
|
||||
lctx.inp_s_seq_qnext = nullptr;
|
||||
lctx.inp_pos_bucket = nullptr;
|
||||
lctx.inp_embd_enc = nullptr;
|
||||
lctx.inp_KQ_mask_cross = nullptr;
|
||||
if (clear_lctx_inputs) {
|
||||
lctx.inp_tokens = nullptr;
|
||||
lctx.inp_embd = nullptr;
|
||||
lctx.inp_pos = nullptr;
|
||||
lctx.inp_out_ids = nullptr;
|
||||
lctx.inp_KQ_mask = nullptr;
|
||||
lctx.inp_KQ_mask_swa = nullptr;
|
||||
lctx.inp_K_shift = nullptr;
|
||||
lctx.inp_mean = nullptr;
|
||||
lctx.inp_cls = nullptr;
|
||||
lctx.inp_s_copy = nullptr;
|
||||
lctx.inp_s_mask = nullptr;
|
||||
lctx.inp_s_seq = nullptr;
|
||||
lctx.inp_s_seq_qnext = nullptr;
|
||||
lctx.inp_pos_bucket = nullptr;
|
||||
lctx.inp_embd_enc = nullptr;
|
||||
lctx.inp_KQ_mask_cross = nullptr;
|
||||
lctx.dflash.inputs.target_features = nullptr;
|
||||
lctx.dflash.inputs.pos_ctx = nullptr;
|
||||
lctx.dflash.inputs.kq_mask = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void llm_build_context::free() {
|
||||
@ -2199,6 +2207,80 @@ struct ggml_cgraph * llm_build_context::llama_build_graph_s_copy(llama_context &
|
||||
return result;
|
||||
}
|
||||
|
||||
struct ggml_cgraph * llm_build_context::llama_build_graph_dflash_kv_cache(llama_context & lctx) {
|
||||
llama_batch dummy;
|
||||
dummy.n_tokens = 0;
|
||||
|
||||
llm_build_cb cb = [&](struct ggml_tensor * cur, const char * name, int il) {
|
||||
if (il >= 0) {
|
||||
int j = 0;
|
||||
for (; j < GGML_MAX_NAME - 1; ++j) {
|
||||
cur->name[j] = name[j];
|
||||
if (!name[j]) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (j < GGML_MAX_NAME - 3) {
|
||||
cur->name[j++] = '-';
|
||||
auto sil = std::to_string(il);
|
||||
for (int k = 0; k < (int) sil.size() && j < GGML_MAX_NAME - 1; ++k) {
|
||||
cur->name[j++] = sil[k];
|
||||
}
|
||||
}
|
||||
cur->name[j] = 0;
|
||||
} else {
|
||||
ggml_set_name(cur, name);
|
||||
}
|
||||
};
|
||||
|
||||
struct llm_build_context llm(lctx, dummy, cb, false, false, 0, false, &lctx.dflash.kv.cache_compute_meta);
|
||||
|
||||
llm.init();
|
||||
|
||||
struct ggml_cgraph * result = llm.build_dflash_kv_cache();
|
||||
|
||||
llm.free();
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
struct ggml_cgraph * llm_build_context::llama_build_graph_dflash_kv_workspace(llama_context & lctx) {
|
||||
llama_batch dummy;
|
||||
dummy.n_tokens = 0;
|
||||
|
||||
llm_build_cb cb = [&](struct ggml_tensor * cur, const char * name, int il) {
|
||||
if (il >= 0) {
|
||||
int j = 0;
|
||||
for (; j < GGML_MAX_NAME - 1; ++j) {
|
||||
cur->name[j] = name[j];
|
||||
if (!name[j]) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (j < GGML_MAX_NAME - 3) {
|
||||
cur->name[j++] = '-';
|
||||
auto sil = std::to_string(il);
|
||||
for (int k = 0; k < (int) sil.size() && j < GGML_MAX_NAME - 1; ++k) {
|
||||
cur->name[j++] = sil[k];
|
||||
}
|
||||
}
|
||||
cur->name[j] = 0;
|
||||
} else {
|
||||
ggml_set_name(cur, name);
|
||||
}
|
||||
};
|
||||
|
||||
struct llm_build_context llm(lctx, dummy, cb, false, false, 0, false, &lctx.dflash.kv.workspace_compute_meta);
|
||||
|
||||
llm.init();
|
||||
|
||||
struct ggml_cgraph * result = llm.build_dflash_kv_workspace();
|
||||
|
||||
llm.free();
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
ggml_cgraph * llm_build_context::llama_build_graph(
|
||||
llama_context & lctx,
|
||||
const llama_batch & batch,
|
||||
@ -2415,6 +2497,10 @@ ggml_cgraph * llm_build_context::llama_build_graph(
|
||||
{
|
||||
result = llm.build_gemma4_mtp();
|
||||
} break;
|
||||
case LLM_ARCH_DFLASH_DRAFT:
|
||||
{
|
||||
result = llm.build_dflash();
|
||||
} break;
|
||||
case LLM_ARCH_STARCODER2:
|
||||
{
|
||||
result = llm.build_starcoder2();
|
||||
|
||||
@ -89,6 +89,7 @@ struct llm_build_context {
|
||||
|
||||
const enum llama_pooling_type pooling_type;
|
||||
const enum llama_rope_type rope_type;
|
||||
const bool clear_lctx_inputs;
|
||||
|
||||
const llm_build_cb & cb;
|
||||
|
||||
@ -103,7 +104,9 @@ struct llm_build_context {
|
||||
const llm_build_cb & cb,
|
||||
bool worst_case,
|
||||
bool warmup,
|
||||
int n_outputs = 0);
|
||||
int n_outputs = 0,
|
||||
bool clear_lctx_inputs = true,
|
||||
std::vector<uint8_t> * buf_compute_meta_override = nullptr);
|
||||
|
||||
void init();
|
||||
|
||||
@ -244,6 +247,12 @@ struct llm_build_context {
|
||||
|
||||
ggml_cgraph * build_gemma4_mtp();
|
||||
|
||||
ggml_cgraph * build_dflash();
|
||||
|
||||
ggml_cgraph * build_dflash_kv_cache();
|
||||
|
||||
ggml_cgraph * build_dflash_kv_workspace();
|
||||
|
||||
ggml_cgraph * build_starcoder2();
|
||||
|
||||
ggml_cgraph * build_mamba();
|
||||
@ -463,6 +472,10 @@ llm_expert_gating_func_type gating_op,
|
||||
|
||||
static ggml_cgraph * llama_build_graph_s_copy(llama_context & lctx);
|
||||
|
||||
static ggml_cgraph * llama_build_graph_dflash_kv_cache(llama_context & lctx);
|
||||
|
||||
static ggml_cgraph * llama_build_graph_dflash_kv_workspace(llama_context & lctx);
|
||||
|
||||
static ggml_cgraph * llama_build_graph(llama_context & lctx, const llama_batch & batch, bool worst_case, int n_outputs = 0);
|
||||
|
||||
ggml_tensor * build_std_attention(ggml_cgraph * gf, ggml_tensor * attn_norm, ggml_tensor * cur,
|
||||
|
||||
@ -278,6 +278,102 @@ struct llama_context {
|
||||
size_t draft_input_hidden_state_n_floats = 0;
|
||||
std::vector<float> draft_input_hidden_state_owned;
|
||||
|
||||
struct dflash_runtime {
|
||||
struct target_window_state {
|
||||
const float * features = nullptr;
|
||||
size_t features_n_floats = 0;
|
||||
int32_t features_n_rows = 0;
|
||||
const float * append_features = nullptr;
|
||||
size_t append_features_n_floats = 0;
|
||||
int32_t append_features_n_rows = 0;
|
||||
const llama_pos * positions = nullptr;
|
||||
size_t positions_n = 0;
|
||||
uint64_t version = 0;
|
||||
int32_t keep_rows = 0;
|
||||
int32_t append_rows = 0;
|
||||
bool replace = false;
|
||||
std::vector<float> features_owned;
|
||||
std::vector<float> append_features_owned;
|
||||
std::vector<llama_pos> positions_owned;
|
||||
std::vector<float> features_padded;
|
||||
std::vector<llama_pos> pos_ctx_data;
|
||||
std::vector<float> kq_mask_data;
|
||||
std::vector<float> kq_mask_swa_data;
|
||||
};
|
||||
|
||||
struct kv_runtime_state {
|
||||
std::vector<struct ggml_tensor *> k_ctx_cache;
|
||||
std::vector<struct ggml_tensor *> v_ctx_cache;
|
||||
std::vector<struct ggml_tensor *> k_ctx_workspace;
|
||||
std::vector<struct ggml_tensor *> v_ctx_workspace;
|
||||
struct ggml_context * cache_ctx = nullptr;
|
||||
std::vector<ggml_backend_buffer_t> cache_bufs;
|
||||
int32_t cache_write_pos = 0;
|
||||
int32_t cache_n_filled = 0;
|
||||
int32_t cache_update_rows = 0;
|
||||
int32_t cache_reserved_rows = 0;
|
||||
int32_t cache_view_write_pos = 0;
|
||||
int32_t cache_view_n_filled = 0;
|
||||
uint64_t cache_applied_window_version = 0;
|
||||
bool cache_valid = false;
|
||||
bool cache_view_valid = false;
|
||||
int32_t workspace_write_pos = 0;
|
||||
int32_t workspace_n_filled = 0;
|
||||
int32_t workspace_reserved_rows = 0;
|
||||
int32_t workspace_token_capacity = 0;
|
||||
int32_t workspace_n_kv_total = 0;
|
||||
uint64_t workspace_applied_window_version = 0;
|
||||
bool workspace_valid = false;
|
||||
bool workspace_sync_pending = false;
|
||||
std::vector<uint8_t> cache_compute_meta;
|
||||
std::vector<uint8_t> workspace_compute_meta;
|
||||
ggml_backend_sched_t cache_sched = nullptr;
|
||||
ggml_backend_sched_t workspace_sched = nullptr;
|
||||
ggml_cgraph * cache_graph = nullptr;
|
||||
ggml_cgraph * workspace_graph = nullptr;
|
||||
int32_t cache_graph_rows = 0;
|
||||
int32_t cache_graph_write_pos = 0;
|
||||
int32_t workspace_graph_rows = 0;
|
||||
int32_t workspace_graph_write_pos = 0;
|
||||
struct ggml_tensor * cache_input_target_features = nullptr;
|
||||
struct ggml_tensor * cache_input_pos_ctx = nullptr;
|
||||
struct ggml_tensor * kq_mask_tensor = nullptr;
|
||||
struct ggml_tensor * kq_mask_swa_tensor = nullptr;
|
||||
};
|
||||
|
||||
struct capture_state {
|
||||
std::vector<int32_t> layer_ids;
|
||||
std::vector<std::vector<float>> layer_rows;
|
||||
int32_t row_count = 0;
|
||||
int32_t row_width = 0;
|
||||
uint64_t capture_batch_id = 0;
|
||||
std::vector<uint64_t> layer_seen_batch_id;
|
||||
ggml_backend_sched_eval_callback prev_cb_eval = nullptr;
|
||||
void * prev_cb_eval_user_data = nullptr;
|
||||
};
|
||||
|
||||
struct input_state {
|
||||
struct ggml_tensor * target_features = nullptr; // F32 [n_target_features, cross_ctx]
|
||||
struct ggml_tensor * pos_ctx = nullptr; // I32 [cross_ctx]
|
||||
struct ggml_tensor * kq_mask = nullptr; // F32 [cross_ctx + n_batch, GGML_PAD(n_batch)]
|
||||
struct ggml_tensor * kq_mask_swa = nullptr; // F32 [cross_ctx + n_batch, GGML_PAD(n_batch)]
|
||||
};
|
||||
|
||||
target_window_state target;
|
||||
kv_runtime_state kv;
|
||||
std::unique_ptr<capture_state> capture;
|
||||
std::vector<float> feature_view_buffer;
|
||||
input_state inputs;
|
||||
int32_t visible_cross_ctx = 0;
|
||||
|
||||
// Argmax token IDs from the DFlash draft graph, computed via GPU argmax.
|
||||
// Populated in llama_decode_internal after graph compute.
|
||||
std::vector<llama_token> draft_tokens;
|
||||
struct ggml_tensor * draft_tokens_tensor = nullptr;
|
||||
};
|
||||
dflash_runtime dflash;
|
||||
using dflash_capture_state = dflash_runtime::capture_state;
|
||||
|
||||
// input tensors
|
||||
struct ggml_tensor * inp_tokens; // I32 [n_batch]
|
||||
struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch]
|
||||
@ -315,6 +411,9 @@ struct llama_context {
|
||||
|
||||
bool update_cache_copies();
|
||||
|
||||
bool ensure_dflash_kv_cache_tensors(int32_t cross_ctx);
|
||||
void free_dflash_kv_cache_tensors();
|
||||
|
||||
bool prepare_mtp_graph_inputs(
|
||||
struct llama_context & lctx);
|
||||
void set_mtp_op_type(llama_mtp_op_type value);
|
||||
@ -322,4 +421,3 @@ struct llama_context {
|
||||
int max_nodes(int n_tokens, int n_kv) const;
|
||||
|
||||
};
|
||||
|
||||
|
||||
694
src/llama-dflash.cpp
Normal file
694
src/llama-dflash.cpp
Normal file
@ -0,0 +1,694 @@
|
||||
#include "llama-dflash.h"
|
||||
|
||||
#include "llama-impl.h"
|
||||
#include "llama-build-context.h"
|
||||
#include "llama-context.h"
|
||||
#include "llama-model.h"
|
||||
#include "llama-spec-features.h"
|
||||
|
||||
#include "ggml.h"
|
||||
#include "ggml-backend.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
void llama_sync_dflash_workspace_if_pending(struct llama_context & lctx) {
|
||||
if (!lctx.dflash.kv.workspace_sync_pending || lctx.dflash.kv.workspace_sched == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
ggml_backend_sched_synchronize(lctx.dflash.kv.workspace_sched);
|
||||
lctx.dflash.kv.workspace_sync_pending = false;
|
||||
}
|
||||
|
||||
static ggml_backend_buffer_type_t llama_dflash_kv_cache_layer_buft(const llama_context & lctx, int32_t il) {
|
||||
if (il >= 0 && (size_t) il < lctx.model.buft_layer.size() && lctx.model.buft_layer[(size_t) il].buft != nullptr) {
|
||||
return lctx.model.buft_layer[(size_t) il].buft;
|
||||
}
|
||||
|
||||
if (il >= 0 && (size_t) il < lctx.model.layers.size()) {
|
||||
const ggml_tensor * wk = lctx.model.layers[(size_t) il].wk;
|
||||
if (wk != nullptr && wk->buffer != nullptr) {
|
||||
return ggml_backend_buffer_get_type(wk->buffer);
|
||||
}
|
||||
}
|
||||
|
||||
return llama_default_buffer_type_cpu(true);
|
||||
}
|
||||
|
||||
static ggml_backend_t llama_backend_for_tensor(const llama_context & lctx, const ggml_tensor * tensor) {
|
||||
if (tensor == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
|
||||
if (buf == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ggml_backend_buffer_type_t buft = ggml_backend_buffer_get_type(buf);
|
||||
for (ggml_backend_t backend : lctx.backends) {
|
||||
ggml_backend_buffer_type_t backend_buft = ggml_backend_is_cpu(backend)
|
||||
? llama_default_buffer_type_cpu(true)
|
||||
: ggml_backend_get_default_buffer_type(backend);
|
||||
if (backend_buft == buft) {
|
||||
return backend;
|
||||
}
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool llama_context::ensure_dflash_kv_cache_tensors(int32_t cross_ctx) {
|
||||
const int32_t target_cross_ctx = std::max<int32_t>(1, cross_ctx);
|
||||
const int32_t target_token_capacity = std::max<int32_t>(1, (int32_t) model.hparams.dflash_block_size);
|
||||
const int32_t target_workspace_n_kv_total = GGML_PAD(target_cross_ctx + target_token_capacity, cparams.flash_attn ? 256 : 32);
|
||||
const int32_t n_layer = model.hparams.n_layer;
|
||||
const int64_t n_embd_head_k = model.hparams.n_embd_head_k(0);
|
||||
const int64_t n_embd_head_v = model.hparams.n_embd_head_v(0);
|
||||
const int64_t n_head_kv = model.hparams.n_head_kv();
|
||||
|
||||
if (dflash.kv.cache_ctx != nullptr &&
|
||||
(int32_t) dflash.kv.k_ctx_cache.size() == n_layer &&
|
||||
(int32_t) dflash.kv.k_ctx_workspace.size() == n_layer) {
|
||||
const bool cache_matches =
|
||||
(int32_t) dflash.kv.k_ctx_cache.front()->ne[2] == target_cross_ctx;
|
||||
const bool workspace_matches =
|
||||
(int32_t) dflash.kv.k_ctx_workspace.front()->ne[1] == target_workspace_n_kv_total;
|
||||
|
||||
if (cache_matches && workspace_matches) {
|
||||
return true;
|
||||
}
|
||||
|
||||
free_dflash_kv_cache_tensors();
|
||||
if (dflash.kv.cache_sched != nullptr) {
|
||||
ggml_backend_sched_free(dflash.kv.cache_sched);
|
||||
dflash.kv.cache_sched = nullptr;
|
||||
}
|
||||
if (dflash.kv.workspace_sched != nullptr) {
|
||||
ggml_backend_sched_free(dflash.kv.workspace_sched);
|
||||
dflash.kv.workspace_sched = nullptr;
|
||||
}
|
||||
dflash.kv.cache_graph = nullptr;
|
||||
dflash.kv.workspace_graph = nullptr;
|
||||
dflash.kv.cache_graph_rows = 0;
|
||||
dflash.kv.cache_graph_write_pos = 0;
|
||||
dflash.kv.workspace_graph_rows = 0;
|
||||
dflash.kv.workspace_graph_write_pos = 0;
|
||||
dflash.kv.workspace_reserved_rows = 0;
|
||||
}
|
||||
|
||||
ggml_init_params params = {
|
||||
/*.mem_size =*/ (size_t) (4 * std::max(1, n_layer)) * ggml_tensor_overhead(),
|
||||
/*.mem_buffer =*/ nullptr,
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
|
||||
dflash.kv.cache_ctx = ggml_init(params);
|
||||
if (dflash.kv.cache_ctx == nullptr) {
|
||||
LLAMA_LOG_ERROR("%s: failed to allocate DFlash K/V cache context\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
dflash.kv.k_ctx_cache.resize((size_t) n_layer);
|
||||
dflash.kv.v_ctx_cache.resize((size_t) n_layer);
|
||||
dflash.kv.k_ctx_workspace.clear();
|
||||
dflash.kv.v_ctx_workspace.clear();
|
||||
dflash.kv.k_ctx_workspace.resize((size_t) n_layer);
|
||||
dflash.kv.v_ctx_workspace.resize((size_t) n_layer);
|
||||
dflash.kv.cache_bufs.clear();
|
||||
dflash.kv.cache_bufs.reserve((size_t) std::max(1, n_layer) * 4);
|
||||
for (int32_t il = 0; il < n_layer; ++il) {
|
||||
ggml_backend_buffer_type_t layer_buft = llama_dflash_kv_cache_layer_buft(*this, il);
|
||||
auto alloc_kv_input = [&](ggml_tensor *& tensor, const char * tensor_tag, const char * tensor_name,
|
||||
int64_t ne0, int64_t ne1, int64_t ne2) -> bool {
|
||||
tensor = ggml_new_tensor_3d(dflash.kv.cache_ctx, GGML_TYPE_F32, ne0, ne1, ne2);
|
||||
if (tensor == nullptr) {
|
||||
LLAMA_LOG_ERROR("%s: failed to create %s for layer %d\n", __func__, tensor_tag, il);
|
||||
return false;
|
||||
}
|
||||
|
||||
ggml_set_input(tensor);
|
||||
ggml_format_name(tensor, tensor_name, il);
|
||||
|
||||
const size_t tensor_bytes = ggml_backend_buft_get_alloc_size(layer_buft, tensor);
|
||||
ggml_backend_buffer_t buf = ggml_backend_buft_alloc_buffer(layer_buft, tensor_bytes);
|
||||
if (buf == nullptr) {
|
||||
LLAMA_LOG_ERROR("%s: failed to allocate %s buffer for layer %d (%zu bytes)\n",
|
||||
__func__, tensor_tag, il, tensor_bytes);
|
||||
return false;
|
||||
}
|
||||
|
||||
ggml_backend_buffer_set_usage(buf, GGML_BACKEND_BUFFER_USAGE_COMPUTE);
|
||||
ggml_backend_tensor_alloc(buf, tensor, ggml_backend_buffer_get_base(buf));
|
||||
ggml_backend_buffer_clear(buf, 0);
|
||||
dflash.kv.cache_bufs.push_back(buf);
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
if (!alloc_kv_input(dflash.kv.k_ctx_cache[(size_t) il], "dflash_k_ctx_cache", "dflash_k_ctx_cache_%d",
|
||||
n_embd_head_k, n_head_kv, target_cross_ctx) ||
|
||||
!alloc_kv_input(dflash.kv.v_ctx_cache[(size_t) il], "dflash_v_ctx_cache", "dflash_v_ctx_cache_%d",
|
||||
n_embd_head_v, n_head_kv, target_cross_ctx) ||
|
||||
!alloc_kv_input(dflash.kv.k_ctx_workspace[(size_t) il], "dflash_k_ctx_workspace", "dflash_k_ctx_workspace_%d",
|
||||
n_embd_head_k, target_workspace_n_kv_total, n_head_kv) ||
|
||||
!alloc_kv_input(dflash.kv.v_ctx_workspace[(size_t) il], "dflash_v_ctx_workspace", "dflash_v_ctx_workspace_%d",
|
||||
n_embd_head_v, target_workspace_n_kv_total, n_head_kv)) {
|
||||
free_dflash_kv_cache_tensors();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
dflash.kv.workspace_token_capacity = target_token_capacity;
|
||||
dflash.kv.workspace_n_kv_total = target_workspace_n_kv_total;
|
||||
llama_reset_dflash_kv_cache_state(this);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void llama_context::free_dflash_kv_cache_tensors() {
|
||||
auto release_vector = [](auto & v) {
|
||||
using vec_type = std::decay_t<decltype(v)>;
|
||||
vec_type().swap(v);
|
||||
};
|
||||
|
||||
release_vector(dflash.kv.k_ctx_cache);
|
||||
release_vector(dflash.kv.v_ctx_cache);
|
||||
release_vector(dflash.kv.k_ctx_workspace);
|
||||
release_vector(dflash.kv.v_ctx_workspace);
|
||||
dflash.kv.cache_write_pos = 0;
|
||||
dflash.kv.cache_n_filled = 0;
|
||||
dflash.kv.cache_update_rows = 0;
|
||||
dflash.kv.cache_reserved_rows = 0;
|
||||
dflash.kv.cache_view_write_pos = 0;
|
||||
dflash.kv.cache_view_n_filled = 0;
|
||||
dflash.kv.cache_applied_window_version = 0;
|
||||
dflash.kv.cache_valid = false;
|
||||
dflash.kv.cache_view_valid = false;
|
||||
dflash.kv.workspace_write_pos = 0;
|
||||
dflash.kv.workspace_n_filled = 0;
|
||||
dflash.kv.workspace_reserved_rows = 0;
|
||||
dflash.kv.workspace_token_capacity = 0;
|
||||
dflash.kv.workspace_n_kv_total = 0;
|
||||
dflash.kv.workspace_applied_window_version = 0;
|
||||
dflash.kv.workspace_valid = false;
|
||||
dflash.kv.workspace_sync_pending = false;
|
||||
dflash.kv.cache_graph = nullptr;
|
||||
dflash.kv.workspace_graph = nullptr;
|
||||
dflash.kv.cache_graph_rows = 0;
|
||||
dflash.kv.cache_graph_write_pos = 0;
|
||||
dflash.kv.workspace_graph_rows = 0;
|
||||
dflash.kv.workspace_graph_write_pos = 0;
|
||||
dflash.kv.cache_input_target_features = nullptr;
|
||||
dflash.kv.cache_input_pos_ctx = nullptr;
|
||||
dflash.kv.kq_mask_tensor = nullptr;
|
||||
dflash.kv.kq_mask_swa_tensor = nullptr;
|
||||
|
||||
if (dflash.kv.workspace_sched != nullptr) {
|
||||
ggml_backend_sched_synchronize(dflash.kv.workspace_sched);
|
||||
ggml_backend_sched_free(dflash.kv.workspace_sched);
|
||||
dflash.kv.workspace_sched = nullptr;
|
||||
}
|
||||
|
||||
for (ggml_backend_buffer_t buf : dflash.kv.cache_bufs) {
|
||||
if (buf != nullptr) {
|
||||
ggml_backend_buffer_free(buf);
|
||||
}
|
||||
}
|
||||
release_vector(dflash.kv.cache_bufs);
|
||||
release_vector(dflash.kv.cache_compute_meta);
|
||||
release_vector(dflash.kv.workspace_compute_meta);
|
||||
if (dflash.kv.cache_ctx != nullptr) {
|
||||
ggml_free(dflash.kv.cache_ctx);
|
||||
dflash.kv.cache_ctx = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
static void llama_graph_compute_sched(
|
||||
llama_context & lctx,
|
||||
ggml_backend_sched_t sched,
|
||||
ggml_cgraph * gf,
|
||||
int n_threads) {
|
||||
#ifdef GGML_USE_METAL
|
||||
if (ggml_backend_is_metal(lctx.backend_metal)) {
|
||||
ggml_backend_metal_set_n_cb(lctx.backend_metal, n_threads);
|
||||
}
|
||||
#endif
|
||||
|
||||
if (lctx.backend_cpu != nullptr) {
|
||||
ggml_backend_cpu_set_n_threads(lctx.backend_cpu, n_threads);
|
||||
ggml_backend_cpu_set_abort_callback(lctx.backend_cpu, lctx.abort_callback, lctx.abort_callback_data);
|
||||
}
|
||||
#ifdef GGML_USE_BLAS
|
||||
if (lctx.backend_blas != nullptr) {
|
||||
ggml_backend_blas_set_n_threads(lctx.backend_blas, n_threads);
|
||||
}
|
||||
#endif
|
||||
|
||||
ggml_backend_sched_graph_compute_async(sched, gf);
|
||||
}
|
||||
|
||||
static bool dflash_layer_has_attention_bias(const llama_layer & layer) {
|
||||
return layer.bq != nullptr ||
|
||||
layer.bk != nullptr ||
|
||||
layer.bv != nullptr ||
|
||||
layer.bo != nullptr ||
|
||||
layer.bqkv != nullptr ||
|
||||
layer.bqk != nullptr ||
|
||||
layer.bkv != nullptr;
|
||||
}
|
||||
|
||||
static bool validate_dflash_graph_contract(const llama_context & lctx) {
|
||||
const auto & model = lctx.model;
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
auto rope_dim_for_layer = [&hparams](int32_t il) -> uint32_t {
|
||||
if (hparams.rope_dim_per_layer[(size_t) il] != 0) {
|
||||
return hparams.rope_dim_per_layer[(size_t) il];
|
||||
}
|
||||
|
||||
return hparams.swa_layers[(size_t) il] ? hparams.n_rot_swa : hparams.n_rot;
|
||||
};
|
||||
|
||||
auto rope_base_for_layer = [&hparams](int32_t il) -> float {
|
||||
if (hparams.has_rope_freq_base_per_layer) {
|
||||
return hparams.rope_freq_base_per_layer[(size_t) il];
|
||||
}
|
||||
|
||||
return hparams.swa_layers[(size_t) il] ? hparams.rope_freq_base_train_swa : hparams.rope_freq_base_train;
|
||||
};
|
||||
|
||||
auto rope_scale_for_layer = [&hparams](int32_t il) -> float {
|
||||
return hparams.swa_layers[(size_t) il] ? hparams.rope_freq_scale_train_swa : hparams.rope_freq_scale_train;
|
||||
};
|
||||
|
||||
const uint32_t ref_n_head = hparams.n_head(0);
|
||||
const uint32_t ref_n_head_kv = hparams.n_head_kv(0);
|
||||
const uint32_t ref_n_embd_head_k = hparams.n_embd_head_k(0);
|
||||
const uint32_t ref_n_embd_head_v = hparams.n_embd_head_v(0);
|
||||
const uint32_t ref_rope_dim = rope_dim_for_layer(0);
|
||||
const float ref_rope_base = rope_base_for_layer(0);
|
||||
const float ref_rope_scale = rope_scale_for_layer(0);
|
||||
|
||||
for (int32_t il = 0; il < (int32_t) hparams.n_layer; ++il) {
|
||||
if (hparams.n_head((uint32_t) il) != ref_n_head ||
|
||||
hparams.n_head_kv((uint32_t) il) != ref_n_head_kv ||
|
||||
hparams.n_embd_head_k(il) != ref_n_embd_head_k ||
|
||||
hparams.n_embd_head_v(il) != ref_n_embd_head_v) {
|
||||
LLAMA_LOG_ERROR("%s: DFlash graph assumes layer-invariant head config, but layer %d differs (n_head=%u/%u n_head_kv=%u/%u head_k=%u/%u head_v=%u/%u)\n",
|
||||
__func__,
|
||||
il,
|
||||
hparams.n_head((uint32_t) il), ref_n_head,
|
||||
hparams.n_head_kv((uint32_t) il), ref_n_head_kv,
|
||||
hparams.n_embd_head_k(il), ref_n_embd_head_k,
|
||||
hparams.n_embd_head_v(il), ref_n_embd_head_v);
|
||||
return false;
|
||||
}
|
||||
|
||||
const uint32_t rope_dim = rope_dim_for_layer(il);
|
||||
const float rope_base = rope_base_for_layer(il);
|
||||
const float rope_scale = rope_scale_for_layer(il);
|
||||
if (rope_dim != ref_rope_dim || std::fabs(rope_base - ref_rope_base) > 1e-6f || std::fabs(rope_scale - ref_rope_scale) > 1e-6f) {
|
||||
LLAMA_LOG_ERROR("%s: DFlash graph assumes layer-invariant RoPE config, but layer %d differs (dim=%u/%u base=%g/%g scale=%g/%g)\n",
|
||||
__func__,
|
||||
il,
|
||||
rope_dim, ref_rope_dim,
|
||||
(double) rope_base, (double) ref_rope_base,
|
||||
(double) rope_scale, (double) ref_rope_scale);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (model.layers[(size_t) il].attn_norm == nullptr ||
|
||||
model.layers[(size_t) il].attn_q_norm == nullptr ||
|
||||
model.layers[(size_t) il].attn_k_norm == nullptr) {
|
||||
LLAMA_LOG_ERROR("%s: DFlash graph requires attn_norm, attn_q_norm, and attn_k_norm weights, but layer %d is missing one or more of them\n",
|
||||
__func__, il);
|
||||
return false;
|
||||
}
|
||||
|
||||
const bool has_q_norm = model.layers[(size_t) il].attn_q_norm != nullptr;
|
||||
const bool has_k_norm = model.layers[(size_t) il].attn_k_norm != nullptr;
|
||||
if (has_q_norm != has_k_norm) {
|
||||
LLAMA_LOG_ERROR("%s: DFlash graph requires symmetric Q/K norm presence, but layer %d has q_norm=%d k_norm=%d\n",
|
||||
__func__, il, (int) has_q_norm, (int) has_k_norm);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (model.layers[(size_t) il].attn_norm_b != nullptr ||
|
||||
model.layers[(size_t) il].attn_q_norm_b != nullptr ||
|
||||
model.layers[(size_t) il].attn_k_norm_b != nullptr) {
|
||||
LLAMA_LOG_ERROR("%s: DFlash graph does not implement norm-bias tensors, but layer %d requires attn_norm_b/q_norm_b/k_norm_b\n",
|
||||
__func__, il);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (dflash_layer_has_attention_bias(model.layers[(size_t) il])) {
|
||||
LLAMA_LOG_ERROR("%s: DFlash graph does not implement attention bias tensors, but layer %d requires them\n",
|
||||
__func__, il);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool llama_prepare_dflash_graph_inputs(
|
||||
struct llama_context & lctx,
|
||||
uint32_t n_tokens) {
|
||||
const int32_t cross_ctx = lctx.dflash.visible_cross_ctx > 0
|
||||
? lctx.dflash.visible_cross_ctx
|
||||
: std::max<int32_t>(1, (int32_t) lctx.cparams.n_ctx - (int32_t) lctx.model.hparams.dflash_block_size);
|
||||
ggml_tensor * kq_mask = lctx.dflash.kv.kq_mask_tensor;
|
||||
ggml_tensor * kq_mask_swa = lctx.dflash.kv.kq_mask_swa_tensor;
|
||||
|
||||
if (kq_mask == nullptr) {
|
||||
LLAMA_LOG_ERROR("%s: DFlash graph inputs are not initialized\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!validate_dflash_graph_contract(lctx)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!lctx.ensure_dflash_kv_cache_tensors(cross_ctx) || lctx.dflash.kv.k_ctx_cache.empty() || lctx.dflash.kv.v_ctx_cache.empty()) {
|
||||
LLAMA_LOG_ERROR("%s: DFlash K/V cache inputs are not initialized\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
const float * src = lctx.dflash.target.features;
|
||||
const float * append_src = lctx.dflash.target.append_features;
|
||||
const llama_pos * src_pos = lctx.dflash.target.positions;
|
||||
const size_t total_floats = lctx.dflash.target.features_n_floats;
|
||||
const size_t append_floats = lctx.dflash.target.append_features_n_floats;
|
||||
const size_t total_positions = lctx.dflash.target.positions_n;
|
||||
const int32_t n_rows = lctx.dflash.target.features_n_rows;
|
||||
const int32_t append_rows_available = lctx.dflash.target.append_features_n_rows;
|
||||
const int32_t width = (int32_t) lctx.model.hparams.dflash_n_target_features;
|
||||
const int32_t graph_cross_ctx = lctx.dflash.kv.k_ctx_cache.front() != nullptr
|
||||
? (int32_t) lctx.dflash.kv.k_ctx_cache.front()->ne[2]
|
||||
: 0;
|
||||
const int32_t n_mask_tokens = (int32_t) kq_mask->ne[1];
|
||||
const int32_t n_kv_total = (int32_t) kq_mask->ne[0];
|
||||
|
||||
llama_sync_dflash_workspace_if_pending(lctx);
|
||||
|
||||
if (graph_cross_ctx != cross_ctx) {
|
||||
LLAMA_LOG_ERROR("%s: DFlash graph cross_ctx drift (graph=%d configured=%d)\n",
|
||||
__func__, graph_cross_ctx, cross_ctx);
|
||||
return false;
|
||||
}
|
||||
if (n_rows <= 0) {
|
||||
LLAMA_LOG_ERROR("%s: missing DFlash target feature rows\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
const bool have_full_src = src != nullptr && total_floats == (size_t) n_rows * (size_t) width;
|
||||
if (n_rows > cross_ctx || (src != nullptr && !have_full_src)) {
|
||||
LLAMA_LOG_ERROR("%s: invalid DFlash target feature shape (rows=%d width=%d floats=%zu cross_ctx=%d)\n",
|
||||
__func__, n_rows, width, total_floats, cross_ctx);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (n_kv_total < cross_ctx + (int32_t) n_tokens) {
|
||||
LLAMA_LOG_ERROR("%s: invalid DFlash mask shape (n_kv_total=%d < cross_ctx+n_tokens=%d)\n",
|
||||
__func__, n_kv_total, cross_ctx + (int32_t) n_tokens);
|
||||
return false;
|
||||
}
|
||||
|
||||
const int32_t left_pad = cross_ctx - n_rows;
|
||||
|
||||
lctx.dflash.target.pos_ctx_data.resize((size_t) cross_ctx);
|
||||
std::fill(lctx.dflash.target.pos_ctx_data.begin(), lctx.dflash.target.pos_ctx_data.end(), 0);
|
||||
if (src_pos == nullptr || total_positions != (size_t) n_rows) {
|
||||
LLAMA_LOG_ERROR("%s: missing DFlash target positions (rows=%d positions=%zu cross_ctx=%d)\n",
|
||||
__func__, n_rows, total_positions, cross_ctx);
|
||||
return false;
|
||||
}
|
||||
|
||||
const llama_pos last_target_pos = src_pos[n_rows - 1];
|
||||
for (int32_t i = 1; i < n_rows; ++i) {
|
||||
if (src_pos[i] <= src_pos[i - 1]) {
|
||||
LLAMA_LOG_ERROR("%s: DFlash target positions are not strictly increasing (rows=%d first=%d last=%d)\n",
|
||||
__func__, n_rows, (int) src_pos[0], (int) src_pos[n_rows - 1]);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
std::copy(src_pos, src_pos + n_rows, lctx.dflash.target.pos_ctx_data.begin() + (ptrdiff_t) left_pad);
|
||||
|
||||
const llama_dflash_kv_cache_transition cache_plan = llama_plan_dflash_kv_cache_transition(
|
||||
cross_ctx,
|
||||
lctx.dflash.kv.cache_n_filled,
|
||||
lctx.dflash.kv.cache_write_pos,
|
||||
lctx.dflash.kv.cache_valid,
|
||||
lctx.dflash.kv.cache_applied_window_version,
|
||||
lctx.dflash.target.version,
|
||||
lctx.dflash.target.keep_rows,
|
||||
lctx.dflash.target.append_rows,
|
||||
lctx.dflash.target.replace,
|
||||
n_rows);
|
||||
|
||||
const bool have_append_src = append_src != nullptr &&
|
||||
append_rows_available == cache_plan.append_rows &&
|
||||
append_floats == (size_t) cache_plan.append_rows * (size_t) width;
|
||||
|
||||
const int32_t update_rows = cache_plan.cache_up_to_date
|
||||
? 0
|
||||
: (cache_plan.rebuild_cache ? n_rows : cache_plan.append_rows);
|
||||
const size_t max_nodes = lctx.model.max_nodes((int) std::max<int32_t>(1, cross_ctx)) + 24 * lctx.model.hparams.n_layer;
|
||||
const size_t meta_size = ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false);
|
||||
if (lctx.dflash.kv.cache_compute_meta.size() != meta_size) {
|
||||
lctx.dflash.kv.cache_compute_meta.resize(meta_size);
|
||||
}
|
||||
|
||||
if (lctx.dflash.kv.cache_sched == nullptr || lctx.dflash.kv.cache_reserved_rows != cross_ctx) {
|
||||
std::vector<ggml_backend_buffer_type_t> backend_buft;
|
||||
backend_buft.reserve(lctx.backends.size());
|
||||
for (auto * backend : lctx.backends) {
|
||||
if (ggml_backend_is_cpu(backend)) {
|
||||
backend_buft.push_back(llama_default_buffer_type_cpu(true));
|
||||
} else {
|
||||
backend_buft.push_back(ggml_backend_get_default_buffer_type(backend));
|
||||
}
|
||||
}
|
||||
|
||||
if (lctx.dflash.kv.cache_sched != nullptr) {
|
||||
ggml_backend_sched_free(lctx.dflash.kv.cache_sched);
|
||||
lctx.dflash.kv.cache_sched = nullptr;
|
||||
}
|
||||
lctx.dflash.kv.cache_graph = nullptr;
|
||||
lctx.dflash.kv.cache_graph_rows = 0;
|
||||
lctx.dflash.kv.cache_graph_write_pos = 0;
|
||||
|
||||
const int32_t saved_update_rows = lctx.dflash.kv.cache_update_rows;
|
||||
lctx.dflash.kv.cache_update_rows = cross_ctx;
|
||||
ggml_cgraph * gf_reserve = llm_build_context::llama_build_graph_dflash_kv_cache(lctx);
|
||||
lctx.dflash.kv.cache_update_rows = saved_update_rows;
|
||||
if (gf_reserve == nullptr) {
|
||||
LLAMA_LOG_ERROR("%s: failed to build DFlash K/V cache reserve graph\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
lctx.dflash.kv.cache_sched = ggml_backend_sched_new(lctx.backends.data(), backend_buft.data(), lctx.backends.size(), max_nodes, false);
|
||||
const bool reserved = lctx.dflash.kv.cache_sched != nullptr && ggml_backend_sched_reserve(lctx.dflash.kv.cache_sched, gf_reserve);
|
||||
if (!reserved) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize DFlash K/V scheduler\n", __func__);
|
||||
return false;
|
||||
}
|
||||
lctx.dflash.kv.cache_reserved_rows = cross_ctx;
|
||||
}
|
||||
|
||||
if (update_rows > 0) {
|
||||
const float * update_src = nullptr;
|
||||
if (have_append_src && update_rows == cache_plan.append_rows) {
|
||||
update_src = append_src;
|
||||
} else if (have_full_src) {
|
||||
update_src = src + (size_t) (n_rows - update_rows) * (size_t) width;
|
||||
}
|
||||
const llama_pos * update_pos = src_pos + (n_rows - update_rows);
|
||||
|
||||
if (update_src == nullptr) {
|
||||
LLAMA_LOG_ERROR("%s: missing DFlash appended target features for cached update (rows=%d append_rows=%d floats=%zu)\n",
|
||||
__func__, n_rows, update_rows, append_floats);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (cache_plan.rebuild_cache) {
|
||||
llama_reset_dflash_kv_cache_state(&lctx);
|
||||
}
|
||||
|
||||
lctx.dflash.kv.cache_update_rows = update_rows;
|
||||
ggml_cgraph * gf_kv = nullptr;
|
||||
const bool can_reuse_kv_graph = lctx.dflash.kv.cache_graph != nullptr &&
|
||||
lctx.dflash.kv.cache_graph_rows == update_rows &&
|
||||
lctx.dflash.kv.cache_graph_write_pos == lctx.dflash.kv.cache_write_pos;
|
||||
if (can_reuse_kv_graph) {
|
||||
gf_kv = lctx.dflash.kv.cache_graph;
|
||||
} else {
|
||||
gf_kv = llm_build_context::llama_build_graph_dflash_kv_cache(lctx);
|
||||
if (gf_kv == nullptr || lctx.dflash.kv.cache_input_target_features == nullptr || lctx.dflash.kv.cache_input_pos_ctx == nullptr) {
|
||||
LLAMA_LOG_ERROR("%s: failed to build DFlash K/V cache graph\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
ggml_backend_sched_reset(lctx.dflash.kv.cache_sched);
|
||||
ggml_backend_sched_alloc_graph(lctx.dflash.kv.cache_sched, gf_kv);
|
||||
|
||||
lctx.dflash.kv.cache_graph = gf_kv;
|
||||
lctx.dflash.kv.cache_graph_rows = update_rows;
|
||||
lctx.dflash.kv.cache_graph_write_pos = lctx.dflash.kv.cache_write_pos;
|
||||
}
|
||||
|
||||
ggml_backend_t kv_feature_backend = llama_backend_for_tensor(lctx, lctx.dflash.kv.cache_input_target_features);
|
||||
if (kv_feature_backend != nullptr) {
|
||||
ggml_backend_tensor_set_async(kv_feature_backend, lctx.dflash.kv.cache_input_target_features, update_src, 0, ggml_nbytes(lctx.dflash.kv.cache_input_target_features));
|
||||
} else {
|
||||
ggml_backend_tensor_set(lctx.dflash.kv.cache_input_target_features, update_src, 0, ggml_nbytes(lctx.dflash.kv.cache_input_target_features));
|
||||
}
|
||||
|
||||
ggml_backend_t kv_pos_backend = llama_backend_for_tensor(lctx, lctx.dflash.kv.cache_input_pos_ctx);
|
||||
if (kv_pos_backend != nullptr) {
|
||||
ggml_backend_tensor_set_async(kv_pos_backend, lctx.dflash.kv.cache_input_pos_ctx, update_pos, 0, ggml_nbytes(lctx.dflash.kv.cache_input_pos_ctx));
|
||||
} else {
|
||||
ggml_backend_tensor_set(lctx.dflash.kv.cache_input_pos_ctx, update_pos, 0, ggml_nbytes(lctx.dflash.kv.cache_input_pos_ctx));
|
||||
}
|
||||
llama_graph_compute_sched(lctx, lctx.dflash.kv.cache_sched, gf_kv, lctx.cparams.n_threads);
|
||||
ggml_backend_sched_synchronize(lctx.dflash.kv.cache_sched);
|
||||
|
||||
lctx.dflash.kv.cache_n_filled = std::min(cross_ctx, lctx.dflash.kv.cache_n_filled + update_rows);
|
||||
lctx.dflash.kv.cache_write_pos = (lctx.dflash.kv.cache_write_pos + update_rows) % cross_ctx;
|
||||
lctx.dflash.kv.cache_applied_window_version = lctx.dflash.target.version;
|
||||
lctx.dflash.kv.cache_valid = true;
|
||||
lctx.dflash.kv.cache_view_n_filled = lctx.dflash.kv.cache_n_filled;
|
||||
lctx.dflash.kv.cache_view_write_pos = lctx.dflash.kv.cache_write_pos;
|
||||
lctx.dflash.kv.cache_view_valid = true;
|
||||
}
|
||||
|
||||
if (lctx.dflash.kv.cache_view_valid &&
|
||||
!lctx.dflash.kv.k_ctx_workspace.empty() && !lctx.dflash.kv.v_ctx_workspace.empty()) {
|
||||
const bool need_workspace_refresh = !lctx.dflash.kv.workspace_valid ||
|
||||
lctx.dflash.kv.workspace_n_filled != lctx.dflash.kv.cache_view_n_filled ||
|
||||
lctx.dflash.kv.workspace_write_pos != lctx.dflash.kv.cache_view_write_pos ||
|
||||
lctx.dflash.kv.workspace_applied_window_version != lctx.dflash.kv.cache_applied_window_version;
|
||||
|
||||
if (need_workspace_refresh) {
|
||||
const size_t max_nodes = lctx.model.max_nodes((int) std::max<int32_t>(1, cross_ctx)) + 16 * lctx.model.hparams.n_layer;
|
||||
const size_t meta_size = ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false);
|
||||
if (lctx.dflash.kv.workspace_compute_meta.size() != meta_size) {
|
||||
lctx.dflash.kv.workspace_compute_meta.resize(meta_size);
|
||||
}
|
||||
|
||||
ggml_cgraph * gf_workspace = nullptr;
|
||||
const bool can_reuse_workspace_graph = lctx.dflash.kv.workspace_graph != nullptr &&
|
||||
lctx.dflash.kv.workspace_graph_rows == lctx.dflash.kv.cache_view_n_filled &&
|
||||
lctx.dflash.kv.workspace_graph_write_pos == lctx.dflash.kv.cache_view_write_pos;
|
||||
|
||||
if (can_reuse_workspace_graph) {
|
||||
gf_workspace = lctx.dflash.kv.workspace_graph;
|
||||
} else {
|
||||
gf_workspace = llm_build_context::llama_build_graph_dflash_kv_workspace(lctx);
|
||||
if (gf_workspace == nullptr) {
|
||||
LLAMA_LOG_ERROR("%s: failed to build DFlash K/V workspace graph\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<ggml_backend_buffer_type_t> backend_buft;
|
||||
backend_buft.reserve(lctx.backends.size());
|
||||
for (auto * backend : lctx.backends) {
|
||||
if (ggml_backend_is_cpu(backend)) {
|
||||
backend_buft.push_back(llama_default_buffer_type_cpu(true));
|
||||
} else {
|
||||
backend_buft.push_back(ggml_backend_get_default_buffer_type(backend));
|
||||
}
|
||||
}
|
||||
|
||||
if (lctx.dflash.kv.workspace_sched == nullptr) {
|
||||
lctx.dflash.kv.workspace_sched = ggml_backend_sched_new(lctx.backends.data(), backend_buft.data(), lctx.backends.size(), max_nodes, false);
|
||||
}
|
||||
|
||||
if (lctx.dflash.kv.workspace_reserved_rows != cross_ctx) {
|
||||
const bool saved_view_valid = lctx.dflash.kv.cache_view_valid;
|
||||
const int32_t saved_view_rows = lctx.dflash.kv.cache_view_n_filled;
|
||||
const int32_t saved_view_write_pos = lctx.dflash.kv.cache_view_write_pos;
|
||||
|
||||
lctx.dflash.kv.cache_view_valid = true;
|
||||
lctx.dflash.kv.cache_view_n_filled = cross_ctx;
|
||||
lctx.dflash.kv.cache_view_write_pos = cross_ctx > 1 ? 1 : 0;
|
||||
|
||||
ggml_cgraph * gf_workspace_reserve = llm_build_context::llama_build_graph_dflash_kv_workspace(lctx);
|
||||
|
||||
lctx.dflash.kv.cache_view_valid = saved_view_valid;
|
||||
lctx.dflash.kv.cache_view_n_filled = saved_view_rows;
|
||||
lctx.dflash.kv.cache_view_write_pos = saved_view_write_pos;
|
||||
|
||||
const bool reserved = lctx.dflash.kv.workspace_sched != nullptr &&
|
||||
gf_workspace_reserve != nullptr &&
|
||||
ggml_backend_sched_reserve(lctx.dflash.kv.workspace_sched, gf_workspace_reserve);
|
||||
if (!reserved) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize DFlash K/V workspace scheduler\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
lctx.dflash.kv.workspace_reserved_rows = cross_ctx;
|
||||
}
|
||||
|
||||
ggml_backend_sched_reset(lctx.dflash.kv.workspace_sched);
|
||||
ggml_backend_sched_alloc_graph(lctx.dflash.kv.workspace_sched, gf_workspace);
|
||||
|
||||
lctx.dflash.kv.workspace_graph = gf_workspace;
|
||||
lctx.dflash.kv.workspace_graph_rows = lctx.dflash.kv.cache_view_n_filled;
|
||||
lctx.dflash.kv.workspace_graph_write_pos = lctx.dflash.kv.cache_view_write_pos;
|
||||
}
|
||||
|
||||
llama_graph_compute_sched(lctx, lctx.dflash.kv.workspace_sched, gf_workspace, lctx.cparams.n_threads);
|
||||
lctx.dflash.kv.workspace_sync_pending = true;
|
||||
|
||||
lctx.dflash.kv.workspace_n_filled = lctx.dflash.kv.cache_view_n_filled;
|
||||
lctx.dflash.kv.workspace_write_pos = lctx.dflash.kv.cache_view_write_pos;
|
||||
lctx.dflash.kv.workspace_applied_window_version = lctx.dflash.kv.cache_applied_window_version;
|
||||
lctx.dflash.kv.workspace_valid = true;
|
||||
}
|
||||
}
|
||||
|
||||
const int32_t full_visible_first = left_pad;
|
||||
const int32_t full_visible_last = cross_ctx + (int32_t) n_tokens - 1;
|
||||
lctx.dflash.target.kq_mask_data.assign((size_t) n_kv_total * (size_t) n_mask_tokens, -INFINITY);
|
||||
for (uint32_t j = 0; j < n_tokens; ++j) {
|
||||
float * row = lctx.dflash.target.kq_mask_data.data() + (size_t) j * (size_t) n_kv_total;
|
||||
for (int32_t i = full_visible_first; i <= full_visible_last; ++i) {
|
||||
row[i] = 0.0f;
|
||||
}
|
||||
}
|
||||
ggml_backend_tensor_set(kq_mask, lctx.dflash.target.kq_mask_data.data(), 0, ggml_nbytes(kq_mask));
|
||||
|
||||
if (kq_mask_swa != nullptr) {
|
||||
lctx.dflash.target.kq_mask_swa_data.assign((size_t) n_kv_total * (size_t) n_mask_tokens, -INFINITY);
|
||||
const int32_t swa_window = (int32_t) lctx.model.hparams.n_swa;
|
||||
const int32_t draft_pos_base = (int32_t) last_target_pos;
|
||||
for (uint32_t j = 0; j < n_tokens; ++j) {
|
||||
float * row = lctx.dflash.target.kq_mask_swa_data.data() + (size_t) j * (size_t) n_kv_total;
|
||||
const int32_t q_pos = draft_pos_base + (int32_t) j;
|
||||
|
||||
for (int32_t k = left_pad; k < cross_ctx; ++k) {
|
||||
const int32_t k_pos = (int32_t) lctx.dflash.target.pos_ctx_data[(size_t) k];
|
||||
if (q_pos - k_pos < swa_window) {
|
||||
row[k] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
for (int32_t k = cross_ctx; k < cross_ctx + (int32_t) n_tokens; ++k) {
|
||||
const int32_t block_k = k - cross_ctx;
|
||||
if (block_k <= (int32_t) j) {
|
||||
row[k] = 0.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ggml_backend_tensor_set(kq_mask_swa, lctx.dflash.target.kq_mask_swa_data.data(), 0, ggml_nbytes(kq_mask_swa));
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
8
src/llama-dflash.h
Normal file
8
src/llama-dflash.h
Normal file
@ -0,0 +1,8 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
struct llama_context;
|
||||
|
||||
bool llama_prepare_dflash_graph_inputs(llama_context & lctx, uint32_t n_tokens);
|
||||
void llama_sync_dflash_workspace_if_pending(llama_context & lctx);
|
||||
@ -3,6 +3,7 @@
|
||||
#include "llama-model-loader.h"
|
||||
#include "llama-model.h"
|
||||
|
||||
#include <limits>
|
||||
#include <map>
|
||||
|
||||
#define LLAMA_MAX_EXPERTS 512 // Qwen3 Next
|
||||
@ -36,6 +37,89 @@ static inline const char * llm_expert_gating_func_name(llm_expert_gating_func_ty
|
||||
}
|
||||
}
|
||||
|
||||
static bool load_dflash_target_layer_ids(
|
||||
llama_model_loader & ml,
|
||||
const std::string & key,
|
||||
llama_hparams & hparams,
|
||||
bool required) {
|
||||
const int kid = gguf_find_key(ml.meta, key.c_str());
|
||||
if (kid < 0 || gguf_get_kv_type(ml.meta, kid) != GGUF_TYPE_ARRAY) {
|
||||
if (required) {
|
||||
throw std::runtime_error(format("array key not found in model: %s", key.c_str()));
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
const enum gguf_type type = gguf_get_arr_type(ml.meta, kid);
|
||||
if (type != GGUF_TYPE_UINT32 && type != GGUF_TYPE_INT32) {
|
||||
throw std::runtime_error(format("dflash: %s must be a uint32/int32 array", key.c_str()));
|
||||
}
|
||||
|
||||
const size_t n = gguf_get_arr_n(ml.meta, kid);
|
||||
if (n == 0) {
|
||||
throw std::runtime_error(format("dflash: %s must not be empty", key.c_str()));
|
||||
}
|
||||
if (n > 8) {
|
||||
throw std::runtime_error(format("dflash: %s has %zu entries, max is 8", key.c_str(), n));
|
||||
}
|
||||
|
||||
hparams.dflash_n_target_layers = (uint32_t) n;
|
||||
for (uint32_t & id : hparams.dflash_target_layer_ids) {
|
||||
id = 0;
|
||||
}
|
||||
|
||||
const void * data = gguf_get_arr_data(ml.meta, kid);
|
||||
for (uint32_t i = 0; i < hparams.dflash_n_target_layers; ++i) {
|
||||
if (type == GGUF_TYPE_INT32) {
|
||||
const int32_t id = ((const int32_t *) data)[i];
|
||||
if (id < 0) {
|
||||
throw std::runtime_error(format("dflash: %s contains negative layer id %d", key.c_str(), id));
|
||||
}
|
||||
hparams.dflash_target_layer_ids[i] = (uint32_t) id;
|
||||
} else {
|
||||
hparams.dflash_target_layer_ids[i] = ((const uint32_t *) data)[i];
|
||||
}
|
||||
|
||||
const uint32_t id = hparams.dflash_target_layer_ids[i];
|
||||
|
||||
for (uint32_t j = 0; j < i; ++j) {
|
||||
if (hparams.dflash_target_layer_ids[j] == id) {
|
||||
throw std::runtime_error(format(
|
||||
"dflash: %s contains duplicate layer id %u",
|
||||
key.c_str(),
|
||||
id));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static void validate_dflash_hparams(llama_hparams & hparams, llm_arch arch) {
|
||||
if (hparams.dflash_block_size <= 1) {
|
||||
throw std::runtime_error(format("%s: dflash block_size must be > 1", llama_model_arch_name(arch)));
|
||||
}
|
||||
if (hparams.dflash_n_target_layers == 0) {
|
||||
throw std::runtime_error(format("%s: dflash target_layer_ids are required", llama_model_arch_name(arch)));
|
||||
}
|
||||
|
||||
// DFlash feature width is target-model specific. Keep the serialized metadata intact here
|
||||
// and validate it against the live target model during DFlash init.
|
||||
|
||||
if (hparams.dflash_n_target_features == 0) {
|
||||
throw std::runtime_error(format(
|
||||
"%s: dflash n_target_features must be > 0",
|
||||
llama_model_arch_name(arch)));
|
||||
}
|
||||
if (hparams.dflash_n_target_features % hparams.dflash_n_target_layers != 0) {
|
||||
throw std::runtime_error(format(
|
||||
"%s: dflash n_target_features=%u must be divisible by n_target_layers=%u",
|
||||
llama_model_arch_name(arch),
|
||||
hparams.dflash_n_target_features,
|
||||
hparams.dflash_n_target_layers));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void llm_load_hparams(
|
||||
llama_model_loader & ml,
|
||||
@ -806,6 +890,18 @@ void llm_load_hparams(
|
||||
default: model.type = e_model::MODEL_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_DFLASH_DRAFT:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
ml.get_key(LLM_KV_DFLASH_BLOCK_SIZE, hparams.dflash_block_size, false);
|
||||
ml.get_key(LLM_KV_DFLASH_MASK_TOKEN_ID, hparams.dflash_mask_token_id, false);
|
||||
ml.get_key(LLM_KV_DFLASH_N_TARGET_FEATURES, hparams.dflash_n_target_features, false);
|
||||
load_dflash_target_layer_ids(ml, LLM_KV(model.arch)(LLM_KV_DFLASH_TARGET_LAYER_IDS), hparams, false);
|
||||
validate_dflash_hparams(hparams, model.arch);
|
||||
|
||||
hparams.n_layer_kv_from_start = hparams.n_layer;
|
||||
model.type = e_model::MODEL_UNKNOWN;
|
||||
} break;
|
||||
|
||||
case LLM_ARCH_STARCODER2:
|
||||
{
|
||||
|
||||
@ -139,6 +139,13 @@ struct llama_hparams {
|
||||
uint32_t mtp_num_centroids = 0;
|
||||
uint32_t mtp_centroid_top_k = 0;
|
||||
|
||||
// DFlash draft model metadata
|
||||
uint32_t dflash_block_size = 16;
|
||||
uint32_t dflash_mask_token_id = 0;
|
||||
uint32_t dflash_n_target_features = 0;
|
||||
uint32_t dflash_n_target_layers = 0;
|
||||
uint32_t dflash_target_layer_ids[8] = {};
|
||||
|
||||
// needed by encoder-decoder models (e.g. T5, FLAN-T5)
|
||||
// ref: https://github.com/ggerganov/llama.cpp/pull/8141
|
||||
llama_token dec_start_token_id = -1;
|
||||
@ -158,6 +165,10 @@ struct llama_hparams {
|
||||
if (this->n_ctx_train != other.n_ctx_train) return true;
|
||||
if (this->n_embd != other.n_embd) return true;
|
||||
if (this->mtp_backbone_n_embd != other.mtp_backbone_n_embd) return true;
|
||||
if (this->dflash_block_size != other.dflash_block_size) return true;
|
||||
if (this->dflash_mask_token_id != other.dflash_mask_token_id) return true;
|
||||
if (this->dflash_n_target_features != other.dflash_n_target_features) return true;
|
||||
if (this->dflash_n_target_layers != other.dflash_n_target_layers) return true;
|
||||
if (this->n_layer != other.n_layer) return true;
|
||||
if (this->n_rot != other.n_rot) return true;
|
||||
if (this->n_swa != other.n_swa) return true;
|
||||
@ -188,6 +199,9 @@ struct llama_hparams {
|
||||
if (this->ssm_dt_rank != other.ssm_dt_rank) return true;
|
||||
if (this->ssm_n_group != other.ssm_n_group) return true;
|
||||
if (this->recurrent_layer_arr != other.recurrent_layer_arr) return true;
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
if (this->dflash_target_layer_ids[i] != other.dflash_target_layer_ids[i]) return true;
|
||||
}
|
||||
|
||||
if (this->dec_start_token_id != other.dec_start_token_id) return true;
|
||||
|
||||
|
||||
@ -100,6 +100,8 @@ struct create_tensors_helper : public create_tensors_helper_interface {
|
||||
|
||||
bool create_gemma4_mtp_tensors(const LLM_TN & tn);
|
||||
|
||||
bool create_dflash_tensors(const LLM_TN & tn);
|
||||
|
||||
bool create_starcoder2_tensors(const LLM_TN & tn);
|
||||
|
||||
bool create_mamba_tensors(const LLM_TN & tn);
|
||||
@ -2248,6 +2250,48 @@ bool create_tensors_helper::create_gemma4_mtp_tensors(const LLM_TN & tn) {
|
||||
return use_mmap_buffer;
|
||||
}
|
||||
|
||||
bool create_tensors_helper::create_dflash_tensors(const LLM_TN & tn) {
|
||||
LOADING_PRELUDE
|
||||
|
||||
const bool use_split_ctx = model.split_mode == LLAMA_SPLIT_MODE_GRAPH || model.split_mode == LLAMA_SPLIT_MODE_ATTN;
|
||||
|
||||
model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
||||
model.output = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
auto output_extra = create_tensor(ctx_output, "output_extra.weight", {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
if (output_extra != nullptr) {
|
||||
model.output = output_extra;
|
||||
}
|
||||
if (model.output == nullptr && model.tok_embd != nullptr) {
|
||||
model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
|
||||
}
|
||||
model.output_mtp = model.output;
|
||||
model.dflash_fc = create_tensor(ctx_output, tn(LLM_TENSOR_DFLASH_FC, "weight"), {(int64_t) hparams.dflash_n_target_features, n_embd}, 0);
|
||||
model.dflash_hidden_norm = create_tensor(ctx_output, tn(LLM_TENSOR_DFLASH_HIDDEN_NORM, "weight"), {n_embd}, 0);
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
ggml_context * ctx_split = use_split_ctx ? ctx_for_layer_split(i) : ctx_for_layer(i);
|
||||
auto & layer = model.layers[i];
|
||||
|
||||
layer.attn_norm = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
||||
layer.attn_post_norm = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
|
||||
layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
|
||||
layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
|
||||
layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v * n_head, n_embd}, 0);
|
||||
|
||||
layer.attn_q_norm = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
|
||||
layer.attn_k_norm = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
|
||||
|
||||
layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
||||
layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
|
||||
layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||
}
|
||||
|
||||
return use_mmap_buffer;
|
||||
}
|
||||
|
||||
bool create_tensors_helper::create_starcoder2_tensors(const LLM_TN & tn) {
|
||||
LOADING_PRELUDE
|
||||
|
||||
@ -4398,6 +4442,8 @@ bool create_tensors_helper::create_tensors() {
|
||||
case LLM_ARCH_GEMMA4_MTP:
|
||||
case LLM_ARCH_GEMMA4_ASSISTANT:
|
||||
use_mmap_buffer = create_gemma4_mtp_tensors(tn); break;
|
||||
case LLM_ARCH_DFLASH_DRAFT:
|
||||
use_mmap_buffer = create_dflash_tensors(tn); break;
|
||||
case LLM_ARCH_STARCODER2:
|
||||
use_mmap_buffer = create_starcoder2_tensors(tn); break;
|
||||
case LLM_ARCH_MAMBA:
|
||||
|
||||
@ -845,6 +845,27 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
|
||||
{ LLM_TENSOR_MTP_CENTROIDS, "mtp_centroids" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_DFLASH_DRAFT,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
{ LLM_TENSOR_DFLASH_FC, "dflash_fc" },
|
||||
{ LLM_TENSOR_DFLASH_HIDDEN_NORM, "dflash_hidden_norm" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_GEMMA4_ASSISTANT,
|
||||
{
|
||||
|
||||
@ -430,6 +430,8 @@ struct llama_model {
|
||||
struct ggml_tensor * mtp_post_proj = nullptr;
|
||||
struct ggml_tensor * mtp_token_ordering = nullptr;
|
||||
struct ggml_tensor * mtp_centroids = nullptr;
|
||||
struct ggml_tensor * dflash_fc = nullptr;
|
||||
struct ggml_tensor * dflash_hidden_norm = nullptr;
|
||||
|
||||
struct ggml_tensor * output_norm;
|
||||
struct ggml_tensor * output_norm_b;
|
||||
|
||||
@ -616,7 +616,9 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
|
||||
if (qs.model.hparams.n_vocab >= 127999 && (qs.model.type == MODEL_8B || qs.model.type == MODEL_70B))
|
||||
new_type = GGML_TYPE_IQ6_K;
|
||||
}
|
||||
else if (qs.model.hparams.n_gqa() >= 4) {
|
||||
else if (qs.model.hparams.n_gqa() >= 4 &&
|
||||
!(arch == LLM_ARCH_DFLASH_DRAFT &&
|
||||
(ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M))) {
|
||||
if (new_type == GGML_TYPE_Q2_K || new_type == GGML_TYPE_IQ3_XXS) new_type = GGML_TYPE_IQ3_S;
|
||||
else if (new_type == GGML_TYPE_Q2_K_R4 || new_type == GGML_TYPE_IQ3_XXS_R4) new_type = GGML_TYPE_IQ3_K_R4;
|
||||
else if (new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_IQ3_S) new_type = GGML_TYPE_Q4_K;
|
||||
@ -1782,4 +1784,3 @@ uint32_t llama_model_quantize(
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
729
src/llama-spec-features-dflash.cpp
Normal file
729
src/llama-spec-features-dflash.cpp
Normal file
@ -0,0 +1,729 @@
|
||||
#include "llama-spec-features.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <random>
|
||||
|
||||
#include "llama-model.h"
|
||||
#include "llama-context.h"
|
||||
|
||||
void llama_reset_dflash_kv_cache_state(struct llama_context * ctx) {
|
||||
if (ctx == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
ctx->dflash.kv.cache_write_pos = 0;
|
||||
ctx->dflash.kv.cache_n_filled = 0;
|
||||
ctx->dflash.kv.cache_update_rows = 0;
|
||||
ctx->dflash.kv.cache_view_write_pos = 0;
|
||||
ctx->dflash.kv.cache_view_n_filled = 0;
|
||||
ctx->dflash.kv.cache_applied_window_version = 0;
|
||||
ctx->dflash.kv.cache_valid = false;
|
||||
ctx->dflash.kv.cache_view_valid = false;
|
||||
ctx->dflash.kv.workspace_write_pos = 0;
|
||||
ctx->dflash.kv.workspace_n_filled = 0;
|
||||
ctx->dflash.kv.workspace_applied_window_version = 0;
|
||||
ctx->dflash.kv.workspace_valid = false;
|
||||
ctx->dflash.kv.workspace_sync_pending = false;
|
||||
|
||||
for (ggml_backend_buffer_t buf : ctx->dflash.kv.cache_bufs) {
|
||||
if (buf != nullptr) {
|
||||
ggml_backend_buffer_clear(buf, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
llama_dflash_kv_cache_transition llama_plan_dflash_kv_cache_transition_for_ctx(
|
||||
const struct llama_context * ctx,
|
||||
const llama_dflash_window_update & window_update,
|
||||
int32_t n_rows) {
|
||||
if (ctx == nullptr) {
|
||||
llama_dflash_kv_cache_transition plan;
|
||||
plan.rebuild_cache = true;
|
||||
plan.append_rows = std::clamp(window_update.append_rows, 0, n_rows);
|
||||
plan.next_n_filled = n_rows;
|
||||
return plan;
|
||||
}
|
||||
|
||||
const int32_t cross_ctx = ctx->dflash.visible_cross_ctx > 0
|
||||
? ctx->dflash.visible_cross_ctx
|
||||
: std::max<int32_t>(1, (int32_t) ctx->cparams.n_ctx - (int32_t) ctx->model.hparams.dflash_block_size);
|
||||
|
||||
return llama_plan_dflash_kv_cache_transition(
|
||||
cross_ctx,
|
||||
ctx->dflash.kv.cache_n_filled,
|
||||
ctx->dflash.kv.cache_write_pos,
|
||||
ctx->dflash.kv.cache_valid,
|
||||
ctx->dflash.kv.cache_applied_window_version,
|
||||
window_update.version,
|
||||
window_update.keep_rows,
|
||||
window_update.append_rows,
|
||||
window_update.replace,
|
||||
n_rows);
|
||||
}
|
||||
|
||||
void llama_set_dflash_visible_cross_ctx(
|
||||
struct llama_context * ctx,
|
||||
int32_t cross_ctx) {
|
||||
if (ctx == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
ctx->dflash.visible_cross_ctx = std::max<int32_t>(0, cross_ctx);
|
||||
}
|
||||
|
||||
int32_t llama_get_dflash_visible_cross_ctx(
|
||||
const struct llama_context * ctx) {
|
||||
return ctx != nullptr ? ctx->dflash.visible_cross_ctx : 0;
|
||||
}
|
||||
|
||||
int32_t llama_model_dflash_block_size(const struct llama_model * model) {
|
||||
return model ? (int32_t) model->hparams.dflash_block_size : 0;
|
||||
}
|
||||
|
||||
int32_t llama_model_dflash_mask_token_id(const struct llama_model * model) {
|
||||
return model ? (int32_t) model->hparams.dflash_mask_token_id : -1;
|
||||
}
|
||||
|
||||
int32_t llama_model_dflash_n_target_layers(const struct llama_model * model) {
|
||||
return model ? (int32_t) model->hparams.dflash_n_target_layers : 0;
|
||||
}
|
||||
|
||||
int32_t llama_model_dflash_n_target_features(const struct llama_model * model) {
|
||||
return model ? (int32_t) model->hparams.dflash_n_target_features : 0;
|
||||
}
|
||||
|
||||
int32_t llama_model_dflash_target_layer_ids(
|
||||
const struct llama_model * model,
|
||||
int32_t * layer_ids,
|
||||
int32_t capacity) {
|
||||
if (model == nullptr || layer_ids == nullptr || capacity <= 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const int32_t n_layers = std::min<int32_t>((int32_t) model->hparams.dflash_n_target_layers, capacity);
|
||||
for (int32_t i = 0; i < n_layers; ++i) {
|
||||
layer_ids[i] = (int32_t) model->hparams.dflash_target_layer_ids[i];
|
||||
}
|
||||
|
||||
return n_layers;
|
||||
}
|
||||
|
||||
int32_t llama_model_dflash_target_mask_token_id(const struct llama_model * model) {
|
||||
if (model == nullptr) {
|
||||
return (int32_t) LLAMA_TOKEN_NULL;
|
||||
}
|
||||
|
||||
return (int32_t) model->vocab.token_mask();
|
||||
}
|
||||
|
||||
const struct ggml_tensor * llama_model_dflash_output_tensor(
|
||||
const struct llama_model * model) {
|
||||
if (model == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (model->output_mtp != nullptr) {
|
||||
return model->output_mtp;
|
||||
}
|
||||
|
||||
if (model->output != nullptr) {
|
||||
return model->output;
|
||||
}
|
||||
|
||||
return model->tok_embd;
|
||||
}
|
||||
|
||||
int32_t llama_model_dflash_io_mode(
|
||||
const struct llama_model * draft_model,
|
||||
const struct llama_model * target_model) {
|
||||
if (draft_model == nullptr || target_model == nullptr || draft_model->arch != LLM_ARCH_DFLASH_DRAFT) {
|
||||
return LLAMA_DFLASH_IO_MODE_INVALID;
|
||||
}
|
||||
|
||||
const ggml_tensor * draft_output = llama_model_dflash_output_tensor(draft_model);
|
||||
const ggml_tensor * target_output = llama_model_dflash_output_tensor(target_model);
|
||||
if (draft_model->tok_embd == nullptr || draft_output == nullptr || target_model->tok_embd == nullptr || target_output == nullptr) {
|
||||
return LLAMA_DFLASH_IO_MODE_INVALID;
|
||||
}
|
||||
|
||||
const bool shared_tok = draft_model->tok_embd == target_model->tok_embd;
|
||||
const bool shared_output = draft_output == target_output;
|
||||
if (shared_tok && shared_output) {
|
||||
return LLAMA_DFLASH_IO_MODE_SHARED;
|
||||
}
|
||||
|
||||
if (!shared_tok && !shared_output) {
|
||||
return LLAMA_DFLASH_IO_MODE_SELF_CONTAINED;
|
||||
}
|
||||
|
||||
return LLAMA_DFLASH_IO_MODE_MIXED;
|
||||
}
|
||||
|
||||
bool llama_model_dflash_io_tensors_match(
|
||||
const struct llama_model * draft_model,
|
||||
int32_t n_embd,
|
||||
int32_t n_vocab) {
|
||||
const ggml_tensor * output = llama_model_dflash_output_tensor(draft_model);
|
||||
if (draft_model == nullptr || draft_model->tok_embd == nullptr || output == nullptr || n_embd <= 0 || n_vocab <= 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return (int32_t) draft_model->tok_embd->ne[0] == n_embd &&
|
||||
(int32_t) draft_model->tok_embd->ne[1] == n_vocab &&
|
||||
(int32_t) output->ne[0] == n_embd &&
|
||||
(int32_t) output->ne[1] == n_vocab;
|
||||
}
|
||||
|
||||
bool llama_model_share_dflash_io_tensors(
|
||||
struct llama_model * draft_model,
|
||||
const struct llama_model * target_model) {
|
||||
if (draft_model == nullptr || target_model == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (draft_model->arch != LLM_ARCH_DFLASH_DRAFT) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (draft_model->tok_embd == nullptr) {
|
||||
draft_model->tok_embd = target_model->tok_embd;
|
||||
}
|
||||
|
||||
if (draft_model->output == nullptr) {
|
||||
draft_model->output = target_model->output ? target_model->output : target_model->tok_embd;
|
||||
if (draft_model->output == nullptr) {
|
||||
draft_model->output = draft_model->tok_embd;
|
||||
}
|
||||
}
|
||||
|
||||
const bool uses_shared_tok = draft_model->tok_embd == target_model->tok_embd;
|
||||
const bool uses_shared_output = draft_model->output == target_model->output ||
|
||||
draft_model->output == target_model->tok_embd;
|
||||
|
||||
if (draft_model->output_mtp == nullptr && target_model->output_mtp != nullptr && uses_shared_tok && uses_shared_output) {
|
||||
draft_model->output_mtp = target_model->output_mtp;
|
||||
}
|
||||
|
||||
const struct ggml_tensor * output = llama_model_dflash_output_tensor(draft_model);
|
||||
return draft_model->tok_embd != nullptr && output != nullptr;
|
||||
}
|
||||
|
||||
static bool llama_set_dflash_target_features_impl(
|
||||
struct llama_context * ctx,
|
||||
const float * target_features,
|
||||
size_t n_floats,
|
||||
int32_t n_rows,
|
||||
const llama_pos * target_positions,
|
||||
bool copy_data,
|
||||
const llama_dflash_window_update * window_update) {
|
||||
const bool have_full_features = target_features != nullptr && n_floats > 0;
|
||||
const bool have_append_features = window_update != nullptr &&
|
||||
window_update->append_features != nullptr &&
|
||||
window_update->append_floats > 0 &&
|
||||
window_update->append_rows > 0;
|
||||
|
||||
if (ctx == nullptr || n_rows <= 0 || (!have_full_features && !have_append_features)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (have_full_features && copy_data) {
|
||||
ctx->dflash.target.features_owned.assign(target_features, target_features + n_floats);
|
||||
ctx->dflash.target.features = ctx->dflash.target.features_owned.data();
|
||||
} else if (have_full_features) {
|
||||
ctx->dflash.target.features_owned.clear();
|
||||
ctx->dflash.target.features = target_features;
|
||||
} else {
|
||||
ctx->dflash.target.features_owned.clear();
|
||||
ctx->dflash.target.features = nullptr;
|
||||
}
|
||||
ctx->dflash.target.features_n_floats = have_full_features ? n_floats : 0;
|
||||
ctx->dflash.target.features_n_rows = n_rows;
|
||||
if (have_append_features && copy_data) {
|
||||
ctx->dflash.target.append_features_owned.assign(
|
||||
window_update->append_features,
|
||||
window_update->append_features + window_update->append_floats);
|
||||
ctx->dflash.target.append_features = ctx->dflash.target.append_features_owned.data();
|
||||
} else if (have_append_features) {
|
||||
ctx->dflash.target.append_features_owned.clear();
|
||||
ctx->dflash.target.append_features = window_update->append_features;
|
||||
} else {
|
||||
ctx->dflash.target.append_features_owned.clear();
|
||||
ctx->dflash.target.append_features = nullptr;
|
||||
}
|
||||
ctx->dflash.target.append_features_n_floats = have_append_features ? window_update->append_floats : 0;
|
||||
ctx->dflash.target.append_features_n_rows = have_append_features ? window_update->append_rows : 0;
|
||||
ctx->dflash.target.version = window_update != nullptr && window_update->version > 0
|
||||
? window_update->version
|
||||
: ctx->dflash.target.version + 1;
|
||||
ctx->dflash.target.keep_rows = window_update != nullptr
|
||||
? std::max<int32_t>(0, std::min(n_rows, window_update->keep_rows))
|
||||
: 0;
|
||||
ctx->dflash.target.append_rows = window_update != nullptr
|
||||
? std::max<int32_t>(0, std::min(n_rows, window_update->append_rows))
|
||||
: n_rows;
|
||||
ctx->dflash.target.replace = window_update != nullptr
|
||||
? window_update->replace
|
||||
: true;
|
||||
if (ctx->dflash.target.keep_rows + ctx->dflash.target.append_rows > n_rows) {
|
||||
ctx->dflash.target.keep_rows = std::max<int32_t>(0, n_rows - ctx->dflash.target.append_rows);
|
||||
}
|
||||
|
||||
const int32_t cross_ctx = ctx->dflash.visible_cross_ctx > 0
|
||||
? ctx->dflash.visible_cross_ctx
|
||||
: std::max<int32_t>(1, (int32_t) ctx->cparams.n_ctx - (int32_t) ctx->model.hparams.dflash_block_size);
|
||||
const llama_dflash_window_update cache_window_update = {
|
||||
ctx->dflash.target.version,
|
||||
ctx->dflash.target.keep_rows,
|
||||
ctx->dflash.target.append_rows,
|
||||
ctx->dflash.target.replace,
|
||||
ctx->dflash.target.append_features,
|
||||
ctx->dflash.target.append_features_n_floats,
|
||||
};
|
||||
const llama_dflash_kv_cache_transition cache_plan = llama_plan_dflash_kv_cache_transition_for_ctx(ctx, cache_window_update, n_rows);
|
||||
|
||||
if (cache_plan.cache_up_to_date) {
|
||||
ctx->dflash.kv.cache_view_n_filled = ctx->dflash.kv.cache_n_filled;
|
||||
ctx->dflash.kv.cache_view_write_pos = ctx->dflash.kv.cache_write_pos;
|
||||
ctx->dflash.kv.cache_view_valid = ctx->dflash.kv.cache_valid;
|
||||
} else if (cross_ctx > 0) {
|
||||
ctx->dflash.kv.cache_view_n_filled = cache_plan.next_n_filled;
|
||||
ctx->dflash.kv.cache_view_write_pos = cache_plan.next_write_pos;
|
||||
ctx->dflash.kv.cache_view_valid = cache_plan.next_n_filled > 0;
|
||||
}
|
||||
|
||||
if (target_positions != nullptr) {
|
||||
if (copy_data) {
|
||||
ctx->dflash.target.positions_owned.assign(target_positions, target_positions + n_rows);
|
||||
ctx->dflash.target.positions = ctx->dflash.target.positions_owned.data();
|
||||
} else {
|
||||
ctx->dflash.target.positions_owned.clear();
|
||||
ctx->dflash.target.positions = target_positions;
|
||||
}
|
||||
ctx->dflash.target.positions_n = (size_t) n_rows;
|
||||
} else {
|
||||
ctx->dflash.target.positions_owned.clear();
|
||||
ctx->dflash.target.positions = nullptr;
|
||||
ctx->dflash.target.positions_n = 0;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool llama_set_dflash_target_features_copy(
|
||||
struct llama_context * ctx,
|
||||
const float * target_features,
|
||||
size_t n_floats,
|
||||
int32_t n_rows,
|
||||
const llama_pos * target_positions,
|
||||
const llama_dflash_window_update * window_update) {
|
||||
return llama_set_dflash_target_features_impl(ctx, target_features, n_floats, n_rows, target_positions, true, window_update);
|
||||
}
|
||||
|
||||
bool llama_set_dflash_target_features_view(
|
||||
struct llama_context * ctx,
|
||||
const float * target_features,
|
||||
size_t n_floats,
|
||||
int32_t n_rows,
|
||||
const llama_pos * target_positions,
|
||||
const llama_dflash_window_update * window_update) {
|
||||
return llama_set_dflash_target_features_impl(ctx, target_features, n_floats, n_rows, target_positions, false, window_update);
|
||||
}
|
||||
|
||||
static bool llama_dflash_parse_layer_id(const struct ggml_tensor * tensor, int32_t & layer_id) {
|
||||
if (tensor == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
static constexpr const char * prefix = "l_out-";
|
||||
if (std::strncmp(tensor->name, prefix, std::strlen(prefix)) != 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
char * end = nullptr;
|
||||
const long raw = std::strtol(tensor->name + std::strlen(prefix), &end, 10);
|
||||
if (end == tensor->name + std::strlen(prefix) || *end != '\0') {
|
||||
return false;
|
||||
}
|
||||
|
||||
layer_id = (int32_t) raw;
|
||||
if (layer_id >= 1000) {
|
||||
layer_id %= 1000;
|
||||
}
|
||||
|
||||
return layer_id >= 0;
|
||||
}
|
||||
|
||||
static int32_t llama_dflash_find_layer_index(const struct llama_context * ctx, int32_t layer_id) {
|
||||
if (ctx == nullptr || !ctx->dflash.capture) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
const auto & layer_ids = ctx->dflash.capture->layer_ids;
|
||||
const auto it = std::find(layer_ids.begin(), layer_ids.end(), layer_id);
|
||||
return it == layer_ids.end() ? -1 : (int32_t) std::distance(layer_ids.begin(), it);
|
||||
}
|
||||
|
||||
static bool llama_dflash_capture_eval_callback(struct ggml_tensor * tensor, bool ask, void * user_data) {
|
||||
auto * ctx = static_cast<llama_context *>(user_data);
|
||||
if (ctx == nullptr || !ctx->dflash.capture) {
|
||||
return false;
|
||||
}
|
||||
|
||||
int32_t layer_id = -1;
|
||||
if (!llama_dflash_parse_layer_id(tensor, layer_id)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const int32_t layer_idx = llama_dflash_find_layer_index(ctx, layer_id);
|
||||
if (layer_idx < 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (ask) {
|
||||
return true;
|
||||
}
|
||||
|
||||
const int32_t row_width = (int32_t) tensor->ne[0];
|
||||
const int32_t row_count = row_width > 0 ? (int32_t) (ggml_nelements(tensor) / (int64_t) row_width) : 0;
|
||||
if (row_width <= 0 || row_count <= 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto & capture = *ctx->dflash.capture;
|
||||
if (capture.capture_batch_id == 0) {
|
||||
capture.capture_batch_id = 1;
|
||||
}
|
||||
if (capture.layer_seen_batch_id.size() != capture.layer_ids.size()) {
|
||||
capture.layer_seen_batch_id.assign(capture.layer_ids.size(), 0);
|
||||
}
|
||||
|
||||
auto & rows = capture.layer_rows[(size_t) layer_idx];
|
||||
rows.resize((size_t) row_count * (size_t) row_width);
|
||||
ggml_backend_tensor_get(tensor, rows.data(), 0, ggml_nbytes(tensor));
|
||||
capture.row_width = row_width;
|
||||
capture.row_count = row_count;
|
||||
capture.layer_seen_batch_id[(size_t) layer_idx] = capture.capture_batch_id;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool llama_set_dflash_capture_layers(
|
||||
struct llama_context * ctx,
|
||||
const int32_t * layer_ids,
|
||||
int32_t n_layers) {
|
||||
if (ctx == nullptr || layer_ids == nullptr || n_layers <= 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto capture = std::make_unique<llama_context::dflash_runtime::capture_state>();
|
||||
capture->layer_ids.assign(layer_ids, layer_ids + n_layers);
|
||||
capture->layer_rows.resize((size_t) n_layers);
|
||||
capture->layer_seen_batch_id.assign((size_t) n_layers, 0);
|
||||
capture->prev_cb_eval = ctx->cparams.cb_eval;
|
||||
capture->prev_cb_eval_user_data = ctx->cparams.cb_eval_user_data;
|
||||
ctx->dflash.capture = std::move(capture);
|
||||
ctx->dflash.feature_view_buffer.clear();
|
||||
|
||||
ctx->cparams.cb_eval = llama_dflash_capture_eval_callback;
|
||||
ctx->cparams.cb_eval_user_data = ctx;
|
||||
if (ctx->sched != nullptr) {
|
||||
ggml_backend_sched_set_eval_callback(ctx->sched, ctx->cparams.cb_eval, ctx->cparams.cb_eval_user_data);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void llama_clear_dflash_capture(struct llama_context * ctx) {
|
||||
if (ctx == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
ggml_backend_sched_eval_callback prev_cb_eval = nullptr;
|
||||
void * prev_cb_eval_user_data = nullptr;
|
||||
if (ctx->dflash.capture) {
|
||||
prev_cb_eval = ctx->dflash.capture->prev_cb_eval;
|
||||
prev_cb_eval_user_data = ctx->dflash.capture->prev_cb_eval_user_data;
|
||||
}
|
||||
|
||||
ctx->dflash.capture.reset();
|
||||
ctx->dflash.feature_view_buffer.clear();
|
||||
|
||||
if (ctx->cparams.cb_eval == llama_dflash_capture_eval_callback && ctx->cparams.cb_eval_user_data == ctx) {
|
||||
ctx->cparams.cb_eval = prev_cb_eval;
|
||||
ctx->cparams.cb_eval_user_data = prev_cb_eval_user_data;
|
||||
if (ctx->sched != nullptr) {
|
||||
ggml_backend_sched_set_eval_callback(ctx->sched, prev_cb_eval, prev_cb_eval_user_data);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void llama_begin_dflash_capture_batch(struct llama_context * ctx) {
|
||||
if (ctx == nullptr || !ctx->dflash.capture) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto & capture = *ctx->dflash.capture;
|
||||
capture.capture_batch_id++;
|
||||
capture.row_count = 0;
|
||||
capture.row_width = 0;
|
||||
std::fill(capture.layer_seen_batch_id.begin(), capture.layer_seen_batch_id.end(), 0);
|
||||
}
|
||||
|
||||
void llama_finish_dflash_capture_batch(
|
||||
struct llama_context * ctx,
|
||||
bool is_prompt_warmup) {
|
||||
if (ctx == nullptr || !ctx->dflash.capture) {
|
||||
return;
|
||||
}
|
||||
|
||||
GGML_UNUSED(is_prompt_warmup);
|
||||
auto & capture = *ctx->dflash.capture;
|
||||
// Reset the batch-local reference shape so the next decode only compares layers within
|
||||
// the same batch, not against the previous prompt/verify batch.
|
||||
capture.row_count = 0;
|
||||
capture.row_width = 0;
|
||||
}
|
||||
|
||||
static bool llama_spec_prepare_dflash_capture(
|
||||
struct llama_context * ctx,
|
||||
int32_t & row_count,
|
||||
int32_t & row_width,
|
||||
int32_t & n_layers) {
|
||||
if (ctx == nullptr || !ctx->dflash.capture) {
|
||||
return false;
|
||||
}
|
||||
|
||||
llama_synchronize(ctx);
|
||||
|
||||
auto & capture = *ctx->dflash.capture;
|
||||
row_count = capture.row_count;
|
||||
row_width = capture.row_width;
|
||||
n_layers = (int32_t) capture.layer_ids.size();
|
||||
if (row_count <= 0 || row_width <= 0 || n_layers <= 0 || capture.layer_rows.size() != (size_t) n_layers) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (capture.capture_batch_id == 0 || capture.layer_seen_batch_id.size() != (size_t) n_layers) {
|
||||
LLAMA_LOG_WARN("%s: DFlash capture batch markers are not initialized (batch_id=%llu layers=%zu expected=%d)\n",
|
||||
__func__,
|
||||
(unsigned long long) capture.capture_batch_id,
|
||||
capture.layer_seen_batch_id.size(),
|
||||
n_layers);
|
||||
return false;
|
||||
}
|
||||
|
||||
for (int32_t layer_idx = 0; layer_idx < n_layers; ++layer_idx) {
|
||||
if (capture.layer_seen_batch_id[(size_t) layer_idx] != capture.capture_batch_id) {
|
||||
LLAMA_LOG_WARN("%s: DFlash capture is stale for layer %d (seen_batch=%llu current_batch=%llu rows=%d width=%d)\n",
|
||||
__func__,
|
||||
capture.layer_ids[(size_t) layer_idx],
|
||||
(unsigned long long) capture.layer_seen_batch_id[(size_t) layer_idx],
|
||||
(unsigned long long) capture.capture_batch_id,
|
||||
row_count,
|
||||
row_width);
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto & rows = capture.layer_rows[(size_t) layer_idx];
|
||||
if (rows.size() != (size_t) row_count * (size_t) row_width) {
|
||||
LLAMA_LOG_WARN("%s: DFlash capture rows mismatch for layer %d: got=%zu expected=%zu (rows=%d width=%d)\n",
|
||||
__func__, capture.layer_ids[(size_t) layer_idx], rows.size(),
|
||||
(size_t) row_count * (size_t) row_width, row_count, row_width);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool llama_spec_materialize_dflash_rows_prepared(
|
||||
struct llama_context * ctx,
|
||||
int32_t row_count,
|
||||
int32_t row_width,
|
||||
int32_t n_layers,
|
||||
const std::vector<int32_t> & row_indices,
|
||||
std::vector<float> & rows_out,
|
||||
int32_t & combined_width);
|
||||
|
||||
static bool llama_spec_materialize_dflash_rows(
|
||||
struct llama_context * ctx,
|
||||
const std::vector<int32_t> & row_indices,
|
||||
std::vector<float> & rows_out,
|
||||
int32_t & combined_width) {
|
||||
int32_t row_count = 0;
|
||||
int32_t row_width = 0;
|
||||
int32_t n_layers = 0;
|
||||
if (!llama_spec_prepare_dflash_capture(ctx, row_count, row_width, n_layers)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return llama_spec_materialize_dflash_rows_prepared(ctx, row_count, row_width, n_layers, row_indices, rows_out, combined_width);
|
||||
}
|
||||
|
||||
static bool llama_spec_materialize_dflash_rows_prepared(
|
||||
struct llama_context * ctx,
|
||||
int32_t row_count,
|
||||
int32_t row_width,
|
||||
int32_t n_layers,
|
||||
const std::vector<int32_t> & row_indices,
|
||||
std::vector<float> & rows_out,
|
||||
int32_t & combined_width) {
|
||||
rows_out.clear();
|
||||
combined_width = 0;
|
||||
if (ctx == nullptr || row_indices.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (row_count <= 0 || row_width <= 0 || n_layers <= 0 || ctx->dflash.capture == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
combined_width = row_width * n_layers;
|
||||
rows_out.resize((size_t) row_indices.size() * (size_t) combined_width);
|
||||
|
||||
const auto & layer_rows = ctx->dflash.capture->layer_rows;
|
||||
for (size_t out_row = 0; out_row < row_indices.size(); ++out_row) {
|
||||
int32_t row_index = row_indices[out_row];
|
||||
if (row_index < 0) {
|
||||
row_index += row_count;
|
||||
}
|
||||
if (row_index < 0 || row_index >= row_count) {
|
||||
rows_out.clear();
|
||||
combined_width = 0;
|
||||
return false;
|
||||
}
|
||||
|
||||
float * dst = rows_out.data() + out_row * (size_t) combined_width;
|
||||
for (int32_t layer_idx = 0; layer_idx < n_layers; ++layer_idx) {
|
||||
const float * src = layer_rows[(size_t) layer_idx].data() + (size_t) row_index * (size_t) row_width;
|
||||
std::memcpy(dst + (size_t) layer_idx * (size_t) row_width, src, (size_t) row_width * sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
bool llama_spec_get_dflash_feature_view(
|
||||
struct llama_context * ctx,
|
||||
const llama_batch & batch,
|
||||
llama_spec_feature_view & view) {
|
||||
if (ctx == nullptr || batch.n_tokens <= 0 || batch.pos == nullptr || batch.n_seq_id == nullptr || batch.seq_id == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
int32_t row_count = 0;
|
||||
int32_t row_width = 0;
|
||||
int32_t n_layers = 0;
|
||||
if (!llama_spec_prepare_dflash_capture(ctx, row_count, row_width, n_layers)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const int32_t batch_row_offset = std::max<int32_t>(0, batch.n_tokens - row_count);
|
||||
std::vector<int32_t> row_indices;
|
||||
std::vector<int32_t> batch_indices;
|
||||
row_indices.reserve((size_t) (batch.n_tokens - batch_row_offset));
|
||||
batch_indices.reserve((size_t) (batch.n_tokens - batch_row_offset));
|
||||
for (int32_t i = batch_row_offset; i < batch.n_tokens; ++i) {
|
||||
row_indices.push_back(i - batch_row_offset);
|
||||
batch_indices.push_back(i);
|
||||
}
|
||||
|
||||
if (row_indices.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
view = {};
|
||||
view.kind = LLAMA_SPEC_FEATURE_HIDDEN_STATE;
|
||||
if (!llama_spec_materialize_dflash_rows_prepared(ctx, row_count, row_width, n_layers, row_indices, ctx->dflash.feature_view_buffer, view.width)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
view.rows.reserve(batch_indices.size());
|
||||
for (int32_t batch_index : batch_indices) {
|
||||
if (batch.n_seq_id[batch_index] <= 0 || batch.seq_id[batch_index] == nullptr) {
|
||||
view.rows.clear();
|
||||
return false;
|
||||
}
|
||||
|
||||
view.rows.push_back({
|
||||
/* .seq_id = */ batch.seq_id[batch_index][0],
|
||||
/* .pos = */ batch.pos[batch_index],
|
||||
/* .data = */ ctx->dflash.feature_view_buffer.data() + view.rows.size() * (size_t) view.width,
|
||||
});
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool llama_spec_get_dflash_feature_view_for_seq(
|
||||
struct llama_context * ctx,
|
||||
const llama_batch & batch,
|
||||
llama_seq_id seq_id,
|
||||
llama_spec_feature_view & view) {
|
||||
if (ctx == nullptr || batch.n_tokens <= 0 || batch.pos == nullptr || batch.n_seq_id == nullptr || batch.seq_id == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
int32_t row_count = 0;
|
||||
int32_t row_width = 0;
|
||||
int32_t n_layers = 0;
|
||||
if (!llama_spec_prepare_dflash_capture(ctx, row_count, row_width, n_layers)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const int32_t batch_row_offset = std::max<int32_t>(0, batch.n_tokens - row_count);
|
||||
std::vector<int32_t> row_indices;
|
||||
row_indices.reserve((size_t) batch.n_tokens);
|
||||
std::vector<int32_t> batch_indices;
|
||||
batch_indices.reserve((size_t) batch.n_tokens);
|
||||
for (int32_t i = batch_row_offset; i < batch.n_tokens; ++i) {
|
||||
if (batch.n_seq_id[i] <= 0 || batch.seq_id[i] == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (int32_t j = 0; j < batch.n_seq_id[i]; ++j) {
|
||||
if (batch.seq_id[i][j] == seq_id) {
|
||||
row_indices.push_back(i - batch_row_offset);
|
||||
batch_indices.push_back(i);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (row_indices.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
view = {};
|
||||
view.kind = LLAMA_SPEC_FEATURE_HIDDEN_STATE;
|
||||
if (!llama_spec_materialize_dflash_rows_prepared(ctx, row_count, row_width, n_layers, row_indices, ctx->dflash.feature_view_buffer, view.width)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
view.rows.reserve(row_indices.size());
|
||||
for (size_t i = 0; i < batch_indices.size(); ++i) {
|
||||
const int32_t batch_index = batch_indices[i];
|
||||
view.rows.push_back({
|
||||
/* .seq_id = */ seq_id,
|
||||
/* .pos = */ batch.pos[batch_index],
|
||||
/* .data = */ ctx->dflash.feature_view_buffer.data() + i * (size_t) view.width,
|
||||
});
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool llama_spec_copy_dflash_rows_from_output_indices(
|
||||
struct llama_context * ctx,
|
||||
const std::vector<int32_t> & output_indices,
|
||||
std::vector<float> & hidden_rows) {
|
||||
int32_t combined_width = 0;
|
||||
if (!llama_spec_materialize_dflash_rows(ctx, output_indices, hidden_rows, combined_width)) {
|
||||
hidden_rows.clear();
|
||||
return false;
|
||||
}
|
||||
|
||||
return hidden_rows.size() == (size_t) output_indices.size() * (size_t) combined_width;
|
||||
}
|
||||
136
src/llama-spec-features-dflash.h
Normal file
136
src/llama-spec-features-dflash.h
Normal file
@ -0,0 +1,136 @@
|
||||
#pragma once
|
||||
|
||||
#include "llama.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
|
||||
struct llama_context;
|
||||
struct llama_model;
|
||||
struct ggml_tensor;
|
||||
struct llama_spec_feature_view;
|
||||
|
||||
struct llama_dflash_window_update {
|
||||
uint64_t version = 0;
|
||||
int32_t keep_rows = 0;
|
||||
int32_t append_rows = 0;
|
||||
bool replace = false;
|
||||
const float * append_features = nullptr;
|
||||
size_t append_floats = 0;
|
||||
};
|
||||
|
||||
struct llama_dflash_kv_cache_transition {
|
||||
bool cache_up_to_date = false;
|
||||
bool rebuild_cache = false;
|
||||
int32_t append_rows = 0;
|
||||
int32_t next_n_filled = 0;
|
||||
int32_t next_write_pos = 0;
|
||||
};
|
||||
|
||||
static inline llama_dflash_kv_cache_transition llama_plan_dflash_kv_cache_transition(
|
||||
int32_t cross_ctx,
|
||||
int32_t current_n_filled,
|
||||
int32_t current_write_pos,
|
||||
bool cache_valid,
|
||||
uint64_t applied_window_version,
|
||||
uint64_t target_window_version,
|
||||
int32_t keep_rows,
|
||||
int32_t append_rows,
|
||||
bool replace,
|
||||
int32_t n_rows) {
|
||||
llama_dflash_kv_cache_transition plan;
|
||||
|
||||
const int32_t safe_cross_ctx = std::max<int32_t>(1, cross_ctx);
|
||||
const int32_t bounded_n_filled = std::clamp(current_n_filled, 0, safe_cross_ctx);
|
||||
const int32_t bounded_append_rows = std::clamp(append_rows, 0, n_rows);
|
||||
const int32_t bounded_keep_rows = std::clamp(keep_rows, 0, n_rows);
|
||||
const int32_t expected_keep_rows = std::min(bounded_n_filled, std::max<int32_t>(0, safe_cross_ctx - bounded_append_rows));
|
||||
|
||||
plan.cache_up_to_date = cache_valid && applied_window_version == target_window_version;
|
||||
plan.rebuild_cache = !cache_valid || replace || bounded_append_rows <= 0 || bounded_append_rows > n_rows;
|
||||
if (!plan.rebuild_cache && bounded_keep_rows != expected_keep_rows) {
|
||||
plan.rebuild_cache = true;
|
||||
}
|
||||
|
||||
plan.append_rows = bounded_append_rows;
|
||||
if (plan.cache_up_to_date) {
|
||||
plan.next_n_filled = bounded_n_filled;
|
||||
plan.next_write_pos = safe_cross_ctx > 0
|
||||
? ((current_write_pos % safe_cross_ctx) + safe_cross_ctx) % safe_cross_ctx
|
||||
: 0;
|
||||
} else if (plan.rebuild_cache) {
|
||||
plan.next_n_filled = std::min(safe_cross_ctx, n_rows);
|
||||
plan.next_write_pos = plan.next_n_filled % safe_cross_ctx;
|
||||
} else {
|
||||
plan.next_n_filled = std::min(safe_cross_ctx, bounded_n_filled + bounded_append_rows);
|
||||
plan.next_write_pos = (current_write_pos + bounded_append_rows) % safe_cross_ctx;
|
||||
}
|
||||
|
||||
return plan;
|
||||
}
|
||||
|
||||
llama_dflash_kv_cache_transition llama_plan_dflash_kv_cache_transition_for_ctx(
|
||||
const struct llama_context * ctx,
|
||||
const llama_dflash_window_update & window_update,
|
||||
int32_t n_rows);
|
||||
|
||||
void llama_reset_dflash_kv_cache_state(struct llama_context * ctx);
|
||||
void llama_set_dflash_visible_cross_ctx(struct llama_context * ctx, int32_t cross_ctx);
|
||||
int32_t llama_get_dflash_visible_cross_ctx(const struct llama_context * ctx);
|
||||
|
||||
int32_t llama_model_dflash_block_size(const struct llama_model * model);
|
||||
int32_t llama_model_dflash_mask_token_id(const struct llama_model * model);
|
||||
int32_t llama_model_dflash_n_target_layers(const struct llama_model * model);
|
||||
int32_t llama_model_dflash_n_target_features(const struct llama_model * model);
|
||||
int32_t llama_model_dflash_target_layer_ids(const struct llama_model * model, int32_t * layer_ids, int32_t capacity);
|
||||
int32_t llama_model_dflash_target_mask_token_id(const struct llama_model * model);
|
||||
const struct ggml_tensor * llama_model_dflash_output_tensor(const struct llama_model * model);
|
||||
|
||||
enum llama_dflash_io_mode {
|
||||
LLAMA_DFLASH_IO_MODE_INVALID = 0,
|
||||
LLAMA_DFLASH_IO_MODE_SHARED,
|
||||
LLAMA_DFLASH_IO_MODE_SELF_CONTAINED,
|
||||
LLAMA_DFLASH_IO_MODE_MIXED,
|
||||
};
|
||||
|
||||
int32_t llama_model_dflash_io_mode(const struct llama_model * draft_model, const struct llama_model * target_model);
|
||||
bool llama_model_dflash_io_tensors_match(const struct llama_model * draft_model, int32_t n_embd, int32_t n_vocab);
|
||||
bool llama_model_share_dflash_io_tensors(struct llama_model * draft_model, const struct llama_model * target_model);
|
||||
|
||||
bool llama_set_dflash_target_features_copy(
|
||||
struct llama_context * ctx,
|
||||
const float * target_features,
|
||||
size_t n_floats,
|
||||
int32_t n_rows,
|
||||
const llama_pos * target_positions,
|
||||
const llama_dflash_window_update * window_update = nullptr);
|
||||
|
||||
bool llama_set_dflash_target_features_view(
|
||||
struct llama_context * ctx,
|
||||
const float * target_features,
|
||||
size_t n_floats,
|
||||
int32_t n_rows,
|
||||
const llama_pos * target_positions,
|
||||
const llama_dflash_window_update * window_update = nullptr);
|
||||
|
||||
bool llama_set_dflash_capture_layers(struct llama_context * ctx, const int32_t * layer_ids, int32_t n_layers);
|
||||
void llama_clear_dflash_capture(struct llama_context * ctx);
|
||||
void llama_begin_dflash_capture_batch(struct llama_context * ctx);
|
||||
void llama_finish_dflash_capture_batch(struct llama_context * ctx, bool is_prompt_warmup);
|
||||
|
||||
bool llama_spec_get_dflash_feature_view(
|
||||
struct llama_context * ctx,
|
||||
const llama_batch & batch,
|
||||
llama_spec_feature_view & view);
|
||||
|
||||
bool llama_spec_get_dflash_feature_view_for_seq(
|
||||
struct llama_context * ctx,
|
||||
const llama_batch & batch,
|
||||
llama_seq_id seq_id,
|
||||
llama_spec_feature_view & view);
|
||||
|
||||
bool llama_spec_copy_dflash_rows_from_output_indices(
|
||||
struct llama_context * ctx,
|
||||
const std::vector<int32_t> & output_indices,
|
||||
std::vector<float> & hidden_rows);
|
||||
@ -88,6 +88,7 @@ bool llama_spec_get_hidden_feature_view(
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
bool llama_spec_get_hidden_feature_view_for_seq(
|
||||
struct llama_context * ctx,
|
||||
const llama_batch & batch,
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
|
||||
#include "llama.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
|
||||
struct llama_context;
|
||||
@ -23,6 +24,8 @@ struct llama_spec_feature_view {
|
||||
std::vector<llama_spec_feature_row_view> rows;
|
||||
};
|
||||
|
||||
#include "llama-spec-features-dflash.h"
|
||||
|
||||
uint32_t llama_mtp_state_n_embd(const struct llama_context * ctx);
|
||||
|
||||
bool llama_set_draft_input_hidden_state_copy(
|
||||
|
||||
111
src/llama.cpp
111
src/llama.cpp
@ -18,6 +18,7 @@
|
||||
#include "llama-hparams.h"
|
||||
#include "llama-context.h"
|
||||
#include "llama-spec-features.h"
|
||||
#include "llama-dflash.h"
|
||||
#include "llama-quantize.h"
|
||||
|
||||
#include "unicode.h"
|
||||
@ -171,6 +172,14 @@ static std::vector<std::string> string_split(const std::string& str, const std::
|
||||
return parts;
|
||||
}
|
||||
|
||||
static bool llama_env_flag_enabled(const char * name) {
|
||||
const char * env = std::getenv(name);
|
||||
return env != nullptr && *env != '\0' &&
|
||||
std::strcmp(env, "0") != 0 &&
|
||||
std::strcmp(env, "false") != 0 &&
|
||||
std::strcmp(env, "off") != 0;
|
||||
}
|
||||
|
||||
// extract ip and port from RPC[ip:port] for rpc and keep other device names
|
||||
static std::vector<rpc_device> extract_device_from_rpc_device(std::vector<std::string> devices) {
|
||||
std::vector<rpc_device> rpc_servers;
|
||||
@ -689,6 +698,10 @@ void llama_context::set_mtp_op_type(llama_mtp_op_type value) {
|
||||
}
|
||||
|
||||
llama_context::~llama_context() {
|
||||
if (dflash.kv.cache_sched != nullptr) {
|
||||
ggml_backend_sched_free(dflash.kv.cache_sched);
|
||||
}
|
||||
free_dflash_kv_cache_tensors();
|
||||
ggml_backend_sched_free(sched);
|
||||
|
||||
for (ggml_backend_t backend : backends) {
|
||||
@ -3163,6 +3176,10 @@ static std::pair<std::vector<double>, double> get_layer_sizes(const llama_model_
|
||||
name == "mtp_centroids.weight" || name == "mtp_token_ordering.weight") {
|
||||
continue;
|
||||
}
|
||||
if (name == "dflash_fc.weight" || name == "dflash_hidden_norm.weight") {
|
||||
output_misc_size += size;
|
||||
continue;
|
||||
}
|
||||
auto pos = name.find("blk.");
|
||||
if (pos != 0) {
|
||||
LLAMA_LOG_WARN("Oops: tensor with strange name %s\n", name.c_str());
|
||||
@ -3972,7 +3989,7 @@ static bool llm_load_tensors(
|
||||
if (model.arch == LLM_ARCH_GEMMA4) {
|
||||
llm_scale_gate_inp_s(model, use_mmap_buffer);
|
||||
}
|
||||
if ((model.arch == LLM_ARCH_QWEN35 || model.arch == LLM_ARCH_QWEN35MOE) && extra_output_type != GGML_TYPE_COUNT) {
|
||||
if ((model.arch == LLM_ARCH_QWEN35 || model.arch == LLM_ARCH_QWEN35MOE || model.arch == LLM_ARCH_DFLASH_DRAFT) && extra_output_type != GGML_TYPE_COUNT) {
|
||||
llm_requantize_output_tensor(model, extra_output_type);
|
||||
}
|
||||
|
||||
@ -4959,6 +4976,30 @@ static void llama_graph_compute(
|
||||
// fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(lctx.sched));
|
||||
}
|
||||
|
||||
static void llama_graph_compute_sched(
|
||||
llama_context & lctx,
|
||||
ggml_backend_sched_t sched,
|
||||
ggml_cgraph * gf,
|
||||
int n_threads) {
|
||||
#ifdef GGML_USE_METAL
|
||||
if (ggml_backend_is_metal(lctx.backend_metal)) {
|
||||
ggml_backend_metal_set_n_cb(lctx.backend_metal, n_threads);
|
||||
}
|
||||
#endif
|
||||
|
||||
if (lctx.backend_cpu != nullptr) {
|
||||
ggml_backend_cpu_set_n_threads(lctx.backend_cpu, n_threads);
|
||||
ggml_backend_cpu_set_abort_callback(lctx.backend_cpu, lctx.abort_callback, lctx.abort_callback_data);
|
||||
}
|
||||
#ifdef GGML_USE_BLAS
|
||||
if (lctx.backend_blas != nullptr) {
|
||||
ggml_backend_blas_set_n_threads(lctx.backend_blas, n_threads);
|
||||
}
|
||||
#endif
|
||||
|
||||
ggml_backend_sched_graph_compute_async(sched, gf);
|
||||
}
|
||||
|
||||
static bool prepare_mtp_graph_inputs(
|
||||
struct llama_context & lctx,
|
||||
uint32_t cur_token,
|
||||
@ -5006,6 +5047,16 @@ static bool prepare_mtp_graph_inputs(
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool dflash_layer_has_attention_bias(const llama_layer & layer) {
|
||||
return layer.bq != nullptr ||
|
||||
layer.bk != nullptr ||
|
||||
layer.bv != nullptr ||
|
||||
layer.bo != nullptr ||
|
||||
layer.bqkv != nullptr ||
|
||||
layer.bqk != nullptr ||
|
||||
layer.bkv != nullptr;
|
||||
}
|
||||
|
||||
// decode a batch of tokens by evaluating the transformer
|
||||
//
|
||||
// - lctx: llama context
|
||||
@ -5034,6 +5085,8 @@ static int llama_decode_internal(
|
||||
const auto & hparams = model.hparams;
|
||||
const auto & cparams = lctx.cparams;
|
||||
|
||||
llama_begin_dflash_capture_batch(&lctx);
|
||||
|
||||
GGML_ASSERT((!batch_all.token && batch_all.embd) || (batch_all.token && !batch_all.embd)); // NOLINT
|
||||
|
||||
GGML_ASSERT(n_tokens_all <= cparams.n_batch);
|
||||
@ -5082,8 +5135,11 @@ static int llama_decode_internal(
|
||||
|
||||
// reserve output buffer
|
||||
n_outputs_embd = has_mtp && cparams.mtp_op_type == MTP_OP_NONE ? n_tokens_all : n_outputs;
|
||||
if (llama_output_reserve(lctx, std::max<size_t>(n_outputs, n_outputs_embd)) < std::max<size_t>(n_outputs, n_outputs_embd)) {
|
||||
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %zu outputs\n", __func__, std::max<size_t>(n_outputs, n_outputs_embd));
|
||||
const size_t required_outputs = std::max<size_t>(n_outputs, n_outputs_embd);
|
||||
const bool is_dflash_decode = lctx.model.arch == LLM_ARCH_DFLASH_DRAFT;
|
||||
const size_t reserved_outputs = llama_output_reserve(lctx, required_outputs);
|
||||
if (reserved_outputs < required_outputs) {
|
||||
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %zu outputs\n", __func__, required_outputs);
|
||||
return -2;
|
||||
};
|
||||
|
||||
@ -5278,7 +5334,6 @@ static int llama_decode_internal(
|
||||
auto tim2 = ggml_time_us();
|
||||
printf("prelude(...): %d us\n", int(tim2-tim1));
|
||||
#endif
|
||||
|
||||
#if IK_PRINT_TIMING
|
||||
tim1 = ggml_time_us();
|
||||
#endif
|
||||
@ -5329,10 +5384,25 @@ static int llama_decode_internal(
|
||||
}
|
||||
}
|
||||
|
||||
if (is_dflash_decode && !llama_prepare_dflash_graph_inputs(lctx, n_tokens)) {
|
||||
return GGML_STATUS_FAILED;
|
||||
}
|
||||
|
||||
// the output is always the last tensor in the graph
|
||||
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
|
||||
struct ggml_tensor * embd = nullptr;
|
||||
|
||||
// DFlash GPU argmax draft_argmax node
|
||||
if (lctx.dflash.draft_tokens_tensor != nullptr &&
|
||||
strcmp(res->name, "result_output") != 0) {
|
||||
for (int i = gf->n_nodes - 2; i >= 0; --i) {
|
||||
if (strcmp(gf->nodes[i]->name, "result_output") == 0) {
|
||||
res = gf->nodes[i];
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (lctx.n_outputs == 0) {
|
||||
// no output
|
||||
res = nullptr;
|
||||
@ -5379,10 +5449,12 @@ static int llama_decode_internal(
|
||||
tim2 = ggml_time_us();
|
||||
printf("set_inputs(...): %d us\n", int(tim2-tim1));
|
||||
#endif
|
||||
|
||||
#if IK_PRINT_TIMING
|
||||
tim1 = ggml_time_us();
|
||||
#endif
|
||||
if (lctx.dflash.kv.workspace_sync_pending) {
|
||||
llama_sync_dflash_workspace_if_pending(lctx);
|
||||
}
|
||||
llama_graph_compute(lctx, gf, n_threads);
|
||||
#if IK_PRINT_TIMING
|
||||
llama_synchronize(&lctx);
|
||||
@ -5409,7 +5481,28 @@ static int llama_decode_internal(
|
||||
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
|
||||
//}
|
||||
|
||||
lctx.dflash.draft_tokens.clear();
|
||||
if (lctx.dflash.draft_tokens_tensor != nullptr) {
|
||||
ggml_backend_t backend_argmax = ggml_backend_sched_get_tensor_backend(
|
||||
lctx.sched, lctx.dflash.draft_tokens_tensor);
|
||||
if (backend_argmax != nullptr) {
|
||||
const int64_t n_tokens_argmax = lctx.dflash.draft_tokens_tensor->ne[0];
|
||||
lctx.dflash.draft_tokens.resize((size_t) n_tokens_argmax);
|
||||
ggml_backend_tensor_get_async(backend_argmax,
|
||||
lctx.dflash.draft_tokens_tensor,
|
||||
lctx.dflash.draft_tokens.data(), 0,
|
||||
(size_t) n_tokens_argmax * sizeof(int32_t));
|
||||
}
|
||||
}
|
||||
|
||||
// extract logits
|
||||
{
|
||||
const bool dflash_skip_logits = (lctx.model.arch == LLM_ARCH_DFLASH_DRAFT
|
||||
&& !lctx.dflash.draft_tokens.empty());
|
||||
if (dflash_skip_logits) {
|
||||
res = nullptr;
|
||||
}
|
||||
}
|
||||
if (res) {
|
||||
#if IK_PRINT_TIMING
|
||||
tim1 = ggml_time_us();
|
||||
@ -7447,6 +7540,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
|
||||
case LLM_ARCH_LAGUNA:
|
||||
case LLM_ARCH_GEMMA4:
|
||||
case LLM_ARCH_GEMMA4_MTP:
|
||||
case LLM_ARCH_DFLASH_DRAFT:
|
||||
case LLM_ARCH_GEMMA4_ASSISTANT:
|
||||
return LLAMA_ROPE_TYPE_NEOX;
|
||||
|
||||
@ -9662,6 +9756,13 @@ float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
|
||||
}
|
||||
}
|
||||
|
||||
llama_token llama_get_dflash_draft_token_ith(struct llama_context * ctx, int32_t i) {
|
||||
if ((size_t) i >= ctx->dflash.draft_tokens.size()) {
|
||||
return LLAMA_TOKEN_NULL;
|
||||
}
|
||||
return ctx->dflash.draft_tokens[(size_t) i];
|
||||
}
|
||||
|
||||
float * llama_get_embeddings(struct llama_context * ctx) {
|
||||
llama_synchronize(ctx);
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user