remove duplicated code and unnecesary refactor

This commit is contained in:
SamuelOliveirads 2026-06-14 16:02:02 -03:00
parent 3b1a0f88d5
commit 0d75eee35a
11 changed files with 2818 additions and 2963 deletions

View File

@ -0,0 +1,872 @@
#pragma once
#include <algorithm>
#include <atomic>
#include <cstddef>
#include <cstdlib>
#include <cstring>
#include <sstream>
#include <vector>
static bool common_speculative_are_dflash_compatible(
const llama_model * model_tgt,
const llama_model * model_dft) {
const llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt);
const llama_vocab * vocab_dft = llama_model_get_vocab(model_dft);
if (llama_vocab_type(vocab_tgt) != llama_vocab_type(vocab_dft)) {
LOG_DBG("%s: DFlash draft model vocab type must match the target model\n", __func__);
return false;
}
const bool add_bos_tgt = llama_vocab_get_add_bos(vocab_tgt);
const bool add_bos_dft = llama_vocab_get_add_bos(vocab_dft);
const bool add_eos_tgt = llama_vocab_get_add_eos(vocab_tgt);
const bool add_eos_dft = llama_vocab_get_add_eos(vocab_dft);
const llama_token bos_tgt = llama_vocab_bos(vocab_tgt);
const llama_token bos_dft = llama_vocab_bos(vocab_dft);
const llama_token eos_tgt = llama_vocab_eos(vocab_tgt);
const llama_token eos_dft = llama_vocab_eos(vocab_dft);
if (add_bos_tgt != add_bos_dft || add_eos_tgt != add_eos_dft ||
(add_bos_tgt && bos_tgt != bos_dft) ||
(add_eos_tgt && eos_tgt != eos_dft)) {
LOG_DBG("%s: DFlash draft special tokens must match the target model (add_bos=%d/%d add_eos=%d/%d bos=%d/%d eos=%d/%d)\n",
__func__,
(int) add_bos_tgt,
(int) add_bos_dft,
(int) add_eos_tgt,
(int) add_eos_dft,
(int) bos_tgt,
(int) bos_dft,
(int) eos_tgt,
(int) eos_dft);
return false;
}
const int n_vocab_tgt = llama_vocab_n_tokens(vocab_tgt);
const int n_vocab_dft = llama_vocab_n_tokens(vocab_dft);
const int vocab_diff = n_vocab_tgt > n_vocab_dft
? n_vocab_tgt - n_vocab_dft
: n_vocab_dft - n_vocab_tgt;
if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
LOG_DBG("%s: DFlash draft vocab size differs too much from the target model (%d vs %d)\n",
__func__, n_vocab_dft, n_vocab_tgt);
return false;
}
for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) {
const char * token_text_tgt = llama_vocab_get_text(vocab_tgt, i);
const char * token_text_dft = llama_vocab_get_text(vocab_dft, i);
if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
LOG_DBG("%s: DFlash draft token %d differs - target '%s', draft '%s'\n", __func__, i,
common_token_to_piece(vocab_tgt, i).c_str(),
common_token_to_piece(vocab_dft, i).c_str());
return false;
}
}
return true;
}
static bool dflash_contract_log_enabled() {
const char * env = std::getenv("IK_DFLASH_CONTRACT_LOG");
if (env == nullptr || *env == '\0') {
return false;
}
return std::strcmp(env, "0") != 0 &&
std::strcmp(env, "false") != 0 &&
std::strcmp(env, "off") != 0;
}
static bool dflash_stats_log_enabled() {
const char * env = std::getenv("IK_DFLASH_STATS_LOG");
if (env == nullptr || *env == '\0') {
return false;
}
return std::strcmp(env, "0") != 0 &&
std::strcmp(env, "false") != 0 &&
std::strcmp(env, "off") != 0;
}
template <typename T>
static std::string dflash_contract_format_values(
const std::vector<T> & values,
size_t edge_count = 4) {
std::ostringstream oss;
oss << '[';
if (values.empty()) {
oss << ']';
return oss.str();
}
const size_t head = std::min(edge_count, values.size());
for (size_t i = 0; i < head; ++i) {
if (i > 0) {
oss << ',';
}
oss << values[i];
}
if (values.size() > edge_count * 2) {
oss << ",...,";
for (size_t i = values.size() - edge_count; i < values.size(); ++i) {
if (i > values.size() - edge_count) {
oss << ',';
}
oss << values[i];
}
} else {
for (size_t i = head; i < values.size(); ++i) {
oss << ',' << values[i];
}
}
oss << ']';
return oss.str();
}
struct dflash_contract_pos_summary {
llama_pos first = -1;
llama_pos last = -1;
int32_t gap_count = 0;
int32_t nonmono_count = 0;
};
static dflash_contract_pos_summary dflash_contract_summarize_positions(
const std::vector<llama_pos> & positions) {
dflash_contract_pos_summary summary;
if (positions.empty()) {
return summary;
}
summary.first = positions.front();
summary.last = positions.back();
for (size_t i = 1; i < positions.size(); ++i) {
if (positions[i] <= positions[i - 1]) {
summary.nonmono_count++;
} else if (positions[i] != positions[i - 1] + 1) {
summary.gap_count++;
}
}
return summary;
}
struct common_speculative_state_dflash;
static void dflash_contract_log_append(
const common_speculative_state_dflash & state,
llama_seq_id seq_id,
const std::vector<llama_pos> & new_positions);
static void dflash_contract_log_draft(
const common_speculative_state_dflash & state,
int32_t n_keep,
size_t result_size);
static void dflash_materialize_target_window_features(common_speculative_state_dflash & state);
// DFlash runtime state and draft path.
struct common_speculative_state_dflash : public common_speculative_state {
llama_context * ctx_tgt;
llama_context * ctx_dft;
llama_batch batch = {};
int32_t block_size = 0;
int32_t mask_token_id = -1;
int32_t n_target_features = 0;
int32_t cross_ctx = 0;
bool ready = false;
std::vector<int32_t> target_layer_ids;
std::vector<float> target_window;
std::vector<llama_pos> target_window_pos;
std::vector<float> target_window_stage;
std::vector<llama_pos> target_window_pos_stage;
std::vector<float> target_window_ring;
std::vector<float> target_window_append_features;
int32_t target_window_rows = 0;
int32_t target_window_ring_write_pos = 0;
int32_t target_window_ring_filled = 0;
uint64_t target_window_version = 0;
int32_t target_window_keep_rows = 0;
int32_t target_window_append_rows = 0;
bool target_window_replace = false;
bool target_window_materialized = false;
llama_pos last_target_pos = -1;
size_t n_window_updates = 0;
size_t n_rows_seen = 0;
size_t n_rows_dropped = 0;
size_t n_context_shifts = 0;
size_t n_draft_empty = 0;
size_t n_set_target_fail = 0;
size_t n_decode_fail = 0;
llama_pos last_draft_pos_base = -1;
uint64_t t_draft_decode_us = 0;
uint64_t t_draft_sample_us = 0;
uint64_t t_warmup_collect_us = 0;
uint64_t t_warmup_append_us = 0;
uint64_t t_accept_output_copy_us = 0;
uint64_t t_accept_commit_us = 0;
uint64_t t_accept_append_us = 0;
uint64_t t_accept_append_filter_us = 0;
uint64_t t_accept_append_window_alloc_us = 0;
uint64_t t_accept_append_replace_us = 0;
uint64_t t_accept_append_keep_old_us = 0;
uint64_t t_accept_append_new_rows_us = 0;
uint64_t t_accept_append_commit_detail_us = 0;
uint64_t t_accept_append_log_us = 0;
size_t n_warmup_collect_calls = 0;
size_t n_warmup_collect_rows = 0;
size_t n_warmup_append_calls = 0;
size_t n_warmup_append_rows = 0;
size_t n_accept_output_copy_calls = 0;
size_t n_accept_output_copy_rows = 0;
size_t n_accept_commit_calls = 0;
size_t n_accept_commit_rows = 0;
size_t n_accept_append_calls = 0;
size_t n_accept_append_rows = 0;
size_t n_accept_append_replace_calls = 0;
size_t n_accept_append_slide_calls = 0;
common_speculative_state_dflash(
enum common_speculative_type type,
llama_context * ctx_tgt,
llama_context * ctx_dft,
int32_t cross_ctx)
: common_speculative_state(type)
, ctx_tgt(ctx_tgt)
, ctx_dft(ctx_dft)
, cross_ctx(std::max(1, cross_ctx))
{
const llama_model * model_tgt = llama_get_model(ctx_tgt);
const llama_model * model_dft = llama_get_model(ctx_dft);
if (!common_speculative_are_dflash_compatible(model_tgt, model_dft)) {
LOG_ERR("%s: DFlash draft model vocab/tokenizer is incompatible with the target model\n", __func__);
return;
}
block_size = llama_model_dflash_block_size(model_dft);
mask_token_id = llama_model_dflash_mask_token_id(model_dft);
n_target_features = llama_model_dflash_n_target_features(model_dft);
const int32_t n_target_layers = llama_model_dflash_n_target_layers(model_dft);
if (block_size <= 0 || mask_token_id < 0 || n_target_features <= 0 || n_target_layers <= 0) {
LOG_ERR("%s: invalid DFlash metadata (block_size=%d, mask_token_id=%d, n_target_features=%d, n_target_layers=%d)\n",
__func__, block_size, mask_token_id, n_target_features, n_target_layers);
return;
}
target_layer_ids.resize((size_t) n_target_layers);
if (llama_model_dflash_target_layer_ids(model_dft, target_layer_ids.data(), n_target_layers) != n_target_layers) {
LOG_ERR("%s: failed to read DFlash target layer ids\n", __func__);
target_layer_ids.clear();
return;
}
const auto * vocab_tgt = llama_model_get_vocab(model_tgt);
const auto * vocab_dft = llama_model_get_vocab(model_dft);
const int32_t target_vocab_size = llama_vocab_n_tokens(vocab_tgt);
const int32_t draft_vocab_size = llama_vocab_n_tokens(vocab_dft);
const int32_t target_hidden_size = llama_model_n_embd(model_tgt);
const int32_t draft_hidden_size = llama_model_n_embd(model_dft);
const int32_t target_mask_token_id = llama_model_dflash_target_mask_token_id(model_tgt);
const int32_t expected_n_target_features = target_hidden_size > 0 ? target_hidden_size * n_target_layers : 0;
if (target_mask_token_id != (int32_t) LLAMA_TOKEN_NULL && mask_token_id != target_mask_token_id) {
LOG_ERR("%s: DFlash mask token mismatch (draft=%d target=%d)\n",
__func__, mask_token_id, target_mask_token_id);
return;
}
if (target_hidden_size <= 0 || draft_hidden_size <= 0) {
LOG_ERR("%s: invalid DFlash hidden sizes (draft=%d target=%d)\n",
__func__, draft_hidden_size, target_hidden_size);
return;
}
if (expected_n_target_features <= 0 || n_target_features != expected_n_target_features) {
LOG_ERR("%s: DFlash target feature width mismatch (metadata=%d expected=%d target_hidden=%d target_layers=%d)\n",
__func__, n_target_features, expected_n_target_features, target_hidden_size, n_target_layers);
return;
}
std::vector<int32_t> sorted_target_layer_ids = target_layer_ids;
std::sort(sorted_target_layer_ids.begin(), sorted_target_layer_ids.end());
if (std::adjacent_find(sorted_target_layer_ids.begin(), sorted_target_layer_ids.end()) != sorted_target_layer_ids.end()) {
LOG_ERR("%s: duplicate DFlash target layer ids survived into runtime validation\n", __func__);
target_layer_ids.clear();
return;
}
const int32_t n_target_model_layers = llama_n_layer(model_tgt);
for (int32_t layer_id : target_layer_ids) {
if (layer_id < 0 || layer_id >= n_target_model_layers) {
LOG_ERR("%s: invalid DFlash target layer id %d for target model with %d layers\n",
__func__, layer_id, n_target_model_layers);
target_layer_ids.clear();
return;
}
}
const int32_t io_mode = llama_model_dflash_io_mode(model_dft, model_tgt);
if (io_mode == LLAMA_DFLASH_IO_MODE_INVALID) {
LOG_ERR("%s: DFlash draft is missing required IO tensors after target sharing\n", __func__);
return;
}
if (io_mode == LLAMA_DFLASH_IO_MODE_MIXED) {
LOG_ERR("%s: DFlash IO contract must be fully shared or fully self-contained, but resolved to mixed mode\n", __func__);
return;
}
if (io_mode == LLAMA_DFLASH_IO_MODE_SELF_CONTAINED && !llama_model_dflash_io_tensors_match(model_dft, target_hidden_size, target_vocab_size)) {
LOG_ERR("%s: DFlash self-contained IO tensors do not match the target hidden/vocab contract (target_hidden=%d target_vocab=%d)\n",
__func__,
target_hidden_size,
target_vocab_size);
return;
}
if (!llama_set_dflash_capture_layers(ctx_tgt, target_layer_ids.data(), (int32_t) target_layer_ids.size())) {
LOG_ERR("%s: failed to configure DFlash target capture callback\n", __func__);
return;
}
batch = llama_batch_init(std::max(1, block_size), 0, 1);
target_window.reserve((size_t) this->cross_ctx * (size_t) n_target_features);
target_window_stage.reserve((size_t) this->cross_ctx * (size_t) n_target_features);
target_window_ring.resize((size_t) this->cross_ctx * (size_t) n_target_features);
target_window_append_features.reserve((size_t) this->cross_ctx * (size_t) n_target_features);
target_window_pos.reserve((size_t) this->cross_ctx);
target_window_pos_stage.reserve((size_t) this->cross_ctx);
ready = true;
llama_set_dflash_visible_cross_ctx(ctx_dft, this->cross_ctx);
llama_dflash_profile_reset(ctx_tgt);
llama_dflash_profile_reset(ctx_dft);
std::ostringstream layers_oss;
for (size_t i = 0; i < target_layer_ids.size(); ++i) {
if (i > 0) {
layers_oss << ",";
}
layers_oss << target_layer_ids[i];
}
const char * io_mode_name = io_mode == LLAMA_DFLASH_IO_MODE_SHARED ? "shared" : "self-contained";
LOG_INF("%s: DFlash context ready (n_ctx=%d, block_size=%d, cross_ctx=%d, n_target_features=%d, target_layer_ids=[%s])\n",
__func__, llama_n_ctx(ctx_dft), block_size, this->cross_ctx, n_target_features, layers_oss.str().c_str());
LOG_INF("%s: DFlash artifact io=%s draft_vocab=%d target_vocab=%d draft_hidden=%d target_hidden=%d mask_token_id=%d target_mask_token_id=%d\n",
__func__, io_mode_name, draft_vocab_size, target_vocab_size, draft_hidden_size, target_hidden_size, mask_token_id, target_mask_token_id);
}
~common_speculative_state_dflash() override {
llama_clear_dflash_capture(ctx_tgt);
if (ctx_dft) {
llama_free(ctx_dft);
}
if (batch.token != nullptr) {
llama_batch_free(batch);
}
}
void begin(const llama_tokens & prompt) override {
GGML_UNUSED(prompt);
llama_kv_cache_clear(ctx_dft);
llama_reset_dflash_kv_cache_state(ctx_dft);
n_window_updates = 0;
n_rows_seen = 0;
n_rows_dropped = 0;
n_context_shifts = 0;
n_draft_empty = 0;
n_set_target_fail = 0;
n_decode_fail = 0;
last_draft_pos_base = -1;
t_draft_decode_us = 0;
t_draft_sample_us = 0;
t_warmup_collect_us = 0;
t_warmup_append_us = 0;
t_accept_output_copy_us = 0;
t_accept_commit_us = 0;
t_accept_append_us = 0;
t_accept_append_filter_us = 0;
t_accept_append_window_alloc_us = 0;
t_accept_append_replace_us = 0;
t_accept_append_keep_old_us = 0;
t_accept_append_new_rows_us = 0;
t_accept_append_commit_detail_us = 0;
t_accept_append_log_us = 0;
n_warmup_collect_calls = 0;
n_warmup_collect_rows = 0;
n_warmup_append_calls = 0;
n_warmup_append_rows = 0;
n_accept_output_copy_calls = 0;
n_accept_output_copy_rows = 0;
n_accept_commit_calls = 0;
n_accept_commit_rows = 0;
n_accept_append_calls = 0;
n_accept_append_rows = 0;
n_accept_append_replace_calls = 0;
n_accept_append_slide_calls = 0;
llama_dflash_profile_reset(ctx_tgt);
llama_dflash_profile_reset(ctx_dft);
}
void draft(
const common_params_speculative & params,
const llama_tokens & prompt_tgt,
llama_token id_last,
llama_tokens & result) override {
GGML_UNUSED(prompt_tgt);
result.clear();
if (!ready || target_window_rows <= 0) {
n_draft_empty++;
return;
}
const int32_t n_keep = std::min<int32_t>(params.n_max, block_size - 1);
if (n_keep <= 0) {
return;
}
const float * target_features = nullptr;
size_t target_feature_floats = 0;
llama_dflash_window_update window_update = {
target_window_version,
target_window_keep_rows,
target_window_append_rows,
target_window_replace,
target_window_append_features.empty() ? nullptr : target_window_append_features.data(),
target_window_append_features.size(),
};
const llama_dflash_kv_cache_transition cache_plan =
llama_plan_dflash_kv_cache_transition_for_ctx(ctx_dft, window_update, target_window_rows);
if (cache_plan.rebuild_cache) {
dflash_materialize_target_window_features(*this);
target_features = target_window.data();
target_feature_floats = target_window.size();
window_update.append_features = target_window.data();
window_update.append_floats = target_window.size();
window_update.append_rows = target_window_rows;
}
if (!llama_set_dflash_target_features_view(ctx_dft, target_features, target_feature_floats, target_window_rows, target_window_pos.data(), &window_update)) {
LOG_ERR("%s: failed to set DFlash target features\n", __func__);
n_set_target_fail++;
return;
}
llama_kv_cache_clear(ctx_dft);
batch.n_tokens = 0;
const int32_t batch_len = n_keep + 1;
const llama_pos draft_pos_base = last_target_pos >= 0 ? last_target_pos + 1 : (llama_pos) target_window_rows;
const llama_pos seed_pos = last_target_pos >= 0 ? last_target_pos : draft_pos_base - 1;
last_draft_pos_base = draft_pos_base;
common_batch_add(batch, id_last, seed_pos, { 0 }, false);
for (int32_t i = 1; i < batch_len; ++i) {
common_batch_add(batch, mask_token_id, draft_pos_base + (i - 1), { 0 }, i <= n_keep);
}
const int64_t t_decode_us = ggml_time_us();
if (llama_decode(ctx_dft, batch) != 0) {
LOG_ERR("%s: llama_decode() failed for DFlash draft batch\n", __func__);
n_decode_fail++;
batch.n_tokens = 0;
return;
}
t_draft_decode_us += (uint64_t) (ggml_time_us() - t_decode_us);
result.reserve((size_t) n_keep);
const int64_t t_sample_us = ggml_time_us();
for (int32_t i = 0; i < n_keep; ++i) {
llama_token id = llama_get_dflash_draft_token_ith(ctx_dft, i);
if (id == LLAMA_TOKEN_NULL) {
id = common_sampler_sample_speculative(nullptr, ctx_dft, i + 1, nullptr);
}
result.push_back(id);
}
t_draft_sample_us += (uint64_t) (ggml_time_us() - t_sample_us);
batch.n_tokens = 0;
dflash_contract_log_draft(*this, n_keep, result.size());
}
void accept(uint16_t n_accepted) override {
GGML_UNUSED(n_accepted);
}
};
static void dflash_contract_log_append(
const common_speculative_state_dflash & state,
llama_seq_id seq_id,
const std::vector<llama_pos> & new_positions) {
if (!dflash_contract_log_enabled()) {
return;
}
static std::atomic<uint64_t> counter = 0;
const uint64_t ordinal = counter.fetch_add(1, std::memory_order_relaxed);
if (ordinal >= 8) {
return;
}
const dflash_contract_pos_summary incoming = dflash_contract_summarize_positions(new_positions);
const dflash_contract_pos_summary window = dflash_contract_summarize_positions(state.target_window_pos);
LOG_INF("dflash contract append[%llu]: seq=%d incoming_rows=%zu incoming_pos=%s pos=[%d..%d] gaps=%d nonmono=%d window_rows=%d window_pos=%s pos=[%d..%d] gaps=%d nonmono=%d last_target_pos=%d\n",
(unsigned long long) (ordinal + 1),
(int) seq_id,
new_positions.size(),
dflash_contract_format_values(new_positions).c_str(),
(int) incoming.first,
(int) incoming.last,
incoming.gap_count,
incoming.nonmono_count,
state.target_window_rows,
dflash_contract_format_values(state.target_window_pos).c_str(),
(int) window.first,
(int) window.last,
window.gap_count,
window.nonmono_count,
(int) state.last_target_pos);
}
static void dflash_contract_log_draft(
const common_speculative_state_dflash & state,
int32_t n_keep,
size_t result_size) {
if (!dflash_contract_log_enabled()) {
return;
}
static std::atomic<uint64_t> counter = 0;
const uint64_t ordinal = counter.fetch_add(1, std::memory_order_relaxed);
if (ordinal >= 8) {
return;
}
const dflash_contract_pos_summary window = dflash_contract_summarize_positions(state.target_window_pos);
llama_dflash_profile_stats graph_stats = {};
llama_dflash_profile_get_stats(state.ctx_dft, &graph_stats);
const int draft_delta = (state.last_target_pos >= 0 && state.last_draft_pos_base >= 0)
? (int) (state.last_draft_pos_base - state.last_target_pos)
: -1;
const llama_pos seed_pos = state.last_target_pos;
const llama_pos mask_first_pos = state.last_draft_pos_base;
const llama_pos mask_last_pos = state.last_draft_pos_base >= 0
? state.last_draft_pos_base + n_keep - 1
: -1;
LOG_INF("dflash contract draft[%llu]: window_rows=%d window_pos=%s pos=[%d..%d] gaps=%d nonmono=%d last_target_pos=%d seed_pos=%d mask_pos=[%d..%d] sample_rows=[1..%d] output_rows=[1..%d] draft_pos_base=%d delta=%d n_keep=%d result=%zu set_target(missing/nonmono)=%llu/%llu graph(fallback/nonmono)=%llu/%llu graph_pos=[%d..%d]\n",
(unsigned long long) (ordinal + 1),
state.target_window_rows,
dflash_contract_format_values(state.target_window_pos).c_str(),
(int) window.first,
(int) window.last,
window.gap_count,
window.nonmono_count,
(int) state.last_target_pos,
(int) seed_pos,
(int) mask_first_pos,
(int) mask_last_pos,
n_keep,
n_keep,
(int) state.last_draft_pos_base,
draft_delta,
n_keep,
result_size,
(unsigned long long) graph_stats.set_target_missing_positions,
(unsigned long long) graph_stats.set_target_non_monotonic_positions,
(unsigned long long) graph_stats.graph_pos_fallbacks,
(unsigned long long) graph_stats.graph_pos_non_monotonic,
(int) graph_stats.last_pos_first,
(int) graph_stats.last_pos_last);
}
struct dflash_append_breakdown {
uint64_t filter_us = 0;
uint64_t window_alloc_us = 0;
uint64_t replace_us = 0;
uint64_t keep_old_us = 0;
uint64_t new_rows_us = 0;
uint64_t commit_us = 0;
uint64_t log_us = 0;
bool replace_call = false;
};
static void dflash_record_window_update(
common_speculative_state_dflash & state,
int32_t keep_rows,
int32_t append_rows,
bool replace) {
state.target_window_keep_rows = std::max<int32_t>(0, keep_rows);
state.target_window_append_rows = std::max<int32_t>(0, append_rows);
state.target_window_replace = replace;
state.target_window_version++;
}
static void dflash_ring_reset_rows(
common_speculative_state_dflash & state,
const float * rows,
int32_t n_rows) {
const size_t row_width = (size_t) state.n_target_features;
if (n_rows <= 0 || rows == nullptr) {
state.target_window_ring_write_pos = 0;
state.target_window_ring_filled = 0;
return;
}
if (state.target_window_ring.size() != (size_t) state.cross_ctx * row_width) {
state.target_window_ring.resize((size_t) state.cross_ctx * row_width);
}
std::memcpy(state.target_window_ring.data(), rows, (size_t) n_rows * row_width * sizeof(float));
state.target_window_ring_write_pos = n_rows % state.cross_ctx;
state.target_window_ring_filled = n_rows;
state.target_window_materialized = false;
}
static void dflash_ring_append_rows(
common_speculative_state_dflash & state,
const float * rows,
int32_t n_rows) {
const size_t row_width = (size_t) state.n_target_features;
if (n_rows <= 0 || rows == nullptr) {
return;
}
if (state.target_window_ring.size() != (size_t) state.cross_ctx * row_width) {
state.target_window_ring.resize((size_t) state.cross_ctx * row_width);
}
int32_t write_pos = state.target_window_ring_write_pos;
int32_t remaining = n_rows;
const float * src = rows;
while (remaining > 0) {
const int32_t chunk_rows = std::min<int32_t>(remaining, state.cross_ctx - write_pos);
std::memcpy(
state.target_window_ring.data() + (size_t) write_pos * row_width,
src,
(size_t) chunk_rows * row_width * sizeof(float));
src += (size_t) chunk_rows * row_width;
remaining -= chunk_rows;
write_pos = (write_pos + chunk_rows) % state.cross_ctx;
}
state.target_window_ring_write_pos = write_pos;
state.target_window_ring_filled = std::min(state.cross_ctx, state.target_window_ring_filled + n_rows);
state.target_window_materialized = false;
}
static void dflash_materialize_target_window_features(common_speculative_state_dflash & state) {
if (state.target_window_materialized || state.target_window_rows <= 0) {
return;
}
const size_t row_width = (size_t) state.n_target_features;
state.target_window.resize((size_t) state.target_window_rows * row_width);
const int32_t read_start = (state.target_window_ring_write_pos - state.target_window_rows + state.cross_ctx) % state.cross_ctx;
const int32_t first_rows = std::min<int32_t>(state.target_window_rows, state.cross_ctx - read_start);
std::memcpy(
state.target_window.data(),
state.target_window_ring.data() + (size_t) read_start * row_width,
(size_t) first_rows * row_width * sizeof(float));
const int32_t second_rows = state.target_window_rows - first_rows;
if (second_rows > 0) {
std::memcpy(
state.target_window.data() + (size_t) first_rows * row_width,
state.target_window_ring.data(),
(size_t) second_rows * row_width * sizeof(float));
}
state.target_window_materialized = true;
}
static bool dflash_append_target_features(
common_speculative_state_dflash & state,
const common_speculative_feature_view & features,
const llama_batch & batch,
llama_seq_id seq_id,
dflash_append_breakdown * breakdown = nullptr) {
GGML_UNUSED(batch);
if (features.kind != COMMON_SPECULATIVE_FEATURE_HIDDEN_STATE ||
features.width != state.n_target_features ||
features.rows.empty() ||
state.cross_ctx <= 0) {
return false;
}
const size_t row_width = (size_t) state.n_target_features;
std::vector<float> new_rows;
std::vector<llama_pos> new_positions;
new_rows.reserve(features.rows.size() * row_width);
new_positions.reserve(features.rows.size());
const int64_t t_filter_us = ggml_time_us();
for (const auto & row : features.rows) {
if (row.seq_id != seq_id || row.data == nullptr) {
continue;
}
new_positions.push_back(row.pos);
new_rows.insert(new_rows.end(), row.data, row.data + row_width);
}
if (breakdown != nullptr) {
breakdown->filter_us += (uint64_t) (ggml_time_us() - t_filter_us);
}
if (new_positions.empty()) {
return false;
}
const int32_t n_rows = (int32_t) new_positions.size();
state.n_window_updates++;
state.n_rows_seen += (size_t) n_rows;
if (n_rows >= state.cross_ctx) {
state.n_rows_dropped += (size_t) state.target_window_rows + (size_t) (n_rows - state.cross_ctx);
const int32_t keep_from = n_rows - state.cross_ctx;
const int64_t t_replace_us = ggml_time_us();
state.target_window_pos.assign(new_positions.begin() + keep_from, new_positions.end());
state.target_window_append_features.assign(
new_rows.begin() + (ptrdiff_t) keep_from * (ptrdiff_t) row_width,
new_rows.end());
dflash_ring_reset_rows(state, state.target_window_append_features.data(), state.cross_ctx);
if (breakdown != nullptr) {
breakdown->replace_us += (uint64_t) (ggml_time_us() - t_replace_us);
breakdown->replace_call = true;
}
const int64_t t_commit_us = ggml_time_us();
state.target_window_rows = state.cross_ctx;
state.target_window_ring_filled = state.target_window_rows;
state.last_target_pos = state.target_window_pos.empty() ? -1 : state.target_window_pos.back();
dflash_record_window_update(state, 0, state.target_window_rows, true);
if (breakdown != nullptr) {
breakdown->commit_us += (uint64_t) (ggml_time_us() - t_commit_us);
}
const int64_t t_log_us = ggml_time_us();
dflash_contract_log_append(state, seq_id, new_positions);
if (breakdown != nullptr) {
breakdown->log_us += (uint64_t) (ggml_time_us() - t_log_us);
}
return true;
}
const int32_t keep_old_rows = std::min<int32_t>(state.target_window_rows, state.cross_ctx - n_rows);
state.n_rows_dropped += (size_t) std::max<int32_t>(0, state.target_window_rows - keep_old_rows);
const int64_t t_window_alloc_us = ggml_time_us();
std::vector<llama_pos> & next_window_pos = state.target_window_pos_stage;
next_window_pos.resize((size_t) (keep_old_rows + n_rows));
if (breakdown != nullptr) {
breakdown->window_alloc_us += (uint64_t) (ggml_time_us() - t_window_alloc_us);
}
if (keep_old_rows > 0) {
const int64_t t_keep_old_us = ggml_time_us();
std::copy(state.target_window_pos.end() - keep_old_rows, state.target_window_pos.end(), next_window_pos.begin());
if (breakdown != nullptr) {
breakdown->keep_old_us += (uint64_t) (ggml_time_us() - t_keep_old_us);
}
}
const int64_t t_new_rows_us = ggml_time_us();
state.target_window_append_features.assign(new_rows.begin(), new_rows.end());
dflash_ring_append_rows(state, state.target_window_append_features.data(), n_rows);
std::copy(new_positions.begin(), new_positions.end(), next_window_pos.begin() + keep_old_rows);
if (breakdown != nullptr) {
breakdown->new_rows_us += (uint64_t) (ggml_time_us() - t_new_rows_us);
}
const int64_t t_commit_us = ggml_time_us();
state.target_window_pos.swap(next_window_pos);
next_window_pos.clear();
state.target_window_rows = keep_old_rows + n_rows;
state.target_window_ring_filled = state.target_window_rows;
state.last_target_pos = state.target_window_pos.empty() ? -1 : state.target_window_pos.back();
dflash_record_window_update(state, keep_old_rows, n_rows, false);
if (breakdown != nullptr) {
breakdown->commit_us += (uint64_t) (ggml_time_us() - t_commit_us);
}
const int64_t t_log_us = ggml_time_us();
dflash_contract_log_append(state, seq_id, new_positions);
if (breakdown != nullptr) {
breakdown->log_us += (uint64_t) (ggml_time_us() - t_log_us);
}
return true;
}
static void dflash_clear_target_features(common_speculative_state_dflash & state) {
state.target_window.clear();
state.target_window_pos.clear();
state.target_window_stage.clear();
state.target_window_pos_stage.clear();
state.target_window_append_features.clear();
state.target_window_rows = 0;
state.target_window_ring_write_pos = 0;
state.target_window_ring_filled = 0;
state.target_window_keep_rows = 0;
state.target_window_append_rows = 0;
state.target_window_replace = false;
state.target_window_materialized = false;
state.last_target_pos = -1;
llama_reset_dflash_kv_cache_state(state.ctx_dft);
}
static void dflash_context_shift(
common_speculative_state_dflash & state,
llama_pos kv_keep,
llama_pos kv_discard,
llama_pos kv_past) {
if (kv_discard <= 0 || state.target_window_rows <= 0 || state.target_window_pos.empty()) {
return;
}
dflash_materialize_target_window_features(state);
const size_t row_width = (size_t) state.n_target_features;
const llama_pos discard_begin = kv_keep;
const llama_pos discard_end = kv_keep + kv_discard;
std::vector<float> shifted_rows;
std::vector<llama_pos> shifted_positions;
shifted_rows.reserve(state.target_window.size());
shifted_positions.reserve(state.target_window_pos.size());
for (int32_t row = 0; row < state.target_window_rows; ++row) {
llama_pos pos = state.target_window_pos[(size_t) row];
if (pos >= discard_begin && pos < discard_end) {
continue;
}
if (pos >= discard_end && pos < kv_past) {
pos -= kv_discard;
}
const float * row_src = state.target_window.data() + (size_t) row * row_width;
shifted_rows.insert(shifted_rows.end(), row_src, row_src + row_width);
shifted_positions.push_back(pos);
}
state.target_window = std::move(shifted_rows);
state.target_window_pos = std::move(shifted_positions);
state.target_window_rows = (int32_t) state.target_window_pos.size();
dflash_ring_reset_rows(state, state.target_window.data(), state.target_window_rows);
state.last_target_pos = state.target_window_pos.empty() ? -1 : state.target_window_pos.back();
dflash_record_window_update(state, 0, state.target_window_rows, true);
llama_reset_dflash_kv_cache_state(state.ctx_dft);
state.n_context_shifts++;
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -570,6 +570,9 @@ class Model:
if chkhsh == "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273":
# ref: https://huggingface.co/alvarobartt/grok-2-tokenizer
res = "grok-2"
if chkhsh == "972da7b59cec44d1f0a490a86c96df53859e486e481563e5dddac155013d87ac":
# ref: https://huggingface.co/poolside/Laguna-XS.2
res = "laguna"
if chkhsh == "0ef9807a4087ebef797fc749390439009c3b9eda9ad1a097abbe738f486c01e5":
# ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B
res = "llama-bpe"
@ -603,6 +606,9 @@ class Model:
if chkhsh == "9c2227e4dd922002fb81bde4fc02b0483ca4f12911410dee2255e4987644e3f8":
# ref: https://huggingface.co/CohereForAI/c4ai-command-r-v01
res = "command-r"
if chkhsh == "52df12b4c8d4176e7481aab4b6e8454d1fd0a210a04a574f6d4e067d10e23c3e":
# ref: https://huggingface.co/CohereLabs/North-Mini-Code-1.0
res = "cohere2_moe"
if chkhsh == "e636dc30a262dcc0d8c323492e32ae2b70728f4df7dfe9737d9f920a282b8aea":
# ref: https://huggingface.co/Qwen/Qwen1.5-7B
res = "qwen2"
@ -684,6 +690,9 @@ class Model:
if chkhsh == "f4f37b6c8eb9ea29b3eac6bb8c8487c5ab7885f8d8022e67edc1c68ce8403e95":
# ref: https://huggingface.co/MiniMaxAI/MiniMax-M2
res = "minimax-m2"
if chkhsh == "9dcf830ee9990cdbf78cc523a5f7bd9ad8f3f9890c2d3581d2785ad10f07049d":
# ref: https://huggingface.co/JetBrains/Mellum2-12B-A2.5B-Base
res = "mellum2"
if res is None:
logger.warning("\n")
logger.warning("**************************************************************************************")
@ -1580,7 +1589,12 @@ class LlamaModel(Model):
special_vocab.add_to_gguf(self.gguf_writer)
def set_gguf_parameters(self):
saved_intermediate_size = self.hparams.get("intermediate_size")
saved_num_experts_per_tok = self.hparams.pop("num_experts_per_tok")
self.hparams["intermediate_size"] = self.hparams["prefix_dense_intermediate_size"]
super().set_gguf_parameters()
self.hparams["intermediate_size"] = saved_intermediate_size
self.hparams["num_experts_per_tok"] = saved_num_experts_per_tok
hparams = self.hparams
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
@ -3735,7 +3749,7 @@ class Gemma4Model(Gemma4BaseModel):
return [(self.map_tensor_name(name), data_torch)]
@Model.register("Gemma4AssistantForCausalLM")
@Model.register("Gemma4AssistantForCausalLM", "Gemma4UnifiedAssistantForCausalLM")
class Gemma4AssistantModel(Gemma4BaseModel):
model_arch = gguf.MODEL_ARCH.GEMMA4_MTP
@ -3950,6 +3964,86 @@ class CommandR2Model(Model):
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
@Model.register("Cohere2MoeForCausalLM")
class Cohere2MoeModel(Model):
model_arch = gguf.MODEL_ARCH.COHERE2_MOE
_experts: list[dict[str, Tensor]] | None = None
def set_gguf_parameters(self):
saved_intermediate_size = self.hparams["intermediate_size"]
saved_num_experts_per_tok = self.hparams.pop("num_experts_per_tok")
self.hparams["intermediate_size"] = self.hparams["prefix_dense_intermediate_size"]
super().set_gguf_parameters()
self.hparams["intermediate_size"] = saved_intermediate_size
self.hparams["num_experts_per_tok"] = saved_num_experts_per_tok
hparams = self.hparams
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
self.gguf_writer.add_logit_scale(hparams.get("logit_scale", 1.0))
self.gguf_writer.add_sliding_window(hparams["sliding_window"])
self.gguf_writer.add_sliding_window_pattern([
layer_type == "sliding_attention"
for layer_type in hparams["layer_types"]
])
self.gguf_writer.add_rope_dimension_count(hparams.get("head_dim", hparams["hidden_size"] // hparams["num_attention_heads"]))
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
self.gguf_writer.add_expert_feed_forward_length(hparams["intermediate_size"])
self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"])
self.gguf_writer.add_expert_count(hparams["num_experts"])
self.gguf_writer.add_expert_used_count(hparams["num_experts_per_tok"])
self.gguf_writer.add_expert_weights_norm(bool(hparams.get("norm_topk_prob", False)))
expert_selection_fn = hparams.get("expert_selection_fn", "softmax")
if expert_selection_fn != "sigmoid":
raise ValueError(f"Unsupported Cohere2-MoE expert_selection_fn={expert_selection_fn!r}")
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
if hparams.get("num_shared_experts", 0) != 0:
raise ValueError("Cohere2-MoE shared experts are not supported in this GGUF converter yet")
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# Cohere2-MoE HF tensors already use the interleaved RoPE layout expected here.
if ".mlp.experts." in name:
n_experts = self.hparams["num_experts"]
assert bid is not None
if self._experts is None:
self._experts = [{} for _ in range(self.block_count)]
self._experts[bid][name] = data_torch
if len(self._experts[bid]) < n_experts * 3:
return []
tensors: list[tuple[str, Tensor]] = []
for src, dst in [
("gate_proj", "gate_proj"),
("down_proj", "down_proj"),
("up_proj", "up_proj"),
]:
datas: list[Tensor] = []
for xid in range(n_experts):
ename = f"model.layers.{bid}.mlp.experts.{xid}.{src}.weight"
datas.append(self._experts[bid][ename])
del self._experts[bid][ename]
merged_name = f"model.layers.{bid}.mlp.experts.{dst}.weight"
tensors.append((self.map_tensor_name(merged_name), torch.stack(datas, dim=0)))
yield from tensors
return
if name == "model.embed_tokens.weight":
yield self.map_tensor_name(name), data_torch
if self.tensor_names is None or "lm_head.weight" not in self.tensor_names:
yield self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT, suffix=".weight"), data_torch
return
yield self.map_tensor_name(name), data_torch
@Model.register("OlmoForCausalLM")
@Model.register("OLMoForCausalLM")
class OlmoModel(Model):
@ -5307,6 +5401,138 @@ class BailingMoeV2Model(Model):
raise ValueError(f"Unprocessed experts: {experts}")
@Model.register("LagunaForCausalLM")
class LagunaModel(Model):
model_arch = gguf.MODEL_ARCH.LAGUNA
_experts: list[dict[str, Tensor]] | None = None
def set_gguf_parameters(self):
hparams = self.hparams
arch = gguf.MODEL_ARCH_NAMES[self.model_arch]
n_layers = int(hparams["num_hidden_layers"])
n_head_base = int(hparams["num_attention_heads"])
n_kv_base = int(hparams.get("num_key_value_heads", n_head_base))
head_dim = int(hparams.get("head_dim", hparams["hidden_size"] // n_head_base))
heads_per_layer = hparams.get("num_attention_heads_per_layer")
kv_per_layer = hparams.get("num_key_value_heads_per_layer")
head_arr: list[int] = []
kv_arr: list[int] = []
for i in range(n_layers):
head_arr.append(int(heads_per_layer[i]) if heads_per_layer is not None else n_head_base)
kv_arr.append(int(kv_per_layer[i]) if kv_per_layer is not None else n_kv_base)
rope_params = hparams.get("rope_parameters", {})
full_rope = rope_params.get("full_attention", rope_params)
swa_rope = rope_params.get("sliding_attention", {})
self.gguf_writer.add_context_length(int(hparams["max_position_embeddings"]))
self.gguf_writer.add_embedding_length(int(hparams["hidden_size"]))
self.gguf_writer.add_block_count(n_layers)
self.gguf_writer.add_feed_forward_length(int(hparams["intermediate_size"]))
self.gguf_writer.add_head_count(head_arr)
if all(n_kv == kv_arr[0] for n_kv in kv_arr):
self.gguf_writer.add_head_count_kv(kv_arr[0])
else:
self.gguf_writer.add_head_count_kv(kv_arr)
self.gguf_writer.add_key_length(head_dim)
self.gguf_writer.add_value_length(head_dim)
self.gguf_writer.add_layer_norm_rms_eps(float(hparams["rms_norm_eps"]))
self.gguf_writer.add_file_type(self.ftype)
self.gguf_writer.add_sliding_window(int(hparams["sliding_window"]))
self.gguf_writer.add_rope_dimension_count(head_dim // 2)
self.gguf_writer.add_uint32(f"{arch}.rope.dimension_count_swa", head_dim)
self.gguf_writer.add_rope_freq_base(float(full_rope.get("rope_theta", 500000.0)))
self.gguf_writer.add_float32(f"{arch}.rope.freq_base_swa", float(swa_rope.get("rope_theta", 10000.0)))
if full_rope.get("rope_type") == "yarn":
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
self.gguf_writer.add_rope_scaling_factor(float(full_rope.get("factor", 1.0)))
self.gguf_writer.add_rope_scaling_orig_ctx_len(int(full_rope.get(
"original_max_position_embeddings",
rope_params.get("original_max_position_embeddings", hparams["max_position_embeddings"]),
)))
self.gguf_writer.add_rope_scaling_yarn_ext_factor(float(full_rope.get("factor", 1.0)))
self.gguf_writer.add_rope_scaling_yarn_attn_factor(float(full_rope.get("attention_factor", 1.0)))
self.gguf_writer.add_rope_scaling_yarn_beta_fast(float(full_rope.get("beta_fast", 32.0)))
self.gguf_writer.add_rope_scaling_yarn_beta_slow(float(full_rope.get("beta_slow", 1.0)))
self.gguf_writer.add_expert_count(int(hparams["num_experts"]))
self.gguf_writer.add_expert_used_count(int(hparams["num_experts_per_tok"]))
self.gguf_writer.add_expert_feed_forward_length(int(hparams["moe_intermediate_size"]))
if (shared_dim := hparams.get("shared_expert_intermediate_size")) is not None and int(shared_dim) > 0:
self.gguf_writer.add_expert_shared_feed_forward_length(int(shared_dim))
if (routing_scale := hparams.get("moe_routed_scaling_factor")) is not None:
self.gguf_writer.add_expert_weights_scale(float(routing_scale))
self.gguf_writer.add_expert_weights_norm(True)
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
leading_dense = 0
for mlp_type in hparams.get("mlp_layer_types", []):
if mlp_type != "dense":
break
leading_dense += 1
self.gguf_writer.add_uint32(f"{arch}.leading_dense_block_count", leading_dense)
if hparams.get("moe_apply_router_weight_on_input", False):
raise ValueError("moe_apply_router_weight_on_input=True is not supported for Laguna")
def set_vocab(self) -> None:
super().set_vocab()
if isinstance(eos_token_id := self.hparams.get("eos_token_id"), list) and len(eos_token_id) > 1:
# Poolside uses token 24 (</assistant>) as a turn boundary.
self.gguf_writer.add_eot_token_id(int(eos_token_id[1]))
template_file = self.dir_model / "chat_template.jinja"
if template_file.is_file():
self.gguf_writer.add_chat_template(template_file.read_text(encoding="utf-8"))
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if bid is not None and name in (
f"model.layers.{bid}.mlp.experts.e_score_correction_bias",
f"model.layers.{bid}.mlp.experts.e_score_correction",
):
# The C++ loader asks for this tensor through the ".bias" suffix.
# Keep the Laguna converter aligned with existing community GGUFs.
yield f"blk.{bid}.exp_probs_b.bias", data_torch
return
if name.endswith(".self_attn.g_proj.weight"):
# HF stores the head-wise attention gate with a singleton dimension.
data_torch = data_torch.squeeze().contiguous()
if bid is not None and re.match(r"model\.layers\.\d+\.mlp\.experts\.\d+\.(gate_proj|up_proj|down_proj)\.weight$", name):
n_experts = int(self.find_hparam(["num_experts"]))
if self._experts is None:
self._experts = [{} for _ in range(self.block_count)]
self._experts[bid][name] = data_torch
if len(self._experts[bid]) < n_experts * 3:
return
for w_name in ("down_proj", "gate_proj", "up_proj"):
datas: list[Tensor] = []
for xid in range(n_experts):
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
datas.append(self._experts[bid][ename])
del self._experts[bid][ename]
merged = torch.stack(datas, dim=0)
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
yield from super().modify_tensors(merged, merged_name, bid)
return
yield from super().modify_tensors(data_torch, name, bid)
def prepare_tensors(self):
super().prepare_tensors()
if self._experts is not None:
experts = [k for d in self._experts for k in d.keys()]
if experts:
raise ValueError(f"Unprocessed experts: {experts}")
###### CONVERSION LOGIC ######

View File

@ -6,6 +6,7 @@
#include "common.h"
#include "llama.h"
#include "llama-spec-features.h"
#include "log.h"
#include "sampling.h"
#include "speculative.h"
@ -13,12 +14,8 @@
#include "mtmd-helper.h"
#include <fstream>
#include <atomic>
#include <cstring>
#include <cstdlib>
#include <iostream>
#include <regex>
#include <sstream>
#include <exception>
static void server_prompt_checkpoint_update(server_prompt_checkpoint & ckpt, llama_context * ctx, int id, int64_t n_tokens, llama_pos pos_min = -1, llama_pos pos_max = -1, int32_t offset = 0) {
@ -49,83 +46,6 @@ static void log_text(const gpt_params & params_base, const std::string & text) {
}
}
static bool server_dflash_contract_log_enabled() {
const char * env = std::getenv("IK_DFLASH_CONTRACT_LOG");
if (env == nullptr || *env == '\0') {
return false;
}
return std::strcmp(env, "0") != 0 &&
std::strcmp(env, "false") != 0 &&
std::strcmp(env, "off") != 0;
}
static std::string server_dflash_contract_format_indices(
const std::vector<int32_t> & values,
size_t edge_count = 4) {
std::ostringstream oss;
oss << '[';
if (values.empty()) {
oss << ']';
return oss.str();
}
const size_t head = std::min(edge_count, values.size());
for (size_t i = 0; i < head; ++i) {
if (i > 0) {
oss << ',';
}
oss << values[i];
}
if (values.size() > edge_count * 2) {
oss << ",...,";
for (size_t i = values.size() - edge_count; i < values.size(); ++i) {
if (i > values.size() - edge_count) {
oss << ',';
}
oss << values[i];
}
} else {
for (size_t i = head; i < values.size(); ++i) {
oss << ',' << values[i];
}
}
oss << ']';
return oss.str();
}
static void server_dflash_contract_log_accept(
const server_slot & slot,
common_speculative_type spec_type_used,
const char * path,
bool any_rejected,
size_t n_draft,
const std::vector<llama_token> & ids,
llama_pos pos_base,
const std::vector<int32_t> & output_indices) {
if (!server_dflash_contract_log_enabled() || spec_type_used != COMMON_SPECULATIVE_TYPE_DFLASH) {
return;
}
static std::atomic<uint64_t> counter = 0;
const uint64_t ordinal = counter.fetch_add(1, std::memory_order_relaxed);
if (ordinal >= 8) {
return;
}
LLAMA_LOG_INFO("dflash contract accept[%llu]: slot=%d path=%s rejected=%s drafted=%zu accepted=%zu pos_base=%d output_indices=%s\n",
(unsigned long long) (ordinal + 1),
slot.id,
path,
any_rejected ? "true" : "false",
n_draft,
ids.size(),
(int) pos_base,
server_dflash_contract_format_indices(output_indices).c_str());
}
static bool server_slot_prompt_batch_overlaps(
const server_slot & slot,
int32_t batch_i0,
@ -158,12 +78,13 @@ static bool server_response_needs_chat_parse(oaicompat_type oaicompat) {
oaicompat == OAICOMPAT_TYPE_RESP;
}
static bool server_speculative_has_dflash(const common_params_speculative & spec) {
return spec.has_stage_type(COMMON_SPECULATIVE_TYPE_DFLASH);
static bool server_speculative_uses_target_features(const common_params_speculative & spec) {
return spec.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP) ||
spec.has_stage_type(COMMON_SPECULATIVE_TYPE_DFLASH);
}
static bool server_speculative_has_target_features(const common_params_speculative & spec) {
return spec.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP) || server_speculative_has_dflash(spec);
static bool server_speculative_requires_single_slot(const common_params_speculative & spec) {
return spec.has_stage_chain();
}
static bool server_speculative_same_stage_types(
@ -295,8 +216,8 @@ bool server_context::load_model(const gpt_params& params_) {
common_speculative_prepare_startup(params_base, false);
if (server_speculative_has_dflash(params_base.speculative) && params_base.n_parallel > 1) {
LOG_ERROR("DFlash is currently limited to a single server slot (-np 1).\n", {
if (server_speculative_requires_single_slot(params_base.speculative) && params_base.n_parallel > 1) {
LOG_ERROR("Speculative decoding is currently limited to a single server slot (-np 1).\n", {
{"n_parallel", params_base.n_parallel},
});
return false;
@ -4127,7 +4048,7 @@ void server_context::speculative_decoding_accept() {
}
std::vector<int32_t> accepted_output_indices;
if (server_speculative_has_target_features(slot.params.speculative)) {
if (server_speculative_uses_target_features(slot.params.speculative)) {
if (!ids.empty()) {
accepted_output_indices.assign(slot.i_batch_dft.begin(), slot.i_batch_dft.begin() + ids.size());
}
@ -4160,14 +4081,14 @@ void server_context::speculative_decoding_accept() {
const common_speculative_checkpoint * ckpt = common_speculative_get_checkpoint(slot.spec);
const bool will_restore = any_rejected && ckpt != nullptr && ckpt->valid;
if (server_speculative_has_target_features(slot.params.speculative) && !accepted_output_indices.empty()) {
server_dflash_contract_log_accept(
slot,
spec_type_used,
if (server_speculative_uses_target_features(slot.params.speculative) && !accepted_output_indices.empty()) {
llama_dflash_contract_log_accept(
slot.id,
spec_type_used == COMMON_SPECULATIVE_TYPE_DFLASH,
will_restore ? "restore" : "direct",
any_rejected,
n_draft,
ids,
ids.size(),
spec_pos_base,
accepted_output_indices);
}
@ -4556,9 +4477,9 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
continue; // continue loop of n_batch
}
if (server_speculative_has_target_features(params_base.speculative)) {
if (server_speculative_uses_target_features(params_base.speculative)) {
for (auto & slot : slots) {
if (!slot.spec || !server_speculative_has_target_features(slot.params.speculative)) {
if (!slot.spec || !server_speculative_uses_target_features(slot.params.speculative)) {
continue;
}

340
src/llama-dflash-profile.h Normal file
View File

@ -0,0 +1,340 @@
#pragma once
#include <cstdint>
#include <cstring>
enum llama_dflash_kv_node_kind {
LLAMA_DFLASH_KV_NODE_NONE = 0,
LLAMA_DFLASH_KV_NODE_FUSED_TARGET,
LLAMA_DFLASH_KV_NODE_K_PROJ,
LLAMA_DFLASH_KV_NODE_K_NORM,
LLAMA_DFLASH_KV_NODE_K_ROPE,
LLAMA_DFLASH_KV_NODE_V_PROJ,
LLAMA_DFLASH_KV_NODE_K_STORE,
LLAMA_DFLASH_KV_NODE_V_STORE,
};
enum llama_dflash_main_node_kind {
LLAMA_DFLASH_MAIN_NODE_NONE = 0,
LLAMA_DFLASH_MAIN_NODE_QCUR,
LLAMA_DFLASH_MAIN_NODE_K_DRAFT,
LLAMA_DFLASH_MAIN_NODE_V_DRAFT,
LLAMA_DFLASH_MAIN_NODE_K_CTX_VIEW,
LLAMA_DFLASH_MAIN_NODE_V_CTX_VIEW,
LLAMA_DFLASH_MAIN_NODE_K_CONCAT,
LLAMA_DFLASH_MAIN_NODE_V_CONCAT,
LLAMA_DFLASH_MAIN_NODE_K_PAD,
LLAMA_DFLASH_MAIN_NODE_V_PAD,
LLAMA_DFLASH_MAIN_NODE_K_PERM_CONT,
LLAMA_DFLASH_MAIN_NODE_V_PERM_CONT,
LLAMA_DFLASH_MAIN_NODE_FLASH_ATTN,
LLAMA_DFLASH_MAIN_NODE_ATTN_OUT,
LLAMA_DFLASH_MAIN_NODE_FFN,
LLAMA_DFLASH_MAIN_NODE_RESULT_ROWS,
LLAMA_DFLASH_MAIN_NODE_RESULT_NORM,
LLAMA_DFLASH_MAIN_NODE_RESULT,
};
struct llama_dflash_kv_node_profiler {
llama_dflash_profile_stats * profile = nullptr;
int64_t t_start_us = 0;
llama_dflash_kv_node_kind active_kind = LLAMA_DFLASH_KV_NODE_NONE;
};
struct llama_dflash_main_node_profiler {
llama_dflash_profile_stats * profile = nullptr;
ggml_backend_sched_eval_callback prev_callback = nullptr;
void * prev_user_data = nullptr;
bool prev_active = false;
int64_t t_start_us = 0;
llama_dflash_main_node_kind active_kind = LLAMA_DFLASH_MAIN_NODE_NONE;
};
static inline bool llama_dflash_tensor_name_has_prefix(const struct ggml_tensor * tensor, const char * prefix) {
if (tensor == nullptr || prefix == nullptr || prefix[0] == '\0') {
return false;
}
return std::strncmp(tensor->name, prefix, std::strlen(prefix)) == 0;
}
static inline bool llama_dflash_tensor_name_matches_label(const struct ggml_tensor * tensor, const char * label) {
if (!llama_dflash_tensor_name_has_prefix(tensor, label)) {
return false;
}
const size_t label_len = std::strlen(label);
const char next = tensor->name[label_len];
return next == '\0' || next == '-';
}
static inline llama_dflash_kv_node_kind llama_dflash_kv_node_kind_from_tensor(const struct ggml_tensor * tensor) {
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_fused_target")) {
return LLAMA_DFLASH_KV_NODE_FUSED_TARGET;
}
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_proj")) {
return LLAMA_DFLASH_KV_NODE_K_PROJ;
}
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_norm")) {
return LLAMA_DFLASH_KV_NODE_K_NORM;
}
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_rope")) {
return LLAMA_DFLASH_KV_NODE_K_ROPE;
}
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_v_proj")) {
return LLAMA_DFLASH_KV_NODE_V_PROJ;
}
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_store")) {
return LLAMA_DFLASH_KV_NODE_K_STORE;
}
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_v_store")) {
return LLAMA_DFLASH_KV_NODE_V_STORE;
}
return LLAMA_DFLASH_KV_NODE_NONE;
}
static inline void llama_dflash_kv_node_profile_add(
llama_dflash_profile_stats & profile,
llama_dflash_kv_node_kind kind,
uint64_t elapsed_us) {
switch (kind) {
case LLAMA_DFLASH_KV_NODE_FUSED_TARGET:
profile.graph_kv_node_fused_target_calls++;
profile.graph_kv_node_fused_target_us += elapsed_us;
break;
case LLAMA_DFLASH_KV_NODE_K_PROJ:
profile.graph_kv_node_k_proj_calls++;
profile.graph_kv_node_k_proj_us += elapsed_us;
break;
case LLAMA_DFLASH_KV_NODE_K_NORM:
profile.graph_kv_node_k_norm_calls++;
profile.graph_kv_node_k_norm_us += elapsed_us;
break;
case LLAMA_DFLASH_KV_NODE_K_ROPE:
profile.graph_kv_node_k_rope_calls++;
profile.graph_kv_node_k_rope_us += elapsed_us;
break;
case LLAMA_DFLASH_KV_NODE_V_PROJ:
profile.graph_kv_node_v_proj_calls++;
profile.graph_kv_node_v_proj_us += elapsed_us;
break;
case LLAMA_DFLASH_KV_NODE_K_STORE:
profile.graph_kv_node_k_store_calls++;
profile.graph_kv_node_k_store_us += elapsed_us;
break;
case LLAMA_DFLASH_KV_NODE_V_STORE:
profile.graph_kv_node_v_store_calls++;
profile.graph_kv_node_v_store_us += elapsed_us;
break;
case LLAMA_DFLASH_KV_NODE_NONE:
break;
}
}
static inline llama_dflash_main_node_kind llama_dflash_main_node_kind_from_tensor(const struct ggml_tensor * tensor) {
if (llama_dflash_tensor_name_has_prefix(tensor, "Qcur")) {
return LLAMA_DFLASH_MAIN_NODE_QCUR;
}
if (llama_dflash_tensor_name_has_prefix(tensor, "Kcur_noise")) {
return LLAMA_DFLASH_MAIN_NODE_K_DRAFT;
}
if (llama_dflash_tensor_name_has_prefix(tensor, "Vcur_noise")) {
return LLAMA_DFLASH_MAIN_NODE_V_DRAFT;
}
if (llama_dflash_tensor_name_has_prefix(tensor, "Kcur_ctx_cache")) {
return LLAMA_DFLASH_MAIN_NODE_K_CTX_VIEW;
}
if (llama_dflash_tensor_name_has_prefix(tensor, "Vcur_ctx_cache")) {
return LLAMA_DFLASH_MAIN_NODE_V_CTX_VIEW;
}
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_k_concat")) {
return LLAMA_DFLASH_MAIN_NODE_K_CONCAT;
}
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_v_concat")) {
return LLAMA_DFLASH_MAIN_NODE_V_CONCAT;
}
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_k_pad")) {
return LLAMA_DFLASH_MAIN_NODE_K_PAD;
}
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_v_pad")) {
return LLAMA_DFLASH_MAIN_NODE_V_PAD;
}
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_k_perm_cont")) {
return LLAMA_DFLASH_MAIN_NODE_K_PERM_CONT;
}
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_v_perm_cont")) {
return LLAMA_DFLASH_MAIN_NODE_V_PERM_CONT;
}
if (llama_dflash_tensor_name_has_prefix(tensor, "flash_attn_reshaped")) {
return LLAMA_DFLASH_MAIN_NODE_NONE;
}
if (llama_dflash_tensor_name_matches_label(tensor, "flash_attn")) {
return LLAMA_DFLASH_MAIN_NODE_FLASH_ATTN;
}
if (llama_dflash_tensor_name_has_prefix(tensor, "kqv_out")) {
return LLAMA_DFLASH_MAIN_NODE_ATTN_OUT;
}
if (llama_dflash_tensor_name_has_prefix(tensor, "ffn_out")) {
return LLAMA_DFLASH_MAIN_NODE_FFN;
}
if (llama_dflash_tensor_name_matches_label(tensor, "result_output_rows")) {
return LLAMA_DFLASH_MAIN_NODE_RESULT_ROWS;
}
if (llama_dflash_tensor_name_matches_label(tensor, "result_norm")) {
return LLAMA_DFLASH_MAIN_NODE_RESULT_NORM;
}
if (llama_dflash_tensor_name_matches_label(tensor, "output")) {
return LLAMA_DFLASH_MAIN_NODE_RESULT;
}
if (llama_dflash_tensor_name_matches_label(tensor, "result_output")) {
return LLAMA_DFLASH_MAIN_NODE_RESULT;
}
return LLAMA_DFLASH_MAIN_NODE_NONE;
}
static inline void llama_dflash_main_node_profile_add(
llama_dflash_profile_stats & profile,
llama_dflash_main_node_kind kind,
uint64_t elapsed_us) {
switch (kind) {
case LLAMA_DFLASH_MAIN_NODE_QCUR:
profile.graph_main_node_qcur_calls++;
profile.graph_main_node_qcur_us += elapsed_us;
break;
case LLAMA_DFLASH_MAIN_NODE_K_DRAFT:
profile.graph_main_node_k_draft_calls++;
profile.graph_main_node_k_draft_us += elapsed_us;
break;
case LLAMA_DFLASH_MAIN_NODE_V_DRAFT:
profile.graph_main_node_v_draft_calls++;
profile.graph_main_node_v_draft_us += elapsed_us;
break;
case LLAMA_DFLASH_MAIN_NODE_K_CTX_VIEW:
profile.graph_main_node_k_ctx_view_calls++;
profile.graph_main_node_k_ctx_view_us += elapsed_us;
break;
case LLAMA_DFLASH_MAIN_NODE_V_CTX_VIEW:
profile.graph_main_node_v_ctx_view_calls++;
profile.graph_main_node_v_ctx_view_us += elapsed_us;
break;
case LLAMA_DFLASH_MAIN_NODE_K_CONCAT:
profile.graph_main_node_k_concat_calls++;
profile.graph_main_node_k_concat_us += elapsed_us;
break;
case LLAMA_DFLASH_MAIN_NODE_V_CONCAT:
profile.graph_main_node_v_concat_calls++;
profile.graph_main_node_v_concat_us += elapsed_us;
break;
case LLAMA_DFLASH_MAIN_NODE_K_PAD:
profile.graph_main_node_k_pad_calls++;
profile.graph_main_node_k_pad_us += elapsed_us;
break;
case LLAMA_DFLASH_MAIN_NODE_V_PAD:
profile.graph_main_node_v_pad_calls++;
profile.graph_main_node_v_pad_us += elapsed_us;
break;
case LLAMA_DFLASH_MAIN_NODE_K_PERM_CONT:
profile.graph_main_node_k_perm_cont_calls++;
profile.graph_main_node_k_perm_cont_us += elapsed_us;
break;
case LLAMA_DFLASH_MAIN_NODE_V_PERM_CONT:
profile.graph_main_node_v_perm_cont_calls++;
profile.graph_main_node_v_perm_cont_us += elapsed_us;
break;
case LLAMA_DFLASH_MAIN_NODE_FLASH_ATTN:
profile.graph_main_node_flash_attn_calls++;
profile.graph_main_node_flash_attn_us += elapsed_us;
break;
case LLAMA_DFLASH_MAIN_NODE_ATTN_OUT:
profile.graph_main_node_attn_out_calls++;
profile.graph_main_node_attn_out_us += elapsed_us;
break;
case LLAMA_DFLASH_MAIN_NODE_FFN:
profile.graph_main_node_ffn_calls++;
profile.graph_main_node_ffn_us += elapsed_us;
break;
case LLAMA_DFLASH_MAIN_NODE_RESULT_ROWS:
profile.graph_main_node_result_rows_calls++;
profile.graph_main_node_result_rows_us += elapsed_us;
break;
case LLAMA_DFLASH_MAIN_NODE_RESULT_NORM:
profile.graph_main_node_result_norm_calls++;
profile.graph_main_node_result_norm_us += elapsed_us;
break;
case LLAMA_DFLASH_MAIN_NODE_RESULT:
profile.graph_main_node_result_calls++;
profile.graph_main_node_result_us += elapsed_us;
break;
case LLAMA_DFLASH_MAIN_NODE_NONE:
break;
}
}
static inline bool llama_dflash_kv_node_eval_callback(struct ggml_tensor * tensor, bool ask, void * user_data) {
auto * profiler = static_cast<llama_dflash_kv_node_profiler *>(user_data);
if (profiler == nullptr || profiler->profile == nullptr) {
return false;
}
const llama_dflash_kv_node_kind kind = llama_dflash_kv_node_kind_from_tensor(tensor);
if (ask) {
if (kind == LLAMA_DFLASH_KV_NODE_NONE) {
return false;
}
profiler->active_kind = kind;
profiler->t_start_us = ggml_time_us();
return true;
}
if (kind != LLAMA_DFLASH_KV_NODE_NONE && profiler->active_kind == kind && profiler->t_start_us > 0) {
llama_dflash_kv_node_profile_add(*profiler->profile, kind, (uint64_t) (ggml_time_us() - profiler->t_start_us));
}
profiler->active_kind = LLAMA_DFLASH_KV_NODE_NONE;
profiler->t_start_us = 0;
return true;
}
static inline bool llama_dflash_main_node_eval_callback(struct ggml_tensor * tensor, bool ask, void * user_data) {
auto * profiler = static_cast<llama_dflash_main_node_profiler *>(user_data);
if (profiler == nullptr || profiler->profile == nullptr) {
return false;
}
const llama_dflash_main_node_kind kind = llama_dflash_main_node_kind_from_tensor(tensor);
if (ask) {
profiler->prev_active = profiler->prev_callback != nullptr
? profiler->prev_callback(tensor, ask, profiler->prev_user_data)
: false;
if (kind == LLAMA_DFLASH_MAIN_NODE_NONE) {
profiler->active_kind = LLAMA_DFLASH_MAIN_NODE_NONE;
profiler->t_start_us = 0;
return profiler->prev_active;
}
profiler->active_kind = kind;
profiler->t_start_us = ggml_time_us();
return true;
}
bool prev_result = false;
if (profiler->prev_active && profiler->prev_callback != nullptr) {
prev_result = profiler->prev_callback(tensor, ask, profiler->prev_user_data);
}
const bool tracked = kind != LLAMA_DFLASH_MAIN_NODE_NONE &&
profiler->active_kind == kind &&
profiler->t_start_us > 0;
if (tracked) {
llama_dflash_main_node_profile_add(*profiler->profile, kind, (uint64_t) (ggml_time_us() - profiler->t_start_us));
}
profiler->prev_active = false;
profiler->active_kind = LLAMA_DFLASH_MAIN_NODE_NONE;
profiler->t_start_us = 0;
return prev_result || tracked;
}

View File

@ -5,6 +5,7 @@
#include "llama-context.h"
#include "llama-model.h"
#include "llama-spec-features.h"
#include "llama-dflash-profile.h"
#include "ggml.h"
#include "ggml-backend.h"
@ -27,342 +28,6 @@ static bool llama_dflash_stats_log_enabled() {
return llama_env_flag_enabled_local("IK_DFLASH_STATS_LOG");
}
enum llama_dflash_kv_node_kind {
LLAMA_DFLASH_KV_NODE_NONE = 0,
LLAMA_DFLASH_KV_NODE_FUSED_TARGET,
LLAMA_DFLASH_KV_NODE_K_PROJ,
LLAMA_DFLASH_KV_NODE_K_NORM,
LLAMA_DFLASH_KV_NODE_K_ROPE,
LLAMA_DFLASH_KV_NODE_V_PROJ,
LLAMA_DFLASH_KV_NODE_K_STORE,
LLAMA_DFLASH_KV_NODE_V_STORE,
};
enum llama_dflash_main_node_kind {
LLAMA_DFLASH_MAIN_NODE_NONE = 0,
LLAMA_DFLASH_MAIN_NODE_QCUR,
LLAMA_DFLASH_MAIN_NODE_K_DRAFT,
LLAMA_DFLASH_MAIN_NODE_V_DRAFT,
LLAMA_DFLASH_MAIN_NODE_K_CTX_VIEW,
LLAMA_DFLASH_MAIN_NODE_V_CTX_VIEW,
LLAMA_DFLASH_MAIN_NODE_K_CONCAT,
LLAMA_DFLASH_MAIN_NODE_V_CONCAT,
LLAMA_DFLASH_MAIN_NODE_K_PAD,
LLAMA_DFLASH_MAIN_NODE_V_PAD,
LLAMA_DFLASH_MAIN_NODE_K_PERM_CONT,
LLAMA_DFLASH_MAIN_NODE_V_PERM_CONT,
LLAMA_DFLASH_MAIN_NODE_FLASH_ATTN,
LLAMA_DFLASH_MAIN_NODE_ATTN_OUT,
LLAMA_DFLASH_MAIN_NODE_FFN,
LLAMA_DFLASH_MAIN_NODE_RESULT_ROWS,
LLAMA_DFLASH_MAIN_NODE_RESULT_NORM,
LLAMA_DFLASH_MAIN_NODE_RESULT,
};
struct llama_dflash_kv_node_profiler {
llama_dflash_profile_stats * profile = nullptr;
int64_t t_start_us = 0;
llama_dflash_kv_node_kind active_kind = LLAMA_DFLASH_KV_NODE_NONE;
};
struct llama_dflash_main_node_profiler {
llama_dflash_profile_stats * profile = nullptr;
ggml_backend_sched_eval_callback prev_callback = nullptr;
void * prev_user_data = nullptr;
bool prev_active = false;
int64_t t_start_us = 0;
llama_dflash_main_node_kind active_kind = LLAMA_DFLASH_MAIN_NODE_NONE;
};
static bool llama_dflash_tensor_name_has_prefix(const struct ggml_tensor * tensor, const char * prefix) {
if (tensor == nullptr || prefix == nullptr || prefix[0] == '\0') {
return false;
}
return std::strncmp(tensor->name, prefix, std::strlen(prefix)) == 0;
}
static bool llama_dflash_tensor_name_matches_label(const struct ggml_tensor * tensor, const char * label) {
if (!llama_dflash_tensor_name_has_prefix(tensor, label)) {
return false;
}
const size_t label_len = std::strlen(label);
const char next = tensor->name[label_len];
return next == '\0' || next == '-';
}
static llama_dflash_kv_node_kind llama_dflash_kv_node_kind_from_tensor(const struct ggml_tensor * tensor) {
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_fused_target")) {
return LLAMA_DFLASH_KV_NODE_FUSED_TARGET;
}
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_proj")) {
return LLAMA_DFLASH_KV_NODE_K_PROJ;
}
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_norm")) {
return LLAMA_DFLASH_KV_NODE_K_NORM;
}
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_rope")) {
return LLAMA_DFLASH_KV_NODE_K_ROPE;
}
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_v_proj")) {
return LLAMA_DFLASH_KV_NODE_V_PROJ;
}
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_store")) {
return LLAMA_DFLASH_KV_NODE_K_STORE;
}
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_v_store")) {
return LLAMA_DFLASH_KV_NODE_V_STORE;
}
return LLAMA_DFLASH_KV_NODE_NONE;
}
static void llama_dflash_kv_node_profile_add(
llama_dflash_profile_stats & profile,
llama_dflash_kv_node_kind kind,
uint64_t elapsed_us) {
switch (kind) {
case LLAMA_DFLASH_KV_NODE_FUSED_TARGET:
profile.graph_kv_node_fused_target_calls++;
profile.graph_kv_node_fused_target_us += elapsed_us;
break;
case LLAMA_DFLASH_KV_NODE_K_PROJ:
profile.graph_kv_node_k_proj_calls++;
profile.graph_kv_node_k_proj_us += elapsed_us;
break;
case LLAMA_DFLASH_KV_NODE_K_NORM:
profile.graph_kv_node_k_norm_calls++;
profile.graph_kv_node_k_norm_us += elapsed_us;
break;
case LLAMA_DFLASH_KV_NODE_K_ROPE:
profile.graph_kv_node_k_rope_calls++;
profile.graph_kv_node_k_rope_us += elapsed_us;
break;
case LLAMA_DFLASH_KV_NODE_V_PROJ:
profile.graph_kv_node_v_proj_calls++;
profile.graph_kv_node_v_proj_us += elapsed_us;
break;
case LLAMA_DFLASH_KV_NODE_K_STORE:
profile.graph_kv_node_k_store_calls++;
profile.graph_kv_node_k_store_us += elapsed_us;
break;
case LLAMA_DFLASH_KV_NODE_V_STORE:
profile.graph_kv_node_v_store_calls++;
profile.graph_kv_node_v_store_us += elapsed_us;
break;
case LLAMA_DFLASH_KV_NODE_NONE:
break;
}
}
static llama_dflash_main_node_kind llama_dflash_main_node_kind_from_tensor(const struct ggml_tensor * tensor) {
if (llama_dflash_tensor_name_has_prefix(tensor, "Qcur")) {
return LLAMA_DFLASH_MAIN_NODE_QCUR;
}
if (llama_dflash_tensor_name_has_prefix(tensor, "Kcur_noise")) {
return LLAMA_DFLASH_MAIN_NODE_K_DRAFT;
}
if (llama_dflash_tensor_name_has_prefix(tensor, "Vcur_noise")) {
return LLAMA_DFLASH_MAIN_NODE_V_DRAFT;
}
if (llama_dflash_tensor_name_has_prefix(tensor, "Kcur_ctx_cache")) {
return LLAMA_DFLASH_MAIN_NODE_K_CTX_VIEW;
}
if (llama_dflash_tensor_name_has_prefix(tensor, "Vcur_ctx_cache")) {
return LLAMA_DFLASH_MAIN_NODE_V_CTX_VIEW;
}
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_k_concat")) {
return LLAMA_DFLASH_MAIN_NODE_K_CONCAT;
}
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_v_concat")) {
return LLAMA_DFLASH_MAIN_NODE_V_CONCAT;
}
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_k_pad")) {
return LLAMA_DFLASH_MAIN_NODE_K_PAD;
}
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_v_pad")) {
return LLAMA_DFLASH_MAIN_NODE_V_PAD;
}
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_k_perm_cont")) {
return LLAMA_DFLASH_MAIN_NODE_K_PERM_CONT;
}
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_v_perm_cont")) {
return LLAMA_DFLASH_MAIN_NODE_V_PERM_CONT;
}
if (llama_dflash_tensor_name_has_prefix(tensor, "flash_attn_reshaped")) {
return LLAMA_DFLASH_MAIN_NODE_NONE;
}
if (llama_dflash_tensor_name_matches_label(tensor, "flash_attn")) {
return LLAMA_DFLASH_MAIN_NODE_FLASH_ATTN;
}
if (llama_dflash_tensor_name_has_prefix(tensor, "kqv_out")) {
return LLAMA_DFLASH_MAIN_NODE_ATTN_OUT;
}
if (llama_dflash_tensor_name_has_prefix(tensor, "ffn_out")) {
return LLAMA_DFLASH_MAIN_NODE_FFN;
}
if (llama_dflash_tensor_name_matches_label(tensor, "result_output_rows")) {
return LLAMA_DFLASH_MAIN_NODE_RESULT_ROWS;
}
if (llama_dflash_tensor_name_matches_label(tensor, "result_norm")) {
return LLAMA_DFLASH_MAIN_NODE_RESULT_NORM;
}
if (llama_dflash_tensor_name_matches_label(tensor, "output")) {
return LLAMA_DFLASH_MAIN_NODE_RESULT;
}
if (llama_dflash_tensor_name_matches_label(tensor, "result_output")) {
return LLAMA_DFLASH_MAIN_NODE_RESULT;
}
return LLAMA_DFLASH_MAIN_NODE_NONE;
}
static void llama_dflash_main_node_profile_add(
llama_dflash_profile_stats & profile,
llama_dflash_main_node_kind kind,
uint64_t elapsed_us) {
switch (kind) {
case LLAMA_DFLASH_MAIN_NODE_QCUR:
profile.graph_main_node_qcur_calls++;
profile.graph_main_node_qcur_us += elapsed_us;
break;
case LLAMA_DFLASH_MAIN_NODE_K_DRAFT:
profile.graph_main_node_k_draft_calls++;
profile.graph_main_node_k_draft_us += elapsed_us;
break;
case LLAMA_DFLASH_MAIN_NODE_V_DRAFT:
profile.graph_main_node_v_draft_calls++;
profile.graph_main_node_v_draft_us += elapsed_us;
break;
case LLAMA_DFLASH_MAIN_NODE_K_CTX_VIEW:
profile.graph_main_node_k_ctx_view_calls++;
profile.graph_main_node_k_ctx_view_us += elapsed_us;
break;
case LLAMA_DFLASH_MAIN_NODE_V_CTX_VIEW:
profile.graph_main_node_v_ctx_view_calls++;
profile.graph_main_node_v_ctx_view_us += elapsed_us;
break;
case LLAMA_DFLASH_MAIN_NODE_K_CONCAT:
profile.graph_main_node_k_concat_calls++;
profile.graph_main_node_k_concat_us += elapsed_us;
break;
case LLAMA_DFLASH_MAIN_NODE_V_CONCAT:
profile.graph_main_node_v_concat_calls++;
profile.graph_main_node_v_concat_us += elapsed_us;
break;
case LLAMA_DFLASH_MAIN_NODE_K_PAD:
profile.graph_main_node_k_pad_calls++;
profile.graph_main_node_k_pad_us += elapsed_us;
break;
case LLAMA_DFLASH_MAIN_NODE_V_PAD:
profile.graph_main_node_v_pad_calls++;
profile.graph_main_node_v_pad_us += elapsed_us;
break;
case LLAMA_DFLASH_MAIN_NODE_K_PERM_CONT:
profile.graph_main_node_k_perm_cont_calls++;
profile.graph_main_node_k_perm_cont_us += elapsed_us;
break;
case LLAMA_DFLASH_MAIN_NODE_V_PERM_CONT:
profile.graph_main_node_v_perm_cont_calls++;
profile.graph_main_node_v_perm_cont_us += elapsed_us;
break;
case LLAMA_DFLASH_MAIN_NODE_FLASH_ATTN:
profile.graph_main_node_flash_attn_calls++;
profile.graph_main_node_flash_attn_us += elapsed_us;
break;
case LLAMA_DFLASH_MAIN_NODE_ATTN_OUT:
profile.graph_main_node_attn_out_calls++;
profile.graph_main_node_attn_out_us += elapsed_us;
break;
case LLAMA_DFLASH_MAIN_NODE_FFN:
profile.graph_main_node_ffn_calls++;
profile.graph_main_node_ffn_us += elapsed_us;
break;
case LLAMA_DFLASH_MAIN_NODE_RESULT_ROWS:
profile.graph_main_node_result_rows_calls++;
profile.graph_main_node_result_rows_us += elapsed_us;
break;
case LLAMA_DFLASH_MAIN_NODE_RESULT_NORM:
profile.graph_main_node_result_norm_calls++;
profile.graph_main_node_result_norm_us += elapsed_us;
break;
case LLAMA_DFLASH_MAIN_NODE_RESULT:
profile.graph_main_node_result_calls++;
profile.graph_main_node_result_us += elapsed_us;
break;
case LLAMA_DFLASH_MAIN_NODE_NONE:
break;
}
}
static bool llama_dflash_kv_node_eval_callback(struct ggml_tensor * tensor, bool ask, void * user_data) {
auto * profiler = static_cast<llama_dflash_kv_node_profiler *>(user_data);
if (profiler == nullptr || profiler->profile == nullptr) {
return false;
}
const llama_dflash_kv_node_kind kind = llama_dflash_kv_node_kind_from_tensor(tensor);
if (ask) {
if (kind == LLAMA_DFLASH_KV_NODE_NONE) {
return false;
}
profiler->active_kind = kind;
profiler->t_start_us = ggml_time_us();
return true;
}
if (kind != LLAMA_DFLASH_KV_NODE_NONE && profiler->active_kind == kind && profiler->t_start_us > 0) {
llama_dflash_kv_node_profile_add(*profiler->profile, kind, (uint64_t) (ggml_time_us() - profiler->t_start_us));
}
profiler->active_kind = LLAMA_DFLASH_KV_NODE_NONE;
profiler->t_start_us = 0;
return true;
}
static bool llama_dflash_main_node_eval_callback(struct ggml_tensor * tensor, bool ask, void * user_data) {
auto * profiler = static_cast<llama_dflash_main_node_profiler *>(user_data);
if (profiler == nullptr || profiler->profile == nullptr) {
return false;
}
const llama_dflash_main_node_kind kind = llama_dflash_main_node_kind_from_tensor(tensor);
if (ask) {
profiler->prev_active = profiler->prev_callback != nullptr
? profiler->prev_callback(tensor, ask, profiler->prev_user_data)
: false;
if (kind == LLAMA_DFLASH_MAIN_NODE_NONE) {
profiler->active_kind = LLAMA_DFLASH_MAIN_NODE_NONE;
profiler->t_start_us = 0;
return profiler->prev_active;
}
profiler->active_kind = kind;
profiler->t_start_us = ggml_time_us();
return true;
}
bool prev_result = false;
if (profiler->prev_active && profiler->prev_callback != nullptr) {
prev_result = profiler->prev_callback(tensor, ask, profiler->prev_user_data);
}
const bool tracked = kind != LLAMA_DFLASH_MAIN_NODE_NONE &&
profiler->active_kind == kind &&
profiler->t_start_us > 0;
if (tracked) {
llama_dflash_main_node_profile_add(*profiler->profile, kind, (uint64_t) (ggml_time_us() - profiler->t_start_us));
}
profiler->prev_active = false;
profiler->active_kind = LLAMA_DFLASH_MAIN_NODE_NONE;
profiler->t_start_us = 0;
return prev_result || tracked;
}
void llama_sync_dflash_workspace_if_pending(struct llama_context & lctx) {
if (!lctx.dflash_kv_workspace_sync_pending || lctx.dflash_workspace_sched == nullptr) {
return;

View File

@ -880,6 +880,36 @@ static void llama_dflash_contract_log_output_indices(
have_capture ? "true" : "false");
}
void llama_dflash_contract_log_accept(
int slot_id,
bool is_dflash,
const char * path,
bool any_rejected,
size_t n_draft,
size_t n_accepted,
llama_pos pos_base,
const std::vector<int32_t> & output_indices) {
if (!llama_dflash_contract_log_enabled() || !is_dflash) {
return;
}
static std::atomic<uint64_t> counter = 0;
const uint64_t ordinal = counter.fetch_add(1, std::memory_order_relaxed);
if (ordinal >= 8) {
return;
}
LLAMA_LOG_INFO("dflash contract accept[%llu]: slot=%d path=%s rejected=%s drafted=%zu accepted=%zu pos_base=%d output_indices=%s\n",
(unsigned long long) (ordinal + 1),
slot_id,
path,
any_rejected ? "true" : "false",
n_draft,
n_accepted,
(int) pos_base,
llama_dflash_contract_format_values(output_indices).c_str());
}
static bool llama_spec_materialize_dflash_rows_prepared(
struct llama_context * ctx,
int32_t row_count,

View File

@ -277,3 +277,13 @@ bool llama_spec_copy_dflash_rows_from_output_indices(
struct llama_context * ctx,
const std::vector<int32_t> & output_indices,
std::vector<float> & hidden_rows);
void llama_dflash_contract_log_accept(
int slot_id,
bool is_dflash,
const char * path,
bool any_rejected,
size_t n_draft,
size_t n_accepted,
llama_pos pos_base,
const std::vector<int32_t> & output_indices);

View File

@ -1,11 +1,6 @@
#include "llama-spec-features.h"
#include <algorithm>
#include <atomic>
#include <cstdlib>
#include <cstring>
#include <random>
#include <sstream>
#include "llama-model.h"
#include "llama-context.h"

View File

@ -19,6 +19,7 @@
#include "llama-context.h"
#include "llama-spec-features.h"
#include "llama-dflash.h"
#include "llama-dflash-profile.h"
#include "llama-quantize.h"
#include "unicode.h"
@ -180,343 +181,6 @@ static bool llama_env_flag_enabled(const char * name) {
std::strcmp(env, "off") != 0;
}
enum llama_dflash_kv_node_kind {
LLAMA_DFLASH_KV_NODE_NONE = 0,
LLAMA_DFLASH_KV_NODE_FUSED_TARGET,
LLAMA_DFLASH_KV_NODE_K_PROJ,
LLAMA_DFLASH_KV_NODE_K_NORM,
LLAMA_DFLASH_KV_NODE_K_ROPE,
LLAMA_DFLASH_KV_NODE_V_PROJ,
LLAMA_DFLASH_KV_NODE_K_STORE,
LLAMA_DFLASH_KV_NODE_V_STORE,
};
enum llama_dflash_main_node_kind {
LLAMA_DFLASH_MAIN_NODE_NONE = 0,
LLAMA_DFLASH_MAIN_NODE_QCUR,
LLAMA_DFLASH_MAIN_NODE_K_DRAFT,
LLAMA_DFLASH_MAIN_NODE_V_DRAFT,
LLAMA_DFLASH_MAIN_NODE_K_CTX_VIEW,
LLAMA_DFLASH_MAIN_NODE_V_CTX_VIEW,
LLAMA_DFLASH_MAIN_NODE_K_CONCAT,
LLAMA_DFLASH_MAIN_NODE_V_CONCAT,
LLAMA_DFLASH_MAIN_NODE_K_PAD,
LLAMA_DFLASH_MAIN_NODE_V_PAD,
LLAMA_DFLASH_MAIN_NODE_K_PERM_CONT,
LLAMA_DFLASH_MAIN_NODE_V_PERM_CONT,
LLAMA_DFLASH_MAIN_NODE_FLASH_ATTN,
LLAMA_DFLASH_MAIN_NODE_ATTN_OUT,
LLAMA_DFLASH_MAIN_NODE_FFN,
LLAMA_DFLASH_MAIN_NODE_RESULT_ROWS,
LLAMA_DFLASH_MAIN_NODE_RESULT_NORM,
LLAMA_DFLASH_MAIN_NODE_RESULT,
};
struct llama_dflash_kv_node_profiler {
llama_dflash_profile_stats * profile = nullptr;
int64_t t_start_us = 0;
llama_dflash_kv_node_kind active_kind = LLAMA_DFLASH_KV_NODE_NONE;
};
struct llama_dflash_main_node_profiler {
llama_dflash_profile_stats * profile = nullptr;
ggml_backend_sched_eval_callback prev_callback = nullptr;
void * prev_user_data = nullptr;
bool prev_active = false;
int64_t t_start_us = 0;
llama_dflash_main_node_kind active_kind = LLAMA_DFLASH_MAIN_NODE_NONE;
};
static bool llama_dflash_tensor_name_has_prefix(const struct ggml_tensor * tensor, const char * prefix) {
if (tensor == nullptr || prefix == nullptr || prefix[0] == '\0') {
return false;
}
return std::strncmp(tensor->name, prefix, std::strlen(prefix)) == 0;
}
static bool llama_dflash_tensor_name_matches_label(const struct ggml_tensor * tensor, const char * label) {
if (!llama_dflash_tensor_name_has_prefix(tensor, label)) {
return false;
}
const size_t label_len = std::strlen(label);
const char next = tensor->name[label_len];
return next == '\0' || next == '-';
}
static llama_dflash_kv_node_kind llama_dflash_kv_node_kind_from_tensor(const struct ggml_tensor * tensor) {
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_fused_target")) {
return LLAMA_DFLASH_KV_NODE_FUSED_TARGET;
}
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_proj")) {
return LLAMA_DFLASH_KV_NODE_K_PROJ;
}
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_norm")) {
return LLAMA_DFLASH_KV_NODE_K_NORM;
}
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_rope")) {
return LLAMA_DFLASH_KV_NODE_K_ROPE;
}
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_v_proj")) {
return LLAMA_DFLASH_KV_NODE_V_PROJ;
}
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_store")) {
return LLAMA_DFLASH_KV_NODE_K_STORE;
}
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_v_store")) {
return LLAMA_DFLASH_KV_NODE_V_STORE;
}
return LLAMA_DFLASH_KV_NODE_NONE;
}
static void llama_dflash_kv_node_profile_add(
llama_dflash_profile_stats & profile,
llama_dflash_kv_node_kind kind,
uint64_t elapsed_us) {
switch (kind) {
case LLAMA_DFLASH_KV_NODE_FUSED_TARGET:
profile.graph_kv_node_fused_target_calls++;
profile.graph_kv_node_fused_target_us += elapsed_us;
break;
case LLAMA_DFLASH_KV_NODE_K_PROJ:
profile.graph_kv_node_k_proj_calls++;
profile.graph_kv_node_k_proj_us += elapsed_us;
break;
case LLAMA_DFLASH_KV_NODE_K_NORM:
profile.graph_kv_node_k_norm_calls++;
profile.graph_kv_node_k_norm_us += elapsed_us;
break;
case LLAMA_DFLASH_KV_NODE_K_ROPE:
profile.graph_kv_node_k_rope_calls++;
profile.graph_kv_node_k_rope_us += elapsed_us;
break;
case LLAMA_DFLASH_KV_NODE_V_PROJ:
profile.graph_kv_node_v_proj_calls++;
profile.graph_kv_node_v_proj_us += elapsed_us;
break;
case LLAMA_DFLASH_KV_NODE_K_STORE:
profile.graph_kv_node_k_store_calls++;
profile.graph_kv_node_k_store_us += elapsed_us;
break;
case LLAMA_DFLASH_KV_NODE_V_STORE:
profile.graph_kv_node_v_store_calls++;
profile.graph_kv_node_v_store_us += elapsed_us;
break;
case LLAMA_DFLASH_KV_NODE_NONE:
break;
}
}
static llama_dflash_main_node_kind llama_dflash_main_node_kind_from_tensor(const struct ggml_tensor * tensor) {
if (llama_dflash_tensor_name_has_prefix(tensor, "Qcur")) {
return LLAMA_DFLASH_MAIN_NODE_QCUR;
}
if (llama_dflash_tensor_name_has_prefix(tensor, "Kcur_noise")) {
return LLAMA_DFLASH_MAIN_NODE_K_DRAFT;
}
if (llama_dflash_tensor_name_has_prefix(tensor, "Vcur_noise")) {
return LLAMA_DFLASH_MAIN_NODE_V_DRAFT;
}
if (llama_dflash_tensor_name_has_prefix(tensor, "Kcur_ctx_cache")) {
return LLAMA_DFLASH_MAIN_NODE_K_CTX_VIEW;
}
if (llama_dflash_tensor_name_has_prefix(tensor, "Vcur_ctx_cache")) {
return LLAMA_DFLASH_MAIN_NODE_V_CTX_VIEW;
}
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_k_concat")) {
return LLAMA_DFLASH_MAIN_NODE_K_CONCAT;
}
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_v_concat")) {
return LLAMA_DFLASH_MAIN_NODE_V_CONCAT;
}
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_k_pad")) {
return LLAMA_DFLASH_MAIN_NODE_K_PAD;
}
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_v_pad")) {
return LLAMA_DFLASH_MAIN_NODE_V_PAD;
}
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_k_perm_cont")) {
return LLAMA_DFLASH_MAIN_NODE_K_PERM_CONT;
}
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_v_perm_cont")) {
return LLAMA_DFLASH_MAIN_NODE_V_PERM_CONT;
}
if (llama_dflash_tensor_name_has_prefix(tensor, "flash_attn_reshaped")) {
return LLAMA_DFLASH_MAIN_NODE_NONE;
}
if (llama_dflash_tensor_name_matches_label(tensor, "flash_attn")) {
return LLAMA_DFLASH_MAIN_NODE_FLASH_ATTN;
}
if (llama_dflash_tensor_name_has_prefix(tensor, "kqv_out")) {
return LLAMA_DFLASH_MAIN_NODE_ATTN_OUT;
}
if (llama_dflash_tensor_name_has_prefix(tensor, "ffn_out")) {
return LLAMA_DFLASH_MAIN_NODE_FFN;
}
if (llama_dflash_tensor_name_matches_label(tensor, "result_output_rows")) {
return LLAMA_DFLASH_MAIN_NODE_RESULT_ROWS;
}
if (llama_dflash_tensor_name_matches_label(tensor, "result_norm")) {
return LLAMA_DFLASH_MAIN_NODE_RESULT_NORM;
}
if (llama_dflash_tensor_name_matches_label(tensor, "output")) {
return LLAMA_DFLASH_MAIN_NODE_RESULT;
}
if (llama_dflash_tensor_name_matches_label(tensor, "result_output")) {
return LLAMA_DFLASH_MAIN_NODE_RESULT;
}
return LLAMA_DFLASH_MAIN_NODE_NONE;
}
static void llama_dflash_main_node_profile_add(
llama_dflash_profile_stats & profile,
llama_dflash_main_node_kind kind,
uint64_t elapsed_us) {
switch (kind) {
case LLAMA_DFLASH_MAIN_NODE_QCUR:
profile.graph_main_node_qcur_calls++;
profile.graph_main_node_qcur_us += elapsed_us;
break;
case LLAMA_DFLASH_MAIN_NODE_K_DRAFT:
profile.graph_main_node_k_draft_calls++;
profile.graph_main_node_k_draft_us += elapsed_us;
break;
case LLAMA_DFLASH_MAIN_NODE_V_DRAFT:
profile.graph_main_node_v_draft_calls++;
profile.graph_main_node_v_draft_us += elapsed_us;
break;
case LLAMA_DFLASH_MAIN_NODE_K_CTX_VIEW:
profile.graph_main_node_k_ctx_view_calls++;
profile.graph_main_node_k_ctx_view_us += elapsed_us;
break;
case LLAMA_DFLASH_MAIN_NODE_V_CTX_VIEW:
profile.graph_main_node_v_ctx_view_calls++;
profile.graph_main_node_v_ctx_view_us += elapsed_us;
break;
case LLAMA_DFLASH_MAIN_NODE_K_CONCAT:
profile.graph_main_node_k_concat_calls++;
profile.graph_main_node_k_concat_us += elapsed_us;
break;
case LLAMA_DFLASH_MAIN_NODE_V_CONCAT:
profile.graph_main_node_v_concat_calls++;
profile.graph_main_node_v_concat_us += elapsed_us;
break;
case LLAMA_DFLASH_MAIN_NODE_K_PAD:
profile.graph_main_node_k_pad_calls++;
profile.graph_main_node_k_pad_us += elapsed_us;
break;
case LLAMA_DFLASH_MAIN_NODE_V_PAD:
profile.graph_main_node_v_pad_calls++;
profile.graph_main_node_v_pad_us += elapsed_us;
break;
case LLAMA_DFLASH_MAIN_NODE_K_PERM_CONT:
profile.graph_main_node_k_perm_cont_calls++;
profile.graph_main_node_k_perm_cont_us += elapsed_us;
break;
case LLAMA_DFLASH_MAIN_NODE_V_PERM_CONT:
profile.graph_main_node_v_perm_cont_calls++;
profile.graph_main_node_v_perm_cont_us += elapsed_us;
break;
case LLAMA_DFLASH_MAIN_NODE_FLASH_ATTN:
profile.graph_main_node_flash_attn_calls++;
profile.graph_main_node_flash_attn_us += elapsed_us;
break;
case LLAMA_DFLASH_MAIN_NODE_ATTN_OUT:
profile.graph_main_node_attn_out_calls++;
profile.graph_main_node_attn_out_us += elapsed_us;
break;
case LLAMA_DFLASH_MAIN_NODE_FFN:
profile.graph_main_node_ffn_calls++;
profile.graph_main_node_ffn_us += elapsed_us;
break;
case LLAMA_DFLASH_MAIN_NODE_RESULT_ROWS:
profile.graph_main_node_result_rows_calls++;
profile.graph_main_node_result_rows_us += elapsed_us;
break;
case LLAMA_DFLASH_MAIN_NODE_RESULT_NORM:
profile.graph_main_node_result_norm_calls++;
profile.graph_main_node_result_norm_us += elapsed_us;
break;
case LLAMA_DFLASH_MAIN_NODE_RESULT:
profile.graph_main_node_result_calls++;
profile.graph_main_node_result_us += elapsed_us;
break;
case LLAMA_DFLASH_MAIN_NODE_NONE:
break;
}
}
static bool llama_dflash_kv_node_eval_callback(struct ggml_tensor * tensor, bool ask, void * user_data) {
auto * profiler = static_cast<llama_dflash_kv_node_profiler *>(user_data);
if (profiler == nullptr || profiler->profile == nullptr) {
return false;
}
const llama_dflash_kv_node_kind kind = llama_dflash_kv_node_kind_from_tensor(tensor);
if (ask) {
if (kind == LLAMA_DFLASH_KV_NODE_NONE) {
return false;
}
profiler->active_kind = kind;
profiler->t_start_us = ggml_time_us();
return true;
}
if (kind != LLAMA_DFLASH_KV_NODE_NONE && profiler->active_kind == kind && profiler->t_start_us > 0) {
llama_dflash_kv_node_profile_add(*profiler->profile, kind, (uint64_t) (ggml_time_us() - profiler->t_start_us));
}
profiler->active_kind = LLAMA_DFLASH_KV_NODE_NONE;
profiler->t_start_us = 0;
return true;
}
static bool llama_dflash_main_node_eval_callback(struct ggml_tensor * tensor, bool ask, void * user_data) {
auto * profiler = static_cast<llama_dflash_main_node_profiler *>(user_data);
if (profiler == nullptr || profiler->profile == nullptr) {
return false;
}
const llama_dflash_main_node_kind kind = llama_dflash_main_node_kind_from_tensor(tensor);
if (ask) {
profiler->prev_active = profiler->prev_callback != nullptr
? profiler->prev_callback(tensor, ask, profiler->prev_user_data)
: false;
if (kind == LLAMA_DFLASH_MAIN_NODE_NONE) {
profiler->active_kind = LLAMA_DFLASH_MAIN_NODE_NONE;
profiler->t_start_us = 0;
return profiler->prev_active;
}
profiler->active_kind = kind;
profiler->t_start_us = ggml_time_us();
return true;
}
bool prev_result = false;
if (profiler->prev_active && profiler->prev_callback != nullptr) {
prev_result = profiler->prev_callback(tensor, ask, profiler->prev_user_data);
}
const bool tracked = kind != LLAMA_DFLASH_MAIN_NODE_NONE &&
profiler->active_kind == kind &&
profiler->t_start_us > 0;
if (tracked) {
llama_dflash_main_node_profile_add(*profiler->profile, kind, (uint64_t) (ggml_time_us() - profiler->t_start_us));
}
profiler->prev_active = false;
profiler->active_kind = LLAMA_DFLASH_MAIN_NODE_NONE;
profiler->t_start_us = 0;
return prev_result || tracked;
}
// extract ip and port from RPC[ip:port] for rpc and keep other device names
static std::vector<rpc_device> extract_device_from_rpc_device(std::vector<std::string> devices) {
std::vector<rpc_device> rpc_servers;