Merge pull request #1970 from SamuelOliveirads/feat/dflash-implementation

Add DFlash support
This commit is contained in:
Kawrakow 2026-06-16 15:07:55 +02:00 committed by GitHub
commit f9078e169b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
31 changed files with 3932 additions and 323 deletions

View File

@ -17,3 +17,5 @@ exclude =
# This contains builds that we don't want to check
dist # This is generated with `python build .` for package releases
# max-complexity = 10
per-file-ignores =
gguf-py/gguf/constants.py: E201, E222

View File

@ -157,6 +157,9 @@ common_params_speculative common_params_speculative::with_stage_overrides(const
if (stage.has_p_min_override()) {
result.p_min = stage.p_min;
}
if (stage.has_dflash_cross_ctx_override()) {
result.dflash_cross_ctx = stage.dflash_cross_ctx;
}
if (stage.has_ngram_size_n_override()) {
result.ngram_size_n = stage.ngram_size_n;
result.ngram_mod.reset();
@ -212,6 +215,7 @@ bool common_params_speculative::has_composite_stage_chain() const {
bool common_params_speculative::needs_dft_model() const {
return has_stage_type(COMMON_SPECULATIVE_TYPE_DRAFT) ||
has_stage_type(COMMON_SPECULATIVE_TYPE_DFLASH) ||
(has_stage_type(COMMON_SPECULATIVE_TYPE_MTP) && has_dft());
}
@ -287,8 +291,12 @@ bool common_speculative_validate_chain(const common_params_speculative & params,
return fail("speculative stage has n_min greater than n_max");
}
if (stage.type == COMMON_SPECULATIVE_TYPE_DRAFT && !params.has_dft()) {
return fail("draft speculative stage requires a draft model or draft params");
if ((stage.type == COMMON_SPECULATIVE_TYPE_DRAFT || stage.type == COMMON_SPECULATIVE_TYPE_DFLASH) && !params.has_dft()) {
return fail(common_speculative_type_to_str(stage.type) + " speculative stage requires a draft model or draft params");
}
if (stage.type == COMMON_SPECULATIVE_TYPE_DFLASH && stage_params.dflash_cross_ctx < 1) {
return fail("dflash speculative stage requires cross_ctx >= 1");
}
}
@ -906,6 +914,13 @@ static void common_speculative_stage_apply_kv(
}
return;
}
if (key == "cross_ctx" || key == "dflash_cross_ctx") {
stage.dflash_cross_ctx = std::stoi(value_raw);
if (stage.dflash_cross_ctx < 1) {
throw std::invalid_argument("speculative stage dflash cross_ctx must be at least 1");
}
return;
}
if (key == "ngram_size_n") {
stage.ngram_size_n = std::stoi(value_raw);
if (stage.ngram_size_n < 1 || stage.ngram_size_n > 1024) {
@ -3253,11 +3268,12 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
" gpu-fallback copy state to GPU buffer; re-decode on rejection\n"
" cpu serialise state via llama_state_seq; re-decode on rejection" });
options.push_back({ "*", "--spec-type SPEC[:k=v,...]", "canonical speculative stage entry; repeat for a supported two-stage chain.\n"
"types: none, draft, mtp, ngram-cache, ngram-simple, ngram-map-k, ngram-map-k4v, ngram-mod, suffix\n"
"canonical keys: n_max,n_min,p_min,ngram_size_n,ngram_size_m,ngram_min_hits,suffix_min_match_len,suffix_max_depth,suffix_corpus\n"
"types: none, draft, dflash, mtp, ngram-cache, ngram-simple, ngram-map-k, ngram-map-k4v, ngram-mod, suffix\n"
"canonical keys: n_max,n_min,p_min,cross_ctx,ngram_size_n,ngram_size_m,ngram_min_hits,suffix_min_match_len,suffix_max_depth,suffix_corpus\n"
"for comma-bearing string values, quote the value inside the stage payload for normal shell use\n"
"if argv is passed directly without shell unescaping, the parser also accepts escaped commas as \\,\n"
"examples: --spec-type mtp:n_max=1,p_min=0.0\n"
" --model-draft draft.gguf --spec-type dflash:n_max=4,cross_ctx=512\n"
" --spec-type ngram-mod:n_max=64,n_min=2,ngram_size_n=8 --spec-type mtp:n_max=1,p_min=0.0\n"
" --spec-type \"suffix:n_max=16,n_min=2,suffix_min_match_len=5,suffix_max_depth=64,suffix_corpus='/tmp/spec,type-corpus.json'\"\n"
"legacy --spec-stage, --draft-*, --spec-ngram-*, --suffix-* and -mtp flags are rejected" });

View File

@ -140,6 +140,7 @@ thinking_tokens thinking_tokens_from_string(const std::string& format);
enum common_speculative_type {
COMMON_SPECULATIVE_TYPE_NONE, // no speculative decoding
COMMON_SPECULATIVE_TYPE_DRAFT, // draft model
COMMON_SPECULATIVE_TYPE_DFLASH, // DFlash draft model
COMMON_SPECULATIVE_TYPE_MTP, // MTP model
COMMON_SPECULATIVE_TYPE_EAGLE3, // eagle draft model
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding
@ -162,6 +163,7 @@ struct common_speculative_stage_params {
int32_t n_max = -1;
int32_t n_min = -1;
float p_min = -1.0f;
int32_t dflash_cross_ctx = -1;
uint16_t ngram_size_n = 0;
uint16_t ngram_size_m = 0;
@ -174,6 +176,7 @@ struct common_speculative_stage_params {
bool has_n_max_override() const { return n_max >= 0; }
bool has_n_min_override() const { return n_min >= 0; }
bool has_p_min_override() const { return p_min >= 0.0f; }
bool has_dflash_cross_ctx_override() const { return dflash_cross_ctx >= 0; }
bool has_ngram_size_n_override() const { return ngram_size_n > 0; }
bool has_ngram_size_m_override() const { return ngram_size_m > 0; }
bool has_ngram_min_hits_override() const { return ngram_min_hits > 0; }
@ -206,6 +209,7 @@ struct common_params_speculative {
int32_t n_max = 16; // number of tokens to draft during speculative decoding
int32_t n_min = 0; // minimum number of tokens to draft during speculative decoding
std::vector<common_speculative_stage_params> stages; // explicit stage chain for single-spec or self-spec + model fallback
int32_t dflash_cross_ctx = 512; // target-feature context window for DFlash
float p_split = 0.1f; // speculative decoding split probability
float p_min = 0.75f; // minimum speculative decoding probability (greedy)

View File

@ -0,0 +1,530 @@
#pragma once
#include <algorithm>
#include <cstddef>
#include <cstring>
#include <vector>
static bool common_speculative_are_dflash_compatible(
const llama_model * model_tgt,
const llama_model * model_dft) {
const llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt);
const llama_vocab * vocab_dft = llama_model_get_vocab(model_dft);
if (llama_vocab_type(vocab_tgt) != llama_vocab_type(vocab_dft)) {
LOG_DBG("%s: DFlash draft model vocab type must match the target model\n", __func__);
return false;
}
const bool add_bos_tgt = llama_vocab_get_add_bos(vocab_tgt);
const bool add_bos_dft = llama_vocab_get_add_bos(vocab_dft);
const bool add_eos_tgt = llama_vocab_get_add_eos(vocab_tgt);
const bool add_eos_dft = llama_vocab_get_add_eos(vocab_dft);
const llama_token bos_tgt = llama_vocab_bos(vocab_tgt);
const llama_token bos_dft = llama_vocab_bos(vocab_dft);
const llama_token eos_tgt = llama_vocab_eos(vocab_tgt);
const llama_token eos_dft = llama_vocab_eos(vocab_dft);
if (add_bos_tgt != add_bos_dft || add_eos_tgt != add_eos_dft ||
(add_bos_tgt && bos_tgt != bos_dft) ||
(add_eos_tgt && eos_tgt != eos_dft)) {
LOG_DBG("%s: DFlash draft special tokens must match the target model (add_bos=%d/%d add_eos=%d/%d bos=%d/%d eos=%d/%d)\n",
__func__,
(int) add_bos_tgt,
(int) add_bos_dft,
(int) add_eos_tgt,
(int) add_eos_dft,
(int) bos_tgt,
(int) bos_dft,
(int) eos_tgt,
(int) eos_dft);
return false;
}
const int n_vocab_tgt = llama_vocab_n_tokens(vocab_tgt);
const int n_vocab_dft = llama_vocab_n_tokens(vocab_dft);
const int vocab_diff = n_vocab_tgt > n_vocab_dft
? n_vocab_tgt - n_vocab_dft
: n_vocab_dft - n_vocab_tgt;
if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
LOG_DBG("%s: DFlash draft vocab size differs too much from the target model (%d vs %d)\n",
__func__, n_vocab_dft, n_vocab_tgt);
return false;
}
for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) {
const char * token_text_tgt = llama_vocab_get_text(vocab_tgt, i);
const char * token_text_dft = llama_vocab_get_text(vocab_dft, i);
if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
LOG_DBG("%s: DFlash draft token %d differs - target '%s', draft '%s'\n", __func__, i,
common_token_to_piece(vocab_tgt, i).c_str(),
common_token_to_piece(vocab_dft, i).c_str());
return false;
}
}
return true;
}
struct common_speculative_state_dflash;
static void dflash_materialize_target_window_features(common_speculative_state_dflash & state);
// DFlash runtime state and draft path.
struct common_speculative_state_dflash : public common_speculative_state {
llama_context * ctx_tgt;
llama_context * ctx_dft;
llama_batch batch = {};
int32_t block_size = 0;
int32_t mask_token_id = -1;
int32_t n_target_features = 0;
int32_t cross_ctx = 0;
bool ready = false;
std::vector<int32_t> target_layer_ids;
std::vector<float> target_window;
std::vector<llama_pos> target_window_pos;
std::vector<float> target_window_stage;
std::vector<llama_pos> target_window_pos_stage;
std::vector<float> target_window_ring;
std::vector<float> target_window_append_features;
int32_t target_window_rows = 0;
int32_t target_window_ring_write_pos = 0;
int32_t target_window_ring_filled = 0;
uint64_t target_window_version = 0;
int32_t target_window_keep_rows = 0;
int32_t target_window_append_rows = 0;
bool target_window_replace = false;
bool target_window_materialized = false;
llama_pos last_target_pos = -1;
common_speculative_state_dflash(
enum common_speculative_type type,
llama_context * ctx_tgt,
llama_context * ctx_dft,
int32_t cross_ctx)
: common_speculative_state(type)
, ctx_tgt(ctx_tgt)
, ctx_dft(ctx_dft)
, cross_ctx(std::max(1, cross_ctx))
{
const llama_model * model_tgt = llama_get_model(ctx_tgt);
const llama_model * model_dft = llama_get_model(ctx_dft);
if (!common_speculative_are_dflash_compatible(model_tgt, model_dft)) {
LOG_ERR("%s: DFlash draft model vocab/tokenizer is incompatible with the target model\n", __func__);
return;
}
block_size = llama_model_dflash_block_size(model_dft);
mask_token_id = llama_model_dflash_mask_token_id(model_dft);
n_target_features = llama_model_dflash_n_target_features(model_dft);
const int32_t n_target_layers = llama_model_dflash_n_target_layers(model_dft);
if (block_size <= 0 || mask_token_id < 0 || n_target_features <= 0 || n_target_layers <= 0) {
LOG_ERR("%s: invalid DFlash metadata (block_size=%d, mask_token_id=%d, n_target_features=%d, n_target_layers=%d)\n",
__func__, block_size, mask_token_id, n_target_features, n_target_layers);
return;
}
target_layer_ids.resize((size_t) n_target_layers);
if (llama_model_dflash_target_layer_ids(model_dft, target_layer_ids.data(), n_target_layers) != n_target_layers) {
LOG_ERR("%s: failed to read DFlash target layer ids\n", __func__);
target_layer_ids.clear();
return;
}
const auto * vocab_tgt = llama_model_get_vocab(model_tgt);
const int32_t target_vocab_size = llama_vocab_n_tokens(vocab_tgt);
const int32_t target_hidden_size = llama_model_n_embd(model_tgt);
const int32_t draft_hidden_size = llama_model_n_embd(model_dft);
const int32_t target_mask_token_id = llama_model_dflash_target_mask_token_id(model_tgt);
const int32_t expected_n_target_features = target_hidden_size > 0 ? target_hidden_size * n_target_layers : 0;
if (target_mask_token_id != (int32_t) LLAMA_TOKEN_NULL && mask_token_id != target_mask_token_id) {
LOG_ERR("%s: DFlash mask token mismatch (draft=%d target=%d)\n",
__func__, mask_token_id, target_mask_token_id);
return;
}
if (target_hidden_size <= 0 || draft_hidden_size <= 0) {
LOG_ERR("%s: invalid DFlash hidden sizes (draft=%d target=%d)\n",
__func__, draft_hidden_size, target_hidden_size);
return;
}
if (expected_n_target_features <= 0 || n_target_features != expected_n_target_features) {
LOG_ERR("%s: DFlash target feature width mismatch (metadata=%d expected=%d target_hidden=%d target_layers=%d)\n",
__func__, n_target_features, expected_n_target_features, target_hidden_size, n_target_layers);
return;
}
std::vector<int32_t> sorted_target_layer_ids = target_layer_ids;
std::sort(sorted_target_layer_ids.begin(), sorted_target_layer_ids.end());
if (std::adjacent_find(sorted_target_layer_ids.begin(), sorted_target_layer_ids.end()) != sorted_target_layer_ids.end()) {
LOG_ERR("%s: duplicate DFlash target layer ids survived into runtime validation\n", __func__);
target_layer_ids.clear();
return;
}
const int32_t n_target_model_layers = llama_n_layer(model_tgt);
for (int32_t layer_id : target_layer_ids) {
if (layer_id < 0 || layer_id >= n_target_model_layers) {
LOG_ERR("%s: invalid DFlash target layer id %d for target model with %d layers\n",
__func__, layer_id, n_target_model_layers);
target_layer_ids.clear();
return;
}
}
const int32_t io_mode = llama_model_dflash_io_mode(model_dft, model_tgt);
if (io_mode == LLAMA_DFLASH_IO_MODE_INVALID) {
LOG_ERR("%s: DFlash draft is missing required IO tensors after target sharing\n", __func__);
return;
}
if (io_mode == LLAMA_DFLASH_IO_MODE_MIXED) {
LOG_ERR("%s: DFlash IO contract must be fully shared or fully self-contained, but resolved to mixed mode\n", __func__);
return;
}
if (io_mode == LLAMA_DFLASH_IO_MODE_SELF_CONTAINED && !llama_model_dflash_io_tensors_match(model_dft, target_hidden_size, target_vocab_size)) {
LOG_ERR("%s: DFlash self-contained IO tensors do not match the target hidden/vocab contract (target_hidden=%d target_vocab=%d)\n",
__func__,
target_hidden_size,
target_vocab_size);
return;
}
if (!llama_set_dflash_capture_layers(ctx_tgt, target_layer_ids.data(), (int32_t) target_layer_ids.size())) {
LOG_ERR("%s: failed to configure DFlash target capture callback\n", __func__);
return;
}
batch = llama_batch_init(std::max(1, block_size), 0, 1);
target_window.reserve((size_t) this->cross_ctx * (size_t) n_target_features);
target_window_stage.reserve((size_t) this->cross_ctx * (size_t) n_target_features);
target_window_ring.resize((size_t) this->cross_ctx * (size_t) n_target_features);
target_window_append_features.reserve((size_t) this->cross_ctx * (size_t) n_target_features);
target_window_pos.reserve((size_t) this->cross_ctx);
target_window_pos_stage.reserve((size_t) this->cross_ctx);
ready = true;
llama_set_dflash_visible_cross_ctx(ctx_dft, this->cross_ctx);
LOG_INF("%s: DFlash context ready (n_ctx=%d, block_size=%d, cross_ctx=%d, n_target_features=%d, n_target_layers=%d)\n",
__func__, llama_n_ctx(ctx_dft), block_size, this->cross_ctx, n_target_features, n_target_layers);
}
~common_speculative_state_dflash() override {
llama_clear_dflash_capture(ctx_tgt);
if (ctx_dft) {
llama_free(ctx_dft);
}
if (batch.token != nullptr) {
llama_batch_free(batch);
}
}
void begin(const llama_tokens & prompt) override {
GGML_UNUSED(prompt);
llama_kv_cache_clear(ctx_dft);
llama_reset_dflash_kv_cache_state(ctx_dft);
}
void draft(
const common_params_speculative & params,
const llama_tokens & prompt_tgt,
llama_token id_last,
llama_tokens & result) override {
GGML_UNUSED(prompt_tgt);
result.clear();
if (!ready || target_window_rows <= 0) {
return;
}
const int32_t n_keep = std::min<int32_t>(params.n_max, block_size - 1);
if (n_keep <= 0) {
return;
}
const float * target_features = nullptr;
size_t target_feature_floats = 0;
llama_dflash_window_update window_update = {
target_window_version,
target_window_keep_rows,
target_window_append_rows,
target_window_replace,
target_window_append_features.empty() ? nullptr : target_window_append_features.data(),
target_window_append_features.size(),
};
const llama_dflash_kv_cache_transition cache_plan =
llama_plan_dflash_kv_cache_transition_for_ctx(ctx_dft, window_update, target_window_rows);
if (cache_plan.rebuild_cache) {
dflash_materialize_target_window_features(*this);
target_features = target_window.data();
target_feature_floats = target_window.size();
window_update.append_features = target_window.data();
window_update.append_floats = target_window.size();
window_update.append_rows = target_window_rows;
}
if (!llama_set_dflash_target_features_view(ctx_dft, target_features, target_feature_floats, target_window_rows, target_window_pos.data(), &window_update)) {
LOG_ERR("%s: failed to set DFlash target features\n", __func__);
return;
}
llama_kv_cache_clear(ctx_dft);
batch.n_tokens = 0;
const int32_t batch_len = n_keep + 1;
const llama_pos draft_pos_base = last_target_pos >= 0 ? last_target_pos + 1 : (llama_pos) target_window_rows;
const llama_pos seed_pos = last_target_pos >= 0 ? last_target_pos : draft_pos_base - 1;
common_batch_add(batch, id_last, seed_pos, { 0 }, false);
for (int32_t i = 1; i < batch_len; ++i) {
common_batch_add(batch, mask_token_id, draft_pos_base + (i - 1), { 0 }, i <= n_keep);
}
if (llama_decode(ctx_dft, batch) != 0) {
LOG_ERR("%s: llama_decode() failed for DFlash draft batch\n", __func__);
batch.n_tokens = 0;
return;
}
result.reserve((size_t) n_keep);
for (int32_t i = 0; i < n_keep; ++i) {
llama_token id = llama_get_dflash_draft_token_ith(ctx_dft, i);
if (id == LLAMA_TOKEN_NULL) {
id = common_sampler_sample_speculative(nullptr, ctx_dft, i + 1, nullptr);
}
result.push_back(id);
}
batch.n_tokens = 0;
}
void accept(uint16_t n_accepted) override {
GGML_UNUSED(n_accepted);
}
};
static void dflash_record_window_update(
common_speculative_state_dflash & state,
int32_t keep_rows,
int32_t append_rows,
bool replace) {
state.target_window_keep_rows = std::max<int32_t>(0, keep_rows);
state.target_window_append_rows = std::max<int32_t>(0, append_rows);
state.target_window_replace = replace;
state.target_window_version++;
}
static void dflash_ring_reset_rows(
common_speculative_state_dflash & state,
const float * rows,
int32_t n_rows) {
const size_t row_width = (size_t) state.n_target_features;
if (n_rows <= 0 || rows == nullptr) {
state.target_window_ring_write_pos = 0;
state.target_window_ring_filled = 0;
return;
}
if (state.target_window_ring.size() != (size_t) state.cross_ctx * row_width) {
state.target_window_ring.resize((size_t) state.cross_ctx * row_width);
}
std::memcpy(state.target_window_ring.data(), rows, (size_t) n_rows * row_width * sizeof(float));
state.target_window_ring_write_pos = n_rows % state.cross_ctx;
state.target_window_ring_filled = n_rows;
state.target_window_materialized = false;
}
static void dflash_ring_append_rows(
common_speculative_state_dflash & state,
const float * rows,
int32_t n_rows) {
const size_t row_width = (size_t) state.n_target_features;
if (n_rows <= 0 || rows == nullptr) {
return;
}
if (state.target_window_ring.size() != (size_t) state.cross_ctx * row_width) {
state.target_window_ring.resize((size_t) state.cross_ctx * row_width);
}
int32_t write_pos = state.target_window_ring_write_pos;
int32_t remaining = n_rows;
const float * src = rows;
while (remaining > 0) {
const int32_t chunk_rows = std::min<int32_t>(remaining, state.cross_ctx - write_pos);
std::memcpy(
state.target_window_ring.data() + (size_t) write_pos * row_width,
src,
(size_t) chunk_rows * row_width * sizeof(float));
src += (size_t) chunk_rows * row_width;
remaining -= chunk_rows;
write_pos = (write_pos + chunk_rows) % state.cross_ctx;
}
state.target_window_ring_write_pos = write_pos;
state.target_window_ring_filled = std::min(state.cross_ctx, state.target_window_ring_filled + n_rows);
state.target_window_materialized = false;
}
static void dflash_materialize_target_window_features(common_speculative_state_dflash & state) {
if (state.target_window_materialized || state.target_window_rows <= 0) {
return;
}
const size_t row_width = (size_t) state.n_target_features;
state.target_window.resize((size_t) state.target_window_rows * row_width);
const int32_t read_start = (state.target_window_ring_write_pos - state.target_window_rows + state.cross_ctx) % state.cross_ctx;
const int32_t first_rows = std::min<int32_t>(state.target_window_rows, state.cross_ctx - read_start);
std::memcpy(
state.target_window.data(),
state.target_window_ring.data() + (size_t) read_start * row_width,
(size_t) first_rows * row_width * sizeof(float));
const int32_t second_rows = state.target_window_rows - first_rows;
if (second_rows > 0) {
std::memcpy(
state.target_window.data() + (size_t) first_rows * row_width,
state.target_window_ring.data(),
(size_t) second_rows * row_width * sizeof(float));
}
state.target_window_materialized = true;
}
static bool dflash_append_target_features(
common_speculative_state_dflash & state,
const common_speculative_feature_view & features,
llama_seq_id seq_id) {
if (features.kind != COMMON_SPECULATIVE_FEATURE_HIDDEN_STATE ||
features.width != state.n_target_features ||
features.rows.empty() ||
state.cross_ctx <= 0) {
return false;
}
const size_t row_width = (size_t) state.n_target_features;
std::vector<float> new_rows;
std::vector<llama_pos> new_positions;
new_rows.reserve(features.rows.size() * row_width);
new_positions.reserve(features.rows.size());
for (const auto & row : features.rows) {
if (row.seq_id != seq_id || row.data == nullptr) {
continue;
}
new_positions.push_back(row.pos);
new_rows.insert(new_rows.end(), row.data, row.data + row_width);
}
if (new_positions.empty()) {
return false;
}
const int32_t n_rows = (int32_t) new_positions.size();
if (n_rows >= state.cross_ctx) {
const int32_t keep_from = n_rows - state.cross_ctx;
state.target_window_pos.assign(new_positions.begin() + keep_from, new_positions.end());
state.target_window_append_features.assign(
new_rows.begin() + (ptrdiff_t) keep_from * (ptrdiff_t) row_width,
new_rows.end());
dflash_ring_reset_rows(state, state.target_window_append_features.data(), state.cross_ctx);
state.target_window_rows = state.cross_ctx;
state.target_window_ring_filled = state.target_window_rows;
state.last_target_pos = state.target_window_pos.empty() ? -1 : state.target_window_pos.back();
dflash_record_window_update(state, 0, state.target_window_rows, true);
return true;
}
const int32_t keep_old_rows = std::min<int32_t>(state.target_window_rows, state.cross_ctx - n_rows);
std::vector<llama_pos> & next_window_pos = state.target_window_pos_stage;
next_window_pos.resize((size_t) (keep_old_rows + n_rows));
if (keep_old_rows > 0) {
std::copy(state.target_window_pos.end() - keep_old_rows, state.target_window_pos.end(), next_window_pos.begin());
}
state.target_window_append_features.assign(new_rows.begin(), new_rows.end());
dflash_ring_append_rows(state, state.target_window_append_features.data(), n_rows);
std::copy(new_positions.begin(), new_positions.end(), next_window_pos.begin() + keep_old_rows);
state.target_window_pos.swap(next_window_pos);
next_window_pos.clear();
state.target_window_rows = keep_old_rows + n_rows;
state.target_window_ring_filled = state.target_window_rows;
state.last_target_pos = state.target_window_pos.empty() ? -1 : state.target_window_pos.back();
dflash_record_window_update(state, keep_old_rows, n_rows, false);
return true;
}
static void dflash_clear_target_features(common_speculative_state_dflash & state) {
state.target_window.clear();
state.target_window_pos.clear();
state.target_window_stage.clear();
state.target_window_pos_stage.clear();
state.target_window_append_features.clear();
state.target_window_rows = 0;
state.target_window_ring_write_pos = 0;
state.target_window_ring_filled = 0;
state.target_window_keep_rows = 0;
state.target_window_append_rows = 0;
state.target_window_replace = false;
state.target_window_materialized = false;
state.last_target_pos = -1;
llama_reset_dflash_kv_cache_state(state.ctx_dft);
}
static void dflash_context_shift(
common_speculative_state_dflash & state,
llama_pos kv_keep,
llama_pos kv_discard,
llama_pos kv_past) {
if (kv_discard <= 0 || state.target_window_rows <= 0 || state.target_window_pos.empty()) {
return;
}
dflash_materialize_target_window_features(state);
const size_t row_width = (size_t) state.n_target_features;
const llama_pos discard_begin = kv_keep;
const llama_pos discard_end = kv_keep + kv_discard;
std::vector<float> shifted_rows;
std::vector<llama_pos> shifted_positions;
shifted_rows.reserve(state.target_window.size());
shifted_positions.reserve(state.target_window_pos.size());
for (int32_t row = 0; row < state.target_window_rows; ++row) {
llama_pos pos = state.target_window_pos[(size_t) row];
if (pos >= discard_begin && pos < discard_end) {
continue;
}
if (pos >= discard_end && pos < kv_past) {
pos -= kv_discard;
}
const float * row_src = state.target_window.data() + (size_t) row * row_width;
shifted_rows.insert(shifted_rows.end(), row_src, row_src + row_width);
shifted_positions.push_back(pos);
}
state.target_window = std::move(shifted_rows);
state.target_window_pos = std::move(shifted_positions);
state.target_window_rows = (int32_t) state.target_window_pos.size();
dflash_ring_reset_rows(state, state.target_window.data(), state.target_window_rows);
state.last_target_pos = state.target_window_pos.empty() ? -1 : state.target_window_pos.back();
dflash_record_window_update(state, 0, state.target_window_rows, true);
llama_reset_dflash_kv_cache_state(state.ctx_dft);
}

View File

@ -11,9 +11,13 @@
#include "suffix-tree.h"
#include <algorithm>
#include <atomic>
#include <cstdlib>
#include <cstring>
#include <iomanip>
#include <limits>
#include <map>
#include <sstream>
#include <unordered_map>
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
@ -24,6 +28,7 @@ void llama_set_mtp_target_context(struct llama_context * ctx, struct llama_conte
const std::vector<enum common_speculative_type> common_speculative_types = {
COMMON_SPECULATIVE_TYPE_NONE,
COMMON_SPECULATIVE_TYPE_DRAFT,
COMMON_SPECULATIVE_TYPE_DFLASH,
COMMON_SPECULATIVE_TYPE_MTP,
COMMON_SPECULATIVE_TYPE_EAGLE3,
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE,
@ -37,6 +42,7 @@ const std::vector<enum common_speculative_type> common_speculative_types = {
const std::map<std::string, enum common_speculative_type> common_speculative_type_from_name_map = {
{"none", COMMON_SPECULATIVE_TYPE_NONE},
{"draft", COMMON_SPECULATIVE_TYPE_DRAFT},
{"dflash", COMMON_SPECULATIVE_TYPE_DFLASH},
{"mtp", COMMON_SPECULATIVE_TYPE_MTP},
{"eagle3", COMMON_SPECULATIVE_TYPE_EAGLE3},
{"ngram_simple", COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE},
@ -180,9 +186,13 @@ struct common_speculative_state {
};
struct common_speculative_state_mtp;
struct common_speculative_state_dflash;
static common_speculative_state_mtp * common_speculative_get_mtp_state(common_speculative * spec);
static const common_speculative_state_mtp * common_speculative_get_mtp_state(const common_speculative * spec);
static common_speculative_state_dflash * common_speculative_get_dflash_state(common_speculative * spec);
static const common_speculative_state_dflash * common_speculative_get_dflash_state(const common_speculative * spec);
static int32_t common_speculative_feature_width(const common_speculative * spec);
static void mtp_invalidate_cached_drafts(common_speculative_state_mtp & state);
static bool common_speculative_checkpoint_save(
common_speculative_checkpoint & ckpt,
@ -325,6 +335,8 @@ struct common_speculative_state_mtp : public common_speculative_state {
}
};
#include "speculative-dflash-impl.h"
struct common_speculative_state_draft : public common_speculative_state {
llama_context * ctx_tgt; // only used for retokenizing from ctx_dft
llama_context * ctx_dft;
@ -1025,17 +1037,13 @@ struct common_speculative_state_suffix : public common_speculative_state {
};
struct common_speculative {
common_speculative_checkpoint checkpoint;
std::vector<common_speculative_config> configs; // resolved stage config for each implementation
std::vector<std::unique_ptr<common_speculative_state>> impls; // list of implementations to use and their states
common_speculative_checkpoint checkpoint;
common_speculative_state * curr_impl = nullptr; // current implementation in use (for stats)
std::unique_ptr<spec_tuner> tuner;
int last_n_drafted = 0;
int64_t t_step_start_us = 0;
~common_speculative() {
checkpoint.clear();
}
};
static bool common_speculative_stage_chain_matches(
@ -1116,6 +1124,7 @@ std::string common_speculative_type_to_str(enum common_speculative_type type) {
switch (type) {
case COMMON_SPECULATIVE_TYPE_NONE: return "none";
case COMMON_SPECULATIVE_TYPE_DRAFT: return "draft";
case COMMON_SPECULATIVE_TYPE_DFLASH: return "dflash";
case COMMON_SPECULATIVE_TYPE_MTP: return "mtp";
case COMMON_SPECULATIVE_TYPE_EAGLE3: return "eagle3";
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return "ngram_simple";
@ -1188,13 +1197,18 @@ common_speculative * common_speculative_init(
});
if (has_draft_stage) {
LOG_ERR("%s: Gemma4 assistant models only support MTP stages; omit -md for self-spec-only runs or use --spec-type mtp:n_max=1,p_min=0.0 for assistant-backed MTP\n", __func__);
LOG_ERR("%s: Gemma4 assistant models only support MTP stages; omit -md for self-spec-only runs or use -mtp/--spec-stage mtp for assistant-backed MTP\n", __func__);
return nullptr;
}
}
const bool has_dflash_stage = std::any_of(stages.begin(), stages.end(), [](const common_speculative_stage_params & stage) {
return stage.type == COMMON_SPECULATIVE_TYPE_DFLASH;
});
const bool needs_draft_ctx = std::any_of(stages.begin(), stages.end(), [&params](const common_speculative_stage_params & stage) {
return stage.type == COMMON_SPECULATIVE_TYPE_DRAFT ||
stage.type == COMMON_SPECULATIVE_TYPE_DFLASH ||
(stage.type == COMMON_SPECULATIVE_TYPE_MTP && params.model_dft != nullptr);
});
@ -1205,7 +1219,40 @@ common_speculative * common_speculative_init(
return nullptr;
}
ctx_dft = llama_init_from_model(params.model_dft, params.cparams_dft);
llama_context_params cparams_dft = params.cparams_dft;
if (has_dflash_stage) {
if (!llama_model_share_dflash_io_tensors(params.model_dft, llama_get_model(ctx_tgt))) {
LOG_ERR("%s: failed to share target IO tensors with DFlash draft model\n", __func__);
return nullptr;
}
int32_t max_cross_ctx = 0;
for (const auto & stage : stages) {
if (stage.type != COMMON_SPECULATIVE_TYPE_DFLASH) {
continue;
}
max_cross_ctx = std::max(max_cross_ctx, params.with_stage_overrides(stage).dflash_cross_ctx);
}
const int32_t block_size = llama_model_dflash_block_size(params.model_dft);
if (block_size <= 0) {
LOG_ERR("%s: invalid DFlash draft block size\n", __func__);
return nullptr;
}
const int64_t required_n_ctx = (int64_t) max_cross_ctx + (int64_t) block_size;
if (required_n_ctx > std::numeric_limits<int32_t>::max()) {
LOG_ERR("%s: invalid DFlash draft context size cross_ctx=%d block_size=%d required_n_ctx=%lld\n",
__func__, max_cross_ctx, block_size, (long long) required_n_ctx);
return nullptr;
}
cparams_dft.n_ctx = (uint32_t) required_n_ctx;
}
ctx_dft = llama_init_from_model(params.model_dft, cparams_dft);
if (ctx_dft == nullptr) {
LOG_ERR("%s", "failed to create draft context\n");
return nullptr;
@ -1268,6 +1315,20 @@ common_speculative * common_speculative_init(
));
break;
}
case COMMON_SPECULATIVE_TYPE_DFLASH: {
auto state = std::make_unique<common_speculative_state_dflash>(
config.type,
ctx_tgt,
ctx_dft,
config.params.dflash_cross_ctx);
if (!state->ready) {
LOG_ERR("%s: failed to initialize DFlash speculative state\n", __func__);
return nullptr;
}
impls.push_back(std::move(state));
ctx_dft = nullptr;
break;
}
case COMMON_SPECULATIVE_TYPE_MTP: {
llama_context * ctx_mtp = ctx_dft;
if (!ctx_mtp) {
@ -1343,7 +1404,6 @@ common_speculative * common_speculative_init(
}
auto * result = new common_speculative {
/* .checkpoint = */ {},
/* .configs = */ std::move(configs),
/* .impls = */ std::move(impls)
};
@ -1369,175 +1429,12 @@ common_speculative * common_speculative_init(
return result;
}
common_speculative_init_status common_speculative_try_init(
common_params_speculative & params,
llama_context * ctx_tgt,
common_speculative ** out_spec) {
if (out_spec != nullptr) {
*out_spec = nullptr;
}
if (!params.has_stage_chain()) {
return COMMON_SPECULATIVE_INIT_SKIPPED;
}
common_speculative * spec = common_speculative_init(params, ctx_tgt);
if (spec != nullptr) {
if (out_spec != nullptr) {
*out_spec = spec;
}
return COMMON_SPECULATIVE_INIT_READY;
}
const llama_model * model = ctx_tgt != nullptr ? llama_get_model(ctx_tgt) : nullptr;
if (model != nullptr && llama_model_has_recurrent(model)) {
return COMMON_SPECULATIVE_INIT_ERR_RECURRENT;
}
if (params.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP)) {
return COMMON_SPECULATIVE_INIT_ERR_MTP;
}
return COMMON_SPECULATIVE_INIT_ERR_GENERIC;
}
void common_speculative_prepare_startup(
gpt_params & params_base,
bool allow_parallel_mtp) {
auto & params = params_base.speculative;
if (!allow_parallel_mtp && params_base.n_parallel > 1 && params.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP)) {
LOG_WRN("%s: MTP is not supported with parallel slots yet, removing the MTP stage to avoid cross-slot corruption. n_parallel=%d, stage_chain=%s\n",
__func__, params_base.n_parallel, common_speculative_stage_chain_to_str(params).c_str());
params.remove_stage_type(COMMON_SPECULATIVE_TYPE_MTP);
}
if (!params.needs_dft_model()) {
params.clear_dft();
}
params_base.has_mtp = params.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP);
}
bool common_speculative_finalize_startup(
gpt_params & params_base,
const llama_model * model) {
auto & params = params_base.speculative;
if (!params.needs_dft_model()) {
params.clear_dft();
}
if (params.has_dft()) {
LLAMA_LOG_INFO("\n\n==================================loading DRAFT model==================================\n\n");
if (!common_speculative_load_draft_model(params, params_base)) {
return false;
}
}
params_base.has_mtp = params.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP);
const bool has_external_mtp = params_base.has_mtp &&
llama_model_is_gemma4_mtp_assistant(params.model_dft);
params_base.has_mtp = common_speculative_prepare_mtp_runtime(
params,
params_base,
model,
has_external_mtp);
if (params_base.has_mtp) {
params_base.pooling_type = LLAMA_POOLING_TYPE_NONE;
}
return true;
}
bool common_speculative_load_draft_model(
common_params_speculative & params,
const gpt_params & params_base) {
if (!params.has_dft()) {
return true;
}
gpt_params params_dft;
params_dft.devices = params.devices;
params_dft.model = params.model;
params_dft.main_gpu = params_base.main_gpu;
params_dft.n_gpu_layers = params.n_gpu_layers;
params_dft.rpc_servers = params_base.rpc_servers;
params_dft.cache_type_k = params.cache_type_k.empty() ? params_base.cache_type_k : params.cache_type_k;
params_dft.cache_type_v = params.cache_type_v.empty() ? params_base.cache_type_v : params.cache_type_v;
params_dft.flash_attn = params_base.flash_attn;
params_dft.k_cache_hadamard = params_base.k_cache_hadamard;
params_dft.v_cache_hadamard = params_base.v_cache_hadamard;
if (!params.params.empty()) {
auto [argc, argv] = parse_command_line("llama-server " + params.params);
if (!gpt_params_parse(argc, argv, params_dft)) {
gpt_params_print_usage(argc, argv, params_dft);
free_command_line(argc, argv);
return false;
}
free_command_line(argc, argv);
}
LOG_INF("%s: loading draft model '%s'\n", __func__, params_dft.model.c_str());
if (params_dft.n_ctx == 0) {
params_dft.n_ctx = params.n_ctx;
}
params_dft.n_ctx = params_dft.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_dft.n_ctx;
params_dft.n_parallel = 1;
params_dft.n_batch = params_dft.n_ctx;
params.mparams_dft.path = params_dft.model;
llama_model_params mparams_dft = common_model_params_to_llama(params_dft);
llama_model * loaded_model = llama_model_load_from_file(params_dft.model.c_str(), mparams_dft);
if (loaded_model == nullptr) {
LOG_ERR("%s: failed to load draft model '%s'\n", __func__, params.model.c_str());
return false;
}
params.model_dft = loaded_model;
params.cparams_dft = common_context_params_to_llama(params_dft);
return true;
}
bool common_speculative_prepare_mtp_runtime(
common_params_speculative & params,
const gpt_params & params_base,
const llama_model * model,
bool has_external_mtp) {
if (!params.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP)) {
return false;
}
if (llama_model_n_nextn_layer(model) == 0 && !has_external_mtp) {
LOG_WRN("%s: MTP speculative stage requested, but model has 0 NextN layers. Removing MTP from the configured stage chain.\n",
__func__);
params.remove_stage_type(COMMON_SPECULATIVE_TYPE_MTP);
if (!params.needs_dft_model()) {
params.clear_dft();
}
return false;
}
if (!has_external_mtp) {
gpt_params params_mtp = params_base;
params_mtp.pooling_type = LLAMA_POOLING_TYPE_NONE;
params.cparams_dft = common_context_params_to_llama(params_mtp);
}
params.cparams_dft.mtp = true;
params.cparams_dft.mtp_op_type = MTP_OP_WARMUP;
params.cparams_dft.embeddings = true;
return true;
}
void common_speculative_free(common_speculative * spec) {
if (spec == nullptr) {
return;
}
spec->checkpoint.clear();
delete spec;
}
@ -1546,11 +1443,6 @@ void common_speculative_begin(common_speculative * spec, const llama_tokens & pr
return;
}
spec->checkpoint.clear();
spec->curr_impl = nullptr;
spec->last_n_drafted = 0;
spec->t_step_start_us = 0;
for (auto & impl : spec->impls) {
common_time_meas tm(impl->t_begin_us, !impl->gen_perf);
impl->begin(prompt);
@ -1654,34 +1546,6 @@ void common_speculative_accept(common_speculative * spec, uint16_t n_accepted) {
}
}
bool common_speculative_before_draft(
common_speculative * spec,
llama_model * model,
llama_context * ctx,
common_sampler * sampler_src,
const common_params_sampling & sparams,
llama_seq_id seq_id,
llama_pos n_past,
llama_token sampled,
int max_tokens,
int ckpt_mode) {
if (spec == nullptr) {
return false;
}
return common_speculative_checkpoint_save(
spec->checkpoint,
model,
ctx,
sampler_src,
sparams,
seq_id,
n_past,
sampled,
max_tokens,
ckpt_mode);
}
static bool common_speculative_has_type(const common_speculative * spec, common_speculative_type type) {
if (spec == nullptr) {
return false;
@ -1830,6 +1694,10 @@ static bool common_speculative_collect_target_batch_features(
const llama_batch & batch,
common_speculative_feature_view & features) {
features = {};
if (common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_DFLASH)) {
return llama_spec_get_dflash_feature_view(ctx, batch, features);
}
if (!common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_MTP)) {
return true;
}
@ -1848,6 +1716,10 @@ static bool common_speculative_collect_target_seq_batch_features(
llama_seq_id seq_id,
common_speculative_feature_view & features) {
features = {};
if (common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_DFLASH)) {
return llama_spec_get_dflash_feature_view_for_seq(ctx, batch, seq_id, features);
}
if (!common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_MTP)) {
return true;
}
@ -1921,27 +1793,246 @@ common_speculative_draft_result common_speculative_draft_ex(
return result;
}
static bool common_speculative_has_target_features(const common_speculative * spec) {
return common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_MTP) ||
common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_DFLASH);
}
bool common_speculative_load_draft_model(
common_params_speculative & params,
const gpt_params & params_base) {
if (!params.has_dft()) {
return true;
}
gpt_params params_dft;
params_dft.devices = params.devices;
params_dft.model = params.model;
params_dft.main_gpu = params_base.main_gpu;
params_dft.n_gpu_layers = params.n_gpu_layers;
params_dft.rpc_servers = params_base.rpc_servers;
params_dft.cache_type_k = params.cache_type_k.empty() ? params_base.cache_type_k : params.cache_type_k;
params_dft.cache_type_v = params.cache_type_v.empty() ? params_base.cache_type_v : params.cache_type_v;
params_dft.flash_attn = params_base.flash_attn;
params_dft.k_cache_hadamard = params_base.k_cache_hadamard;
params_dft.v_cache_hadamard = params_base.v_cache_hadamard;
if (params.has_stage_type(COMMON_SPECULATIVE_TYPE_DFLASH)) {
params_dft.split_mode = params_base.split_mode;
for (size_t i = 0; i < std::size(params_dft.tensor_split); ++i) {
params_dft.tensor_split[i] = params_base.tensor_split[i];
}
params_dft.attn_max_batch = params_base.attn_max_batch;
params_dft.graph_reuse = params_base.graph_reuse;
params_dft.split_mode_graph_scheduling = params_base.split_mode_graph_scheduling;
params_dft.scheduler_async = params_base.scheduler_async;
params_dft.max_extra_alloc_MiB = params_base.max_extra_alloc_MiB;
params_dft.reduce_type = params_base.reduce_type;
}
if (!params.params.empty()) {
auto [argc, argv] = parse_command_line("llama-server " + params.params);
if (!gpt_params_parse(argc, argv, params_dft)) {
gpt_params_print_usage(argc, argv, params_dft);
free_command_line(argc, argv);
return false;
}
free_command_line(argc, argv);
}
LOG_INF("%s: loading draft model '%s'\n", __func__, params_dft.model.c_str());
if (params_dft.n_ctx == 0) {
params_dft.n_ctx = params.n_ctx;
}
if (params.has_stage_type(COMMON_SPECULATIVE_TYPE_DFLASH) && params_dft.n_gpu_layers < 0) {
params_dft.n_gpu_layers = params_base.n_gpu_layers;
}
params_dft.n_ctx = params_dft.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_dft.n_ctx;
params_dft.n_parallel = 1;
params_dft.n_batch = params_dft.n_ctx;
params.mparams_dft.path = params_dft.model;
llama_model_params mparams_dft = common_model_params_to_llama(params_dft);
llama_model * loaded_model = llama_model_load_from_file(params_dft.model.c_str(), mparams_dft);
if (loaded_model == nullptr) {
LOG_ERR("%s: failed to load draft model '%s'\n", __func__, params.model.c_str());
return false;
}
params.model_dft = loaded_model;
params.cparams_dft = common_context_params_to_llama(params_dft);
return true;
}
bool common_speculative_prepare_mtp_runtime(
common_params_speculative & params,
const gpt_params & params_base,
const llama_model * model,
bool has_external_mtp) {
if (!params.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP)) {
return false;
}
if (llama_model_n_nextn_layer(model) == 0 && !has_external_mtp) {
LOG_WRN("%s: MTP speculative stage requested, but model has 0 NextN layers. Removing MTP from the configured stage chain.\n",
__func__);
params.remove_stage_type(COMMON_SPECULATIVE_TYPE_MTP);
if (!params.needs_dft_model()) {
params.clear_dft();
}
return false;
}
if (!has_external_mtp) {
gpt_params params_mtp = params_base;
params_mtp.pooling_type = LLAMA_POOLING_TYPE_NONE;
params.cparams_dft = common_context_params_to_llama(params_mtp);
}
params.cparams_dft.mtp = true;
params.cparams_dft.mtp_op_type = MTP_OP_WARMUP;
params.cparams_dft.embeddings = true;
return true;
}
common_speculative_init_status common_speculative_try_init(
common_params_speculative & params,
llama_context * ctx_tgt,
common_speculative ** out_spec) {
if (out_spec != nullptr) {
*out_spec = nullptr;
}
if (!params.has_stage_chain()) {
return COMMON_SPECULATIVE_INIT_SKIPPED;
}
common_speculative * spec = common_speculative_init(params, ctx_tgt);
if (spec != nullptr) {
if (out_spec != nullptr) {
*out_spec = spec;
}
return COMMON_SPECULATIVE_INIT_READY;
}
const llama_model * model = ctx_tgt != nullptr ? llama_get_model(ctx_tgt) : nullptr;
if (model != nullptr && llama_model_has_recurrent(model)) {
return COMMON_SPECULATIVE_INIT_ERR_RECURRENT;
}
if (params.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP)) {
return COMMON_SPECULATIVE_INIT_ERR_MTP;
}
return COMMON_SPECULATIVE_INIT_ERR_GENERIC;
}
void common_speculative_prepare_startup(
gpt_params & params_base,
bool allow_parallel_mtp) {
auto & params = params_base.speculative;
if (!allow_parallel_mtp && params_base.n_parallel > 1 && params.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP)) {
LOG_WRN("%s: MTP is not supported with parallel slots yet, removing the MTP stage to avoid cross-slot corruption. n_parallel=%d, stage_chain=%s\n",
__func__, params_base.n_parallel, common_speculative_stage_chain_to_str(params).c_str());
params.remove_stage_type(COMMON_SPECULATIVE_TYPE_MTP);
}
if (!params.needs_dft_model()) {
params.clear_dft();
}
params_base.has_mtp = params.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP);
}
bool common_speculative_finalize_startup(
gpt_params & params_base,
const llama_model * model) {
auto & params = params_base.speculative;
if (!params.needs_dft_model()) {
params.clear_dft();
}
if (params.has_dft()) {
LLAMA_LOG_INFO("\n\n==================================loading DRAFT model==================================\n\n");
if (!common_speculative_load_draft_model(params, params_base)) {
return false;
}
}
params_base.has_mtp = params.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP);
const bool has_external_mtp = params_base.has_mtp &&
llama_model_is_gemma4_mtp_assistant(params.model_dft);
params_base.has_mtp = common_speculative_prepare_mtp_runtime(
params,
params_base,
model,
has_external_mtp);
if (params_base.has_mtp) {
params_base.pooling_type = LLAMA_POOLING_TYPE_NONE;
}
return true;
}
bool common_speculative_before_draft(
common_speculative * spec,
llama_model * model,
llama_context * ctx,
common_sampler * sampler_src,
const common_params_sampling & sparams,
llama_seq_id seq_id,
llama_pos n_past,
llama_token sampled,
int max_tokens,
int ckpt_mode) {
if (spec == nullptr) {
return false;
}
return common_speculative_checkpoint_save(
spec->checkpoint,
model,
ctx,
sampler_src,
sparams,
seq_id,
n_past,
sampled,
max_tokens,
ckpt_mode);
}
int32_t common_speculative_on_target_seq_batch(
common_speculative * spec,
llama_context * ctx_tgt,
const llama_batch & batch,
llama_seq_id seq_id,
bool is_prompt_warmup) {
llama_context * ctx_mtp = common_speculative_get_companion_ctx(spec);
ctx_mtp = ctx_mtp ? ctx_mtp : ctx_tgt;
if (ctx_tgt == nullptr || ctx_mtp == nullptr || batch.n_tokens <= 0) {
if (ctx_tgt == nullptr || batch.n_tokens <= 0) {
return 0;
}
const int n_embd_src = common_speculative_ctx_mtp_n_embd(ctx_tgt);
const int n_embd_dst = common_speculative_ctx_mtp_n_embd(ctx_mtp);
if (n_embd_src <= 0 || n_embd_dst <= 0) {
return -1;
}
if (!common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_DFLASH)) {
llama_context * ctx_mtp = common_speculative_get_companion_ctx(spec);
ctx_mtp = ctx_mtp ? ctx_mtp : ctx_tgt;
if (ctx_mtp == nullptr) {
return 0;
}
if (n_embd_src != n_embd_dst) {
LOG_ERR("MTP warmup hidden state width mismatch: n_embd_src = %d, n_embd_dst = %d\n", n_embd_src, n_embd_dst);
return -1;
const int n_embd_src = common_speculative_ctx_mtp_n_embd(ctx_tgt);
const int n_embd_dst = common_speculative_ctx_mtp_n_embd(ctx_mtp);
if (n_embd_src <= 0 || n_embd_dst <= 0) {
return -1;
}
if (n_embd_src != n_embd_dst) {
LOG_ERR("MTP warmup hidden state width mismatch: n_embd_src = %d, n_embd_dst = %d\n", n_embd_src, n_embd_dst);
return -1;
}
}
common_speculative_feature_view feature_view;
@ -1981,6 +2072,10 @@ bool common_speculative_copy_output_hidden_rows(
const std::vector<int32_t> & output_indices,
std::vector<float> & hidden_rows) {
hidden_rows.clear();
if (common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_DFLASH)) {
return llama_spec_copy_dflash_rows_from_output_indices(ctx, output_indices, hidden_rows);
}
if (!common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_MTP)) {
return true;
}
@ -2018,13 +2113,13 @@ static bool common_speculative_apply_hidden_rows(
llama_pos pos_base,
const std::vector<llama_token> & ids,
const std::vector<float> & hidden_rows) {
auto * mtp_state = common_speculative_get_mtp_state(spec);
if (mtp_state == nullptr || ids.empty()) {
const int32_t feature_width = common_speculative_feature_width(spec);
if (feature_width <= 0 || ids.empty()) {
return true;
}
const size_t expected_floats = ids.size() * (size_t) mtp_state->n_embd;
if (mtp_state->n_embd <= 0 || hidden_rows.size() != expected_floats) {
const size_t expected_floats = ids.size() * (size_t) feature_width;
if (hidden_rows.size() != expected_floats) {
return false;
}
@ -2035,7 +2130,7 @@ static bool common_speculative_apply_hidden_rows(
common_speculative_feature_view feature_view;
const bool have_feature_view = common_speculative_feature_view_from_hidden_rows(
hidden_rows, mtp_state->n_embd, seq_id, pos_base, feature_view);
hidden_rows, feature_width, seq_id, pos_base, feature_view);
const int32_t ret = have_feature_view
? common_speculative_on_target_batch(spec, accepted_batch, feature_view, false)
: -1;
@ -2052,7 +2147,7 @@ bool common_speculative_commit_accepted_hidden_rows(
llama_token sampled_before,
const std::vector<llama_token> & ids,
const std::vector<float> & hidden_rows) {
if (!common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_MTP) || ids.empty()) {
if (common_speculative_feature_width(spec) <= 0 || ids.empty()) {
return true;
}
@ -2073,7 +2168,7 @@ bool common_speculative_commit_accepted_output(
llama_token sampled_before,
const std::vector<llama_token> & ids,
const std::vector<int32_t> & output_indices) {
if (!common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_MTP) || ids.empty()) {
if (common_speculative_feature_width(spec) <= 0 || ids.empty()) {
return true;
}
@ -2172,7 +2267,7 @@ void common_speculative_checkpoint_restore(
}
}
if (common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_MTP) && !mtp_hidden_state_pre.empty()) {
if (common_speculative_has_target_features(spec) && !mtp_hidden_state_pre.empty()) {
if (!common_speculative_commit_accepted_hidden_rows(
spec,
spec_type_used,
@ -2218,7 +2313,7 @@ void common_speculative_checkpoint_restore(
__func__, (int) seq_id, ret);
}
if (common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_MTP)) {
if (common_speculative_has_target_features(spec)) {
std::vector<int32_t> redecoded_indices(n_re);
for (int j = 0; j < n_re; ++j) {
redecoded_indices[j] = j;
@ -2274,7 +2369,7 @@ void common_speculative_commit(
common_speculative_accept(spec, ids.size() - 1);
if (common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_MTP) &&
if (common_speculative_has_target_features(spec) &&
any_rejected &&
ckpt.valid &&
!accepted_output_indices.empty()) {
@ -2299,7 +2394,7 @@ void common_speculative_commit(
return;
}
if (common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_MTP) && !accepted_output_indices.empty()) {
if (common_speculative_has_target_features(spec) && !accepted_output_indices.empty()) {
if (!common_speculative_commit_accepted_output(
spec,
ctx,
@ -2345,6 +2440,7 @@ void common_speculative_print_stats(const common_speculative * spec, double slot
impl->n_gen_tokens,
impl->n_acc_tokens,
str_perf.c_str());
}
if (spec->tuner && spec->tuner->enabled && slot_tps > 0.0 && n_decoded > 0) {
@ -2384,6 +2480,40 @@ static const common_speculative_state_mtp * common_speculative_get_mtp_state(con
return common_speculative_get_mtp_state(const_cast<common_speculative *>(spec));
}
static common_speculative_state_dflash * common_speculative_get_dflash_state(common_speculative * spec) {
if (!spec) {
return nullptr;
}
for (auto & impl : spec->impls) {
if (impl->type != COMMON_SPECULATIVE_TYPE_DFLASH) {
continue;
}
if (auto * dflash_state = dynamic_cast<common_speculative_state_dflash *>(impl.get())) {
return dflash_state;
}
}
return nullptr;
}
static const common_speculative_state_dflash * common_speculative_get_dflash_state(const common_speculative * spec) {
return common_speculative_get_dflash_state(const_cast<common_speculative *>(spec));
}
static int32_t common_speculative_feature_width(const common_speculative * spec) {
if (const auto * dflash_state = common_speculative_get_dflash_state(spec); dflash_state != nullptr) {
return dflash_state->n_target_features;
}
if (const auto * mtp_state = common_speculative_get_mtp_state(spec); mtp_state != nullptr) {
return mtp_state->n_embd;
}
return 0;
}
static mtp_last_embd & mtp_get_last_embd(common_speculative_state_mtp & state, llama_seq_id seq_id) {
auto & last = state.draft_cache_by_seq[seq_id];
if ((int) last.embd.size() != state.n_embd) {
@ -2459,11 +2589,13 @@ bool common_speculative_has_sequence_hidden(const common_speculative * spec, lla
void common_speculative_clear_sequence_hidden(common_speculative * spec, llama_seq_id seq_id) {
auto * mtp_state = common_speculative_get_mtp_state(spec);
if (mtp_state == nullptr) {
return;
if (mtp_state != nullptr) {
mtp_clear_target_hidden(*mtp_state, seq_id);
}
mtp_clear_target_hidden(*mtp_state, seq_id);
if (auto * dflash_state = common_speculative_get_dflash_state(spec); dflash_state != nullptr) {
dflash_clear_target_features(*dflash_state);
}
}
void common_speculative_clear_sequence(
@ -2515,6 +2647,10 @@ llama_context * common_speculative_get_companion_ctx(common_speculative * spec)
return mtp_state->ctx_mtp;
}
if (auto * dflash_state = common_speculative_get_dflash_state(spec); dflash_state != nullptr) {
return dflash_state->ctx_dft;
}
return nullptr;
}
@ -2553,6 +2689,34 @@ int32_t common_speculative_on_target_batch(
const llama_batch & batch,
const common_speculative_feature_view & features,
bool is_prompt_warmup) {
if (auto * dflash_state = common_speculative_get_dflash_state(spec); dflash_state != nullptr) {
if (features.kind != COMMON_SPECULATIVE_FEATURE_HIDDEN_STATE || batch.n_tokens <= 0) {
return 0;
}
if (features.width != dflash_state->n_target_features) {
LOG_ERR("%s: DFlash feature width mismatch: got %d expected %d\n",
__func__, features.width, dflash_state->n_target_features);
return -1;
}
if (batch.n_seq_id == nullptr || batch.seq_id == nullptr || batch.n_seq_id[0] <= 0 || batch.seq_id[0] == nullptr) {
return -1;
}
const llama_seq_id seq_id = batch.seq_id[0][0];
for (int i = 0; i < batch.n_tokens; ++i) {
if (batch.n_seq_id[i] != 1 || batch.seq_id[i] == nullptr || batch.seq_id[i][0] != seq_id) {
return -1;
}
}
if (!dflash_append_target_features(*dflash_state, features, seq_id)) {
return -1;
}
return 0;
}
auto * mtp_state = common_speculative_get_mtp_state(spec);
if (mtp_state == nullptr) {
return 0;
@ -2617,6 +2781,10 @@ void common_speculative_context_shift(
llama_kv_cache_seq_rm (ctx_mtp, seq_id, kv_keep, kv_keep + kv_discard);
llama_kv_cache_seq_add(ctx_mtp, seq_id, kv_keep + kv_discard, kv_past, -kv_discard);
}
if (auto * dflash_state = common_speculative_get_dflash_state(spec); dflash_state != nullptr) {
dflash_context_shift(*dflash_state, kv_keep, kv_discard, kv_past);
}
}
std::vector<llama_token> mtp_speculative_gen_draft(

View File

@ -64,6 +64,7 @@ class Model:
model_name: str | None
metadata_override: Path | None
dir_model_card: Path
target_model_dir: Path | None
# subclasses should define this!
model_arch: gguf.MODEL_ARCH
@ -71,7 +72,8 @@ class Model:
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool = False,
use_temp_file: bool = False, eager: bool = False,
metadata_override: Path | None = None, model_name: str | None = None,
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False):
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False,
target_model_dir: Path | None = None):
if type(self) is Model:
raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
@ -93,6 +95,7 @@ class Model:
self.metadata_override = metadata_override
self.model_name = model_name
self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py
self.target_model_dir = target_model_dir
# Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type
if self.ftype == gguf.LlamaFileType.GUESSED:
@ -459,6 +462,14 @@ class Model:
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
return json.load(f)
@staticmethod
def load_text_hparams(dir_model: Path) -> dict[str, Any]:
hparams = Model.load_hparams(dir_model)
text_config = hparams.get("text_config")
if isinstance(text_config, dict):
return {**hparams, **text_config}
return hparams
@classmethod
def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]:
assert names
@ -500,13 +511,14 @@ class Model:
return seems_special
# used for GPT-2 BPE and WordPiece vocabs
def get_vocab_base(self) -> tuple[list[str], list[int], str]:
def get_vocab_base(self, dir_model: Path | None = None, vocab_size: int | None = None) -> tuple[list[str], list[int], str]:
tokens: list[str] = []
toktypes: list[int] = []
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
vocab_size = self.hparams.get("vocab_size", len(tokenizer.vocab))
dir_model = dir_model or self.dir_model
tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)
vocab_size = vocab_size or self.hparams.get("vocab_size", len(tokenizer.vocab))
assert max(tokenizer.vocab.values()) < vocab_size
tokpre = self.get_vocab_base_pre(tokenizer)
@ -558,12 +570,12 @@ class Model:
if chkhsh == "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273":
# ref: https://huggingface.co/alvarobartt/grok-2-tokenizer
res = "grok-2"
if chkhsh == "0ef9807a4087ebef797fc749390439009c3b9eda9ad1a097abbe738f486c01e5":
# ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B
res = "llama-bpe"
if chkhsh == "972da7b59cec44d1f0a490a86c96df53859e486e481563e5dddac155013d87ac":
# ref: https://huggingface.co/poolside/Laguna-XS.2
res = "laguna"
if chkhsh == "0ef9807a4087ebef797fc749390439009c3b9eda9ad1a097abbe738f486c01e5":
# ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B
res = "llama-bpe"
if chkhsh == "049ecf7629871e3041641907f3de7c733e4dbfdc736f57d882ba0b0845599754":
# ref: https://huggingface.co/deepseek-ai/deepseek-llm-7b-base
res = "deepseek-llm"
@ -600,6 +612,18 @@ class Model:
if chkhsh == "e636dc30a262dcc0d8c323492e32ae2b70728f4df7dfe9737d9f920a282b8aea":
# ref: https://huggingface.co/Qwen/Qwen1.5-7B
res = "qwen2"
if chkhsh == "d30d75d9059f1aa2c19359de71047b3ae408c70875e8a3ccf8c5fba56c9d8af4":
# ref: https://huggingface.co/Qwen/Qwen3.5-9B-Instruct
res = "qwen35"
if chkhsh == "99cc61242f7106804ce24fdf3a6451e4a55251078dffd5453c806e11b2310db3":
# ref: https://huggingface.co/Qwen/Qwen3.5-27B
res = "qwen35"
if chkhsh == "1444df51289cfa8063b96f0e62b1125440111bc79a52003ea14b6eac7016fd5f":
# ref: https://huggingface.co/z-lab/Qwen3.5-27B-DFlash (uses Qwen3.5 tokenizer)
res = "qwen35"
if chkhsh == "4f53cda18c2baa0c0354bb5f9a3ecbe5ed12ab4d8e11ba873c2f11161202b945":
# ref: https://huggingface.co/Qwen/Qwen3.6-35B-A3B (identical pre-tokenizer regex to qwen35)
res = "qwen35"
if chkhsh == "b6dc8df998e1cfbdc4eac8243701a65afe638679230920b50d6f17d81c098166":
# ref: https://huggingface.co/allenai/OLMo-1.7-7B-hf
res = "olmo"
@ -690,19 +714,20 @@ class Model:
return res
# Marker: End get_vocab_base_pre
def _set_vocab_gpt2(self) -> None:
tokens, toktypes, tokpre = self.get_vocab_base()
def _set_vocab_gpt2(self, dir_model: Path | None = None, vocab_size: int | None = None) -> None:
dir_model = dir_model or self.dir_model
tokens, toktypes, tokpre = self.get_vocab_base(dir_model=dir_model, vocab_size=vocab_size)
self.gguf_writer.add_tokenizer_model("gpt2")
self.gguf_writer.add_tokenizer_pre(tokpre)
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
special_vocab = gguf.SpecialVocab(dir_model, load_merges=True)
special_vocab.add_to_gguf(self.gguf_writer)
def _set_vocab_qwen(self):
dir_model = self.dir_model
hparams = self.hparams
def _set_vocab_qwen(self, dir_model: Path | None = None, hparams: dict[str, Any] | None = None):
dir_model = dir_model or self.dir_model
hparams = hparams or self.hparams
tokens: list[str] = []
toktypes: list[int] = []
@ -2260,15 +2285,254 @@ class Qwen2MoeModel(Model):
if len(experts) > 0:
raise ValueError(f"Unprocessed experts: {experts}")
@Model.register("Qwen3ForCausalLM")
class Qwen3Model(Qwen2Model):
model_arch = gguf.MODEL_ARCH.QWEN3
@Model.register("Qwen3MoeForCausalLM")
class Qwen3MoeModel(Qwen2MoeModel):
model_arch = gguf.MODEL_ARCH.QWEN3MOE
@Model.register("DFlashDraftModel")
class DFlashDraftModel(Qwen3Model):
model_arch = gguf.MODEL_ARCH.DFLASH_DRAFT
_target_hparams: dict[str, Any] | None = None
_target_raw_hparams: dict[str, Any] | None = None
_saw_token_embd = False
_saw_output = False
def _require_target_model_dir(self) -> Path:
if self.target_model_dir is None:
raise ValueError("DFlashDraftModel conversion requires --target-model-dir <matching target model directory>")
return self.target_model_dir
def _get_target_hparams(self) -> dict[str, Any]:
if self._target_hparams is None:
self._target_hparams = Model.load_text_hparams(self._require_target_model_dir())
return self._target_hparams
def _get_target_raw_hparams(self) -> dict[str, Any]:
if self._target_raw_hparams is None:
self._target_raw_hparams = Model.load_hparams(self._require_target_model_dir())
return self._target_raw_hparams
def _target_uses_gemma4_vocab(self) -> bool:
raw_hparams = self._get_target_raw_hparams()
model_type = str(raw_hparams.get("model_type", ""))
if model_type.startswith("gemma4"):
return True
architectures = raw_hparams.get("architectures")
if isinstance(architectures, list):
return any(str(arch).startswith("Gemma4") for arch in architectures)
return False
def _get_target_hidden_size(self) -> int | None:
raw_hparams = self._get_target_raw_hparams()
if (hidden_size := raw_hparams.get("hidden_size")) is not None:
return int(hidden_size)
if (hidden_size := raw_hparams.get("backbone_hidden_size")) is not None:
return int(hidden_size)
text_hparams = raw_hparams.get("text_config")
if isinstance(text_hparams, dict) and (hidden_size := text_hparams.get("hidden_size")) is not None:
return int(hidden_size)
return None
def _set_vocab_gemma4(self, dir_model: Path, vocab_size: int | None = None) -> None:
vocab = gguf.LlamaHfVocab(dir_model)
tokens = []
scores = []
toktypes = []
visible_tokens = {
"<|channel>",
"<channel|>",
"<|tool_call>",
"<tool_call|>",
"<|tool_response>",
"<tool_response|>",
"<|\"|>",
}
for text, score, toktype in vocab.all_tokens():
tokens.append(text)
scores.append(score)
text_str = text.decode()
if text_str in visible_tokens:
toktypes.append(gguf.TokenType.USER_DEFINED)
logger.info(f"Token {text_str!r} is set to USER_DEFINED")
else:
toktypes.append(toktype)
if vocab_size is not None and len(tokens) != int(vocab_size):
raise ValueError(
f"DFlashDraftModel: Gemma4 tokenizer size {len(tokens)} does not match expected vocab_size={int(vocab_size)}"
)
self.gguf_writer.add_tokenizer_model("gemma4")
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_scores(scores)
self.gguf_writer.add_token_types(toktypes)
special_vocab = gguf.SpecialVocab(dir_model, load_merges=True)
special_vocab.add_to_gguf(self.gguf_writer)
self.gguf_writer.add_add_space_prefix(False)
self.gguf_writer.add_add_bos_token(True)
def set_vocab(self):
target_hparams = self._get_target_hparams()
target_model_dir = self._require_target_model_dir()
if self._target_uses_gemma4_vocab():
self._set_vocab_gemma4(
dir_model=target_model_dir,
vocab_size=target_hparams.get("vocab_size"),
)
return
self._set_vocab_gpt2(
dir_model=target_model_dir,
vocab_size=target_hparams.get("vocab_size"),
)
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_causal_attention(False)
self.gguf_writer.add_rope_dimension_count(self.hparams.get("head_dim", 128))
rope_scaling = self.hparams.get("rope_scaling")
if isinstance(rope_scaling, dict):
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type"))
rope_factor = rope_scaling.get("factor")
if rope_type == "linear" and rope_factor is not None:
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
self.gguf_writer.add_rope_scaling_factor(rope_factor)
elif rope_type == "yarn" and rope_factor is not None:
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
self.gguf_writer.add_rope_scaling_factor(rope_factor)
if (orig_ctx_len := rope_scaling.get("original_max_position_embeddings")) is not None:
self.gguf_writer.add_rope_scaling_orig_ctx_len(orig_ctx_len)
if (yarn_ext_factor := rope_scaling.get("extrapolation_factor")) is not None:
self.gguf_writer.add_rope_scaling_yarn_ext_factor(yarn_ext_factor)
if (yarn_attn_factor := rope_scaling.get("attention_factor", rope_scaling.get("attn_factor"))) is not None:
self.gguf_writer.add_rope_scaling_yarn_attn_factor(yarn_attn_factor)
if (yarn_beta_fast := rope_scaling.get("beta_fast")) is not None:
self.gguf_writer.add_rope_scaling_yarn_beta_fast(yarn_beta_fast)
if (yarn_beta_slow := rope_scaling.get("beta_slow")) is not None:
self.gguf_writer.add_rope_scaling_yarn_beta_slow(yarn_beta_slow)
arch = self.gguf_writer.arch
dflash_cfg = self.hparams.get("dflash_config")
dflash_cfg = dflash_cfg if isinstance(dflash_cfg, dict) else {}
def dflash_required_value(name: str) -> Any:
if name in dflash_cfg:
return dflash_cfg[name]
if name in self.hparams:
return self.hparams[name]
raise ValueError(f"DFlashDraftModel conversion requires explicit {name} metadata")
block_size = int(dflash_required_value("block_size"))
self.gguf_writer.add_uint32(f"{arch}.dflash.block_size", block_size)
mask_token_id = int(dflash_required_value("mask_token_id"))
self.gguf_writer.add_uint32(f"{arch}.dflash.mask_token_id", mask_token_id)
target_layer_ids = [int(layer_id) for layer_id in dflash_required_value("target_layer_ids")]
if len(target_layer_ids) == 0:
raise ValueError("DFlashDraftModel conversion requires at least one target_layer_id")
self.gguf_writer.add_array(f"{arch}.dflash.target_layer_ids", target_layer_ids)
if "n_target_features" in dflash_cfg:
n_target_features = int(dflash_cfg["n_target_features"])
elif "n_target_features" in self.hparams:
n_target_features = int(self.hparams["n_target_features"])
else:
target_hidden_size = self._get_target_hidden_size()
if target_hidden_size is None:
raise ValueError("DFlashDraftModel: target config is missing hidden_size")
draft_hidden_size = self.hparams.get("hidden_size")
if draft_hidden_size is None:
raise ValueError("DFlashDraftModel: draft config is missing hidden_size")
n_target_features = int(target_hidden_size) * len(target_layer_ids)
if target_hidden_size is not None and int(target_hidden_size) != int(draft_hidden_size):
logger.warning(
"DFlashDraftModel: target hidden_size=%d differs from draft hidden_size=%d; using target hidden width for n_target_features",
int(target_hidden_size),
int(draft_hidden_size),
)
logger.info(
"DFlashDraftModel: inferred n_target_features=%d from target hidden_size=%d and n_target_layers=%d",
n_target_features,
int(target_hidden_size),
len(target_layer_ids),
)
self.gguf_writer.add_uint32(f"{arch}.dflash.n_target_features", n_target_features)
logger.info(
"DFlashDraftModel metadata: block_size=%s mask_token_id=%s target_layer_ids=%s n_target_features=%s",
block_size,
mask_token_id,
target_layer_ids,
n_target_features,
)
def prepare_tensors(self):
super().prepare_tensors()
if self._saw_output and not self._saw_token_embd:
raise ValueError(
"DFlashDraftModel conversion requires token_embd.weight when output.weight is present"
)
if self._saw_token_embd and self._saw_output:
io_mode = "self-contained"
elif self._saw_token_embd:
io_mode = "self-contained-tied"
else:
io_mode = "shared-target"
logger.info(
"DFlashDraftModel IO contract: io=%s token_embd=%s output=%s target_model_dir=%s",
io_mode,
self._saw_token_embd,
self._saw_output,
self._require_target_model_dir(),
)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
top_level_name = name[6:] if name.startswith("model.") else name
if top_level_name == "fc.weight":
return [(f"{gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.DFLASH_FC]}.weight", data_torch)]
if top_level_name == "hidden_norm.weight":
return [(f"{gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.DFLASH_HIDDEN_NORM]}.weight", data_torch)]
if name == "norm.weight":
name = "model.norm.weight"
elif name.startswith("layers."):
name = f"model.{name}"
tensors = list(super().modify_tensors(data_torch, name, bid))
token_embd_name = f"{gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.TOKEN_EMBD]}.weight"
output_name = f"{gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.OUTPUT]}.weight"
for tensor_name, _ in tensors:
if tensor_name == token_embd_name:
self._saw_token_embd = True
elif tensor_name == output_name:
self._saw_output = True
return tensors
@Model.register("MellumForCausalLM")
class MellumModel(Model):
model_arch = gguf.MODEL_ARCH.MELLUM
@ -4617,6 +4881,7 @@ class JaisModel(Model):
super().prepare_tensors()
self.gguf_writer.add_max_alibi_bias(self.max_alibi_bias)
@Model.register("MiniMaxM2ForCausalLM")
class MiniMaxM2Model(Model):
model_arch = gguf.MODEL_ARCH.MINIMAXM2
@ -4689,10 +4954,12 @@ class SmolLM3Model(LlamaModel):
chat_template = tokenizer.chat_template.replace("[:]", "")
self.gguf_writer.add_chat_template(chat_template)
@Model.register("SeedOssForCausalLM")
class SeedOssModel(Model):
model_arch = gguf.MODEL_ARCH.SEED_OSS
@Model.register("Dots1ForCausalLM")
class Dots1Model(Qwen2MoeModel):
model_arch = gguf.MODEL_ARCH.DOTS1
@ -4853,6 +5120,7 @@ class Glm4MoeModel(Model):
if len(experts) > 0:
raise ValueError(f"Unprocessed experts: {experts}")
@Model.register("ChatGLMModel", "ChatGLMForConditionalGeneration")
class ChatGLMModel(Model):
model_arch = gguf.MODEL_ARCH.CHATGLM
@ -5035,6 +5303,7 @@ class ChatGLMModel(Model):
name = name.removeprefix("transformer.")
return [(self.map_tensor_name(name), data_torch)]
@Model.register("BailingMoeV2ForCausalLM")
class BailingMoeV2Model(Model):
model_arch = gguf.MODEL_ARCH.BAILINGMOE2
@ -5392,6 +5661,10 @@ def parse_args() -> argparse.Namespace:
"--metadata", type=Path,
help="Specify the path for an authorship metadata override file"
)
parser.add_argument(
"--target-model-dir", type=Path,
help="matching target model directory; required for DFlash conversion to reuse tokenizer and infer target feature width",
)
return parser.parse_args()
@ -5471,7 +5744,8 @@ def main() -> None:
metadata_override=args.metadata, model_name=args.model_name,
split_max_tensors=args.split_max_tensors,
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
small_first_shard=args.no_tensor_first_split)
small_first_shard=args.no_tensor_first_split,
target_model_dir=args.target_model_dir)
if args.vocab_only:
logger.info("Exporting model vocab...")

View File

@ -78,6 +78,10 @@ models = [
{"name": "refact", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/smallcloudai/Refact-1_6-base", },
{"name": "command-r", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/CohereForAI/c4ai-command-r-v01", },
{"name": "qwen2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen1.5-7B", },
{"name": "qwen35", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3.5-9B-Instruct", "chkhsh": "d30d75d9059f1aa2c19359de71047b3ae408c70875e8a3ccf8c5fba56c9d8af4", },
{"name": "qwen35", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3.5-27B", "chkhsh": "99cc61242f7106804ce24fdf3a6451e4a55251078dffd5453c806e11b2310db3", },
{"name": "qwen35", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/z-lab/Qwen3.5-27B-DFlash", "chkhsh": "1444df51289cfa8063b96f0e62b1125440111bc79a52003ea14b6eac7016fd5f", },
{"name": "qwen35", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3.6-35B-A3B", "chkhsh": "4f53cda18c2baa0c0354bb5f9a3ecbe5ed12ab4d8e11ba873c2f11161202b945", },
{"name": "olmo", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/allenai/OLMo-1.7-7B-hf", },
{"name": "dbrx", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/databricks/dbrx-base", },
{"name": "jina-v2-en", "tokt": TOKENIZER_TYPE.WPM, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-en", }, # WPM!
@ -155,39 +159,46 @@ for model in models:
if tokt == TOKENIZER_TYPE.SPM or tokt == TOKENIZER_TYPE.UGM:
continue
# Skip if the tokenizer folder does not exist or there are other download issues previously
if not os.path.exists(f"models/tokenizers/{name}"):
logger.warning(f"Directory for tokenizer {name} not found. Skipping...")
continue
chkhsh = model.get("chkhsh")
# create the tokenizer
try:
if name == "t5":
tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}", use_fast=False)
else:
tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}")
except (OSError, TypeError) as e:
logger.error(f"Error loading tokenizer for model {name}. The model may not exist or is not accessible with the provided token. Error: {e}")
continue # Skip to the next model if the tokenizer can't be loaded
if chkhsh is None:
# Skip if the tokenizer folder does not exist or there are other download issues previously
if not os.path.exists(f"models/tokenizers/{name}"):
logger.warning(f"Directory for tokenizer {name} not found. Skipping...")
continue
chktok = tokenizer.encode(CHK_TXT)
chkhsh = sha256(str(chktok).encode()).hexdigest()
# create the tokenizer
try:
if name == "t5":
tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}", use_fast=False)
else:
tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}")
except (OSError, TypeError) as e:
logger.error(f"Error loading tokenizer for model {name}. The model may not exist or is not accessible with the provided token. Error: {e}")
continue # Skip to the next model if the tokenizer can't be loaded
chktok = tokenizer.encode(CHK_TXT)
chkhsh = sha256(str(chktok).encode()).hexdigest()
logger.info(f"model: {name}")
logger.info(f"tokt: {tokt}")
logger.info(f"repo: {model['repo']}")
logger.info(f"chktok: {chktok}")
logger.info(f"chkhsh: {chkhsh}")
# print the "pre_tokenizer" content from the tokenizer.json
with open(f"models/tokenizers/{name}/tokenizer.json", "r", encoding="utf-8") as f:
cfg = json.load(f)
normalizer = cfg["normalizer"]
logger.info("normalizer: " + json.dumps(normalizer, indent=4))
pre_tokenizer = cfg["pre_tokenizer"]
logger.info("pre_tokenizer: " + json.dumps(pre_tokenizer, indent=4))
if "ignore_merges" in cfg["model"]:
logger.info("ignore_merges: " + json.dumps(cfg["model"]["ignore_merges"], indent=4))
if model.get("chkhsh") is None:
logger.info(f"chktok: {chktok}")
# print the "pre_tokenizer" content from the tokenizer.json
with open(f"models/tokenizers/{name}/tokenizer.json", "r", encoding="utf-8") as f:
cfg = json.load(f)
normalizer = cfg["normalizer"]
logger.info("normalizer: " + json.dumps(normalizer, indent=4))
pre_tokenizer = cfg["pre_tokenizer"]
logger.info("pre_tokenizer: " + json.dumps(pre_tokenizer, indent=4))
if "ignore_merges" in cfg["model"]:
logger.info("ignore_merges: " + json.dumps(cfg["model"]["ignore_merges"], indent=4))
else:
logger.info("using manually provided tokenizer hash")
logger.info("")
@ -354,6 +365,6 @@ logger.info("\nRun the following commands to generate the vocab files for testin
for model in models:
name = model["name"]
print(f"python3 convert_hf_to_gguf.py models/tokenizers/{name}/ --outfile models/ggml-vocab-{name}.gguf --vocab-only") # noqa: NP100
logger.info(f"python3 convert_hf_to_gguf.py models/tokenizers/{name}/ --outfile models/ggml-vocab-{name}.gguf --vocab-only") # noqa: NP100
logger.info("\n")

View File

@ -6,6 +6,7 @@
#include "common.h"
#include "llama.h"
#include "llama-spec-features.h"
#include "log.h"
#include "sampling.h"
#include "speculative.h"
@ -45,6 +46,16 @@ static void log_text(const gpt_params & params_base, const std::string & text) {
}
}
static bool server_slot_prompt_batch_overlaps(
const server_slot & slot,
int32_t batch_i0,
int32_t batch_i1) {
if (slot.prompt_batch_i0 < 0 || slot.prompt_batch_i1 <= slot.prompt_batch_i0) {
return false;
}
return slot.prompt_batch_i0 < batch_i1 && batch_i0 < slot.prompt_batch_i1;
}
struct server_mtp_warmup {
llama_context * ctx_tgt;
server_slot * slot;
@ -67,6 +78,15 @@ static bool server_response_needs_chat_parse(oaicompat_type oaicompat) {
oaicompat == OAICOMPAT_TYPE_RESP;
}
static bool server_speculative_uses_target_features(const common_params_speculative & spec) {
return spec.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP) ||
spec.has_stage_type(COMMON_SPECULATIVE_TYPE_DFLASH);
}
static bool server_speculative_requires_single_slot(const common_params_speculative & spec) {
return spec.has_stage_chain();
}
static bool server_speculative_same_stage_types(
const common_params_speculative & lhs,
const common_params_speculative & rhs) {
@ -151,6 +171,15 @@ static common_speculative_stage_params server_parse_speculative_stage_json(const
}
server_context::~server_context() {
// Speculative state may reference the live target context during teardown.
for (server_slot& slot : slots) {
if (slot.ctx_sampling != nullptr) {
common_sampler_free(slot.ctx_sampling);
}
common_speculative_free(slot.spec);
slot.spec = nullptr;
}
if (ctx) {
llama_free(ctx);
ctx = nullptr;
@ -162,17 +191,7 @@ server_context::~server_context() {
}
// Free multimodal
mtmd_free(mctx);
// Clear any sampling context
for (server_slot& slot : slots) {
if (slot.ctx_sampling != nullptr) {
common_sampler_free(slot.ctx_sampling);
}
common_speculative_free(slot.spec);
}
params_base.speculative.clear_dft();
llama_batch_free(batch);
}
@ -197,6 +216,12 @@ bool server_context::load_model(const gpt_params& params_) {
common_speculative_prepare_startup(params_base, false);
if (server_speculative_requires_single_slot(params_base.speculative) && params_base.n_parallel > 1) {
LOG_ERROR("Speculative decoding is currently limited to a single server slot (-np 1).\n", {
{"n_parallel", params_base.n_parallel},
});
return false;
}
const bool has_draft_model = params_base.speculative.has_dft();
std::string & mmproj_path = params_base.mmproj.path;
if (!mmproj_path.empty()) {
@ -307,7 +332,6 @@ void server_context::init() {
slot.params.speculative = params_base.speculative;
slot.sparams = params_base.sparams;
// try speculative decoding
if (can_spec && requested_spec) {
switch (common_speculative_try_init(params_base.speculative, slot.ctx, &slot.spec)) {
@ -445,9 +469,12 @@ void server_slot::reset() {
n_past_prompt = 0;
n_discarded_prompt = 0;
n_kept_prompt = 0;
prompt_batch_i0 = -1;
prompt_batch_i1 = -1;
n_sent_text = 0;
drafted.clear();
i_batch_dft.clear();
spec_prompt_warmup_failed = false;
n_sent_token_probs = 0;
infill = false;
ga_i = 0;
@ -540,7 +567,7 @@ void server_slot::add_token_string(const completion_token_output& token) {
}
bool server_slot::can_speculate() const {
return (!!spec || uses_mtp());
return !spec_prompt_warmup_failed && (!!spec || uses_mtp());
}
int server_slot::get_n_draft_max() const {
@ -3186,7 +3213,7 @@ void server_context::discard_n_kv_and_cache_tokens(llama_context* ctx, server_sl
const auto pos_max = llama_kv_cache_seq_pos_max(slot.ctx, slot.id);
llama_kv_cache_seq_rm(ctx, slot.id, slot.cache_tokens.pos_next(kv_keep), slot.cache_tokens.pos_next(kv_keep + kv_discard));
llama_kv_cache_seq_add(ctx, slot.id, kv_keep + kv_discard, kv_past, -kv_discard);
if (slot.uses_mtp() && slot.spec) {
if (slot.spec) {
common_speculative_context_shift(slot.spec, slot.id, kv_keep, kv_discard, kv_past);
}
if (slot.params.cache_prompt) {
@ -3586,6 +3613,9 @@ bool server_context::create_checkpoint(server_slot & slot) {
void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t n_batch, int32_t & batch_type) {
if (params_base.cont_batching || batch.n_tokens == 0) {
for (auto& slot : slots) {
slot.prompt_batch_i0 = -1;
slot.prompt_batch_i1 = -1;
// this slot still has a prompt to be processed
if (slot.state == SLOT_STATE_IDLE && slot.command == SLOT_COMMAND_LOAD_PROMPT) {
auto& prompt_tokens = slot.prompt_tokens;
@ -3868,6 +3898,7 @@ void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t
int32_t ga_i = slot.ga_i;
int32_t ga_n = slot.ga_n;
int32_t ga_w = slot.ga_w;
const int32_t prompt_batch_i0 = batch.n_tokens;
// add prompt tokens for processing in the current batch
// TODO: the self-extend stuff here is a mess - simplify and/or abstract it somehow
@ -3902,6 +3933,9 @@ void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t
}
}
slot.prompt_batch_i0 = prompt_batch_i0;
slot.prompt_batch_i1 = batch.n_tokens;
LOG_VERBOSE("prompt processing progress", {
{"id_slot", slot.id},
{"n_past", slot.n_past},
@ -4014,7 +4048,7 @@ void server_context::speculative_decoding_accept() {
}
std::vector<int32_t> accepted_output_indices;
if (slot.uses_mtp()) {
if (server_speculative_uses_target_features(slot.params.speculative)) {
if (!ids.empty()) {
accepted_output_indices.assign(slot.i_batch_dft.begin(), slot.i_batch_dft.begin() + ids.size());
}
@ -4366,6 +4400,7 @@ void server_context::update_allowlist_state(server_slot& slot) {
void server_context::process_batch_tokens(int32_t & n_batch) {
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
bool finish_prompt_warmup_batch = false;
extend_context(n_tokens);
llama_batch batch_view = {
@ -4425,19 +4460,26 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
continue; // continue loop of n_batch
}
if (params_base.speculative.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP)) {
if (server_speculative_uses_target_features(params_base.speculative)) {
for (auto & slot : slots) {
if (!slot.spec || !slot.uses_mtp()) {
if (!slot.spec || !server_speculative_uses_target_features(slot.params.speculative)) {
continue;
}
if ((slot.state != SLOT_STATE_PROCESSING || slot.n_decoded != 0) &&
(slot.state != SLOT_STATE_IDLE || slot.command != SLOT_COMMAND_LOAD_PROMPT)) {
if (slot.spec_prompt_warmup_failed) {
continue;
}
if (!server_slot_prompt_batch_overlaps(slot, i, i + n_tokens)) {
continue;
}
if (common_speculative_on_target_seq_batch(slot.spec, ctx, batch_view, slot.id, true) != 0) {
LOG_ERROR("failed to warm up MTP state from prompt batch for slot %d\n", slot.id);
common_speculative_clear_sequence_hidden(slot.spec, slot.id);
slot.spec_prompt_warmup_failed = true;
LOG_ERROR("failed to warm up speculative target-feature state from prompt batch for slot %d\n", slot.id);
} else {
finish_prompt_warmup_batch = true;
}
}
}
@ -4558,6 +4600,10 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
}
// speculative decoding - main model sample and accept
speculative_decoding_accept();
if (finish_prompt_warmup_batch) {
llama_finish_dflash_capture_batch(ctx, true);
}
}
}
@ -4589,6 +4635,11 @@ void server_context::update_slots() {
// start populating the batch for this iteration
common_batch_clear(batch);
for (auto & slot : slots) {
slot.prompt_batch_i0 = -1;
slot.prompt_batch_i1 = -1;
}
// first, add sampled tokens from any ongoing sequences
add_sampled_tokens(); // Prepare batch for inference

View File

@ -51,6 +51,8 @@ struct server_slot {
int32_t i_batch = -1;
int32_t n_predict = -1; // TODO: disambiguate from params.n_predict
int32_t prompt_batch_i0 = -1;
int32_t prompt_batch_i1 = -1;
int32_t n_prompt_tokens = 0;
int32_t n_prompt_tokens_cache = 0;
@ -157,6 +159,7 @@ struct server_slot {
// expiring logit bias
std::vector<common_sampler::elb_state> prev_elb_states;
bool spec_prompt_warmup_failed = false;
// speculative decoding stats
int32_t n_draft_total = 0; // Total draft tokens generated
int32_t n_draft_accepted = 0; // Draft tokens actually accepted

View File

@ -246,6 +246,8 @@ class MODEL_ARCH(IntEnum):
GEMMA3 = auto()
GEMMA4 = auto()
GEMMA4_MTP = auto()
DFLASH = auto()
DFLASH_DRAFT = auto()
STARCODER2 = auto()
MAMBA = auto()
XVERSE = auto()
@ -272,6 +274,7 @@ class MODEL_ARCH(IntEnum):
SEED_OSS = auto()
LAGUNA = auto()
class MODEL_TENSOR(IntEnum):
TOKEN_EMBD = auto()
TOKEN_EMBD_NORM = auto()
@ -379,6 +382,8 @@ class MODEL_TENSOR(IntEnum):
MTP_POST_PROJ = auto()
MTP_TOKEN_ORDERING = auto()
MTP_CENTROIDS = auto()
DFLASH_FC = auto()
DFLASH_HIDDEN_NORM = auto()
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
@ -416,6 +421,8 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.GEMMA3: "gemma3",
MODEL_ARCH.GEMMA4: "gemma4",
MODEL_ARCH.GEMMA4_MTP: "gemma4_mtp",
MODEL_ARCH.DFLASH: "dflash",
MODEL_ARCH.DFLASH_DRAFT: "dflash-draft",
MODEL_ARCH.STARCODER2: "starcoder2",
MODEL_ARCH.MAMBA: "mamba",
MODEL_ARCH.XVERSE: "xverse",
@ -551,6 +558,8 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.MTP_POST_PROJ: "mtp_post_proj",
MODEL_TENSOR.MTP_TOKEN_ORDERING: "mtp_token_ordering",
MODEL_TENSOR.MTP_CENTROIDS: "mtp_centroids",
MODEL_TENSOR.DFLASH_FC: "dflash_fc",
MODEL_TENSOR.DFLASH_HIDDEN_NORM: "dflash_hidden_norm",
}
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
@ -1286,6 +1295,40 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD,
MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM,
],
MODEL_ARCH.DFLASH: [
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_POST_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.DFLASH_FC,
MODEL_TENSOR.DFLASH_HIDDEN_NORM,
],
MODEL_ARCH.DFLASH_DRAFT: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_POST_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.DFLASH_FC,
MODEL_TENSOR.DFLASH_HIDDEN_NORM,
],
MODEL_ARCH.BITNET: [
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,

View File

@ -53,7 +53,7 @@
#define LLAMA_STATE_SEQ_VERSION 3
#define LLAMA_SERVER_MAGIC 0x6c6d7376u // 'lmsv'
#define LLAMA_SERVER_VERSION 1
#define LLAMA_SERVER_VERSION 1
#ifdef __cplusplus
extern "C" {
@ -1096,6 +1096,10 @@ extern "C" {
// returns NULL for invalid ids.
LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i);
// Get the argmax token ID for DFlash draft position i without materializing full logits.
// Returns LLAMA_TOKEN_NULL if argmax is not available (falls back to logits path).
LLAMA_API llama_token llama_get_dflash_draft_token_ith(struct llama_context * ctx, int32_t i);
// Get all output token embeddings.
// when pooling_type == LLAMA_POOLING_TYPE_NONE or when using a generative model,
// the embeddings for which llama_batch.logits[i] != 0 are stored contiguously

View File

@ -41,6 +41,8 @@ add_library(llama
../include/llama.h
llama.cpp
llama-spec-features.cpp
llama-spec-features-dflash.cpp
llama-dflash.cpp
llama-vocab.cpp
llama-grammar.cpp
llama-sampling.cpp
@ -99,6 +101,7 @@ add_library(llama
graphs/build_gemma2.cpp
graphs/build_gemma3.cpp
graphs/build_gemma4.cpp
graphs/build_dflash.cpp
graphs/build_mamba.cpp
graphs/build_command_r.cpp
graphs/build_olmo.cpp

439
src/graphs/build_dflash.cpp Normal file
View File

@ -0,0 +1,439 @@
#include "../llama-build-context.h"
#include "../llama-context.h"
#include "../llama-model.h"
#include <cmath>
ggml_cgraph * llm_build_context::build_dflash_kv_workspace() {
const int64_t n_embd_head_k = hparams.n_embd_head_k(0);
const int64_t n_embd_head_v = hparams.n_embd_head_v(0);
const int64_t ctx_len = lctx.dflash.visible_cross_ctx > 0
? (int64_t) lctx.dflash.visible_cross_ctx
: std::max<int64_t>(1, (int64_t) cparams.n_ctx - (int64_t) hparams.dflash_block_size);
const int32_t cache_rows = std::clamp(lctx.dflash.kv.cache_view_n_filled, 0, (int32_t) ctx_len);
const int32_t cache_write_pos = ctx_len > 0
? ((lctx.dflash.kv.cache_view_write_pos % (int32_t) ctx_len) + (int32_t) ctx_len) % (int32_t) ctx_len
: 0;
GGML_ASSERT(n_embd_head_k == n_embd_head_v);
GGML_ASSERT(lctx.ensure_dflash_kv_cache_tensors((int32_t) ctx_len));
GGML_ASSERT((int32_t) lctx.dflash.kv.k_ctx_workspace.size() == n_layer);
GGML_ASSERT((int32_t) lctx.dflash.kv.v_ctx_workspace.size() == n_layer);
ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes((int) std::max<int64_t>(1, ctx_len)) + 16 * n_layer, false);
auto build_ordered_cache_view = [&](ggml_tensor * cache) -> ggml_tensor * {
if (!lctx.dflash.kv.cache_view_valid || cache_rows <= 0) {
return cache;
}
if (cache_rows < ctx_len) {
ggml_tensor * zero_pad = ggml_view_3d(ctx0, cache,
cache->ne[0],
cache->ne[1],
ctx_len - cache_rows,
cache->nb[1],
cache->nb[2],
(size_t) cache_rows * cache->nb[2]);
ggml_tensor * valid = ggml_view_3d(ctx0, cache,
cache->ne[0],
cache->ne[1],
cache_rows,
cache->nb[1],
cache->nb[2],
0);
return ggml_concat(ctx0, zero_pad, valid, 2);
}
if (cache_write_pos == 0) {
return cache;
}
ggml_tensor * tail = ggml_view_3d(ctx0, cache,
cache->ne[0],
cache->ne[1],
ctx_len - cache_write_pos,
cache->nb[1],
cache->nb[2],
(size_t) cache_write_pos * cache->nb[2]);
ggml_tensor * head = ggml_view_3d(ctx0, cache,
cache->ne[0],
cache->ne[1],
cache_write_pos,
cache->nb[1],
cache->nb[2],
0);
return ggml_concat(ctx0, tail, head, 2);
};
for (int il = 0; il < n_layer; ++il) {
GGML_ASSERT((size_t) il < lctx.dflash.kv.k_ctx_cache.size());
GGML_ASSERT((size_t) il < lctx.dflash.kv.v_ctx_cache.size());
ggml_tensor * Kordered = build_ordered_cache_view(lctx.dflash.kv.k_ctx_cache[(size_t) il]);
ggml_tensor * Vordered = build_ordered_cache_view(lctx.dflash.kv.v_ctx_cache[(size_t) il]);
cb(Kordered, "dflash_workspace_k_ctx_view", il);
cb(Vordered, "dflash_workspace_v_ctx_view", il);
ggml_tensor * Kworkspace = ggml_cont(ctx0, ggml_permute(ctx0, Kordered, 0, 2, 1, 3));
ggml_tensor * Vworkspace = ggml_cont(ctx0, ggml_permute(ctx0, Vordered, 0, 2, 1, 3));
cb(Kworkspace, "dflash_workspace_k_perm_cont", il);
cb(Vworkspace, "dflash_workspace_v_perm_cont", il);
ggml_tensor * Kdst = ggml_view_3d(ctx0, lctx.dflash.kv.k_ctx_workspace[(size_t) il],
lctx.dflash.kv.k_ctx_workspace[(size_t) il]->ne[0],
ctx_len,
lctx.dflash.kv.k_ctx_workspace[(size_t) il]->ne[2],
lctx.dflash.kv.k_ctx_workspace[(size_t) il]->nb[1],
lctx.dflash.kv.k_ctx_workspace[(size_t) il]->nb[2],
0);
ggml_tensor * Vdst = ggml_view_3d(ctx0, lctx.dflash.kv.v_ctx_workspace[(size_t) il],
lctx.dflash.kv.v_ctx_workspace[(size_t) il]->ne[0],
ctx_len,
lctx.dflash.kv.v_ctx_workspace[(size_t) il]->ne[2],
lctx.dflash.kv.v_ctx_workspace[(size_t) il]->nb[1],
lctx.dflash.kv.v_ctx_workspace[(size_t) il]->nb[2],
0);
ggml_tensor * Kstore = ggml_cpy(ctx0, Kworkspace, Kdst);
ggml_tensor * Vstore = ggml_cpy(ctx0, Vworkspace, Vdst);
cb(Kstore, "dflash_workspace_k_store", il);
cb(Vstore, "dflash_workspace_v_store", il);
ggml_build_forward_expand(gf, Kstore);
ggml_build_forward_expand(gf, Vstore);
}
return gf;
}
ggml_cgraph * llm_build_context::build_dflash_kv_cache() {
const int64_t n_embd_head_k = hparams.n_embd_head_k(0);
const int64_t n_embd_head_v = hparams.n_embd_head_v(0);
const int64_t n_target_features = hparams.dflash_n_target_features;
const int64_t ctx_len = lctx.dflash.visible_cross_ctx > 0
? (int64_t) lctx.dflash.visible_cross_ctx
: std::max<int64_t>(1, (int64_t) cparams.n_ctx - (int64_t) hparams.dflash_block_size);
const int64_t update_rows = std::max<int64_t>(1, lctx.dflash.kv.cache_update_rows > 0 ? lctx.dflash.kv.cache_update_rows : ctx_len);
const int32_t write_pos = lctx.dflash.kv.cache_write_pos;
GGML_ASSERT(n_embd_head_k == n_embd_head_v);
GGML_ASSERT(n_target_features > 0);
GGML_ASSERT(lctx.ensure_dflash_kv_cache_tensors((int32_t) ctx_len));
GGML_ASSERT(update_rows > 0 && update_rows <= ctx_len);
GGML_ASSERT(write_pos >= 0 && write_pos < ctx_len);
ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes((int) std::max<int64_t>(1, update_rows)) + 24 * n_layer, false);
lctx.dflash.kv.cache_input_target_features = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_target_features, update_rows);
ggml_set_input(lctx.dflash.kv.cache_input_target_features);
cb(lctx.dflash.kv.cache_input_target_features, "dflash_kv_input_target_features", -1);
lctx.dflash.kv.cache_input_pos_ctx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, update_rows);
ggml_set_input(lctx.dflash.kv.cache_input_pos_ctx);
cb(lctx.dflash.kv.cache_input_pos_ctx, "dflash_kv_input_pos_ctx", -1);
ggml_tensor * fused_target = llm_build_lora_mm(lctx, ctx0, model.dflash_fc, lctx.dflash.kv.cache_input_target_features);
fused_target = llm_build_norm(ctx0, fused_target, hparams, model.dflash_hidden_norm, nullptr, LLM_NORM_RMS, cb, -1);
cb(fused_target, "dflash_kv_fused_target", -1);
for (int il = 0; il < n_layer; ++il) {
GGML_ASSERT((size_t) il < lctx.dflash.kv.k_ctx_cache.size());
GGML_ASSERT((size_t) il < lctx.dflash.kv.v_ctx_cache.size());
ggml_tensor * Kcur_ctx_proj = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, fused_target);
cb(Kcur_ctx_proj, "dflash_kv_k_proj", il);
ggml_tensor * Kcur_ctx = ggml_reshape_3d(ctx0, Kcur_ctx_proj, n_embd_head_k, n_head_kv, update_rows);
Kcur_ctx = llm_build_norm(ctx0, Kcur_ctx, hparams, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, cb, il);
cb(Kcur_ctx, "dflash_kv_k_norm", il);
Kcur_ctx = ggml_rope_ext(ctx0, Kcur_ctx, lctx.dflash.kv.cache_input_pos_ctx, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
cb(Kcur_ctx, "dflash_kv_k_rope", il);
ggml_tensor * Vcur_ctx = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, fused_target);
cb(Vcur_ctx, "dflash_kv_v_proj", il);
Vcur_ctx = ggml_reshape_3d(ctx0, Vcur_ctx, n_embd_head_v, n_head_kv, update_rows);
const int32_t first_rows = std::min<int32_t>((int32_t) update_rows, (int32_t) ctx_len - write_pos);
const int32_t second_rows = (int32_t) update_rows - first_rows;
if (first_rows > 0) {
ggml_tensor * Ksrc_first = first_rows == update_rows
? Kcur_ctx
: ggml_view_3d(ctx0, Kcur_ctx,
Kcur_ctx->ne[0],
Kcur_ctx->ne[1],
first_rows,
Kcur_ctx->nb[1],
Kcur_ctx->nb[2],
0);
ggml_tensor * Vsrc_first = first_rows == update_rows
? Vcur_ctx
: ggml_view_3d(ctx0, Vcur_ctx,
Vcur_ctx->ne[0],
Vcur_ctx->ne[1],
first_rows,
Vcur_ctx->nb[1],
Vcur_ctx->nb[2],
0);
ggml_tensor * Kdst_first = ggml_view_3d(ctx0, lctx.dflash.kv.k_ctx_cache[(size_t) il],
lctx.dflash.kv.k_ctx_cache[(size_t) il]->ne[0],
lctx.dflash.kv.k_ctx_cache[(size_t) il]->ne[1],
first_rows,
lctx.dflash.kv.k_ctx_cache[(size_t) il]->nb[1],
lctx.dflash.kv.k_ctx_cache[(size_t) il]->nb[2],
(size_t) write_pos * lctx.dflash.kv.k_ctx_cache[(size_t) il]->nb[2]);
ggml_tensor * Vdst_first = ggml_view_3d(ctx0, lctx.dflash.kv.v_ctx_cache[(size_t) il],
lctx.dflash.kv.v_ctx_cache[(size_t) il]->ne[0],
lctx.dflash.kv.v_ctx_cache[(size_t) il]->ne[1],
first_rows,
lctx.dflash.kv.v_ctx_cache[(size_t) il]->nb[1],
lctx.dflash.kv.v_ctx_cache[(size_t) il]->nb[2],
(size_t) write_pos * lctx.dflash.kv.v_ctx_cache[(size_t) il]->nb[2]);
ggml_tensor * Kstore_first = ggml_cpy(ctx0, Ksrc_first, Kdst_first);
cb(Kstore_first, "dflash_kv_k_store", il);
ggml_build_forward_expand(gf, Kstore_first);
ggml_tensor * Vstore_first = ggml_cpy(ctx0, Vsrc_first, Vdst_first);
cb(Vstore_first, "dflash_kv_v_store", il);
ggml_build_forward_expand(gf, Vstore_first);
}
if (second_rows > 0) {
ggml_tensor * Ksrc_second = ggml_view_3d(ctx0, Kcur_ctx,
Kcur_ctx->ne[0],
Kcur_ctx->ne[1],
second_rows,
Kcur_ctx->nb[1],
Kcur_ctx->nb[2],
(size_t) first_rows * Kcur_ctx->nb[2]);
ggml_tensor * Vsrc_second = ggml_view_3d(ctx0, Vcur_ctx,
Vcur_ctx->ne[0],
Vcur_ctx->ne[1],
second_rows,
Vcur_ctx->nb[1],
Vcur_ctx->nb[2],
(size_t) first_rows * Vcur_ctx->nb[2]);
ggml_tensor * Kdst_second = ggml_view_3d(ctx0, lctx.dflash.kv.k_ctx_cache[(size_t) il],
lctx.dflash.kv.k_ctx_cache[(size_t) il]->ne[0],
lctx.dflash.kv.k_ctx_cache[(size_t) il]->ne[1],
second_rows,
lctx.dflash.kv.k_ctx_cache[(size_t) il]->nb[1],
lctx.dflash.kv.k_ctx_cache[(size_t) il]->nb[2],
0);
ggml_tensor * Vdst_second = ggml_view_3d(ctx0, lctx.dflash.kv.v_ctx_cache[(size_t) il],
lctx.dflash.kv.v_ctx_cache[(size_t) il]->ne[0],
lctx.dflash.kv.v_ctx_cache[(size_t) il]->ne[1],
second_rows,
lctx.dflash.kv.v_ctx_cache[(size_t) il]->nb[1],
lctx.dflash.kv.v_ctx_cache[(size_t) il]->nb[2],
0);
ggml_tensor * Kstore_second = ggml_cpy(ctx0, Ksrc_second, Kdst_second);
cb(Kstore_second, "dflash_kv_k_store", il);
ggml_build_forward_expand(gf, Kstore_second);
ggml_tensor * Vstore_second = ggml_cpy(ctx0, Vsrc_second, Vdst_second);
cb(Vstore_second, "dflash_kv_v_store", il);
ggml_build_forward_expand(gf, Vstore_second);
}
}
return gf;
}
ggml_cgraph * llm_build_context::build_dflash() {
const int64_t n_embd_head_k = hparams.n_embd_head_k(0);
const int64_t n_embd_head_v = hparams.n_embd_head_v(0);
const int64_t n_target_features = hparams.dflash_n_target_features;
const int64_t ctx_len = lctx.dflash.visible_cross_ctx > 0
? (int64_t) lctx.dflash.visible_cross_ctx
: std::max<int64_t>(1, (int64_t) cparams.n_ctx - (int64_t) hparams.dflash_block_size);
const int32_t cache_write_pos = ctx_len > 0
? ((lctx.dflash.kv.cache_view_write_pos % (int32_t) ctx_len) + (int32_t) ctx_len) % (int32_t) ctx_len
: 0;
const int64_t n_kv_total = GGML_PAD(ctx_len + n_tokens, flash_attn ? 256 : 32);
const int64_t n_kv_pad = n_kv_total - (ctx_len + n_tokens);
GGML_ASSERT(n_embd_head_k == n_embd_head_v);
GGML_ASSERT(n_target_features > 0);
GGML_ASSERT(lctx.ensure_dflash_kv_cache_tensors((int32_t) ctx_len));
GGML_ASSERT(cache_write_pos >= 0 && cache_write_pos < ctx_len);
ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes((int) std::max<int64_t>(n_tokens, ctx_len)) + 32 * n_layer, false);
bool have_swa_layers = false;
for (int il = 0; il < n_layer; ++il) {
if (hparams.swa_layers[il]) {
have_swa_layers = true;
break;
}
}
lctx.dflash.inputs.kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv_total, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
lctx.dflash.kv.kq_mask_tensor = lctx.dflash.inputs.kq_mask;
ggml_set_input(lctx.dflash.inputs.kq_mask);
cb(lctx.dflash.inputs.kq_mask, "dflash_kq_mask", -1);
ggml_tensor * dflash_kq_mask_full = flash_attn ? ggml_cast(ctx0, lctx.dflash.inputs.kq_mask, GGML_TYPE_F16) : lctx.dflash.inputs.kq_mask;
ggml_tensor * dflash_kq_mask_swa = nullptr;
lctx.dflash.inputs.kq_mask_swa = nullptr;
lctx.dflash.kv.kq_mask_swa_tensor = nullptr;
if (have_swa_layers && hparams.n_swa > 0) {
lctx.dflash.inputs.kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv_total, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
lctx.dflash.kv.kq_mask_swa_tensor = lctx.dflash.inputs.kq_mask_swa;
ggml_set_input(lctx.dflash.inputs.kq_mask_swa);
cb(lctx.dflash.inputs.kq_mask_swa, "dflash_kq_mask_swa", -1);
dflash_kq_mask_swa = flash_attn ? ggml_cast(ctx0, lctx.dflash.inputs.kq_mask_swa, GGML_TYPE_F16) : lctx.dflash.inputs.kq_mask_swa;
}
ggml_tensor * tok_embd = model.tok_embd;
if (tok_embd == nullptr) {
tok_embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_Q4_0, n_embd, hparams.n_vocab);
}
ggml_tensor * inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, tok_embd, cb);
ggml_tensor * inp_pos = build_inp_pos();
ggml_tensor * inp_out_ids = (n_tokens > 1 && n_outputs < n_tokens) ? build_inp_out_ids() : nullptr;
bool result_rows_selected = false;
const float kq_scale = 1.0f / std::sqrt((float) n_embd_head_k);
for (int il = 0; il < n_layer; ++il) {
ggml_tensor * inpSA = inpL;
ggml_tensor * cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, cb, il);
cb(cur, "attn_norm", il);
ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens);
Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, cb, il);
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
cb(Qcur, "Qcur", il);
ggml_tensor * Kcur_noise = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
Kcur_noise = ggml_reshape_3d(ctx0, Kcur_noise, n_embd_head_k, n_head_kv, n_tokens);
Kcur_noise = llm_build_norm(ctx0, Kcur_noise, hparams, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, cb, il);
Kcur_noise = ggml_rope_ext(ctx0, Kcur_noise, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
cb(Kcur_noise, "Kcur_noise", il);
ggml_tensor * Vcur_noise = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
Vcur_noise = ggml_reshape_3d(ctx0, Vcur_noise, n_embd_head_v, n_head_kv, n_tokens);
cb(Vcur_noise, "Vcur_noise", il);
GGML_ASSERT((size_t) il < lctx.dflash.kv.k_ctx_workspace.size());
GGML_ASSERT((size_t) il < lctx.dflash.kv.v_ctx_workspace.size());
GGML_ASSERT(lctx.dflash.kv.k_ctx_workspace[(size_t) il] != nullptr);
GGML_ASSERT(lctx.dflash.kv.v_ctx_workspace[(size_t) il] != nullptr);
ggml_tensor * Kcur_ctx = ggml_view_3d(ctx0, lctx.dflash.kv.k_ctx_workspace[(size_t) il],
lctx.dflash.kv.k_ctx_workspace[(size_t) il]->ne[0],
ctx_len,
lctx.dflash.kv.k_ctx_workspace[(size_t) il]->ne[2],
lctx.dflash.kv.k_ctx_workspace[(size_t) il]->nb[1],
lctx.dflash.kv.k_ctx_workspace[(size_t) il]->nb[2],
0);
ggml_tensor * Vcur_ctx = ggml_view_3d(ctx0, lctx.dflash.kv.v_ctx_workspace[(size_t) il],
lctx.dflash.kv.v_ctx_workspace[(size_t) il]->ne[0],
ctx_len,
lctx.dflash.kv.v_ctx_workspace[(size_t) il]->ne[2],
lctx.dflash.kv.v_ctx_workspace[(size_t) il]->nb[1],
lctx.dflash.kv.v_ctx_workspace[(size_t) il]->nb[2],
0);
cb(Kcur_ctx, "Kcur_ctx_workspace", il);
cb(Vcur_ctx, "Vcur_ctx_workspace", il);
ggml_tensor * Kcur_draft = ggml_cont(ctx0, ggml_permute(ctx0, Kcur_noise, 0, 2, 1, 3));
ggml_tensor * Vcur_draft = ggml_cont(ctx0, ggml_permute(ctx0, Vcur_noise, 0, 2, 1, 3));
cb(Kcur_draft, "dflash_main_k_perm_cont", il);
cb(Vcur_draft, "dflash_main_v_perm_cont", il);
ggml_tensor * Kcur = ggml_concat(ctx0, Kcur_ctx, Kcur_draft, 1);
ggml_tensor * Vcur = ggml_concat(ctx0, Vcur_ctx, Vcur_draft, 1);
cb(Kcur, "dflash_main_k_concat", il);
cb(Vcur, "dflash_main_v_concat", il);
if (n_kv_pad > 0) {
Kcur = ggml_pad(ctx0, Kcur, 0, (int) n_kv_pad, 0, 0);
Vcur = ggml_pad(ctx0, Vcur, 0, (int) n_kv_pad, 0, 0);
cb(Kcur, "dflash_main_k_pad", il);
cb(Vcur, "dflash_main_v_pad", il);
}
cb(Qcur, "Qcur", il);
ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
ggml_tensor * k = Kcur;
ggml_tensor * v = Vcur;
ggml_tensor * dflash_kq_mask_l = (hparams.swa_layers[il] && dflash_kq_mask_swa != nullptr)
? dflash_kq_mask_swa
: dflash_kq_mask_full;
cb(q, "q", il);
cur = ggml_flash_attn_ext(ctx0, q, k, v, dflash_kq_mask_l, kq_scale, hparams.f_max_alibi_bias,
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
cb(cur, "flash_attn", il);
ggml_build_forward_expand(gf, cur);
cur = ggml_reshape_2d(ctx0, cur, model.layers[il].wo->ne[0], n_tokens);
cb(cur, "flash_attn_reshaped", il);
cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur);
cb(cur, "kqv_out", il);
cur = ggml_add(ctx0, cur, inpSA);
cb(cur, "attn_residual", il);
if (inp_out_ids != nullptr && il == n_layer - 1) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
cb(cur, "result_output_rows", -1);
result_rows_selected = true;
}
ggml_tensor * ffn_residual = cur;
cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, cb, il);
cb(cur, "attn_post_norm", il);
cur = llm_build_ffn(ctx0, lctx, nullptr, cur,
model.layers[il].ffn_up, nullptr, nullptr,
model.layers[il].ffn_gate, nullptr, nullptr,
model.layers[il].ffn_down, nullptr, nullptr,
nullptr,
LLM_FFN_SILU, LLM_FFN_PAR, cb, il, gf, false, false);
cb(cur, "ffn_out", il);
cur = ggml_add(ctx0, cur, ffn_residual);
cb(cur, "l_out", il);
inpL = cur;
}
ggml_tensor * output = const_cast<ggml_tensor *>(llama_model_dflash_output_tensor(&model));
if (output == nullptr) {
output = ggml_new_tensor_2d(ctx0, GGML_TYPE_Q4_0, n_embd, hparams.n_vocab);
}
ggml_tensor * result_input = inpL;
if (inp_out_ids && !result_rows_selected) {
result_input = ggml_get_rows(ctx0, result_input, inp_out_ids);
cb(result_input, "result_output_rows", -1);
}
ggml_tensor * result = build_output(lctx, ctx0, result_input, output, model.output_norm, cb);
cb(result, "result_output", -1);
ggml_build_forward_expand(gf, result);
lctx.dflash.draft_tokens_tensor = nullptr;
ggml_tensor * draft_tokens = ggml_argmax(ctx0, result);
ggml_set_name(draft_tokens, "draft_argmax");
ggml_build_forward_expand(gf, draft_tokens);
lctx.dflash.draft_tokens_tensor = draft_tokens;
return gf;
}

View File

@ -83,6 +83,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_MISTRAL4, "mistral4" },
{ LLM_ARCH_GEMMA4, "gemma4" },
{ LLM_ARCH_GEMMA4_MTP, "gemma4_mtp" },
{ LLM_ARCH_DFLASH_DRAFT, "dflash-draft" },
{ LLM_ARCH_GEMMA4_ASSISTANT,"gemma4_assistant" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};
@ -153,6 +154,10 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_MTP_USE_ORDERED_EMBEDDINGS, "%s.use_ordered_embeddings" },
{ LLM_KV_MTP_CENTROID_COUNT, "%s.centroid_count" },
{ LLM_KV_MTP_CENTROID_TOP_K, "%s.centroid_top_k" },
{ LLM_KV_DFLASH_BLOCK_SIZE, "%s.dflash.block_size" },
{ LLM_KV_DFLASH_MASK_TOKEN_ID, "%s.dflash.mask_token_id" },
{ LLM_KV_DFLASH_TARGET_LAYER_IDS, "%s.dflash.target_layer_ids" },
{ LLM_KV_DFLASH_N_TARGET_FEATURES, "%s.dflash.n_target_features" },
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },

View File

@ -82,6 +82,7 @@ enum llm_arch {
LLM_ARCH_MISTRAL4,
LLM_ARCH_GEMMA4,
LLM_ARCH_GEMMA4_MTP,
LLM_ARCH_DFLASH_DRAFT,
LLM_ARCH_GEMMA4_ASSISTANT,
LLM_ARCH_UNKNOWN,
};
@ -143,6 +144,10 @@ enum llm_kv {
LLM_KV_MTP_USE_ORDERED_EMBEDDINGS,
LLM_KV_MTP_CENTROID_COUNT,
LLM_KV_MTP_CENTROID_TOP_K,
LLM_KV_DFLASH_BLOCK_SIZE,
LLM_KV_DFLASH_MASK_TOKEN_ID,
LLM_KV_DFLASH_TARGET_LAYER_IDS,
LLM_KV_DFLASH_N_TARGET_FEATURES,
LLM_KV_ATTENTION_HEAD_COUNT,
LLM_KV_ATTENTION_HEAD_COUNT_KV,
@ -372,6 +377,8 @@ enum llm_tensor {
LLM_TENSOR_MTP_POST_PROJ,
LLM_TENSOR_MTP_TOKEN_ORDERING,
LLM_TENSOR_MTP_CENTROIDS,
LLM_TENSOR_DFLASH_FC,
LLM_TENSOR_DFLASH_HIDDEN_NORM,
LLM_TENSOR_UNKNOWN,
};

View File

@ -35,7 +35,9 @@ llm_build_context::llm_build_context(
const llm_build_cb & cb,
bool worst_case,
bool warmup,
int n_outputs_) :
int n_outputs_,
bool clear_lctx_inputs,
std::vector<uint8_t> * buf_compute_meta_override) :
model (lctx.model),
lctx (lctx),
hparams (model.hparams),
@ -82,8 +84,9 @@ llm_build_context::llm_build_context(
thresh_experts (cparams.thresh_experts),
pooling_type (cparams.pooling_type),
rope_type (hparams.rope_type),
clear_lctx_inputs(clear_lctx_inputs),
cb (cb),
buf_compute_meta (lctx.buf_compute_meta) {
buf_compute_meta (buf_compute_meta_override ? *buf_compute_meta_override : lctx.buf_compute_meta) {
// all initializations should be done in init()
}
@ -96,22 +99,27 @@ void llm_build_context::init() {
ctx0 = ggml_init(params);
lctx.inp_tokens = nullptr;
lctx.inp_embd = nullptr;
lctx.inp_pos = nullptr;
lctx.inp_out_ids = nullptr;
lctx.inp_KQ_mask = nullptr;
lctx.inp_KQ_mask_swa = nullptr;
lctx.inp_K_shift = nullptr;
lctx.inp_mean = nullptr;
lctx.inp_cls = nullptr;
lctx.inp_s_copy = nullptr;
lctx.inp_s_mask = nullptr;
lctx.inp_s_seq = nullptr;
lctx.inp_s_seq_qnext = nullptr;
lctx.inp_pos_bucket = nullptr;
lctx.inp_embd_enc = nullptr;
lctx.inp_KQ_mask_cross = nullptr;
if (clear_lctx_inputs) {
lctx.inp_tokens = nullptr;
lctx.inp_embd = nullptr;
lctx.inp_pos = nullptr;
lctx.inp_out_ids = nullptr;
lctx.inp_KQ_mask = nullptr;
lctx.inp_KQ_mask_swa = nullptr;
lctx.inp_K_shift = nullptr;
lctx.inp_mean = nullptr;
lctx.inp_cls = nullptr;
lctx.inp_s_copy = nullptr;
lctx.inp_s_mask = nullptr;
lctx.inp_s_seq = nullptr;
lctx.inp_s_seq_qnext = nullptr;
lctx.inp_pos_bucket = nullptr;
lctx.inp_embd_enc = nullptr;
lctx.inp_KQ_mask_cross = nullptr;
lctx.dflash.inputs.target_features = nullptr;
lctx.dflash.inputs.pos_ctx = nullptr;
lctx.dflash.inputs.kq_mask = nullptr;
}
}
void llm_build_context::free() {
@ -2199,6 +2207,80 @@ struct ggml_cgraph * llm_build_context::llama_build_graph_s_copy(llama_context &
return result;
}
struct ggml_cgraph * llm_build_context::llama_build_graph_dflash_kv_cache(llama_context & lctx) {
llama_batch dummy;
dummy.n_tokens = 0;
llm_build_cb cb = [&](struct ggml_tensor * cur, const char * name, int il) {
if (il >= 0) {
int j = 0;
for (; j < GGML_MAX_NAME - 1; ++j) {
cur->name[j] = name[j];
if (!name[j]) {
break;
}
}
if (j < GGML_MAX_NAME - 3) {
cur->name[j++] = '-';
auto sil = std::to_string(il);
for (int k = 0; k < (int) sil.size() && j < GGML_MAX_NAME - 1; ++k) {
cur->name[j++] = sil[k];
}
}
cur->name[j] = 0;
} else {
ggml_set_name(cur, name);
}
};
struct llm_build_context llm(lctx, dummy, cb, false, false, 0, false, &lctx.dflash.kv.cache_compute_meta);
llm.init();
struct ggml_cgraph * result = llm.build_dflash_kv_cache();
llm.free();
return result;
}
struct ggml_cgraph * llm_build_context::llama_build_graph_dflash_kv_workspace(llama_context & lctx) {
llama_batch dummy;
dummy.n_tokens = 0;
llm_build_cb cb = [&](struct ggml_tensor * cur, const char * name, int il) {
if (il >= 0) {
int j = 0;
for (; j < GGML_MAX_NAME - 1; ++j) {
cur->name[j] = name[j];
if (!name[j]) {
break;
}
}
if (j < GGML_MAX_NAME - 3) {
cur->name[j++] = '-';
auto sil = std::to_string(il);
for (int k = 0; k < (int) sil.size() && j < GGML_MAX_NAME - 1; ++k) {
cur->name[j++] = sil[k];
}
}
cur->name[j] = 0;
} else {
ggml_set_name(cur, name);
}
};
struct llm_build_context llm(lctx, dummy, cb, false, false, 0, false, &lctx.dflash.kv.workspace_compute_meta);
llm.init();
struct ggml_cgraph * result = llm.build_dflash_kv_workspace();
llm.free();
return result;
}
ggml_cgraph * llm_build_context::llama_build_graph(
llama_context & lctx,
const llama_batch & batch,
@ -2415,6 +2497,10 @@ ggml_cgraph * llm_build_context::llama_build_graph(
{
result = llm.build_gemma4_mtp();
} break;
case LLM_ARCH_DFLASH_DRAFT:
{
result = llm.build_dflash();
} break;
case LLM_ARCH_STARCODER2:
{
result = llm.build_starcoder2();

View File

@ -89,6 +89,7 @@ struct llm_build_context {
const enum llama_pooling_type pooling_type;
const enum llama_rope_type rope_type;
const bool clear_lctx_inputs;
const llm_build_cb & cb;
@ -103,7 +104,9 @@ struct llm_build_context {
const llm_build_cb & cb,
bool worst_case,
bool warmup,
int n_outputs = 0);
int n_outputs = 0,
bool clear_lctx_inputs = true,
std::vector<uint8_t> * buf_compute_meta_override = nullptr);
void init();
@ -244,6 +247,12 @@ struct llm_build_context {
ggml_cgraph * build_gemma4_mtp();
ggml_cgraph * build_dflash();
ggml_cgraph * build_dflash_kv_cache();
ggml_cgraph * build_dflash_kv_workspace();
ggml_cgraph * build_starcoder2();
ggml_cgraph * build_mamba();
@ -463,6 +472,10 @@ llm_expert_gating_func_type gating_op,
static ggml_cgraph * llama_build_graph_s_copy(llama_context & lctx);
static ggml_cgraph * llama_build_graph_dflash_kv_cache(llama_context & lctx);
static ggml_cgraph * llama_build_graph_dflash_kv_workspace(llama_context & lctx);
static ggml_cgraph * llama_build_graph(llama_context & lctx, const llama_batch & batch, bool worst_case, int n_outputs = 0);
ggml_tensor * build_std_attention(ggml_cgraph * gf, ggml_tensor * attn_norm, ggml_tensor * cur,

View File

@ -278,6 +278,102 @@ struct llama_context {
size_t draft_input_hidden_state_n_floats = 0;
std::vector<float> draft_input_hidden_state_owned;
struct dflash_runtime {
struct target_window_state {
const float * features = nullptr;
size_t features_n_floats = 0;
int32_t features_n_rows = 0;
const float * append_features = nullptr;
size_t append_features_n_floats = 0;
int32_t append_features_n_rows = 0;
const llama_pos * positions = nullptr;
size_t positions_n = 0;
uint64_t version = 0;
int32_t keep_rows = 0;
int32_t append_rows = 0;
bool replace = false;
std::vector<float> features_owned;
std::vector<float> append_features_owned;
std::vector<llama_pos> positions_owned;
std::vector<float> features_padded;
std::vector<llama_pos> pos_ctx_data;
std::vector<float> kq_mask_data;
std::vector<float> kq_mask_swa_data;
};
struct kv_runtime_state {
std::vector<struct ggml_tensor *> k_ctx_cache;
std::vector<struct ggml_tensor *> v_ctx_cache;
std::vector<struct ggml_tensor *> k_ctx_workspace;
std::vector<struct ggml_tensor *> v_ctx_workspace;
struct ggml_context * cache_ctx = nullptr;
std::vector<ggml_backend_buffer_t> cache_bufs;
int32_t cache_write_pos = 0;
int32_t cache_n_filled = 0;
int32_t cache_update_rows = 0;
int32_t cache_reserved_rows = 0;
int32_t cache_view_write_pos = 0;
int32_t cache_view_n_filled = 0;
uint64_t cache_applied_window_version = 0;
bool cache_valid = false;
bool cache_view_valid = false;
int32_t workspace_write_pos = 0;
int32_t workspace_n_filled = 0;
int32_t workspace_reserved_rows = 0;
int32_t workspace_token_capacity = 0;
int32_t workspace_n_kv_total = 0;
uint64_t workspace_applied_window_version = 0;
bool workspace_valid = false;
bool workspace_sync_pending = false;
std::vector<uint8_t> cache_compute_meta;
std::vector<uint8_t> workspace_compute_meta;
ggml_backend_sched_t cache_sched = nullptr;
ggml_backend_sched_t workspace_sched = nullptr;
ggml_cgraph * cache_graph = nullptr;
ggml_cgraph * workspace_graph = nullptr;
int32_t cache_graph_rows = 0;
int32_t cache_graph_write_pos = 0;
int32_t workspace_graph_rows = 0;
int32_t workspace_graph_write_pos = 0;
struct ggml_tensor * cache_input_target_features = nullptr;
struct ggml_tensor * cache_input_pos_ctx = nullptr;
struct ggml_tensor * kq_mask_tensor = nullptr;
struct ggml_tensor * kq_mask_swa_tensor = nullptr;
};
struct capture_state {
std::vector<int32_t> layer_ids;
std::vector<std::vector<float>> layer_rows;
int32_t row_count = 0;
int32_t row_width = 0;
uint64_t capture_batch_id = 0;
std::vector<uint64_t> layer_seen_batch_id;
ggml_backend_sched_eval_callback prev_cb_eval = nullptr;
void * prev_cb_eval_user_data = nullptr;
};
struct input_state {
struct ggml_tensor * target_features = nullptr; // F32 [n_target_features, cross_ctx]
struct ggml_tensor * pos_ctx = nullptr; // I32 [cross_ctx]
struct ggml_tensor * kq_mask = nullptr; // F32 [cross_ctx + n_batch, GGML_PAD(n_batch)]
struct ggml_tensor * kq_mask_swa = nullptr; // F32 [cross_ctx + n_batch, GGML_PAD(n_batch)]
};
target_window_state target;
kv_runtime_state kv;
std::unique_ptr<capture_state> capture;
std::vector<float> feature_view_buffer;
input_state inputs;
int32_t visible_cross_ctx = 0;
// Argmax token IDs from the DFlash draft graph, computed via GPU argmax.
// Populated in llama_decode_internal after graph compute.
std::vector<llama_token> draft_tokens;
struct ggml_tensor * draft_tokens_tensor = nullptr;
};
dflash_runtime dflash;
using dflash_capture_state = dflash_runtime::capture_state;
// input tensors
struct ggml_tensor * inp_tokens; // I32 [n_batch]
struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch]
@ -315,6 +411,9 @@ struct llama_context {
bool update_cache_copies();
bool ensure_dflash_kv_cache_tensors(int32_t cross_ctx);
void free_dflash_kv_cache_tensors();
bool prepare_mtp_graph_inputs(
struct llama_context & lctx);
void set_mtp_op_type(llama_mtp_op_type value);
@ -322,4 +421,3 @@ struct llama_context {
int max_nodes(int n_tokens, int n_kv) const;
};

694
src/llama-dflash.cpp Normal file
View File

@ -0,0 +1,694 @@
#include "llama-dflash.h"
#include "llama-impl.h"
#include "llama-build-context.h"
#include "llama-context.h"
#include "llama-model.h"
#include "llama-spec-features.h"
#include "ggml.h"
#include "ggml-backend.h"
#include <algorithm>
#include <cmath>
#include <cstring>
#include <type_traits>
#include <vector>
void llama_sync_dflash_workspace_if_pending(struct llama_context & lctx) {
if (!lctx.dflash.kv.workspace_sync_pending || lctx.dflash.kv.workspace_sched == nullptr) {
return;
}
ggml_backend_sched_synchronize(lctx.dflash.kv.workspace_sched);
lctx.dflash.kv.workspace_sync_pending = false;
}
static ggml_backend_buffer_type_t llama_dflash_kv_cache_layer_buft(const llama_context & lctx, int32_t il) {
if (il >= 0 && (size_t) il < lctx.model.buft_layer.size() && lctx.model.buft_layer[(size_t) il].buft != nullptr) {
return lctx.model.buft_layer[(size_t) il].buft;
}
if (il >= 0 && (size_t) il < lctx.model.layers.size()) {
const ggml_tensor * wk = lctx.model.layers[(size_t) il].wk;
if (wk != nullptr && wk->buffer != nullptr) {
return ggml_backend_buffer_get_type(wk->buffer);
}
}
return llama_default_buffer_type_cpu(true);
}
static ggml_backend_t llama_backend_for_tensor(const llama_context & lctx, const ggml_tensor * tensor) {
if (tensor == nullptr) {
return nullptr;
}
ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
if (buf == nullptr) {
return nullptr;
}
ggml_backend_buffer_type_t buft = ggml_backend_buffer_get_type(buf);
for (ggml_backend_t backend : lctx.backends) {
ggml_backend_buffer_type_t backend_buft = ggml_backend_is_cpu(backend)
? llama_default_buffer_type_cpu(true)
: ggml_backend_get_default_buffer_type(backend);
if (backend_buft == buft) {
return backend;
}
}
return nullptr;
}
bool llama_context::ensure_dflash_kv_cache_tensors(int32_t cross_ctx) {
const int32_t target_cross_ctx = std::max<int32_t>(1, cross_ctx);
const int32_t target_token_capacity = std::max<int32_t>(1, (int32_t) model.hparams.dflash_block_size);
const int32_t target_workspace_n_kv_total = GGML_PAD(target_cross_ctx + target_token_capacity, cparams.flash_attn ? 256 : 32);
const int32_t n_layer = model.hparams.n_layer;
const int64_t n_embd_head_k = model.hparams.n_embd_head_k(0);
const int64_t n_embd_head_v = model.hparams.n_embd_head_v(0);
const int64_t n_head_kv = model.hparams.n_head_kv();
if (dflash.kv.cache_ctx != nullptr &&
(int32_t) dflash.kv.k_ctx_cache.size() == n_layer &&
(int32_t) dflash.kv.k_ctx_workspace.size() == n_layer) {
const bool cache_matches =
(int32_t) dflash.kv.k_ctx_cache.front()->ne[2] == target_cross_ctx;
const bool workspace_matches =
(int32_t) dflash.kv.k_ctx_workspace.front()->ne[1] == target_workspace_n_kv_total;
if (cache_matches && workspace_matches) {
return true;
}
free_dflash_kv_cache_tensors();
if (dflash.kv.cache_sched != nullptr) {
ggml_backend_sched_free(dflash.kv.cache_sched);
dflash.kv.cache_sched = nullptr;
}
if (dflash.kv.workspace_sched != nullptr) {
ggml_backend_sched_free(dflash.kv.workspace_sched);
dflash.kv.workspace_sched = nullptr;
}
dflash.kv.cache_graph = nullptr;
dflash.kv.workspace_graph = nullptr;
dflash.kv.cache_graph_rows = 0;
dflash.kv.cache_graph_write_pos = 0;
dflash.kv.workspace_graph_rows = 0;
dflash.kv.workspace_graph_write_pos = 0;
dflash.kv.workspace_reserved_rows = 0;
}
ggml_init_params params = {
/*.mem_size =*/ (size_t) (4 * std::max(1, n_layer)) * ggml_tensor_overhead(),
/*.mem_buffer =*/ nullptr,
/*.no_alloc =*/ true,
};
dflash.kv.cache_ctx = ggml_init(params);
if (dflash.kv.cache_ctx == nullptr) {
LLAMA_LOG_ERROR("%s: failed to allocate DFlash K/V cache context\n", __func__);
return false;
}
dflash.kv.k_ctx_cache.resize((size_t) n_layer);
dflash.kv.v_ctx_cache.resize((size_t) n_layer);
dflash.kv.k_ctx_workspace.clear();
dflash.kv.v_ctx_workspace.clear();
dflash.kv.k_ctx_workspace.resize((size_t) n_layer);
dflash.kv.v_ctx_workspace.resize((size_t) n_layer);
dflash.kv.cache_bufs.clear();
dflash.kv.cache_bufs.reserve((size_t) std::max(1, n_layer) * 4);
for (int32_t il = 0; il < n_layer; ++il) {
ggml_backend_buffer_type_t layer_buft = llama_dflash_kv_cache_layer_buft(*this, il);
auto alloc_kv_input = [&](ggml_tensor *& tensor, const char * tensor_tag, const char * tensor_name,
int64_t ne0, int64_t ne1, int64_t ne2) -> bool {
tensor = ggml_new_tensor_3d(dflash.kv.cache_ctx, GGML_TYPE_F32, ne0, ne1, ne2);
if (tensor == nullptr) {
LLAMA_LOG_ERROR("%s: failed to create %s for layer %d\n", __func__, tensor_tag, il);
return false;
}
ggml_set_input(tensor);
ggml_format_name(tensor, tensor_name, il);
const size_t tensor_bytes = ggml_backend_buft_get_alloc_size(layer_buft, tensor);
ggml_backend_buffer_t buf = ggml_backend_buft_alloc_buffer(layer_buft, tensor_bytes);
if (buf == nullptr) {
LLAMA_LOG_ERROR("%s: failed to allocate %s buffer for layer %d (%zu bytes)\n",
__func__, tensor_tag, il, tensor_bytes);
return false;
}
ggml_backend_buffer_set_usage(buf, GGML_BACKEND_BUFFER_USAGE_COMPUTE);
ggml_backend_tensor_alloc(buf, tensor, ggml_backend_buffer_get_base(buf));
ggml_backend_buffer_clear(buf, 0);
dflash.kv.cache_bufs.push_back(buf);
return true;
};
if (!alloc_kv_input(dflash.kv.k_ctx_cache[(size_t) il], "dflash_k_ctx_cache", "dflash_k_ctx_cache_%d",
n_embd_head_k, n_head_kv, target_cross_ctx) ||
!alloc_kv_input(dflash.kv.v_ctx_cache[(size_t) il], "dflash_v_ctx_cache", "dflash_v_ctx_cache_%d",
n_embd_head_v, n_head_kv, target_cross_ctx) ||
!alloc_kv_input(dflash.kv.k_ctx_workspace[(size_t) il], "dflash_k_ctx_workspace", "dflash_k_ctx_workspace_%d",
n_embd_head_k, target_workspace_n_kv_total, n_head_kv) ||
!alloc_kv_input(dflash.kv.v_ctx_workspace[(size_t) il], "dflash_v_ctx_workspace", "dflash_v_ctx_workspace_%d",
n_embd_head_v, target_workspace_n_kv_total, n_head_kv)) {
free_dflash_kv_cache_tensors();
return false;
}
}
dflash.kv.workspace_token_capacity = target_token_capacity;
dflash.kv.workspace_n_kv_total = target_workspace_n_kv_total;
llama_reset_dflash_kv_cache_state(this);
return true;
}
void llama_context::free_dflash_kv_cache_tensors() {
auto release_vector = [](auto & v) {
using vec_type = std::decay_t<decltype(v)>;
vec_type().swap(v);
};
release_vector(dflash.kv.k_ctx_cache);
release_vector(dflash.kv.v_ctx_cache);
release_vector(dflash.kv.k_ctx_workspace);
release_vector(dflash.kv.v_ctx_workspace);
dflash.kv.cache_write_pos = 0;
dflash.kv.cache_n_filled = 0;
dflash.kv.cache_update_rows = 0;
dflash.kv.cache_reserved_rows = 0;
dflash.kv.cache_view_write_pos = 0;
dflash.kv.cache_view_n_filled = 0;
dflash.kv.cache_applied_window_version = 0;
dflash.kv.cache_valid = false;
dflash.kv.cache_view_valid = false;
dflash.kv.workspace_write_pos = 0;
dflash.kv.workspace_n_filled = 0;
dflash.kv.workspace_reserved_rows = 0;
dflash.kv.workspace_token_capacity = 0;
dflash.kv.workspace_n_kv_total = 0;
dflash.kv.workspace_applied_window_version = 0;
dflash.kv.workspace_valid = false;
dflash.kv.workspace_sync_pending = false;
dflash.kv.cache_graph = nullptr;
dflash.kv.workspace_graph = nullptr;
dflash.kv.cache_graph_rows = 0;
dflash.kv.cache_graph_write_pos = 0;
dflash.kv.workspace_graph_rows = 0;
dflash.kv.workspace_graph_write_pos = 0;
dflash.kv.cache_input_target_features = nullptr;
dflash.kv.cache_input_pos_ctx = nullptr;
dflash.kv.kq_mask_tensor = nullptr;
dflash.kv.kq_mask_swa_tensor = nullptr;
if (dflash.kv.workspace_sched != nullptr) {
ggml_backend_sched_synchronize(dflash.kv.workspace_sched);
ggml_backend_sched_free(dflash.kv.workspace_sched);
dflash.kv.workspace_sched = nullptr;
}
for (ggml_backend_buffer_t buf : dflash.kv.cache_bufs) {
if (buf != nullptr) {
ggml_backend_buffer_free(buf);
}
}
release_vector(dflash.kv.cache_bufs);
release_vector(dflash.kv.cache_compute_meta);
release_vector(dflash.kv.workspace_compute_meta);
if (dflash.kv.cache_ctx != nullptr) {
ggml_free(dflash.kv.cache_ctx);
dflash.kv.cache_ctx = nullptr;
}
}
static void llama_graph_compute_sched(
llama_context & lctx,
ggml_backend_sched_t sched,
ggml_cgraph * gf,
int n_threads) {
#ifdef GGML_USE_METAL
if (ggml_backend_is_metal(lctx.backend_metal)) {
ggml_backend_metal_set_n_cb(lctx.backend_metal, n_threads);
}
#endif
if (lctx.backend_cpu != nullptr) {
ggml_backend_cpu_set_n_threads(lctx.backend_cpu, n_threads);
ggml_backend_cpu_set_abort_callback(lctx.backend_cpu, lctx.abort_callback, lctx.abort_callback_data);
}
#ifdef GGML_USE_BLAS
if (lctx.backend_blas != nullptr) {
ggml_backend_blas_set_n_threads(lctx.backend_blas, n_threads);
}
#endif
ggml_backend_sched_graph_compute_async(sched, gf);
}
static bool dflash_layer_has_attention_bias(const llama_layer & layer) {
return layer.bq != nullptr ||
layer.bk != nullptr ||
layer.bv != nullptr ||
layer.bo != nullptr ||
layer.bqkv != nullptr ||
layer.bqk != nullptr ||
layer.bkv != nullptr;
}
static bool validate_dflash_graph_contract(const llama_context & lctx) {
const auto & model = lctx.model;
const auto & hparams = model.hparams;
auto rope_dim_for_layer = [&hparams](int32_t il) -> uint32_t {
if (hparams.rope_dim_per_layer[(size_t) il] != 0) {
return hparams.rope_dim_per_layer[(size_t) il];
}
return hparams.swa_layers[(size_t) il] ? hparams.n_rot_swa : hparams.n_rot;
};
auto rope_base_for_layer = [&hparams](int32_t il) -> float {
if (hparams.has_rope_freq_base_per_layer) {
return hparams.rope_freq_base_per_layer[(size_t) il];
}
return hparams.swa_layers[(size_t) il] ? hparams.rope_freq_base_train_swa : hparams.rope_freq_base_train;
};
auto rope_scale_for_layer = [&hparams](int32_t il) -> float {
return hparams.swa_layers[(size_t) il] ? hparams.rope_freq_scale_train_swa : hparams.rope_freq_scale_train;
};
const uint32_t ref_n_head = hparams.n_head(0);
const uint32_t ref_n_head_kv = hparams.n_head_kv(0);
const uint32_t ref_n_embd_head_k = hparams.n_embd_head_k(0);
const uint32_t ref_n_embd_head_v = hparams.n_embd_head_v(0);
const uint32_t ref_rope_dim = rope_dim_for_layer(0);
const float ref_rope_base = rope_base_for_layer(0);
const float ref_rope_scale = rope_scale_for_layer(0);
for (int32_t il = 0; il < (int32_t) hparams.n_layer; ++il) {
if (hparams.n_head((uint32_t) il) != ref_n_head ||
hparams.n_head_kv((uint32_t) il) != ref_n_head_kv ||
hparams.n_embd_head_k(il) != ref_n_embd_head_k ||
hparams.n_embd_head_v(il) != ref_n_embd_head_v) {
LLAMA_LOG_ERROR("%s: DFlash graph assumes layer-invariant head config, but layer %d differs (n_head=%u/%u n_head_kv=%u/%u head_k=%u/%u head_v=%u/%u)\n",
__func__,
il,
hparams.n_head((uint32_t) il), ref_n_head,
hparams.n_head_kv((uint32_t) il), ref_n_head_kv,
hparams.n_embd_head_k(il), ref_n_embd_head_k,
hparams.n_embd_head_v(il), ref_n_embd_head_v);
return false;
}
const uint32_t rope_dim = rope_dim_for_layer(il);
const float rope_base = rope_base_for_layer(il);
const float rope_scale = rope_scale_for_layer(il);
if (rope_dim != ref_rope_dim || std::fabs(rope_base - ref_rope_base) > 1e-6f || std::fabs(rope_scale - ref_rope_scale) > 1e-6f) {
LLAMA_LOG_ERROR("%s: DFlash graph assumes layer-invariant RoPE config, but layer %d differs (dim=%u/%u base=%g/%g scale=%g/%g)\n",
__func__,
il,
rope_dim, ref_rope_dim,
(double) rope_base, (double) ref_rope_base,
(double) rope_scale, (double) ref_rope_scale);
return false;
}
if (model.layers[(size_t) il].attn_norm == nullptr ||
model.layers[(size_t) il].attn_q_norm == nullptr ||
model.layers[(size_t) il].attn_k_norm == nullptr) {
LLAMA_LOG_ERROR("%s: DFlash graph requires attn_norm, attn_q_norm, and attn_k_norm weights, but layer %d is missing one or more of them\n",
__func__, il);
return false;
}
const bool has_q_norm = model.layers[(size_t) il].attn_q_norm != nullptr;
const bool has_k_norm = model.layers[(size_t) il].attn_k_norm != nullptr;
if (has_q_norm != has_k_norm) {
LLAMA_LOG_ERROR("%s: DFlash graph requires symmetric Q/K norm presence, but layer %d has q_norm=%d k_norm=%d\n",
__func__, il, (int) has_q_norm, (int) has_k_norm);
return false;
}
if (model.layers[(size_t) il].attn_norm_b != nullptr ||
model.layers[(size_t) il].attn_q_norm_b != nullptr ||
model.layers[(size_t) il].attn_k_norm_b != nullptr) {
LLAMA_LOG_ERROR("%s: DFlash graph does not implement norm-bias tensors, but layer %d requires attn_norm_b/q_norm_b/k_norm_b\n",
__func__, il);
return false;
}
if (dflash_layer_has_attention_bias(model.layers[(size_t) il])) {
LLAMA_LOG_ERROR("%s: DFlash graph does not implement attention bias tensors, but layer %d requires them\n",
__func__, il);
return false;
}
}
return true;
}
bool llama_prepare_dflash_graph_inputs(
struct llama_context & lctx,
uint32_t n_tokens) {
const int32_t cross_ctx = lctx.dflash.visible_cross_ctx > 0
? lctx.dflash.visible_cross_ctx
: std::max<int32_t>(1, (int32_t) lctx.cparams.n_ctx - (int32_t) lctx.model.hparams.dflash_block_size);
ggml_tensor * kq_mask = lctx.dflash.kv.kq_mask_tensor;
ggml_tensor * kq_mask_swa = lctx.dflash.kv.kq_mask_swa_tensor;
if (kq_mask == nullptr) {
LLAMA_LOG_ERROR("%s: DFlash graph inputs are not initialized\n", __func__);
return false;
}
if (!validate_dflash_graph_contract(lctx)) {
return false;
}
if (!lctx.ensure_dflash_kv_cache_tensors(cross_ctx) || lctx.dflash.kv.k_ctx_cache.empty() || lctx.dflash.kv.v_ctx_cache.empty()) {
LLAMA_LOG_ERROR("%s: DFlash K/V cache inputs are not initialized\n", __func__);
return false;
}
const float * src = lctx.dflash.target.features;
const float * append_src = lctx.dflash.target.append_features;
const llama_pos * src_pos = lctx.dflash.target.positions;
const size_t total_floats = lctx.dflash.target.features_n_floats;
const size_t append_floats = lctx.dflash.target.append_features_n_floats;
const size_t total_positions = lctx.dflash.target.positions_n;
const int32_t n_rows = lctx.dflash.target.features_n_rows;
const int32_t append_rows_available = lctx.dflash.target.append_features_n_rows;
const int32_t width = (int32_t) lctx.model.hparams.dflash_n_target_features;
const int32_t graph_cross_ctx = lctx.dflash.kv.k_ctx_cache.front() != nullptr
? (int32_t) lctx.dflash.kv.k_ctx_cache.front()->ne[2]
: 0;
const int32_t n_mask_tokens = (int32_t) kq_mask->ne[1];
const int32_t n_kv_total = (int32_t) kq_mask->ne[0];
llama_sync_dflash_workspace_if_pending(lctx);
if (graph_cross_ctx != cross_ctx) {
LLAMA_LOG_ERROR("%s: DFlash graph cross_ctx drift (graph=%d configured=%d)\n",
__func__, graph_cross_ctx, cross_ctx);
return false;
}
if (n_rows <= 0) {
LLAMA_LOG_ERROR("%s: missing DFlash target feature rows\n", __func__);
return false;
}
const bool have_full_src = src != nullptr && total_floats == (size_t) n_rows * (size_t) width;
if (n_rows > cross_ctx || (src != nullptr && !have_full_src)) {
LLAMA_LOG_ERROR("%s: invalid DFlash target feature shape (rows=%d width=%d floats=%zu cross_ctx=%d)\n",
__func__, n_rows, width, total_floats, cross_ctx);
return false;
}
if (n_kv_total < cross_ctx + (int32_t) n_tokens) {
LLAMA_LOG_ERROR("%s: invalid DFlash mask shape (n_kv_total=%d < cross_ctx+n_tokens=%d)\n",
__func__, n_kv_total, cross_ctx + (int32_t) n_tokens);
return false;
}
const int32_t left_pad = cross_ctx - n_rows;
lctx.dflash.target.pos_ctx_data.resize((size_t) cross_ctx);
std::fill(lctx.dflash.target.pos_ctx_data.begin(), lctx.dflash.target.pos_ctx_data.end(), 0);
if (src_pos == nullptr || total_positions != (size_t) n_rows) {
LLAMA_LOG_ERROR("%s: missing DFlash target positions (rows=%d positions=%zu cross_ctx=%d)\n",
__func__, n_rows, total_positions, cross_ctx);
return false;
}
const llama_pos last_target_pos = src_pos[n_rows - 1];
for (int32_t i = 1; i < n_rows; ++i) {
if (src_pos[i] <= src_pos[i - 1]) {
LLAMA_LOG_ERROR("%s: DFlash target positions are not strictly increasing (rows=%d first=%d last=%d)\n",
__func__, n_rows, (int) src_pos[0], (int) src_pos[n_rows - 1]);
return false;
}
}
std::copy(src_pos, src_pos + n_rows, lctx.dflash.target.pos_ctx_data.begin() + (ptrdiff_t) left_pad);
const llama_dflash_kv_cache_transition cache_plan = llama_plan_dflash_kv_cache_transition(
cross_ctx,
lctx.dflash.kv.cache_n_filled,
lctx.dflash.kv.cache_write_pos,
lctx.dflash.kv.cache_valid,
lctx.dflash.kv.cache_applied_window_version,
lctx.dflash.target.version,
lctx.dflash.target.keep_rows,
lctx.dflash.target.append_rows,
lctx.dflash.target.replace,
n_rows);
const bool have_append_src = append_src != nullptr &&
append_rows_available == cache_plan.append_rows &&
append_floats == (size_t) cache_plan.append_rows * (size_t) width;
const int32_t update_rows = cache_plan.cache_up_to_date
? 0
: (cache_plan.rebuild_cache ? n_rows : cache_plan.append_rows);
const size_t max_nodes = lctx.model.max_nodes((int) std::max<int32_t>(1, cross_ctx)) + 24 * lctx.model.hparams.n_layer;
const size_t meta_size = ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false);
if (lctx.dflash.kv.cache_compute_meta.size() != meta_size) {
lctx.dflash.kv.cache_compute_meta.resize(meta_size);
}
if (lctx.dflash.kv.cache_sched == nullptr || lctx.dflash.kv.cache_reserved_rows != cross_ctx) {
std::vector<ggml_backend_buffer_type_t> backend_buft;
backend_buft.reserve(lctx.backends.size());
for (auto * backend : lctx.backends) {
if (ggml_backend_is_cpu(backend)) {
backend_buft.push_back(llama_default_buffer_type_cpu(true));
} else {
backend_buft.push_back(ggml_backend_get_default_buffer_type(backend));
}
}
if (lctx.dflash.kv.cache_sched != nullptr) {
ggml_backend_sched_free(lctx.dflash.kv.cache_sched);
lctx.dflash.kv.cache_sched = nullptr;
}
lctx.dflash.kv.cache_graph = nullptr;
lctx.dflash.kv.cache_graph_rows = 0;
lctx.dflash.kv.cache_graph_write_pos = 0;
const int32_t saved_update_rows = lctx.dflash.kv.cache_update_rows;
lctx.dflash.kv.cache_update_rows = cross_ctx;
ggml_cgraph * gf_reserve = llm_build_context::llama_build_graph_dflash_kv_cache(lctx);
lctx.dflash.kv.cache_update_rows = saved_update_rows;
if (gf_reserve == nullptr) {
LLAMA_LOG_ERROR("%s: failed to build DFlash K/V cache reserve graph\n", __func__);
return false;
}
lctx.dflash.kv.cache_sched = ggml_backend_sched_new(lctx.backends.data(), backend_buft.data(), lctx.backends.size(), max_nodes, false);
const bool reserved = lctx.dflash.kv.cache_sched != nullptr && ggml_backend_sched_reserve(lctx.dflash.kv.cache_sched, gf_reserve);
if (!reserved) {
LLAMA_LOG_ERROR("%s: failed to initialize DFlash K/V scheduler\n", __func__);
return false;
}
lctx.dflash.kv.cache_reserved_rows = cross_ctx;
}
if (update_rows > 0) {
const float * update_src = nullptr;
if (have_append_src && update_rows == cache_plan.append_rows) {
update_src = append_src;
} else if (have_full_src) {
update_src = src + (size_t) (n_rows - update_rows) * (size_t) width;
}
const llama_pos * update_pos = src_pos + (n_rows - update_rows);
if (update_src == nullptr) {
LLAMA_LOG_ERROR("%s: missing DFlash appended target features for cached update (rows=%d append_rows=%d floats=%zu)\n",
__func__, n_rows, update_rows, append_floats);
return false;
}
if (cache_plan.rebuild_cache) {
llama_reset_dflash_kv_cache_state(&lctx);
}
lctx.dflash.kv.cache_update_rows = update_rows;
ggml_cgraph * gf_kv = nullptr;
const bool can_reuse_kv_graph = lctx.dflash.kv.cache_graph != nullptr &&
lctx.dflash.kv.cache_graph_rows == update_rows &&
lctx.dflash.kv.cache_graph_write_pos == lctx.dflash.kv.cache_write_pos;
if (can_reuse_kv_graph) {
gf_kv = lctx.dflash.kv.cache_graph;
} else {
gf_kv = llm_build_context::llama_build_graph_dflash_kv_cache(lctx);
if (gf_kv == nullptr || lctx.dflash.kv.cache_input_target_features == nullptr || lctx.dflash.kv.cache_input_pos_ctx == nullptr) {
LLAMA_LOG_ERROR("%s: failed to build DFlash K/V cache graph\n", __func__);
return false;
}
ggml_backend_sched_reset(lctx.dflash.kv.cache_sched);
ggml_backend_sched_alloc_graph(lctx.dflash.kv.cache_sched, gf_kv);
lctx.dflash.kv.cache_graph = gf_kv;
lctx.dflash.kv.cache_graph_rows = update_rows;
lctx.dflash.kv.cache_graph_write_pos = lctx.dflash.kv.cache_write_pos;
}
ggml_backend_t kv_feature_backend = llama_backend_for_tensor(lctx, lctx.dflash.kv.cache_input_target_features);
if (kv_feature_backend != nullptr) {
ggml_backend_tensor_set_async(kv_feature_backend, lctx.dflash.kv.cache_input_target_features, update_src, 0, ggml_nbytes(lctx.dflash.kv.cache_input_target_features));
} else {
ggml_backend_tensor_set(lctx.dflash.kv.cache_input_target_features, update_src, 0, ggml_nbytes(lctx.dflash.kv.cache_input_target_features));
}
ggml_backend_t kv_pos_backend = llama_backend_for_tensor(lctx, lctx.dflash.kv.cache_input_pos_ctx);
if (kv_pos_backend != nullptr) {
ggml_backend_tensor_set_async(kv_pos_backend, lctx.dflash.kv.cache_input_pos_ctx, update_pos, 0, ggml_nbytes(lctx.dflash.kv.cache_input_pos_ctx));
} else {
ggml_backend_tensor_set(lctx.dflash.kv.cache_input_pos_ctx, update_pos, 0, ggml_nbytes(lctx.dflash.kv.cache_input_pos_ctx));
}
llama_graph_compute_sched(lctx, lctx.dflash.kv.cache_sched, gf_kv, lctx.cparams.n_threads);
ggml_backend_sched_synchronize(lctx.dflash.kv.cache_sched);
lctx.dflash.kv.cache_n_filled = std::min(cross_ctx, lctx.dflash.kv.cache_n_filled + update_rows);
lctx.dflash.kv.cache_write_pos = (lctx.dflash.kv.cache_write_pos + update_rows) % cross_ctx;
lctx.dflash.kv.cache_applied_window_version = lctx.dflash.target.version;
lctx.dflash.kv.cache_valid = true;
lctx.dflash.kv.cache_view_n_filled = lctx.dflash.kv.cache_n_filled;
lctx.dflash.kv.cache_view_write_pos = lctx.dflash.kv.cache_write_pos;
lctx.dflash.kv.cache_view_valid = true;
}
if (lctx.dflash.kv.cache_view_valid &&
!lctx.dflash.kv.k_ctx_workspace.empty() && !lctx.dflash.kv.v_ctx_workspace.empty()) {
const bool need_workspace_refresh = !lctx.dflash.kv.workspace_valid ||
lctx.dflash.kv.workspace_n_filled != lctx.dflash.kv.cache_view_n_filled ||
lctx.dflash.kv.workspace_write_pos != lctx.dflash.kv.cache_view_write_pos ||
lctx.dflash.kv.workspace_applied_window_version != lctx.dflash.kv.cache_applied_window_version;
if (need_workspace_refresh) {
const size_t max_nodes = lctx.model.max_nodes((int) std::max<int32_t>(1, cross_ctx)) + 16 * lctx.model.hparams.n_layer;
const size_t meta_size = ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false);
if (lctx.dflash.kv.workspace_compute_meta.size() != meta_size) {
lctx.dflash.kv.workspace_compute_meta.resize(meta_size);
}
ggml_cgraph * gf_workspace = nullptr;
const bool can_reuse_workspace_graph = lctx.dflash.kv.workspace_graph != nullptr &&
lctx.dflash.kv.workspace_graph_rows == lctx.dflash.kv.cache_view_n_filled &&
lctx.dflash.kv.workspace_graph_write_pos == lctx.dflash.kv.cache_view_write_pos;
if (can_reuse_workspace_graph) {
gf_workspace = lctx.dflash.kv.workspace_graph;
} else {
gf_workspace = llm_build_context::llama_build_graph_dflash_kv_workspace(lctx);
if (gf_workspace == nullptr) {
LLAMA_LOG_ERROR("%s: failed to build DFlash K/V workspace graph\n", __func__);
return false;
}
std::vector<ggml_backend_buffer_type_t> backend_buft;
backend_buft.reserve(lctx.backends.size());
for (auto * backend : lctx.backends) {
if (ggml_backend_is_cpu(backend)) {
backend_buft.push_back(llama_default_buffer_type_cpu(true));
} else {
backend_buft.push_back(ggml_backend_get_default_buffer_type(backend));
}
}
if (lctx.dflash.kv.workspace_sched == nullptr) {
lctx.dflash.kv.workspace_sched = ggml_backend_sched_new(lctx.backends.data(), backend_buft.data(), lctx.backends.size(), max_nodes, false);
}
if (lctx.dflash.kv.workspace_reserved_rows != cross_ctx) {
const bool saved_view_valid = lctx.dflash.kv.cache_view_valid;
const int32_t saved_view_rows = lctx.dflash.kv.cache_view_n_filled;
const int32_t saved_view_write_pos = lctx.dflash.kv.cache_view_write_pos;
lctx.dflash.kv.cache_view_valid = true;
lctx.dflash.kv.cache_view_n_filled = cross_ctx;
lctx.dflash.kv.cache_view_write_pos = cross_ctx > 1 ? 1 : 0;
ggml_cgraph * gf_workspace_reserve = llm_build_context::llama_build_graph_dflash_kv_workspace(lctx);
lctx.dflash.kv.cache_view_valid = saved_view_valid;
lctx.dflash.kv.cache_view_n_filled = saved_view_rows;
lctx.dflash.kv.cache_view_write_pos = saved_view_write_pos;
const bool reserved = lctx.dflash.kv.workspace_sched != nullptr &&
gf_workspace_reserve != nullptr &&
ggml_backend_sched_reserve(lctx.dflash.kv.workspace_sched, gf_workspace_reserve);
if (!reserved) {
LLAMA_LOG_ERROR("%s: failed to initialize DFlash K/V workspace scheduler\n", __func__);
return false;
}
lctx.dflash.kv.workspace_reserved_rows = cross_ctx;
}
ggml_backend_sched_reset(lctx.dflash.kv.workspace_sched);
ggml_backend_sched_alloc_graph(lctx.dflash.kv.workspace_sched, gf_workspace);
lctx.dflash.kv.workspace_graph = gf_workspace;
lctx.dflash.kv.workspace_graph_rows = lctx.dflash.kv.cache_view_n_filled;
lctx.dflash.kv.workspace_graph_write_pos = lctx.dflash.kv.cache_view_write_pos;
}
llama_graph_compute_sched(lctx, lctx.dflash.kv.workspace_sched, gf_workspace, lctx.cparams.n_threads);
lctx.dflash.kv.workspace_sync_pending = true;
lctx.dflash.kv.workspace_n_filled = lctx.dflash.kv.cache_view_n_filled;
lctx.dflash.kv.workspace_write_pos = lctx.dflash.kv.cache_view_write_pos;
lctx.dflash.kv.workspace_applied_window_version = lctx.dflash.kv.cache_applied_window_version;
lctx.dflash.kv.workspace_valid = true;
}
}
const int32_t full_visible_first = left_pad;
const int32_t full_visible_last = cross_ctx + (int32_t) n_tokens - 1;
lctx.dflash.target.kq_mask_data.assign((size_t) n_kv_total * (size_t) n_mask_tokens, -INFINITY);
for (uint32_t j = 0; j < n_tokens; ++j) {
float * row = lctx.dflash.target.kq_mask_data.data() + (size_t) j * (size_t) n_kv_total;
for (int32_t i = full_visible_first; i <= full_visible_last; ++i) {
row[i] = 0.0f;
}
}
ggml_backend_tensor_set(kq_mask, lctx.dflash.target.kq_mask_data.data(), 0, ggml_nbytes(kq_mask));
if (kq_mask_swa != nullptr) {
lctx.dflash.target.kq_mask_swa_data.assign((size_t) n_kv_total * (size_t) n_mask_tokens, -INFINITY);
const int32_t swa_window = (int32_t) lctx.model.hparams.n_swa;
const int32_t draft_pos_base = (int32_t) last_target_pos;
for (uint32_t j = 0; j < n_tokens; ++j) {
float * row = lctx.dflash.target.kq_mask_swa_data.data() + (size_t) j * (size_t) n_kv_total;
const int32_t q_pos = draft_pos_base + (int32_t) j;
for (int32_t k = left_pad; k < cross_ctx; ++k) {
const int32_t k_pos = (int32_t) lctx.dflash.target.pos_ctx_data[(size_t) k];
if (q_pos - k_pos < swa_window) {
row[k] = 0.0f;
}
}
for (int32_t k = cross_ctx; k < cross_ctx + (int32_t) n_tokens; ++k) {
const int32_t block_k = k - cross_ctx;
if (block_k <= (int32_t) j) {
row[k] = 0.0f;
}
}
}
ggml_backend_tensor_set(kq_mask_swa, lctx.dflash.target.kq_mask_swa_data.data(), 0, ggml_nbytes(kq_mask_swa));
}
return true;
}

8
src/llama-dflash.h Normal file
View File

@ -0,0 +1,8 @@
#pragma once
#include <cstdint>
struct llama_context;
bool llama_prepare_dflash_graph_inputs(llama_context & lctx, uint32_t n_tokens);
void llama_sync_dflash_workspace_if_pending(llama_context & lctx);

View File

@ -3,6 +3,7 @@
#include "llama-model-loader.h"
#include "llama-model.h"
#include <limits>
#include <map>
#define LLAMA_MAX_EXPERTS 512 // Qwen3 Next
@ -36,6 +37,89 @@ static inline const char * llm_expert_gating_func_name(llm_expert_gating_func_ty
}
}
static bool load_dflash_target_layer_ids(
llama_model_loader & ml,
const std::string & key,
llama_hparams & hparams,
bool required) {
const int kid = gguf_find_key(ml.meta, key.c_str());
if (kid < 0 || gguf_get_kv_type(ml.meta, kid) != GGUF_TYPE_ARRAY) {
if (required) {
throw std::runtime_error(format("array key not found in model: %s", key.c_str()));
}
return false;
}
const enum gguf_type type = gguf_get_arr_type(ml.meta, kid);
if (type != GGUF_TYPE_UINT32 && type != GGUF_TYPE_INT32) {
throw std::runtime_error(format("dflash: %s must be a uint32/int32 array", key.c_str()));
}
const size_t n = gguf_get_arr_n(ml.meta, kid);
if (n == 0) {
throw std::runtime_error(format("dflash: %s must not be empty", key.c_str()));
}
if (n > 8) {
throw std::runtime_error(format("dflash: %s has %zu entries, max is 8", key.c_str(), n));
}
hparams.dflash_n_target_layers = (uint32_t) n;
for (uint32_t & id : hparams.dflash_target_layer_ids) {
id = 0;
}
const void * data = gguf_get_arr_data(ml.meta, kid);
for (uint32_t i = 0; i < hparams.dflash_n_target_layers; ++i) {
if (type == GGUF_TYPE_INT32) {
const int32_t id = ((const int32_t *) data)[i];
if (id < 0) {
throw std::runtime_error(format("dflash: %s contains negative layer id %d", key.c_str(), id));
}
hparams.dflash_target_layer_ids[i] = (uint32_t) id;
} else {
hparams.dflash_target_layer_ids[i] = ((const uint32_t *) data)[i];
}
const uint32_t id = hparams.dflash_target_layer_ids[i];
for (uint32_t j = 0; j < i; ++j) {
if (hparams.dflash_target_layer_ids[j] == id) {
throw std::runtime_error(format(
"dflash: %s contains duplicate layer id %u",
key.c_str(),
id));
}
}
}
return true;
}
static void validate_dflash_hparams(llama_hparams & hparams, llm_arch arch) {
if (hparams.dflash_block_size <= 1) {
throw std::runtime_error(format("%s: dflash block_size must be > 1", llama_model_arch_name(arch)));
}
if (hparams.dflash_n_target_layers == 0) {
throw std::runtime_error(format("%s: dflash target_layer_ids are required", llama_model_arch_name(arch)));
}
// DFlash feature width is target-model specific. Keep the serialized metadata intact here
// and validate it against the live target model during DFlash init.
if (hparams.dflash_n_target_features == 0) {
throw std::runtime_error(format(
"%s: dflash n_target_features must be > 0",
llama_model_arch_name(arch)));
}
if (hparams.dflash_n_target_features % hparams.dflash_n_target_layers != 0) {
throw std::runtime_error(format(
"%s: dflash n_target_features=%u must be divisible by n_target_layers=%u",
llama_model_arch_name(arch),
hparams.dflash_n_target_features,
hparams.dflash_n_target_layers));
}
}
void llm_load_hparams(
llama_model_loader & ml,
@ -806,6 +890,18 @@ void llm_load_hparams(
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
case LLM_ARCH_DFLASH_DRAFT:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_DFLASH_BLOCK_SIZE, hparams.dflash_block_size, false);
ml.get_key(LLM_KV_DFLASH_MASK_TOKEN_ID, hparams.dflash_mask_token_id, false);
ml.get_key(LLM_KV_DFLASH_N_TARGET_FEATURES, hparams.dflash_n_target_features, false);
load_dflash_target_layer_ids(ml, LLM_KV(model.arch)(LLM_KV_DFLASH_TARGET_LAYER_IDS), hparams, false);
validate_dflash_hparams(hparams, model.arch);
hparams.n_layer_kv_from_start = hparams.n_layer;
model.type = e_model::MODEL_UNKNOWN;
} break;
case LLM_ARCH_STARCODER2:
{

View File

@ -139,6 +139,13 @@ struct llama_hparams {
uint32_t mtp_num_centroids = 0;
uint32_t mtp_centroid_top_k = 0;
// DFlash draft model metadata
uint32_t dflash_block_size = 16;
uint32_t dflash_mask_token_id = 0;
uint32_t dflash_n_target_features = 0;
uint32_t dflash_n_target_layers = 0;
uint32_t dflash_target_layer_ids[8] = {};
// needed by encoder-decoder models (e.g. T5, FLAN-T5)
// ref: https://github.com/ggerganov/llama.cpp/pull/8141
llama_token dec_start_token_id = -1;
@ -158,6 +165,10 @@ struct llama_hparams {
if (this->n_ctx_train != other.n_ctx_train) return true;
if (this->n_embd != other.n_embd) return true;
if (this->mtp_backbone_n_embd != other.mtp_backbone_n_embd) return true;
if (this->dflash_block_size != other.dflash_block_size) return true;
if (this->dflash_mask_token_id != other.dflash_mask_token_id) return true;
if (this->dflash_n_target_features != other.dflash_n_target_features) return true;
if (this->dflash_n_target_layers != other.dflash_n_target_layers) return true;
if (this->n_layer != other.n_layer) return true;
if (this->n_rot != other.n_rot) return true;
if (this->n_swa != other.n_swa) return true;
@ -188,6 +199,9 @@ struct llama_hparams {
if (this->ssm_dt_rank != other.ssm_dt_rank) return true;
if (this->ssm_n_group != other.ssm_n_group) return true;
if (this->recurrent_layer_arr != other.recurrent_layer_arr) return true;
for (int i = 0; i < 8; ++i) {
if (this->dflash_target_layer_ids[i] != other.dflash_target_layer_ids[i]) return true;
}
if (this->dec_start_token_id != other.dec_start_token_id) return true;

View File

@ -100,6 +100,8 @@ struct create_tensors_helper : public create_tensors_helper_interface {
bool create_gemma4_mtp_tensors(const LLM_TN & tn);
bool create_dflash_tensors(const LLM_TN & tn);
bool create_starcoder2_tensors(const LLM_TN & tn);
bool create_mamba_tensors(const LLM_TN & tn);
@ -2248,6 +2250,48 @@ bool create_tensors_helper::create_gemma4_mtp_tensors(const LLM_TN & tn) {
return use_mmap_buffer;
}
bool create_tensors_helper::create_dflash_tensors(const LLM_TN & tn) {
LOADING_PRELUDE
const bool use_split_ctx = model.split_mode == LLAMA_SPLIT_MODE_GRAPH || model.split_mode == LLAMA_SPLIT_MODE_ATTN;
model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
model.output = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
auto output_extra = create_tensor(ctx_output, "output_extra.weight", {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
if (output_extra != nullptr) {
model.output = output_extra;
}
if (model.output == nullptr && model.tok_embd != nullptr) {
model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
}
model.output_mtp = model.output;
model.dflash_fc = create_tensor(ctx_output, tn(LLM_TENSOR_DFLASH_FC, "weight"), {(int64_t) hparams.dflash_n_target_features, n_embd}, 0);
model.dflash_hidden_norm = create_tensor(ctx_output, tn(LLM_TENSOR_DFLASH_HIDDEN_NORM, "weight"), {n_embd}, 0);
for (int i = 0; i < n_layer; ++i) {
ggml_context * ctx_split = use_split_ctx ? ctx_for_layer_split(i) : ctx_for_layer(i);
auto & layer = model.layers[i];
layer.attn_norm = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
layer.attn_post_norm = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v * n_head, n_embd}, 0);
layer.attn_q_norm = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
layer.attn_k_norm = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
}
return use_mmap_buffer;
}
bool create_tensors_helper::create_starcoder2_tensors(const LLM_TN & tn) {
LOADING_PRELUDE
@ -4398,6 +4442,8 @@ bool create_tensors_helper::create_tensors() {
case LLM_ARCH_GEMMA4_MTP:
case LLM_ARCH_GEMMA4_ASSISTANT:
use_mmap_buffer = create_gemma4_mtp_tensors(tn); break;
case LLM_ARCH_DFLASH_DRAFT:
use_mmap_buffer = create_dflash_tensors(tn); break;
case LLM_ARCH_STARCODER2:
use_mmap_buffer = create_starcoder2_tensors(tn); break;
case LLM_ARCH_MAMBA:

View File

@ -845,6 +845,27 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
{ LLM_TENSOR_MTP_CENTROIDS, "mtp_centroids" },
},
},
{
LLM_ARCH_DFLASH_DRAFT,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_DFLASH_FC, "dflash_fc" },
{ LLM_TENSOR_DFLASH_HIDDEN_NORM, "dflash_hidden_norm" },
},
},
{
LLM_ARCH_GEMMA4_ASSISTANT,
{

View File

@ -430,6 +430,8 @@ struct llama_model {
struct ggml_tensor * mtp_post_proj = nullptr;
struct ggml_tensor * mtp_token_ordering = nullptr;
struct ggml_tensor * mtp_centroids = nullptr;
struct ggml_tensor * dflash_fc = nullptr;
struct ggml_tensor * dflash_hidden_norm = nullptr;
struct ggml_tensor * output_norm;
struct ggml_tensor * output_norm_b;

View File

@ -616,7 +616,9 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
if (qs.model.hparams.n_vocab >= 127999 && (qs.model.type == MODEL_8B || qs.model.type == MODEL_70B))
new_type = GGML_TYPE_IQ6_K;
}
else if (qs.model.hparams.n_gqa() >= 4) {
else if (qs.model.hparams.n_gqa() >= 4 &&
!(arch == LLM_ARCH_DFLASH_DRAFT &&
(ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M))) {
if (new_type == GGML_TYPE_Q2_K || new_type == GGML_TYPE_IQ3_XXS) new_type = GGML_TYPE_IQ3_S;
else if (new_type == GGML_TYPE_Q2_K_R4 || new_type == GGML_TYPE_IQ3_XXS_R4) new_type = GGML_TYPE_IQ3_K_R4;
else if (new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_IQ3_S) new_type = GGML_TYPE_Q4_K;
@ -1782,4 +1784,3 @@ uint32_t llama_model_quantize(
return 1;
}
}

View File

@ -0,0 +1,729 @@
#include "llama-spec-features.h"
#include <algorithm>
#include <cstdlib>
#include <cstring>
#include <random>
#include "llama-model.h"
#include "llama-context.h"
void llama_reset_dflash_kv_cache_state(struct llama_context * ctx) {
if (ctx == nullptr) {
return;
}
ctx->dflash.kv.cache_write_pos = 0;
ctx->dflash.kv.cache_n_filled = 0;
ctx->dflash.kv.cache_update_rows = 0;
ctx->dflash.kv.cache_view_write_pos = 0;
ctx->dflash.kv.cache_view_n_filled = 0;
ctx->dflash.kv.cache_applied_window_version = 0;
ctx->dflash.kv.cache_valid = false;
ctx->dflash.kv.cache_view_valid = false;
ctx->dflash.kv.workspace_write_pos = 0;
ctx->dflash.kv.workspace_n_filled = 0;
ctx->dflash.kv.workspace_applied_window_version = 0;
ctx->dflash.kv.workspace_valid = false;
ctx->dflash.kv.workspace_sync_pending = false;
for (ggml_backend_buffer_t buf : ctx->dflash.kv.cache_bufs) {
if (buf != nullptr) {
ggml_backend_buffer_clear(buf, 0);
}
}
}
llama_dflash_kv_cache_transition llama_plan_dflash_kv_cache_transition_for_ctx(
const struct llama_context * ctx,
const llama_dflash_window_update & window_update,
int32_t n_rows) {
if (ctx == nullptr) {
llama_dflash_kv_cache_transition plan;
plan.rebuild_cache = true;
plan.append_rows = std::clamp(window_update.append_rows, 0, n_rows);
plan.next_n_filled = n_rows;
return plan;
}
const int32_t cross_ctx = ctx->dflash.visible_cross_ctx > 0
? ctx->dflash.visible_cross_ctx
: std::max<int32_t>(1, (int32_t) ctx->cparams.n_ctx - (int32_t) ctx->model.hparams.dflash_block_size);
return llama_plan_dflash_kv_cache_transition(
cross_ctx,
ctx->dflash.kv.cache_n_filled,
ctx->dflash.kv.cache_write_pos,
ctx->dflash.kv.cache_valid,
ctx->dflash.kv.cache_applied_window_version,
window_update.version,
window_update.keep_rows,
window_update.append_rows,
window_update.replace,
n_rows);
}
void llama_set_dflash_visible_cross_ctx(
struct llama_context * ctx,
int32_t cross_ctx) {
if (ctx == nullptr) {
return;
}
ctx->dflash.visible_cross_ctx = std::max<int32_t>(0, cross_ctx);
}
int32_t llama_get_dflash_visible_cross_ctx(
const struct llama_context * ctx) {
return ctx != nullptr ? ctx->dflash.visible_cross_ctx : 0;
}
int32_t llama_model_dflash_block_size(const struct llama_model * model) {
return model ? (int32_t) model->hparams.dflash_block_size : 0;
}
int32_t llama_model_dflash_mask_token_id(const struct llama_model * model) {
return model ? (int32_t) model->hparams.dflash_mask_token_id : -1;
}
int32_t llama_model_dflash_n_target_layers(const struct llama_model * model) {
return model ? (int32_t) model->hparams.dflash_n_target_layers : 0;
}
int32_t llama_model_dflash_n_target_features(const struct llama_model * model) {
return model ? (int32_t) model->hparams.dflash_n_target_features : 0;
}
int32_t llama_model_dflash_target_layer_ids(
const struct llama_model * model,
int32_t * layer_ids,
int32_t capacity) {
if (model == nullptr || layer_ids == nullptr || capacity <= 0) {
return 0;
}
const int32_t n_layers = std::min<int32_t>((int32_t) model->hparams.dflash_n_target_layers, capacity);
for (int32_t i = 0; i < n_layers; ++i) {
layer_ids[i] = (int32_t) model->hparams.dflash_target_layer_ids[i];
}
return n_layers;
}
int32_t llama_model_dflash_target_mask_token_id(const struct llama_model * model) {
if (model == nullptr) {
return (int32_t) LLAMA_TOKEN_NULL;
}
return (int32_t) model->vocab.token_mask();
}
const struct ggml_tensor * llama_model_dflash_output_tensor(
const struct llama_model * model) {
if (model == nullptr) {
return nullptr;
}
if (model->output_mtp != nullptr) {
return model->output_mtp;
}
if (model->output != nullptr) {
return model->output;
}
return model->tok_embd;
}
int32_t llama_model_dflash_io_mode(
const struct llama_model * draft_model,
const struct llama_model * target_model) {
if (draft_model == nullptr || target_model == nullptr || draft_model->arch != LLM_ARCH_DFLASH_DRAFT) {
return LLAMA_DFLASH_IO_MODE_INVALID;
}
const ggml_tensor * draft_output = llama_model_dflash_output_tensor(draft_model);
const ggml_tensor * target_output = llama_model_dflash_output_tensor(target_model);
if (draft_model->tok_embd == nullptr || draft_output == nullptr || target_model->tok_embd == nullptr || target_output == nullptr) {
return LLAMA_DFLASH_IO_MODE_INVALID;
}
const bool shared_tok = draft_model->tok_embd == target_model->tok_embd;
const bool shared_output = draft_output == target_output;
if (shared_tok && shared_output) {
return LLAMA_DFLASH_IO_MODE_SHARED;
}
if (!shared_tok && !shared_output) {
return LLAMA_DFLASH_IO_MODE_SELF_CONTAINED;
}
return LLAMA_DFLASH_IO_MODE_MIXED;
}
bool llama_model_dflash_io_tensors_match(
const struct llama_model * draft_model,
int32_t n_embd,
int32_t n_vocab) {
const ggml_tensor * output = llama_model_dflash_output_tensor(draft_model);
if (draft_model == nullptr || draft_model->tok_embd == nullptr || output == nullptr || n_embd <= 0 || n_vocab <= 0) {
return false;
}
return (int32_t) draft_model->tok_embd->ne[0] == n_embd &&
(int32_t) draft_model->tok_embd->ne[1] == n_vocab &&
(int32_t) output->ne[0] == n_embd &&
(int32_t) output->ne[1] == n_vocab;
}
bool llama_model_share_dflash_io_tensors(
struct llama_model * draft_model,
const struct llama_model * target_model) {
if (draft_model == nullptr || target_model == nullptr) {
return false;
}
if (draft_model->arch != LLM_ARCH_DFLASH_DRAFT) {
return true;
}
if (draft_model->tok_embd == nullptr) {
draft_model->tok_embd = target_model->tok_embd;
}
if (draft_model->output == nullptr) {
draft_model->output = target_model->output ? target_model->output : target_model->tok_embd;
if (draft_model->output == nullptr) {
draft_model->output = draft_model->tok_embd;
}
}
const bool uses_shared_tok = draft_model->tok_embd == target_model->tok_embd;
const bool uses_shared_output = draft_model->output == target_model->output ||
draft_model->output == target_model->tok_embd;
if (draft_model->output_mtp == nullptr && target_model->output_mtp != nullptr && uses_shared_tok && uses_shared_output) {
draft_model->output_mtp = target_model->output_mtp;
}
const struct ggml_tensor * output = llama_model_dflash_output_tensor(draft_model);
return draft_model->tok_embd != nullptr && output != nullptr;
}
static bool llama_set_dflash_target_features_impl(
struct llama_context * ctx,
const float * target_features,
size_t n_floats,
int32_t n_rows,
const llama_pos * target_positions,
bool copy_data,
const llama_dflash_window_update * window_update) {
const bool have_full_features = target_features != nullptr && n_floats > 0;
const bool have_append_features = window_update != nullptr &&
window_update->append_features != nullptr &&
window_update->append_floats > 0 &&
window_update->append_rows > 0;
if (ctx == nullptr || n_rows <= 0 || (!have_full_features && !have_append_features)) {
return false;
}
if (have_full_features && copy_data) {
ctx->dflash.target.features_owned.assign(target_features, target_features + n_floats);
ctx->dflash.target.features = ctx->dflash.target.features_owned.data();
} else if (have_full_features) {
ctx->dflash.target.features_owned.clear();
ctx->dflash.target.features = target_features;
} else {
ctx->dflash.target.features_owned.clear();
ctx->dflash.target.features = nullptr;
}
ctx->dflash.target.features_n_floats = have_full_features ? n_floats : 0;
ctx->dflash.target.features_n_rows = n_rows;
if (have_append_features && copy_data) {
ctx->dflash.target.append_features_owned.assign(
window_update->append_features,
window_update->append_features + window_update->append_floats);
ctx->dflash.target.append_features = ctx->dflash.target.append_features_owned.data();
} else if (have_append_features) {
ctx->dflash.target.append_features_owned.clear();
ctx->dflash.target.append_features = window_update->append_features;
} else {
ctx->dflash.target.append_features_owned.clear();
ctx->dflash.target.append_features = nullptr;
}
ctx->dflash.target.append_features_n_floats = have_append_features ? window_update->append_floats : 0;
ctx->dflash.target.append_features_n_rows = have_append_features ? window_update->append_rows : 0;
ctx->dflash.target.version = window_update != nullptr && window_update->version > 0
? window_update->version
: ctx->dflash.target.version + 1;
ctx->dflash.target.keep_rows = window_update != nullptr
? std::max<int32_t>(0, std::min(n_rows, window_update->keep_rows))
: 0;
ctx->dflash.target.append_rows = window_update != nullptr
? std::max<int32_t>(0, std::min(n_rows, window_update->append_rows))
: n_rows;
ctx->dflash.target.replace = window_update != nullptr
? window_update->replace
: true;
if (ctx->dflash.target.keep_rows + ctx->dflash.target.append_rows > n_rows) {
ctx->dflash.target.keep_rows = std::max<int32_t>(0, n_rows - ctx->dflash.target.append_rows);
}
const int32_t cross_ctx = ctx->dflash.visible_cross_ctx > 0
? ctx->dflash.visible_cross_ctx
: std::max<int32_t>(1, (int32_t) ctx->cparams.n_ctx - (int32_t) ctx->model.hparams.dflash_block_size);
const llama_dflash_window_update cache_window_update = {
ctx->dflash.target.version,
ctx->dflash.target.keep_rows,
ctx->dflash.target.append_rows,
ctx->dflash.target.replace,
ctx->dflash.target.append_features,
ctx->dflash.target.append_features_n_floats,
};
const llama_dflash_kv_cache_transition cache_plan = llama_plan_dflash_kv_cache_transition_for_ctx(ctx, cache_window_update, n_rows);
if (cache_plan.cache_up_to_date) {
ctx->dflash.kv.cache_view_n_filled = ctx->dflash.kv.cache_n_filled;
ctx->dflash.kv.cache_view_write_pos = ctx->dflash.kv.cache_write_pos;
ctx->dflash.kv.cache_view_valid = ctx->dflash.kv.cache_valid;
} else if (cross_ctx > 0) {
ctx->dflash.kv.cache_view_n_filled = cache_plan.next_n_filled;
ctx->dflash.kv.cache_view_write_pos = cache_plan.next_write_pos;
ctx->dflash.kv.cache_view_valid = cache_plan.next_n_filled > 0;
}
if (target_positions != nullptr) {
if (copy_data) {
ctx->dflash.target.positions_owned.assign(target_positions, target_positions + n_rows);
ctx->dflash.target.positions = ctx->dflash.target.positions_owned.data();
} else {
ctx->dflash.target.positions_owned.clear();
ctx->dflash.target.positions = target_positions;
}
ctx->dflash.target.positions_n = (size_t) n_rows;
} else {
ctx->dflash.target.positions_owned.clear();
ctx->dflash.target.positions = nullptr;
ctx->dflash.target.positions_n = 0;
}
return true;
}
bool llama_set_dflash_target_features_copy(
struct llama_context * ctx,
const float * target_features,
size_t n_floats,
int32_t n_rows,
const llama_pos * target_positions,
const llama_dflash_window_update * window_update) {
return llama_set_dflash_target_features_impl(ctx, target_features, n_floats, n_rows, target_positions, true, window_update);
}
bool llama_set_dflash_target_features_view(
struct llama_context * ctx,
const float * target_features,
size_t n_floats,
int32_t n_rows,
const llama_pos * target_positions,
const llama_dflash_window_update * window_update) {
return llama_set_dflash_target_features_impl(ctx, target_features, n_floats, n_rows, target_positions, false, window_update);
}
static bool llama_dflash_parse_layer_id(const struct ggml_tensor * tensor, int32_t & layer_id) {
if (tensor == nullptr) {
return false;
}
static constexpr const char * prefix = "l_out-";
if (std::strncmp(tensor->name, prefix, std::strlen(prefix)) != 0) {
return false;
}
char * end = nullptr;
const long raw = std::strtol(tensor->name + std::strlen(prefix), &end, 10);
if (end == tensor->name + std::strlen(prefix) || *end != '\0') {
return false;
}
layer_id = (int32_t) raw;
if (layer_id >= 1000) {
layer_id %= 1000;
}
return layer_id >= 0;
}
static int32_t llama_dflash_find_layer_index(const struct llama_context * ctx, int32_t layer_id) {
if (ctx == nullptr || !ctx->dflash.capture) {
return -1;
}
const auto & layer_ids = ctx->dflash.capture->layer_ids;
const auto it = std::find(layer_ids.begin(), layer_ids.end(), layer_id);
return it == layer_ids.end() ? -1 : (int32_t) std::distance(layer_ids.begin(), it);
}
static bool llama_dflash_capture_eval_callback(struct ggml_tensor * tensor, bool ask, void * user_data) {
auto * ctx = static_cast<llama_context *>(user_data);
if (ctx == nullptr || !ctx->dflash.capture) {
return false;
}
int32_t layer_id = -1;
if (!llama_dflash_parse_layer_id(tensor, layer_id)) {
return false;
}
const int32_t layer_idx = llama_dflash_find_layer_index(ctx, layer_id);
if (layer_idx < 0) {
return false;
}
if (ask) {
return true;
}
const int32_t row_width = (int32_t) tensor->ne[0];
const int32_t row_count = row_width > 0 ? (int32_t) (ggml_nelements(tensor) / (int64_t) row_width) : 0;
if (row_width <= 0 || row_count <= 0) {
return false;
}
auto & capture = *ctx->dflash.capture;
if (capture.capture_batch_id == 0) {
capture.capture_batch_id = 1;
}
if (capture.layer_seen_batch_id.size() != capture.layer_ids.size()) {
capture.layer_seen_batch_id.assign(capture.layer_ids.size(), 0);
}
auto & rows = capture.layer_rows[(size_t) layer_idx];
rows.resize((size_t) row_count * (size_t) row_width);
ggml_backend_tensor_get(tensor, rows.data(), 0, ggml_nbytes(tensor));
capture.row_width = row_width;
capture.row_count = row_count;
capture.layer_seen_batch_id[(size_t) layer_idx] = capture.capture_batch_id;
return true;
}
bool llama_set_dflash_capture_layers(
struct llama_context * ctx,
const int32_t * layer_ids,
int32_t n_layers) {
if (ctx == nullptr || layer_ids == nullptr || n_layers <= 0) {
return false;
}
auto capture = std::make_unique<llama_context::dflash_runtime::capture_state>();
capture->layer_ids.assign(layer_ids, layer_ids + n_layers);
capture->layer_rows.resize((size_t) n_layers);
capture->layer_seen_batch_id.assign((size_t) n_layers, 0);
capture->prev_cb_eval = ctx->cparams.cb_eval;
capture->prev_cb_eval_user_data = ctx->cparams.cb_eval_user_data;
ctx->dflash.capture = std::move(capture);
ctx->dflash.feature_view_buffer.clear();
ctx->cparams.cb_eval = llama_dflash_capture_eval_callback;
ctx->cparams.cb_eval_user_data = ctx;
if (ctx->sched != nullptr) {
ggml_backend_sched_set_eval_callback(ctx->sched, ctx->cparams.cb_eval, ctx->cparams.cb_eval_user_data);
}
return true;
}
void llama_clear_dflash_capture(struct llama_context * ctx) {
if (ctx == nullptr) {
return;
}
ggml_backend_sched_eval_callback prev_cb_eval = nullptr;
void * prev_cb_eval_user_data = nullptr;
if (ctx->dflash.capture) {
prev_cb_eval = ctx->dflash.capture->prev_cb_eval;
prev_cb_eval_user_data = ctx->dflash.capture->prev_cb_eval_user_data;
}
ctx->dflash.capture.reset();
ctx->dflash.feature_view_buffer.clear();
if (ctx->cparams.cb_eval == llama_dflash_capture_eval_callback && ctx->cparams.cb_eval_user_data == ctx) {
ctx->cparams.cb_eval = prev_cb_eval;
ctx->cparams.cb_eval_user_data = prev_cb_eval_user_data;
if (ctx->sched != nullptr) {
ggml_backend_sched_set_eval_callback(ctx->sched, prev_cb_eval, prev_cb_eval_user_data);
}
}
}
void llama_begin_dflash_capture_batch(struct llama_context * ctx) {
if (ctx == nullptr || !ctx->dflash.capture) {
return;
}
auto & capture = *ctx->dflash.capture;
capture.capture_batch_id++;
capture.row_count = 0;
capture.row_width = 0;
std::fill(capture.layer_seen_batch_id.begin(), capture.layer_seen_batch_id.end(), 0);
}
void llama_finish_dflash_capture_batch(
struct llama_context * ctx,
bool is_prompt_warmup) {
if (ctx == nullptr || !ctx->dflash.capture) {
return;
}
GGML_UNUSED(is_prompt_warmup);
auto & capture = *ctx->dflash.capture;
// Reset the batch-local reference shape so the next decode only compares layers within
// the same batch, not against the previous prompt/verify batch.
capture.row_count = 0;
capture.row_width = 0;
}
static bool llama_spec_prepare_dflash_capture(
struct llama_context * ctx,
int32_t & row_count,
int32_t & row_width,
int32_t & n_layers) {
if (ctx == nullptr || !ctx->dflash.capture) {
return false;
}
llama_synchronize(ctx);
auto & capture = *ctx->dflash.capture;
row_count = capture.row_count;
row_width = capture.row_width;
n_layers = (int32_t) capture.layer_ids.size();
if (row_count <= 0 || row_width <= 0 || n_layers <= 0 || capture.layer_rows.size() != (size_t) n_layers) {
return false;
}
if (capture.capture_batch_id == 0 || capture.layer_seen_batch_id.size() != (size_t) n_layers) {
LLAMA_LOG_WARN("%s: DFlash capture batch markers are not initialized (batch_id=%llu layers=%zu expected=%d)\n",
__func__,
(unsigned long long) capture.capture_batch_id,
capture.layer_seen_batch_id.size(),
n_layers);
return false;
}
for (int32_t layer_idx = 0; layer_idx < n_layers; ++layer_idx) {
if (capture.layer_seen_batch_id[(size_t) layer_idx] != capture.capture_batch_id) {
LLAMA_LOG_WARN("%s: DFlash capture is stale for layer %d (seen_batch=%llu current_batch=%llu rows=%d width=%d)\n",
__func__,
capture.layer_ids[(size_t) layer_idx],
(unsigned long long) capture.layer_seen_batch_id[(size_t) layer_idx],
(unsigned long long) capture.capture_batch_id,
row_count,
row_width);
return false;
}
const auto & rows = capture.layer_rows[(size_t) layer_idx];
if (rows.size() != (size_t) row_count * (size_t) row_width) {
LLAMA_LOG_WARN("%s: DFlash capture rows mismatch for layer %d: got=%zu expected=%zu (rows=%d width=%d)\n",
__func__, capture.layer_ids[(size_t) layer_idx], rows.size(),
(size_t) row_count * (size_t) row_width, row_count, row_width);
return false;
}
}
return true;
}
static bool llama_spec_materialize_dflash_rows_prepared(
struct llama_context * ctx,
int32_t row_count,
int32_t row_width,
int32_t n_layers,
const std::vector<int32_t> & row_indices,
std::vector<float> & rows_out,
int32_t & combined_width);
static bool llama_spec_materialize_dflash_rows(
struct llama_context * ctx,
const std::vector<int32_t> & row_indices,
std::vector<float> & rows_out,
int32_t & combined_width) {
int32_t row_count = 0;
int32_t row_width = 0;
int32_t n_layers = 0;
if (!llama_spec_prepare_dflash_capture(ctx, row_count, row_width, n_layers)) {
return false;
}
return llama_spec_materialize_dflash_rows_prepared(ctx, row_count, row_width, n_layers, row_indices, rows_out, combined_width);
}
static bool llama_spec_materialize_dflash_rows_prepared(
struct llama_context * ctx,
int32_t row_count,
int32_t row_width,
int32_t n_layers,
const std::vector<int32_t> & row_indices,
std::vector<float> & rows_out,
int32_t & combined_width) {
rows_out.clear();
combined_width = 0;
if (ctx == nullptr || row_indices.empty()) {
return false;
}
if (row_count <= 0 || row_width <= 0 || n_layers <= 0 || ctx->dflash.capture == nullptr) {
return false;
}
combined_width = row_width * n_layers;
rows_out.resize((size_t) row_indices.size() * (size_t) combined_width);
const auto & layer_rows = ctx->dflash.capture->layer_rows;
for (size_t out_row = 0; out_row < row_indices.size(); ++out_row) {
int32_t row_index = row_indices[out_row];
if (row_index < 0) {
row_index += row_count;
}
if (row_index < 0 || row_index >= row_count) {
rows_out.clear();
combined_width = 0;
return false;
}
float * dst = rows_out.data() + out_row * (size_t) combined_width;
for (int32_t layer_idx = 0; layer_idx < n_layers; ++layer_idx) {
const float * src = layer_rows[(size_t) layer_idx].data() + (size_t) row_index * (size_t) row_width;
std::memcpy(dst + (size_t) layer_idx * (size_t) row_width, src, (size_t) row_width * sizeof(float));
}
}
return true;
}
bool llama_spec_get_dflash_feature_view(
struct llama_context * ctx,
const llama_batch & batch,
llama_spec_feature_view & view) {
if (ctx == nullptr || batch.n_tokens <= 0 || batch.pos == nullptr || batch.n_seq_id == nullptr || batch.seq_id == nullptr) {
return false;
}
int32_t row_count = 0;
int32_t row_width = 0;
int32_t n_layers = 0;
if (!llama_spec_prepare_dflash_capture(ctx, row_count, row_width, n_layers)) {
return false;
}
const int32_t batch_row_offset = std::max<int32_t>(0, batch.n_tokens - row_count);
std::vector<int32_t> row_indices;
std::vector<int32_t> batch_indices;
row_indices.reserve((size_t) (batch.n_tokens - batch_row_offset));
batch_indices.reserve((size_t) (batch.n_tokens - batch_row_offset));
for (int32_t i = batch_row_offset; i < batch.n_tokens; ++i) {
row_indices.push_back(i - batch_row_offset);
batch_indices.push_back(i);
}
if (row_indices.empty()) {
return false;
}
view = {};
view.kind = LLAMA_SPEC_FEATURE_HIDDEN_STATE;
if (!llama_spec_materialize_dflash_rows_prepared(ctx, row_count, row_width, n_layers, row_indices, ctx->dflash.feature_view_buffer, view.width)) {
return false;
}
view.rows.reserve(batch_indices.size());
for (int32_t batch_index : batch_indices) {
if (batch.n_seq_id[batch_index] <= 0 || batch.seq_id[batch_index] == nullptr) {
view.rows.clear();
return false;
}
view.rows.push_back({
/* .seq_id = */ batch.seq_id[batch_index][0],
/* .pos = */ batch.pos[batch_index],
/* .data = */ ctx->dflash.feature_view_buffer.data() + view.rows.size() * (size_t) view.width,
});
}
return true;
}
bool llama_spec_get_dflash_feature_view_for_seq(
struct llama_context * ctx,
const llama_batch & batch,
llama_seq_id seq_id,
llama_spec_feature_view & view) {
if (ctx == nullptr || batch.n_tokens <= 0 || batch.pos == nullptr || batch.n_seq_id == nullptr || batch.seq_id == nullptr) {
return false;
}
int32_t row_count = 0;
int32_t row_width = 0;
int32_t n_layers = 0;
if (!llama_spec_prepare_dflash_capture(ctx, row_count, row_width, n_layers)) {
return false;
}
const int32_t batch_row_offset = std::max<int32_t>(0, batch.n_tokens - row_count);
std::vector<int32_t> row_indices;
row_indices.reserve((size_t) batch.n_tokens);
std::vector<int32_t> batch_indices;
batch_indices.reserve((size_t) batch.n_tokens);
for (int32_t i = batch_row_offset; i < batch.n_tokens; ++i) {
if (batch.n_seq_id[i] <= 0 || batch.seq_id[i] == nullptr) {
return false;
}
for (int32_t j = 0; j < batch.n_seq_id[i]; ++j) {
if (batch.seq_id[i][j] == seq_id) {
row_indices.push_back(i - batch_row_offset);
batch_indices.push_back(i);
break;
}
}
}
if (row_indices.empty()) {
return false;
}
view = {};
view.kind = LLAMA_SPEC_FEATURE_HIDDEN_STATE;
if (!llama_spec_materialize_dflash_rows_prepared(ctx, row_count, row_width, n_layers, row_indices, ctx->dflash.feature_view_buffer, view.width)) {
return false;
}
view.rows.reserve(row_indices.size());
for (size_t i = 0; i < batch_indices.size(); ++i) {
const int32_t batch_index = batch_indices[i];
view.rows.push_back({
/* .seq_id = */ seq_id,
/* .pos = */ batch.pos[batch_index],
/* .data = */ ctx->dflash.feature_view_buffer.data() + i * (size_t) view.width,
});
}
return true;
}
bool llama_spec_copy_dflash_rows_from_output_indices(
struct llama_context * ctx,
const std::vector<int32_t> & output_indices,
std::vector<float> & hidden_rows) {
int32_t combined_width = 0;
if (!llama_spec_materialize_dflash_rows(ctx, output_indices, hidden_rows, combined_width)) {
hidden_rows.clear();
return false;
}
return hidden_rows.size() == (size_t) output_indices.size() * (size_t) combined_width;
}

View File

@ -0,0 +1,136 @@
#pragma once
#include "llama.h"
#include <algorithm>
#include <cstdint>
#include <vector>
struct llama_context;
struct llama_model;
struct ggml_tensor;
struct llama_spec_feature_view;
struct llama_dflash_window_update {
uint64_t version = 0;
int32_t keep_rows = 0;
int32_t append_rows = 0;
bool replace = false;
const float * append_features = nullptr;
size_t append_floats = 0;
};
struct llama_dflash_kv_cache_transition {
bool cache_up_to_date = false;
bool rebuild_cache = false;
int32_t append_rows = 0;
int32_t next_n_filled = 0;
int32_t next_write_pos = 0;
};
static inline llama_dflash_kv_cache_transition llama_plan_dflash_kv_cache_transition(
int32_t cross_ctx,
int32_t current_n_filled,
int32_t current_write_pos,
bool cache_valid,
uint64_t applied_window_version,
uint64_t target_window_version,
int32_t keep_rows,
int32_t append_rows,
bool replace,
int32_t n_rows) {
llama_dflash_kv_cache_transition plan;
const int32_t safe_cross_ctx = std::max<int32_t>(1, cross_ctx);
const int32_t bounded_n_filled = std::clamp(current_n_filled, 0, safe_cross_ctx);
const int32_t bounded_append_rows = std::clamp(append_rows, 0, n_rows);
const int32_t bounded_keep_rows = std::clamp(keep_rows, 0, n_rows);
const int32_t expected_keep_rows = std::min(bounded_n_filled, std::max<int32_t>(0, safe_cross_ctx - bounded_append_rows));
plan.cache_up_to_date = cache_valid && applied_window_version == target_window_version;
plan.rebuild_cache = !cache_valid || replace || bounded_append_rows <= 0 || bounded_append_rows > n_rows;
if (!plan.rebuild_cache && bounded_keep_rows != expected_keep_rows) {
plan.rebuild_cache = true;
}
plan.append_rows = bounded_append_rows;
if (plan.cache_up_to_date) {
plan.next_n_filled = bounded_n_filled;
plan.next_write_pos = safe_cross_ctx > 0
? ((current_write_pos % safe_cross_ctx) + safe_cross_ctx) % safe_cross_ctx
: 0;
} else if (plan.rebuild_cache) {
plan.next_n_filled = std::min(safe_cross_ctx, n_rows);
plan.next_write_pos = plan.next_n_filled % safe_cross_ctx;
} else {
plan.next_n_filled = std::min(safe_cross_ctx, bounded_n_filled + bounded_append_rows);
plan.next_write_pos = (current_write_pos + bounded_append_rows) % safe_cross_ctx;
}
return plan;
}
llama_dflash_kv_cache_transition llama_plan_dflash_kv_cache_transition_for_ctx(
const struct llama_context * ctx,
const llama_dflash_window_update & window_update,
int32_t n_rows);
void llama_reset_dflash_kv_cache_state(struct llama_context * ctx);
void llama_set_dflash_visible_cross_ctx(struct llama_context * ctx, int32_t cross_ctx);
int32_t llama_get_dflash_visible_cross_ctx(const struct llama_context * ctx);
int32_t llama_model_dflash_block_size(const struct llama_model * model);
int32_t llama_model_dflash_mask_token_id(const struct llama_model * model);
int32_t llama_model_dflash_n_target_layers(const struct llama_model * model);
int32_t llama_model_dflash_n_target_features(const struct llama_model * model);
int32_t llama_model_dflash_target_layer_ids(const struct llama_model * model, int32_t * layer_ids, int32_t capacity);
int32_t llama_model_dflash_target_mask_token_id(const struct llama_model * model);
const struct ggml_tensor * llama_model_dflash_output_tensor(const struct llama_model * model);
enum llama_dflash_io_mode {
LLAMA_DFLASH_IO_MODE_INVALID = 0,
LLAMA_DFLASH_IO_MODE_SHARED,
LLAMA_DFLASH_IO_MODE_SELF_CONTAINED,
LLAMA_DFLASH_IO_MODE_MIXED,
};
int32_t llama_model_dflash_io_mode(const struct llama_model * draft_model, const struct llama_model * target_model);
bool llama_model_dflash_io_tensors_match(const struct llama_model * draft_model, int32_t n_embd, int32_t n_vocab);
bool llama_model_share_dflash_io_tensors(struct llama_model * draft_model, const struct llama_model * target_model);
bool llama_set_dflash_target_features_copy(
struct llama_context * ctx,
const float * target_features,
size_t n_floats,
int32_t n_rows,
const llama_pos * target_positions,
const llama_dflash_window_update * window_update = nullptr);
bool llama_set_dflash_target_features_view(
struct llama_context * ctx,
const float * target_features,
size_t n_floats,
int32_t n_rows,
const llama_pos * target_positions,
const llama_dflash_window_update * window_update = nullptr);
bool llama_set_dflash_capture_layers(struct llama_context * ctx, const int32_t * layer_ids, int32_t n_layers);
void llama_clear_dflash_capture(struct llama_context * ctx);
void llama_begin_dflash_capture_batch(struct llama_context * ctx);
void llama_finish_dflash_capture_batch(struct llama_context * ctx, bool is_prompt_warmup);
bool llama_spec_get_dflash_feature_view(
struct llama_context * ctx,
const llama_batch & batch,
llama_spec_feature_view & view);
bool llama_spec_get_dflash_feature_view_for_seq(
struct llama_context * ctx,
const llama_batch & batch,
llama_seq_id seq_id,
llama_spec_feature_view & view);
bool llama_spec_copy_dflash_rows_from_output_indices(
struct llama_context * ctx,
const std::vector<int32_t> & output_indices,
std::vector<float> & hidden_rows);

View File

@ -88,6 +88,7 @@ bool llama_spec_get_hidden_feature_view(
return true;
}
bool llama_spec_get_hidden_feature_view_for_seq(
struct llama_context * ctx,
const llama_batch & batch,

View File

@ -2,6 +2,7 @@
#include "llama.h"
#include <cstdint>
#include <vector>
struct llama_context;
@ -23,6 +24,8 @@ struct llama_spec_feature_view {
std::vector<llama_spec_feature_row_view> rows;
};
#include "llama-spec-features-dflash.h"
uint32_t llama_mtp_state_n_embd(const struct llama_context * ctx);
bool llama_set_draft_input_hidden_state_copy(
@ -51,4 +54,4 @@ bool llama_spec_get_hidden_feature_view_from_output_index(
bool llama_spec_copy_hidden_rows_from_output_indices(
struct llama_context * ctx,
const std::vector<int32_t> & output_indices,
std::vector<float> & hidden_rows);
std::vector<float> & hidden_rows);

View File

@ -18,6 +18,7 @@
#include "llama-hparams.h"
#include "llama-context.h"
#include "llama-spec-features.h"
#include "llama-dflash.h"
#include "llama-quantize.h"
#include "unicode.h"
@ -171,6 +172,14 @@ static std::vector<std::string> string_split(const std::string& str, const std::
return parts;
}
static bool llama_env_flag_enabled(const char * name) {
const char * env = std::getenv(name);
return env != nullptr && *env != '\0' &&
std::strcmp(env, "0") != 0 &&
std::strcmp(env, "false") != 0 &&
std::strcmp(env, "off") != 0;
}
// extract ip and port from RPC[ip:port] for rpc and keep other device names
static std::vector<rpc_device> extract_device_from_rpc_device(std::vector<std::string> devices) {
std::vector<rpc_device> rpc_servers;
@ -689,6 +698,10 @@ void llama_context::set_mtp_op_type(llama_mtp_op_type value) {
}
llama_context::~llama_context() {
if (dflash.kv.cache_sched != nullptr) {
ggml_backend_sched_free(dflash.kv.cache_sched);
}
free_dflash_kv_cache_tensors();
ggml_backend_sched_free(sched);
for (ggml_backend_t backend : backends) {
@ -3163,6 +3176,10 @@ static std::pair<std::vector<double>, double> get_layer_sizes(const llama_model_
name == "mtp_centroids.weight" || name == "mtp_token_ordering.weight") {
continue;
}
if (name == "dflash_fc.weight" || name == "dflash_hidden_norm.weight") {
output_misc_size += size;
continue;
}
auto pos = name.find("blk.");
if (pos != 0) {
LLAMA_LOG_WARN("Oops: tensor with strange name %s\n", name.c_str());
@ -3972,7 +3989,7 @@ static bool llm_load_tensors(
if (model.arch == LLM_ARCH_GEMMA4) {
llm_scale_gate_inp_s(model, use_mmap_buffer);
}
if ((model.arch == LLM_ARCH_QWEN35 || model.arch == LLM_ARCH_QWEN35MOE) && extra_output_type != GGML_TYPE_COUNT) {
if ((model.arch == LLM_ARCH_QWEN35 || model.arch == LLM_ARCH_QWEN35MOE || model.arch == LLM_ARCH_DFLASH_DRAFT) && extra_output_type != GGML_TYPE_COUNT) {
llm_requantize_output_tensor(model, extra_output_type);
}
@ -4959,6 +4976,30 @@ static void llama_graph_compute(
// fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(lctx.sched));
}
static void llama_graph_compute_sched(
llama_context & lctx,
ggml_backend_sched_t sched,
ggml_cgraph * gf,
int n_threads) {
#ifdef GGML_USE_METAL
if (ggml_backend_is_metal(lctx.backend_metal)) {
ggml_backend_metal_set_n_cb(lctx.backend_metal, n_threads);
}
#endif
if (lctx.backend_cpu != nullptr) {
ggml_backend_cpu_set_n_threads(lctx.backend_cpu, n_threads);
ggml_backend_cpu_set_abort_callback(lctx.backend_cpu, lctx.abort_callback, lctx.abort_callback_data);
}
#ifdef GGML_USE_BLAS
if (lctx.backend_blas != nullptr) {
ggml_backend_blas_set_n_threads(lctx.backend_blas, n_threads);
}
#endif
ggml_backend_sched_graph_compute_async(sched, gf);
}
static bool prepare_mtp_graph_inputs(
struct llama_context & lctx,
uint32_t cur_token,
@ -5006,6 +5047,16 @@ static bool prepare_mtp_graph_inputs(
return true;
}
static bool dflash_layer_has_attention_bias(const llama_layer & layer) {
return layer.bq != nullptr ||
layer.bk != nullptr ||
layer.bv != nullptr ||
layer.bo != nullptr ||
layer.bqkv != nullptr ||
layer.bqk != nullptr ||
layer.bkv != nullptr;
}
// decode a batch of tokens by evaluating the transformer
//
// - lctx: llama context
@ -5034,6 +5085,8 @@ static int llama_decode_internal(
const auto & hparams = model.hparams;
const auto & cparams = lctx.cparams;
llama_begin_dflash_capture_batch(&lctx);
GGML_ASSERT((!batch_all.token && batch_all.embd) || (batch_all.token && !batch_all.embd)); // NOLINT
GGML_ASSERT(n_tokens_all <= cparams.n_batch);
@ -5082,8 +5135,11 @@ static int llama_decode_internal(
// reserve output buffer
n_outputs_embd = has_mtp && cparams.mtp_op_type == MTP_OP_NONE ? n_tokens_all : n_outputs;
if (llama_output_reserve(lctx, std::max<size_t>(n_outputs, n_outputs_embd)) < std::max<size_t>(n_outputs, n_outputs_embd)) {
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %zu outputs\n", __func__, std::max<size_t>(n_outputs, n_outputs_embd));
const size_t required_outputs = std::max<size_t>(n_outputs, n_outputs_embd);
const bool is_dflash_decode = lctx.model.arch == LLM_ARCH_DFLASH_DRAFT;
const size_t reserved_outputs = llama_output_reserve(lctx, required_outputs);
if (reserved_outputs < required_outputs) {
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %zu outputs\n", __func__, required_outputs);
return -2;
};
@ -5160,7 +5216,7 @@ static int llama_decode_internal(
// * mrope (embd): (section-major array of rope fields) [t; n][h; n][w; n][extra; n]
const uint8_t rope_params_per_token = (hparams.rope_type == LLAMA_ROPE_TYPE_MROPE ||
hparams.rope_type == LLAMA_ROPE_TYPE_IMROPE) ? 4 : 1;
llama_pos * u_batch_pos;
if (!batch_all.pos) {
u_batch_pos = nullptr;
@ -5278,7 +5334,6 @@ static int llama_decode_internal(
auto tim2 = ggml_time_us();
printf("prelude(...): %d us\n", int(tim2-tim1));
#endif
#if IK_PRINT_TIMING
tim1 = ggml_time_us();
#endif
@ -5329,10 +5384,25 @@ static int llama_decode_internal(
}
}
if (is_dflash_decode && !llama_prepare_dflash_graph_inputs(lctx, n_tokens)) {
return GGML_STATUS_FAILED;
}
// the output is always the last tensor in the graph
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
struct ggml_tensor * embd = nullptr;
// DFlash GPU argmax draft_argmax node
if (lctx.dflash.draft_tokens_tensor != nullptr &&
strcmp(res->name, "result_output") != 0) {
for (int i = gf->n_nodes - 2; i >= 0; --i) {
if (strcmp(gf->nodes[i]->name, "result_output") == 0) {
res = gf->nodes[i];
break;
}
}
}
if (lctx.n_outputs == 0) {
// no output
res = nullptr;
@ -5379,10 +5449,12 @@ static int llama_decode_internal(
tim2 = ggml_time_us();
printf("set_inputs(...): %d us\n", int(tim2-tim1));
#endif
#if IK_PRINT_TIMING
tim1 = ggml_time_us();
#endif
if (lctx.dflash.kv.workspace_sync_pending) {
llama_sync_dflash_workspace_if_pending(lctx);
}
llama_graph_compute(lctx, gf, n_threads);
#if IK_PRINT_TIMING
llama_synchronize(&lctx);
@ -5409,7 +5481,28 @@ static int llama_decode_internal(
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
//}
lctx.dflash.draft_tokens.clear();
if (lctx.dflash.draft_tokens_tensor != nullptr) {
ggml_backend_t backend_argmax = ggml_backend_sched_get_tensor_backend(
lctx.sched, lctx.dflash.draft_tokens_tensor);
if (backend_argmax != nullptr) {
const int64_t n_tokens_argmax = lctx.dflash.draft_tokens_tensor->ne[0];
lctx.dflash.draft_tokens.resize((size_t) n_tokens_argmax);
ggml_backend_tensor_get_async(backend_argmax,
lctx.dflash.draft_tokens_tensor,
lctx.dflash.draft_tokens.data(), 0,
(size_t) n_tokens_argmax * sizeof(int32_t));
}
}
// extract logits
{
const bool dflash_skip_logits = (lctx.model.arch == LLM_ARCH_DFLASH_DRAFT
&& !lctx.dflash.draft_tokens.empty());
if (dflash_skip_logits) {
res = nullptr;
}
}
if (res) {
#if IK_PRINT_TIMING
tim1 = ggml_time_us();
@ -7447,6 +7540,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
case LLM_ARCH_LAGUNA:
case LLM_ARCH_GEMMA4:
case LLM_ARCH_GEMMA4_MTP:
case LLM_ARCH_DFLASH_DRAFT:
case LLM_ARCH_GEMMA4_ASSISTANT:
return LLAMA_ROPE_TYPE_NEOX;
@ -9662,6 +9756,13 @@ float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
}
}
llama_token llama_get_dflash_draft_token_ith(struct llama_context * ctx, int32_t i) {
if ((size_t) i >= ctx->dflash.draft_tokens.size()) {
return LLAMA_TOKEN_NULL;
}
return ctx->dflash.draft_tokens[(size_t) i];
}
float * llama_get_embeddings(struct llama_context * ctx) {
llama_synchronize(ctx);