mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-06-28 04:30:15 -05:00
remove duplicated code and unnecesary refactor
This commit is contained in:
parent
3b1a0f88d5
commit
0d75eee35a
872
common/speculative-dflash-impl.h
Normal file
872
common/speculative-dflash-impl.h
Normal file
@ -0,0 +1,872 @@
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <atomic>
|
||||
#include <cstddef>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
|
||||
static bool common_speculative_are_dflash_compatible(
|
||||
const llama_model * model_tgt,
|
||||
const llama_model * model_dft) {
|
||||
const llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt);
|
||||
const llama_vocab * vocab_dft = llama_model_get_vocab(model_dft);
|
||||
|
||||
if (llama_vocab_type(vocab_tgt) != llama_vocab_type(vocab_dft)) {
|
||||
LOG_DBG("%s: DFlash draft model vocab type must match the target model\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
const bool add_bos_tgt = llama_vocab_get_add_bos(vocab_tgt);
|
||||
const bool add_bos_dft = llama_vocab_get_add_bos(vocab_dft);
|
||||
const bool add_eos_tgt = llama_vocab_get_add_eos(vocab_tgt);
|
||||
const bool add_eos_dft = llama_vocab_get_add_eos(vocab_dft);
|
||||
const llama_token bos_tgt = llama_vocab_bos(vocab_tgt);
|
||||
const llama_token bos_dft = llama_vocab_bos(vocab_dft);
|
||||
const llama_token eos_tgt = llama_vocab_eos(vocab_tgt);
|
||||
const llama_token eos_dft = llama_vocab_eos(vocab_dft);
|
||||
|
||||
if (add_bos_tgt != add_bos_dft || add_eos_tgt != add_eos_dft ||
|
||||
(add_bos_tgt && bos_tgt != bos_dft) ||
|
||||
(add_eos_tgt && eos_tgt != eos_dft)) {
|
||||
LOG_DBG("%s: DFlash draft special tokens must match the target model (add_bos=%d/%d add_eos=%d/%d bos=%d/%d eos=%d/%d)\n",
|
||||
__func__,
|
||||
(int) add_bos_tgt,
|
||||
(int) add_bos_dft,
|
||||
(int) add_eos_tgt,
|
||||
(int) add_eos_dft,
|
||||
(int) bos_tgt,
|
||||
(int) bos_dft,
|
||||
(int) eos_tgt,
|
||||
(int) eos_dft);
|
||||
return false;
|
||||
}
|
||||
|
||||
const int n_vocab_tgt = llama_vocab_n_tokens(vocab_tgt);
|
||||
const int n_vocab_dft = llama_vocab_n_tokens(vocab_dft);
|
||||
const int vocab_diff = n_vocab_tgt > n_vocab_dft
|
||||
? n_vocab_tgt - n_vocab_dft
|
||||
: n_vocab_dft - n_vocab_tgt;
|
||||
|
||||
if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
|
||||
LOG_DBG("%s: DFlash draft vocab size differs too much from the target model (%d vs %d)\n",
|
||||
__func__, n_vocab_dft, n_vocab_tgt);
|
||||
return false;
|
||||
}
|
||||
|
||||
for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) {
|
||||
const char * token_text_tgt = llama_vocab_get_text(vocab_tgt, i);
|
||||
const char * token_text_dft = llama_vocab_get_text(vocab_dft, i);
|
||||
|
||||
if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
|
||||
LOG_DBG("%s: DFlash draft token %d differs - target '%s', draft '%s'\n", __func__, i,
|
||||
common_token_to_piece(vocab_tgt, i).c_str(),
|
||||
common_token_to_piece(vocab_dft, i).c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool dflash_contract_log_enabled() {
|
||||
const char * env = std::getenv("IK_DFLASH_CONTRACT_LOG");
|
||||
if (env == nullptr || *env == '\0') {
|
||||
return false;
|
||||
}
|
||||
|
||||
return std::strcmp(env, "0") != 0 &&
|
||||
std::strcmp(env, "false") != 0 &&
|
||||
std::strcmp(env, "off") != 0;
|
||||
}
|
||||
|
||||
static bool dflash_stats_log_enabled() {
|
||||
const char * env = std::getenv("IK_DFLASH_STATS_LOG");
|
||||
if (env == nullptr || *env == '\0') {
|
||||
return false;
|
||||
}
|
||||
|
||||
return std::strcmp(env, "0") != 0 &&
|
||||
std::strcmp(env, "false") != 0 &&
|
||||
std::strcmp(env, "off") != 0;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static std::string dflash_contract_format_values(
|
||||
const std::vector<T> & values,
|
||||
size_t edge_count = 4) {
|
||||
std::ostringstream oss;
|
||||
oss << '[';
|
||||
if (values.empty()) {
|
||||
oss << ']';
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
const size_t head = std::min(edge_count, values.size());
|
||||
for (size_t i = 0; i < head; ++i) {
|
||||
if (i > 0) {
|
||||
oss << ',';
|
||||
}
|
||||
oss << values[i];
|
||||
}
|
||||
|
||||
if (values.size() > edge_count * 2) {
|
||||
oss << ",...,";
|
||||
for (size_t i = values.size() - edge_count; i < values.size(); ++i) {
|
||||
if (i > values.size() - edge_count) {
|
||||
oss << ',';
|
||||
}
|
||||
oss << values[i];
|
||||
}
|
||||
} else {
|
||||
for (size_t i = head; i < values.size(); ++i) {
|
||||
oss << ',' << values[i];
|
||||
}
|
||||
}
|
||||
|
||||
oss << ']';
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
struct dflash_contract_pos_summary {
|
||||
llama_pos first = -1;
|
||||
llama_pos last = -1;
|
||||
int32_t gap_count = 0;
|
||||
int32_t nonmono_count = 0;
|
||||
};
|
||||
|
||||
static dflash_contract_pos_summary dflash_contract_summarize_positions(
|
||||
const std::vector<llama_pos> & positions) {
|
||||
dflash_contract_pos_summary summary;
|
||||
if (positions.empty()) {
|
||||
return summary;
|
||||
}
|
||||
|
||||
summary.first = positions.front();
|
||||
summary.last = positions.back();
|
||||
for (size_t i = 1; i < positions.size(); ++i) {
|
||||
if (positions[i] <= positions[i - 1]) {
|
||||
summary.nonmono_count++;
|
||||
} else if (positions[i] != positions[i - 1] + 1) {
|
||||
summary.gap_count++;
|
||||
}
|
||||
}
|
||||
|
||||
return summary;
|
||||
}
|
||||
|
||||
struct common_speculative_state_dflash;
|
||||
|
||||
static void dflash_contract_log_append(
|
||||
const common_speculative_state_dflash & state,
|
||||
llama_seq_id seq_id,
|
||||
const std::vector<llama_pos> & new_positions);
|
||||
static void dflash_contract_log_draft(
|
||||
const common_speculative_state_dflash & state,
|
||||
int32_t n_keep,
|
||||
size_t result_size);
|
||||
static void dflash_materialize_target_window_features(common_speculative_state_dflash & state);
|
||||
|
||||
// DFlash runtime state and draft path.
|
||||
struct common_speculative_state_dflash : public common_speculative_state {
|
||||
llama_context * ctx_tgt;
|
||||
llama_context * ctx_dft;
|
||||
|
||||
llama_batch batch = {};
|
||||
|
||||
int32_t block_size = 0;
|
||||
int32_t mask_token_id = -1;
|
||||
int32_t n_target_features = 0;
|
||||
int32_t cross_ctx = 0;
|
||||
bool ready = false;
|
||||
|
||||
std::vector<int32_t> target_layer_ids;
|
||||
std::vector<float> target_window;
|
||||
std::vector<llama_pos> target_window_pos;
|
||||
std::vector<float> target_window_stage;
|
||||
std::vector<llama_pos> target_window_pos_stage;
|
||||
std::vector<float> target_window_ring;
|
||||
std::vector<float> target_window_append_features;
|
||||
int32_t target_window_rows = 0;
|
||||
int32_t target_window_ring_write_pos = 0;
|
||||
int32_t target_window_ring_filled = 0;
|
||||
uint64_t target_window_version = 0;
|
||||
int32_t target_window_keep_rows = 0;
|
||||
int32_t target_window_append_rows = 0;
|
||||
bool target_window_replace = false;
|
||||
bool target_window_materialized = false;
|
||||
llama_pos last_target_pos = -1;
|
||||
size_t n_window_updates = 0;
|
||||
size_t n_rows_seen = 0;
|
||||
size_t n_rows_dropped = 0;
|
||||
size_t n_context_shifts = 0;
|
||||
size_t n_draft_empty = 0;
|
||||
size_t n_set_target_fail = 0;
|
||||
size_t n_decode_fail = 0;
|
||||
llama_pos last_draft_pos_base = -1;
|
||||
|
||||
uint64_t t_draft_decode_us = 0;
|
||||
uint64_t t_draft_sample_us = 0;
|
||||
uint64_t t_warmup_collect_us = 0;
|
||||
uint64_t t_warmup_append_us = 0;
|
||||
uint64_t t_accept_output_copy_us = 0;
|
||||
uint64_t t_accept_commit_us = 0;
|
||||
uint64_t t_accept_append_us = 0;
|
||||
uint64_t t_accept_append_filter_us = 0;
|
||||
uint64_t t_accept_append_window_alloc_us = 0;
|
||||
uint64_t t_accept_append_replace_us = 0;
|
||||
uint64_t t_accept_append_keep_old_us = 0;
|
||||
uint64_t t_accept_append_new_rows_us = 0;
|
||||
uint64_t t_accept_append_commit_detail_us = 0;
|
||||
uint64_t t_accept_append_log_us = 0;
|
||||
size_t n_warmup_collect_calls = 0;
|
||||
size_t n_warmup_collect_rows = 0;
|
||||
size_t n_warmup_append_calls = 0;
|
||||
size_t n_warmup_append_rows = 0;
|
||||
size_t n_accept_output_copy_calls = 0;
|
||||
size_t n_accept_output_copy_rows = 0;
|
||||
size_t n_accept_commit_calls = 0;
|
||||
size_t n_accept_commit_rows = 0;
|
||||
size_t n_accept_append_calls = 0;
|
||||
size_t n_accept_append_rows = 0;
|
||||
size_t n_accept_append_replace_calls = 0;
|
||||
size_t n_accept_append_slide_calls = 0;
|
||||
|
||||
common_speculative_state_dflash(
|
||||
enum common_speculative_type type,
|
||||
llama_context * ctx_tgt,
|
||||
llama_context * ctx_dft,
|
||||
int32_t cross_ctx)
|
||||
: common_speculative_state(type)
|
||||
, ctx_tgt(ctx_tgt)
|
||||
, ctx_dft(ctx_dft)
|
||||
, cross_ctx(std::max(1, cross_ctx))
|
||||
{
|
||||
const llama_model * model_tgt = llama_get_model(ctx_tgt);
|
||||
const llama_model * model_dft = llama_get_model(ctx_dft);
|
||||
|
||||
if (!common_speculative_are_dflash_compatible(model_tgt, model_dft)) {
|
||||
LOG_ERR("%s: DFlash draft model vocab/tokenizer is incompatible with the target model\n", __func__);
|
||||
return;
|
||||
}
|
||||
|
||||
block_size = llama_model_dflash_block_size(model_dft);
|
||||
mask_token_id = llama_model_dflash_mask_token_id(model_dft);
|
||||
n_target_features = llama_model_dflash_n_target_features(model_dft);
|
||||
const int32_t n_target_layers = llama_model_dflash_n_target_layers(model_dft);
|
||||
|
||||
if (block_size <= 0 || mask_token_id < 0 || n_target_features <= 0 || n_target_layers <= 0) {
|
||||
LOG_ERR("%s: invalid DFlash metadata (block_size=%d, mask_token_id=%d, n_target_features=%d, n_target_layers=%d)\n",
|
||||
__func__, block_size, mask_token_id, n_target_features, n_target_layers);
|
||||
return;
|
||||
}
|
||||
|
||||
target_layer_ids.resize((size_t) n_target_layers);
|
||||
if (llama_model_dflash_target_layer_ids(model_dft, target_layer_ids.data(), n_target_layers) != n_target_layers) {
|
||||
LOG_ERR("%s: failed to read DFlash target layer ids\n", __func__);
|
||||
target_layer_ids.clear();
|
||||
return;
|
||||
}
|
||||
|
||||
const auto * vocab_tgt = llama_model_get_vocab(model_tgt);
|
||||
const auto * vocab_dft = llama_model_get_vocab(model_dft);
|
||||
const int32_t target_vocab_size = llama_vocab_n_tokens(vocab_tgt);
|
||||
const int32_t draft_vocab_size = llama_vocab_n_tokens(vocab_dft);
|
||||
const int32_t target_hidden_size = llama_model_n_embd(model_tgt);
|
||||
const int32_t draft_hidden_size = llama_model_n_embd(model_dft);
|
||||
const int32_t target_mask_token_id = llama_model_dflash_target_mask_token_id(model_tgt);
|
||||
const int32_t expected_n_target_features = target_hidden_size > 0 ? target_hidden_size * n_target_layers : 0;
|
||||
|
||||
if (target_mask_token_id != (int32_t) LLAMA_TOKEN_NULL && mask_token_id != target_mask_token_id) {
|
||||
LOG_ERR("%s: DFlash mask token mismatch (draft=%d target=%d)\n",
|
||||
__func__, mask_token_id, target_mask_token_id);
|
||||
return;
|
||||
}
|
||||
|
||||
if (target_hidden_size <= 0 || draft_hidden_size <= 0) {
|
||||
LOG_ERR("%s: invalid DFlash hidden sizes (draft=%d target=%d)\n",
|
||||
__func__, draft_hidden_size, target_hidden_size);
|
||||
return;
|
||||
}
|
||||
|
||||
if (expected_n_target_features <= 0 || n_target_features != expected_n_target_features) {
|
||||
LOG_ERR("%s: DFlash target feature width mismatch (metadata=%d expected=%d target_hidden=%d target_layers=%d)\n",
|
||||
__func__, n_target_features, expected_n_target_features, target_hidden_size, n_target_layers);
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<int32_t> sorted_target_layer_ids = target_layer_ids;
|
||||
std::sort(sorted_target_layer_ids.begin(), sorted_target_layer_ids.end());
|
||||
if (std::adjacent_find(sorted_target_layer_ids.begin(), sorted_target_layer_ids.end()) != sorted_target_layer_ids.end()) {
|
||||
LOG_ERR("%s: duplicate DFlash target layer ids survived into runtime validation\n", __func__);
|
||||
target_layer_ids.clear();
|
||||
return;
|
||||
}
|
||||
|
||||
const int32_t n_target_model_layers = llama_n_layer(model_tgt);
|
||||
for (int32_t layer_id : target_layer_ids) {
|
||||
if (layer_id < 0 || layer_id >= n_target_model_layers) {
|
||||
LOG_ERR("%s: invalid DFlash target layer id %d for target model with %d layers\n",
|
||||
__func__, layer_id, n_target_model_layers);
|
||||
target_layer_ids.clear();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
const int32_t io_mode = llama_model_dflash_io_mode(model_dft, model_tgt);
|
||||
if (io_mode == LLAMA_DFLASH_IO_MODE_INVALID) {
|
||||
LOG_ERR("%s: DFlash draft is missing required IO tensors after target sharing\n", __func__);
|
||||
return;
|
||||
}
|
||||
|
||||
if (io_mode == LLAMA_DFLASH_IO_MODE_MIXED) {
|
||||
LOG_ERR("%s: DFlash IO contract must be fully shared or fully self-contained, but resolved to mixed mode\n", __func__);
|
||||
return;
|
||||
}
|
||||
|
||||
if (io_mode == LLAMA_DFLASH_IO_MODE_SELF_CONTAINED && !llama_model_dflash_io_tensors_match(model_dft, target_hidden_size, target_vocab_size)) {
|
||||
LOG_ERR("%s: DFlash self-contained IO tensors do not match the target hidden/vocab contract (target_hidden=%d target_vocab=%d)\n",
|
||||
__func__,
|
||||
target_hidden_size,
|
||||
target_vocab_size);
|
||||
return;
|
||||
}
|
||||
|
||||
if (!llama_set_dflash_capture_layers(ctx_tgt, target_layer_ids.data(), (int32_t) target_layer_ids.size())) {
|
||||
LOG_ERR("%s: failed to configure DFlash target capture callback\n", __func__);
|
||||
return;
|
||||
}
|
||||
|
||||
batch = llama_batch_init(std::max(1, block_size), 0, 1);
|
||||
target_window.reserve((size_t) this->cross_ctx * (size_t) n_target_features);
|
||||
target_window_stage.reserve((size_t) this->cross_ctx * (size_t) n_target_features);
|
||||
target_window_ring.resize((size_t) this->cross_ctx * (size_t) n_target_features);
|
||||
target_window_append_features.reserve((size_t) this->cross_ctx * (size_t) n_target_features);
|
||||
target_window_pos.reserve((size_t) this->cross_ctx);
|
||||
target_window_pos_stage.reserve((size_t) this->cross_ctx);
|
||||
ready = true;
|
||||
|
||||
llama_set_dflash_visible_cross_ctx(ctx_dft, this->cross_ctx);
|
||||
llama_dflash_profile_reset(ctx_tgt);
|
||||
llama_dflash_profile_reset(ctx_dft);
|
||||
|
||||
std::ostringstream layers_oss;
|
||||
for (size_t i = 0; i < target_layer_ids.size(); ++i) {
|
||||
if (i > 0) {
|
||||
layers_oss << ",";
|
||||
}
|
||||
layers_oss << target_layer_ids[i];
|
||||
}
|
||||
|
||||
const char * io_mode_name = io_mode == LLAMA_DFLASH_IO_MODE_SHARED ? "shared" : "self-contained";
|
||||
LOG_INF("%s: DFlash context ready (n_ctx=%d, block_size=%d, cross_ctx=%d, n_target_features=%d, target_layer_ids=[%s])\n",
|
||||
__func__, llama_n_ctx(ctx_dft), block_size, this->cross_ctx, n_target_features, layers_oss.str().c_str());
|
||||
LOG_INF("%s: DFlash artifact io=%s draft_vocab=%d target_vocab=%d draft_hidden=%d target_hidden=%d mask_token_id=%d target_mask_token_id=%d\n",
|
||||
__func__, io_mode_name, draft_vocab_size, target_vocab_size, draft_hidden_size, target_hidden_size, mask_token_id, target_mask_token_id);
|
||||
}
|
||||
|
||||
~common_speculative_state_dflash() override {
|
||||
llama_clear_dflash_capture(ctx_tgt);
|
||||
if (ctx_dft) {
|
||||
llama_free(ctx_dft);
|
||||
}
|
||||
if (batch.token != nullptr) {
|
||||
llama_batch_free(batch);
|
||||
}
|
||||
}
|
||||
|
||||
void begin(const llama_tokens & prompt) override {
|
||||
GGML_UNUSED(prompt);
|
||||
llama_kv_cache_clear(ctx_dft);
|
||||
llama_reset_dflash_kv_cache_state(ctx_dft);
|
||||
n_window_updates = 0;
|
||||
n_rows_seen = 0;
|
||||
n_rows_dropped = 0;
|
||||
n_context_shifts = 0;
|
||||
n_draft_empty = 0;
|
||||
n_set_target_fail = 0;
|
||||
n_decode_fail = 0;
|
||||
last_draft_pos_base = -1;
|
||||
t_draft_decode_us = 0;
|
||||
t_draft_sample_us = 0;
|
||||
t_warmup_collect_us = 0;
|
||||
t_warmup_append_us = 0;
|
||||
t_accept_output_copy_us = 0;
|
||||
t_accept_commit_us = 0;
|
||||
t_accept_append_us = 0;
|
||||
t_accept_append_filter_us = 0;
|
||||
t_accept_append_window_alloc_us = 0;
|
||||
t_accept_append_replace_us = 0;
|
||||
t_accept_append_keep_old_us = 0;
|
||||
t_accept_append_new_rows_us = 0;
|
||||
t_accept_append_commit_detail_us = 0;
|
||||
t_accept_append_log_us = 0;
|
||||
n_warmup_collect_calls = 0;
|
||||
n_warmup_collect_rows = 0;
|
||||
n_warmup_append_calls = 0;
|
||||
n_warmup_append_rows = 0;
|
||||
n_accept_output_copy_calls = 0;
|
||||
n_accept_output_copy_rows = 0;
|
||||
n_accept_commit_calls = 0;
|
||||
n_accept_commit_rows = 0;
|
||||
n_accept_append_calls = 0;
|
||||
n_accept_append_rows = 0;
|
||||
n_accept_append_replace_calls = 0;
|
||||
n_accept_append_slide_calls = 0;
|
||||
llama_dflash_profile_reset(ctx_tgt);
|
||||
llama_dflash_profile_reset(ctx_dft);
|
||||
}
|
||||
|
||||
void draft(
|
||||
const common_params_speculative & params,
|
||||
const llama_tokens & prompt_tgt,
|
||||
llama_token id_last,
|
||||
llama_tokens & result) override {
|
||||
GGML_UNUSED(prompt_tgt);
|
||||
|
||||
result.clear();
|
||||
if (!ready || target_window_rows <= 0) {
|
||||
n_draft_empty++;
|
||||
return;
|
||||
}
|
||||
|
||||
const int32_t n_keep = std::min<int32_t>(params.n_max, block_size - 1);
|
||||
if (n_keep <= 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const float * target_features = nullptr;
|
||||
size_t target_feature_floats = 0;
|
||||
llama_dflash_window_update window_update = {
|
||||
target_window_version,
|
||||
target_window_keep_rows,
|
||||
target_window_append_rows,
|
||||
target_window_replace,
|
||||
target_window_append_features.empty() ? nullptr : target_window_append_features.data(),
|
||||
target_window_append_features.size(),
|
||||
};
|
||||
const llama_dflash_kv_cache_transition cache_plan =
|
||||
llama_plan_dflash_kv_cache_transition_for_ctx(ctx_dft, window_update, target_window_rows);
|
||||
|
||||
if (cache_plan.rebuild_cache) {
|
||||
dflash_materialize_target_window_features(*this);
|
||||
target_features = target_window.data();
|
||||
target_feature_floats = target_window.size();
|
||||
window_update.append_features = target_window.data();
|
||||
window_update.append_floats = target_window.size();
|
||||
window_update.append_rows = target_window_rows;
|
||||
}
|
||||
|
||||
if (!llama_set_dflash_target_features_view(ctx_dft, target_features, target_feature_floats, target_window_rows, target_window_pos.data(), &window_update)) {
|
||||
LOG_ERR("%s: failed to set DFlash target features\n", __func__);
|
||||
n_set_target_fail++;
|
||||
return;
|
||||
}
|
||||
|
||||
llama_kv_cache_clear(ctx_dft);
|
||||
batch.n_tokens = 0;
|
||||
const int32_t batch_len = n_keep + 1;
|
||||
const llama_pos draft_pos_base = last_target_pos >= 0 ? last_target_pos + 1 : (llama_pos) target_window_rows;
|
||||
const llama_pos seed_pos = last_target_pos >= 0 ? last_target_pos : draft_pos_base - 1;
|
||||
last_draft_pos_base = draft_pos_base;
|
||||
common_batch_add(batch, id_last, seed_pos, { 0 }, false);
|
||||
for (int32_t i = 1; i < batch_len; ++i) {
|
||||
common_batch_add(batch, mask_token_id, draft_pos_base + (i - 1), { 0 }, i <= n_keep);
|
||||
}
|
||||
|
||||
const int64_t t_decode_us = ggml_time_us();
|
||||
if (llama_decode(ctx_dft, batch) != 0) {
|
||||
LOG_ERR("%s: llama_decode() failed for DFlash draft batch\n", __func__);
|
||||
n_decode_fail++;
|
||||
batch.n_tokens = 0;
|
||||
return;
|
||||
}
|
||||
t_draft_decode_us += (uint64_t) (ggml_time_us() - t_decode_us);
|
||||
|
||||
result.reserve((size_t) n_keep);
|
||||
const int64_t t_sample_us = ggml_time_us();
|
||||
for (int32_t i = 0; i < n_keep; ++i) {
|
||||
llama_token id = llama_get_dflash_draft_token_ith(ctx_dft, i);
|
||||
if (id == LLAMA_TOKEN_NULL) {
|
||||
id = common_sampler_sample_speculative(nullptr, ctx_dft, i + 1, nullptr);
|
||||
}
|
||||
result.push_back(id);
|
||||
}
|
||||
t_draft_sample_us += (uint64_t) (ggml_time_us() - t_sample_us);
|
||||
|
||||
batch.n_tokens = 0;
|
||||
dflash_contract_log_draft(*this, n_keep, result.size());
|
||||
}
|
||||
|
||||
void accept(uint16_t n_accepted) override {
|
||||
GGML_UNUSED(n_accepted);
|
||||
}
|
||||
};
|
||||
|
||||
static void dflash_contract_log_append(
|
||||
const common_speculative_state_dflash & state,
|
||||
llama_seq_id seq_id,
|
||||
const std::vector<llama_pos> & new_positions) {
|
||||
if (!dflash_contract_log_enabled()) {
|
||||
return;
|
||||
}
|
||||
|
||||
static std::atomic<uint64_t> counter = 0;
|
||||
const uint64_t ordinal = counter.fetch_add(1, std::memory_order_relaxed);
|
||||
if (ordinal >= 8) {
|
||||
return;
|
||||
}
|
||||
|
||||
const dflash_contract_pos_summary incoming = dflash_contract_summarize_positions(new_positions);
|
||||
const dflash_contract_pos_summary window = dflash_contract_summarize_positions(state.target_window_pos);
|
||||
|
||||
LOG_INF("dflash contract append[%llu]: seq=%d incoming_rows=%zu incoming_pos=%s pos=[%d..%d] gaps=%d nonmono=%d window_rows=%d window_pos=%s pos=[%d..%d] gaps=%d nonmono=%d last_target_pos=%d\n",
|
||||
(unsigned long long) (ordinal + 1),
|
||||
(int) seq_id,
|
||||
new_positions.size(),
|
||||
dflash_contract_format_values(new_positions).c_str(),
|
||||
(int) incoming.first,
|
||||
(int) incoming.last,
|
||||
incoming.gap_count,
|
||||
incoming.nonmono_count,
|
||||
state.target_window_rows,
|
||||
dflash_contract_format_values(state.target_window_pos).c_str(),
|
||||
(int) window.first,
|
||||
(int) window.last,
|
||||
window.gap_count,
|
||||
window.nonmono_count,
|
||||
(int) state.last_target_pos);
|
||||
}
|
||||
|
||||
static void dflash_contract_log_draft(
|
||||
const common_speculative_state_dflash & state,
|
||||
int32_t n_keep,
|
||||
size_t result_size) {
|
||||
if (!dflash_contract_log_enabled()) {
|
||||
return;
|
||||
}
|
||||
|
||||
static std::atomic<uint64_t> counter = 0;
|
||||
const uint64_t ordinal = counter.fetch_add(1, std::memory_order_relaxed);
|
||||
if (ordinal >= 8) {
|
||||
return;
|
||||
}
|
||||
|
||||
const dflash_contract_pos_summary window = dflash_contract_summarize_positions(state.target_window_pos);
|
||||
llama_dflash_profile_stats graph_stats = {};
|
||||
llama_dflash_profile_get_stats(state.ctx_dft, &graph_stats);
|
||||
const int draft_delta = (state.last_target_pos >= 0 && state.last_draft_pos_base >= 0)
|
||||
? (int) (state.last_draft_pos_base - state.last_target_pos)
|
||||
: -1;
|
||||
const llama_pos seed_pos = state.last_target_pos;
|
||||
const llama_pos mask_first_pos = state.last_draft_pos_base;
|
||||
const llama_pos mask_last_pos = state.last_draft_pos_base >= 0
|
||||
? state.last_draft_pos_base + n_keep - 1
|
||||
: -1;
|
||||
|
||||
LOG_INF("dflash contract draft[%llu]: window_rows=%d window_pos=%s pos=[%d..%d] gaps=%d nonmono=%d last_target_pos=%d seed_pos=%d mask_pos=[%d..%d] sample_rows=[1..%d] output_rows=[1..%d] draft_pos_base=%d delta=%d n_keep=%d result=%zu set_target(missing/nonmono)=%llu/%llu graph(fallback/nonmono)=%llu/%llu graph_pos=[%d..%d]\n",
|
||||
(unsigned long long) (ordinal + 1),
|
||||
state.target_window_rows,
|
||||
dflash_contract_format_values(state.target_window_pos).c_str(),
|
||||
(int) window.first,
|
||||
(int) window.last,
|
||||
window.gap_count,
|
||||
window.nonmono_count,
|
||||
(int) state.last_target_pos,
|
||||
(int) seed_pos,
|
||||
(int) mask_first_pos,
|
||||
(int) mask_last_pos,
|
||||
n_keep,
|
||||
n_keep,
|
||||
(int) state.last_draft_pos_base,
|
||||
draft_delta,
|
||||
n_keep,
|
||||
result_size,
|
||||
(unsigned long long) graph_stats.set_target_missing_positions,
|
||||
(unsigned long long) graph_stats.set_target_non_monotonic_positions,
|
||||
(unsigned long long) graph_stats.graph_pos_fallbacks,
|
||||
(unsigned long long) graph_stats.graph_pos_non_monotonic,
|
||||
(int) graph_stats.last_pos_first,
|
||||
(int) graph_stats.last_pos_last);
|
||||
}
|
||||
|
||||
struct dflash_append_breakdown {
|
||||
uint64_t filter_us = 0;
|
||||
uint64_t window_alloc_us = 0;
|
||||
uint64_t replace_us = 0;
|
||||
uint64_t keep_old_us = 0;
|
||||
uint64_t new_rows_us = 0;
|
||||
uint64_t commit_us = 0;
|
||||
uint64_t log_us = 0;
|
||||
bool replace_call = false;
|
||||
};
|
||||
|
||||
static void dflash_record_window_update(
|
||||
common_speculative_state_dflash & state,
|
||||
int32_t keep_rows,
|
||||
int32_t append_rows,
|
||||
bool replace) {
|
||||
state.target_window_keep_rows = std::max<int32_t>(0, keep_rows);
|
||||
state.target_window_append_rows = std::max<int32_t>(0, append_rows);
|
||||
state.target_window_replace = replace;
|
||||
state.target_window_version++;
|
||||
}
|
||||
|
||||
static void dflash_ring_reset_rows(
|
||||
common_speculative_state_dflash & state,
|
||||
const float * rows,
|
||||
int32_t n_rows) {
|
||||
const size_t row_width = (size_t) state.n_target_features;
|
||||
if (n_rows <= 0 || rows == nullptr) {
|
||||
state.target_window_ring_write_pos = 0;
|
||||
state.target_window_ring_filled = 0;
|
||||
return;
|
||||
}
|
||||
|
||||
if (state.target_window_ring.size() != (size_t) state.cross_ctx * row_width) {
|
||||
state.target_window_ring.resize((size_t) state.cross_ctx * row_width);
|
||||
}
|
||||
|
||||
std::memcpy(state.target_window_ring.data(), rows, (size_t) n_rows * row_width * sizeof(float));
|
||||
state.target_window_ring_write_pos = n_rows % state.cross_ctx;
|
||||
state.target_window_ring_filled = n_rows;
|
||||
state.target_window_materialized = false;
|
||||
}
|
||||
|
||||
static void dflash_ring_append_rows(
|
||||
common_speculative_state_dflash & state,
|
||||
const float * rows,
|
||||
int32_t n_rows) {
|
||||
const size_t row_width = (size_t) state.n_target_features;
|
||||
if (n_rows <= 0 || rows == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (state.target_window_ring.size() != (size_t) state.cross_ctx * row_width) {
|
||||
state.target_window_ring.resize((size_t) state.cross_ctx * row_width);
|
||||
}
|
||||
|
||||
int32_t write_pos = state.target_window_ring_write_pos;
|
||||
int32_t remaining = n_rows;
|
||||
const float * src = rows;
|
||||
while (remaining > 0) {
|
||||
const int32_t chunk_rows = std::min<int32_t>(remaining, state.cross_ctx - write_pos);
|
||||
std::memcpy(
|
||||
state.target_window_ring.data() + (size_t) write_pos * row_width,
|
||||
src,
|
||||
(size_t) chunk_rows * row_width * sizeof(float));
|
||||
src += (size_t) chunk_rows * row_width;
|
||||
remaining -= chunk_rows;
|
||||
write_pos = (write_pos + chunk_rows) % state.cross_ctx;
|
||||
}
|
||||
|
||||
state.target_window_ring_write_pos = write_pos;
|
||||
state.target_window_ring_filled = std::min(state.cross_ctx, state.target_window_ring_filled + n_rows);
|
||||
state.target_window_materialized = false;
|
||||
}
|
||||
|
||||
static void dflash_materialize_target_window_features(common_speculative_state_dflash & state) {
|
||||
if (state.target_window_materialized || state.target_window_rows <= 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const size_t row_width = (size_t) state.n_target_features;
|
||||
state.target_window.resize((size_t) state.target_window_rows * row_width);
|
||||
|
||||
const int32_t read_start = (state.target_window_ring_write_pos - state.target_window_rows + state.cross_ctx) % state.cross_ctx;
|
||||
const int32_t first_rows = std::min<int32_t>(state.target_window_rows, state.cross_ctx - read_start);
|
||||
std::memcpy(
|
||||
state.target_window.data(),
|
||||
state.target_window_ring.data() + (size_t) read_start * row_width,
|
||||
(size_t) first_rows * row_width * sizeof(float));
|
||||
|
||||
const int32_t second_rows = state.target_window_rows - first_rows;
|
||||
if (second_rows > 0) {
|
||||
std::memcpy(
|
||||
state.target_window.data() + (size_t) first_rows * row_width,
|
||||
state.target_window_ring.data(),
|
||||
(size_t) second_rows * row_width * sizeof(float));
|
||||
}
|
||||
|
||||
state.target_window_materialized = true;
|
||||
}
|
||||
|
||||
static bool dflash_append_target_features(
|
||||
common_speculative_state_dflash & state,
|
||||
const common_speculative_feature_view & features,
|
||||
const llama_batch & batch,
|
||||
llama_seq_id seq_id,
|
||||
dflash_append_breakdown * breakdown = nullptr) {
|
||||
GGML_UNUSED(batch);
|
||||
|
||||
if (features.kind != COMMON_SPECULATIVE_FEATURE_HIDDEN_STATE ||
|
||||
features.width != state.n_target_features ||
|
||||
features.rows.empty() ||
|
||||
state.cross_ctx <= 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const size_t row_width = (size_t) state.n_target_features;
|
||||
std::vector<float> new_rows;
|
||||
std::vector<llama_pos> new_positions;
|
||||
new_rows.reserve(features.rows.size() * row_width);
|
||||
new_positions.reserve(features.rows.size());
|
||||
|
||||
const int64_t t_filter_us = ggml_time_us();
|
||||
for (const auto & row : features.rows) {
|
||||
if (row.seq_id != seq_id || row.data == nullptr) {
|
||||
continue;
|
||||
}
|
||||
|
||||
new_positions.push_back(row.pos);
|
||||
new_rows.insert(new_rows.end(), row.data, row.data + row_width);
|
||||
}
|
||||
if (breakdown != nullptr) {
|
||||
breakdown->filter_us += (uint64_t) (ggml_time_us() - t_filter_us);
|
||||
}
|
||||
|
||||
if (new_positions.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const int32_t n_rows = (int32_t) new_positions.size();
|
||||
state.n_window_updates++;
|
||||
state.n_rows_seen += (size_t) n_rows;
|
||||
if (n_rows >= state.cross_ctx) {
|
||||
state.n_rows_dropped += (size_t) state.target_window_rows + (size_t) (n_rows - state.cross_ctx);
|
||||
const int32_t keep_from = n_rows - state.cross_ctx;
|
||||
const int64_t t_replace_us = ggml_time_us();
|
||||
state.target_window_pos.assign(new_positions.begin() + keep_from, new_positions.end());
|
||||
state.target_window_append_features.assign(
|
||||
new_rows.begin() + (ptrdiff_t) keep_from * (ptrdiff_t) row_width,
|
||||
new_rows.end());
|
||||
dflash_ring_reset_rows(state, state.target_window_append_features.data(), state.cross_ctx);
|
||||
if (breakdown != nullptr) {
|
||||
breakdown->replace_us += (uint64_t) (ggml_time_us() - t_replace_us);
|
||||
breakdown->replace_call = true;
|
||||
}
|
||||
|
||||
const int64_t t_commit_us = ggml_time_us();
|
||||
state.target_window_rows = state.cross_ctx;
|
||||
state.target_window_ring_filled = state.target_window_rows;
|
||||
state.last_target_pos = state.target_window_pos.empty() ? -1 : state.target_window_pos.back();
|
||||
dflash_record_window_update(state, 0, state.target_window_rows, true);
|
||||
if (breakdown != nullptr) {
|
||||
breakdown->commit_us += (uint64_t) (ggml_time_us() - t_commit_us);
|
||||
}
|
||||
|
||||
const int64_t t_log_us = ggml_time_us();
|
||||
dflash_contract_log_append(state, seq_id, new_positions);
|
||||
if (breakdown != nullptr) {
|
||||
breakdown->log_us += (uint64_t) (ggml_time_us() - t_log_us);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
const int32_t keep_old_rows = std::min<int32_t>(state.target_window_rows, state.cross_ctx - n_rows);
|
||||
state.n_rows_dropped += (size_t) std::max<int32_t>(0, state.target_window_rows - keep_old_rows);
|
||||
const int64_t t_window_alloc_us = ggml_time_us();
|
||||
std::vector<llama_pos> & next_window_pos = state.target_window_pos_stage;
|
||||
next_window_pos.resize((size_t) (keep_old_rows + n_rows));
|
||||
if (breakdown != nullptr) {
|
||||
breakdown->window_alloc_us += (uint64_t) (ggml_time_us() - t_window_alloc_us);
|
||||
}
|
||||
|
||||
if (keep_old_rows > 0) {
|
||||
const int64_t t_keep_old_us = ggml_time_us();
|
||||
std::copy(state.target_window_pos.end() - keep_old_rows, state.target_window_pos.end(), next_window_pos.begin());
|
||||
if (breakdown != nullptr) {
|
||||
breakdown->keep_old_us += (uint64_t) (ggml_time_us() - t_keep_old_us);
|
||||
}
|
||||
}
|
||||
|
||||
const int64_t t_new_rows_us = ggml_time_us();
|
||||
state.target_window_append_features.assign(new_rows.begin(), new_rows.end());
|
||||
dflash_ring_append_rows(state, state.target_window_append_features.data(), n_rows);
|
||||
std::copy(new_positions.begin(), new_positions.end(), next_window_pos.begin() + keep_old_rows);
|
||||
if (breakdown != nullptr) {
|
||||
breakdown->new_rows_us += (uint64_t) (ggml_time_us() - t_new_rows_us);
|
||||
}
|
||||
|
||||
const int64_t t_commit_us = ggml_time_us();
|
||||
state.target_window_pos.swap(next_window_pos);
|
||||
next_window_pos.clear();
|
||||
state.target_window_rows = keep_old_rows + n_rows;
|
||||
state.target_window_ring_filled = state.target_window_rows;
|
||||
state.last_target_pos = state.target_window_pos.empty() ? -1 : state.target_window_pos.back();
|
||||
dflash_record_window_update(state, keep_old_rows, n_rows, false);
|
||||
if (breakdown != nullptr) {
|
||||
breakdown->commit_us += (uint64_t) (ggml_time_us() - t_commit_us);
|
||||
}
|
||||
|
||||
const int64_t t_log_us = ggml_time_us();
|
||||
dflash_contract_log_append(state, seq_id, new_positions);
|
||||
if (breakdown != nullptr) {
|
||||
breakdown->log_us += (uint64_t) (ggml_time_us() - t_log_us);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static void dflash_clear_target_features(common_speculative_state_dflash & state) {
|
||||
state.target_window.clear();
|
||||
state.target_window_pos.clear();
|
||||
state.target_window_stage.clear();
|
||||
state.target_window_pos_stage.clear();
|
||||
state.target_window_append_features.clear();
|
||||
state.target_window_rows = 0;
|
||||
state.target_window_ring_write_pos = 0;
|
||||
state.target_window_ring_filled = 0;
|
||||
state.target_window_keep_rows = 0;
|
||||
state.target_window_append_rows = 0;
|
||||
state.target_window_replace = false;
|
||||
state.target_window_materialized = false;
|
||||
state.last_target_pos = -1;
|
||||
llama_reset_dflash_kv_cache_state(state.ctx_dft);
|
||||
}
|
||||
|
||||
static void dflash_context_shift(
|
||||
common_speculative_state_dflash & state,
|
||||
llama_pos kv_keep,
|
||||
llama_pos kv_discard,
|
||||
llama_pos kv_past) {
|
||||
if (kv_discard <= 0 || state.target_window_rows <= 0 || state.target_window_pos.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
dflash_materialize_target_window_features(state);
|
||||
|
||||
const size_t row_width = (size_t) state.n_target_features;
|
||||
const llama_pos discard_begin = kv_keep;
|
||||
const llama_pos discard_end = kv_keep + kv_discard;
|
||||
|
||||
std::vector<float> shifted_rows;
|
||||
std::vector<llama_pos> shifted_positions;
|
||||
shifted_rows.reserve(state.target_window.size());
|
||||
shifted_positions.reserve(state.target_window_pos.size());
|
||||
|
||||
for (int32_t row = 0; row < state.target_window_rows; ++row) {
|
||||
llama_pos pos = state.target_window_pos[(size_t) row];
|
||||
if (pos >= discard_begin && pos < discard_end) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (pos >= discard_end && pos < kv_past) {
|
||||
pos -= kv_discard;
|
||||
}
|
||||
|
||||
const float * row_src = state.target_window.data() + (size_t) row * row_width;
|
||||
shifted_rows.insert(shifted_rows.end(), row_src, row_src + row_width);
|
||||
shifted_positions.push_back(pos);
|
||||
}
|
||||
|
||||
state.target_window = std::move(shifted_rows);
|
||||
state.target_window_pos = std::move(shifted_positions);
|
||||
state.target_window_rows = (int32_t) state.target_window_pos.size();
|
||||
dflash_ring_reset_rows(state, state.target_window.data(), state.target_window_rows);
|
||||
state.last_target_pos = state.target_window_pos.empty() ? -1 : state.target_window_pos.back();
|
||||
dflash_record_window_update(state, 0, state.target_window_rows, true);
|
||||
llama_reset_dflash_kv_cache_state(state.ctx_dft);
|
||||
state.n_context_shifts++;
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -570,6 +570,9 @@ class Model:
|
||||
if chkhsh == "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273":
|
||||
# ref: https://huggingface.co/alvarobartt/grok-2-tokenizer
|
||||
res = "grok-2"
|
||||
if chkhsh == "972da7b59cec44d1f0a490a86c96df53859e486e481563e5dddac155013d87ac":
|
||||
# ref: https://huggingface.co/poolside/Laguna-XS.2
|
||||
res = "laguna"
|
||||
if chkhsh == "0ef9807a4087ebef797fc749390439009c3b9eda9ad1a097abbe738f486c01e5":
|
||||
# ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B
|
||||
res = "llama-bpe"
|
||||
@ -603,6 +606,9 @@ class Model:
|
||||
if chkhsh == "9c2227e4dd922002fb81bde4fc02b0483ca4f12911410dee2255e4987644e3f8":
|
||||
# ref: https://huggingface.co/CohereForAI/c4ai-command-r-v01
|
||||
res = "command-r"
|
||||
if chkhsh == "52df12b4c8d4176e7481aab4b6e8454d1fd0a210a04a574f6d4e067d10e23c3e":
|
||||
# ref: https://huggingface.co/CohereLabs/North-Mini-Code-1.0
|
||||
res = "cohere2_moe"
|
||||
if chkhsh == "e636dc30a262dcc0d8c323492e32ae2b70728f4df7dfe9737d9f920a282b8aea":
|
||||
# ref: https://huggingface.co/Qwen/Qwen1.5-7B
|
||||
res = "qwen2"
|
||||
@ -684,6 +690,9 @@ class Model:
|
||||
if chkhsh == "f4f37b6c8eb9ea29b3eac6bb8c8487c5ab7885f8d8022e67edc1c68ce8403e95":
|
||||
# ref: https://huggingface.co/MiniMaxAI/MiniMax-M2
|
||||
res = "minimax-m2"
|
||||
if chkhsh == "9dcf830ee9990cdbf78cc523a5f7bd9ad8f3f9890c2d3581d2785ad10f07049d":
|
||||
# ref: https://huggingface.co/JetBrains/Mellum2-12B-A2.5B-Base
|
||||
res = "mellum2"
|
||||
if res is None:
|
||||
logger.warning("\n")
|
||||
logger.warning("**************************************************************************************")
|
||||
@ -1580,7 +1589,12 @@ class LlamaModel(Model):
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
saved_intermediate_size = self.hparams.get("intermediate_size")
|
||||
saved_num_experts_per_tok = self.hparams.pop("num_experts_per_tok")
|
||||
self.hparams["intermediate_size"] = self.hparams["prefix_dense_intermediate_size"]
|
||||
super().set_gguf_parameters()
|
||||
self.hparams["intermediate_size"] = saved_intermediate_size
|
||||
self.hparams["num_experts_per_tok"] = saved_num_experts_per_tok
|
||||
hparams = self.hparams
|
||||
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
|
||||
|
||||
@ -3735,7 +3749,7 @@ class Gemma4Model(Gemma4BaseModel):
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
|
||||
@Model.register("Gemma4AssistantForCausalLM")
|
||||
@Model.register("Gemma4AssistantForCausalLM", "Gemma4UnifiedAssistantForCausalLM")
|
||||
class Gemma4AssistantModel(Gemma4BaseModel):
|
||||
model_arch = gguf.MODEL_ARCH.GEMMA4_MTP
|
||||
|
||||
@ -3950,6 +3964,86 @@ class CommandR2Model(Model):
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
|
||||
|
||||
|
||||
@Model.register("Cohere2MoeForCausalLM")
|
||||
class Cohere2MoeModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.COHERE2_MOE
|
||||
|
||||
_experts: list[dict[str, Tensor]] | None = None
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
saved_intermediate_size = self.hparams["intermediate_size"]
|
||||
saved_num_experts_per_tok = self.hparams.pop("num_experts_per_tok")
|
||||
self.hparams["intermediate_size"] = self.hparams["prefix_dense_intermediate_size"]
|
||||
super().set_gguf_parameters()
|
||||
self.hparams["intermediate_size"] = saved_intermediate_size
|
||||
self.hparams["num_experts_per_tok"] = saved_num_experts_per_tok
|
||||
hparams = self.hparams
|
||||
|
||||
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
|
||||
self.gguf_writer.add_logit_scale(hparams.get("logit_scale", 1.0))
|
||||
self.gguf_writer.add_sliding_window(hparams["sliding_window"])
|
||||
self.gguf_writer.add_sliding_window_pattern([
|
||||
layer_type == "sliding_attention"
|
||||
for layer_type in hparams["layer_types"]
|
||||
])
|
||||
self.gguf_writer.add_rope_dimension_count(hparams.get("head_dim", hparams["hidden_size"] // hparams["num_attention_heads"]))
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
|
||||
|
||||
self.gguf_writer.add_expert_feed_forward_length(hparams["intermediate_size"])
|
||||
self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"])
|
||||
self.gguf_writer.add_expert_count(hparams["num_experts"])
|
||||
self.gguf_writer.add_expert_used_count(hparams["num_experts_per_tok"])
|
||||
self.gguf_writer.add_expert_weights_norm(bool(hparams.get("norm_topk_prob", False)))
|
||||
|
||||
expert_selection_fn = hparams.get("expert_selection_fn", "softmax")
|
||||
if expert_selection_fn != "sigmoid":
|
||||
raise ValueError(f"Unsupported Cohere2-MoE expert_selection_fn={expert_selection_fn!r}")
|
||||
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
|
||||
|
||||
if hparams.get("num_shared_experts", 0) != 0:
|
||||
raise ValueError("Cohere2-MoE shared experts are not supported in this GGUF converter yet")
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# Cohere2-MoE HF tensors already use the interleaved RoPE layout expected here.
|
||||
|
||||
if ".mlp.experts." in name:
|
||||
n_experts = self.hparams["num_experts"]
|
||||
assert bid is not None
|
||||
|
||||
if self._experts is None:
|
||||
self._experts = [{} for _ in range(self.block_count)]
|
||||
|
||||
self._experts[bid][name] = data_torch
|
||||
|
||||
if len(self._experts[bid]) < n_experts * 3:
|
||||
return []
|
||||
|
||||
tensors: list[tuple[str, Tensor]] = []
|
||||
for src, dst in [
|
||||
("gate_proj", "gate_proj"),
|
||||
("down_proj", "down_proj"),
|
||||
("up_proj", "up_proj"),
|
||||
]:
|
||||
datas: list[Tensor] = []
|
||||
for xid in range(n_experts):
|
||||
ename = f"model.layers.{bid}.mlp.experts.{xid}.{src}.weight"
|
||||
datas.append(self._experts[bid][ename])
|
||||
del self._experts[bid][ename]
|
||||
|
||||
merged_name = f"model.layers.{bid}.mlp.experts.{dst}.weight"
|
||||
tensors.append((self.map_tensor_name(merged_name), torch.stack(datas, dim=0)))
|
||||
yield from tensors
|
||||
return
|
||||
|
||||
if name == "model.embed_tokens.weight":
|
||||
yield self.map_tensor_name(name), data_torch
|
||||
if self.tensor_names is None or "lm_head.weight" not in self.tensor_names:
|
||||
yield self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT, suffix=".weight"), data_torch
|
||||
return
|
||||
|
||||
yield self.map_tensor_name(name), data_torch
|
||||
|
||||
|
||||
@Model.register("OlmoForCausalLM")
|
||||
@Model.register("OLMoForCausalLM")
|
||||
class OlmoModel(Model):
|
||||
@ -5307,6 +5401,138 @@ class BailingMoeV2Model(Model):
|
||||
raise ValueError(f"Unprocessed experts: {experts}")
|
||||
|
||||
|
||||
@Model.register("LagunaForCausalLM")
|
||||
class LagunaModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.LAGUNA
|
||||
|
||||
_experts: list[dict[str, Tensor]] | None = None
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
hparams = self.hparams
|
||||
arch = gguf.MODEL_ARCH_NAMES[self.model_arch]
|
||||
n_layers = int(hparams["num_hidden_layers"])
|
||||
n_head_base = int(hparams["num_attention_heads"])
|
||||
n_kv_base = int(hparams.get("num_key_value_heads", n_head_base))
|
||||
head_dim = int(hparams.get("head_dim", hparams["hidden_size"] // n_head_base))
|
||||
|
||||
heads_per_layer = hparams.get("num_attention_heads_per_layer")
|
||||
kv_per_layer = hparams.get("num_key_value_heads_per_layer")
|
||||
|
||||
head_arr: list[int] = []
|
||||
kv_arr: list[int] = []
|
||||
for i in range(n_layers):
|
||||
head_arr.append(int(heads_per_layer[i]) if heads_per_layer is not None else n_head_base)
|
||||
kv_arr.append(int(kv_per_layer[i]) if kv_per_layer is not None else n_kv_base)
|
||||
|
||||
rope_params = hparams.get("rope_parameters", {})
|
||||
full_rope = rope_params.get("full_attention", rope_params)
|
||||
swa_rope = rope_params.get("sliding_attention", {})
|
||||
|
||||
self.gguf_writer.add_context_length(int(hparams["max_position_embeddings"]))
|
||||
self.gguf_writer.add_embedding_length(int(hparams["hidden_size"]))
|
||||
self.gguf_writer.add_block_count(n_layers)
|
||||
self.gguf_writer.add_feed_forward_length(int(hparams["intermediate_size"]))
|
||||
self.gguf_writer.add_head_count(head_arr)
|
||||
if all(n_kv == kv_arr[0] for n_kv in kv_arr):
|
||||
self.gguf_writer.add_head_count_kv(kv_arr[0])
|
||||
else:
|
||||
self.gguf_writer.add_head_count_kv(kv_arr)
|
||||
self.gguf_writer.add_key_length(head_dim)
|
||||
self.gguf_writer.add_value_length(head_dim)
|
||||
self.gguf_writer.add_layer_norm_rms_eps(float(hparams["rms_norm_eps"]))
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
|
||||
self.gguf_writer.add_sliding_window(int(hparams["sliding_window"]))
|
||||
self.gguf_writer.add_rope_dimension_count(head_dim // 2)
|
||||
self.gguf_writer.add_uint32(f"{arch}.rope.dimension_count_swa", head_dim)
|
||||
self.gguf_writer.add_rope_freq_base(float(full_rope.get("rope_theta", 500000.0)))
|
||||
self.gguf_writer.add_float32(f"{arch}.rope.freq_base_swa", float(swa_rope.get("rope_theta", 10000.0)))
|
||||
if full_rope.get("rope_type") == "yarn":
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
|
||||
self.gguf_writer.add_rope_scaling_factor(float(full_rope.get("factor", 1.0)))
|
||||
self.gguf_writer.add_rope_scaling_orig_ctx_len(int(full_rope.get(
|
||||
"original_max_position_embeddings",
|
||||
rope_params.get("original_max_position_embeddings", hparams["max_position_embeddings"]),
|
||||
)))
|
||||
self.gguf_writer.add_rope_scaling_yarn_ext_factor(float(full_rope.get("factor", 1.0)))
|
||||
self.gguf_writer.add_rope_scaling_yarn_attn_factor(float(full_rope.get("attention_factor", 1.0)))
|
||||
self.gguf_writer.add_rope_scaling_yarn_beta_fast(float(full_rope.get("beta_fast", 32.0)))
|
||||
self.gguf_writer.add_rope_scaling_yarn_beta_slow(float(full_rope.get("beta_slow", 1.0)))
|
||||
|
||||
self.gguf_writer.add_expert_count(int(hparams["num_experts"]))
|
||||
self.gguf_writer.add_expert_used_count(int(hparams["num_experts_per_tok"]))
|
||||
self.gguf_writer.add_expert_feed_forward_length(int(hparams["moe_intermediate_size"]))
|
||||
if (shared_dim := hparams.get("shared_expert_intermediate_size")) is not None and int(shared_dim) > 0:
|
||||
self.gguf_writer.add_expert_shared_feed_forward_length(int(shared_dim))
|
||||
if (routing_scale := hparams.get("moe_routed_scaling_factor")) is not None:
|
||||
self.gguf_writer.add_expert_weights_scale(float(routing_scale))
|
||||
self.gguf_writer.add_expert_weights_norm(True)
|
||||
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
|
||||
|
||||
leading_dense = 0
|
||||
for mlp_type in hparams.get("mlp_layer_types", []):
|
||||
if mlp_type != "dense":
|
||||
break
|
||||
leading_dense += 1
|
||||
self.gguf_writer.add_uint32(f"{arch}.leading_dense_block_count", leading_dense)
|
||||
|
||||
if hparams.get("moe_apply_router_weight_on_input", False):
|
||||
raise ValueError("moe_apply_router_weight_on_input=True is not supported for Laguna")
|
||||
|
||||
def set_vocab(self) -> None:
|
||||
super().set_vocab()
|
||||
if isinstance(eos_token_id := self.hparams.get("eos_token_id"), list) and len(eos_token_id) > 1:
|
||||
# Poolside uses token 24 (</assistant>) as a turn boundary.
|
||||
self.gguf_writer.add_eot_token_id(int(eos_token_id[1]))
|
||||
template_file = self.dir_model / "chat_template.jinja"
|
||||
if template_file.is_file():
|
||||
self.gguf_writer.add_chat_template(template_file.read_text(encoding="utf-8"))
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
if bid is not None and name in (
|
||||
f"model.layers.{bid}.mlp.experts.e_score_correction_bias",
|
||||
f"model.layers.{bid}.mlp.experts.e_score_correction",
|
||||
):
|
||||
# The C++ loader asks for this tensor through the ".bias" suffix.
|
||||
# Keep the Laguna converter aligned with existing community GGUFs.
|
||||
yield f"blk.{bid}.exp_probs_b.bias", data_torch
|
||||
return
|
||||
|
||||
if name.endswith(".self_attn.g_proj.weight"):
|
||||
# HF stores the head-wise attention gate with a singleton dimension.
|
||||
data_torch = data_torch.squeeze().contiguous()
|
||||
|
||||
if bid is not None and re.match(r"model\.layers\.\d+\.mlp\.experts\.\d+\.(gate_proj|up_proj|down_proj)\.weight$", name):
|
||||
n_experts = int(self.find_hparam(["num_experts"]))
|
||||
if self._experts is None:
|
||||
self._experts = [{} for _ in range(self.block_count)]
|
||||
|
||||
self._experts[bid][name] = data_torch
|
||||
if len(self._experts[bid]) < n_experts * 3:
|
||||
return
|
||||
|
||||
for w_name in ("down_proj", "gate_proj", "up_proj"):
|
||||
datas: list[Tensor] = []
|
||||
for xid in range(n_experts):
|
||||
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
|
||||
datas.append(self._experts[bid][ename])
|
||||
del self._experts[bid][ename]
|
||||
|
||||
merged = torch.stack(datas, dim=0)
|
||||
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
|
||||
yield from super().modify_tensors(merged, merged_name, bid)
|
||||
return
|
||||
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
def prepare_tensors(self):
|
||||
super().prepare_tensors()
|
||||
if self._experts is not None:
|
||||
experts = [k for d in self._experts for k in d.keys()]
|
||||
if experts:
|
||||
raise ValueError(f"Unprocessed experts: {experts}")
|
||||
|
||||
|
||||
###### CONVERSION LOGIC ######
|
||||
|
||||
|
||||
|
||||
@ -6,6 +6,7 @@
|
||||
|
||||
#include "common.h"
|
||||
#include "llama.h"
|
||||
#include "llama-spec-features.h"
|
||||
#include "log.h"
|
||||
#include "sampling.h"
|
||||
#include "speculative.h"
|
||||
@ -13,12 +14,8 @@
|
||||
#include "mtmd-helper.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <atomic>
|
||||
#include <cstring>
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <regex>
|
||||
#include <sstream>
|
||||
#include <exception>
|
||||
|
||||
static void server_prompt_checkpoint_update(server_prompt_checkpoint & ckpt, llama_context * ctx, int id, int64_t n_tokens, llama_pos pos_min = -1, llama_pos pos_max = -1, int32_t offset = 0) {
|
||||
@ -49,83 +46,6 @@ static void log_text(const gpt_params & params_base, const std::string & text) {
|
||||
}
|
||||
}
|
||||
|
||||
static bool server_dflash_contract_log_enabled() {
|
||||
const char * env = std::getenv("IK_DFLASH_CONTRACT_LOG");
|
||||
if (env == nullptr || *env == '\0') {
|
||||
return false;
|
||||
}
|
||||
|
||||
return std::strcmp(env, "0") != 0 &&
|
||||
std::strcmp(env, "false") != 0 &&
|
||||
std::strcmp(env, "off") != 0;
|
||||
}
|
||||
|
||||
static std::string server_dflash_contract_format_indices(
|
||||
const std::vector<int32_t> & values,
|
||||
size_t edge_count = 4) {
|
||||
std::ostringstream oss;
|
||||
oss << '[';
|
||||
if (values.empty()) {
|
||||
oss << ']';
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
const size_t head = std::min(edge_count, values.size());
|
||||
for (size_t i = 0; i < head; ++i) {
|
||||
if (i > 0) {
|
||||
oss << ',';
|
||||
}
|
||||
oss << values[i];
|
||||
}
|
||||
|
||||
if (values.size() > edge_count * 2) {
|
||||
oss << ",...,";
|
||||
for (size_t i = values.size() - edge_count; i < values.size(); ++i) {
|
||||
if (i > values.size() - edge_count) {
|
||||
oss << ',';
|
||||
}
|
||||
oss << values[i];
|
||||
}
|
||||
} else {
|
||||
for (size_t i = head; i < values.size(); ++i) {
|
||||
oss << ',' << values[i];
|
||||
}
|
||||
}
|
||||
|
||||
oss << ']';
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
static void server_dflash_contract_log_accept(
|
||||
const server_slot & slot,
|
||||
common_speculative_type spec_type_used,
|
||||
const char * path,
|
||||
bool any_rejected,
|
||||
size_t n_draft,
|
||||
const std::vector<llama_token> & ids,
|
||||
llama_pos pos_base,
|
||||
const std::vector<int32_t> & output_indices) {
|
||||
if (!server_dflash_contract_log_enabled() || spec_type_used != COMMON_SPECULATIVE_TYPE_DFLASH) {
|
||||
return;
|
||||
}
|
||||
|
||||
static std::atomic<uint64_t> counter = 0;
|
||||
const uint64_t ordinal = counter.fetch_add(1, std::memory_order_relaxed);
|
||||
if (ordinal >= 8) {
|
||||
return;
|
||||
}
|
||||
|
||||
LLAMA_LOG_INFO("dflash contract accept[%llu]: slot=%d path=%s rejected=%s drafted=%zu accepted=%zu pos_base=%d output_indices=%s\n",
|
||||
(unsigned long long) (ordinal + 1),
|
||||
slot.id,
|
||||
path,
|
||||
any_rejected ? "true" : "false",
|
||||
n_draft,
|
||||
ids.size(),
|
||||
(int) pos_base,
|
||||
server_dflash_contract_format_indices(output_indices).c_str());
|
||||
}
|
||||
|
||||
static bool server_slot_prompt_batch_overlaps(
|
||||
const server_slot & slot,
|
||||
int32_t batch_i0,
|
||||
@ -158,12 +78,13 @@ static bool server_response_needs_chat_parse(oaicompat_type oaicompat) {
|
||||
oaicompat == OAICOMPAT_TYPE_RESP;
|
||||
}
|
||||
|
||||
static bool server_speculative_has_dflash(const common_params_speculative & spec) {
|
||||
return spec.has_stage_type(COMMON_SPECULATIVE_TYPE_DFLASH);
|
||||
static bool server_speculative_uses_target_features(const common_params_speculative & spec) {
|
||||
return spec.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP) ||
|
||||
spec.has_stage_type(COMMON_SPECULATIVE_TYPE_DFLASH);
|
||||
}
|
||||
|
||||
static bool server_speculative_has_target_features(const common_params_speculative & spec) {
|
||||
return spec.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP) || server_speculative_has_dflash(spec);
|
||||
static bool server_speculative_requires_single_slot(const common_params_speculative & spec) {
|
||||
return spec.has_stage_chain();
|
||||
}
|
||||
|
||||
static bool server_speculative_same_stage_types(
|
||||
@ -295,8 +216,8 @@ bool server_context::load_model(const gpt_params& params_) {
|
||||
|
||||
common_speculative_prepare_startup(params_base, false);
|
||||
|
||||
if (server_speculative_has_dflash(params_base.speculative) && params_base.n_parallel > 1) {
|
||||
LOG_ERROR("DFlash is currently limited to a single server slot (-np 1).\n", {
|
||||
if (server_speculative_requires_single_slot(params_base.speculative) && params_base.n_parallel > 1) {
|
||||
LOG_ERROR("Speculative decoding is currently limited to a single server slot (-np 1).\n", {
|
||||
{"n_parallel", params_base.n_parallel},
|
||||
});
|
||||
return false;
|
||||
@ -4127,7 +4048,7 @@ void server_context::speculative_decoding_accept() {
|
||||
}
|
||||
|
||||
std::vector<int32_t> accepted_output_indices;
|
||||
if (server_speculative_has_target_features(slot.params.speculative)) {
|
||||
if (server_speculative_uses_target_features(slot.params.speculative)) {
|
||||
if (!ids.empty()) {
|
||||
accepted_output_indices.assign(slot.i_batch_dft.begin(), slot.i_batch_dft.begin() + ids.size());
|
||||
}
|
||||
@ -4160,14 +4081,14 @@ void server_context::speculative_decoding_accept() {
|
||||
const common_speculative_checkpoint * ckpt = common_speculative_get_checkpoint(slot.spec);
|
||||
const bool will_restore = any_rejected && ckpt != nullptr && ckpt->valid;
|
||||
|
||||
if (server_speculative_has_target_features(slot.params.speculative) && !accepted_output_indices.empty()) {
|
||||
server_dflash_contract_log_accept(
|
||||
slot,
|
||||
spec_type_used,
|
||||
if (server_speculative_uses_target_features(slot.params.speculative) && !accepted_output_indices.empty()) {
|
||||
llama_dflash_contract_log_accept(
|
||||
slot.id,
|
||||
spec_type_used == COMMON_SPECULATIVE_TYPE_DFLASH,
|
||||
will_restore ? "restore" : "direct",
|
||||
any_rejected,
|
||||
n_draft,
|
||||
ids,
|
||||
ids.size(),
|
||||
spec_pos_base,
|
||||
accepted_output_indices);
|
||||
}
|
||||
@ -4556,9 +4477,9 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
|
||||
continue; // continue loop of n_batch
|
||||
}
|
||||
|
||||
if (server_speculative_has_target_features(params_base.speculative)) {
|
||||
if (server_speculative_uses_target_features(params_base.speculative)) {
|
||||
for (auto & slot : slots) {
|
||||
if (!slot.spec || !server_speculative_has_target_features(slot.params.speculative)) {
|
||||
if (!slot.spec || !server_speculative_uses_target_features(slot.params.speculative)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
||||
340
src/llama-dflash-profile.h
Normal file
340
src/llama-dflash-profile.h
Normal file
@ -0,0 +1,340 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
|
||||
enum llama_dflash_kv_node_kind {
|
||||
LLAMA_DFLASH_KV_NODE_NONE = 0,
|
||||
LLAMA_DFLASH_KV_NODE_FUSED_TARGET,
|
||||
LLAMA_DFLASH_KV_NODE_K_PROJ,
|
||||
LLAMA_DFLASH_KV_NODE_K_NORM,
|
||||
LLAMA_DFLASH_KV_NODE_K_ROPE,
|
||||
LLAMA_DFLASH_KV_NODE_V_PROJ,
|
||||
LLAMA_DFLASH_KV_NODE_K_STORE,
|
||||
LLAMA_DFLASH_KV_NODE_V_STORE,
|
||||
};
|
||||
|
||||
enum llama_dflash_main_node_kind {
|
||||
LLAMA_DFLASH_MAIN_NODE_NONE = 0,
|
||||
LLAMA_DFLASH_MAIN_NODE_QCUR,
|
||||
LLAMA_DFLASH_MAIN_NODE_K_DRAFT,
|
||||
LLAMA_DFLASH_MAIN_NODE_V_DRAFT,
|
||||
LLAMA_DFLASH_MAIN_NODE_K_CTX_VIEW,
|
||||
LLAMA_DFLASH_MAIN_NODE_V_CTX_VIEW,
|
||||
LLAMA_DFLASH_MAIN_NODE_K_CONCAT,
|
||||
LLAMA_DFLASH_MAIN_NODE_V_CONCAT,
|
||||
LLAMA_DFLASH_MAIN_NODE_K_PAD,
|
||||
LLAMA_DFLASH_MAIN_NODE_V_PAD,
|
||||
LLAMA_DFLASH_MAIN_NODE_K_PERM_CONT,
|
||||
LLAMA_DFLASH_MAIN_NODE_V_PERM_CONT,
|
||||
LLAMA_DFLASH_MAIN_NODE_FLASH_ATTN,
|
||||
LLAMA_DFLASH_MAIN_NODE_ATTN_OUT,
|
||||
LLAMA_DFLASH_MAIN_NODE_FFN,
|
||||
LLAMA_DFLASH_MAIN_NODE_RESULT_ROWS,
|
||||
LLAMA_DFLASH_MAIN_NODE_RESULT_NORM,
|
||||
LLAMA_DFLASH_MAIN_NODE_RESULT,
|
||||
};
|
||||
|
||||
struct llama_dflash_kv_node_profiler {
|
||||
llama_dflash_profile_stats * profile = nullptr;
|
||||
int64_t t_start_us = 0;
|
||||
llama_dflash_kv_node_kind active_kind = LLAMA_DFLASH_KV_NODE_NONE;
|
||||
};
|
||||
|
||||
struct llama_dflash_main_node_profiler {
|
||||
llama_dflash_profile_stats * profile = nullptr;
|
||||
ggml_backend_sched_eval_callback prev_callback = nullptr;
|
||||
void * prev_user_data = nullptr;
|
||||
bool prev_active = false;
|
||||
int64_t t_start_us = 0;
|
||||
llama_dflash_main_node_kind active_kind = LLAMA_DFLASH_MAIN_NODE_NONE;
|
||||
};
|
||||
|
||||
static inline bool llama_dflash_tensor_name_has_prefix(const struct ggml_tensor * tensor, const char * prefix) {
|
||||
if (tensor == nullptr || prefix == nullptr || prefix[0] == '\0') {
|
||||
return false;
|
||||
}
|
||||
|
||||
return std::strncmp(tensor->name, prefix, std::strlen(prefix)) == 0;
|
||||
}
|
||||
|
||||
static inline bool llama_dflash_tensor_name_matches_label(const struct ggml_tensor * tensor, const char * label) {
|
||||
if (!llama_dflash_tensor_name_has_prefix(tensor, label)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const size_t label_len = std::strlen(label);
|
||||
const char next = tensor->name[label_len];
|
||||
return next == '\0' || next == '-';
|
||||
}
|
||||
|
||||
static inline llama_dflash_kv_node_kind llama_dflash_kv_node_kind_from_tensor(const struct ggml_tensor * tensor) {
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_fused_target")) {
|
||||
return LLAMA_DFLASH_KV_NODE_FUSED_TARGET;
|
||||
}
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_proj")) {
|
||||
return LLAMA_DFLASH_KV_NODE_K_PROJ;
|
||||
}
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_norm")) {
|
||||
return LLAMA_DFLASH_KV_NODE_K_NORM;
|
||||
}
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_rope")) {
|
||||
return LLAMA_DFLASH_KV_NODE_K_ROPE;
|
||||
}
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_v_proj")) {
|
||||
return LLAMA_DFLASH_KV_NODE_V_PROJ;
|
||||
}
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_store")) {
|
||||
return LLAMA_DFLASH_KV_NODE_K_STORE;
|
||||
}
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_v_store")) {
|
||||
return LLAMA_DFLASH_KV_NODE_V_STORE;
|
||||
}
|
||||
|
||||
return LLAMA_DFLASH_KV_NODE_NONE;
|
||||
}
|
||||
|
||||
static inline void llama_dflash_kv_node_profile_add(
|
||||
llama_dflash_profile_stats & profile,
|
||||
llama_dflash_kv_node_kind kind,
|
||||
uint64_t elapsed_us) {
|
||||
switch (kind) {
|
||||
case LLAMA_DFLASH_KV_NODE_FUSED_TARGET:
|
||||
profile.graph_kv_node_fused_target_calls++;
|
||||
profile.graph_kv_node_fused_target_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_KV_NODE_K_PROJ:
|
||||
profile.graph_kv_node_k_proj_calls++;
|
||||
profile.graph_kv_node_k_proj_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_KV_NODE_K_NORM:
|
||||
profile.graph_kv_node_k_norm_calls++;
|
||||
profile.graph_kv_node_k_norm_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_KV_NODE_K_ROPE:
|
||||
profile.graph_kv_node_k_rope_calls++;
|
||||
profile.graph_kv_node_k_rope_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_KV_NODE_V_PROJ:
|
||||
profile.graph_kv_node_v_proj_calls++;
|
||||
profile.graph_kv_node_v_proj_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_KV_NODE_K_STORE:
|
||||
profile.graph_kv_node_k_store_calls++;
|
||||
profile.graph_kv_node_k_store_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_KV_NODE_V_STORE:
|
||||
profile.graph_kv_node_v_store_calls++;
|
||||
profile.graph_kv_node_v_store_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_KV_NODE_NONE:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
static inline llama_dflash_main_node_kind llama_dflash_main_node_kind_from_tensor(const struct ggml_tensor * tensor) {
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "Qcur")) {
|
||||
return LLAMA_DFLASH_MAIN_NODE_QCUR;
|
||||
}
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "Kcur_noise")) {
|
||||
return LLAMA_DFLASH_MAIN_NODE_K_DRAFT;
|
||||
}
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "Vcur_noise")) {
|
||||
return LLAMA_DFLASH_MAIN_NODE_V_DRAFT;
|
||||
}
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "Kcur_ctx_cache")) {
|
||||
return LLAMA_DFLASH_MAIN_NODE_K_CTX_VIEW;
|
||||
}
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "Vcur_ctx_cache")) {
|
||||
return LLAMA_DFLASH_MAIN_NODE_V_CTX_VIEW;
|
||||
}
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_k_concat")) {
|
||||
return LLAMA_DFLASH_MAIN_NODE_K_CONCAT;
|
||||
}
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_v_concat")) {
|
||||
return LLAMA_DFLASH_MAIN_NODE_V_CONCAT;
|
||||
}
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_k_pad")) {
|
||||
return LLAMA_DFLASH_MAIN_NODE_K_PAD;
|
||||
}
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_v_pad")) {
|
||||
return LLAMA_DFLASH_MAIN_NODE_V_PAD;
|
||||
}
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_k_perm_cont")) {
|
||||
return LLAMA_DFLASH_MAIN_NODE_K_PERM_CONT;
|
||||
}
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_v_perm_cont")) {
|
||||
return LLAMA_DFLASH_MAIN_NODE_V_PERM_CONT;
|
||||
}
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "flash_attn_reshaped")) {
|
||||
return LLAMA_DFLASH_MAIN_NODE_NONE;
|
||||
}
|
||||
if (llama_dflash_tensor_name_matches_label(tensor, "flash_attn")) {
|
||||
return LLAMA_DFLASH_MAIN_NODE_FLASH_ATTN;
|
||||
}
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "kqv_out")) {
|
||||
return LLAMA_DFLASH_MAIN_NODE_ATTN_OUT;
|
||||
}
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "ffn_out")) {
|
||||
return LLAMA_DFLASH_MAIN_NODE_FFN;
|
||||
}
|
||||
if (llama_dflash_tensor_name_matches_label(tensor, "result_output_rows")) {
|
||||
return LLAMA_DFLASH_MAIN_NODE_RESULT_ROWS;
|
||||
}
|
||||
if (llama_dflash_tensor_name_matches_label(tensor, "result_norm")) {
|
||||
return LLAMA_DFLASH_MAIN_NODE_RESULT_NORM;
|
||||
}
|
||||
if (llama_dflash_tensor_name_matches_label(tensor, "output")) {
|
||||
return LLAMA_DFLASH_MAIN_NODE_RESULT;
|
||||
}
|
||||
if (llama_dflash_tensor_name_matches_label(tensor, "result_output")) {
|
||||
return LLAMA_DFLASH_MAIN_NODE_RESULT;
|
||||
}
|
||||
|
||||
return LLAMA_DFLASH_MAIN_NODE_NONE;
|
||||
}
|
||||
|
||||
static inline void llama_dflash_main_node_profile_add(
|
||||
llama_dflash_profile_stats & profile,
|
||||
llama_dflash_main_node_kind kind,
|
||||
uint64_t elapsed_us) {
|
||||
switch (kind) {
|
||||
case LLAMA_DFLASH_MAIN_NODE_QCUR:
|
||||
profile.graph_main_node_qcur_calls++;
|
||||
profile.graph_main_node_qcur_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_MAIN_NODE_K_DRAFT:
|
||||
profile.graph_main_node_k_draft_calls++;
|
||||
profile.graph_main_node_k_draft_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_MAIN_NODE_V_DRAFT:
|
||||
profile.graph_main_node_v_draft_calls++;
|
||||
profile.graph_main_node_v_draft_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_MAIN_NODE_K_CTX_VIEW:
|
||||
profile.graph_main_node_k_ctx_view_calls++;
|
||||
profile.graph_main_node_k_ctx_view_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_MAIN_NODE_V_CTX_VIEW:
|
||||
profile.graph_main_node_v_ctx_view_calls++;
|
||||
profile.graph_main_node_v_ctx_view_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_MAIN_NODE_K_CONCAT:
|
||||
profile.graph_main_node_k_concat_calls++;
|
||||
profile.graph_main_node_k_concat_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_MAIN_NODE_V_CONCAT:
|
||||
profile.graph_main_node_v_concat_calls++;
|
||||
profile.graph_main_node_v_concat_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_MAIN_NODE_K_PAD:
|
||||
profile.graph_main_node_k_pad_calls++;
|
||||
profile.graph_main_node_k_pad_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_MAIN_NODE_V_PAD:
|
||||
profile.graph_main_node_v_pad_calls++;
|
||||
profile.graph_main_node_v_pad_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_MAIN_NODE_K_PERM_CONT:
|
||||
profile.graph_main_node_k_perm_cont_calls++;
|
||||
profile.graph_main_node_k_perm_cont_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_MAIN_NODE_V_PERM_CONT:
|
||||
profile.graph_main_node_v_perm_cont_calls++;
|
||||
profile.graph_main_node_v_perm_cont_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_MAIN_NODE_FLASH_ATTN:
|
||||
profile.graph_main_node_flash_attn_calls++;
|
||||
profile.graph_main_node_flash_attn_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_MAIN_NODE_ATTN_OUT:
|
||||
profile.graph_main_node_attn_out_calls++;
|
||||
profile.graph_main_node_attn_out_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_MAIN_NODE_FFN:
|
||||
profile.graph_main_node_ffn_calls++;
|
||||
profile.graph_main_node_ffn_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_MAIN_NODE_RESULT_ROWS:
|
||||
profile.graph_main_node_result_rows_calls++;
|
||||
profile.graph_main_node_result_rows_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_MAIN_NODE_RESULT_NORM:
|
||||
profile.graph_main_node_result_norm_calls++;
|
||||
profile.graph_main_node_result_norm_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_MAIN_NODE_RESULT:
|
||||
profile.graph_main_node_result_calls++;
|
||||
profile.graph_main_node_result_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_MAIN_NODE_NONE:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
static inline bool llama_dflash_kv_node_eval_callback(struct ggml_tensor * tensor, bool ask, void * user_data) {
|
||||
auto * profiler = static_cast<llama_dflash_kv_node_profiler *>(user_data);
|
||||
if (profiler == nullptr || profiler->profile == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const llama_dflash_kv_node_kind kind = llama_dflash_kv_node_kind_from_tensor(tensor);
|
||||
if (ask) {
|
||||
if (kind == LLAMA_DFLASH_KV_NODE_NONE) {
|
||||
return false;
|
||||
}
|
||||
|
||||
profiler->active_kind = kind;
|
||||
profiler->t_start_us = ggml_time_us();
|
||||
return true;
|
||||
}
|
||||
|
||||
if (kind != LLAMA_DFLASH_KV_NODE_NONE && profiler->active_kind == kind && profiler->t_start_us > 0) {
|
||||
llama_dflash_kv_node_profile_add(*profiler->profile, kind, (uint64_t) (ggml_time_us() - profiler->t_start_us));
|
||||
}
|
||||
|
||||
profiler->active_kind = LLAMA_DFLASH_KV_NODE_NONE;
|
||||
profiler->t_start_us = 0;
|
||||
return true;
|
||||
}
|
||||
|
||||
static inline bool llama_dflash_main_node_eval_callback(struct ggml_tensor * tensor, bool ask, void * user_data) {
|
||||
auto * profiler = static_cast<llama_dflash_main_node_profiler *>(user_data);
|
||||
if (profiler == nullptr || profiler->profile == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const llama_dflash_main_node_kind kind = llama_dflash_main_node_kind_from_tensor(tensor);
|
||||
if (ask) {
|
||||
profiler->prev_active = profiler->prev_callback != nullptr
|
||||
? profiler->prev_callback(tensor, ask, profiler->prev_user_data)
|
||||
: false;
|
||||
|
||||
if (kind == LLAMA_DFLASH_MAIN_NODE_NONE) {
|
||||
profiler->active_kind = LLAMA_DFLASH_MAIN_NODE_NONE;
|
||||
profiler->t_start_us = 0;
|
||||
return profiler->prev_active;
|
||||
}
|
||||
|
||||
profiler->active_kind = kind;
|
||||
profiler->t_start_us = ggml_time_us();
|
||||
return true;
|
||||
}
|
||||
|
||||
bool prev_result = false;
|
||||
if (profiler->prev_active && profiler->prev_callback != nullptr) {
|
||||
prev_result = profiler->prev_callback(tensor, ask, profiler->prev_user_data);
|
||||
}
|
||||
|
||||
const bool tracked = kind != LLAMA_DFLASH_MAIN_NODE_NONE &&
|
||||
profiler->active_kind == kind &&
|
||||
profiler->t_start_us > 0;
|
||||
if (tracked) {
|
||||
llama_dflash_main_node_profile_add(*profiler->profile, kind, (uint64_t) (ggml_time_us() - profiler->t_start_us));
|
||||
}
|
||||
|
||||
profiler->prev_active = false;
|
||||
profiler->active_kind = LLAMA_DFLASH_MAIN_NODE_NONE;
|
||||
profiler->t_start_us = 0;
|
||||
return prev_result || tracked;
|
||||
}
|
||||
@ -5,6 +5,7 @@
|
||||
#include "llama-context.h"
|
||||
#include "llama-model.h"
|
||||
#include "llama-spec-features.h"
|
||||
#include "llama-dflash-profile.h"
|
||||
|
||||
#include "ggml.h"
|
||||
#include "ggml-backend.h"
|
||||
@ -27,342 +28,6 @@ static bool llama_dflash_stats_log_enabled() {
|
||||
return llama_env_flag_enabled_local("IK_DFLASH_STATS_LOG");
|
||||
}
|
||||
|
||||
enum llama_dflash_kv_node_kind {
|
||||
LLAMA_DFLASH_KV_NODE_NONE = 0,
|
||||
LLAMA_DFLASH_KV_NODE_FUSED_TARGET,
|
||||
LLAMA_DFLASH_KV_NODE_K_PROJ,
|
||||
LLAMA_DFLASH_KV_NODE_K_NORM,
|
||||
LLAMA_DFLASH_KV_NODE_K_ROPE,
|
||||
LLAMA_DFLASH_KV_NODE_V_PROJ,
|
||||
LLAMA_DFLASH_KV_NODE_K_STORE,
|
||||
LLAMA_DFLASH_KV_NODE_V_STORE,
|
||||
};
|
||||
|
||||
enum llama_dflash_main_node_kind {
|
||||
LLAMA_DFLASH_MAIN_NODE_NONE = 0,
|
||||
LLAMA_DFLASH_MAIN_NODE_QCUR,
|
||||
LLAMA_DFLASH_MAIN_NODE_K_DRAFT,
|
||||
LLAMA_DFLASH_MAIN_NODE_V_DRAFT,
|
||||
LLAMA_DFLASH_MAIN_NODE_K_CTX_VIEW,
|
||||
LLAMA_DFLASH_MAIN_NODE_V_CTX_VIEW,
|
||||
LLAMA_DFLASH_MAIN_NODE_K_CONCAT,
|
||||
LLAMA_DFLASH_MAIN_NODE_V_CONCAT,
|
||||
LLAMA_DFLASH_MAIN_NODE_K_PAD,
|
||||
LLAMA_DFLASH_MAIN_NODE_V_PAD,
|
||||
LLAMA_DFLASH_MAIN_NODE_K_PERM_CONT,
|
||||
LLAMA_DFLASH_MAIN_NODE_V_PERM_CONT,
|
||||
LLAMA_DFLASH_MAIN_NODE_FLASH_ATTN,
|
||||
LLAMA_DFLASH_MAIN_NODE_ATTN_OUT,
|
||||
LLAMA_DFLASH_MAIN_NODE_FFN,
|
||||
LLAMA_DFLASH_MAIN_NODE_RESULT_ROWS,
|
||||
LLAMA_DFLASH_MAIN_NODE_RESULT_NORM,
|
||||
LLAMA_DFLASH_MAIN_NODE_RESULT,
|
||||
};
|
||||
|
||||
struct llama_dflash_kv_node_profiler {
|
||||
llama_dflash_profile_stats * profile = nullptr;
|
||||
int64_t t_start_us = 0;
|
||||
llama_dflash_kv_node_kind active_kind = LLAMA_DFLASH_KV_NODE_NONE;
|
||||
};
|
||||
|
||||
struct llama_dflash_main_node_profiler {
|
||||
llama_dflash_profile_stats * profile = nullptr;
|
||||
ggml_backend_sched_eval_callback prev_callback = nullptr;
|
||||
void * prev_user_data = nullptr;
|
||||
bool prev_active = false;
|
||||
int64_t t_start_us = 0;
|
||||
llama_dflash_main_node_kind active_kind = LLAMA_DFLASH_MAIN_NODE_NONE;
|
||||
};
|
||||
|
||||
static bool llama_dflash_tensor_name_has_prefix(const struct ggml_tensor * tensor, const char * prefix) {
|
||||
if (tensor == nullptr || prefix == nullptr || prefix[0] == '\0') {
|
||||
return false;
|
||||
}
|
||||
|
||||
return std::strncmp(tensor->name, prefix, std::strlen(prefix)) == 0;
|
||||
}
|
||||
|
||||
static bool llama_dflash_tensor_name_matches_label(const struct ggml_tensor * tensor, const char * label) {
|
||||
if (!llama_dflash_tensor_name_has_prefix(tensor, label)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const size_t label_len = std::strlen(label);
|
||||
const char next = tensor->name[label_len];
|
||||
return next == '\0' || next == '-';
|
||||
}
|
||||
|
||||
static llama_dflash_kv_node_kind llama_dflash_kv_node_kind_from_tensor(const struct ggml_tensor * tensor) {
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_fused_target")) {
|
||||
return LLAMA_DFLASH_KV_NODE_FUSED_TARGET;
|
||||
}
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_proj")) {
|
||||
return LLAMA_DFLASH_KV_NODE_K_PROJ;
|
||||
}
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_norm")) {
|
||||
return LLAMA_DFLASH_KV_NODE_K_NORM;
|
||||
}
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_rope")) {
|
||||
return LLAMA_DFLASH_KV_NODE_K_ROPE;
|
||||
}
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_v_proj")) {
|
||||
return LLAMA_DFLASH_KV_NODE_V_PROJ;
|
||||
}
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_store")) {
|
||||
return LLAMA_DFLASH_KV_NODE_K_STORE;
|
||||
}
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_v_store")) {
|
||||
return LLAMA_DFLASH_KV_NODE_V_STORE;
|
||||
}
|
||||
|
||||
return LLAMA_DFLASH_KV_NODE_NONE;
|
||||
}
|
||||
|
||||
static void llama_dflash_kv_node_profile_add(
|
||||
llama_dflash_profile_stats & profile,
|
||||
llama_dflash_kv_node_kind kind,
|
||||
uint64_t elapsed_us) {
|
||||
switch (kind) {
|
||||
case LLAMA_DFLASH_KV_NODE_FUSED_TARGET:
|
||||
profile.graph_kv_node_fused_target_calls++;
|
||||
profile.graph_kv_node_fused_target_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_KV_NODE_K_PROJ:
|
||||
profile.graph_kv_node_k_proj_calls++;
|
||||
profile.graph_kv_node_k_proj_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_KV_NODE_K_NORM:
|
||||
profile.graph_kv_node_k_norm_calls++;
|
||||
profile.graph_kv_node_k_norm_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_KV_NODE_K_ROPE:
|
||||
profile.graph_kv_node_k_rope_calls++;
|
||||
profile.graph_kv_node_k_rope_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_KV_NODE_V_PROJ:
|
||||
profile.graph_kv_node_v_proj_calls++;
|
||||
profile.graph_kv_node_v_proj_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_KV_NODE_K_STORE:
|
||||
profile.graph_kv_node_k_store_calls++;
|
||||
profile.graph_kv_node_k_store_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_KV_NODE_V_STORE:
|
||||
profile.graph_kv_node_v_store_calls++;
|
||||
profile.graph_kv_node_v_store_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_KV_NODE_NONE:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
static llama_dflash_main_node_kind llama_dflash_main_node_kind_from_tensor(const struct ggml_tensor * tensor) {
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "Qcur")) {
|
||||
return LLAMA_DFLASH_MAIN_NODE_QCUR;
|
||||
}
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "Kcur_noise")) {
|
||||
return LLAMA_DFLASH_MAIN_NODE_K_DRAFT;
|
||||
}
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "Vcur_noise")) {
|
||||
return LLAMA_DFLASH_MAIN_NODE_V_DRAFT;
|
||||
}
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "Kcur_ctx_cache")) {
|
||||
return LLAMA_DFLASH_MAIN_NODE_K_CTX_VIEW;
|
||||
}
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "Vcur_ctx_cache")) {
|
||||
return LLAMA_DFLASH_MAIN_NODE_V_CTX_VIEW;
|
||||
}
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_k_concat")) {
|
||||
return LLAMA_DFLASH_MAIN_NODE_K_CONCAT;
|
||||
}
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_v_concat")) {
|
||||
return LLAMA_DFLASH_MAIN_NODE_V_CONCAT;
|
||||
}
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_k_pad")) {
|
||||
return LLAMA_DFLASH_MAIN_NODE_K_PAD;
|
||||
}
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_v_pad")) {
|
||||
return LLAMA_DFLASH_MAIN_NODE_V_PAD;
|
||||
}
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_k_perm_cont")) {
|
||||
return LLAMA_DFLASH_MAIN_NODE_K_PERM_CONT;
|
||||
}
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_v_perm_cont")) {
|
||||
return LLAMA_DFLASH_MAIN_NODE_V_PERM_CONT;
|
||||
}
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "flash_attn_reshaped")) {
|
||||
return LLAMA_DFLASH_MAIN_NODE_NONE;
|
||||
}
|
||||
if (llama_dflash_tensor_name_matches_label(tensor, "flash_attn")) {
|
||||
return LLAMA_DFLASH_MAIN_NODE_FLASH_ATTN;
|
||||
}
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "kqv_out")) {
|
||||
return LLAMA_DFLASH_MAIN_NODE_ATTN_OUT;
|
||||
}
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "ffn_out")) {
|
||||
return LLAMA_DFLASH_MAIN_NODE_FFN;
|
||||
}
|
||||
if (llama_dflash_tensor_name_matches_label(tensor, "result_output_rows")) {
|
||||
return LLAMA_DFLASH_MAIN_NODE_RESULT_ROWS;
|
||||
}
|
||||
if (llama_dflash_tensor_name_matches_label(tensor, "result_norm")) {
|
||||
return LLAMA_DFLASH_MAIN_NODE_RESULT_NORM;
|
||||
}
|
||||
if (llama_dflash_tensor_name_matches_label(tensor, "output")) {
|
||||
return LLAMA_DFLASH_MAIN_NODE_RESULT;
|
||||
}
|
||||
if (llama_dflash_tensor_name_matches_label(tensor, "result_output")) {
|
||||
return LLAMA_DFLASH_MAIN_NODE_RESULT;
|
||||
}
|
||||
|
||||
return LLAMA_DFLASH_MAIN_NODE_NONE;
|
||||
}
|
||||
|
||||
static void llama_dflash_main_node_profile_add(
|
||||
llama_dflash_profile_stats & profile,
|
||||
llama_dflash_main_node_kind kind,
|
||||
uint64_t elapsed_us) {
|
||||
switch (kind) {
|
||||
case LLAMA_DFLASH_MAIN_NODE_QCUR:
|
||||
profile.graph_main_node_qcur_calls++;
|
||||
profile.graph_main_node_qcur_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_MAIN_NODE_K_DRAFT:
|
||||
profile.graph_main_node_k_draft_calls++;
|
||||
profile.graph_main_node_k_draft_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_MAIN_NODE_V_DRAFT:
|
||||
profile.graph_main_node_v_draft_calls++;
|
||||
profile.graph_main_node_v_draft_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_MAIN_NODE_K_CTX_VIEW:
|
||||
profile.graph_main_node_k_ctx_view_calls++;
|
||||
profile.graph_main_node_k_ctx_view_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_MAIN_NODE_V_CTX_VIEW:
|
||||
profile.graph_main_node_v_ctx_view_calls++;
|
||||
profile.graph_main_node_v_ctx_view_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_MAIN_NODE_K_CONCAT:
|
||||
profile.graph_main_node_k_concat_calls++;
|
||||
profile.graph_main_node_k_concat_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_MAIN_NODE_V_CONCAT:
|
||||
profile.graph_main_node_v_concat_calls++;
|
||||
profile.graph_main_node_v_concat_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_MAIN_NODE_K_PAD:
|
||||
profile.graph_main_node_k_pad_calls++;
|
||||
profile.graph_main_node_k_pad_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_MAIN_NODE_V_PAD:
|
||||
profile.graph_main_node_v_pad_calls++;
|
||||
profile.graph_main_node_v_pad_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_MAIN_NODE_K_PERM_CONT:
|
||||
profile.graph_main_node_k_perm_cont_calls++;
|
||||
profile.graph_main_node_k_perm_cont_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_MAIN_NODE_V_PERM_CONT:
|
||||
profile.graph_main_node_v_perm_cont_calls++;
|
||||
profile.graph_main_node_v_perm_cont_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_MAIN_NODE_FLASH_ATTN:
|
||||
profile.graph_main_node_flash_attn_calls++;
|
||||
profile.graph_main_node_flash_attn_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_MAIN_NODE_ATTN_OUT:
|
||||
profile.graph_main_node_attn_out_calls++;
|
||||
profile.graph_main_node_attn_out_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_MAIN_NODE_FFN:
|
||||
profile.graph_main_node_ffn_calls++;
|
||||
profile.graph_main_node_ffn_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_MAIN_NODE_RESULT_ROWS:
|
||||
profile.graph_main_node_result_rows_calls++;
|
||||
profile.graph_main_node_result_rows_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_MAIN_NODE_RESULT_NORM:
|
||||
profile.graph_main_node_result_norm_calls++;
|
||||
profile.graph_main_node_result_norm_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_MAIN_NODE_RESULT:
|
||||
profile.graph_main_node_result_calls++;
|
||||
profile.graph_main_node_result_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_MAIN_NODE_NONE:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
static bool llama_dflash_kv_node_eval_callback(struct ggml_tensor * tensor, bool ask, void * user_data) {
|
||||
auto * profiler = static_cast<llama_dflash_kv_node_profiler *>(user_data);
|
||||
if (profiler == nullptr || profiler->profile == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const llama_dflash_kv_node_kind kind = llama_dflash_kv_node_kind_from_tensor(tensor);
|
||||
if (ask) {
|
||||
if (kind == LLAMA_DFLASH_KV_NODE_NONE) {
|
||||
return false;
|
||||
}
|
||||
|
||||
profiler->active_kind = kind;
|
||||
profiler->t_start_us = ggml_time_us();
|
||||
return true;
|
||||
}
|
||||
|
||||
if (kind != LLAMA_DFLASH_KV_NODE_NONE && profiler->active_kind == kind && profiler->t_start_us > 0) {
|
||||
llama_dflash_kv_node_profile_add(*profiler->profile, kind, (uint64_t) (ggml_time_us() - profiler->t_start_us));
|
||||
}
|
||||
|
||||
profiler->active_kind = LLAMA_DFLASH_KV_NODE_NONE;
|
||||
profiler->t_start_us = 0;
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool llama_dflash_main_node_eval_callback(struct ggml_tensor * tensor, bool ask, void * user_data) {
|
||||
auto * profiler = static_cast<llama_dflash_main_node_profiler *>(user_data);
|
||||
if (profiler == nullptr || profiler->profile == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const llama_dflash_main_node_kind kind = llama_dflash_main_node_kind_from_tensor(tensor);
|
||||
if (ask) {
|
||||
profiler->prev_active = profiler->prev_callback != nullptr
|
||||
? profiler->prev_callback(tensor, ask, profiler->prev_user_data)
|
||||
: false;
|
||||
|
||||
if (kind == LLAMA_DFLASH_MAIN_NODE_NONE) {
|
||||
profiler->active_kind = LLAMA_DFLASH_MAIN_NODE_NONE;
|
||||
profiler->t_start_us = 0;
|
||||
return profiler->prev_active;
|
||||
}
|
||||
|
||||
profiler->active_kind = kind;
|
||||
profiler->t_start_us = ggml_time_us();
|
||||
return true;
|
||||
}
|
||||
|
||||
bool prev_result = false;
|
||||
if (profiler->prev_active && profiler->prev_callback != nullptr) {
|
||||
prev_result = profiler->prev_callback(tensor, ask, profiler->prev_user_data);
|
||||
}
|
||||
|
||||
const bool tracked = kind != LLAMA_DFLASH_MAIN_NODE_NONE &&
|
||||
profiler->active_kind == kind &&
|
||||
profiler->t_start_us > 0;
|
||||
if (tracked) {
|
||||
llama_dflash_main_node_profile_add(*profiler->profile, kind, (uint64_t) (ggml_time_us() - profiler->t_start_us));
|
||||
}
|
||||
|
||||
profiler->prev_active = false;
|
||||
profiler->active_kind = LLAMA_DFLASH_MAIN_NODE_NONE;
|
||||
profiler->t_start_us = 0;
|
||||
return prev_result || tracked;
|
||||
}
|
||||
|
||||
void llama_sync_dflash_workspace_if_pending(struct llama_context & lctx) {
|
||||
if (!lctx.dflash_kv_workspace_sync_pending || lctx.dflash_workspace_sched == nullptr) {
|
||||
return;
|
||||
|
||||
@ -880,6 +880,36 @@ static void llama_dflash_contract_log_output_indices(
|
||||
have_capture ? "true" : "false");
|
||||
}
|
||||
|
||||
void llama_dflash_contract_log_accept(
|
||||
int slot_id,
|
||||
bool is_dflash,
|
||||
const char * path,
|
||||
bool any_rejected,
|
||||
size_t n_draft,
|
||||
size_t n_accepted,
|
||||
llama_pos pos_base,
|
||||
const std::vector<int32_t> & output_indices) {
|
||||
if (!llama_dflash_contract_log_enabled() || !is_dflash) {
|
||||
return;
|
||||
}
|
||||
|
||||
static std::atomic<uint64_t> counter = 0;
|
||||
const uint64_t ordinal = counter.fetch_add(1, std::memory_order_relaxed);
|
||||
if (ordinal >= 8) {
|
||||
return;
|
||||
}
|
||||
|
||||
LLAMA_LOG_INFO("dflash contract accept[%llu]: slot=%d path=%s rejected=%s drafted=%zu accepted=%zu pos_base=%d output_indices=%s\n",
|
||||
(unsigned long long) (ordinal + 1),
|
||||
slot_id,
|
||||
path,
|
||||
any_rejected ? "true" : "false",
|
||||
n_draft,
|
||||
n_accepted,
|
||||
(int) pos_base,
|
||||
llama_dflash_contract_format_values(output_indices).c_str());
|
||||
}
|
||||
|
||||
static bool llama_spec_materialize_dflash_rows_prepared(
|
||||
struct llama_context * ctx,
|
||||
int32_t row_count,
|
||||
|
||||
@ -277,3 +277,13 @@ bool llama_spec_copy_dflash_rows_from_output_indices(
|
||||
struct llama_context * ctx,
|
||||
const std::vector<int32_t> & output_indices,
|
||||
std::vector<float> & hidden_rows);
|
||||
|
||||
void llama_dflash_contract_log_accept(
|
||||
int slot_id,
|
||||
bool is_dflash,
|
||||
const char * path,
|
||||
bool any_rejected,
|
||||
size_t n_draft,
|
||||
size_t n_accepted,
|
||||
llama_pos pos_base,
|
||||
const std::vector<int32_t> & output_indices);
|
||||
|
||||
@ -1,11 +1,6 @@
|
||||
#include "llama-spec-features.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <atomic>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <random>
|
||||
#include <sstream>
|
||||
|
||||
#include "llama-model.h"
|
||||
#include "llama-context.h"
|
||||
|
||||
338
src/llama.cpp
338
src/llama.cpp
@ -19,6 +19,7 @@
|
||||
#include "llama-context.h"
|
||||
#include "llama-spec-features.h"
|
||||
#include "llama-dflash.h"
|
||||
#include "llama-dflash-profile.h"
|
||||
#include "llama-quantize.h"
|
||||
|
||||
#include "unicode.h"
|
||||
@ -180,343 +181,6 @@ static bool llama_env_flag_enabled(const char * name) {
|
||||
std::strcmp(env, "off") != 0;
|
||||
}
|
||||
|
||||
enum llama_dflash_kv_node_kind {
|
||||
LLAMA_DFLASH_KV_NODE_NONE = 0,
|
||||
LLAMA_DFLASH_KV_NODE_FUSED_TARGET,
|
||||
LLAMA_DFLASH_KV_NODE_K_PROJ,
|
||||
LLAMA_DFLASH_KV_NODE_K_NORM,
|
||||
LLAMA_DFLASH_KV_NODE_K_ROPE,
|
||||
LLAMA_DFLASH_KV_NODE_V_PROJ,
|
||||
LLAMA_DFLASH_KV_NODE_K_STORE,
|
||||
LLAMA_DFLASH_KV_NODE_V_STORE,
|
||||
};
|
||||
|
||||
enum llama_dflash_main_node_kind {
|
||||
LLAMA_DFLASH_MAIN_NODE_NONE = 0,
|
||||
LLAMA_DFLASH_MAIN_NODE_QCUR,
|
||||
LLAMA_DFLASH_MAIN_NODE_K_DRAFT,
|
||||
LLAMA_DFLASH_MAIN_NODE_V_DRAFT,
|
||||
LLAMA_DFLASH_MAIN_NODE_K_CTX_VIEW,
|
||||
LLAMA_DFLASH_MAIN_NODE_V_CTX_VIEW,
|
||||
LLAMA_DFLASH_MAIN_NODE_K_CONCAT,
|
||||
LLAMA_DFLASH_MAIN_NODE_V_CONCAT,
|
||||
LLAMA_DFLASH_MAIN_NODE_K_PAD,
|
||||
LLAMA_DFLASH_MAIN_NODE_V_PAD,
|
||||
LLAMA_DFLASH_MAIN_NODE_K_PERM_CONT,
|
||||
LLAMA_DFLASH_MAIN_NODE_V_PERM_CONT,
|
||||
LLAMA_DFLASH_MAIN_NODE_FLASH_ATTN,
|
||||
LLAMA_DFLASH_MAIN_NODE_ATTN_OUT,
|
||||
LLAMA_DFLASH_MAIN_NODE_FFN,
|
||||
LLAMA_DFLASH_MAIN_NODE_RESULT_ROWS,
|
||||
LLAMA_DFLASH_MAIN_NODE_RESULT_NORM,
|
||||
LLAMA_DFLASH_MAIN_NODE_RESULT,
|
||||
};
|
||||
|
||||
struct llama_dflash_kv_node_profiler {
|
||||
llama_dflash_profile_stats * profile = nullptr;
|
||||
int64_t t_start_us = 0;
|
||||
llama_dflash_kv_node_kind active_kind = LLAMA_DFLASH_KV_NODE_NONE;
|
||||
};
|
||||
|
||||
struct llama_dflash_main_node_profiler {
|
||||
llama_dflash_profile_stats * profile = nullptr;
|
||||
ggml_backend_sched_eval_callback prev_callback = nullptr;
|
||||
void * prev_user_data = nullptr;
|
||||
bool prev_active = false;
|
||||
int64_t t_start_us = 0;
|
||||
llama_dflash_main_node_kind active_kind = LLAMA_DFLASH_MAIN_NODE_NONE;
|
||||
};
|
||||
|
||||
static bool llama_dflash_tensor_name_has_prefix(const struct ggml_tensor * tensor, const char * prefix) {
|
||||
if (tensor == nullptr || prefix == nullptr || prefix[0] == '\0') {
|
||||
return false;
|
||||
}
|
||||
|
||||
return std::strncmp(tensor->name, prefix, std::strlen(prefix)) == 0;
|
||||
}
|
||||
|
||||
static bool llama_dflash_tensor_name_matches_label(const struct ggml_tensor * tensor, const char * label) {
|
||||
if (!llama_dflash_tensor_name_has_prefix(tensor, label)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const size_t label_len = std::strlen(label);
|
||||
const char next = tensor->name[label_len];
|
||||
return next == '\0' || next == '-';
|
||||
}
|
||||
|
||||
static llama_dflash_kv_node_kind llama_dflash_kv_node_kind_from_tensor(const struct ggml_tensor * tensor) {
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_fused_target")) {
|
||||
return LLAMA_DFLASH_KV_NODE_FUSED_TARGET;
|
||||
}
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_proj")) {
|
||||
return LLAMA_DFLASH_KV_NODE_K_PROJ;
|
||||
}
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_norm")) {
|
||||
return LLAMA_DFLASH_KV_NODE_K_NORM;
|
||||
}
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_rope")) {
|
||||
return LLAMA_DFLASH_KV_NODE_K_ROPE;
|
||||
}
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_v_proj")) {
|
||||
return LLAMA_DFLASH_KV_NODE_V_PROJ;
|
||||
}
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_store")) {
|
||||
return LLAMA_DFLASH_KV_NODE_K_STORE;
|
||||
}
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_v_store")) {
|
||||
return LLAMA_DFLASH_KV_NODE_V_STORE;
|
||||
}
|
||||
|
||||
return LLAMA_DFLASH_KV_NODE_NONE;
|
||||
}
|
||||
|
||||
static void llama_dflash_kv_node_profile_add(
|
||||
llama_dflash_profile_stats & profile,
|
||||
llama_dflash_kv_node_kind kind,
|
||||
uint64_t elapsed_us) {
|
||||
switch (kind) {
|
||||
case LLAMA_DFLASH_KV_NODE_FUSED_TARGET:
|
||||
profile.graph_kv_node_fused_target_calls++;
|
||||
profile.graph_kv_node_fused_target_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_KV_NODE_K_PROJ:
|
||||
profile.graph_kv_node_k_proj_calls++;
|
||||
profile.graph_kv_node_k_proj_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_KV_NODE_K_NORM:
|
||||
profile.graph_kv_node_k_norm_calls++;
|
||||
profile.graph_kv_node_k_norm_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_KV_NODE_K_ROPE:
|
||||
profile.graph_kv_node_k_rope_calls++;
|
||||
profile.graph_kv_node_k_rope_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_KV_NODE_V_PROJ:
|
||||
profile.graph_kv_node_v_proj_calls++;
|
||||
profile.graph_kv_node_v_proj_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_KV_NODE_K_STORE:
|
||||
profile.graph_kv_node_k_store_calls++;
|
||||
profile.graph_kv_node_k_store_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_KV_NODE_V_STORE:
|
||||
profile.graph_kv_node_v_store_calls++;
|
||||
profile.graph_kv_node_v_store_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_KV_NODE_NONE:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
static llama_dflash_main_node_kind llama_dflash_main_node_kind_from_tensor(const struct ggml_tensor * tensor) {
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "Qcur")) {
|
||||
return LLAMA_DFLASH_MAIN_NODE_QCUR;
|
||||
}
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "Kcur_noise")) {
|
||||
return LLAMA_DFLASH_MAIN_NODE_K_DRAFT;
|
||||
}
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "Vcur_noise")) {
|
||||
return LLAMA_DFLASH_MAIN_NODE_V_DRAFT;
|
||||
}
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "Kcur_ctx_cache")) {
|
||||
return LLAMA_DFLASH_MAIN_NODE_K_CTX_VIEW;
|
||||
}
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "Vcur_ctx_cache")) {
|
||||
return LLAMA_DFLASH_MAIN_NODE_V_CTX_VIEW;
|
||||
}
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_k_concat")) {
|
||||
return LLAMA_DFLASH_MAIN_NODE_K_CONCAT;
|
||||
}
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_v_concat")) {
|
||||
return LLAMA_DFLASH_MAIN_NODE_V_CONCAT;
|
||||
}
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_k_pad")) {
|
||||
return LLAMA_DFLASH_MAIN_NODE_K_PAD;
|
||||
}
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_v_pad")) {
|
||||
return LLAMA_DFLASH_MAIN_NODE_V_PAD;
|
||||
}
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_k_perm_cont")) {
|
||||
return LLAMA_DFLASH_MAIN_NODE_K_PERM_CONT;
|
||||
}
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_v_perm_cont")) {
|
||||
return LLAMA_DFLASH_MAIN_NODE_V_PERM_CONT;
|
||||
}
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "flash_attn_reshaped")) {
|
||||
return LLAMA_DFLASH_MAIN_NODE_NONE;
|
||||
}
|
||||
if (llama_dflash_tensor_name_matches_label(tensor, "flash_attn")) {
|
||||
return LLAMA_DFLASH_MAIN_NODE_FLASH_ATTN;
|
||||
}
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "kqv_out")) {
|
||||
return LLAMA_DFLASH_MAIN_NODE_ATTN_OUT;
|
||||
}
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "ffn_out")) {
|
||||
return LLAMA_DFLASH_MAIN_NODE_FFN;
|
||||
}
|
||||
if (llama_dflash_tensor_name_matches_label(tensor, "result_output_rows")) {
|
||||
return LLAMA_DFLASH_MAIN_NODE_RESULT_ROWS;
|
||||
}
|
||||
if (llama_dflash_tensor_name_matches_label(tensor, "result_norm")) {
|
||||
return LLAMA_DFLASH_MAIN_NODE_RESULT_NORM;
|
||||
}
|
||||
if (llama_dflash_tensor_name_matches_label(tensor, "output")) {
|
||||
return LLAMA_DFLASH_MAIN_NODE_RESULT;
|
||||
}
|
||||
if (llama_dflash_tensor_name_matches_label(tensor, "result_output")) {
|
||||
return LLAMA_DFLASH_MAIN_NODE_RESULT;
|
||||
}
|
||||
|
||||
return LLAMA_DFLASH_MAIN_NODE_NONE;
|
||||
}
|
||||
|
||||
static void llama_dflash_main_node_profile_add(
|
||||
llama_dflash_profile_stats & profile,
|
||||
llama_dflash_main_node_kind kind,
|
||||
uint64_t elapsed_us) {
|
||||
switch (kind) {
|
||||
case LLAMA_DFLASH_MAIN_NODE_QCUR:
|
||||
profile.graph_main_node_qcur_calls++;
|
||||
profile.graph_main_node_qcur_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_MAIN_NODE_K_DRAFT:
|
||||
profile.graph_main_node_k_draft_calls++;
|
||||
profile.graph_main_node_k_draft_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_MAIN_NODE_V_DRAFT:
|
||||
profile.graph_main_node_v_draft_calls++;
|
||||
profile.graph_main_node_v_draft_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_MAIN_NODE_K_CTX_VIEW:
|
||||
profile.graph_main_node_k_ctx_view_calls++;
|
||||
profile.graph_main_node_k_ctx_view_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_MAIN_NODE_V_CTX_VIEW:
|
||||
profile.graph_main_node_v_ctx_view_calls++;
|
||||
profile.graph_main_node_v_ctx_view_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_MAIN_NODE_K_CONCAT:
|
||||
profile.graph_main_node_k_concat_calls++;
|
||||
profile.graph_main_node_k_concat_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_MAIN_NODE_V_CONCAT:
|
||||
profile.graph_main_node_v_concat_calls++;
|
||||
profile.graph_main_node_v_concat_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_MAIN_NODE_K_PAD:
|
||||
profile.graph_main_node_k_pad_calls++;
|
||||
profile.graph_main_node_k_pad_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_MAIN_NODE_V_PAD:
|
||||
profile.graph_main_node_v_pad_calls++;
|
||||
profile.graph_main_node_v_pad_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_MAIN_NODE_K_PERM_CONT:
|
||||
profile.graph_main_node_k_perm_cont_calls++;
|
||||
profile.graph_main_node_k_perm_cont_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_MAIN_NODE_V_PERM_CONT:
|
||||
profile.graph_main_node_v_perm_cont_calls++;
|
||||
profile.graph_main_node_v_perm_cont_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_MAIN_NODE_FLASH_ATTN:
|
||||
profile.graph_main_node_flash_attn_calls++;
|
||||
profile.graph_main_node_flash_attn_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_MAIN_NODE_ATTN_OUT:
|
||||
profile.graph_main_node_attn_out_calls++;
|
||||
profile.graph_main_node_attn_out_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_MAIN_NODE_FFN:
|
||||
profile.graph_main_node_ffn_calls++;
|
||||
profile.graph_main_node_ffn_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_MAIN_NODE_RESULT_ROWS:
|
||||
profile.graph_main_node_result_rows_calls++;
|
||||
profile.graph_main_node_result_rows_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_MAIN_NODE_RESULT_NORM:
|
||||
profile.graph_main_node_result_norm_calls++;
|
||||
profile.graph_main_node_result_norm_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_MAIN_NODE_RESULT:
|
||||
profile.graph_main_node_result_calls++;
|
||||
profile.graph_main_node_result_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_MAIN_NODE_NONE:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
static bool llama_dflash_kv_node_eval_callback(struct ggml_tensor * tensor, bool ask, void * user_data) {
|
||||
auto * profiler = static_cast<llama_dflash_kv_node_profiler *>(user_data);
|
||||
if (profiler == nullptr || profiler->profile == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const llama_dflash_kv_node_kind kind = llama_dflash_kv_node_kind_from_tensor(tensor);
|
||||
if (ask) {
|
||||
if (kind == LLAMA_DFLASH_KV_NODE_NONE) {
|
||||
return false;
|
||||
}
|
||||
|
||||
profiler->active_kind = kind;
|
||||
profiler->t_start_us = ggml_time_us();
|
||||
return true;
|
||||
}
|
||||
|
||||
if (kind != LLAMA_DFLASH_KV_NODE_NONE && profiler->active_kind == kind && profiler->t_start_us > 0) {
|
||||
llama_dflash_kv_node_profile_add(*profiler->profile, kind, (uint64_t) (ggml_time_us() - profiler->t_start_us));
|
||||
}
|
||||
|
||||
profiler->active_kind = LLAMA_DFLASH_KV_NODE_NONE;
|
||||
profiler->t_start_us = 0;
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool llama_dflash_main_node_eval_callback(struct ggml_tensor * tensor, bool ask, void * user_data) {
|
||||
auto * profiler = static_cast<llama_dflash_main_node_profiler *>(user_data);
|
||||
if (profiler == nullptr || profiler->profile == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const llama_dflash_main_node_kind kind = llama_dflash_main_node_kind_from_tensor(tensor);
|
||||
if (ask) {
|
||||
profiler->prev_active = profiler->prev_callback != nullptr
|
||||
? profiler->prev_callback(tensor, ask, profiler->prev_user_data)
|
||||
: false;
|
||||
|
||||
if (kind == LLAMA_DFLASH_MAIN_NODE_NONE) {
|
||||
profiler->active_kind = LLAMA_DFLASH_MAIN_NODE_NONE;
|
||||
profiler->t_start_us = 0;
|
||||
return profiler->prev_active;
|
||||
}
|
||||
|
||||
profiler->active_kind = kind;
|
||||
profiler->t_start_us = ggml_time_us();
|
||||
return true;
|
||||
}
|
||||
|
||||
bool prev_result = false;
|
||||
if (profiler->prev_active && profiler->prev_callback != nullptr) {
|
||||
prev_result = profiler->prev_callback(tensor, ask, profiler->prev_user_data);
|
||||
}
|
||||
|
||||
const bool tracked = kind != LLAMA_DFLASH_MAIN_NODE_NONE &&
|
||||
profiler->active_kind == kind &&
|
||||
profiler->t_start_us > 0;
|
||||
if (tracked) {
|
||||
llama_dflash_main_node_profile_add(*profiler->profile, kind, (uint64_t) (ggml_time_us() - profiler->t_start_us));
|
||||
}
|
||||
|
||||
profiler->prev_active = false;
|
||||
profiler->active_kind = LLAMA_DFLASH_MAIN_NODE_NONE;
|
||||
profiler->t_start_us = 0;
|
||||
return prev_result || tracked;
|
||||
}
|
||||
|
||||
|
||||
// extract ip and port from RPC[ip:port] for rpc and keep other device names
|
||||
static std::vector<rpc_device> extract_device_from_rpc_device(std::vector<std::string> devices) {
|
||||
std::vector<rpc_device> rpc_servers;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user