Use windows update in kv cache

This commit is contained in:
SamuelOliveirads 2026-05-31 14:51:21 -03:00
parent 1369e68471
commit ed403dca27
7 changed files with 1133 additions and 115 deletions

View File

@ -253,6 +253,8 @@ static const common_speculative_state_mtp * common_speculative_get_mtp_state(con
static common_speculative_state_dflash * common_speculative_get_dflash_state(common_speculative * spec);
static const common_speculative_state_dflash * common_speculative_get_dflash_state(const common_speculative * spec);
static int32_t common_speculative_feature_width(const common_speculative * spec);
static void dflash_materialize_target_window_features(common_speculative_state_dflash & state);
static void dflash_ring_reset_rows(common_speculative_state_dflash & state, const float * rows, int32_t n_rows);
static void dflash_append_target_features(
common_speculative_state_dflash & state,
const float * feature_rows,
@ -284,6 +286,17 @@ static bool dflash_contract_log_enabled() {
std::strcmp(env, "off") != 0;
}
static bool dflash_use_kv_cache_experiment() {
const char * env = std::getenv("IK_DFLASH_KV_CACHE");
if (env == nullptr || *env == '\0') {
return false;
}
return std::strcmp(env, "0") != 0 &&
std::strcmp(env, "false") != 0 &&
std::strcmp(env, "off") != 0;
}
template <typename T>
static std::string dflash_contract_format_values(
const std::vector<T> & values,
@ -479,7 +492,18 @@ struct common_speculative_state_dflash : public common_speculative_state {
std::vector<int32_t> target_layer_ids;
std::vector<float> target_window;
std::vector<llama_pos> target_window_pos;
std::vector<float> target_window_stage;
std::vector<llama_pos> target_window_pos_stage;
std::vector<float> target_window_ring;
std::vector<float> target_window_append_features;
int32_t target_window_rows = 0;
int32_t target_window_ring_write_pos = 0;
int32_t target_window_ring_filled = 0;
uint64_t target_window_version = 0;
int32_t target_window_keep_rows = 0;
int32_t target_window_append_rows = 0;
bool target_window_replace = false;
bool target_window_materialized = false;
llama_pos last_target_pos = -1;
size_t n_window_updates = 0;
size_t n_rows_seen = 0;
@ -497,6 +521,13 @@ struct common_speculative_state_dflash : public common_speculative_state {
uint64_t t_accept_output_copy_us = 0;
uint64_t t_accept_commit_us = 0;
uint64_t t_accept_append_us = 0;
uint64_t t_accept_append_filter_us = 0;
uint64_t t_accept_append_window_alloc_us = 0;
uint64_t t_accept_append_replace_us = 0;
uint64_t t_accept_append_keep_old_us = 0;
uint64_t t_accept_append_new_rows_us = 0;
uint64_t t_accept_append_commit_detail_us = 0;
uint64_t t_accept_append_log_us = 0;
size_t n_warmup_collect_calls = 0;
size_t n_warmup_collect_rows = 0;
size_t n_warmup_append_calls = 0;
@ -507,6 +538,8 @@ struct common_speculative_state_dflash : public common_speculative_state {
size_t n_accept_commit_rows = 0;
size_t n_accept_append_calls = 0;
size_t n_accept_append_rows = 0;
size_t n_accept_append_replace_calls = 0;
size_t n_accept_append_slide_calls = 0;
common_speculative_state_dflash(
enum common_speculative_type type,
@ -614,6 +647,12 @@ struct common_speculative_state_dflash : public common_speculative_state {
}
batch = llama_batch_init(std::max(1, block_size), 0, 1);
target_window.reserve((size_t) this->cross_ctx * (size_t) n_target_features);
target_window_stage.reserve((size_t) this->cross_ctx * (size_t) n_target_features);
target_window_ring.resize((size_t) this->cross_ctx * (size_t) n_target_features);
target_window_append_features.reserve((size_t) this->cross_ctx * (size_t) n_target_features);
target_window_pos.reserve((size_t) this->cross_ctx);
target_window_pos_stage.reserve((size_t) this->cross_ctx);
ready = true;
llama_set_dflash_visible_cross_ctx(ctx_dft, this->cross_ctx);
@ -648,6 +687,7 @@ struct common_speculative_state_dflash : public common_speculative_state {
void begin(const llama_tokens & prompt) override {
GGML_UNUSED(prompt);
llama_kv_cache_clear(ctx_dft);
llama_reset_dflash_kv_cache_state(ctx_dft);
n_window_updates = 0;
n_rows_seen = 0;
n_rows_dropped = 0;
@ -663,6 +703,13 @@ struct common_speculative_state_dflash : public common_speculative_state {
t_accept_output_copy_us = 0;
t_accept_commit_us = 0;
t_accept_append_us = 0;
t_accept_append_filter_us = 0;
t_accept_append_window_alloc_us = 0;
t_accept_append_replace_us = 0;
t_accept_append_keep_old_us = 0;
t_accept_append_new_rows_us = 0;
t_accept_append_commit_detail_us = 0;
t_accept_append_log_us = 0;
n_warmup_collect_calls = 0;
n_warmup_collect_rows = 0;
n_warmup_append_calls = 0;
@ -673,6 +720,8 @@ struct common_speculative_state_dflash : public common_speculative_state {
n_accept_commit_rows = 0;
n_accept_append_calls = 0;
n_accept_append_rows = 0;
n_accept_append_replace_calls = 0;
n_accept_append_slide_calls = 0;
llama_dflash_profile_reset(ctx_tgt);
llama_dflash_profile_reset(ctx_dft);
}
@ -695,7 +744,33 @@ struct common_speculative_state_dflash : public common_speculative_state {
return;
}
if (!llama_set_dflash_target_features_view(ctx_dft, target_window.data(), target_window.size(), target_window_rows, target_window_pos.data())) {
const bool use_kv_cache = dflash_use_kv_cache_experiment();
const float * target_features = nullptr;
size_t target_feature_floats = 0;
llama_dflash_window_update window_update = {
target_window_version,
target_window_keep_rows,
target_window_append_rows,
target_window_replace,
target_window_append_features.empty() ? nullptr : target_window_append_features.data(),
target_window_append_features.size(),
};
const llama_dflash_kv_cache_transition cache_plan = use_kv_cache
? llama_plan_dflash_kv_cache_transition_for_ctx(ctx_dft, window_update, target_window_rows)
: llama_dflash_kv_cache_transition{};
if (!use_kv_cache || cache_plan.rebuild_cache) {
dflash_materialize_target_window_features(*this);
target_features = target_window.data();
target_feature_floats = target_window.size();
}
if (use_kv_cache && cache_plan.rebuild_cache) {
window_update.append_features = target_window.data();
window_update.append_floats = target_window.size();
window_update.append_rows = target_window_rows;
}
if (!llama_set_dflash_target_features_view(ctx_dft, target_features, target_feature_floats, target_window_rows, target_window_pos.data(), &window_update)) {
LOG_ERR("%s: failed to set DFlash target features\n", __func__);
n_set_target_fail++;
return;
@ -2522,6 +2597,24 @@ void common_speculative_print_stats(const common_speculative * spec, double slot
const double kv_compute_ms = (double) graph_stats.graph_kv_cache_compute_us / 1000.0;
const double kv_sync_ms = (double) graph_stats.graph_kv_cache_sync_us / 1000.0;
const double replay_append_ms = (double) dflash_state->t_accept_append_us / 1000.0;
const double feature_path_ms = (double) (
capture_stats.capture_prepare_sync_us +
capture_stats.capture_materialize_us +
graph_stats.set_target_copy_us +
graph_stats.graph_feature_copy_us +
graph_stats.graph_pos_copy_us +
graph_stats.graph_mask_build_us) / 1000.0;
const double decode_internal_ms = (double) (
graph_stats.decode_prelude_us +
graph_stats.decode_sched_reset_us +
graph_stats.decode_build_graph_us +
graph_stats.decode_sched_alloc_graph_us +
graph_stats.decode_prepare_us +
graph_stats.decode_set_inputs_us +
graph_stats.decode_graph_compute_us +
graph_stats.decode_result_us +
graph_stats.decode_embedding_us +
graph_stats.decode_final_sched_reset_us) / 1000.0;
LOG_INF("statistics dflash profile: capture(sync/materialize)=%.3f/%.3f ms calls=%llu/%llu bytes=%llu phase(prompt/verify batches changes)=%llu/%llu %llu/%llu, set_target=%.3f ms rows=%llu bytes=%llu, decode(llama_output_reserve/prepare)=%.3f/%.3f ms calls=%llu/%llu realloc(bytes)=%llu/%llu, prep(total/features/pos/mask)=%.3f/%.3f/%.3f/%.3f ms kv_cache(total/build/reserve/reset/alloc/up_f/up_p/compute/sync/read)=%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f ms calls(prepare/cache/read)=%llu/%llu/%llu bytes(feature/pos/mask/read)=%llu/%llu/%llu/%llu host_layers=%d, fallback_pos(copy/graph)=%llu/%llu, nonmono(copy/graph)=%llu/%llu, capture_fail=%llu/%llu decode_prepare_fail=%llu, visible_kv_max=%llu, last(rows=%d width=%d left_pad=%d n_tokens=%d n_kv=%d pos=[%d..%d])\n",
(double) capture_stats.capture_prepare_sync_us / 1000.0,
@ -2580,6 +2673,81 @@ void common_speculative_print_stats(const common_speculative * spec, double slot
(int) graph_stats.last_pos_first,
(int) graph_stats.last_pos_last);
LOG_INF("statistics dflash features: total=%.3f ms capture(sync/materialize)=%.3f/%.3f ms set_target=%.3f ms prep(feature/pos/mask)=%.3f/%.3f/%.3f ms rows(materialize/set_target)=%llu/%llu bytes(materialize/set_target/feature/pos/mask)=%llu/%llu/%llu/%llu/%llu\n",
feature_path_ms,
(double) capture_stats.capture_prepare_sync_us / 1000.0,
(double) capture_stats.capture_materialize_us / 1000.0,
(double) graph_stats.set_target_copy_us / 1000.0,
(double) graph_stats.graph_feature_copy_us / 1000.0,
(double) graph_stats.graph_pos_copy_us / 1000.0,
(double) graph_stats.graph_mask_build_us / 1000.0,
(unsigned long long) capture_stats.capture_materialize_rows,
(unsigned long long) graph_stats.set_target_rows,
(unsigned long long) capture_stats.capture_materialize_bytes,
(unsigned long long) graph_stats.set_target_copy_bytes,
(unsigned long long) graph_stats.graph_feature_bytes,
(unsigned long long) graph_stats.graph_pos_bytes,
(unsigned long long) graph_stats.graph_mask_bytes);
LOG_INF("statistics dflash kv: total=%.3f ms build/reserve/reset/alloc/upload_f/upload_p/compute/sync/read=%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f ms calls=%llu cached_bytes=%llu host_layers=%d\n",
kv_cache_total_ms,
(double) graph_stats.graph_kv_cache_build_us / 1000.0,
(double) graph_stats.graph_kv_cache_reserve_us / 1000.0,
(double) graph_stats.graph_kv_cache_reset_us / 1000.0,
(double) graph_stats.graph_kv_cache_alloc_us / 1000.0,
(double) graph_stats.graph_kv_cache_feature_upload_us / 1000.0,
(double) graph_stats.graph_kv_cache_pos_upload_us / 1000.0,
(double) graph_stats.graph_kv_cache_compute_us / 1000.0,
(double) graph_stats.graph_kv_cache_sync_us / 1000.0,
(double) graph_stats.graph_kv_cache_read_concat_pad_us / 1000.0,
(unsigned long long) graph_stats.graph_kv_cache_calls,
(unsigned long long) graph_stats.graph_kv_cache_cached_bytes,
graph_stats.last_kv_cache_host_layers);
if (graph_stats.decode_internal_chunks > 0) {
LOG_INF("statistics dflash decode: llama_decode(total)=%.3f ms calls=%zu chunks=%llu rebuilds=%llu sync_points=%llu internal(total/prelude/sched_reset/build/alloc/prepare/set_inputs/compute/get_result/get_embedding/final_reset)=%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f ms\n",
(double) dflash_state->t_draft_decode_us / 1000.0,
dflash_state->n_call_draft,
(unsigned long long) graph_stats.decode_internal_chunks,
(unsigned long long) graph_stats.decode_graph_rebuilds,
(unsigned long long) graph_stats.decode_sync_profile_points,
decode_internal_ms,
(double) graph_stats.decode_prelude_us / 1000.0,
(double) graph_stats.decode_sched_reset_us / 1000.0,
(double) graph_stats.decode_build_graph_us / 1000.0,
(double) graph_stats.decode_sched_alloc_graph_us / 1000.0,
(double) graph_stats.decode_prepare_us / 1000.0,
(double) graph_stats.decode_set_inputs_us / 1000.0,
(double) graph_stats.decode_graph_compute_us / 1000.0,
(double) graph_stats.decode_result_us / 1000.0,
(double) graph_stats.decode_embedding_us / 1000.0,
(double) graph_stats.decode_final_sched_reset_us / 1000.0);
}
if (graph_stats.graph_kv_node_fused_target_calls > 0 ||
graph_stats.graph_kv_node_k_proj_calls > 0 ||
graph_stats.graph_kv_node_k_norm_calls > 0 ||
graph_stats.graph_kv_node_k_rope_calls > 0 ||
graph_stats.graph_kv_node_v_proj_calls > 0 ||
graph_stats.graph_kv_node_k_store_calls > 0 ||
graph_stats.graph_kv_node_v_store_calls > 0) {
LOG_INF("statistics dflash kv nodes: fused_target/k_proj/k_norm/k_rope/v_proj/k_store/v_store=%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f ms calls=%llu/%llu/%llu/%llu/%llu/%llu/%llu\n",
(double) graph_stats.graph_kv_node_fused_target_us / 1000.0,
(double) graph_stats.graph_kv_node_k_proj_us / 1000.0,
(double) graph_stats.graph_kv_node_k_norm_us / 1000.0,
(double) graph_stats.graph_kv_node_k_rope_us / 1000.0,
(double) graph_stats.graph_kv_node_v_proj_us / 1000.0,
(double) graph_stats.graph_kv_node_k_store_us / 1000.0,
(double) graph_stats.graph_kv_node_v_store_us / 1000.0,
(unsigned long long) graph_stats.graph_kv_node_fused_target_calls,
(unsigned long long) graph_stats.graph_kv_node_k_proj_calls,
(unsigned long long) graph_stats.graph_kv_node_k_norm_calls,
(unsigned long long) graph_stats.graph_kv_node_k_rope_calls,
(unsigned long long) graph_stats.graph_kv_node_v_proj_calls,
(unsigned long long) graph_stats.graph_kv_node_k_store_calls,
(unsigned long long) graph_stats.graph_kv_node_v_store_calls);
}
LOG_INF("statistics dflash hot: kv(upload_f/upload_p/upload/compute/sync)=%.3f/%.3f/%.3f/%.3f/%.3f ms calls=%llu replay(accepted_prefix_append)=%.3f ms calls=%zu rows=%zu\n",
kv_upload_feature_ms,
kv_upload_pos_ms,
@ -2609,6 +2777,20 @@ void common_speculative_print_stats(const common_speculative * spec, double slot
dflash_state->n_accept_commit_rows,
dflash_state->n_accept_output_copy_rows,
dflash_state->n_accept_append_rows);
if (dflash_state->n_accept_append_calls > 0) {
LOG_INF("statistics dflash replay: append(filter/window_alloc/replace/keep_old/new_rows/commit/log)=%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f ms calls=%zu replace/slide=%zu/%zu\n",
(double) dflash_state->t_accept_append_filter_us / 1000.0,
(double) dflash_state->t_accept_append_window_alloc_us / 1000.0,
(double) dflash_state->t_accept_append_replace_us / 1000.0,
(double) dflash_state->t_accept_append_keep_old_us / 1000.0,
(double) dflash_state->t_accept_append_new_rows_us / 1000.0,
(double) dflash_state->t_accept_append_commit_detail_us / 1000.0,
(double) dflash_state->t_accept_append_log_us / 1000.0,
dflash_state->n_accept_append_calls,
dflash_state->n_accept_append_replace_calls,
dflash_state->n_accept_append_slide_calls);
}
}
}
}
@ -2728,11 +2910,113 @@ static void mtp_clear_target_hidden(common_speculative_state_mtp & state, llama_
state.draft_cache_by_seq.erase(seq_id);
}
struct dflash_append_breakdown {
uint64_t filter_us = 0;
uint64_t window_alloc_us = 0;
uint64_t replace_us = 0;
uint64_t keep_old_us = 0;
uint64_t new_rows_us = 0;
uint64_t commit_us = 0;
uint64_t log_us = 0;
bool replace_call = false;
};
static void dflash_record_window_update(
common_speculative_state_dflash & state,
int32_t keep_rows,
int32_t append_rows,
bool replace) {
state.target_window_keep_rows = std::max<int32_t>(0, keep_rows);
state.target_window_append_rows = std::max<int32_t>(0, append_rows);
state.target_window_replace = replace;
state.target_window_version++;
}
static void dflash_ring_reset_rows(
common_speculative_state_dflash & state,
const float * rows,
int32_t n_rows) {
const size_t row_width = (size_t) state.n_target_features;
if (n_rows <= 0 || rows == nullptr) {
state.target_window_ring_write_pos = 0;
state.target_window_ring_filled = 0;
return;
}
if (state.target_window_ring.size() != (size_t) state.cross_ctx * row_width) {
state.target_window_ring.resize((size_t) state.cross_ctx * row_width);
}
std::memcpy(state.target_window_ring.data(), rows, (size_t) n_rows * row_width * sizeof(float));
state.target_window_ring_write_pos = n_rows % state.cross_ctx;
state.target_window_ring_filled = n_rows;
state.target_window_materialized = false;
}
static void dflash_ring_append_rows(
common_speculative_state_dflash & state,
const float * rows,
int32_t n_rows) {
const size_t row_width = (size_t) state.n_target_features;
if (n_rows <= 0 || rows == nullptr) {
return;
}
if (state.target_window_ring.size() != (size_t) state.cross_ctx * row_width) {
state.target_window_ring.resize((size_t) state.cross_ctx * row_width);
}
int32_t write_pos = state.target_window_ring_write_pos;
int32_t remaining = n_rows;
const float * src = rows;
while (remaining > 0) {
const int32_t chunk_rows = std::min<int32_t>(remaining, state.cross_ctx - write_pos);
std::memcpy(
state.target_window_ring.data() + (size_t) write_pos * row_width,
src,
(size_t) chunk_rows * row_width * sizeof(float));
src += (size_t) chunk_rows * row_width;
remaining -= chunk_rows;
write_pos = (write_pos + chunk_rows) % state.cross_ctx;
}
state.target_window_ring_write_pos = write_pos;
state.target_window_ring_filled = std::min(state.cross_ctx, state.target_window_ring_filled + n_rows);
state.target_window_materialized = false;
}
static void dflash_materialize_target_window_features(common_speculative_state_dflash & state) {
if (state.target_window_materialized || state.target_window_rows <= 0) {
return;
}
const size_t row_width = (size_t) state.n_target_features;
state.target_window.resize((size_t) state.target_window_rows * row_width);
const int32_t read_start = (state.target_window_ring_write_pos - state.target_window_rows + state.cross_ctx) % state.cross_ctx;
const int32_t first_rows = std::min<int32_t>(state.target_window_rows, state.cross_ctx - read_start);
std::memcpy(
state.target_window.data(),
state.target_window_ring.data() + (size_t) read_start * row_width,
(size_t) first_rows * row_width * sizeof(float));
const int32_t second_rows = state.target_window_rows - first_rows;
if (second_rows > 0) {
std::memcpy(
state.target_window.data() + (size_t) first_rows * row_width,
state.target_window_ring.data(),
(size_t) second_rows * row_width * sizeof(float));
}
state.target_window_materialized = true;
}
static bool dflash_append_target_features(
common_speculative_state_dflash & state,
const common_speculative_feature_view & features,
const llama_batch & batch,
llama_seq_id seq_id) {
llama_seq_id seq_id,
dflash_append_breakdown * breakdown = nullptr) {
GGML_UNUSED(batch);
if (features.kind != COMMON_SPECULATIVE_FEATURE_HIDDEN_STATE ||
@ -2748,6 +3032,7 @@ static bool dflash_append_target_features(
new_rows.reserve(features.rows.size() * row_width);
new_positions.reserve(features.rows.size());
const int64_t t_filter_us = ggml_time_us();
for (const auto & row : features.rows) {
if (row.seq_id != seq_id || row.data == nullptr) {
continue;
@ -2756,6 +3041,9 @@ static bool dflash_append_target_features(
new_positions.push_back(row.pos);
new_rows.insert(new_rows.end(), row.data, row.data + row_width);
}
if (breakdown != nullptr) {
breakdown->filter_us += (uint64_t) (ggml_time_us() - t_filter_us);
}
if (new_positions.empty()) {
return false;
@ -2767,46 +3055,93 @@ static bool dflash_append_target_features(
if (n_rows >= state.cross_ctx) {
state.n_rows_dropped += (size_t) state.target_window_rows + (size_t) (n_rows - state.cross_ctx);
const int32_t keep_from = n_rows - state.cross_ctx;
state.target_window.assign(
const int64_t t_replace_us = ggml_time_us();
state.target_window_pos.assign(new_positions.begin() + keep_from, new_positions.end());
state.target_window_append_features.assign(
new_rows.begin() + (ptrdiff_t) keep_from * (ptrdiff_t) row_width,
new_rows.end());
state.target_window_pos.assign(new_positions.begin() + keep_from, new_positions.end());
dflash_ring_reset_rows(state, state.target_window_append_features.data(), state.cross_ctx);
if (breakdown != nullptr) {
breakdown->replace_us += (uint64_t) (ggml_time_us() - t_replace_us);
breakdown->replace_call = true;
}
const int64_t t_commit_us = ggml_time_us();
state.target_window_rows = state.cross_ctx;
state.target_window_ring_filled = state.target_window_rows;
state.last_target_pos = state.target_window_pos.empty() ? -1 : state.target_window_pos.back();
dflash_record_window_update(state, 0, state.target_window_rows, true);
if (breakdown != nullptr) {
breakdown->commit_us += (uint64_t) (ggml_time_us() - t_commit_us);
}
const int64_t t_log_us = ggml_time_us();
dflash_contract_log_append(state, seq_id, new_positions);
if (breakdown != nullptr) {
breakdown->log_us += (uint64_t) (ggml_time_us() - t_log_us);
}
return true;
}
const int32_t keep_old_rows = std::min<int32_t>(state.target_window_rows, state.cross_ctx - n_rows);
state.n_rows_dropped += (size_t) std::max<int32_t>(0, state.target_window_rows - keep_old_rows);
std::vector<float> next_window((size_t) (keep_old_rows + n_rows) * row_width);
std::vector<llama_pos> next_window_pos((size_t) (keep_old_rows + n_rows));
if (keep_old_rows > 0) {
const float * old_src = state.target_window.data() + (size_t) (state.target_window_rows - keep_old_rows) * row_width;
std::memcpy(next_window.data(), old_src, (size_t) keep_old_rows * row_width * sizeof(float));
std::copy(state.target_window_pos.end() - keep_old_rows, state.target_window_pos.end(), next_window_pos.begin());
const int64_t t_window_alloc_us = ggml_time_us();
std::vector<llama_pos> & next_window_pos = state.target_window_pos_stage;
next_window_pos.resize((size_t) (keep_old_rows + n_rows));
if (breakdown != nullptr) {
breakdown->window_alloc_us += (uint64_t) (ggml_time_us() - t_window_alloc_us);
}
std::memcpy(
next_window.data() + (size_t) keep_old_rows * row_width,
new_rows.data(),
(size_t) n_rows * row_width * sizeof(float));
std::copy(new_positions.begin(), new_positions.end(), next_window_pos.begin() + keep_old_rows);
if (keep_old_rows > 0) {
const int64_t t_keep_old_us = ggml_time_us();
std::copy(state.target_window_pos.end() - keep_old_rows, state.target_window_pos.end(), next_window_pos.begin());
if (breakdown != nullptr) {
breakdown->keep_old_us += (uint64_t) (ggml_time_us() - t_keep_old_us);
}
}
state.target_window = std::move(next_window);
state.target_window_pos = std::move(next_window_pos);
const int64_t t_new_rows_us = ggml_time_us();
state.target_window_append_features.assign(new_rows.begin(), new_rows.end());
dflash_ring_append_rows(state, state.target_window_append_features.data(), n_rows);
std::copy(new_positions.begin(), new_positions.end(), next_window_pos.begin() + keep_old_rows);
if (breakdown != nullptr) {
breakdown->new_rows_us += (uint64_t) (ggml_time_us() - t_new_rows_us);
}
const int64_t t_commit_us = ggml_time_us();
state.target_window_pos.swap(next_window_pos);
next_window_pos.clear();
state.target_window_rows = keep_old_rows + n_rows;
state.target_window_ring_filled = state.target_window_rows;
state.last_target_pos = state.target_window_pos.empty() ? -1 : state.target_window_pos.back();
dflash_record_window_update(state, keep_old_rows, n_rows, false);
if (breakdown != nullptr) {
breakdown->commit_us += (uint64_t) (ggml_time_us() - t_commit_us);
}
const int64_t t_log_us = ggml_time_us();
dflash_contract_log_append(state, seq_id, new_positions);
if (breakdown != nullptr) {
breakdown->log_us += (uint64_t) (ggml_time_us() - t_log_us);
}
return true;
}
static void dflash_clear_target_features(common_speculative_state_dflash & state) {
state.target_window.clear();
state.target_window_pos.clear();
state.target_window_stage.clear();
state.target_window_pos_stage.clear();
state.target_window_append_features.clear();
state.target_window_rows = 0;
state.target_window_ring_write_pos = 0;
state.target_window_ring_filled = 0;
state.target_window_keep_rows = 0;
state.target_window_append_rows = 0;
state.target_window_replace = false;
state.target_window_materialized = false;
state.last_target_pos = -1;
llama_reset_dflash_kv_cache_state(state.ctx_dft);
}
static void dflash_context_shift(
@ -2814,10 +3149,12 @@ static void dflash_context_shift(
llama_pos kv_keep,
llama_pos kv_discard,
llama_pos kv_past) {
if (kv_discard <= 0 || state.target_window_rows <= 0 || state.target_window.empty() || state.target_window_pos.empty()) {
if (kv_discard <= 0 || state.target_window_rows <= 0 || state.target_window_pos.empty()) {
return;
}
dflash_materialize_target_window_features(state);
const size_t row_width = (size_t) state.n_target_features;
const llama_pos discard_begin = kv_keep;
const llama_pos discard_end = kv_keep + kv_discard;
@ -2845,7 +3182,10 @@ static void dflash_context_shift(
state.target_window = std::move(shifted_rows);
state.target_window_pos = std::move(shifted_positions);
state.target_window_rows = (int32_t) state.target_window_pos.size();
dflash_ring_reset_rows(state, state.target_window.data(), state.target_window_rows);
state.last_target_pos = state.target_window_pos.empty() ? -1 : state.target_window_pos.back();
dflash_record_window_update(state, 0, state.target_window_rows, true);
llama_reset_dflash_kv_cache_state(state.ctx_dft);
state.n_context_shifts++;
}
@ -2959,8 +3299,9 @@ int32_t common_speculative_on_target_batch(
}
}
dflash_append_breakdown append_breakdown;
const int64_t t_append_us = ggml_time_us();
if (!dflash_append_target_features(*dflash_state, features, batch, seq_id)) {
if (!dflash_append_target_features(*dflash_state, features, batch, seq_id, &append_breakdown)) {
return -1;
}
@ -2971,8 +3312,20 @@ int32_t common_speculative_on_target_batch(
dflash_state->n_warmup_append_rows += (size_t) batch.n_tokens;
} else {
dflash_state->t_accept_append_us += append_us;
dflash_state->t_accept_append_filter_us += append_breakdown.filter_us;
dflash_state->t_accept_append_window_alloc_us += append_breakdown.window_alloc_us;
dflash_state->t_accept_append_replace_us += append_breakdown.replace_us;
dflash_state->t_accept_append_keep_old_us += append_breakdown.keep_old_us;
dflash_state->t_accept_append_new_rows_us += append_breakdown.new_rows_us;
dflash_state->t_accept_append_commit_detail_us += append_breakdown.commit_us;
dflash_state->t_accept_append_log_us += append_breakdown.log_us;
dflash_state->n_accept_append_calls++;
dflash_state->n_accept_append_rows += (size_t) batch.n_tokens;
if (append_breakdown.replace_call) {
dflash_state->n_accept_append_replace_calls++;
} else {
dflash_state->n_accept_append_slide_calls++;
}
}
return 0;

View File

@ -23,38 +23,132 @@ ggml_cgraph * llm_build_context::build_dflash_kv_cache() {
const int64_t ctx_len = lctx.dflash_visible_cross_ctx > 0
? (int64_t) lctx.dflash_visible_cross_ctx
: std::max<int64_t>(1, (int64_t) cparams.n_ctx - (int64_t) hparams.dflash_block_size);
const int64_t update_rows = std::max<int64_t>(1, lctx.dflash_kv_cache_update_rows > 0 ? lctx.dflash_kv_cache_update_rows : ctx_len);
const int32_t write_pos = lctx.dflash_kv_cache_write_pos;
GGML_ASSERT(n_embd_head_k == n_embd_head_v);
GGML_ASSERT(n_target_features > 0);
GGML_ASSERT(lctx.ensure_dflash_kv_cache_tensors((int32_t) ctx_len));
GGML_ASSERT(update_rows > 0 && update_rows <= ctx_len);
GGML_ASSERT(write_pos >= 0 && write_pos < ctx_len);
ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes((int) std::max<int64_t>(1, ctx_len)) + 24 * n_layer, false);
ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes((int) std::max<int64_t>(1, update_rows)) + 24 * n_layer, false);
lctx.dflash_kv_input_target_features = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_target_features, ctx_len);
lctx.dflash_kv_input_target_features = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_target_features, update_rows);
ggml_set_input(lctx.dflash_kv_input_target_features);
cb(lctx.dflash_kv_input_target_features, "dflash_kv_input_target_features", -1);
lctx.dflash_kv_input_pos_ctx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ctx_len);
lctx.dflash_kv_input_pos_ctx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, update_rows);
ggml_set_input(lctx.dflash_kv_input_pos_ctx);
cb(lctx.dflash_kv_input_pos_ctx, "dflash_kv_input_pos_ctx", -1);
ggml_tensor * fused_target = llm_build_lora_mm(lctx, ctx0, model.dflash_fc, lctx.dflash_kv_input_target_features);
fused_target = llm_build_norm(ctx0, fused_target, hparams, model.dflash_hidden_norm, nullptr, LLM_NORM_RMS, cb, -1);
cb(fused_target, "dflash_kv_fused_target", -1);
for (int il = 0; il < n_layer; ++il) {
GGML_ASSERT((size_t) il < lctx.dflash_k_ctx_cache.size());
GGML_ASSERT((size_t) il < lctx.dflash_v_ctx_cache.size());
ggml_tensor * Kcur_ctx = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, fused_target);
Kcur_ctx = ggml_reshape_3d(ctx0, Kcur_ctx, n_embd_head_k, n_head_kv, ctx_len);
ggml_tensor * Kcur_ctx_proj = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, fused_target);
cb(Kcur_ctx_proj, "dflash_kv_k_proj", il);
ggml_tensor * Kcur_ctx = ggml_reshape_3d(ctx0, Kcur_ctx_proj, n_embd_head_k, n_head_kv, update_rows);
Kcur_ctx = llm_build_norm(ctx0, Kcur_ctx, hparams, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, cb, il);
cb(Kcur_ctx, "dflash_kv_k_norm", il);
Kcur_ctx = ggml_rope_ext(ctx0, Kcur_ctx, lctx.dflash_kv_input_pos_ctx, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
cb(Kcur_ctx, "dflash_kv_k_rope", il);
ggml_tensor * Vcur_ctx = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, fused_target);
Vcur_ctx = ggml_reshape_3d(ctx0, Vcur_ctx, n_embd_head_v, n_head_kv, ctx_len);
cb(Vcur_ctx, "dflash_kv_v_proj", il);
Vcur_ctx = ggml_reshape_3d(ctx0, Vcur_ctx, n_embd_head_v, n_head_kv, update_rows);
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur_ctx, lctx.dflash_k_ctx_cache[(size_t) il]));
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur_ctx, lctx.dflash_v_ctx_cache[(size_t) il]));
const int32_t first_rows = std::min<int32_t>((int32_t) update_rows, (int32_t) ctx_len - write_pos);
const int32_t second_rows = (int32_t) update_rows - first_rows;
if (first_rows > 0) {
ggml_tensor * Ksrc_first = first_rows == update_rows
? Kcur_ctx
: ggml_view_3d(ctx0, Kcur_ctx,
Kcur_ctx->ne[0],
Kcur_ctx->ne[1],
first_rows,
Kcur_ctx->nb[1],
Kcur_ctx->nb[2],
0);
ggml_tensor * Vsrc_first = first_rows == update_rows
? Vcur_ctx
: ggml_view_3d(ctx0, Vcur_ctx,
Vcur_ctx->ne[0],
Vcur_ctx->ne[1],
first_rows,
Vcur_ctx->nb[1],
Vcur_ctx->nb[2],
0);
ggml_tensor * Kdst_first = ggml_view_3d(ctx0, lctx.dflash_k_ctx_cache[(size_t) il],
lctx.dflash_k_ctx_cache[(size_t) il]->ne[0],
lctx.dflash_k_ctx_cache[(size_t) il]->ne[1],
first_rows,
lctx.dflash_k_ctx_cache[(size_t) il]->nb[1],
lctx.dflash_k_ctx_cache[(size_t) il]->nb[2],
(size_t) write_pos * lctx.dflash_k_ctx_cache[(size_t) il]->nb[2]);
ggml_tensor * Vdst_first = ggml_view_3d(ctx0, lctx.dflash_v_ctx_cache[(size_t) il],
lctx.dflash_v_ctx_cache[(size_t) il]->ne[0],
lctx.dflash_v_ctx_cache[(size_t) il]->ne[1],
first_rows,
lctx.dflash_v_ctx_cache[(size_t) il]->nb[1],
lctx.dflash_v_ctx_cache[(size_t) il]->nb[2],
(size_t) write_pos * lctx.dflash_v_ctx_cache[(size_t) il]->nb[2]);
ggml_tensor * Kstore_first = ggml_cpy(ctx0, Ksrc_first, Kdst_first);
cb(Kstore_first, "dflash_kv_k_store", il);
ggml_build_forward_expand(gf, Kstore_first);
ggml_tensor * Vstore_first = ggml_cpy(ctx0, Vsrc_first, Vdst_first);
cb(Vstore_first, "dflash_kv_v_store", il);
ggml_build_forward_expand(gf, Vstore_first);
}
if (second_rows > 0) {
ggml_tensor * Ksrc_second = ggml_view_3d(ctx0, Kcur_ctx,
Kcur_ctx->ne[0],
Kcur_ctx->ne[1],
second_rows,
Kcur_ctx->nb[1],
Kcur_ctx->nb[2],
(size_t) first_rows * Kcur_ctx->nb[2]);
ggml_tensor * Vsrc_second = ggml_view_3d(ctx0, Vcur_ctx,
Vcur_ctx->ne[0],
Vcur_ctx->ne[1],
second_rows,
Vcur_ctx->nb[1],
Vcur_ctx->nb[2],
(size_t) first_rows * Vcur_ctx->nb[2]);
ggml_tensor * Kdst_second = ggml_view_3d(ctx0, lctx.dflash_k_ctx_cache[(size_t) il],
lctx.dflash_k_ctx_cache[(size_t) il]->ne[0],
lctx.dflash_k_ctx_cache[(size_t) il]->ne[1],
second_rows,
lctx.dflash_k_ctx_cache[(size_t) il]->nb[1],
lctx.dflash_k_ctx_cache[(size_t) il]->nb[2],
0);
ggml_tensor * Vdst_second = ggml_view_3d(ctx0, lctx.dflash_v_ctx_cache[(size_t) il],
lctx.dflash_v_ctx_cache[(size_t) il]->ne[0],
lctx.dflash_v_ctx_cache[(size_t) il]->ne[1],
second_rows,
lctx.dflash_v_ctx_cache[(size_t) il]->nb[1],
lctx.dflash_v_ctx_cache[(size_t) il]->nb[2],
0);
ggml_tensor * Kstore_second = ggml_cpy(ctx0, Ksrc_second, Kdst_second);
cb(Kstore_second, "dflash_kv_k_store", il);
ggml_build_forward_expand(gf, Kstore_second);
ggml_tensor * Vstore_second = ggml_cpy(ctx0, Vsrc_second, Vdst_second);
cb(Vstore_second, "dflash_kv_v_store", il);
ggml_build_forward_expand(gf, Vstore_second);
}
}
return gf;
@ -69,12 +163,17 @@ ggml_cgraph * llm_build_context::build_dflash() {
const int64_t ctx_len = lctx.dflash_visible_cross_ctx > 0
? (int64_t) lctx.dflash_visible_cross_ctx
: std::max<int64_t>(1, (int64_t) cparams.n_ctx - (int64_t) hparams.dflash_block_size);
const int32_t cache_rows = use_kv_cache ? std::clamp(lctx.dflash_kv_cache_view_n_filled, 0, (int32_t) ctx_len) : 0;
const int32_t cache_write_pos = use_kv_cache && ctx_len > 0
? ((lctx.dflash_kv_cache_view_write_pos % (int32_t) ctx_len) + (int32_t) ctx_len) % (int32_t) ctx_len
: 0;
const int64_t n_kv_total = GGML_PAD(ctx_len + n_tokens, flash_attn ? 256 : 32);
const int64_t n_kv_pad = n_kv_total - (ctx_len + n_tokens);
GGML_ASSERT(n_embd_head_k == n_embd_head_v);
GGML_ASSERT(n_target_features > 0);
GGML_ASSERT(!use_kv_cache || lctx.ensure_dflash_kv_cache_tensors((int32_t) ctx_len));
GGML_ASSERT(!use_kv_cache || (cache_write_pos >= 0 && cache_write_pos < ctx_len));
ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes((int) std::max<int64_t>(n_tokens, ctx_len)) + 32 * n_layer, false);
@ -160,8 +259,52 @@ ggml_cgraph * llm_build_context::build_dflash() {
ggml_tensor * Kcur_ctx = nullptr;
ggml_tensor * Vcur_ctx = nullptr;
if (use_kv_cache) {
Kcur_ctx = lctx.dflash_k_ctx_cache[(size_t) il];
Vcur_ctx = lctx.dflash_v_ctx_cache[(size_t) il];
auto build_ordered_cache_view = [&](ggml_tensor * cache) -> ggml_tensor * {
if (!lctx.dflash_kv_cache_view_valid || cache_rows <= 0) {
return cache;
}
if (cache_rows < ctx_len) {
ggml_tensor * zero_pad = ggml_view_3d(ctx0, cache,
cache->ne[0],
cache->ne[1],
ctx_len - cache_rows,
cache->nb[1],
cache->nb[2],
(size_t) cache_rows * cache->nb[2]);
ggml_tensor * valid = ggml_view_3d(ctx0, cache,
cache->ne[0],
cache->ne[1],
cache_rows,
cache->nb[1],
cache->nb[2],
0);
return ggml_concat(ctx0, zero_pad, valid, 2);
}
if (cache_write_pos == 0) {
return cache;
}
ggml_tensor * tail = ggml_view_3d(ctx0, cache,
cache->ne[0],
cache->ne[1],
ctx_len - cache_write_pos,
cache->nb[1],
cache->nb[2],
(size_t) cache_write_pos * cache->nb[2]);
ggml_tensor * head = ggml_view_3d(ctx0, cache,
cache->ne[0],
cache->ne[1],
cache_write_pos,
cache->nb[1],
cache->nb[2],
0);
return ggml_concat(ctx0, tail, head, 2);
};
Kcur_ctx = build_ordered_cache_view(lctx.dflash_k_ctx_cache[(size_t) il]);
Vcur_ctx = build_ordered_cache_view(lctx.dflash_v_ctx_cache[(size_t) il]);
cb(Kcur_ctx, "Kcur_ctx_cache", il);
cb(Vcur_ctx, "Vcur_ctx_cache", il);
} else {

View File

@ -2173,7 +2173,27 @@ struct ggml_cgraph * llm_build_context::llama_build_graph_dflash_kv_cache(llama_
llama_batch dummy;
dummy.n_tokens = 0;
llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { };
llm_build_cb cb = [&](struct ggml_tensor * cur, const char * name, int il) {
if (il >= 0) {
int j = 0;
for (; j < GGML_MAX_NAME - 1; ++j) {
cur->name[j] = name[j];
if (!name[j]) {
break;
}
}
if (j < GGML_MAX_NAME - 3) {
cur->name[j++] = '-';
auto sil = std::to_string(il);
for (int k = 0; k < (int) sil.size() && j < GGML_MAX_NAME - 1; ++k) {
cur->name[j++] = sil[k];
}
}
cur->name[j] = 0;
} else {
ggml_set_name(cur, name);
}
};
struct llm_build_context llm(lctx, dummy, cb, false, false, 0, false, &lctx.dflash_buf_compute_meta);

View File

@ -281,9 +281,17 @@ struct llama_context {
const float * dflash_target_features = nullptr;
size_t dflash_target_features_n_floats = 0;
int32_t dflash_target_features_n_rows = 0;
const float * dflash_target_append_features = nullptr;
size_t dflash_target_append_features_n_floats = 0;
int32_t dflash_target_append_features_n_rows = 0;
const llama_pos * dflash_target_positions = nullptr;
size_t dflash_target_positions_n = 0;
uint64_t dflash_target_window_version = 0;
int32_t dflash_target_window_keep_rows = 0;
int32_t dflash_target_window_append_rows = 0;
bool dflash_target_window_replace = false;
std::vector<float> dflash_target_features_owned;
std::vector<float> dflash_target_append_features_owned;
std::vector<llama_pos> dflash_target_positions_owned;
std::vector<float> dflash_target_features_padded;
std::vector<float> dflash_feature_view_buffer;
@ -295,6 +303,15 @@ struct llama_context {
std::vector<struct ggml_tensor *> dflash_v_ctx_cache;
struct ggml_context * dflash_cache_ctx = nullptr;
std::vector<ggml_backend_buffer_t> dflash_cache_bufs;
int32_t dflash_kv_cache_write_pos = 0;
int32_t dflash_kv_cache_n_filled = 0;
int32_t dflash_kv_cache_update_rows = 0;
int32_t dflash_kv_cache_reserved_rows = 0;
int32_t dflash_kv_cache_view_write_pos = 0;
int32_t dflash_kv_cache_view_n_filled = 0;
uint64_t dflash_kv_cache_applied_window_version = 0;
bool dflash_kv_cache_valid = false;
bool dflash_kv_cache_view_valid = false;
std::vector<uint8_t> dflash_buf_compute_meta;
ggml_backend_sched_t dflash_sched = nullptr;
struct ggml_tensor * dflash_kv_input_target_features = nullptr;

View File

@ -55,6 +55,56 @@ void llama_dflash_profile_reset(struct llama_context * ctx) {
ctx->dflash_profile = {};
}
void llama_reset_dflash_kv_cache_state(struct llama_context * ctx) {
if (ctx == nullptr) {
return;
}
ctx->dflash_kv_cache_write_pos = 0;
ctx->dflash_kv_cache_n_filled = 0;
ctx->dflash_kv_cache_update_rows = 0;
ctx->dflash_kv_cache_view_write_pos = 0;
ctx->dflash_kv_cache_view_n_filled = 0;
ctx->dflash_kv_cache_applied_window_version = 0;
ctx->dflash_kv_cache_valid = false;
ctx->dflash_kv_cache_view_valid = false;
for (ggml_backend_buffer_t buf : ctx->dflash_cache_bufs) {
if (buf != nullptr) {
ggml_backend_buffer_clear(buf, 0);
}
}
}
llama_dflash_kv_cache_transition llama_plan_dflash_kv_cache_transition_for_ctx(
const struct llama_context * ctx,
const llama_dflash_window_update & window_update,
int32_t n_rows) {
if (ctx == nullptr) {
llama_dflash_kv_cache_transition plan;
plan.rebuild_cache = true;
plan.append_rows = std::clamp(window_update.append_rows, 0, n_rows);
plan.next_n_filled = n_rows;
return plan;
}
const int32_t cross_ctx = ctx->dflash_visible_cross_ctx > 0
? ctx->dflash_visible_cross_ctx
: std::max<int32_t>(1, (int32_t) ctx->cparams.n_ctx - (int32_t) ctx->model.hparams.dflash_block_size);
return llama_plan_dflash_kv_cache_transition(
cross_ctx,
ctx->dflash_kv_cache_n_filled,
ctx->dflash_kv_cache_write_pos,
ctx->dflash_kv_cache_valid,
ctx->dflash_kv_cache_applied_window_version,
window_update.version,
window_update.keep_rows,
window_update.append_rows,
window_update.replace,
n_rows);
}
void llama_set_dflash_visible_cross_ctx(
struct llama_context * ctx,
int32_t cross_ctx) {
@ -205,26 +255,91 @@ static bool llama_set_dflash_target_features_impl(
size_t n_floats,
int32_t n_rows,
const llama_pos * target_positions,
bool copy_data) {
if (ctx == nullptr || target_features == nullptr || n_floats == 0 || n_rows <= 0) {
bool copy_data,
const llama_dflash_window_update * window_update) {
const bool have_full_features = target_features != nullptr && n_floats > 0;
const bool have_append_features = window_update != nullptr &&
window_update->append_features != nullptr &&
window_update->append_floats > 0 &&
window_update->append_rows > 0;
if (ctx == nullptr || n_rows <= 0 || (!have_full_features && !have_append_features)) {
return false;
}
auto & profile = ctx->dflash_profile;
const int64_t t_start_us = ggml_time_us();
const int32_t row_width = n_rows > 0 ? (int32_t) (n_floats / (size_t) n_rows) : 0;
const int32_t row_width = have_full_features
? (n_rows > 0 ? (int32_t) (n_floats / (size_t) n_rows) : 0)
: (window_update->append_rows > 0 ? (int32_t) (window_update->append_floats / (size_t) window_update->append_rows) : 0);
llama_pos first_pos = -1;
llama_pos last_pos = -1;
if (copy_data) {
if (have_full_features && copy_data) {
ctx->dflash_target_features_owned.assign(target_features, target_features + n_floats);
ctx->dflash_target_features = ctx->dflash_target_features_owned.data();
} else {
} else if (have_full_features) {
ctx->dflash_target_features_owned.clear();
ctx->dflash_target_features = target_features;
} else {
ctx->dflash_target_features_owned.clear();
ctx->dflash_target_features = nullptr;
}
ctx->dflash_target_features_n_floats = n_floats;
ctx->dflash_target_features_n_floats = have_full_features ? n_floats : 0;
ctx->dflash_target_features_n_rows = n_rows;
if (have_append_features && copy_data) {
ctx->dflash_target_append_features_owned.assign(
window_update->append_features,
window_update->append_features + window_update->append_floats);
ctx->dflash_target_append_features = ctx->dflash_target_append_features_owned.data();
} else if (have_append_features) {
ctx->dflash_target_append_features_owned.clear();
ctx->dflash_target_append_features = window_update->append_features;
} else {
ctx->dflash_target_append_features_owned.clear();
ctx->dflash_target_append_features = nullptr;
}
ctx->dflash_target_append_features_n_floats = have_append_features ? window_update->append_floats : 0;
ctx->dflash_target_append_features_n_rows = have_append_features ? window_update->append_rows : 0;
ctx->dflash_target_window_version = window_update != nullptr && window_update->version > 0
? window_update->version
: ctx->dflash_target_window_version + 1;
ctx->dflash_target_window_keep_rows = window_update != nullptr
? std::max<int32_t>(0, std::min(n_rows, window_update->keep_rows))
: 0;
ctx->dflash_target_window_append_rows = window_update != nullptr
? std::max<int32_t>(0, std::min(n_rows, window_update->append_rows))
: n_rows;
ctx->dflash_target_window_replace = window_update != nullptr
? window_update->replace
: true;
if (ctx->dflash_target_window_keep_rows + ctx->dflash_target_window_append_rows > n_rows) {
ctx->dflash_target_window_keep_rows = std::max<int32_t>(0, n_rows - ctx->dflash_target_window_append_rows);
}
const int32_t cross_ctx = ctx->dflash_visible_cross_ctx > 0
? ctx->dflash_visible_cross_ctx
: std::max<int32_t>(1, (int32_t) ctx->cparams.n_ctx - (int32_t) ctx->model.hparams.dflash_block_size);
const llama_dflash_window_update cache_window_update = {
ctx->dflash_target_window_version,
ctx->dflash_target_window_keep_rows,
ctx->dflash_target_window_append_rows,
ctx->dflash_target_window_replace,
ctx->dflash_target_append_features,
ctx->dflash_target_append_features_n_floats,
};
const llama_dflash_kv_cache_transition cache_plan = llama_plan_dflash_kv_cache_transition_for_ctx(ctx, cache_window_update, n_rows);
if (cache_plan.cache_up_to_date) {
ctx->dflash_kv_cache_view_n_filled = ctx->dflash_kv_cache_n_filled;
ctx->dflash_kv_cache_view_write_pos = ctx->dflash_kv_cache_write_pos;
ctx->dflash_kv_cache_view_valid = ctx->dflash_kv_cache_valid;
} else if (cross_ctx > 0) {
ctx->dflash_kv_cache_view_n_filled = cache_plan.next_n_filled;
ctx->dflash_kv_cache_view_write_pos = cache_plan.next_write_pos;
ctx->dflash_kv_cache_view_valid = cache_plan.next_n_filled > 0;
}
if (target_positions != nullptr) {
if (copy_data) {
ctx->dflash_target_positions_owned.assign(target_positions, target_positions + n_rows);
@ -243,7 +358,10 @@ static bool llama_set_dflash_target_features_impl(
profile.set_target_copy_calls++;
profile.set_target_copy_us += (uint64_t) (ggml_time_us() - t_start_us);
profile.set_target_rows += (uint64_t) n_rows;
profile.set_target_copy_bytes += n_floats * sizeof(float) + (target_positions ? (size_t) n_rows * sizeof(llama_pos) : 0);
profile.set_target_copy_bytes +=
(have_full_features ? n_floats : 0) * sizeof(float) +
(have_append_features ? window_update->append_floats : 0) * sizeof(float) +
(target_positions ? (size_t) n_rows * sizeof(llama_pos) : 0);
profile.last_n_rows = n_rows;
profile.last_width = row_width;
@ -267,8 +385,9 @@ bool llama_set_dflash_target_features_copy(
const float * target_features,
size_t n_floats,
int32_t n_rows,
const llama_pos * target_positions) {
return llama_set_dflash_target_features_impl(ctx, target_features, n_floats, n_rows, target_positions, true);
const llama_pos * target_positions,
const llama_dflash_window_update * window_update) {
return llama_set_dflash_target_features_impl(ctx, target_features, n_floats, n_rows, target_positions, true, window_update);
}
bool llama_set_dflash_target_features_view(
@ -276,8 +395,9 @@ bool llama_set_dflash_target_features_view(
const float * target_features,
size_t n_floats,
int32_t n_rows,
const llama_pos * target_positions) {
return llama_set_dflash_target_features_impl(ctx, target_features, n_floats, n_rows, target_positions, false);
const llama_pos * target_positions,
const llama_dflash_window_update * window_update) {
return llama_set_dflash_target_features_impl(ctx, target_features, n_floats, n_rows, target_positions, false, window_update);
}
static void llama_record_dflash_capture_phase(

View File

@ -2,6 +2,8 @@
#include "llama.h"
#include <algorithm>
#include <cstdint>
#include <vector>
struct llama_context;
@ -24,6 +26,19 @@ struct llama_spec_feature_view {
};
struct llama_dflash_profile_stats {
uint64_t decode_internal_chunks = 0;
uint64_t decode_graph_rebuilds = 0;
uint64_t decode_sync_profile_points = 0;
uint64_t decode_prelude_us = 0;
uint64_t decode_sched_reset_us = 0;
uint64_t decode_build_graph_us = 0;
uint64_t decode_sched_alloc_graph_us = 0;
uint64_t decode_set_inputs_us = 0;
uint64_t decode_graph_compute_us = 0;
uint64_t decode_result_us = 0;
uint64_t decode_embedding_us = 0;
uint64_t decode_final_sched_reset_us = 0;
uint64_t decode_output_reserve_calls = 0;
uint64_t decode_output_reserve_us = 0;
uint64_t decode_output_reserve_reallocs = 0;
@ -71,6 +86,20 @@ struct llama_dflash_profile_stats {
uint64_t graph_kv_cache_read_concat_pad_calls = 0;
uint64_t graph_kv_cache_cached_bytes = 0;
uint64_t graph_kv_cache_calls = 0;
uint64_t graph_kv_node_fused_target_calls = 0;
uint64_t graph_kv_node_fused_target_us = 0;
uint64_t graph_kv_node_k_proj_calls = 0;
uint64_t graph_kv_node_k_proj_us = 0;
uint64_t graph_kv_node_k_norm_calls = 0;
uint64_t graph_kv_node_k_norm_us = 0;
uint64_t graph_kv_node_k_rope_calls = 0;
uint64_t graph_kv_node_k_rope_us = 0;
uint64_t graph_kv_node_v_proj_calls = 0;
uint64_t graph_kv_node_v_proj_us = 0;
uint64_t graph_kv_node_k_store_calls = 0;
uint64_t graph_kv_node_k_store_us = 0;
uint64_t graph_kv_node_v_store_calls = 0;
uint64_t graph_kv_node_v_store_us = 0;
uint64_t graph_feature_bytes = 0;
uint64_t graph_pos_bytes = 0;
uint64_t graph_mask_bytes = 0;
@ -96,10 +125,76 @@ struct llama_dflash_profile_stats {
llama_pos last_pos_last = -1;
};
struct llama_dflash_window_update {
uint64_t version = 0;
int32_t keep_rows = 0;
int32_t append_rows = 0;
bool replace = false;
const float * append_features = nullptr;
size_t append_floats = 0;
};
struct llama_dflash_kv_cache_transition {
bool cache_up_to_date = false;
bool rebuild_cache = false;
int32_t append_rows = 0;
int32_t next_n_filled = 0;
int32_t next_write_pos = 0;
};
static inline llama_dflash_kv_cache_transition llama_plan_dflash_kv_cache_transition(
int32_t cross_ctx,
int32_t current_n_filled,
int32_t current_write_pos,
bool cache_valid,
uint64_t applied_window_version,
uint64_t target_window_version,
int32_t keep_rows,
int32_t append_rows,
bool replace,
int32_t n_rows) {
llama_dflash_kv_cache_transition plan;
const int32_t safe_cross_ctx = std::max<int32_t>(1, cross_ctx);
const int32_t bounded_n_filled = std::clamp(current_n_filled, 0, safe_cross_ctx);
const int32_t bounded_append_rows = std::clamp(append_rows, 0, n_rows);
const int32_t bounded_keep_rows = std::clamp(keep_rows, 0, n_rows);
const int32_t expected_keep_rows = std::min(bounded_n_filled, std::max<int32_t>(0, safe_cross_ctx - bounded_append_rows));
plan.cache_up_to_date = cache_valid && applied_window_version == target_window_version;
plan.rebuild_cache = !cache_valid || replace || bounded_append_rows <= 0 || bounded_append_rows > n_rows;
if (!plan.rebuild_cache && bounded_keep_rows != expected_keep_rows) {
plan.rebuild_cache = true;
}
plan.append_rows = bounded_append_rows;
if (plan.cache_up_to_date) {
plan.next_n_filled = bounded_n_filled;
plan.next_write_pos = safe_cross_ctx > 0
? ((current_write_pos % safe_cross_ctx) + safe_cross_ctx) % safe_cross_ctx
: 0;
} else if (plan.rebuild_cache) {
plan.next_n_filled = std::min(safe_cross_ctx, n_rows);
plan.next_write_pos = plan.next_n_filled % safe_cross_ctx;
} else {
plan.next_n_filled = std::min(safe_cross_ctx, bounded_n_filled + bounded_append_rows);
plan.next_write_pos = (current_write_pos + bounded_append_rows) % safe_cross_ctx;
}
return plan;
}
llama_dflash_kv_cache_transition llama_plan_dflash_kv_cache_transition_for_ctx(
const struct llama_context * ctx,
const llama_dflash_window_update & window_update,
int32_t n_rows);
uint32_t llama_mtp_state_n_embd(const struct llama_context * ctx);
void llama_dflash_profile_reset(struct llama_context * ctx);
void llama_reset_dflash_kv_cache_state(struct llama_context * ctx);
void llama_set_dflash_visible_cross_ctx(
struct llama_context * ctx,
int32_t cross_ctx);
@ -156,14 +251,16 @@ bool llama_set_dflash_target_features_copy(
const float * target_features,
size_t n_floats,
int32_t n_rows,
const llama_pos * target_positions);
const llama_pos * target_positions,
const llama_dflash_window_update * window_update = nullptr);
bool llama_set_dflash_target_features_view(
struct llama_context * ctx,
const float * target_features,
size_t n_floats,
int32_t n_rows,
const llama_pos * target_positions);
const llama_pos * target_positions,
const llama_dflash_window_update * window_update = nullptr);
bool llama_set_dflash_capture_layers(
struct llama_context * ctx,

View File

@ -171,6 +171,129 @@ static std::vector<std::string> string_split(const std::string& str, const std::
return parts;
}
static bool llama_env_flag_enabled(const char * name) {
const char * env = std::getenv(name);
return env != nullptr && *env != '\0' &&
std::strcmp(env, "0") != 0 &&
std::strcmp(env, "false") != 0 &&
std::strcmp(env, "off") != 0;
}
enum llama_dflash_kv_node_kind {
LLAMA_DFLASH_KV_NODE_NONE = 0,
LLAMA_DFLASH_KV_NODE_FUSED_TARGET,
LLAMA_DFLASH_KV_NODE_K_PROJ,
LLAMA_DFLASH_KV_NODE_K_NORM,
LLAMA_DFLASH_KV_NODE_K_ROPE,
LLAMA_DFLASH_KV_NODE_V_PROJ,
LLAMA_DFLASH_KV_NODE_K_STORE,
LLAMA_DFLASH_KV_NODE_V_STORE,
};
struct llama_dflash_kv_node_profiler {
llama_dflash_profile_stats * profile = nullptr;
int64_t t_start_us = 0;
llama_dflash_kv_node_kind active_kind = LLAMA_DFLASH_KV_NODE_NONE;
};
static bool llama_dflash_tensor_name_has_prefix(const struct ggml_tensor * tensor, const char * prefix) {
if (tensor == nullptr || prefix == nullptr || prefix[0] == '\0') {
return false;
}
return std::strncmp(tensor->name, prefix, std::strlen(prefix)) == 0;
}
static llama_dflash_kv_node_kind llama_dflash_kv_node_kind_from_tensor(const struct ggml_tensor * tensor) {
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_fused_target")) {
return LLAMA_DFLASH_KV_NODE_FUSED_TARGET;
}
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_proj")) {
return LLAMA_DFLASH_KV_NODE_K_PROJ;
}
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_norm")) {
return LLAMA_DFLASH_KV_NODE_K_NORM;
}
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_rope")) {
return LLAMA_DFLASH_KV_NODE_K_ROPE;
}
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_v_proj")) {
return LLAMA_DFLASH_KV_NODE_V_PROJ;
}
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_store")) {
return LLAMA_DFLASH_KV_NODE_K_STORE;
}
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_v_store")) {
return LLAMA_DFLASH_KV_NODE_V_STORE;
}
return LLAMA_DFLASH_KV_NODE_NONE;
}
static void llama_dflash_kv_node_profile_add(
llama_dflash_profile_stats & profile,
llama_dflash_kv_node_kind kind,
uint64_t elapsed_us) {
switch (kind) {
case LLAMA_DFLASH_KV_NODE_FUSED_TARGET:
profile.graph_kv_node_fused_target_calls++;
profile.graph_kv_node_fused_target_us += elapsed_us;
break;
case LLAMA_DFLASH_KV_NODE_K_PROJ:
profile.graph_kv_node_k_proj_calls++;
profile.graph_kv_node_k_proj_us += elapsed_us;
break;
case LLAMA_DFLASH_KV_NODE_K_NORM:
profile.graph_kv_node_k_norm_calls++;
profile.graph_kv_node_k_norm_us += elapsed_us;
break;
case LLAMA_DFLASH_KV_NODE_K_ROPE:
profile.graph_kv_node_k_rope_calls++;
profile.graph_kv_node_k_rope_us += elapsed_us;
break;
case LLAMA_DFLASH_KV_NODE_V_PROJ:
profile.graph_kv_node_v_proj_calls++;
profile.graph_kv_node_v_proj_us += elapsed_us;
break;
case LLAMA_DFLASH_KV_NODE_K_STORE:
profile.graph_kv_node_k_store_calls++;
profile.graph_kv_node_k_store_us += elapsed_us;
break;
case LLAMA_DFLASH_KV_NODE_V_STORE:
profile.graph_kv_node_v_store_calls++;
profile.graph_kv_node_v_store_us += elapsed_us;
break;
case LLAMA_DFLASH_KV_NODE_NONE:
break;
}
}
static bool llama_dflash_kv_node_eval_callback(struct ggml_tensor * tensor, bool ask, void * user_data) {
auto * profiler = static_cast<llama_dflash_kv_node_profiler *>(user_data);
if (profiler == nullptr || profiler->profile == nullptr) {
return false;
}
const llama_dflash_kv_node_kind kind = llama_dflash_kv_node_kind_from_tensor(tensor);
if (ask) {
if (kind == LLAMA_DFLASH_KV_NODE_NONE) {
return false;
}
profiler->active_kind = kind;
profiler->t_start_us = ggml_time_us();
return true;
}
if (kind != LLAMA_DFLASH_KV_NODE_NONE && profiler->active_kind == kind && profiler->t_start_us > 0) {
llama_dflash_kv_node_profile_add(*profiler->profile, kind, (uint64_t) (ggml_time_us() - profiler->t_start_us));
}
profiler->active_kind = LLAMA_DFLASH_KV_NODE_NONE;
profiler->t_start_us = 0;
return true;
}
// extract ip and port from RPC[ip:port] for rpc and keep other device names
static std::vector<rpc_device> extract_device_from_rpc_device(std::vector<std::string> devices) {
std::vector<rpc_device> rpc_servers;
@ -689,6 +812,7 @@ bool llama_context::ensure_dflash_kv_cache_tensors(int32_t cross_ctx) {
}
dflash_profile.last_kv_cache_host_layers = host_layers;
llama_reset_dflash_kv_cache_state(this);
LLAMA_LOG_INFO("%s: DFlash K/V cache placement cross_ctx=%d host_layers=%d/%d first=%s last=%s\n",
__func__,
target_cross_ctx,
@ -703,6 +827,15 @@ bool llama_context::ensure_dflash_kv_cache_tensors(int32_t cross_ctx) {
void llama_context::free_dflash_kv_cache_tensors() {
dflash_k_ctx_cache.clear();
dflash_v_ctx_cache.clear();
dflash_kv_cache_write_pos = 0;
dflash_kv_cache_n_filled = 0;
dflash_kv_cache_update_rows = 0;
dflash_kv_cache_reserved_rows = 0;
dflash_kv_cache_view_write_pos = 0;
dflash_kv_cache_view_n_filled = 0;
dflash_kv_cache_applied_window_version = 0;
dflash_kv_cache_valid = false;
dflash_kv_cache_view_valid = false;
dflash_kv_input_target_features = nullptr;
dflash_kv_input_pos_ctx = nullptr;
dflash_kq_mask_tensor = nullptr;
@ -5271,11 +5404,8 @@ static bool validate_dflash_graph_contract(const llama_context & lctx) {
static bool prepare_dflash_graph_inputs(
struct llama_context & lctx,
uint32_t n_tokens) {
const char * dflash_kv_cache_env = std::getenv("IK_DFLASH_KV_CACHE");
const bool use_kv_cache = dflash_kv_cache_env != nullptr && *dflash_kv_cache_env != '\0' &&
std::strcmp(dflash_kv_cache_env, "0") != 0 &&
std::strcmp(dflash_kv_cache_env, "false") != 0 &&
std::strcmp(dflash_kv_cache_env, "off") != 0;
const bool use_kv_cache = llama_env_flag_enabled("IK_DFLASH_KV_CACHE");
const bool kv_node_timing = llama_env_flag_enabled("IK_DFLASH_KV_NODE_TIMING");
auto & profile = lctx.dflash_profile;
const int32_t cross_ctx = lctx.dflash_visible_cross_ctx > 0
? lctx.dflash_visible_cross_ctx
@ -5304,10 +5434,13 @@ static bool prepare_dflash_graph_inputs(
}
const float * src = lctx.dflash_target_features;
const float * append_src = lctx.dflash_target_append_features;
const llama_pos * src_pos = lctx.dflash_target_positions;
const size_t total_floats = lctx.dflash_target_features_n_floats;
const size_t append_floats = lctx.dflash_target_append_features_n_floats;
const size_t total_positions = lctx.dflash_target_positions_n;
const int32_t n_rows = lctx.dflash_target_features_n_rows;
const int32_t append_rows_available = lctx.dflash_target_append_features_n_rows;
const int32_t width = (int32_t) lctx.model.hparams.dflash_n_target_features;
const int32_t graph_cross_ctx = use_kv_cache
? (lctx.dflash_k_ctx_cache.front() != nullptr ? (int32_t) lctx.dflash_k_ctx_cache.front()->ne[2] : 0)
@ -5330,19 +5463,26 @@ static bool prepare_dflash_graph_inputs(
__func__, graph_cross_ctx, cross_ctx);
return false;
}
if (src == nullptr || total_floats == 0 || n_rows <= 0) {
if (n_rows <= 0) {
profile.graph_shape_failures++;
LLAMA_LOG_ERROR("%s: missing DFlash target features\n", __func__);
LLAMA_LOG_ERROR("%s: missing DFlash target feature rows\n", __func__);
return false;
}
if (n_rows > cross_ctx || total_floats != (size_t) n_rows * (size_t) width) {
const bool have_full_src = src != nullptr && total_floats == (size_t) n_rows * (size_t) width;
if (n_rows > cross_ctx || (src != nullptr && !have_full_src)) {
profile.graph_shape_failures++;
LLAMA_LOG_ERROR("%s: invalid DFlash target feature shape (rows=%d width=%d floats=%zu cross_ctx=%d)\n",
__func__, n_rows, width, total_floats, cross_ctx);
return false;
}
if (!use_kv_cache && !have_full_src) {
profile.graph_shape_failures++;
LLAMA_LOG_ERROR("%s: missing contiguous DFlash target features for inline path\n", __func__);
return false;
}
if (n_kv_total < cross_ctx + (int32_t) n_tokens) {
profile.graph_mask_overflow++;
LLAMA_LOG_ERROR("%s: invalid DFlash mask shape (n_kv_total=%d < cross_ctx+n_tokens=%d)\n",
@ -5351,24 +5491,26 @@ static bool prepare_dflash_graph_inputs(
}
const int32_t left_pad = cross_ctx - n_rows;
const size_t padded_floats = (size_t) cross_ctx * (size_t) width;
const size_t dst_offset = (size_t) left_pad * (size_t) width;
const int64_t t_feature_us = ggml_time_us();
profile.last_left_pad = left_pad;
if (lctx.dflash_target_features_padded.size() != padded_floats) {
lctx.dflash_target_features_padded.resize(padded_floats);
}
if (left_pad == 0 && total_floats == padded_floats) {
std::copy(src, src + total_floats, lctx.dflash_target_features_padded.begin());
} else {
if (dst_offset > 0) {
std::fill(lctx.dflash_target_features_padded.begin(),
lctx.dflash_target_features_padded.begin() + (ptrdiff_t) dst_offset, 0.0f);
if (!use_kv_cache) {
const size_t padded_floats = (size_t) cross_ctx * (size_t) width;
const size_t dst_offset = (size_t) left_pad * (size_t) width;
const int64_t t_feature_us = ggml_time_us();
if (lctx.dflash_target_features_padded.size() != padded_floats) {
lctx.dflash_target_features_padded.resize(padded_floats);
}
std::copy(src, src + total_floats, lctx.dflash_target_features_padded.begin() + (ptrdiff_t) dst_offset);
if (left_pad == 0 && total_floats == padded_floats) {
std::copy(src, src + total_floats, lctx.dflash_target_features_padded.begin());
} else {
if (dst_offset > 0) {
std::fill(lctx.dflash_target_features_padded.begin(),
lctx.dflash_target_features_padded.begin() + (ptrdiff_t) dst_offset, 0.0f);
}
std::copy(src, src + total_floats, lctx.dflash_target_features_padded.begin() + (ptrdiff_t) dst_offset);
}
profile.graph_feature_copy_us += (uint64_t) (ggml_time_us() - t_feature_us);
profile.graph_feature_bytes += padded_floats * sizeof(float);
}
profile.graph_feature_copy_us += (uint64_t) (ggml_time_us() - t_feature_us);
profile.graph_feature_bytes += padded_floats * sizeof(float);
const int64_t t_pos_us = ggml_time_us();
lctx.dflash_pos_ctx_data.resize((size_t) cross_ctx);
@ -5403,22 +5545,32 @@ static bool prepare_dflash_graph_inputs(
profile.graph_pos_bytes += lctx.dflash_pos_ctx_data.size() * sizeof(llama_pos);
if (use_kv_cache) {
const llama_dflash_kv_cache_transition cache_plan = llama_plan_dflash_kv_cache_transition(
cross_ctx,
lctx.dflash_kv_cache_n_filled,
lctx.dflash_kv_cache_write_pos,
lctx.dflash_kv_cache_valid,
lctx.dflash_kv_cache_applied_window_version,
lctx.dflash_target_window_version,
lctx.dflash_target_window_keep_rows,
lctx.dflash_target_window_append_rows,
lctx.dflash_target_window_replace,
n_rows);
const bool have_append_src = append_src != nullptr &&
append_rows_available == cache_plan.append_rows &&
append_floats == (size_t) cache_plan.append_rows * (size_t) width;
const int32_t update_rows = cache_plan.cache_up_to_date
? 0
: (cache_plan.rebuild_cache ? n_rows : cache_plan.append_rows);
const size_t max_nodes = lctx.model.max_nodes((int) std::max<int32_t>(1, cross_ctx)) + 24 * lctx.model.hparams.n_layer;
const size_t meta_size = ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false);
if (lctx.dflash_buf_compute_meta.size() != meta_size) {
lctx.dflash_buf_compute_meta.resize(meta_size);
}
const int64_t t_build_us = ggml_time_us();
ggml_cgraph * gf_kv = llm_build_context::llama_build_graph_dflash_kv_cache(lctx);
profile.graph_kv_cache_build_us += (uint64_t) (ggml_time_us() - t_build_us);
if (gf_kv == nullptr || lctx.dflash_kv_input_target_features == nullptr || lctx.dflash_kv_input_pos_ctx == nullptr) {
profile.graph_shape_failures++;
LLAMA_LOG_ERROR("%s: failed to build DFlash K/V cache graph\n", __func__);
return false;
}
if (lctx.dflash_sched == nullptr) {
if (lctx.dflash_sched == nullptr || lctx.dflash_kv_cache_reserved_rows != cross_ctx) {
std::vector<ggml_backend_buffer_type_t> backend_buft;
backend_buft.reserve(lctx.backends.size());
for (auto * backend : lctx.backends) {
@ -5429,51 +5581,117 @@ static bool prepare_dflash_graph_inputs(
}
}
if (lctx.dflash_sched != nullptr) {
ggml_backend_sched_free(lctx.dflash_sched);
lctx.dflash_sched = nullptr;
}
const int32_t saved_update_rows = lctx.dflash_kv_cache_update_rows;
lctx.dflash_kv_cache_update_rows = cross_ctx;
const int64_t t_build_us = ggml_time_us();
ggml_cgraph * gf_reserve = llm_build_context::llama_build_graph_dflash_kv_cache(lctx);
profile.graph_kv_cache_build_us += (uint64_t) (ggml_time_us() - t_build_us);
lctx.dflash_kv_cache_update_rows = saved_update_rows;
if (gf_reserve == nullptr) {
profile.graph_shape_failures++;
LLAMA_LOG_ERROR("%s: failed to build DFlash K/V cache reserve graph\n", __func__);
return false;
}
const int64_t t_reserve_us = ggml_time_us();
lctx.dflash_sched = ggml_backend_sched_new(lctx.backends.data(), backend_buft.data(), lctx.backends.size(), max_nodes, false);
const bool reserved = lctx.dflash_sched != nullptr && ggml_backend_sched_reserve(lctx.dflash_sched, gf_kv);
const bool reserved = lctx.dflash_sched != nullptr && ggml_backend_sched_reserve(lctx.dflash_sched, gf_reserve);
profile.graph_kv_cache_reserve_us += (uint64_t) (ggml_time_us() - t_reserve_us);
if (!reserved) {
profile.graph_shape_failures++;
LLAMA_LOG_ERROR("%s: failed to initialize DFlash K/V scheduler\n", __func__);
return false;
}
lctx.dflash_kv_cache_reserved_rows = cross_ctx;
}
const int64_t t_reset_us = ggml_time_us();
ggml_backend_sched_reset(lctx.dflash_sched);
profile.graph_kv_cache_reset_us += (uint64_t) (ggml_time_us() - t_reset_us);
if (update_rows > 0) {
const float * update_src = nullptr;
if (have_append_src && update_rows == cache_plan.append_rows) {
update_src = append_src;
} else if (have_full_src) {
update_src = src + (size_t) (n_rows - update_rows) * (size_t) width;
}
const llama_pos * update_pos = src_pos + (n_rows - update_rows);
const int64_t t_alloc_us = ggml_time_us();
ggml_backend_sched_alloc_graph(lctx.dflash_sched, gf_kv);
profile.graph_kv_cache_alloc_us += (uint64_t) (ggml_time_us() - t_alloc_us);
if (update_src == nullptr) {
profile.graph_shape_failures++;
LLAMA_LOG_ERROR("%s: missing DFlash appended target features for cached update (rows=%d append_rows=%d floats=%zu)\n",
__func__, n_rows, update_rows, append_floats);
return false;
}
ggml_backend_t kv_feature_backend = llama_backend_for_tensor(lctx, lctx.dflash_kv_input_target_features);
const int64_t t_feature_upload_us = ggml_time_us();
if (kv_feature_backend != nullptr) {
ggml_backend_tensor_set_async(kv_feature_backend, lctx.dflash_kv_input_target_features, lctx.dflash_target_features_padded.data(), 0, ggml_nbytes(lctx.dflash_kv_input_target_features));
} else {
ggml_backend_tensor_set(lctx.dflash_kv_input_target_features, lctx.dflash_target_features_padded.data(), 0, ggml_nbytes(lctx.dflash_kv_input_target_features));
if (cache_plan.rebuild_cache) {
llama_reset_dflash_kv_cache_state(&lctx);
}
lctx.dflash_kv_cache_update_rows = update_rows;
const int64_t t_build_us = ggml_time_us();
ggml_cgraph * gf_kv = llm_build_context::llama_build_graph_dflash_kv_cache(lctx);
profile.graph_kv_cache_build_us += (uint64_t) (ggml_time_us() - t_build_us);
if (gf_kv == nullptr || lctx.dflash_kv_input_target_features == nullptr || lctx.dflash_kv_input_pos_ctx == nullptr) {
profile.graph_shape_failures++;
LLAMA_LOG_ERROR("%s: failed to build DFlash K/V cache graph\n", __func__);
return false;
}
const int64_t t_reset_us = ggml_time_us();
ggml_backend_sched_reset(lctx.dflash_sched);
profile.graph_kv_cache_reset_us += (uint64_t) (ggml_time_us() - t_reset_us);
const int64_t t_alloc_us = ggml_time_us();
ggml_backend_sched_alloc_graph(lctx.dflash_sched, gf_kv);
profile.graph_kv_cache_alloc_us += (uint64_t) (ggml_time_us() - t_alloc_us);
ggml_backend_t kv_feature_backend = llama_backend_for_tensor(lctx, lctx.dflash_kv_input_target_features);
const int64_t t_feature_upload_us = ggml_time_us();
if (kv_feature_backend != nullptr) {
ggml_backend_tensor_set_async(kv_feature_backend, lctx.dflash_kv_input_target_features, update_src, 0, ggml_nbytes(lctx.dflash_kv_input_target_features));
} else {
ggml_backend_tensor_set(lctx.dflash_kv_input_target_features, update_src, 0, ggml_nbytes(lctx.dflash_kv_input_target_features));
}
profile.graph_kv_cache_feature_upload_us += (uint64_t) (ggml_time_us() - t_feature_upload_us);
profile.graph_feature_bytes += (size_t) update_rows * (size_t) width * sizeof(float);
ggml_backend_t kv_pos_backend = llama_backend_for_tensor(lctx, lctx.dflash_kv_input_pos_ctx);
const int64_t t_pos_upload_us = ggml_time_us();
if (kv_pos_backend != nullptr) {
ggml_backend_tensor_set_async(kv_pos_backend, lctx.dflash_kv_input_pos_ctx, update_pos, 0, ggml_nbytes(lctx.dflash_kv_input_pos_ctx));
} else {
ggml_backend_tensor_set(lctx.dflash_kv_input_pos_ctx, update_pos, 0, ggml_nbytes(lctx.dflash_kv_input_pos_ctx));
}
profile.graph_kv_cache_pos_upload_us += (uint64_t) (ggml_time_us() - t_pos_upload_us);
const int64_t t_kv_cache_us = ggml_time_us();
llama_dflash_kv_node_profiler kv_node_profiler;
if (kv_node_timing) {
kv_node_profiler.profile = &profile;
ggml_backend_sched_set_eval_callback(lctx.dflash_sched, llama_dflash_kv_node_eval_callback, &kv_node_profiler);
}
llama_graph_compute_sched(lctx, lctx.dflash_sched, gf_kv, lctx.cparams.n_threads);
if (kv_node_timing) {
ggml_backend_sched_set_eval_callback(lctx.dflash_sched, nullptr, nullptr);
}
profile.graph_kv_cache_compute_us += (uint64_t) (ggml_time_us() - t_kv_cache_us);
const int64_t t_sync_us = ggml_time_us();
ggml_backend_sched_synchronize(lctx.dflash_sched);
profile.graph_kv_cache_sync_us += (uint64_t) (ggml_time_us() - t_sync_us);
profile.graph_kv_cache_calls++;
lctx.dflash_kv_cache_n_filled = std::min(cross_ctx, lctx.dflash_kv_cache_n_filled + update_rows);
lctx.dflash_kv_cache_write_pos = (lctx.dflash_kv_cache_write_pos + update_rows) % cross_ctx;
lctx.dflash_kv_cache_applied_window_version = lctx.dflash_target_window_version;
lctx.dflash_kv_cache_valid = true;
lctx.dflash_kv_cache_view_n_filled = lctx.dflash_kv_cache_n_filled;
lctx.dflash_kv_cache_view_write_pos = lctx.dflash_kv_cache_write_pos;
lctx.dflash_kv_cache_view_valid = true;
}
profile.graph_kv_cache_feature_upload_us += (uint64_t) (ggml_time_us() - t_feature_upload_us);
ggml_backend_t kv_pos_backend = llama_backend_for_tensor(lctx, lctx.dflash_kv_input_pos_ctx);
const int64_t t_pos_upload_us = ggml_time_us();
if (kv_pos_backend != nullptr) {
ggml_backend_tensor_set_async(kv_pos_backend, lctx.dflash_kv_input_pos_ctx, lctx.dflash_pos_ctx_data.data(), 0, ggml_nbytes(lctx.dflash_kv_input_pos_ctx));
} else {
ggml_backend_tensor_set(lctx.dflash_kv_input_pos_ctx, lctx.dflash_pos_ctx_data.data(), 0, ggml_nbytes(lctx.dflash_kv_input_pos_ctx));
}
profile.graph_kv_cache_pos_upload_us += (uint64_t) (ggml_time_us() - t_pos_upload_us);
const int64_t t_kv_cache_us = ggml_time_us();
llama_graph_compute_sched(lctx, lctx.dflash_sched, gf_kv, lctx.cparams.n_threads);
profile.graph_kv_cache_compute_us += (uint64_t) (ggml_time_us() - t_kv_cache_us);
const int64_t t_sync_us = ggml_time_us();
ggml_backend_sched_synchronize(lctx.dflash_sched);
profile.graph_kv_cache_sync_us += (uint64_t) (ggml_time_us() - t_sync_us);
profile.graph_kv_cache_calls++;
} else {
ggml_backend_tensor_set(lctx.inp_dflash_target_features, lctx.dflash_target_features_padded.data(), 0, ggml_nbytes(lctx.inp_dflash_target_features));
ggml_backend_tensor_set(lctx.inp_dflash_pos_ctx, lctx.dflash_pos_ctx_data.data(), 0, ggml_nbytes(lctx.inp_dflash_pos_ctx));
@ -5586,6 +5804,9 @@ static int llama_decode_internal(
}
lctx.n_queued_tokens += n_tokens_all;
auto * dflash_profile = lctx.model.arch == LLM_ARCH_DFLASH_DRAFT ? &lctx.dflash_profile : nullptr;
const bool dflash_decode_timing = dflash_profile != nullptr && llama_env_flag_enabled("IK_DFLASH_DECODE_TIMING");
auto & kv_self = lctx.kv_self;
const int64_t n_embd = hparams.n_embd;
@ -5670,6 +5891,10 @@ static int llama_decode_internal(
#if IK_PRINT_TIMING
auto tim1 = ggml_time_us();
#endif
const int64_t t_dflash_prelude_us = dflash_decode_timing ? ggml_time_us() : 0;
if (dflash_decode_timing) {
dflash_profile->decode_internal_chunks++;
}
uint32_t n_tokens = std::min(n_ubatch, n_tokens_all - cur_token);
if (llm_arch_is_hybrid(model.arch) &&
n_tokens > 1 &&
@ -5804,6 +6029,9 @@ static int llama_decode_internal(
auto tim2 = ggml_time_us();
printf("prelude(...): %d us\n", int(tim2-tim1));
#endif
if (dflash_decode_timing) {
dflash_profile->decode_prelude_us += (uint64_t) (ggml_time_us() - t_dflash_prelude_us);
}
#if IK_PRINT_TIMING
tim1 = ggml_time_us();
@ -5811,30 +6039,45 @@ static int llama_decode_internal(
auto & prev = cparams.mtp_op_type == MTP_OP_NONE ? lctx.prev : lctx.prev_mtp;
ggml_cgraph * gf = nullptr;
if (!lctx.can_reuse_graph(u_batch)) {
if (dflash_decode_timing) {
dflash_profile->decode_graph_rebuilds++;
}
const int64_t t_dflash_sched_reset_us = dflash_decode_timing ? ggml_time_us() : 0;
lctx.reset_scheduler();
ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
#if IK_PRINT_TIMING
tim2 = ggml_time_us();
printf("sched_reset(...): %d us\n", int(tim2-tim1));
#endif
if (dflash_decode_timing) {
dflash_profile->decode_sched_reset_us += (uint64_t) (ggml_time_us() - t_dflash_sched_reset_us);
}
#if IK_PRINT_TIMING
tim1 = ggml_time_us();
#endif
const int64_t t_dflash_build_graph_us = dflash_decode_timing ? ggml_time_us() : 0;
gf = llm_build_context::llama_build_graph(lctx, u_batch, false);
#if IK_PRINT_TIMING
tim2 = ggml_time_us();
printf("build_graph(...): %d us\n", int(tim2-tim1));
#endif
if (dflash_decode_timing) {
dflash_profile->decode_build_graph_us += (uint64_t) (ggml_time_us() - t_dflash_build_graph_us);
}
#if IK_PRINT_TIMING
tim1 = ggml_time_us();
#endif
const int64_t t_dflash_sched_alloc_us = dflash_decode_timing ? ggml_time_us() : 0;
ggml_backend_sched_alloc_graph(lctx.sched, gf);
#if IK_PRINT_TIMING
tim2 = ggml_time_us();
printf("sched_alloc_graph(...): %d us\n", int(tim2-tim1));
#endif
if (dflash_decode_timing) {
dflash_profile->decode_sched_alloc_graph_us += (uint64_t) (ggml_time_us() - t_dflash_sched_alloc_us);
}
//if (u_batch.n_tokens == 1 && u_batch.embd == nullptr && lctx.cparams.graph_reuse) {
if (u_batch.embd == nullptr && lctx.cparams.graph_reuse &&
!(lctx.model.arch == LLM_ARCH_GEMMA4_MTP && lctx.mtp_target_ctx != nullptr)) {
@ -5855,16 +6098,15 @@ static int llama_decode_internal(
}
}
if (lctx.model.arch == LLM_ARCH_DFLASH_DRAFT) {
auto & profile = lctx.dflash_profile;
profile.decode_prepare_calls++;
if (dflash_profile != nullptr) {
dflash_profile->decode_prepare_calls++;
const int64_t t_prepare_dflash_us = ggml_time_us();
if (!prepare_dflash_graph_inputs(lctx, n_tokens)) {
profile.decode_prepare_failures++;
profile.decode_prepare_us += (uint64_t) (ggml_time_us() - t_prepare_dflash_us);
dflash_profile->decode_prepare_failures++;
dflash_profile->decode_prepare_us += (uint64_t) (ggml_time_us() - t_prepare_dflash_us);
return GGML_STATUS_FAILED;
}
profile.decode_prepare_us += (uint64_t) (ggml_time_us() - t_prepare_dflash_us);
dflash_profile->decode_prepare_us += (uint64_t) (ggml_time_us() - t_prepare_dflash_us);
}
// the output is always the last tensor in the graph
@ -5910,16 +6152,26 @@ static int llama_decode_internal(
#if IK_PRINT_TIMING == 1
tim1 = ggml_time_us();
#endif
const int64_t t_dflash_set_inputs_us = dflash_decode_timing ? ggml_time_us() : 0;
llama_set_inputs(lctx, u_batch);
#if IK_PRINT_TIMING == 1
tim2 = ggml_time_us();
printf("set_inputs(...): %d us\n", int(tim2-tim1));
#endif
if (dflash_decode_timing) {
dflash_profile->decode_set_inputs_us += (uint64_t) (ggml_time_us() - t_dflash_set_inputs_us);
}
#if IK_PRINT_TIMING
tim1 = ggml_time_us();
#endif
const int64_t t_dflash_graph_compute_us = dflash_decode_timing ? ggml_time_us() : 0;
llama_graph_compute(lctx, gf, n_threads);
if (dflash_decode_timing) {
llama_synchronize(&lctx);
dflash_profile->decode_sync_profile_points++;
dflash_profile->decode_graph_compute_us += (uint64_t) (ggml_time_us() - t_dflash_graph_compute_us);
}
#if IK_PRINT_TIMING
llama_synchronize(&lctx);
tim2 = ggml_time_us();
@ -5950,6 +6202,7 @@ static int llama_decode_internal(
#if IK_PRINT_TIMING
tim1 = ggml_time_us();
#endif
const int64_t t_dflash_get_result_us = dflash_decode_timing ? ggml_time_us() : 0;
// Do not process logits if MTP is only updating the KV cache.
if (cparams.mtp_op_type != MTP_OP_WARMUP) { // && cparams.mtp_op_type != MTP_OP_UPDATE_ACCEPTED) {
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched, res);
@ -5980,6 +6233,11 @@ static int llama_decode_internal(
}
}
}
if (dflash_decode_timing) {
llama_synchronize(&lctx);
dflash_profile->decode_sync_profile_points++;
dflash_profile->decode_result_us += (uint64_t) (ggml_time_us() - t_dflash_get_result_us);
}
#if IK_PRINT_TIMING
tim2 = ggml_time_us();
printf("get_result(...): %d us\n", int(tim2-tim1));
@ -5992,6 +6250,7 @@ static int llama_decode_internal(
#if IK_PRINT_TIMING
tim1 = ggml_time_us();
#endif
const int64_t t_dflash_get_embedding_us = dflash_decode_timing ? ggml_time_us() : 0;
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched, embd);
GGML_ASSERT(backend_embd != nullptr);
@ -6031,6 +6290,11 @@ static int llama_decode_internal(
GGML_ABORT("unknown pooling type");
}
}
if (dflash_decode_timing) {
llama_synchronize(&lctx);
dflash_profile->decode_sync_profile_points++;
dflash_profile->decode_embedding_us += (uint64_t) (ggml_time_us() - t_dflash_get_embedding_us);
}
#if IK_PRINT_TIMING
tim2 = ggml_time_us();
printf("get_embedding(...): %d us\n", int(tim2-tim1));
@ -6074,9 +6338,13 @@ static int llama_decode_internal(
#if IK_PRINT_TIMING
auto tim1 = ggml_time_us();
#endif
const int64_t t_dflash_final_sched_reset_us = dflash_decode_timing ? ggml_time_us() : 0;
if (!lctx.prev) {
lctx.reset_scheduler();
}
if (dflash_decode_timing) {
dflash_profile->decode_final_sched_reset_us += (uint64_t) (ggml_time_us() - t_dflash_final_sched_reset_us);
}
#if IK_PRINT_TIMING
auto tim2 = ggml_time_us();
printf("sched_reset(...): %d us\n", int(tim2-tim1));