mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-06-28 04:30:15 -05:00
Use windows update in kv cache
This commit is contained in:
parent
1369e68471
commit
ed403dca27
@ -253,6 +253,8 @@ static const common_speculative_state_mtp * common_speculative_get_mtp_state(con
|
||||
static common_speculative_state_dflash * common_speculative_get_dflash_state(common_speculative * spec);
|
||||
static const common_speculative_state_dflash * common_speculative_get_dflash_state(const common_speculative * spec);
|
||||
static int32_t common_speculative_feature_width(const common_speculative * spec);
|
||||
static void dflash_materialize_target_window_features(common_speculative_state_dflash & state);
|
||||
static void dflash_ring_reset_rows(common_speculative_state_dflash & state, const float * rows, int32_t n_rows);
|
||||
static void dflash_append_target_features(
|
||||
common_speculative_state_dflash & state,
|
||||
const float * feature_rows,
|
||||
@ -284,6 +286,17 @@ static bool dflash_contract_log_enabled() {
|
||||
std::strcmp(env, "off") != 0;
|
||||
}
|
||||
|
||||
static bool dflash_use_kv_cache_experiment() {
|
||||
const char * env = std::getenv("IK_DFLASH_KV_CACHE");
|
||||
if (env == nullptr || *env == '\0') {
|
||||
return false;
|
||||
}
|
||||
|
||||
return std::strcmp(env, "0") != 0 &&
|
||||
std::strcmp(env, "false") != 0 &&
|
||||
std::strcmp(env, "off") != 0;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static std::string dflash_contract_format_values(
|
||||
const std::vector<T> & values,
|
||||
@ -479,7 +492,18 @@ struct common_speculative_state_dflash : public common_speculative_state {
|
||||
std::vector<int32_t> target_layer_ids;
|
||||
std::vector<float> target_window;
|
||||
std::vector<llama_pos> target_window_pos;
|
||||
std::vector<float> target_window_stage;
|
||||
std::vector<llama_pos> target_window_pos_stage;
|
||||
std::vector<float> target_window_ring;
|
||||
std::vector<float> target_window_append_features;
|
||||
int32_t target_window_rows = 0;
|
||||
int32_t target_window_ring_write_pos = 0;
|
||||
int32_t target_window_ring_filled = 0;
|
||||
uint64_t target_window_version = 0;
|
||||
int32_t target_window_keep_rows = 0;
|
||||
int32_t target_window_append_rows = 0;
|
||||
bool target_window_replace = false;
|
||||
bool target_window_materialized = false;
|
||||
llama_pos last_target_pos = -1;
|
||||
size_t n_window_updates = 0;
|
||||
size_t n_rows_seen = 0;
|
||||
@ -497,6 +521,13 @@ struct common_speculative_state_dflash : public common_speculative_state {
|
||||
uint64_t t_accept_output_copy_us = 0;
|
||||
uint64_t t_accept_commit_us = 0;
|
||||
uint64_t t_accept_append_us = 0;
|
||||
uint64_t t_accept_append_filter_us = 0;
|
||||
uint64_t t_accept_append_window_alloc_us = 0;
|
||||
uint64_t t_accept_append_replace_us = 0;
|
||||
uint64_t t_accept_append_keep_old_us = 0;
|
||||
uint64_t t_accept_append_new_rows_us = 0;
|
||||
uint64_t t_accept_append_commit_detail_us = 0;
|
||||
uint64_t t_accept_append_log_us = 0;
|
||||
size_t n_warmup_collect_calls = 0;
|
||||
size_t n_warmup_collect_rows = 0;
|
||||
size_t n_warmup_append_calls = 0;
|
||||
@ -507,6 +538,8 @@ struct common_speculative_state_dflash : public common_speculative_state {
|
||||
size_t n_accept_commit_rows = 0;
|
||||
size_t n_accept_append_calls = 0;
|
||||
size_t n_accept_append_rows = 0;
|
||||
size_t n_accept_append_replace_calls = 0;
|
||||
size_t n_accept_append_slide_calls = 0;
|
||||
|
||||
common_speculative_state_dflash(
|
||||
enum common_speculative_type type,
|
||||
@ -614,6 +647,12 @@ struct common_speculative_state_dflash : public common_speculative_state {
|
||||
}
|
||||
|
||||
batch = llama_batch_init(std::max(1, block_size), 0, 1);
|
||||
target_window.reserve((size_t) this->cross_ctx * (size_t) n_target_features);
|
||||
target_window_stage.reserve((size_t) this->cross_ctx * (size_t) n_target_features);
|
||||
target_window_ring.resize((size_t) this->cross_ctx * (size_t) n_target_features);
|
||||
target_window_append_features.reserve((size_t) this->cross_ctx * (size_t) n_target_features);
|
||||
target_window_pos.reserve((size_t) this->cross_ctx);
|
||||
target_window_pos_stage.reserve((size_t) this->cross_ctx);
|
||||
ready = true;
|
||||
|
||||
llama_set_dflash_visible_cross_ctx(ctx_dft, this->cross_ctx);
|
||||
@ -648,6 +687,7 @@ struct common_speculative_state_dflash : public common_speculative_state {
|
||||
void begin(const llama_tokens & prompt) override {
|
||||
GGML_UNUSED(prompt);
|
||||
llama_kv_cache_clear(ctx_dft);
|
||||
llama_reset_dflash_kv_cache_state(ctx_dft);
|
||||
n_window_updates = 0;
|
||||
n_rows_seen = 0;
|
||||
n_rows_dropped = 0;
|
||||
@ -663,6 +703,13 @@ struct common_speculative_state_dflash : public common_speculative_state {
|
||||
t_accept_output_copy_us = 0;
|
||||
t_accept_commit_us = 0;
|
||||
t_accept_append_us = 0;
|
||||
t_accept_append_filter_us = 0;
|
||||
t_accept_append_window_alloc_us = 0;
|
||||
t_accept_append_replace_us = 0;
|
||||
t_accept_append_keep_old_us = 0;
|
||||
t_accept_append_new_rows_us = 0;
|
||||
t_accept_append_commit_detail_us = 0;
|
||||
t_accept_append_log_us = 0;
|
||||
n_warmup_collect_calls = 0;
|
||||
n_warmup_collect_rows = 0;
|
||||
n_warmup_append_calls = 0;
|
||||
@ -673,6 +720,8 @@ struct common_speculative_state_dflash : public common_speculative_state {
|
||||
n_accept_commit_rows = 0;
|
||||
n_accept_append_calls = 0;
|
||||
n_accept_append_rows = 0;
|
||||
n_accept_append_replace_calls = 0;
|
||||
n_accept_append_slide_calls = 0;
|
||||
llama_dflash_profile_reset(ctx_tgt);
|
||||
llama_dflash_profile_reset(ctx_dft);
|
||||
}
|
||||
@ -695,7 +744,33 @@ struct common_speculative_state_dflash : public common_speculative_state {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!llama_set_dflash_target_features_view(ctx_dft, target_window.data(), target_window.size(), target_window_rows, target_window_pos.data())) {
|
||||
const bool use_kv_cache = dflash_use_kv_cache_experiment();
|
||||
const float * target_features = nullptr;
|
||||
size_t target_feature_floats = 0;
|
||||
llama_dflash_window_update window_update = {
|
||||
target_window_version,
|
||||
target_window_keep_rows,
|
||||
target_window_append_rows,
|
||||
target_window_replace,
|
||||
target_window_append_features.empty() ? nullptr : target_window_append_features.data(),
|
||||
target_window_append_features.size(),
|
||||
};
|
||||
const llama_dflash_kv_cache_transition cache_plan = use_kv_cache
|
||||
? llama_plan_dflash_kv_cache_transition_for_ctx(ctx_dft, window_update, target_window_rows)
|
||||
: llama_dflash_kv_cache_transition{};
|
||||
|
||||
if (!use_kv_cache || cache_plan.rebuild_cache) {
|
||||
dflash_materialize_target_window_features(*this);
|
||||
target_features = target_window.data();
|
||||
target_feature_floats = target_window.size();
|
||||
}
|
||||
if (use_kv_cache && cache_plan.rebuild_cache) {
|
||||
window_update.append_features = target_window.data();
|
||||
window_update.append_floats = target_window.size();
|
||||
window_update.append_rows = target_window_rows;
|
||||
}
|
||||
|
||||
if (!llama_set_dflash_target_features_view(ctx_dft, target_features, target_feature_floats, target_window_rows, target_window_pos.data(), &window_update)) {
|
||||
LOG_ERR("%s: failed to set DFlash target features\n", __func__);
|
||||
n_set_target_fail++;
|
||||
return;
|
||||
@ -2522,6 +2597,24 @@ void common_speculative_print_stats(const common_speculative * spec, double slot
|
||||
const double kv_compute_ms = (double) graph_stats.graph_kv_cache_compute_us / 1000.0;
|
||||
const double kv_sync_ms = (double) graph_stats.graph_kv_cache_sync_us / 1000.0;
|
||||
const double replay_append_ms = (double) dflash_state->t_accept_append_us / 1000.0;
|
||||
const double feature_path_ms = (double) (
|
||||
capture_stats.capture_prepare_sync_us +
|
||||
capture_stats.capture_materialize_us +
|
||||
graph_stats.set_target_copy_us +
|
||||
graph_stats.graph_feature_copy_us +
|
||||
graph_stats.graph_pos_copy_us +
|
||||
graph_stats.graph_mask_build_us) / 1000.0;
|
||||
const double decode_internal_ms = (double) (
|
||||
graph_stats.decode_prelude_us +
|
||||
graph_stats.decode_sched_reset_us +
|
||||
graph_stats.decode_build_graph_us +
|
||||
graph_stats.decode_sched_alloc_graph_us +
|
||||
graph_stats.decode_prepare_us +
|
||||
graph_stats.decode_set_inputs_us +
|
||||
graph_stats.decode_graph_compute_us +
|
||||
graph_stats.decode_result_us +
|
||||
graph_stats.decode_embedding_us +
|
||||
graph_stats.decode_final_sched_reset_us) / 1000.0;
|
||||
|
||||
LOG_INF("statistics dflash profile: capture(sync/materialize)=%.3f/%.3f ms calls=%llu/%llu bytes=%llu phase(prompt/verify batches changes)=%llu/%llu %llu/%llu, set_target=%.3f ms rows=%llu bytes=%llu, decode(llama_output_reserve/prepare)=%.3f/%.3f ms calls=%llu/%llu realloc(bytes)=%llu/%llu, prep(total/features/pos/mask)=%.3f/%.3f/%.3f/%.3f ms kv_cache(total/build/reserve/reset/alloc/up_f/up_p/compute/sync/read)=%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f ms calls(prepare/cache/read)=%llu/%llu/%llu bytes(feature/pos/mask/read)=%llu/%llu/%llu/%llu host_layers=%d, fallback_pos(copy/graph)=%llu/%llu, nonmono(copy/graph)=%llu/%llu, capture_fail=%llu/%llu decode_prepare_fail=%llu, visible_kv_max=%llu, last(rows=%d width=%d left_pad=%d n_tokens=%d n_kv=%d pos=[%d..%d])\n",
|
||||
(double) capture_stats.capture_prepare_sync_us / 1000.0,
|
||||
@ -2580,6 +2673,81 @@ void common_speculative_print_stats(const common_speculative * spec, double slot
|
||||
(int) graph_stats.last_pos_first,
|
||||
(int) graph_stats.last_pos_last);
|
||||
|
||||
LOG_INF("statistics dflash features: total=%.3f ms capture(sync/materialize)=%.3f/%.3f ms set_target=%.3f ms prep(feature/pos/mask)=%.3f/%.3f/%.3f ms rows(materialize/set_target)=%llu/%llu bytes(materialize/set_target/feature/pos/mask)=%llu/%llu/%llu/%llu/%llu\n",
|
||||
feature_path_ms,
|
||||
(double) capture_stats.capture_prepare_sync_us / 1000.0,
|
||||
(double) capture_stats.capture_materialize_us / 1000.0,
|
||||
(double) graph_stats.set_target_copy_us / 1000.0,
|
||||
(double) graph_stats.graph_feature_copy_us / 1000.0,
|
||||
(double) graph_stats.graph_pos_copy_us / 1000.0,
|
||||
(double) graph_stats.graph_mask_build_us / 1000.0,
|
||||
(unsigned long long) capture_stats.capture_materialize_rows,
|
||||
(unsigned long long) graph_stats.set_target_rows,
|
||||
(unsigned long long) capture_stats.capture_materialize_bytes,
|
||||
(unsigned long long) graph_stats.set_target_copy_bytes,
|
||||
(unsigned long long) graph_stats.graph_feature_bytes,
|
||||
(unsigned long long) graph_stats.graph_pos_bytes,
|
||||
(unsigned long long) graph_stats.graph_mask_bytes);
|
||||
|
||||
LOG_INF("statistics dflash kv: total=%.3f ms build/reserve/reset/alloc/upload_f/upload_p/compute/sync/read=%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f ms calls=%llu cached_bytes=%llu host_layers=%d\n",
|
||||
kv_cache_total_ms,
|
||||
(double) graph_stats.graph_kv_cache_build_us / 1000.0,
|
||||
(double) graph_stats.graph_kv_cache_reserve_us / 1000.0,
|
||||
(double) graph_stats.graph_kv_cache_reset_us / 1000.0,
|
||||
(double) graph_stats.graph_kv_cache_alloc_us / 1000.0,
|
||||
(double) graph_stats.graph_kv_cache_feature_upload_us / 1000.0,
|
||||
(double) graph_stats.graph_kv_cache_pos_upload_us / 1000.0,
|
||||
(double) graph_stats.graph_kv_cache_compute_us / 1000.0,
|
||||
(double) graph_stats.graph_kv_cache_sync_us / 1000.0,
|
||||
(double) graph_stats.graph_kv_cache_read_concat_pad_us / 1000.0,
|
||||
(unsigned long long) graph_stats.graph_kv_cache_calls,
|
||||
(unsigned long long) graph_stats.graph_kv_cache_cached_bytes,
|
||||
graph_stats.last_kv_cache_host_layers);
|
||||
|
||||
if (graph_stats.decode_internal_chunks > 0) {
|
||||
LOG_INF("statistics dflash decode: llama_decode(total)=%.3f ms calls=%zu chunks=%llu rebuilds=%llu sync_points=%llu internal(total/prelude/sched_reset/build/alloc/prepare/set_inputs/compute/get_result/get_embedding/final_reset)=%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f ms\n",
|
||||
(double) dflash_state->t_draft_decode_us / 1000.0,
|
||||
dflash_state->n_call_draft,
|
||||
(unsigned long long) graph_stats.decode_internal_chunks,
|
||||
(unsigned long long) graph_stats.decode_graph_rebuilds,
|
||||
(unsigned long long) graph_stats.decode_sync_profile_points,
|
||||
decode_internal_ms,
|
||||
(double) graph_stats.decode_prelude_us / 1000.0,
|
||||
(double) graph_stats.decode_sched_reset_us / 1000.0,
|
||||
(double) graph_stats.decode_build_graph_us / 1000.0,
|
||||
(double) graph_stats.decode_sched_alloc_graph_us / 1000.0,
|
||||
(double) graph_stats.decode_prepare_us / 1000.0,
|
||||
(double) graph_stats.decode_set_inputs_us / 1000.0,
|
||||
(double) graph_stats.decode_graph_compute_us / 1000.0,
|
||||
(double) graph_stats.decode_result_us / 1000.0,
|
||||
(double) graph_stats.decode_embedding_us / 1000.0,
|
||||
(double) graph_stats.decode_final_sched_reset_us / 1000.0);
|
||||
}
|
||||
|
||||
if (graph_stats.graph_kv_node_fused_target_calls > 0 ||
|
||||
graph_stats.graph_kv_node_k_proj_calls > 0 ||
|
||||
graph_stats.graph_kv_node_k_norm_calls > 0 ||
|
||||
graph_stats.graph_kv_node_k_rope_calls > 0 ||
|
||||
graph_stats.graph_kv_node_v_proj_calls > 0 ||
|
||||
graph_stats.graph_kv_node_k_store_calls > 0 ||
|
||||
graph_stats.graph_kv_node_v_store_calls > 0) {
|
||||
LOG_INF("statistics dflash kv nodes: fused_target/k_proj/k_norm/k_rope/v_proj/k_store/v_store=%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f ms calls=%llu/%llu/%llu/%llu/%llu/%llu/%llu\n",
|
||||
(double) graph_stats.graph_kv_node_fused_target_us / 1000.0,
|
||||
(double) graph_stats.graph_kv_node_k_proj_us / 1000.0,
|
||||
(double) graph_stats.graph_kv_node_k_norm_us / 1000.0,
|
||||
(double) graph_stats.graph_kv_node_k_rope_us / 1000.0,
|
||||
(double) graph_stats.graph_kv_node_v_proj_us / 1000.0,
|
||||
(double) graph_stats.graph_kv_node_k_store_us / 1000.0,
|
||||
(double) graph_stats.graph_kv_node_v_store_us / 1000.0,
|
||||
(unsigned long long) graph_stats.graph_kv_node_fused_target_calls,
|
||||
(unsigned long long) graph_stats.graph_kv_node_k_proj_calls,
|
||||
(unsigned long long) graph_stats.graph_kv_node_k_norm_calls,
|
||||
(unsigned long long) graph_stats.graph_kv_node_k_rope_calls,
|
||||
(unsigned long long) graph_stats.graph_kv_node_v_proj_calls,
|
||||
(unsigned long long) graph_stats.graph_kv_node_k_store_calls,
|
||||
(unsigned long long) graph_stats.graph_kv_node_v_store_calls);
|
||||
}
|
||||
|
||||
LOG_INF("statistics dflash hot: kv(upload_f/upload_p/upload/compute/sync)=%.3f/%.3f/%.3f/%.3f/%.3f ms calls=%llu replay(accepted_prefix_append)=%.3f ms calls=%zu rows=%zu\n",
|
||||
kv_upload_feature_ms,
|
||||
kv_upload_pos_ms,
|
||||
@ -2609,6 +2777,20 @@ void common_speculative_print_stats(const common_speculative * spec, double slot
|
||||
dflash_state->n_accept_commit_rows,
|
||||
dflash_state->n_accept_output_copy_rows,
|
||||
dflash_state->n_accept_append_rows);
|
||||
|
||||
if (dflash_state->n_accept_append_calls > 0) {
|
||||
LOG_INF("statistics dflash replay: append(filter/window_alloc/replace/keep_old/new_rows/commit/log)=%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f ms calls=%zu replace/slide=%zu/%zu\n",
|
||||
(double) dflash_state->t_accept_append_filter_us / 1000.0,
|
||||
(double) dflash_state->t_accept_append_window_alloc_us / 1000.0,
|
||||
(double) dflash_state->t_accept_append_replace_us / 1000.0,
|
||||
(double) dflash_state->t_accept_append_keep_old_us / 1000.0,
|
||||
(double) dflash_state->t_accept_append_new_rows_us / 1000.0,
|
||||
(double) dflash_state->t_accept_append_commit_detail_us / 1000.0,
|
||||
(double) dflash_state->t_accept_append_log_us / 1000.0,
|
||||
dflash_state->n_accept_append_calls,
|
||||
dflash_state->n_accept_append_replace_calls,
|
||||
dflash_state->n_accept_append_slide_calls);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -2728,11 +2910,113 @@ static void mtp_clear_target_hidden(common_speculative_state_mtp & state, llama_
|
||||
state.draft_cache_by_seq.erase(seq_id);
|
||||
}
|
||||
|
||||
struct dflash_append_breakdown {
|
||||
uint64_t filter_us = 0;
|
||||
uint64_t window_alloc_us = 0;
|
||||
uint64_t replace_us = 0;
|
||||
uint64_t keep_old_us = 0;
|
||||
uint64_t new_rows_us = 0;
|
||||
uint64_t commit_us = 0;
|
||||
uint64_t log_us = 0;
|
||||
bool replace_call = false;
|
||||
};
|
||||
|
||||
static void dflash_record_window_update(
|
||||
common_speculative_state_dflash & state,
|
||||
int32_t keep_rows,
|
||||
int32_t append_rows,
|
||||
bool replace) {
|
||||
state.target_window_keep_rows = std::max<int32_t>(0, keep_rows);
|
||||
state.target_window_append_rows = std::max<int32_t>(0, append_rows);
|
||||
state.target_window_replace = replace;
|
||||
state.target_window_version++;
|
||||
}
|
||||
|
||||
static void dflash_ring_reset_rows(
|
||||
common_speculative_state_dflash & state,
|
||||
const float * rows,
|
||||
int32_t n_rows) {
|
||||
const size_t row_width = (size_t) state.n_target_features;
|
||||
if (n_rows <= 0 || rows == nullptr) {
|
||||
state.target_window_ring_write_pos = 0;
|
||||
state.target_window_ring_filled = 0;
|
||||
return;
|
||||
}
|
||||
|
||||
if (state.target_window_ring.size() != (size_t) state.cross_ctx * row_width) {
|
||||
state.target_window_ring.resize((size_t) state.cross_ctx * row_width);
|
||||
}
|
||||
|
||||
std::memcpy(state.target_window_ring.data(), rows, (size_t) n_rows * row_width * sizeof(float));
|
||||
state.target_window_ring_write_pos = n_rows % state.cross_ctx;
|
||||
state.target_window_ring_filled = n_rows;
|
||||
state.target_window_materialized = false;
|
||||
}
|
||||
|
||||
static void dflash_ring_append_rows(
|
||||
common_speculative_state_dflash & state,
|
||||
const float * rows,
|
||||
int32_t n_rows) {
|
||||
const size_t row_width = (size_t) state.n_target_features;
|
||||
if (n_rows <= 0 || rows == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (state.target_window_ring.size() != (size_t) state.cross_ctx * row_width) {
|
||||
state.target_window_ring.resize((size_t) state.cross_ctx * row_width);
|
||||
}
|
||||
|
||||
int32_t write_pos = state.target_window_ring_write_pos;
|
||||
int32_t remaining = n_rows;
|
||||
const float * src = rows;
|
||||
while (remaining > 0) {
|
||||
const int32_t chunk_rows = std::min<int32_t>(remaining, state.cross_ctx - write_pos);
|
||||
std::memcpy(
|
||||
state.target_window_ring.data() + (size_t) write_pos * row_width,
|
||||
src,
|
||||
(size_t) chunk_rows * row_width * sizeof(float));
|
||||
src += (size_t) chunk_rows * row_width;
|
||||
remaining -= chunk_rows;
|
||||
write_pos = (write_pos + chunk_rows) % state.cross_ctx;
|
||||
}
|
||||
|
||||
state.target_window_ring_write_pos = write_pos;
|
||||
state.target_window_ring_filled = std::min(state.cross_ctx, state.target_window_ring_filled + n_rows);
|
||||
state.target_window_materialized = false;
|
||||
}
|
||||
|
||||
static void dflash_materialize_target_window_features(common_speculative_state_dflash & state) {
|
||||
if (state.target_window_materialized || state.target_window_rows <= 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const size_t row_width = (size_t) state.n_target_features;
|
||||
state.target_window.resize((size_t) state.target_window_rows * row_width);
|
||||
|
||||
const int32_t read_start = (state.target_window_ring_write_pos - state.target_window_rows + state.cross_ctx) % state.cross_ctx;
|
||||
const int32_t first_rows = std::min<int32_t>(state.target_window_rows, state.cross_ctx - read_start);
|
||||
std::memcpy(
|
||||
state.target_window.data(),
|
||||
state.target_window_ring.data() + (size_t) read_start * row_width,
|
||||
(size_t) first_rows * row_width * sizeof(float));
|
||||
|
||||
const int32_t second_rows = state.target_window_rows - first_rows;
|
||||
if (second_rows > 0) {
|
||||
std::memcpy(
|
||||
state.target_window.data() + (size_t) first_rows * row_width,
|
||||
state.target_window_ring.data(),
|
||||
(size_t) second_rows * row_width * sizeof(float));
|
||||
}
|
||||
|
||||
state.target_window_materialized = true;
|
||||
}
|
||||
|
||||
static bool dflash_append_target_features(
|
||||
common_speculative_state_dflash & state,
|
||||
const common_speculative_feature_view & features,
|
||||
const llama_batch & batch,
|
||||
llama_seq_id seq_id) {
|
||||
llama_seq_id seq_id,
|
||||
dflash_append_breakdown * breakdown = nullptr) {
|
||||
GGML_UNUSED(batch);
|
||||
|
||||
if (features.kind != COMMON_SPECULATIVE_FEATURE_HIDDEN_STATE ||
|
||||
@ -2748,6 +3032,7 @@ static bool dflash_append_target_features(
|
||||
new_rows.reserve(features.rows.size() * row_width);
|
||||
new_positions.reserve(features.rows.size());
|
||||
|
||||
const int64_t t_filter_us = ggml_time_us();
|
||||
for (const auto & row : features.rows) {
|
||||
if (row.seq_id != seq_id || row.data == nullptr) {
|
||||
continue;
|
||||
@ -2756,6 +3041,9 @@ static bool dflash_append_target_features(
|
||||
new_positions.push_back(row.pos);
|
||||
new_rows.insert(new_rows.end(), row.data, row.data + row_width);
|
||||
}
|
||||
if (breakdown != nullptr) {
|
||||
breakdown->filter_us += (uint64_t) (ggml_time_us() - t_filter_us);
|
||||
}
|
||||
|
||||
if (new_positions.empty()) {
|
||||
return false;
|
||||
@ -2767,46 +3055,93 @@ static bool dflash_append_target_features(
|
||||
if (n_rows >= state.cross_ctx) {
|
||||
state.n_rows_dropped += (size_t) state.target_window_rows + (size_t) (n_rows - state.cross_ctx);
|
||||
const int32_t keep_from = n_rows - state.cross_ctx;
|
||||
state.target_window.assign(
|
||||
const int64_t t_replace_us = ggml_time_us();
|
||||
state.target_window_pos.assign(new_positions.begin() + keep_from, new_positions.end());
|
||||
state.target_window_append_features.assign(
|
||||
new_rows.begin() + (ptrdiff_t) keep_from * (ptrdiff_t) row_width,
|
||||
new_rows.end());
|
||||
state.target_window_pos.assign(new_positions.begin() + keep_from, new_positions.end());
|
||||
dflash_ring_reset_rows(state, state.target_window_append_features.data(), state.cross_ctx);
|
||||
if (breakdown != nullptr) {
|
||||
breakdown->replace_us += (uint64_t) (ggml_time_us() - t_replace_us);
|
||||
breakdown->replace_call = true;
|
||||
}
|
||||
|
||||
const int64_t t_commit_us = ggml_time_us();
|
||||
state.target_window_rows = state.cross_ctx;
|
||||
state.target_window_ring_filled = state.target_window_rows;
|
||||
state.last_target_pos = state.target_window_pos.empty() ? -1 : state.target_window_pos.back();
|
||||
dflash_record_window_update(state, 0, state.target_window_rows, true);
|
||||
if (breakdown != nullptr) {
|
||||
breakdown->commit_us += (uint64_t) (ggml_time_us() - t_commit_us);
|
||||
}
|
||||
|
||||
const int64_t t_log_us = ggml_time_us();
|
||||
dflash_contract_log_append(state, seq_id, new_positions);
|
||||
if (breakdown != nullptr) {
|
||||
breakdown->log_us += (uint64_t) (ggml_time_us() - t_log_us);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
const int32_t keep_old_rows = std::min<int32_t>(state.target_window_rows, state.cross_ctx - n_rows);
|
||||
state.n_rows_dropped += (size_t) std::max<int32_t>(0, state.target_window_rows - keep_old_rows);
|
||||
std::vector<float> next_window((size_t) (keep_old_rows + n_rows) * row_width);
|
||||
std::vector<llama_pos> next_window_pos((size_t) (keep_old_rows + n_rows));
|
||||
|
||||
if (keep_old_rows > 0) {
|
||||
const float * old_src = state.target_window.data() + (size_t) (state.target_window_rows - keep_old_rows) * row_width;
|
||||
std::memcpy(next_window.data(), old_src, (size_t) keep_old_rows * row_width * sizeof(float));
|
||||
std::copy(state.target_window_pos.end() - keep_old_rows, state.target_window_pos.end(), next_window_pos.begin());
|
||||
const int64_t t_window_alloc_us = ggml_time_us();
|
||||
std::vector<llama_pos> & next_window_pos = state.target_window_pos_stage;
|
||||
next_window_pos.resize((size_t) (keep_old_rows + n_rows));
|
||||
if (breakdown != nullptr) {
|
||||
breakdown->window_alloc_us += (uint64_t) (ggml_time_us() - t_window_alloc_us);
|
||||
}
|
||||
|
||||
std::memcpy(
|
||||
next_window.data() + (size_t) keep_old_rows * row_width,
|
||||
new_rows.data(),
|
||||
(size_t) n_rows * row_width * sizeof(float));
|
||||
std::copy(new_positions.begin(), new_positions.end(), next_window_pos.begin() + keep_old_rows);
|
||||
if (keep_old_rows > 0) {
|
||||
const int64_t t_keep_old_us = ggml_time_us();
|
||||
std::copy(state.target_window_pos.end() - keep_old_rows, state.target_window_pos.end(), next_window_pos.begin());
|
||||
if (breakdown != nullptr) {
|
||||
breakdown->keep_old_us += (uint64_t) (ggml_time_us() - t_keep_old_us);
|
||||
}
|
||||
}
|
||||
|
||||
state.target_window = std::move(next_window);
|
||||
state.target_window_pos = std::move(next_window_pos);
|
||||
const int64_t t_new_rows_us = ggml_time_us();
|
||||
state.target_window_append_features.assign(new_rows.begin(), new_rows.end());
|
||||
dflash_ring_append_rows(state, state.target_window_append_features.data(), n_rows);
|
||||
std::copy(new_positions.begin(), new_positions.end(), next_window_pos.begin() + keep_old_rows);
|
||||
if (breakdown != nullptr) {
|
||||
breakdown->new_rows_us += (uint64_t) (ggml_time_us() - t_new_rows_us);
|
||||
}
|
||||
|
||||
const int64_t t_commit_us = ggml_time_us();
|
||||
state.target_window_pos.swap(next_window_pos);
|
||||
next_window_pos.clear();
|
||||
state.target_window_rows = keep_old_rows + n_rows;
|
||||
state.target_window_ring_filled = state.target_window_rows;
|
||||
state.last_target_pos = state.target_window_pos.empty() ? -1 : state.target_window_pos.back();
|
||||
dflash_record_window_update(state, keep_old_rows, n_rows, false);
|
||||
if (breakdown != nullptr) {
|
||||
breakdown->commit_us += (uint64_t) (ggml_time_us() - t_commit_us);
|
||||
}
|
||||
|
||||
const int64_t t_log_us = ggml_time_us();
|
||||
dflash_contract_log_append(state, seq_id, new_positions);
|
||||
if (breakdown != nullptr) {
|
||||
breakdown->log_us += (uint64_t) (ggml_time_us() - t_log_us);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static void dflash_clear_target_features(common_speculative_state_dflash & state) {
|
||||
state.target_window.clear();
|
||||
state.target_window_pos.clear();
|
||||
state.target_window_stage.clear();
|
||||
state.target_window_pos_stage.clear();
|
||||
state.target_window_append_features.clear();
|
||||
state.target_window_rows = 0;
|
||||
state.target_window_ring_write_pos = 0;
|
||||
state.target_window_ring_filled = 0;
|
||||
state.target_window_keep_rows = 0;
|
||||
state.target_window_append_rows = 0;
|
||||
state.target_window_replace = false;
|
||||
state.target_window_materialized = false;
|
||||
state.last_target_pos = -1;
|
||||
llama_reset_dflash_kv_cache_state(state.ctx_dft);
|
||||
}
|
||||
|
||||
static void dflash_context_shift(
|
||||
@ -2814,10 +3149,12 @@ static void dflash_context_shift(
|
||||
llama_pos kv_keep,
|
||||
llama_pos kv_discard,
|
||||
llama_pos kv_past) {
|
||||
if (kv_discard <= 0 || state.target_window_rows <= 0 || state.target_window.empty() || state.target_window_pos.empty()) {
|
||||
if (kv_discard <= 0 || state.target_window_rows <= 0 || state.target_window_pos.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
dflash_materialize_target_window_features(state);
|
||||
|
||||
const size_t row_width = (size_t) state.n_target_features;
|
||||
const llama_pos discard_begin = kv_keep;
|
||||
const llama_pos discard_end = kv_keep + kv_discard;
|
||||
@ -2845,7 +3182,10 @@ static void dflash_context_shift(
|
||||
state.target_window = std::move(shifted_rows);
|
||||
state.target_window_pos = std::move(shifted_positions);
|
||||
state.target_window_rows = (int32_t) state.target_window_pos.size();
|
||||
dflash_ring_reset_rows(state, state.target_window.data(), state.target_window_rows);
|
||||
state.last_target_pos = state.target_window_pos.empty() ? -1 : state.target_window_pos.back();
|
||||
dflash_record_window_update(state, 0, state.target_window_rows, true);
|
||||
llama_reset_dflash_kv_cache_state(state.ctx_dft);
|
||||
state.n_context_shifts++;
|
||||
}
|
||||
|
||||
@ -2959,8 +3299,9 @@ int32_t common_speculative_on_target_batch(
|
||||
}
|
||||
}
|
||||
|
||||
dflash_append_breakdown append_breakdown;
|
||||
const int64_t t_append_us = ggml_time_us();
|
||||
if (!dflash_append_target_features(*dflash_state, features, batch, seq_id)) {
|
||||
if (!dflash_append_target_features(*dflash_state, features, batch, seq_id, &append_breakdown)) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
@ -2971,8 +3312,20 @@ int32_t common_speculative_on_target_batch(
|
||||
dflash_state->n_warmup_append_rows += (size_t) batch.n_tokens;
|
||||
} else {
|
||||
dflash_state->t_accept_append_us += append_us;
|
||||
dflash_state->t_accept_append_filter_us += append_breakdown.filter_us;
|
||||
dflash_state->t_accept_append_window_alloc_us += append_breakdown.window_alloc_us;
|
||||
dflash_state->t_accept_append_replace_us += append_breakdown.replace_us;
|
||||
dflash_state->t_accept_append_keep_old_us += append_breakdown.keep_old_us;
|
||||
dflash_state->t_accept_append_new_rows_us += append_breakdown.new_rows_us;
|
||||
dflash_state->t_accept_append_commit_detail_us += append_breakdown.commit_us;
|
||||
dflash_state->t_accept_append_log_us += append_breakdown.log_us;
|
||||
dflash_state->n_accept_append_calls++;
|
||||
dflash_state->n_accept_append_rows += (size_t) batch.n_tokens;
|
||||
if (append_breakdown.replace_call) {
|
||||
dflash_state->n_accept_append_replace_calls++;
|
||||
} else {
|
||||
dflash_state->n_accept_append_slide_calls++;
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
|
||||
@ -23,38 +23,132 @@ ggml_cgraph * llm_build_context::build_dflash_kv_cache() {
|
||||
const int64_t ctx_len = lctx.dflash_visible_cross_ctx > 0
|
||||
? (int64_t) lctx.dflash_visible_cross_ctx
|
||||
: std::max<int64_t>(1, (int64_t) cparams.n_ctx - (int64_t) hparams.dflash_block_size);
|
||||
const int64_t update_rows = std::max<int64_t>(1, lctx.dflash_kv_cache_update_rows > 0 ? lctx.dflash_kv_cache_update_rows : ctx_len);
|
||||
const int32_t write_pos = lctx.dflash_kv_cache_write_pos;
|
||||
|
||||
GGML_ASSERT(n_embd_head_k == n_embd_head_v);
|
||||
GGML_ASSERT(n_target_features > 0);
|
||||
GGML_ASSERT(lctx.ensure_dflash_kv_cache_tensors((int32_t) ctx_len));
|
||||
GGML_ASSERT(update_rows > 0 && update_rows <= ctx_len);
|
||||
GGML_ASSERT(write_pos >= 0 && write_pos < ctx_len);
|
||||
|
||||
ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes((int) std::max<int64_t>(1, ctx_len)) + 24 * n_layer, false);
|
||||
ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes((int) std::max<int64_t>(1, update_rows)) + 24 * n_layer, false);
|
||||
|
||||
lctx.dflash_kv_input_target_features = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_target_features, ctx_len);
|
||||
lctx.dflash_kv_input_target_features = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_target_features, update_rows);
|
||||
ggml_set_input(lctx.dflash_kv_input_target_features);
|
||||
cb(lctx.dflash_kv_input_target_features, "dflash_kv_input_target_features", -1);
|
||||
|
||||
lctx.dflash_kv_input_pos_ctx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ctx_len);
|
||||
lctx.dflash_kv_input_pos_ctx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, update_rows);
|
||||
ggml_set_input(lctx.dflash_kv_input_pos_ctx);
|
||||
cb(lctx.dflash_kv_input_pos_ctx, "dflash_kv_input_pos_ctx", -1);
|
||||
|
||||
ggml_tensor * fused_target = llm_build_lora_mm(lctx, ctx0, model.dflash_fc, lctx.dflash_kv_input_target_features);
|
||||
fused_target = llm_build_norm(ctx0, fused_target, hparams, model.dflash_hidden_norm, nullptr, LLM_NORM_RMS, cb, -1);
|
||||
cb(fused_target, "dflash_kv_fused_target", -1);
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
GGML_ASSERT((size_t) il < lctx.dflash_k_ctx_cache.size());
|
||||
GGML_ASSERT((size_t) il < lctx.dflash_v_ctx_cache.size());
|
||||
|
||||
ggml_tensor * Kcur_ctx = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, fused_target);
|
||||
Kcur_ctx = ggml_reshape_3d(ctx0, Kcur_ctx, n_embd_head_k, n_head_kv, ctx_len);
|
||||
ggml_tensor * Kcur_ctx_proj = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, fused_target);
|
||||
cb(Kcur_ctx_proj, "dflash_kv_k_proj", il);
|
||||
|
||||
ggml_tensor * Kcur_ctx = ggml_reshape_3d(ctx0, Kcur_ctx_proj, n_embd_head_k, n_head_kv, update_rows);
|
||||
Kcur_ctx = llm_build_norm(ctx0, Kcur_ctx, hparams, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, cb, il);
|
||||
cb(Kcur_ctx, "dflash_kv_k_norm", il);
|
||||
Kcur_ctx = ggml_rope_ext(ctx0, Kcur_ctx, lctx.dflash_kv_input_pos_ctx, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
cb(Kcur_ctx, "dflash_kv_k_rope", il);
|
||||
|
||||
ggml_tensor * Vcur_ctx = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, fused_target);
|
||||
Vcur_ctx = ggml_reshape_3d(ctx0, Vcur_ctx, n_embd_head_v, n_head_kv, ctx_len);
|
||||
cb(Vcur_ctx, "dflash_kv_v_proj", il);
|
||||
Vcur_ctx = ggml_reshape_3d(ctx0, Vcur_ctx, n_embd_head_v, n_head_kv, update_rows);
|
||||
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur_ctx, lctx.dflash_k_ctx_cache[(size_t) il]));
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur_ctx, lctx.dflash_v_ctx_cache[(size_t) il]));
|
||||
const int32_t first_rows = std::min<int32_t>((int32_t) update_rows, (int32_t) ctx_len - write_pos);
|
||||
const int32_t second_rows = (int32_t) update_rows - first_rows;
|
||||
|
||||
if (first_rows > 0) {
|
||||
ggml_tensor * Ksrc_first = first_rows == update_rows
|
||||
? Kcur_ctx
|
||||
: ggml_view_3d(ctx0, Kcur_ctx,
|
||||
Kcur_ctx->ne[0],
|
||||
Kcur_ctx->ne[1],
|
||||
first_rows,
|
||||
Kcur_ctx->nb[1],
|
||||
Kcur_ctx->nb[2],
|
||||
0);
|
||||
ggml_tensor * Vsrc_first = first_rows == update_rows
|
||||
? Vcur_ctx
|
||||
: ggml_view_3d(ctx0, Vcur_ctx,
|
||||
Vcur_ctx->ne[0],
|
||||
Vcur_ctx->ne[1],
|
||||
first_rows,
|
||||
Vcur_ctx->nb[1],
|
||||
Vcur_ctx->nb[2],
|
||||
0);
|
||||
ggml_tensor * Kdst_first = ggml_view_3d(ctx0, lctx.dflash_k_ctx_cache[(size_t) il],
|
||||
lctx.dflash_k_ctx_cache[(size_t) il]->ne[0],
|
||||
lctx.dflash_k_ctx_cache[(size_t) il]->ne[1],
|
||||
first_rows,
|
||||
lctx.dflash_k_ctx_cache[(size_t) il]->nb[1],
|
||||
lctx.dflash_k_ctx_cache[(size_t) il]->nb[2],
|
||||
(size_t) write_pos * lctx.dflash_k_ctx_cache[(size_t) il]->nb[2]);
|
||||
ggml_tensor * Vdst_first = ggml_view_3d(ctx0, lctx.dflash_v_ctx_cache[(size_t) il],
|
||||
lctx.dflash_v_ctx_cache[(size_t) il]->ne[0],
|
||||
lctx.dflash_v_ctx_cache[(size_t) il]->ne[1],
|
||||
first_rows,
|
||||
lctx.dflash_v_ctx_cache[(size_t) il]->nb[1],
|
||||
lctx.dflash_v_ctx_cache[(size_t) il]->nb[2],
|
||||
(size_t) write_pos * lctx.dflash_v_ctx_cache[(size_t) il]->nb[2]);
|
||||
|
||||
ggml_tensor * Kstore_first = ggml_cpy(ctx0, Ksrc_first, Kdst_first);
|
||||
cb(Kstore_first, "dflash_kv_k_store", il);
|
||||
ggml_build_forward_expand(gf, Kstore_first);
|
||||
|
||||
ggml_tensor * Vstore_first = ggml_cpy(ctx0, Vsrc_first, Vdst_first);
|
||||
cb(Vstore_first, "dflash_kv_v_store", il);
|
||||
ggml_build_forward_expand(gf, Vstore_first);
|
||||
}
|
||||
|
||||
if (second_rows > 0) {
|
||||
ggml_tensor * Ksrc_second = ggml_view_3d(ctx0, Kcur_ctx,
|
||||
Kcur_ctx->ne[0],
|
||||
Kcur_ctx->ne[1],
|
||||
second_rows,
|
||||
Kcur_ctx->nb[1],
|
||||
Kcur_ctx->nb[2],
|
||||
(size_t) first_rows * Kcur_ctx->nb[2]);
|
||||
ggml_tensor * Vsrc_second = ggml_view_3d(ctx0, Vcur_ctx,
|
||||
Vcur_ctx->ne[0],
|
||||
Vcur_ctx->ne[1],
|
||||
second_rows,
|
||||
Vcur_ctx->nb[1],
|
||||
Vcur_ctx->nb[2],
|
||||
(size_t) first_rows * Vcur_ctx->nb[2]);
|
||||
ggml_tensor * Kdst_second = ggml_view_3d(ctx0, lctx.dflash_k_ctx_cache[(size_t) il],
|
||||
lctx.dflash_k_ctx_cache[(size_t) il]->ne[0],
|
||||
lctx.dflash_k_ctx_cache[(size_t) il]->ne[1],
|
||||
second_rows,
|
||||
lctx.dflash_k_ctx_cache[(size_t) il]->nb[1],
|
||||
lctx.dflash_k_ctx_cache[(size_t) il]->nb[2],
|
||||
0);
|
||||
ggml_tensor * Vdst_second = ggml_view_3d(ctx0, lctx.dflash_v_ctx_cache[(size_t) il],
|
||||
lctx.dflash_v_ctx_cache[(size_t) il]->ne[0],
|
||||
lctx.dflash_v_ctx_cache[(size_t) il]->ne[1],
|
||||
second_rows,
|
||||
lctx.dflash_v_ctx_cache[(size_t) il]->nb[1],
|
||||
lctx.dflash_v_ctx_cache[(size_t) il]->nb[2],
|
||||
0);
|
||||
|
||||
ggml_tensor * Kstore_second = ggml_cpy(ctx0, Ksrc_second, Kdst_second);
|
||||
cb(Kstore_second, "dflash_kv_k_store", il);
|
||||
ggml_build_forward_expand(gf, Kstore_second);
|
||||
|
||||
ggml_tensor * Vstore_second = ggml_cpy(ctx0, Vsrc_second, Vdst_second);
|
||||
cb(Vstore_second, "dflash_kv_v_store", il);
|
||||
ggml_build_forward_expand(gf, Vstore_second);
|
||||
}
|
||||
}
|
||||
|
||||
return gf;
|
||||
@ -69,12 +163,17 @@ ggml_cgraph * llm_build_context::build_dflash() {
|
||||
const int64_t ctx_len = lctx.dflash_visible_cross_ctx > 0
|
||||
? (int64_t) lctx.dflash_visible_cross_ctx
|
||||
: std::max<int64_t>(1, (int64_t) cparams.n_ctx - (int64_t) hparams.dflash_block_size);
|
||||
const int32_t cache_rows = use_kv_cache ? std::clamp(lctx.dflash_kv_cache_view_n_filled, 0, (int32_t) ctx_len) : 0;
|
||||
const int32_t cache_write_pos = use_kv_cache && ctx_len > 0
|
||||
? ((lctx.dflash_kv_cache_view_write_pos % (int32_t) ctx_len) + (int32_t) ctx_len) % (int32_t) ctx_len
|
||||
: 0;
|
||||
const int64_t n_kv_total = GGML_PAD(ctx_len + n_tokens, flash_attn ? 256 : 32);
|
||||
const int64_t n_kv_pad = n_kv_total - (ctx_len + n_tokens);
|
||||
|
||||
GGML_ASSERT(n_embd_head_k == n_embd_head_v);
|
||||
GGML_ASSERT(n_target_features > 0);
|
||||
GGML_ASSERT(!use_kv_cache || lctx.ensure_dflash_kv_cache_tensors((int32_t) ctx_len));
|
||||
GGML_ASSERT(!use_kv_cache || (cache_write_pos >= 0 && cache_write_pos < ctx_len));
|
||||
|
||||
ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes((int) std::max<int64_t>(n_tokens, ctx_len)) + 32 * n_layer, false);
|
||||
|
||||
@ -160,8 +259,52 @@ ggml_cgraph * llm_build_context::build_dflash() {
|
||||
ggml_tensor * Kcur_ctx = nullptr;
|
||||
ggml_tensor * Vcur_ctx = nullptr;
|
||||
if (use_kv_cache) {
|
||||
Kcur_ctx = lctx.dflash_k_ctx_cache[(size_t) il];
|
||||
Vcur_ctx = lctx.dflash_v_ctx_cache[(size_t) il];
|
||||
auto build_ordered_cache_view = [&](ggml_tensor * cache) -> ggml_tensor * {
|
||||
if (!lctx.dflash_kv_cache_view_valid || cache_rows <= 0) {
|
||||
return cache;
|
||||
}
|
||||
|
||||
if (cache_rows < ctx_len) {
|
||||
ggml_tensor * zero_pad = ggml_view_3d(ctx0, cache,
|
||||
cache->ne[0],
|
||||
cache->ne[1],
|
||||
ctx_len - cache_rows,
|
||||
cache->nb[1],
|
||||
cache->nb[2],
|
||||
(size_t) cache_rows * cache->nb[2]);
|
||||
ggml_tensor * valid = ggml_view_3d(ctx0, cache,
|
||||
cache->ne[0],
|
||||
cache->ne[1],
|
||||
cache_rows,
|
||||
cache->nb[1],
|
||||
cache->nb[2],
|
||||
0);
|
||||
return ggml_concat(ctx0, zero_pad, valid, 2);
|
||||
}
|
||||
|
||||
if (cache_write_pos == 0) {
|
||||
return cache;
|
||||
}
|
||||
|
||||
ggml_tensor * tail = ggml_view_3d(ctx0, cache,
|
||||
cache->ne[0],
|
||||
cache->ne[1],
|
||||
ctx_len - cache_write_pos,
|
||||
cache->nb[1],
|
||||
cache->nb[2],
|
||||
(size_t) cache_write_pos * cache->nb[2]);
|
||||
ggml_tensor * head = ggml_view_3d(ctx0, cache,
|
||||
cache->ne[0],
|
||||
cache->ne[1],
|
||||
cache_write_pos,
|
||||
cache->nb[1],
|
||||
cache->nb[2],
|
||||
0);
|
||||
return ggml_concat(ctx0, tail, head, 2);
|
||||
};
|
||||
|
||||
Kcur_ctx = build_ordered_cache_view(lctx.dflash_k_ctx_cache[(size_t) il]);
|
||||
Vcur_ctx = build_ordered_cache_view(lctx.dflash_v_ctx_cache[(size_t) il]);
|
||||
cb(Kcur_ctx, "Kcur_ctx_cache", il);
|
||||
cb(Vcur_ctx, "Vcur_ctx_cache", il);
|
||||
} else {
|
||||
|
||||
@ -2173,7 +2173,27 @@ struct ggml_cgraph * llm_build_context::llama_build_graph_dflash_kv_cache(llama_
|
||||
llama_batch dummy;
|
||||
dummy.n_tokens = 0;
|
||||
|
||||
llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { };
|
||||
llm_build_cb cb = [&](struct ggml_tensor * cur, const char * name, int il) {
|
||||
if (il >= 0) {
|
||||
int j = 0;
|
||||
for (; j < GGML_MAX_NAME - 1; ++j) {
|
||||
cur->name[j] = name[j];
|
||||
if (!name[j]) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (j < GGML_MAX_NAME - 3) {
|
||||
cur->name[j++] = '-';
|
||||
auto sil = std::to_string(il);
|
||||
for (int k = 0; k < (int) sil.size() && j < GGML_MAX_NAME - 1; ++k) {
|
||||
cur->name[j++] = sil[k];
|
||||
}
|
||||
}
|
||||
cur->name[j] = 0;
|
||||
} else {
|
||||
ggml_set_name(cur, name);
|
||||
}
|
||||
};
|
||||
|
||||
struct llm_build_context llm(lctx, dummy, cb, false, false, 0, false, &lctx.dflash_buf_compute_meta);
|
||||
|
||||
|
||||
@ -281,9 +281,17 @@ struct llama_context {
|
||||
const float * dflash_target_features = nullptr;
|
||||
size_t dflash_target_features_n_floats = 0;
|
||||
int32_t dflash_target_features_n_rows = 0;
|
||||
const float * dflash_target_append_features = nullptr;
|
||||
size_t dflash_target_append_features_n_floats = 0;
|
||||
int32_t dflash_target_append_features_n_rows = 0;
|
||||
const llama_pos * dflash_target_positions = nullptr;
|
||||
size_t dflash_target_positions_n = 0;
|
||||
uint64_t dflash_target_window_version = 0;
|
||||
int32_t dflash_target_window_keep_rows = 0;
|
||||
int32_t dflash_target_window_append_rows = 0;
|
||||
bool dflash_target_window_replace = false;
|
||||
std::vector<float> dflash_target_features_owned;
|
||||
std::vector<float> dflash_target_append_features_owned;
|
||||
std::vector<llama_pos> dflash_target_positions_owned;
|
||||
std::vector<float> dflash_target_features_padded;
|
||||
std::vector<float> dflash_feature_view_buffer;
|
||||
@ -295,6 +303,15 @@ struct llama_context {
|
||||
std::vector<struct ggml_tensor *> dflash_v_ctx_cache;
|
||||
struct ggml_context * dflash_cache_ctx = nullptr;
|
||||
std::vector<ggml_backend_buffer_t> dflash_cache_bufs;
|
||||
int32_t dflash_kv_cache_write_pos = 0;
|
||||
int32_t dflash_kv_cache_n_filled = 0;
|
||||
int32_t dflash_kv_cache_update_rows = 0;
|
||||
int32_t dflash_kv_cache_reserved_rows = 0;
|
||||
int32_t dflash_kv_cache_view_write_pos = 0;
|
||||
int32_t dflash_kv_cache_view_n_filled = 0;
|
||||
uint64_t dflash_kv_cache_applied_window_version = 0;
|
||||
bool dflash_kv_cache_valid = false;
|
||||
bool dflash_kv_cache_view_valid = false;
|
||||
std::vector<uint8_t> dflash_buf_compute_meta;
|
||||
ggml_backend_sched_t dflash_sched = nullptr;
|
||||
struct ggml_tensor * dflash_kv_input_target_features = nullptr;
|
||||
|
||||
@ -55,6 +55,56 @@ void llama_dflash_profile_reset(struct llama_context * ctx) {
|
||||
ctx->dflash_profile = {};
|
||||
}
|
||||
|
||||
void llama_reset_dflash_kv_cache_state(struct llama_context * ctx) {
|
||||
if (ctx == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
ctx->dflash_kv_cache_write_pos = 0;
|
||||
ctx->dflash_kv_cache_n_filled = 0;
|
||||
ctx->dflash_kv_cache_update_rows = 0;
|
||||
ctx->dflash_kv_cache_view_write_pos = 0;
|
||||
ctx->dflash_kv_cache_view_n_filled = 0;
|
||||
ctx->dflash_kv_cache_applied_window_version = 0;
|
||||
ctx->dflash_kv_cache_valid = false;
|
||||
ctx->dflash_kv_cache_view_valid = false;
|
||||
|
||||
for (ggml_backend_buffer_t buf : ctx->dflash_cache_bufs) {
|
||||
if (buf != nullptr) {
|
||||
ggml_backend_buffer_clear(buf, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
llama_dflash_kv_cache_transition llama_plan_dflash_kv_cache_transition_for_ctx(
|
||||
const struct llama_context * ctx,
|
||||
const llama_dflash_window_update & window_update,
|
||||
int32_t n_rows) {
|
||||
if (ctx == nullptr) {
|
||||
llama_dflash_kv_cache_transition plan;
|
||||
plan.rebuild_cache = true;
|
||||
plan.append_rows = std::clamp(window_update.append_rows, 0, n_rows);
|
||||
plan.next_n_filled = n_rows;
|
||||
return plan;
|
||||
}
|
||||
|
||||
const int32_t cross_ctx = ctx->dflash_visible_cross_ctx > 0
|
||||
? ctx->dflash_visible_cross_ctx
|
||||
: std::max<int32_t>(1, (int32_t) ctx->cparams.n_ctx - (int32_t) ctx->model.hparams.dflash_block_size);
|
||||
|
||||
return llama_plan_dflash_kv_cache_transition(
|
||||
cross_ctx,
|
||||
ctx->dflash_kv_cache_n_filled,
|
||||
ctx->dflash_kv_cache_write_pos,
|
||||
ctx->dflash_kv_cache_valid,
|
||||
ctx->dflash_kv_cache_applied_window_version,
|
||||
window_update.version,
|
||||
window_update.keep_rows,
|
||||
window_update.append_rows,
|
||||
window_update.replace,
|
||||
n_rows);
|
||||
}
|
||||
|
||||
void llama_set_dflash_visible_cross_ctx(
|
||||
struct llama_context * ctx,
|
||||
int32_t cross_ctx) {
|
||||
@ -205,26 +255,91 @@ static bool llama_set_dflash_target_features_impl(
|
||||
size_t n_floats,
|
||||
int32_t n_rows,
|
||||
const llama_pos * target_positions,
|
||||
bool copy_data) {
|
||||
if (ctx == nullptr || target_features == nullptr || n_floats == 0 || n_rows <= 0) {
|
||||
bool copy_data,
|
||||
const llama_dflash_window_update * window_update) {
|
||||
const bool have_full_features = target_features != nullptr && n_floats > 0;
|
||||
const bool have_append_features = window_update != nullptr &&
|
||||
window_update->append_features != nullptr &&
|
||||
window_update->append_floats > 0 &&
|
||||
window_update->append_rows > 0;
|
||||
|
||||
if (ctx == nullptr || n_rows <= 0 || (!have_full_features && !have_append_features)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto & profile = ctx->dflash_profile;
|
||||
const int64_t t_start_us = ggml_time_us();
|
||||
const int32_t row_width = n_rows > 0 ? (int32_t) (n_floats / (size_t) n_rows) : 0;
|
||||
const int32_t row_width = have_full_features
|
||||
? (n_rows > 0 ? (int32_t) (n_floats / (size_t) n_rows) : 0)
|
||||
: (window_update->append_rows > 0 ? (int32_t) (window_update->append_floats / (size_t) window_update->append_rows) : 0);
|
||||
llama_pos first_pos = -1;
|
||||
llama_pos last_pos = -1;
|
||||
|
||||
if (copy_data) {
|
||||
if (have_full_features && copy_data) {
|
||||
ctx->dflash_target_features_owned.assign(target_features, target_features + n_floats);
|
||||
ctx->dflash_target_features = ctx->dflash_target_features_owned.data();
|
||||
} else {
|
||||
} else if (have_full_features) {
|
||||
ctx->dflash_target_features_owned.clear();
|
||||
ctx->dflash_target_features = target_features;
|
||||
} else {
|
||||
ctx->dflash_target_features_owned.clear();
|
||||
ctx->dflash_target_features = nullptr;
|
||||
}
|
||||
ctx->dflash_target_features_n_floats = n_floats;
|
||||
ctx->dflash_target_features_n_floats = have_full_features ? n_floats : 0;
|
||||
ctx->dflash_target_features_n_rows = n_rows;
|
||||
if (have_append_features && copy_data) {
|
||||
ctx->dflash_target_append_features_owned.assign(
|
||||
window_update->append_features,
|
||||
window_update->append_features + window_update->append_floats);
|
||||
ctx->dflash_target_append_features = ctx->dflash_target_append_features_owned.data();
|
||||
} else if (have_append_features) {
|
||||
ctx->dflash_target_append_features_owned.clear();
|
||||
ctx->dflash_target_append_features = window_update->append_features;
|
||||
} else {
|
||||
ctx->dflash_target_append_features_owned.clear();
|
||||
ctx->dflash_target_append_features = nullptr;
|
||||
}
|
||||
ctx->dflash_target_append_features_n_floats = have_append_features ? window_update->append_floats : 0;
|
||||
ctx->dflash_target_append_features_n_rows = have_append_features ? window_update->append_rows : 0;
|
||||
ctx->dflash_target_window_version = window_update != nullptr && window_update->version > 0
|
||||
? window_update->version
|
||||
: ctx->dflash_target_window_version + 1;
|
||||
ctx->dflash_target_window_keep_rows = window_update != nullptr
|
||||
? std::max<int32_t>(0, std::min(n_rows, window_update->keep_rows))
|
||||
: 0;
|
||||
ctx->dflash_target_window_append_rows = window_update != nullptr
|
||||
? std::max<int32_t>(0, std::min(n_rows, window_update->append_rows))
|
||||
: n_rows;
|
||||
ctx->dflash_target_window_replace = window_update != nullptr
|
||||
? window_update->replace
|
||||
: true;
|
||||
if (ctx->dflash_target_window_keep_rows + ctx->dflash_target_window_append_rows > n_rows) {
|
||||
ctx->dflash_target_window_keep_rows = std::max<int32_t>(0, n_rows - ctx->dflash_target_window_append_rows);
|
||||
}
|
||||
|
||||
const int32_t cross_ctx = ctx->dflash_visible_cross_ctx > 0
|
||||
? ctx->dflash_visible_cross_ctx
|
||||
: std::max<int32_t>(1, (int32_t) ctx->cparams.n_ctx - (int32_t) ctx->model.hparams.dflash_block_size);
|
||||
const llama_dflash_window_update cache_window_update = {
|
||||
ctx->dflash_target_window_version,
|
||||
ctx->dflash_target_window_keep_rows,
|
||||
ctx->dflash_target_window_append_rows,
|
||||
ctx->dflash_target_window_replace,
|
||||
ctx->dflash_target_append_features,
|
||||
ctx->dflash_target_append_features_n_floats,
|
||||
};
|
||||
const llama_dflash_kv_cache_transition cache_plan = llama_plan_dflash_kv_cache_transition_for_ctx(ctx, cache_window_update, n_rows);
|
||||
|
||||
if (cache_plan.cache_up_to_date) {
|
||||
ctx->dflash_kv_cache_view_n_filled = ctx->dflash_kv_cache_n_filled;
|
||||
ctx->dflash_kv_cache_view_write_pos = ctx->dflash_kv_cache_write_pos;
|
||||
ctx->dflash_kv_cache_view_valid = ctx->dflash_kv_cache_valid;
|
||||
} else if (cross_ctx > 0) {
|
||||
ctx->dflash_kv_cache_view_n_filled = cache_plan.next_n_filled;
|
||||
ctx->dflash_kv_cache_view_write_pos = cache_plan.next_write_pos;
|
||||
ctx->dflash_kv_cache_view_valid = cache_plan.next_n_filled > 0;
|
||||
}
|
||||
|
||||
if (target_positions != nullptr) {
|
||||
if (copy_data) {
|
||||
ctx->dflash_target_positions_owned.assign(target_positions, target_positions + n_rows);
|
||||
@ -243,7 +358,10 @@ static bool llama_set_dflash_target_features_impl(
|
||||
profile.set_target_copy_calls++;
|
||||
profile.set_target_copy_us += (uint64_t) (ggml_time_us() - t_start_us);
|
||||
profile.set_target_rows += (uint64_t) n_rows;
|
||||
profile.set_target_copy_bytes += n_floats * sizeof(float) + (target_positions ? (size_t) n_rows * sizeof(llama_pos) : 0);
|
||||
profile.set_target_copy_bytes +=
|
||||
(have_full_features ? n_floats : 0) * sizeof(float) +
|
||||
(have_append_features ? window_update->append_floats : 0) * sizeof(float) +
|
||||
(target_positions ? (size_t) n_rows * sizeof(llama_pos) : 0);
|
||||
profile.last_n_rows = n_rows;
|
||||
profile.last_width = row_width;
|
||||
|
||||
@ -267,8 +385,9 @@ bool llama_set_dflash_target_features_copy(
|
||||
const float * target_features,
|
||||
size_t n_floats,
|
||||
int32_t n_rows,
|
||||
const llama_pos * target_positions) {
|
||||
return llama_set_dflash_target_features_impl(ctx, target_features, n_floats, n_rows, target_positions, true);
|
||||
const llama_pos * target_positions,
|
||||
const llama_dflash_window_update * window_update) {
|
||||
return llama_set_dflash_target_features_impl(ctx, target_features, n_floats, n_rows, target_positions, true, window_update);
|
||||
}
|
||||
|
||||
bool llama_set_dflash_target_features_view(
|
||||
@ -276,8 +395,9 @@ bool llama_set_dflash_target_features_view(
|
||||
const float * target_features,
|
||||
size_t n_floats,
|
||||
int32_t n_rows,
|
||||
const llama_pos * target_positions) {
|
||||
return llama_set_dflash_target_features_impl(ctx, target_features, n_floats, n_rows, target_positions, false);
|
||||
const llama_pos * target_positions,
|
||||
const llama_dflash_window_update * window_update) {
|
||||
return llama_set_dflash_target_features_impl(ctx, target_features, n_floats, n_rows, target_positions, false, window_update);
|
||||
}
|
||||
|
||||
static void llama_record_dflash_capture_phase(
|
||||
|
||||
@ -2,6 +2,8 @@
|
||||
|
||||
#include "llama.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
|
||||
struct llama_context;
|
||||
@ -24,6 +26,19 @@ struct llama_spec_feature_view {
|
||||
};
|
||||
|
||||
struct llama_dflash_profile_stats {
|
||||
uint64_t decode_internal_chunks = 0;
|
||||
uint64_t decode_graph_rebuilds = 0;
|
||||
uint64_t decode_sync_profile_points = 0;
|
||||
uint64_t decode_prelude_us = 0;
|
||||
uint64_t decode_sched_reset_us = 0;
|
||||
uint64_t decode_build_graph_us = 0;
|
||||
uint64_t decode_sched_alloc_graph_us = 0;
|
||||
uint64_t decode_set_inputs_us = 0;
|
||||
uint64_t decode_graph_compute_us = 0;
|
||||
uint64_t decode_result_us = 0;
|
||||
uint64_t decode_embedding_us = 0;
|
||||
uint64_t decode_final_sched_reset_us = 0;
|
||||
|
||||
uint64_t decode_output_reserve_calls = 0;
|
||||
uint64_t decode_output_reserve_us = 0;
|
||||
uint64_t decode_output_reserve_reallocs = 0;
|
||||
@ -71,6 +86,20 @@ struct llama_dflash_profile_stats {
|
||||
uint64_t graph_kv_cache_read_concat_pad_calls = 0;
|
||||
uint64_t graph_kv_cache_cached_bytes = 0;
|
||||
uint64_t graph_kv_cache_calls = 0;
|
||||
uint64_t graph_kv_node_fused_target_calls = 0;
|
||||
uint64_t graph_kv_node_fused_target_us = 0;
|
||||
uint64_t graph_kv_node_k_proj_calls = 0;
|
||||
uint64_t graph_kv_node_k_proj_us = 0;
|
||||
uint64_t graph_kv_node_k_norm_calls = 0;
|
||||
uint64_t graph_kv_node_k_norm_us = 0;
|
||||
uint64_t graph_kv_node_k_rope_calls = 0;
|
||||
uint64_t graph_kv_node_k_rope_us = 0;
|
||||
uint64_t graph_kv_node_v_proj_calls = 0;
|
||||
uint64_t graph_kv_node_v_proj_us = 0;
|
||||
uint64_t graph_kv_node_k_store_calls = 0;
|
||||
uint64_t graph_kv_node_k_store_us = 0;
|
||||
uint64_t graph_kv_node_v_store_calls = 0;
|
||||
uint64_t graph_kv_node_v_store_us = 0;
|
||||
uint64_t graph_feature_bytes = 0;
|
||||
uint64_t graph_pos_bytes = 0;
|
||||
uint64_t graph_mask_bytes = 0;
|
||||
@ -96,10 +125,76 @@ struct llama_dflash_profile_stats {
|
||||
llama_pos last_pos_last = -1;
|
||||
};
|
||||
|
||||
struct llama_dflash_window_update {
|
||||
uint64_t version = 0;
|
||||
int32_t keep_rows = 0;
|
||||
int32_t append_rows = 0;
|
||||
bool replace = false;
|
||||
const float * append_features = nullptr;
|
||||
size_t append_floats = 0;
|
||||
};
|
||||
|
||||
struct llama_dflash_kv_cache_transition {
|
||||
bool cache_up_to_date = false;
|
||||
bool rebuild_cache = false;
|
||||
int32_t append_rows = 0;
|
||||
int32_t next_n_filled = 0;
|
||||
int32_t next_write_pos = 0;
|
||||
};
|
||||
|
||||
static inline llama_dflash_kv_cache_transition llama_plan_dflash_kv_cache_transition(
|
||||
int32_t cross_ctx,
|
||||
int32_t current_n_filled,
|
||||
int32_t current_write_pos,
|
||||
bool cache_valid,
|
||||
uint64_t applied_window_version,
|
||||
uint64_t target_window_version,
|
||||
int32_t keep_rows,
|
||||
int32_t append_rows,
|
||||
bool replace,
|
||||
int32_t n_rows) {
|
||||
llama_dflash_kv_cache_transition plan;
|
||||
|
||||
const int32_t safe_cross_ctx = std::max<int32_t>(1, cross_ctx);
|
||||
const int32_t bounded_n_filled = std::clamp(current_n_filled, 0, safe_cross_ctx);
|
||||
const int32_t bounded_append_rows = std::clamp(append_rows, 0, n_rows);
|
||||
const int32_t bounded_keep_rows = std::clamp(keep_rows, 0, n_rows);
|
||||
const int32_t expected_keep_rows = std::min(bounded_n_filled, std::max<int32_t>(0, safe_cross_ctx - bounded_append_rows));
|
||||
|
||||
plan.cache_up_to_date = cache_valid && applied_window_version == target_window_version;
|
||||
plan.rebuild_cache = !cache_valid || replace || bounded_append_rows <= 0 || bounded_append_rows > n_rows;
|
||||
if (!plan.rebuild_cache && bounded_keep_rows != expected_keep_rows) {
|
||||
plan.rebuild_cache = true;
|
||||
}
|
||||
|
||||
plan.append_rows = bounded_append_rows;
|
||||
if (plan.cache_up_to_date) {
|
||||
plan.next_n_filled = bounded_n_filled;
|
||||
plan.next_write_pos = safe_cross_ctx > 0
|
||||
? ((current_write_pos % safe_cross_ctx) + safe_cross_ctx) % safe_cross_ctx
|
||||
: 0;
|
||||
} else if (plan.rebuild_cache) {
|
||||
plan.next_n_filled = std::min(safe_cross_ctx, n_rows);
|
||||
plan.next_write_pos = plan.next_n_filled % safe_cross_ctx;
|
||||
} else {
|
||||
plan.next_n_filled = std::min(safe_cross_ctx, bounded_n_filled + bounded_append_rows);
|
||||
plan.next_write_pos = (current_write_pos + bounded_append_rows) % safe_cross_ctx;
|
||||
}
|
||||
|
||||
return plan;
|
||||
}
|
||||
|
||||
llama_dflash_kv_cache_transition llama_plan_dflash_kv_cache_transition_for_ctx(
|
||||
const struct llama_context * ctx,
|
||||
const llama_dflash_window_update & window_update,
|
||||
int32_t n_rows);
|
||||
|
||||
uint32_t llama_mtp_state_n_embd(const struct llama_context * ctx);
|
||||
|
||||
void llama_dflash_profile_reset(struct llama_context * ctx);
|
||||
|
||||
void llama_reset_dflash_kv_cache_state(struct llama_context * ctx);
|
||||
|
||||
void llama_set_dflash_visible_cross_ctx(
|
||||
struct llama_context * ctx,
|
||||
int32_t cross_ctx);
|
||||
@ -156,14 +251,16 @@ bool llama_set_dflash_target_features_copy(
|
||||
const float * target_features,
|
||||
size_t n_floats,
|
||||
int32_t n_rows,
|
||||
const llama_pos * target_positions);
|
||||
const llama_pos * target_positions,
|
||||
const llama_dflash_window_update * window_update = nullptr);
|
||||
|
||||
bool llama_set_dflash_target_features_view(
|
||||
struct llama_context * ctx,
|
||||
const float * target_features,
|
||||
size_t n_floats,
|
||||
int32_t n_rows,
|
||||
const llama_pos * target_positions);
|
||||
const llama_pos * target_positions,
|
||||
const llama_dflash_window_update * window_update = nullptr);
|
||||
|
||||
bool llama_set_dflash_capture_layers(
|
||||
struct llama_context * ctx,
|
||||
|
||||
410
src/llama.cpp
410
src/llama.cpp
@ -171,6 +171,129 @@ static std::vector<std::string> string_split(const std::string& str, const std::
|
||||
return parts;
|
||||
}
|
||||
|
||||
static bool llama_env_flag_enabled(const char * name) {
|
||||
const char * env = std::getenv(name);
|
||||
return env != nullptr && *env != '\0' &&
|
||||
std::strcmp(env, "0") != 0 &&
|
||||
std::strcmp(env, "false") != 0 &&
|
||||
std::strcmp(env, "off") != 0;
|
||||
}
|
||||
|
||||
enum llama_dflash_kv_node_kind {
|
||||
LLAMA_DFLASH_KV_NODE_NONE = 0,
|
||||
LLAMA_DFLASH_KV_NODE_FUSED_TARGET,
|
||||
LLAMA_DFLASH_KV_NODE_K_PROJ,
|
||||
LLAMA_DFLASH_KV_NODE_K_NORM,
|
||||
LLAMA_DFLASH_KV_NODE_K_ROPE,
|
||||
LLAMA_DFLASH_KV_NODE_V_PROJ,
|
||||
LLAMA_DFLASH_KV_NODE_K_STORE,
|
||||
LLAMA_DFLASH_KV_NODE_V_STORE,
|
||||
};
|
||||
|
||||
struct llama_dflash_kv_node_profiler {
|
||||
llama_dflash_profile_stats * profile = nullptr;
|
||||
int64_t t_start_us = 0;
|
||||
llama_dflash_kv_node_kind active_kind = LLAMA_DFLASH_KV_NODE_NONE;
|
||||
};
|
||||
|
||||
static bool llama_dflash_tensor_name_has_prefix(const struct ggml_tensor * tensor, const char * prefix) {
|
||||
if (tensor == nullptr || prefix == nullptr || prefix[0] == '\0') {
|
||||
return false;
|
||||
}
|
||||
|
||||
return std::strncmp(tensor->name, prefix, std::strlen(prefix)) == 0;
|
||||
}
|
||||
|
||||
static llama_dflash_kv_node_kind llama_dflash_kv_node_kind_from_tensor(const struct ggml_tensor * tensor) {
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_fused_target")) {
|
||||
return LLAMA_DFLASH_KV_NODE_FUSED_TARGET;
|
||||
}
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_proj")) {
|
||||
return LLAMA_DFLASH_KV_NODE_K_PROJ;
|
||||
}
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_norm")) {
|
||||
return LLAMA_DFLASH_KV_NODE_K_NORM;
|
||||
}
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_rope")) {
|
||||
return LLAMA_DFLASH_KV_NODE_K_ROPE;
|
||||
}
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_v_proj")) {
|
||||
return LLAMA_DFLASH_KV_NODE_V_PROJ;
|
||||
}
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_store")) {
|
||||
return LLAMA_DFLASH_KV_NODE_K_STORE;
|
||||
}
|
||||
if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_v_store")) {
|
||||
return LLAMA_DFLASH_KV_NODE_V_STORE;
|
||||
}
|
||||
|
||||
return LLAMA_DFLASH_KV_NODE_NONE;
|
||||
}
|
||||
|
||||
static void llama_dflash_kv_node_profile_add(
|
||||
llama_dflash_profile_stats & profile,
|
||||
llama_dflash_kv_node_kind kind,
|
||||
uint64_t elapsed_us) {
|
||||
switch (kind) {
|
||||
case LLAMA_DFLASH_KV_NODE_FUSED_TARGET:
|
||||
profile.graph_kv_node_fused_target_calls++;
|
||||
profile.graph_kv_node_fused_target_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_KV_NODE_K_PROJ:
|
||||
profile.graph_kv_node_k_proj_calls++;
|
||||
profile.graph_kv_node_k_proj_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_KV_NODE_K_NORM:
|
||||
profile.graph_kv_node_k_norm_calls++;
|
||||
profile.graph_kv_node_k_norm_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_KV_NODE_K_ROPE:
|
||||
profile.graph_kv_node_k_rope_calls++;
|
||||
profile.graph_kv_node_k_rope_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_KV_NODE_V_PROJ:
|
||||
profile.graph_kv_node_v_proj_calls++;
|
||||
profile.graph_kv_node_v_proj_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_KV_NODE_K_STORE:
|
||||
profile.graph_kv_node_k_store_calls++;
|
||||
profile.graph_kv_node_k_store_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_KV_NODE_V_STORE:
|
||||
profile.graph_kv_node_v_store_calls++;
|
||||
profile.graph_kv_node_v_store_us += elapsed_us;
|
||||
break;
|
||||
case LLAMA_DFLASH_KV_NODE_NONE:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
static bool llama_dflash_kv_node_eval_callback(struct ggml_tensor * tensor, bool ask, void * user_data) {
|
||||
auto * profiler = static_cast<llama_dflash_kv_node_profiler *>(user_data);
|
||||
if (profiler == nullptr || profiler->profile == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const llama_dflash_kv_node_kind kind = llama_dflash_kv_node_kind_from_tensor(tensor);
|
||||
if (ask) {
|
||||
if (kind == LLAMA_DFLASH_KV_NODE_NONE) {
|
||||
return false;
|
||||
}
|
||||
|
||||
profiler->active_kind = kind;
|
||||
profiler->t_start_us = ggml_time_us();
|
||||
return true;
|
||||
}
|
||||
|
||||
if (kind != LLAMA_DFLASH_KV_NODE_NONE && profiler->active_kind == kind && profiler->t_start_us > 0) {
|
||||
llama_dflash_kv_node_profile_add(*profiler->profile, kind, (uint64_t) (ggml_time_us() - profiler->t_start_us));
|
||||
}
|
||||
|
||||
profiler->active_kind = LLAMA_DFLASH_KV_NODE_NONE;
|
||||
profiler->t_start_us = 0;
|
||||
return true;
|
||||
}
|
||||
|
||||
// extract ip and port from RPC[ip:port] for rpc and keep other device names
|
||||
static std::vector<rpc_device> extract_device_from_rpc_device(std::vector<std::string> devices) {
|
||||
std::vector<rpc_device> rpc_servers;
|
||||
@ -689,6 +812,7 @@ bool llama_context::ensure_dflash_kv_cache_tensors(int32_t cross_ctx) {
|
||||
}
|
||||
|
||||
dflash_profile.last_kv_cache_host_layers = host_layers;
|
||||
llama_reset_dflash_kv_cache_state(this);
|
||||
LLAMA_LOG_INFO("%s: DFlash K/V cache placement cross_ctx=%d host_layers=%d/%d first=%s last=%s\n",
|
||||
__func__,
|
||||
target_cross_ctx,
|
||||
@ -703,6 +827,15 @@ bool llama_context::ensure_dflash_kv_cache_tensors(int32_t cross_ctx) {
|
||||
void llama_context::free_dflash_kv_cache_tensors() {
|
||||
dflash_k_ctx_cache.clear();
|
||||
dflash_v_ctx_cache.clear();
|
||||
dflash_kv_cache_write_pos = 0;
|
||||
dflash_kv_cache_n_filled = 0;
|
||||
dflash_kv_cache_update_rows = 0;
|
||||
dflash_kv_cache_reserved_rows = 0;
|
||||
dflash_kv_cache_view_write_pos = 0;
|
||||
dflash_kv_cache_view_n_filled = 0;
|
||||
dflash_kv_cache_applied_window_version = 0;
|
||||
dflash_kv_cache_valid = false;
|
||||
dflash_kv_cache_view_valid = false;
|
||||
dflash_kv_input_target_features = nullptr;
|
||||
dflash_kv_input_pos_ctx = nullptr;
|
||||
dflash_kq_mask_tensor = nullptr;
|
||||
@ -5271,11 +5404,8 @@ static bool validate_dflash_graph_contract(const llama_context & lctx) {
|
||||
static bool prepare_dflash_graph_inputs(
|
||||
struct llama_context & lctx,
|
||||
uint32_t n_tokens) {
|
||||
const char * dflash_kv_cache_env = std::getenv("IK_DFLASH_KV_CACHE");
|
||||
const bool use_kv_cache = dflash_kv_cache_env != nullptr && *dflash_kv_cache_env != '\0' &&
|
||||
std::strcmp(dflash_kv_cache_env, "0") != 0 &&
|
||||
std::strcmp(dflash_kv_cache_env, "false") != 0 &&
|
||||
std::strcmp(dflash_kv_cache_env, "off") != 0;
|
||||
const bool use_kv_cache = llama_env_flag_enabled("IK_DFLASH_KV_CACHE");
|
||||
const bool kv_node_timing = llama_env_flag_enabled("IK_DFLASH_KV_NODE_TIMING");
|
||||
auto & profile = lctx.dflash_profile;
|
||||
const int32_t cross_ctx = lctx.dflash_visible_cross_ctx > 0
|
||||
? lctx.dflash_visible_cross_ctx
|
||||
@ -5304,10 +5434,13 @@ static bool prepare_dflash_graph_inputs(
|
||||
}
|
||||
|
||||
const float * src = lctx.dflash_target_features;
|
||||
const float * append_src = lctx.dflash_target_append_features;
|
||||
const llama_pos * src_pos = lctx.dflash_target_positions;
|
||||
const size_t total_floats = lctx.dflash_target_features_n_floats;
|
||||
const size_t append_floats = lctx.dflash_target_append_features_n_floats;
|
||||
const size_t total_positions = lctx.dflash_target_positions_n;
|
||||
const int32_t n_rows = lctx.dflash_target_features_n_rows;
|
||||
const int32_t append_rows_available = lctx.dflash_target_append_features_n_rows;
|
||||
const int32_t width = (int32_t) lctx.model.hparams.dflash_n_target_features;
|
||||
const int32_t graph_cross_ctx = use_kv_cache
|
||||
? (lctx.dflash_k_ctx_cache.front() != nullptr ? (int32_t) lctx.dflash_k_ctx_cache.front()->ne[2] : 0)
|
||||
@ -5330,19 +5463,26 @@ static bool prepare_dflash_graph_inputs(
|
||||
__func__, graph_cross_ctx, cross_ctx);
|
||||
return false;
|
||||
}
|
||||
if (src == nullptr || total_floats == 0 || n_rows <= 0) {
|
||||
if (n_rows <= 0) {
|
||||
profile.graph_shape_failures++;
|
||||
LLAMA_LOG_ERROR("%s: missing DFlash target features\n", __func__);
|
||||
LLAMA_LOG_ERROR("%s: missing DFlash target feature rows\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (n_rows > cross_ctx || total_floats != (size_t) n_rows * (size_t) width) {
|
||||
const bool have_full_src = src != nullptr && total_floats == (size_t) n_rows * (size_t) width;
|
||||
if (n_rows > cross_ctx || (src != nullptr && !have_full_src)) {
|
||||
profile.graph_shape_failures++;
|
||||
LLAMA_LOG_ERROR("%s: invalid DFlash target feature shape (rows=%d width=%d floats=%zu cross_ctx=%d)\n",
|
||||
__func__, n_rows, width, total_floats, cross_ctx);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!use_kv_cache && !have_full_src) {
|
||||
profile.graph_shape_failures++;
|
||||
LLAMA_LOG_ERROR("%s: missing contiguous DFlash target features for inline path\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (n_kv_total < cross_ctx + (int32_t) n_tokens) {
|
||||
profile.graph_mask_overflow++;
|
||||
LLAMA_LOG_ERROR("%s: invalid DFlash mask shape (n_kv_total=%d < cross_ctx+n_tokens=%d)\n",
|
||||
@ -5351,24 +5491,26 @@ static bool prepare_dflash_graph_inputs(
|
||||
}
|
||||
|
||||
const int32_t left_pad = cross_ctx - n_rows;
|
||||
const size_t padded_floats = (size_t) cross_ctx * (size_t) width;
|
||||
const size_t dst_offset = (size_t) left_pad * (size_t) width;
|
||||
const int64_t t_feature_us = ggml_time_us();
|
||||
profile.last_left_pad = left_pad;
|
||||
if (lctx.dflash_target_features_padded.size() != padded_floats) {
|
||||
lctx.dflash_target_features_padded.resize(padded_floats);
|
||||
}
|
||||
if (left_pad == 0 && total_floats == padded_floats) {
|
||||
std::copy(src, src + total_floats, lctx.dflash_target_features_padded.begin());
|
||||
} else {
|
||||
if (dst_offset > 0) {
|
||||
std::fill(lctx.dflash_target_features_padded.begin(),
|
||||
lctx.dflash_target_features_padded.begin() + (ptrdiff_t) dst_offset, 0.0f);
|
||||
if (!use_kv_cache) {
|
||||
const size_t padded_floats = (size_t) cross_ctx * (size_t) width;
|
||||
const size_t dst_offset = (size_t) left_pad * (size_t) width;
|
||||
const int64_t t_feature_us = ggml_time_us();
|
||||
if (lctx.dflash_target_features_padded.size() != padded_floats) {
|
||||
lctx.dflash_target_features_padded.resize(padded_floats);
|
||||
}
|
||||
std::copy(src, src + total_floats, lctx.dflash_target_features_padded.begin() + (ptrdiff_t) dst_offset);
|
||||
if (left_pad == 0 && total_floats == padded_floats) {
|
||||
std::copy(src, src + total_floats, lctx.dflash_target_features_padded.begin());
|
||||
} else {
|
||||
if (dst_offset > 0) {
|
||||
std::fill(lctx.dflash_target_features_padded.begin(),
|
||||
lctx.dflash_target_features_padded.begin() + (ptrdiff_t) dst_offset, 0.0f);
|
||||
}
|
||||
std::copy(src, src + total_floats, lctx.dflash_target_features_padded.begin() + (ptrdiff_t) dst_offset);
|
||||
}
|
||||
profile.graph_feature_copy_us += (uint64_t) (ggml_time_us() - t_feature_us);
|
||||
profile.graph_feature_bytes += padded_floats * sizeof(float);
|
||||
}
|
||||
profile.graph_feature_copy_us += (uint64_t) (ggml_time_us() - t_feature_us);
|
||||
profile.graph_feature_bytes += padded_floats * sizeof(float);
|
||||
|
||||
const int64_t t_pos_us = ggml_time_us();
|
||||
lctx.dflash_pos_ctx_data.resize((size_t) cross_ctx);
|
||||
@ -5403,22 +5545,32 @@ static bool prepare_dflash_graph_inputs(
|
||||
profile.graph_pos_bytes += lctx.dflash_pos_ctx_data.size() * sizeof(llama_pos);
|
||||
|
||||
if (use_kv_cache) {
|
||||
const llama_dflash_kv_cache_transition cache_plan = llama_plan_dflash_kv_cache_transition(
|
||||
cross_ctx,
|
||||
lctx.dflash_kv_cache_n_filled,
|
||||
lctx.dflash_kv_cache_write_pos,
|
||||
lctx.dflash_kv_cache_valid,
|
||||
lctx.dflash_kv_cache_applied_window_version,
|
||||
lctx.dflash_target_window_version,
|
||||
lctx.dflash_target_window_keep_rows,
|
||||
lctx.dflash_target_window_append_rows,
|
||||
lctx.dflash_target_window_replace,
|
||||
n_rows);
|
||||
|
||||
const bool have_append_src = append_src != nullptr &&
|
||||
append_rows_available == cache_plan.append_rows &&
|
||||
append_floats == (size_t) cache_plan.append_rows * (size_t) width;
|
||||
|
||||
const int32_t update_rows = cache_plan.cache_up_to_date
|
||||
? 0
|
||||
: (cache_plan.rebuild_cache ? n_rows : cache_plan.append_rows);
|
||||
const size_t max_nodes = lctx.model.max_nodes((int) std::max<int32_t>(1, cross_ctx)) + 24 * lctx.model.hparams.n_layer;
|
||||
const size_t meta_size = ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false);
|
||||
if (lctx.dflash_buf_compute_meta.size() != meta_size) {
|
||||
lctx.dflash_buf_compute_meta.resize(meta_size);
|
||||
}
|
||||
|
||||
const int64_t t_build_us = ggml_time_us();
|
||||
ggml_cgraph * gf_kv = llm_build_context::llama_build_graph_dflash_kv_cache(lctx);
|
||||
profile.graph_kv_cache_build_us += (uint64_t) (ggml_time_us() - t_build_us);
|
||||
if (gf_kv == nullptr || lctx.dflash_kv_input_target_features == nullptr || lctx.dflash_kv_input_pos_ctx == nullptr) {
|
||||
profile.graph_shape_failures++;
|
||||
LLAMA_LOG_ERROR("%s: failed to build DFlash K/V cache graph\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (lctx.dflash_sched == nullptr) {
|
||||
if (lctx.dflash_sched == nullptr || lctx.dflash_kv_cache_reserved_rows != cross_ctx) {
|
||||
std::vector<ggml_backend_buffer_type_t> backend_buft;
|
||||
backend_buft.reserve(lctx.backends.size());
|
||||
for (auto * backend : lctx.backends) {
|
||||
@ -5429,51 +5581,117 @@ static bool prepare_dflash_graph_inputs(
|
||||
}
|
||||
}
|
||||
|
||||
if (lctx.dflash_sched != nullptr) {
|
||||
ggml_backend_sched_free(lctx.dflash_sched);
|
||||
lctx.dflash_sched = nullptr;
|
||||
}
|
||||
|
||||
const int32_t saved_update_rows = lctx.dflash_kv_cache_update_rows;
|
||||
lctx.dflash_kv_cache_update_rows = cross_ctx;
|
||||
const int64_t t_build_us = ggml_time_us();
|
||||
ggml_cgraph * gf_reserve = llm_build_context::llama_build_graph_dflash_kv_cache(lctx);
|
||||
profile.graph_kv_cache_build_us += (uint64_t) (ggml_time_us() - t_build_us);
|
||||
lctx.dflash_kv_cache_update_rows = saved_update_rows;
|
||||
if (gf_reserve == nullptr) {
|
||||
profile.graph_shape_failures++;
|
||||
LLAMA_LOG_ERROR("%s: failed to build DFlash K/V cache reserve graph\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
const int64_t t_reserve_us = ggml_time_us();
|
||||
lctx.dflash_sched = ggml_backend_sched_new(lctx.backends.data(), backend_buft.data(), lctx.backends.size(), max_nodes, false);
|
||||
const bool reserved = lctx.dflash_sched != nullptr && ggml_backend_sched_reserve(lctx.dflash_sched, gf_kv);
|
||||
const bool reserved = lctx.dflash_sched != nullptr && ggml_backend_sched_reserve(lctx.dflash_sched, gf_reserve);
|
||||
profile.graph_kv_cache_reserve_us += (uint64_t) (ggml_time_us() - t_reserve_us);
|
||||
if (!reserved) {
|
||||
profile.graph_shape_failures++;
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize DFlash K/V scheduler\n", __func__);
|
||||
return false;
|
||||
}
|
||||
lctx.dflash_kv_cache_reserved_rows = cross_ctx;
|
||||
}
|
||||
|
||||
const int64_t t_reset_us = ggml_time_us();
|
||||
ggml_backend_sched_reset(lctx.dflash_sched);
|
||||
profile.graph_kv_cache_reset_us += (uint64_t) (ggml_time_us() - t_reset_us);
|
||||
if (update_rows > 0) {
|
||||
const float * update_src = nullptr;
|
||||
if (have_append_src && update_rows == cache_plan.append_rows) {
|
||||
update_src = append_src;
|
||||
} else if (have_full_src) {
|
||||
update_src = src + (size_t) (n_rows - update_rows) * (size_t) width;
|
||||
}
|
||||
const llama_pos * update_pos = src_pos + (n_rows - update_rows);
|
||||
|
||||
const int64_t t_alloc_us = ggml_time_us();
|
||||
ggml_backend_sched_alloc_graph(lctx.dflash_sched, gf_kv);
|
||||
profile.graph_kv_cache_alloc_us += (uint64_t) (ggml_time_us() - t_alloc_us);
|
||||
if (update_src == nullptr) {
|
||||
profile.graph_shape_failures++;
|
||||
LLAMA_LOG_ERROR("%s: missing DFlash appended target features for cached update (rows=%d append_rows=%d floats=%zu)\n",
|
||||
__func__, n_rows, update_rows, append_floats);
|
||||
return false;
|
||||
}
|
||||
|
||||
ggml_backend_t kv_feature_backend = llama_backend_for_tensor(lctx, lctx.dflash_kv_input_target_features);
|
||||
const int64_t t_feature_upload_us = ggml_time_us();
|
||||
if (kv_feature_backend != nullptr) {
|
||||
ggml_backend_tensor_set_async(kv_feature_backend, lctx.dflash_kv_input_target_features, lctx.dflash_target_features_padded.data(), 0, ggml_nbytes(lctx.dflash_kv_input_target_features));
|
||||
} else {
|
||||
ggml_backend_tensor_set(lctx.dflash_kv_input_target_features, lctx.dflash_target_features_padded.data(), 0, ggml_nbytes(lctx.dflash_kv_input_target_features));
|
||||
if (cache_plan.rebuild_cache) {
|
||||
llama_reset_dflash_kv_cache_state(&lctx);
|
||||
}
|
||||
|
||||
lctx.dflash_kv_cache_update_rows = update_rows;
|
||||
const int64_t t_build_us = ggml_time_us();
|
||||
ggml_cgraph * gf_kv = llm_build_context::llama_build_graph_dflash_kv_cache(lctx);
|
||||
profile.graph_kv_cache_build_us += (uint64_t) (ggml_time_us() - t_build_us);
|
||||
if (gf_kv == nullptr || lctx.dflash_kv_input_target_features == nullptr || lctx.dflash_kv_input_pos_ctx == nullptr) {
|
||||
profile.graph_shape_failures++;
|
||||
LLAMA_LOG_ERROR("%s: failed to build DFlash K/V cache graph\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
const int64_t t_reset_us = ggml_time_us();
|
||||
ggml_backend_sched_reset(lctx.dflash_sched);
|
||||
profile.graph_kv_cache_reset_us += (uint64_t) (ggml_time_us() - t_reset_us);
|
||||
|
||||
const int64_t t_alloc_us = ggml_time_us();
|
||||
ggml_backend_sched_alloc_graph(lctx.dflash_sched, gf_kv);
|
||||
profile.graph_kv_cache_alloc_us += (uint64_t) (ggml_time_us() - t_alloc_us);
|
||||
|
||||
ggml_backend_t kv_feature_backend = llama_backend_for_tensor(lctx, lctx.dflash_kv_input_target_features);
|
||||
const int64_t t_feature_upload_us = ggml_time_us();
|
||||
if (kv_feature_backend != nullptr) {
|
||||
ggml_backend_tensor_set_async(kv_feature_backend, lctx.dflash_kv_input_target_features, update_src, 0, ggml_nbytes(lctx.dflash_kv_input_target_features));
|
||||
} else {
|
||||
ggml_backend_tensor_set(lctx.dflash_kv_input_target_features, update_src, 0, ggml_nbytes(lctx.dflash_kv_input_target_features));
|
||||
}
|
||||
profile.graph_kv_cache_feature_upload_us += (uint64_t) (ggml_time_us() - t_feature_upload_us);
|
||||
profile.graph_feature_bytes += (size_t) update_rows * (size_t) width * sizeof(float);
|
||||
|
||||
ggml_backend_t kv_pos_backend = llama_backend_for_tensor(lctx, lctx.dflash_kv_input_pos_ctx);
|
||||
const int64_t t_pos_upload_us = ggml_time_us();
|
||||
if (kv_pos_backend != nullptr) {
|
||||
ggml_backend_tensor_set_async(kv_pos_backend, lctx.dflash_kv_input_pos_ctx, update_pos, 0, ggml_nbytes(lctx.dflash_kv_input_pos_ctx));
|
||||
} else {
|
||||
ggml_backend_tensor_set(lctx.dflash_kv_input_pos_ctx, update_pos, 0, ggml_nbytes(lctx.dflash_kv_input_pos_ctx));
|
||||
}
|
||||
profile.graph_kv_cache_pos_upload_us += (uint64_t) (ggml_time_us() - t_pos_upload_us);
|
||||
|
||||
const int64_t t_kv_cache_us = ggml_time_us();
|
||||
llama_dflash_kv_node_profiler kv_node_profiler;
|
||||
if (kv_node_timing) {
|
||||
kv_node_profiler.profile = &profile;
|
||||
ggml_backend_sched_set_eval_callback(lctx.dflash_sched, llama_dflash_kv_node_eval_callback, &kv_node_profiler);
|
||||
}
|
||||
llama_graph_compute_sched(lctx, lctx.dflash_sched, gf_kv, lctx.cparams.n_threads);
|
||||
if (kv_node_timing) {
|
||||
ggml_backend_sched_set_eval_callback(lctx.dflash_sched, nullptr, nullptr);
|
||||
}
|
||||
profile.graph_kv_cache_compute_us += (uint64_t) (ggml_time_us() - t_kv_cache_us);
|
||||
|
||||
const int64_t t_sync_us = ggml_time_us();
|
||||
ggml_backend_sched_synchronize(lctx.dflash_sched);
|
||||
profile.graph_kv_cache_sync_us += (uint64_t) (ggml_time_us() - t_sync_us);
|
||||
profile.graph_kv_cache_calls++;
|
||||
|
||||
lctx.dflash_kv_cache_n_filled = std::min(cross_ctx, lctx.dflash_kv_cache_n_filled + update_rows);
|
||||
lctx.dflash_kv_cache_write_pos = (lctx.dflash_kv_cache_write_pos + update_rows) % cross_ctx;
|
||||
lctx.dflash_kv_cache_applied_window_version = lctx.dflash_target_window_version;
|
||||
lctx.dflash_kv_cache_valid = true;
|
||||
lctx.dflash_kv_cache_view_n_filled = lctx.dflash_kv_cache_n_filled;
|
||||
lctx.dflash_kv_cache_view_write_pos = lctx.dflash_kv_cache_write_pos;
|
||||
lctx.dflash_kv_cache_view_valid = true;
|
||||
}
|
||||
profile.graph_kv_cache_feature_upload_us += (uint64_t) (ggml_time_us() - t_feature_upload_us);
|
||||
|
||||
ggml_backend_t kv_pos_backend = llama_backend_for_tensor(lctx, lctx.dflash_kv_input_pos_ctx);
|
||||
const int64_t t_pos_upload_us = ggml_time_us();
|
||||
if (kv_pos_backend != nullptr) {
|
||||
ggml_backend_tensor_set_async(kv_pos_backend, lctx.dflash_kv_input_pos_ctx, lctx.dflash_pos_ctx_data.data(), 0, ggml_nbytes(lctx.dflash_kv_input_pos_ctx));
|
||||
} else {
|
||||
ggml_backend_tensor_set(lctx.dflash_kv_input_pos_ctx, lctx.dflash_pos_ctx_data.data(), 0, ggml_nbytes(lctx.dflash_kv_input_pos_ctx));
|
||||
}
|
||||
profile.graph_kv_cache_pos_upload_us += (uint64_t) (ggml_time_us() - t_pos_upload_us);
|
||||
|
||||
const int64_t t_kv_cache_us = ggml_time_us();
|
||||
llama_graph_compute_sched(lctx, lctx.dflash_sched, gf_kv, lctx.cparams.n_threads);
|
||||
profile.graph_kv_cache_compute_us += (uint64_t) (ggml_time_us() - t_kv_cache_us);
|
||||
|
||||
const int64_t t_sync_us = ggml_time_us();
|
||||
ggml_backend_sched_synchronize(lctx.dflash_sched);
|
||||
profile.graph_kv_cache_sync_us += (uint64_t) (ggml_time_us() - t_sync_us);
|
||||
profile.graph_kv_cache_calls++;
|
||||
} else {
|
||||
ggml_backend_tensor_set(lctx.inp_dflash_target_features, lctx.dflash_target_features_padded.data(), 0, ggml_nbytes(lctx.inp_dflash_target_features));
|
||||
ggml_backend_tensor_set(lctx.inp_dflash_pos_ctx, lctx.dflash_pos_ctx_data.data(), 0, ggml_nbytes(lctx.inp_dflash_pos_ctx));
|
||||
@ -5586,6 +5804,9 @@ static int llama_decode_internal(
|
||||
}
|
||||
lctx.n_queued_tokens += n_tokens_all;
|
||||
|
||||
auto * dflash_profile = lctx.model.arch == LLM_ARCH_DFLASH_DRAFT ? &lctx.dflash_profile : nullptr;
|
||||
const bool dflash_decode_timing = dflash_profile != nullptr && llama_env_flag_enabled("IK_DFLASH_DECODE_TIMING");
|
||||
|
||||
auto & kv_self = lctx.kv_self;
|
||||
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
@ -5670,6 +5891,10 @@ static int llama_decode_internal(
|
||||
#if IK_PRINT_TIMING
|
||||
auto tim1 = ggml_time_us();
|
||||
#endif
|
||||
const int64_t t_dflash_prelude_us = dflash_decode_timing ? ggml_time_us() : 0;
|
||||
if (dflash_decode_timing) {
|
||||
dflash_profile->decode_internal_chunks++;
|
||||
}
|
||||
uint32_t n_tokens = std::min(n_ubatch, n_tokens_all - cur_token);
|
||||
if (llm_arch_is_hybrid(model.arch) &&
|
||||
n_tokens > 1 &&
|
||||
@ -5804,6 +6029,9 @@ static int llama_decode_internal(
|
||||
auto tim2 = ggml_time_us();
|
||||
printf("prelude(...): %d us\n", int(tim2-tim1));
|
||||
#endif
|
||||
if (dflash_decode_timing) {
|
||||
dflash_profile->decode_prelude_us += (uint64_t) (ggml_time_us() - t_dflash_prelude_us);
|
||||
}
|
||||
|
||||
#if IK_PRINT_TIMING
|
||||
tim1 = ggml_time_us();
|
||||
@ -5811,30 +6039,45 @@ static int llama_decode_internal(
|
||||
auto & prev = cparams.mtp_op_type == MTP_OP_NONE ? lctx.prev : lctx.prev_mtp;
|
||||
ggml_cgraph * gf = nullptr;
|
||||
if (!lctx.can_reuse_graph(u_batch)) {
|
||||
if (dflash_decode_timing) {
|
||||
dflash_profile->decode_graph_rebuilds++;
|
||||
}
|
||||
const int64_t t_dflash_sched_reset_us = dflash_decode_timing ? ggml_time_us() : 0;
|
||||
lctx.reset_scheduler();
|
||||
ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
|
||||
#if IK_PRINT_TIMING
|
||||
tim2 = ggml_time_us();
|
||||
printf("sched_reset(...): %d us\n", int(tim2-tim1));
|
||||
#endif
|
||||
if (dflash_decode_timing) {
|
||||
dflash_profile->decode_sched_reset_us += (uint64_t) (ggml_time_us() - t_dflash_sched_reset_us);
|
||||
}
|
||||
|
||||
#if IK_PRINT_TIMING
|
||||
tim1 = ggml_time_us();
|
||||
#endif
|
||||
const int64_t t_dflash_build_graph_us = dflash_decode_timing ? ggml_time_us() : 0;
|
||||
gf = llm_build_context::llama_build_graph(lctx, u_batch, false);
|
||||
#if IK_PRINT_TIMING
|
||||
tim2 = ggml_time_us();
|
||||
printf("build_graph(...): %d us\n", int(tim2-tim1));
|
||||
#endif
|
||||
if (dflash_decode_timing) {
|
||||
dflash_profile->decode_build_graph_us += (uint64_t) (ggml_time_us() - t_dflash_build_graph_us);
|
||||
}
|
||||
|
||||
#if IK_PRINT_TIMING
|
||||
tim1 = ggml_time_us();
|
||||
#endif
|
||||
const int64_t t_dflash_sched_alloc_us = dflash_decode_timing ? ggml_time_us() : 0;
|
||||
ggml_backend_sched_alloc_graph(lctx.sched, gf);
|
||||
#if IK_PRINT_TIMING
|
||||
tim2 = ggml_time_us();
|
||||
printf("sched_alloc_graph(...): %d us\n", int(tim2-tim1));
|
||||
#endif
|
||||
if (dflash_decode_timing) {
|
||||
dflash_profile->decode_sched_alloc_graph_us += (uint64_t) (ggml_time_us() - t_dflash_sched_alloc_us);
|
||||
}
|
||||
//if (u_batch.n_tokens == 1 && u_batch.embd == nullptr && lctx.cparams.graph_reuse) {
|
||||
if (u_batch.embd == nullptr && lctx.cparams.graph_reuse &&
|
||||
!(lctx.model.arch == LLM_ARCH_GEMMA4_MTP && lctx.mtp_target_ctx != nullptr)) {
|
||||
@ -5855,16 +6098,15 @@ static int llama_decode_internal(
|
||||
}
|
||||
}
|
||||
|
||||
if (lctx.model.arch == LLM_ARCH_DFLASH_DRAFT) {
|
||||
auto & profile = lctx.dflash_profile;
|
||||
profile.decode_prepare_calls++;
|
||||
if (dflash_profile != nullptr) {
|
||||
dflash_profile->decode_prepare_calls++;
|
||||
const int64_t t_prepare_dflash_us = ggml_time_us();
|
||||
if (!prepare_dflash_graph_inputs(lctx, n_tokens)) {
|
||||
profile.decode_prepare_failures++;
|
||||
profile.decode_prepare_us += (uint64_t) (ggml_time_us() - t_prepare_dflash_us);
|
||||
dflash_profile->decode_prepare_failures++;
|
||||
dflash_profile->decode_prepare_us += (uint64_t) (ggml_time_us() - t_prepare_dflash_us);
|
||||
return GGML_STATUS_FAILED;
|
||||
}
|
||||
profile.decode_prepare_us += (uint64_t) (ggml_time_us() - t_prepare_dflash_us);
|
||||
dflash_profile->decode_prepare_us += (uint64_t) (ggml_time_us() - t_prepare_dflash_us);
|
||||
}
|
||||
|
||||
// the output is always the last tensor in the graph
|
||||
@ -5910,16 +6152,26 @@ static int llama_decode_internal(
|
||||
#if IK_PRINT_TIMING == 1
|
||||
tim1 = ggml_time_us();
|
||||
#endif
|
||||
const int64_t t_dflash_set_inputs_us = dflash_decode_timing ? ggml_time_us() : 0;
|
||||
llama_set_inputs(lctx, u_batch);
|
||||
#if IK_PRINT_TIMING == 1
|
||||
tim2 = ggml_time_us();
|
||||
printf("set_inputs(...): %d us\n", int(tim2-tim1));
|
||||
#endif
|
||||
if (dflash_decode_timing) {
|
||||
dflash_profile->decode_set_inputs_us += (uint64_t) (ggml_time_us() - t_dflash_set_inputs_us);
|
||||
}
|
||||
|
||||
#if IK_PRINT_TIMING
|
||||
tim1 = ggml_time_us();
|
||||
#endif
|
||||
const int64_t t_dflash_graph_compute_us = dflash_decode_timing ? ggml_time_us() : 0;
|
||||
llama_graph_compute(lctx, gf, n_threads);
|
||||
if (dflash_decode_timing) {
|
||||
llama_synchronize(&lctx);
|
||||
dflash_profile->decode_sync_profile_points++;
|
||||
dflash_profile->decode_graph_compute_us += (uint64_t) (ggml_time_us() - t_dflash_graph_compute_us);
|
||||
}
|
||||
#if IK_PRINT_TIMING
|
||||
llama_synchronize(&lctx);
|
||||
tim2 = ggml_time_us();
|
||||
@ -5950,6 +6202,7 @@ static int llama_decode_internal(
|
||||
#if IK_PRINT_TIMING
|
||||
tim1 = ggml_time_us();
|
||||
#endif
|
||||
const int64_t t_dflash_get_result_us = dflash_decode_timing ? ggml_time_us() : 0;
|
||||
// Do not process logits if MTP is only updating the KV cache.
|
||||
if (cparams.mtp_op_type != MTP_OP_WARMUP) { // && cparams.mtp_op_type != MTP_OP_UPDATE_ACCEPTED) {
|
||||
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched, res);
|
||||
@ -5980,6 +6233,11 @@ static int llama_decode_internal(
|
||||
}
|
||||
}
|
||||
}
|
||||
if (dflash_decode_timing) {
|
||||
llama_synchronize(&lctx);
|
||||
dflash_profile->decode_sync_profile_points++;
|
||||
dflash_profile->decode_result_us += (uint64_t) (ggml_time_us() - t_dflash_get_result_us);
|
||||
}
|
||||
#if IK_PRINT_TIMING
|
||||
tim2 = ggml_time_us();
|
||||
printf("get_result(...): %d us\n", int(tim2-tim1));
|
||||
@ -5992,6 +6250,7 @@ static int llama_decode_internal(
|
||||
#if IK_PRINT_TIMING
|
||||
tim1 = ggml_time_us();
|
||||
#endif
|
||||
const int64_t t_dflash_get_embedding_us = dflash_decode_timing ? ggml_time_us() : 0;
|
||||
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched, embd);
|
||||
GGML_ASSERT(backend_embd != nullptr);
|
||||
|
||||
@ -6031,6 +6290,11 @@ static int llama_decode_internal(
|
||||
GGML_ABORT("unknown pooling type");
|
||||
}
|
||||
}
|
||||
if (dflash_decode_timing) {
|
||||
llama_synchronize(&lctx);
|
||||
dflash_profile->decode_sync_profile_points++;
|
||||
dflash_profile->decode_embedding_us += (uint64_t) (ggml_time_us() - t_dflash_get_embedding_us);
|
||||
}
|
||||
#if IK_PRINT_TIMING
|
||||
tim2 = ggml_time_us();
|
||||
printf("get_embedding(...): %d us\n", int(tim2-tim1));
|
||||
@ -6074,9 +6338,13 @@ static int llama_decode_internal(
|
||||
#if IK_PRINT_TIMING
|
||||
auto tim1 = ggml_time_us();
|
||||
#endif
|
||||
const int64_t t_dflash_final_sched_reset_us = dflash_decode_timing ? ggml_time_us() : 0;
|
||||
if (!lctx.prev) {
|
||||
lctx.reset_scheduler();
|
||||
}
|
||||
if (dflash_decode_timing) {
|
||||
dflash_profile->decode_final_sched_reset_us += (uint64_t) (ggml_time_us() - t_dflash_final_sched_reset_us);
|
||||
}
|
||||
#if IK_PRINT_TIMING
|
||||
auto tim2 = ggml_time_us();
|
||||
printf("sched_reset(...): %d us\n", int(tim2-tim1));
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user