From ed403dca271e0b013be75d96ee9531fc97b768c2 Mon Sep 17 00:00:00 2001 From: SamuelOliveirads Date: Sun, 31 May 2026 14:51:21 -0300 Subject: [PATCH] Use windows update in kv cache --- common/speculative.cpp | 393 ++++++++++++++++++++++++++++++++-- src/graphs/build_dflash.cpp | 163 +++++++++++++- src/llama-build-context.cpp | 22 +- src/llama-context.h | 17 ++ src/llama-spec-features.cpp | 142 ++++++++++++- src/llama-spec-features.h | 101 ++++++++- src/llama.cpp | 410 +++++++++++++++++++++++++++++------- 7 files changed, 1133 insertions(+), 115 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index 911526a8..e7ce71f9 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -253,6 +253,8 @@ static const common_speculative_state_mtp * common_speculative_get_mtp_state(con static common_speculative_state_dflash * common_speculative_get_dflash_state(common_speculative * spec); static const common_speculative_state_dflash * common_speculative_get_dflash_state(const common_speculative * spec); static int32_t common_speculative_feature_width(const common_speculative * spec); +static void dflash_materialize_target_window_features(common_speculative_state_dflash & state); +static void dflash_ring_reset_rows(common_speculative_state_dflash & state, const float * rows, int32_t n_rows); static void dflash_append_target_features( common_speculative_state_dflash & state, const float * feature_rows, @@ -284,6 +286,17 @@ static bool dflash_contract_log_enabled() { std::strcmp(env, "off") != 0; } +static bool dflash_use_kv_cache_experiment() { + const char * env = std::getenv("IK_DFLASH_KV_CACHE"); + if (env == nullptr || *env == '\0') { + return false; + } + + return std::strcmp(env, "0") != 0 && + std::strcmp(env, "false") != 0 && + std::strcmp(env, "off") != 0; +} + template static std::string dflash_contract_format_values( const std::vector & values, @@ -479,7 +492,18 @@ struct common_speculative_state_dflash : public common_speculative_state { std::vector target_layer_ids; std::vector target_window; std::vector target_window_pos; + std::vector target_window_stage; + std::vector target_window_pos_stage; + std::vector target_window_ring; + std::vector target_window_append_features; int32_t target_window_rows = 0; + int32_t target_window_ring_write_pos = 0; + int32_t target_window_ring_filled = 0; + uint64_t target_window_version = 0; + int32_t target_window_keep_rows = 0; + int32_t target_window_append_rows = 0; + bool target_window_replace = false; + bool target_window_materialized = false; llama_pos last_target_pos = -1; size_t n_window_updates = 0; size_t n_rows_seen = 0; @@ -497,6 +521,13 @@ struct common_speculative_state_dflash : public common_speculative_state { uint64_t t_accept_output_copy_us = 0; uint64_t t_accept_commit_us = 0; uint64_t t_accept_append_us = 0; + uint64_t t_accept_append_filter_us = 0; + uint64_t t_accept_append_window_alloc_us = 0; + uint64_t t_accept_append_replace_us = 0; + uint64_t t_accept_append_keep_old_us = 0; + uint64_t t_accept_append_new_rows_us = 0; + uint64_t t_accept_append_commit_detail_us = 0; + uint64_t t_accept_append_log_us = 0; size_t n_warmup_collect_calls = 0; size_t n_warmup_collect_rows = 0; size_t n_warmup_append_calls = 0; @@ -507,6 +538,8 @@ struct common_speculative_state_dflash : public common_speculative_state { size_t n_accept_commit_rows = 0; size_t n_accept_append_calls = 0; size_t n_accept_append_rows = 0; + size_t n_accept_append_replace_calls = 0; + size_t n_accept_append_slide_calls = 0; common_speculative_state_dflash( enum common_speculative_type type, @@ -614,6 +647,12 @@ struct common_speculative_state_dflash : public common_speculative_state { } batch = llama_batch_init(std::max(1, block_size), 0, 1); + target_window.reserve((size_t) this->cross_ctx * (size_t) n_target_features); + target_window_stage.reserve((size_t) this->cross_ctx * (size_t) n_target_features); + target_window_ring.resize((size_t) this->cross_ctx * (size_t) n_target_features); + target_window_append_features.reserve((size_t) this->cross_ctx * (size_t) n_target_features); + target_window_pos.reserve((size_t) this->cross_ctx); + target_window_pos_stage.reserve((size_t) this->cross_ctx); ready = true; llama_set_dflash_visible_cross_ctx(ctx_dft, this->cross_ctx); @@ -648,6 +687,7 @@ struct common_speculative_state_dflash : public common_speculative_state { void begin(const llama_tokens & prompt) override { GGML_UNUSED(prompt); llama_kv_cache_clear(ctx_dft); + llama_reset_dflash_kv_cache_state(ctx_dft); n_window_updates = 0; n_rows_seen = 0; n_rows_dropped = 0; @@ -663,6 +703,13 @@ struct common_speculative_state_dflash : public common_speculative_state { t_accept_output_copy_us = 0; t_accept_commit_us = 0; t_accept_append_us = 0; + t_accept_append_filter_us = 0; + t_accept_append_window_alloc_us = 0; + t_accept_append_replace_us = 0; + t_accept_append_keep_old_us = 0; + t_accept_append_new_rows_us = 0; + t_accept_append_commit_detail_us = 0; + t_accept_append_log_us = 0; n_warmup_collect_calls = 0; n_warmup_collect_rows = 0; n_warmup_append_calls = 0; @@ -673,6 +720,8 @@ struct common_speculative_state_dflash : public common_speculative_state { n_accept_commit_rows = 0; n_accept_append_calls = 0; n_accept_append_rows = 0; + n_accept_append_replace_calls = 0; + n_accept_append_slide_calls = 0; llama_dflash_profile_reset(ctx_tgt); llama_dflash_profile_reset(ctx_dft); } @@ -695,7 +744,33 @@ struct common_speculative_state_dflash : public common_speculative_state { return; } - if (!llama_set_dflash_target_features_view(ctx_dft, target_window.data(), target_window.size(), target_window_rows, target_window_pos.data())) { + const bool use_kv_cache = dflash_use_kv_cache_experiment(); + const float * target_features = nullptr; + size_t target_feature_floats = 0; + llama_dflash_window_update window_update = { + target_window_version, + target_window_keep_rows, + target_window_append_rows, + target_window_replace, + target_window_append_features.empty() ? nullptr : target_window_append_features.data(), + target_window_append_features.size(), + }; + const llama_dflash_kv_cache_transition cache_plan = use_kv_cache + ? llama_plan_dflash_kv_cache_transition_for_ctx(ctx_dft, window_update, target_window_rows) + : llama_dflash_kv_cache_transition{}; + + if (!use_kv_cache || cache_plan.rebuild_cache) { + dflash_materialize_target_window_features(*this); + target_features = target_window.data(); + target_feature_floats = target_window.size(); + } + if (use_kv_cache && cache_plan.rebuild_cache) { + window_update.append_features = target_window.data(); + window_update.append_floats = target_window.size(); + window_update.append_rows = target_window_rows; + } + + if (!llama_set_dflash_target_features_view(ctx_dft, target_features, target_feature_floats, target_window_rows, target_window_pos.data(), &window_update)) { LOG_ERR("%s: failed to set DFlash target features\n", __func__); n_set_target_fail++; return; @@ -2522,6 +2597,24 @@ void common_speculative_print_stats(const common_speculative * spec, double slot const double kv_compute_ms = (double) graph_stats.graph_kv_cache_compute_us / 1000.0; const double kv_sync_ms = (double) graph_stats.graph_kv_cache_sync_us / 1000.0; const double replay_append_ms = (double) dflash_state->t_accept_append_us / 1000.0; + const double feature_path_ms = (double) ( + capture_stats.capture_prepare_sync_us + + capture_stats.capture_materialize_us + + graph_stats.set_target_copy_us + + graph_stats.graph_feature_copy_us + + graph_stats.graph_pos_copy_us + + graph_stats.graph_mask_build_us) / 1000.0; + const double decode_internal_ms = (double) ( + graph_stats.decode_prelude_us + + graph_stats.decode_sched_reset_us + + graph_stats.decode_build_graph_us + + graph_stats.decode_sched_alloc_graph_us + + graph_stats.decode_prepare_us + + graph_stats.decode_set_inputs_us + + graph_stats.decode_graph_compute_us + + graph_stats.decode_result_us + + graph_stats.decode_embedding_us + + graph_stats.decode_final_sched_reset_us) / 1000.0; LOG_INF("statistics dflash profile: capture(sync/materialize)=%.3f/%.3f ms calls=%llu/%llu bytes=%llu phase(prompt/verify batches changes)=%llu/%llu %llu/%llu, set_target=%.3f ms rows=%llu bytes=%llu, decode(llama_output_reserve/prepare)=%.3f/%.3f ms calls=%llu/%llu realloc(bytes)=%llu/%llu, prep(total/features/pos/mask)=%.3f/%.3f/%.3f/%.3f ms kv_cache(total/build/reserve/reset/alloc/up_f/up_p/compute/sync/read)=%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f ms calls(prepare/cache/read)=%llu/%llu/%llu bytes(feature/pos/mask/read)=%llu/%llu/%llu/%llu host_layers=%d, fallback_pos(copy/graph)=%llu/%llu, nonmono(copy/graph)=%llu/%llu, capture_fail=%llu/%llu decode_prepare_fail=%llu, visible_kv_max=%llu, last(rows=%d width=%d left_pad=%d n_tokens=%d n_kv=%d pos=[%d..%d])\n", (double) capture_stats.capture_prepare_sync_us / 1000.0, @@ -2580,6 +2673,81 @@ void common_speculative_print_stats(const common_speculative * spec, double slot (int) graph_stats.last_pos_first, (int) graph_stats.last_pos_last); + LOG_INF("statistics dflash features: total=%.3f ms capture(sync/materialize)=%.3f/%.3f ms set_target=%.3f ms prep(feature/pos/mask)=%.3f/%.3f/%.3f ms rows(materialize/set_target)=%llu/%llu bytes(materialize/set_target/feature/pos/mask)=%llu/%llu/%llu/%llu/%llu\n", + feature_path_ms, + (double) capture_stats.capture_prepare_sync_us / 1000.0, + (double) capture_stats.capture_materialize_us / 1000.0, + (double) graph_stats.set_target_copy_us / 1000.0, + (double) graph_stats.graph_feature_copy_us / 1000.0, + (double) graph_stats.graph_pos_copy_us / 1000.0, + (double) graph_stats.graph_mask_build_us / 1000.0, + (unsigned long long) capture_stats.capture_materialize_rows, + (unsigned long long) graph_stats.set_target_rows, + (unsigned long long) capture_stats.capture_materialize_bytes, + (unsigned long long) graph_stats.set_target_copy_bytes, + (unsigned long long) graph_stats.graph_feature_bytes, + (unsigned long long) graph_stats.graph_pos_bytes, + (unsigned long long) graph_stats.graph_mask_bytes); + + LOG_INF("statistics dflash kv: total=%.3f ms build/reserve/reset/alloc/upload_f/upload_p/compute/sync/read=%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f ms calls=%llu cached_bytes=%llu host_layers=%d\n", + kv_cache_total_ms, + (double) graph_stats.graph_kv_cache_build_us / 1000.0, + (double) graph_stats.graph_kv_cache_reserve_us / 1000.0, + (double) graph_stats.graph_kv_cache_reset_us / 1000.0, + (double) graph_stats.graph_kv_cache_alloc_us / 1000.0, + (double) graph_stats.graph_kv_cache_feature_upload_us / 1000.0, + (double) graph_stats.graph_kv_cache_pos_upload_us / 1000.0, + (double) graph_stats.graph_kv_cache_compute_us / 1000.0, + (double) graph_stats.graph_kv_cache_sync_us / 1000.0, + (double) graph_stats.graph_kv_cache_read_concat_pad_us / 1000.0, + (unsigned long long) graph_stats.graph_kv_cache_calls, + (unsigned long long) graph_stats.graph_kv_cache_cached_bytes, + graph_stats.last_kv_cache_host_layers); + + if (graph_stats.decode_internal_chunks > 0) { + LOG_INF("statistics dflash decode: llama_decode(total)=%.3f ms calls=%zu chunks=%llu rebuilds=%llu sync_points=%llu internal(total/prelude/sched_reset/build/alloc/prepare/set_inputs/compute/get_result/get_embedding/final_reset)=%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f ms\n", + (double) dflash_state->t_draft_decode_us / 1000.0, + dflash_state->n_call_draft, + (unsigned long long) graph_stats.decode_internal_chunks, + (unsigned long long) graph_stats.decode_graph_rebuilds, + (unsigned long long) graph_stats.decode_sync_profile_points, + decode_internal_ms, + (double) graph_stats.decode_prelude_us / 1000.0, + (double) graph_stats.decode_sched_reset_us / 1000.0, + (double) graph_stats.decode_build_graph_us / 1000.0, + (double) graph_stats.decode_sched_alloc_graph_us / 1000.0, + (double) graph_stats.decode_prepare_us / 1000.0, + (double) graph_stats.decode_set_inputs_us / 1000.0, + (double) graph_stats.decode_graph_compute_us / 1000.0, + (double) graph_stats.decode_result_us / 1000.0, + (double) graph_stats.decode_embedding_us / 1000.0, + (double) graph_stats.decode_final_sched_reset_us / 1000.0); + } + + if (graph_stats.graph_kv_node_fused_target_calls > 0 || + graph_stats.graph_kv_node_k_proj_calls > 0 || + graph_stats.graph_kv_node_k_norm_calls > 0 || + graph_stats.graph_kv_node_k_rope_calls > 0 || + graph_stats.graph_kv_node_v_proj_calls > 0 || + graph_stats.graph_kv_node_k_store_calls > 0 || + graph_stats.graph_kv_node_v_store_calls > 0) { + LOG_INF("statistics dflash kv nodes: fused_target/k_proj/k_norm/k_rope/v_proj/k_store/v_store=%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f ms calls=%llu/%llu/%llu/%llu/%llu/%llu/%llu\n", + (double) graph_stats.graph_kv_node_fused_target_us / 1000.0, + (double) graph_stats.graph_kv_node_k_proj_us / 1000.0, + (double) graph_stats.graph_kv_node_k_norm_us / 1000.0, + (double) graph_stats.graph_kv_node_k_rope_us / 1000.0, + (double) graph_stats.graph_kv_node_v_proj_us / 1000.0, + (double) graph_stats.graph_kv_node_k_store_us / 1000.0, + (double) graph_stats.graph_kv_node_v_store_us / 1000.0, + (unsigned long long) graph_stats.graph_kv_node_fused_target_calls, + (unsigned long long) graph_stats.graph_kv_node_k_proj_calls, + (unsigned long long) graph_stats.graph_kv_node_k_norm_calls, + (unsigned long long) graph_stats.graph_kv_node_k_rope_calls, + (unsigned long long) graph_stats.graph_kv_node_v_proj_calls, + (unsigned long long) graph_stats.graph_kv_node_k_store_calls, + (unsigned long long) graph_stats.graph_kv_node_v_store_calls); + } + LOG_INF("statistics dflash hot: kv(upload_f/upload_p/upload/compute/sync)=%.3f/%.3f/%.3f/%.3f/%.3f ms calls=%llu replay(accepted_prefix_append)=%.3f ms calls=%zu rows=%zu\n", kv_upload_feature_ms, kv_upload_pos_ms, @@ -2609,6 +2777,20 @@ void common_speculative_print_stats(const common_speculative * spec, double slot dflash_state->n_accept_commit_rows, dflash_state->n_accept_output_copy_rows, dflash_state->n_accept_append_rows); + + if (dflash_state->n_accept_append_calls > 0) { + LOG_INF("statistics dflash replay: append(filter/window_alloc/replace/keep_old/new_rows/commit/log)=%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f ms calls=%zu replace/slide=%zu/%zu\n", + (double) dflash_state->t_accept_append_filter_us / 1000.0, + (double) dflash_state->t_accept_append_window_alloc_us / 1000.0, + (double) dflash_state->t_accept_append_replace_us / 1000.0, + (double) dflash_state->t_accept_append_keep_old_us / 1000.0, + (double) dflash_state->t_accept_append_new_rows_us / 1000.0, + (double) dflash_state->t_accept_append_commit_detail_us / 1000.0, + (double) dflash_state->t_accept_append_log_us / 1000.0, + dflash_state->n_accept_append_calls, + dflash_state->n_accept_append_replace_calls, + dflash_state->n_accept_append_slide_calls); + } } } } @@ -2728,11 +2910,113 @@ static void mtp_clear_target_hidden(common_speculative_state_mtp & state, llama_ state.draft_cache_by_seq.erase(seq_id); } +struct dflash_append_breakdown { + uint64_t filter_us = 0; + uint64_t window_alloc_us = 0; + uint64_t replace_us = 0; + uint64_t keep_old_us = 0; + uint64_t new_rows_us = 0; + uint64_t commit_us = 0; + uint64_t log_us = 0; + bool replace_call = false; +}; + +static void dflash_record_window_update( + common_speculative_state_dflash & state, + int32_t keep_rows, + int32_t append_rows, + bool replace) { + state.target_window_keep_rows = std::max(0, keep_rows); + state.target_window_append_rows = std::max(0, append_rows); + state.target_window_replace = replace; + state.target_window_version++; +} + +static void dflash_ring_reset_rows( + common_speculative_state_dflash & state, + const float * rows, + int32_t n_rows) { + const size_t row_width = (size_t) state.n_target_features; + if (n_rows <= 0 || rows == nullptr) { + state.target_window_ring_write_pos = 0; + state.target_window_ring_filled = 0; + return; + } + + if (state.target_window_ring.size() != (size_t) state.cross_ctx * row_width) { + state.target_window_ring.resize((size_t) state.cross_ctx * row_width); + } + + std::memcpy(state.target_window_ring.data(), rows, (size_t) n_rows * row_width * sizeof(float)); + state.target_window_ring_write_pos = n_rows % state.cross_ctx; + state.target_window_ring_filled = n_rows; + state.target_window_materialized = false; +} + +static void dflash_ring_append_rows( + common_speculative_state_dflash & state, + const float * rows, + int32_t n_rows) { + const size_t row_width = (size_t) state.n_target_features; + if (n_rows <= 0 || rows == nullptr) { + return; + } + + if (state.target_window_ring.size() != (size_t) state.cross_ctx * row_width) { + state.target_window_ring.resize((size_t) state.cross_ctx * row_width); + } + + int32_t write_pos = state.target_window_ring_write_pos; + int32_t remaining = n_rows; + const float * src = rows; + while (remaining > 0) { + const int32_t chunk_rows = std::min(remaining, state.cross_ctx - write_pos); + std::memcpy( + state.target_window_ring.data() + (size_t) write_pos * row_width, + src, + (size_t) chunk_rows * row_width * sizeof(float)); + src += (size_t) chunk_rows * row_width; + remaining -= chunk_rows; + write_pos = (write_pos + chunk_rows) % state.cross_ctx; + } + + state.target_window_ring_write_pos = write_pos; + state.target_window_ring_filled = std::min(state.cross_ctx, state.target_window_ring_filled + n_rows); + state.target_window_materialized = false; +} + +static void dflash_materialize_target_window_features(common_speculative_state_dflash & state) { + if (state.target_window_materialized || state.target_window_rows <= 0) { + return; + } + + const size_t row_width = (size_t) state.n_target_features; + state.target_window.resize((size_t) state.target_window_rows * row_width); + + const int32_t read_start = (state.target_window_ring_write_pos - state.target_window_rows + state.cross_ctx) % state.cross_ctx; + const int32_t first_rows = std::min(state.target_window_rows, state.cross_ctx - read_start); + std::memcpy( + state.target_window.data(), + state.target_window_ring.data() + (size_t) read_start * row_width, + (size_t) first_rows * row_width * sizeof(float)); + + const int32_t second_rows = state.target_window_rows - first_rows; + if (second_rows > 0) { + std::memcpy( + state.target_window.data() + (size_t) first_rows * row_width, + state.target_window_ring.data(), + (size_t) second_rows * row_width * sizeof(float)); + } + + state.target_window_materialized = true; +} + static bool dflash_append_target_features( common_speculative_state_dflash & state, const common_speculative_feature_view & features, const llama_batch & batch, - llama_seq_id seq_id) { + llama_seq_id seq_id, + dflash_append_breakdown * breakdown = nullptr) { GGML_UNUSED(batch); if (features.kind != COMMON_SPECULATIVE_FEATURE_HIDDEN_STATE || @@ -2748,6 +3032,7 @@ static bool dflash_append_target_features( new_rows.reserve(features.rows.size() * row_width); new_positions.reserve(features.rows.size()); + const int64_t t_filter_us = ggml_time_us(); for (const auto & row : features.rows) { if (row.seq_id != seq_id || row.data == nullptr) { continue; @@ -2756,6 +3041,9 @@ static bool dflash_append_target_features( new_positions.push_back(row.pos); new_rows.insert(new_rows.end(), row.data, row.data + row_width); } + if (breakdown != nullptr) { + breakdown->filter_us += (uint64_t) (ggml_time_us() - t_filter_us); + } if (new_positions.empty()) { return false; @@ -2767,46 +3055,93 @@ static bool dflash_append_target_features( if (n_rows >= state.cross_ctx) { state.n_rows_dropped += (size_t) state.target_window_rows + (size_t) (n_rows - state.cross_ctx); const int32_t keep_from = n_rows - state.cross_ctx; - state.target_window.assign( + const int64_t t_replace_us = ggml_time_us(); + state.target_window_pos.assign(new_positions.begin() + keep_from, new_positions.end()); + state.target_window_append_features.assign( new_rows.begin() + (ptrdiff_t) keep_from * (ptrdiff_t) row_width, new_rows.end()); - state.target_window_pos.assign(new_positions.begin() + keep_from, new_positions.end()); + dflash_ring_reset_rows(state, state.target_window_append_features.data(), state.cross_ctx); + if (breakdown != nullptr) { + breakdown->replace_us += (uint64_t) (ggml_time_us() - t_replace_us); + breakdown->replace_call = true; + } + + const int64_t t_commit_us = ggml_time_us(); state.target_window_rows = state.cross_ctx; + state.target_window_ring_filled = state.target_window_rows; state.last_target_pos = state.target_window_pos.empty() ? -1 : state.target_window_pos.back(); + dflash_record_window_update(state, 0, state.target_window_rows, true); + if (breakdown != nullptr) { + breakdown->commit_us += (uint64_t) (ggml_time_us() - t_commit_us); + } + + const int64_t t_log_us = ggml_time_us(); dflash_contract_log_append(state, seq_id, new_positions); + if (breakdown != nullptr) { + breakdown->log_us += (uint64_t) (ggml_time_us() - t_log_us); + } return true; } const int32_t keep_old_rows = std::min(state.target_window_rows, state.cross_ctx - n_rows); state.n_rows_dropped += (size_t) std::max(0, state.target_window_rows - keep_old_rows); - std::vector next_window((size_t) (keep_old_rows + n_rows) * row_width); - std::vector next_window_pos((size_t) (keep_old_rows + n_rows)); - - if (keep_old_rows > 0) { - const float * old_src = state.target_window.data() + (size_t) (state.target_window_rows - keep_old_rows) * row_width; - std::memcpy(next_window.data(), old_src, (size_t) keep_old_rows * row_width * sizeof(float)); - std::copy(state.target_window_pos.end() - keep_old_rows, state.target_window_pos.end(), next_window_pos.begin()); + const int64_t t_window_alloc_us = ggml_time_us(); + std::vector & next_window_pos = state.target_window_pos_stage; + next_window_pos.resize((size_t) (keep_old_rows + n_rows)); + if (breakdown != nullptr) { + breakdown->window_alloc_us += (uint64_t) (ggml_time_us() - t_window_alloc_us); } - std::memcpy( - next_window.data() + (size_t) keep_old_rows * row_width, - new_rows.data(), - (size_t) n_rows * row_width * sizeof(float)); - std::copy(new_positions.begin(), new_positions.end(), next_window_pos.begin() + keep_old_rows); + if (keep_old_rows > 0) { + const int64_t t_keep_old_us = ggml_time_us(); + std::copy(state.target_window_pos.end() - keep_old_rows, state.target_window_pos.end(), next_window_pos.begin()); + if (breakdown != nullptr) { + breakdown->keep_old_us += (uint64_t) (ggml_time_us() - t_keep_old_us); + } + } - state.target_window = std::move(next_window); - state.target_window_pos = std::move(next_window_pos); + const int64_t t_new_rows_us = ggml_time_us(); + state.target_window_append_features.assign(new_rows.begin(), new_rows.end()); + dflash_ring_append_rows(state, state.target_window_append_features.data(), n_rows); + std::copy(new_positions.begin(), new_positions.end(), next_window_pos.begin() + keep_old_rows); + if (breakdown != nullptr) { + breakdown->new_rows_us += (uint64_t) (ggml_time_us() - t_new_rows_us); + } + + const int64_t t_commit_us = ggml_time_us(); + state.target_window_pos.swap(next_window_pos); + next_window_pos.clear(); state.target_window_rows = keep_old_rows + n_rows; + state.target_window_ring_filled = state.target_window_rows; state.last_target_pos = state.target_window_pos.empty() ? -1 : state.target_window_pos.back(); + dflash_record_window_update(state, keep_old_rows, n_rows, false); + if (breakdown != nullptr) { + breakdown->commit_us += (uint64_t) (ggml_time_us() - t_commit_us); + } + + const int64_t t_log_us = ggml_time_us(); dflash_contract_log_append(state, seq_id, new_positions); + if (breakdown != nullptr) { + breakdown->log_us += (uint64_t) (ggml_time_us() - t_log_us); + } return true; } static void dflash_clear_target_features(common_speculative_state_dflash & state) { state.target_window.clear(); state.target_window_pos.clear(); + state.target_window_stage.clear(); + state.target_window_pos_stage.clear(); + state.target_window_append_features.clear(); state.target_window_rows = 0; + state.target_window_ring_write_pos = 0; + state.target_window_ring_filled = 0; + state.target_window_keep_rows = 0; + state.target_window_append_rows = 0; + state.target_window_replace = false; + state.target_window_materialized = false; state.last_target_pos = -1; + llama_reset_dflash_kv_cache_state(state.ctx_dft); } static void dflash_context_shift( @@ -2814,10 +3149,12 @@ static void dflash_context_shift( llama_pos kv_keep, llama_pos kv_discard, llama_pos kv_past) { - if (kv_discard <= 0 || state.target_window_rows <= 0 || state.target_window.empty() || state.target_window_pos.empty()) { + if (kv_discard <= 0 || state.target_window_rows <= 0 || state.target_window_pos.empty()) { return; } + dflash_materialize_target_window_features(state); + const size_t row_width = (size_t) state.n_target_features; const llama_pos discard_begin = kv_keep; const llama_pos discard_end = kv_keep + kv_discard; @@ -2845,7 +3182,10 @@ static void dflash_context_shift( state.target_window = std::move(shifted_rows); state.target_window_pos = std::move(shifted_positions); state.target_window_rows = (int32_t) state.target_window_pos.size(); + dflash_ring_reset_rows(state, state.target_window.data(), state.target_window_rows); state.last_target_pos = state.target_window_pos.empty() ? -1 : state.target_window_pos.back(); + dflash_record_window_update(state, 0, state.target_window_rows, true); + llama_reset_dflash_kv_cache_state(state.ctx_dft); state.n_context_shifts++; } @@ -2959,8 +3299,9 @@ int32_t common_speculative_on_target_batch( } } + dflash_append_breakdown append_breakdown; const int64_t t_append_us = ggml_time_us(); - if (!dflash_append_target_features(*dflash_state, features, batch, seq_id)) { + if (!dflash_append_target_features(*dflash_state, features, batch, seq_id, &append_breakdown)) { return -1; } @@ -2971,8 +3312,20 @@ int32_t common_speculative_on_target_batch( dflash_state->n_warmup_append_rows += (size_t) batch.n_tokens; } else { dflash_state->t_accept_append_us += append_us; + dflash_state->t_accept_append_filter_us += append_breakdown.filter_us; + dflash_state->t_accept_append_window_alloc_us += append_breakdown.window_alloc_us; + dflash_state->t_accept_append_replace_us += append_breakdown.replace_us; + dflash_state->t_accept_append_keep_old_us += append_breakdown.keep_old_us; + dflash_state->t_accept_append_new_rows_us += append_breakdown.new_rows_us; + dflash_state->t_accept_append_commit_detail_us += append_breakdown.commit_us; + dflash_state->t_accept_append_log_us += append_breakdown.log_us; dflash_state->n_accept_append_calls++; dflash_state->n_accept_append_rows += (size_t) batch.n_tokens; + if (append_breakdown.replace_call) { + dflash_state->n_accept_append_replace_calls++; + } else { + dflash_state->n_accept_append_slide_calls++; + } } return 0; diff --git a/src/graphs/build_dflash.cpp b/src/graphs/build_dflash.cpp index b9862c2a..a5b9a815 100644 --- a/src/graphs/build_dflash.cpp +++ b/src/graphs/build_dflash.cpp @@ -23,38 +23,132 @@ ggml_cgraph * llm_build_context::build_dflash_kv_cache() { const int64_t ctx_len = lctx.dflash_visible_cross_ctx > 0 ? (int64_t) lctx.dflash_visible_cross_ctx : std::max(1, (int64_t) cparams.n_ctx - (int64_t) hparams.dflash_block_size); + const int64_t update_rows = std::max(1, lctx.dflash_kv_cache_update_rows > 0 ? lctx.dflash_kv_cache_update_rows : ctx_len); + const int32_t write_pos = lctx.dflash_kv_cache_write_pos; GGML_ASSERT(n_embd_head_k == n_embd_head_v); GGML_ASSERT(n_target_features > 0); GGML_ASSERT(lctx.ensure_dflash_kv_cache_tensors((int32_t) ctx_len)); + GGML_ASSERT(update_rows > 0 && update_rows <= ctx_len); + GGML_ASSERT(write_pos >= 0 && write_pos < ctx_len); - ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes((int) std::max(1, ctx_len)) + 24 * n_layer, false); + ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes((int) std::max(1, update_rows)) + 24 * n_layer, false); - lctx.dflash_kv_input_target_features = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_target_features, ctx_len); + lctx.dflash_kv_input_target_features = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_target_features, update_rows); ggml_set_input(lctx.dflash_kv_input_target_features); + cb(lctx.dflash_kv_input_target_features, "dflash_kv_input_target_features", -1); - lctx.dflash_kv_input_pos_ctx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ctx_len); + lctx.dflash_kv_input_pos_ctx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, update_rows); ggml_set_input(lctx.dflash_kv_input_pos_ctx); + cb(lctx.dflash_kv_input_pos_ctx, "dflash_kv_input_pos_ctx", -1); ggml_tensor * fused_target = llm_build_lora_mm(lctx, ctx0, model.dflash_fc, lctx.dflash_kv_input_target_features); fused_target = llm_build_norm(ctx0, fused_target, hparams, model.dflash_hidden_norm, nullptr, LLM_NORM_RMS, cb, -1); + cb(fused_target, "dflash_kv_fused_target", -1); for (int il = 0; il < n_layer; ++il) { GGML_ASSERT((size_t) il < lctx.dflash_k_ctx_cache.size()); GGML_ASSERT((size_t) il < lctx.dflash_v_ctx_cache.size()); - ggml_tensor * Kcur_ctx = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, fused_target); - Kcur_ctx = ggml_reshape_3d(ctx0, Kcur_ctx, n_embd_head_k, n_head_kv, ctx_len); + ggml_tensor * Kcur_ctx_proj = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, fused_target); + cb(Kcur_ctx_proj, "dflash_kv_k_proj", il); + + ggml_tensor * Kcur_ctx = ggml_reshape_3d(ctx0, Kcur_ctx_proj, n_embd_head_k, n_head_kv, update_rows); Kcur_ctx = llm_build_norm(ctx0, Kcur_ctx, hparams, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, cb, il); + cb(Kcur_ctx, "dflash_kv_k_norm", il); Kcur_ctx = ggml_rope_ext(ctx0, Kcur_ctx, lctx.dflash_kv_input_pos_ctx, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + cb(Kcur_ctx, "dflash_kv_k_rope", il); ggml_tensor * Vcur_ctx = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, fused_target); - Vcur_ctx = ggml_reshape_3d(ctx0, Vcur_ctx, n_embd_head_v, n_head_kv, ctx_len); + cb(Vcur_ctx, "dflash_kv_v_proj", il); + Vcur_ctx = ggml_reshape_3d(ctx0, Vcur_ctx, n_embd_head_v, n_head_kv, update_rows); - ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur_ctx, lctx.dflash_k_ctx_cache[(size_t) il])); - ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur_ctx, lctx.dflash_v_ctx_cache[(size_t) il])); + const int32_t first_rows = std::min((int32_t) update_rows, (int32_t) ctx_len - write_pos); + const int32_t second_rows = (int32_t) update_rows - first_rows; + + if (first_rows > 0) { + ggml_tensor * Ksrc_first = first_rows == update_rows + ? Kcur_ctx + : ggml_view_3d(ctx0, Kcur_ctx, + Kcur_ctx->ne[0], + Kcur_ctx->ne[1], + first_rows, + Kcur_ctx->nb[1], + Kcur_ctx->nb[2], + 0); + ggml_tensor * Vsrc_first = first_rows == update_rows + ? Vcur_ctx + : ggml_view_3d(ctx0, Vcur_ctx, + Vcur_ctx->ne[0], + Vcur_ctx->ne[1], + first_rows, + Vcur_ctx->nb[1], + Vcur_ctx->nb[2], + 0); + ggml_tensor * Kdst_first = ggml_view_3d(ctx0, lctx.dflash_k_ctx_cache[(size_t) il], + lctx.dflash_k_ctx_cache[(size_t) il]->ne[0], + lctx.dflash_k_ctx_cache[(size_t) il]->ne[1], + first_rows, + lctx.dflash_k_ctx_cache[(size_t) il]->nb[1], + lctx.dflash_k_ctx_cache[(size_t) il]->nb[2], + (size_t) write_pos * lctx.dflash_k_ctx_cache[(size_t) il]->nb[2]); + ggml_tensor * Vdst_first = ggml_view_3d(ctx0, lctx.dflash_v_ctx_cache[(size_t) il], + lctx.dflash_v_ctx_cache[(size_t) il]->ne[0], + lctx.dflash_v_ctx_cache[(size_t) il]->ne[1], + first_rows, + lctx.dflash_v_ctx_cache[(size_t) il]->nb[1], + lctx.dflash_v_ctx_cache[(size_t) il]->nb[2], + (size_t) write_pos * lctx.dflash_v_ctx_cache[(size_t) il]->nb[2]); + + ggml_tensor * Kstore_first = ggml_cpy(ctx0, Ksrc_first, Kdst_first); + cb(Kstore_first, "dflash_kv_k_store", il); + ggml_build_forward_expand(gf, Kstore_first); + + ggml_tensor * Vstore_first = ggml_cpy(ctx0, Vsrc_first, Vdst_first); + cb(Vstore_first, "dflash_kv_v_store", il); + ggml_build_forward_expand(gf, Vstore_first); + } + + if (second_rows > 0) { + ggml_tensor * Ksrc_second = ggml_view_3d(ctx0, Kcur_ctx, + Kcur_ctx->ne[0], + Kcur_ctx->ne[1], + second_rows, + Kcur_ctx->nb[1], + Kcur_ctx->nb[2], + (size_t) first_rows * Kcur_ctx->nb[2]); + ggml_tensor * Vsrc_second = ggml_view_3d(ctx0, Vcur_ctx, + Vcur_ctx->ne[0], + Vcur_ctx->ne[1], + second_rows, + Vcur_ctx->nb[1], + Vcur_ctx->nb[2], + (size_t) first_rows * Vcur_ctx->nb[2]); + ggml_tensor * Kdst_second = ggml_view_3d(ctx0, lctx.dflash_k_ctx_cache[(size_t) il], + lctx.dflash_k_ctx_cache[(size_t) il]->ne[0], + lctx.dflash_k_ctx_cache[(size_t) il]->ne[1], + second_rows, + lctx.dflash_k_ctx_cache[(size_t) il]->nb[1], + lctx.dflash_k_ctx_cache[(size_t) il]->nb[2], + 0); + ggml_tensor * Vdst_second = ggml_view_3d(ctx0, lctx.dflash_v_ctx_cache[(size_t) il], + lctx.dflash_v_ctx_cache[(size_t) il]->ne[0], + lctx.dflash_v_ctx_cache[(size_t) il]->ne[1], + second_rows, + lctx.dflash_v_ctx_cache[(size_t) il]->nb[1], + lctx.dflash_v_ctx_cache[(size_t) il]->nb[2], + 0); + + ggml_tensor * Kstore_second = ggml_cpy(ctx0, Ksrc_second, Kdst_second); + cb(Kstore_second, "dflash_kv_k_store", il); + ggml_build_forward_expand(gf, Kstore_second); + + ggml_tensor * Vstore_second = ggml_cpy(ctx0, Vsrc_second, Vdst_second); + cb(Vstore_second, "dflash_kv_v_store", il); + ggml_build_forward_expand(gf, Vstore_second); + } } return gf; @@ -69,12 +163,17 @@ ggml_cgraph * llm_build_context::build_dflash() { const int64_t ctx_len = lctx.dflash_visible_cross_ctx > 0 ? (int64_t) lctx.dflash_visible_cross_ctx : std::max(1, (int64_t) cparams.n_ctx - (int64_t) hparams.dflash_block_size); + const int32_t cache_rows = use_kv_cache ? std::clamp(lctx.dflash_kv_cache_view_n_filled, 0, (int32_t) ctx_len) : 0; + const int32_t cache_write_pos = use_kv_cache && ctx_len > 0 + ? ((lctx.dflash_kv_cache_view_write_pos % (int32_t) ctx_len) + (int32_t) ctx_len) % (int32_t) ctx_len + : 0; const int64_t n_kv_total = GGML_PAD(ctx_len + n_tokens, flash_attn ? 256 : 32); const int64_t n_kv_pad = n_kv_total - (ctx_len + n_tokens); GGML_ASSERT(n_embd_head_k == n_embd_head_v); GGML_ASSERT(n_target_features > 0); GGML_ASSERT(!use_kv_cache || lctx.ensure_dflash_kv_cache_tensors((int32_t) ctx_len)); + GGML_ASSERT(!use_kv_cache || (cache_write_pos >= 0 && cache_write_pos < ctx_len)); ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes((int) std::max(n_tokens, ctx_len)) + 32 * n_layer, false); @@ -160,8 +259,52 @@ ggml_cgraph * llm_build_context::build_dflash() { ggml_tensor * Kcur_ctx = nullptr; ggml_tensor * Vcur_ctx = nullptr; if (use_kv_cache) { - Kcur_ctx = lctx.dflash_k_ctx_cache[(size_t) il]; - Vcur_ctx = lctx.dflash_v_ctx_cache[(size_t) il]; + auto build_ordered_cache_view = [&](ggml_tensor * cache) -> ggml_tensor * { + if (!lctx.dflash_kv_cache_view_valid || cache_rows <= 0) { + return cache; + } + + if (cache_rows < ctx_len) { + ggml_tensor * zero_pad = ggml_view_3d(ctx0, cache, + cache->ne[0], + cache->ne[1], + ctx_len - cache_rows, + cache->nb[1], + cache->nb[2], + (size_t) cache_rows * cache->nb[2]); + ggml_tensor * valid = ggml_view_3d(ctx0, cache, + cache->ne[0], + cache->ne[1], + cache_rows, + cache->nb[1], + cache->nb[2], + 0); + return ggml_concat(ctx0, zero_pad, valid, 2); + } + + if (cache_write_pos == 0) { + return cache; + } + + ggml_tensor * tail = ggml_view_3d(ctx0, cache, + cache->ne[0], + cache->ne[1], + ctx_len - cache_write_pos, + cache->nb[1], + cache->nb[2], + (size_t) cache_write_pos * cache->nb[2]); + ggml_tensor * head = ggml_view_3d(ctx0, cache, + cache->ne[0], + cache->ne[1], + cache_write_pos, + cache->nb[1], + cache->nb[2], + 0); + return ggml_concat(ctx0, tail, head, 2); + }; + + Kcur_ctx = build_ordered_cache_view(lctx.dflash_k_ctx_cache[(size_t) il]); + Vcur_ctx = build_ordered_cache_view(lctx.dflash_v_ctx_cache[(size_t) il]); cb(Kcur_ctx, "Kcur_ctx_cache", il); cb(Vcur_ctx, "Vcur_ctx_cache", il); } else { diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index 081215dd..fc03353c 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -2173,7 +2173,27 @@ struct ggml_cgraph * llm_build_context::llama_build_graph_dflash_kv_cache(llama_ llama_batch dummy; dummy.n_tokens = 0; - llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { }; + llm_build_cb cb = [&](struct ggml_tensor * cur, const char * name, int il) { + if (il >= 0) { + int j = 0; + for (; j < GGML_MAX_NAME - 1; ++j) { + cur->name[j] = name[j]; + if (!name[j]) { + break; + } + } + if (j < GGML_MAX_NAME - 3) { + cur->name[j++] = '-'; + auto sil = std::to_string(il); + for (int k = 0; k < (int) sil.size() && j < GGML_MAX_NAME - 1; ++k) { + cur->name[j++] = sil[k]; + } + } + cur->name[j] = 0; + } else { + ggml_set_name(cur, name); + } + }; struct llm_build_context llm(lctx, dummy, cb, false, false, 0, false, &lctx.dflash_buf_compute_meta); diff --git a/src/llama-context.h b/src/llama-context.h index cc207a36..1a7a9d80 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -281,9 +281,17 @@ struct llama_context { const float * dflash_target_features = nullptr; size_t dflash_target_features_n_floats = 0; int32_t dflash_target_features_n_rows = 0; + const float * dflash_target_append_features = nullptr; + size_t dflash_target_append_features_n_floats = 0; + int32_t dflash_target_append_features_n_rows = 0; const llama_pos * dflash_target_positions = nullptr; size_t dflash_target_positions_n = 0; + uint64_t dflash_target_window_version = 0; + int32_t dflash_target_window_keep_rows = 0; + int32_t dflash_target_window_append_rows = 0; + bool dflash_target_window_replace = false; std::vector dflash_target_features_owned; + std::vector dflash_target_append_features_owned; std::vector dflash_target_positions_owned; std::vector dflash_target_features_padded; std::vector dflash_feature_view_buffer; @@ -295,6 +303,15 @@ struct llama_context { std::vector dflash_v_ctx_cache; struct ggml_context * dflash_cache_ctx = nullptr; std::vector dflash_cache_bufs; + int32_t dflash_kv_cache_write_pos = 0; + int32_t dflash_kv_cache_n_filled = 0; + int32_t dflash_kv_cache_update_rows = 0; + int32_t dflash_kv_cache_reserved_rows = 0; + int32_t dflash_kv_cache_view_write_pos = 0; + int32_t dflash_kv_cache_view_n_filled = 0; + uint64_t dflash_kv_cache_applied_window_version = 0; + bool dflash_kv_cache_valid = false; + bool dflash_kv_cache_view_valid = false; std::vector dflash_buf_compute_meta; ggml_backend_sched_t dflash_sched = nullptr; struct ggml_tensor * dflash_kv_input_target_features = nullptr; diff --git a/src/llama-spec-features.cpp b/src/llama-spec-features.cpp index bcf1ca89..ab8efddb 100644 --- a/src/llama-spec-features.cpp +++ b/src/llama-spec-features.cpp @@ -55,6 +55,56 @@ void llama_dflash_profile_reset(struct llama_context * ctx) { ctx->dflash_profile = {}; } +void llama_reset_dflash_kv_cache_state(struct llama_context * ctx) { + if (ctx == nullptr) { + return; + } + + ctx->dflash_kv_cache_write_pos = 0; + ctx->dflash_kv_cache_n_filled = 0; + ctx->dflash_kv_cache_update_rows = 0; + ctx->dflash_kv_cache_view_write_pos = 0; + ctx->dflash_kv_cache_view_n_filled = 0; + ctx->dflash_kv_cache_applied_window_version = 0; + ctx->dflash_kv_cache_valid = false; + ctx->dflash_kv_cache_view_valid = false; + + for (ggml_backend_buffer_t buf : ctx->dflash_cache_bufs) { + if (buf != nullptr) { + ggml_backend_buffer_clear(buf, 0); + } + } +} + +llama_dflash_kv_cache_transition llama_plan_dflash_kv_cache_transition_for_ctx( + const struct llama_context * ctx, + const llama_dflash_window_update & window_update, + int32_t n_rows) { + if (ctx == nullptr) { + llama_dflash_kv_cache_transition plan; + plan.rebuild_cache = true; + plan.append_rows = std::clamp(window_update.append_rows, 0, n_rows); + plan.next_n_filled = n_rows; + return plan; + } + + const int32_t cross_ctx = ctx->dflash_visible_cross_ctx > 0 + ? ctx->dflash_visible_cross_ctx + : std::max(1, (int32_t) ctx->cparams.n_ctx - (int32_t) ctx->model.hparams.dflash_block_size); + + return llama_plan_dflash_kv_cache_transition( + cross_ctx, + ctx->dflash_kv_cache_n_filled, + ctx->dflash_kv_cache_write_pos, + ctx->dflash_kv_cache_valid, + ctx->dflash_kv_cache_applied_window_version, + window_update.version, + window_update.keep_rows, + window_update.append_rows, + window_update.replace, + n_rows); +} + void llama_set_dflash_visible_cross_ctx( struct llama_context * ctx, int32_t cross_ctx) { @@ -205,26 +255,91 @@ static bool llama_set_dflash_target_features_impl( size_t n_floats, int32_t n_rows, const llama_pos * target_positions, - bool copy_data) { - if (ctx == nullptr || target_features == nullptr || n_floats == 0 || n_rows <= 0) { + bool copy_data, + const llama_dflash_window_update * window_update) { + const bool have_full_features = target_features != nullptr && n_floats > 0; + const bool have_append_features = window_update != nullptr && + window_update->append_features != nullptr && + window_update->append_floats > 0 && + window_update->append_rows > 0; + + if (ctx == nullptr || n_rows <= 0 || (!have_full_features && !have_append_features)) { return false; } auto & profile = ctx->dflash_profile; const int64_t t_start_us = ggml_time_us(); - const int32_t row_width = n_rows > 0 ? (int32_t) (n_floats / (size_t) n_rows) : 0; + const int32_t row_width = have_full_features + ? (n_rows > 0 ? (int32_t) (n_floats / (size_t) n_rows) : 0) + : (window_update->append_rows > 0 ? (int32_t) (window_update->append_floats / (size_t) window_update->append_rows) : 0); llama_pos first_pos = -1; llama_pos last_pos = -1; - if (copy_data) { + if (have_full_features && copy_data) { ctx->dflash_target_features_owned.assign(target_features, target_features + n_floats); ctx->dflash_target_features = ctx->dflash_target_features_owned.data(); - } else { + } else if (have_full_features) { ctx->dflash_target_features_owned.clear(); ctx->dflash_target_features = target_features; + } else { + ctx->dflash_target_features_owned.clear(); + ctx->dflash_target_features = nullptr; } - ctx->dflash_target_features_n_floats = n_floats; + ctx->dflash_target_features_n_floats = have_full_features ? n_floats : 0; ctx->dflash_target_features_n_rows = n_rows; + if (have_append_features && copy_data) { + ctx->dflash_target_append_features_owned.assign( + window_update->append_features, + window_update->append_features + window_update->append_floats); + ctx->dflash_target_append_features = ctx->dflash_target_append_features_owned.data(); + } else if (have_append_features) { + ctx->dflash_target_append_features_owned.clear(); + ctx->dflash_target_append_features = window_update->append_features; + } else { + ctx->dflash_target_append_features_owned.clear(); + ctx->dflash_target_append_features = nullptr; + } + ctx->dflash_target_append_features_n_floats = have_append_features ? window_update->append_floats : 0; + ctx->dflash_target_append_features_n_rows = have_append_features ? window_update->append_rows : 0; + ctx->dflash_target_window_version = window_update != nullptr && window_update->version > 0 + ? window_update->version + : ctx->dflash_target_window_version + 1; + ctx->dflash_target_window_keep_rows = window_update != nullptr + ? std::max(0, std::min(n_rows, window_update->keep_rows)) + : 0; + ctx->dflash_target_window_append_rows = window_update != nullptr + ? std::max(0, std::min(n_rows, window_update->append_rows)) + : n_rows; + ctx->dflash_target_window_replace = window_update != nullptr + ? window_update->replace + : true; + if (ctx->dflash_target_window_keep_rows + ctx->dflash_target_window_append_rows > n_rows) { + ctx->dflash_target_window_keep_rows = std::max(0, n_rows - ctx->dflash_target_window_append_rows); + } + + const int32_t cross_ctx = ctx->dflash_visible_cross_ctx > 0 + ? ctx->dflash_visible_cross_ctx + : std::max(1, (int32_t) ctx->cparams.n_ctx - (int32_t) ctx->model.hparams.dflash_block_size); + const llama_dflash_window_update cache_window_update = { + ctx->dflash_target_window_version, + ctx->dflash_target_window_keep_rows, + ctx->dflash_target_window_append_rows, + ctx->dflash_target_window_replace, + ctx->dflash_target_append_features, + ctx->dflash_target_append_features_n_floats, + }; + const llama_dflash_kv_cache_transition cache_plan = llama_plan_dflash_kv_cache_transition_for_ctx(ctx, cache_window_update, n_rows); + + if (cache_plan.cache_up_to_date) { + ctx->dflash_kv_cache_view_n_filled = ctx->dflash_kv_cache_n_filled; + ctx->dflash_kv_cache_view_write_pos = ctx->dflash_kv_cache_write_pos; + ctx->dflash_kv_cache_view_valid = ctx->dflash_kv_cache_valid; + } else if (cross_ctx > 0) { + ctx->dflash_kv_cache_view_n_filled = cache_plan.next_n_filled; + ctx->dflash_kv_cache_view_write_pos = cache_plan.next_write_pos; + ctx->dflash_kv_cache_view_valid = cache_plan.next_n_filled > 0; + } + if (target_positions != nullptr) { if (copy_data) { ctx->dflash_target_positions_owned.assign(target_positions, target_positions + n_rows); @@ -243,7 +358,10 @@ static bool llama_set_dflash_target_features_impl( profile.set_target_copy_calls++; profile.set_target_copy_us += (uint64_t) (ggml_time_us() - t_start_us); profile.set_target_rows += (uint64_t) n_rows; - profile.set_target_copy_bytes += n_floats * sizeof(float) + (target_positions ? (size_t) n_rows * sizeof(llama_pos) : 0); + profile.set_target_copy_bytes += + (have_full_features ? n_floats : 0) * sizeof(float) + + (have_append_features ? window_update->append_floats : 0) * sizeof(float) + + (target_positions ? (size_t) n_rows * sizeof(llama_pos) : 0); profile.last_n_rows = n_rows; profile.last_width = row_width; @@ -267,8 +385,9 @@ bool llama_set_dflash_target_features_copy( const float * target_features, size_t n_floats, int32_t n_rows, - const llama_pos * target_positions) { - return llama_set_dflash_target_features_impl(ctx, target_features, n_floats, n_rows, target_positions, true); + const llama_pos * target_positions, + const llama_dflash_window_update * window_update) { + return llama_set_dflash_target_features_impl(ctx, target_features, n_floats, n_rows, target_positions, true, window_update); } bool llama_set_dflash_target_features_view( @@ -276,8 +395,9 @@ bool llama_set_dflash_target_features_view( const float * target_features, size_t n_floats, int32_t n_rows, - const llama_pos * target_positions) { - return llama_set_dflash_target_features_impl(ctx, target_features, n_floats, n_rows, target_positions, false); + const llama_pos * target_positions, + const llama_dflash_window_update * window_update) { + return llama_set_dflash_target_features_impl(ctx, target_features, n_floats, n_rows, target_positions, false, window_update); } static void llama_record_dflash_capture_phase( diff --git a/src/llama-spec-features.h b/src/llama-spec-features.h index 9ec2e827..d976c89a 100644 --- a/src/llama-spec-features.h +++ b/src/llama-spec-features.h @@ -2,6 +2,8 @@ #include "llama.h" +#include +#include #include struct llama_context; @@ -24,6 +26,19 @@ struct llama_spec_feature_view { }; struct llama_dflash_profile_stats { + uint64_t decode_internal_chunks = 0; + uint64_t decode_graph_rebuilds = 0; + uint64_t decode_sync_profile_points = 0; + uint64_t decode_prelude_us = 0; + uint64_t decode_sched_reset_us = 0; + uint64_t decode_build_graph_us = 0; + uint64_t decode_sched_alloc_graph_us = 0; + uint64_t decode_set_inputs_us = 0; + uint64_t decode_graph_compute_us = 0; + uint64_t decode_result_us = 0; + uint64_t decode_embedding_us = 0; + uint64_t decode_final_sched_reset_us = 0; + uint64_t decode_output_reserve_calls = 0; uint64_t decode_output_reserve_us = 0; uint64_t decode_output_reserve_reallocs = 0; @@ -71,6 +86,20 @@ struct llama_dflash_profile_stats { uint64_t graph_kv_cache_read_concat_pad_calls = 0; uint64_t graph_kv_cache_cached_bytes = 0; uint64_t graph_kv_cache_calls = 0; + uint64_t graph_kv_node_fused_target_calls = 0; + uint64_t graph_kv_node_fused_target_us = 0; + uint64_t graph_kv_node_k_proj_calls = 0; + uint64_t graph_kv_node_k_proj_us = 0; + uint64_t graph_kv_node_k_norm_calls = 0; + uint64_t graph_kv_node_k_norm_us = 0; + uint64_t graph_kv_node_k_rope_calls = 0; + uint64_t graph_kv_node_k_rope_us = 0; + uint64_t graph_kv_node_v_proj_calls = 0; + uint64_t graph_kv_node_v_proj_us = 0; + uint64_t graph_kv_node_k_store_calls = 0; + uint64_t graph_kv_node_k_store_us = 0; + uint64_t graph_kv_node_v_store_calls = 0; + uint64_t graph_kv_node_v_store_us = 0; uint64_t graph_feature_bytes = 0; uint64_t graph_pos_bytes = 0; uint64_t graph_mask_bytes = 0; @@ -96,10 +125,76 @@ struct llama_dflash_profile_stats { llama_pos last_pos_last = -1; }; +struct llama_dflash_window_update { + uint64_t version = 0; + int32_t keep_rows = 0; + int32_t append_rows = 0; + bool replace = false; + const float * append_features = nullptr; + size_t append_floats = 0; +}; + +struct llama_dflash_kv_cache_transition { + bool cache_up_to_date = false; + bool rebuild_cache = false; + int32_t append_rows = 0; + int32_t next_n_filled = 0; + int32_t next_write_pos = 0; +}; + +static inline llama_dflash_kv_cache_transition llama_plan_dflash_kv_cache_transition( + int32_t cross_ctx, + int32_t current_n_filled, + int32_t current_write_pos, + bool cache_valid, + uint64_t applied_window_version, + uint64_t target_window_version, + int32_t keep_rows, + int32_t append_rows, + bool replace, + int32_t n_rows) { + llama_dflash_kv_cache_transition plan; + + const int32_t safe_cross_ctx = std::max(1, cross_ctx); + const int32_t bounded_n_filled = std::clamp(current_n_filled, 0, safe_cross_ctx); + const int32_t bounded_append_rows = std::clamp(append_rows, 0, n_rows); + const int32_t bounded_keep_rows = std::clamp(keep_rows, 0, n_rows); + const int32_t expected_keep_rows = std::min(bounded_n_filled, std::max(0, safe_cross_ctx - bounded_append_rows)); + + plan.cache_up_to_date = cache_valid && applied_window_version == target_window_version; + plan.rebuild_cache = !cache_valid || replace || bounded_append_rows <= 0 || bounded_append_rows > n_rows; + if (!plan.rebuild_cache && bounded_keep_rows != expected_keep_rows) { + plan.rebuild_cache = true; + } + + plan.append_rows = bounded_append_rows; + if (plan.cache_up_to_date) { + plan.next_n_filled = bounded_n_filled; + plan.next_write_pos = safe_cross_ctx > 0 + ? ((current_write_pos % safe_cross_ctx) + safe_cross_ctx) % safe_cross_ctx + : 0; + } else if (plan.rebuild_cache) { + plan.next_n_filled = std::min(safe_cross_ctx, n_rows); + plan.next_write_pos = plan.next_n_filled % safe_cross_ctx; + } else { + plan.next_n_filled = std::min(safe_cross_ctx, bounded_n_filled + bounded_append_rows); + plan.next_write_pos = (current_write_pos + bounded_append_rows) % safe_cross_ctx; + } + + return plan; +} + +llama_dflash_kv_cache_transition llama_plan_dflash_kv_cache_transition_for_ctx( + const struct llama_context * ctx, + const llama_dflash_window_update & window_update, + int32_t n_rows); + uint32_t llama_mtp_state_n_embd(const struct llama_context * ctx); void llama_dflash_profile_reset(struct llama_context * ctx); +void llama_reset_dflash_kv_cache_state(struct llama_context * ctx); + void llama_set_dflash_visible_cross_ctx( struct llama_context * ctx, int32_t cross_ctx); @@ -156,14 +251,16 @@ bool llama_set_dflash_target_features_copy( const float * target_features, size_t n_floats, int32_t n_rows, - const llama_pos * target_positions); + const llama_pos * target_positions, + const llama_dflash_window_update * window_update = nullptr); bool llama_set_dflash_target_features_view( struct llama_context * ctx, const float * target_features, size_t n_floats, int32_t n_rows, - const llama_pos * target_positions); + const llama_pos * target_positions, + const llama_dflash_window_update * window_update = nullptr); bool llama_set_dflash_capture_layers( struct llama_context * ctx, diff --git a/src/llama.cpp b/src/llama.cpp index e3b91b0b..e53940a2 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -171,6 +171,129 @@ static std::vector string_split(const std::string& str, const std:: return parts; } +static bool llama_env_flag_enabled(const char * name) { + const char * env = std::getenv(name); + return env != nullptr && *env != '\0' && + std::strcmp(env, "0") != 0 && + std::strcmp(env, "false") != 0 && + std::strcmp(env, "off") != 0; +} + +enum llama_dflash_kv_node_kind { + LLAMA_DFLASH_KV_NODE_NONE = 0, + LLAMA_DFLASH_KV_NODE_FUSED_TARGET, + LLAMA_DFLASH_KV_NODE_K_PROJ, + LLAMA_DFLASH_KV_NODE_K_NORM, + LLAMA_DFLASH_KV_NODE_K_ROPE, + LLAMA_DFLASH_KV_NODE_V_PROJ, + LLAMA_DFLASH_KV_NODE_K_STORE, + LLAMA_DFLASH_KV_NODE_V_STORE, +}; + +struct llama_dflash_kv_node_profiler { + llama_dflash_profile_stats * profile = nullptr; + int64_t t_start_us = 0; + llama_dflash_kv_node_kind active_kind = LLAMA_DFLASH_KV_NODE_NONE; +}; + +static bool llama_dflash_tensor_name_has_prefix(const struct ggml_tensor * tensor, const char * prefix) { + if (tensor == nullptr || prefix == nullptr || prefix[0] == '\0') { + return false; + } + + return std::strncmp(tensor->name, prefix, std::strlen(prefix)) == 0; +} + +static llama_dflash_kv_node_kind llama_dflash_kv_node_kind_from_tensor(const struct ggml_tensor * tensor) { + if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_fused_target")) { + return LLAMA_DFLASH_KV_NODE_FUSED_TARGET; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_proj")) { + return LLAMA_DFLASH_KV_NODE_K_PROJ; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_norm")) { + return LLAMA_DFLASH_KV_NODE_K_NORM; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_rope")) { + return LLAMA_DFLASH_KV_NODE_K_ROPE; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_v_proj")) { + return LLAMA_DFLASH_KV_NODE_V_PROJ; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_store")) { + return LLAMA_DFLASH_KV_NODE_K_STORE; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_v_store")) { + return LLAMA_DFLASH_KV_NODE_V_STORE; + } + + return LLAMA_DFLASH_KV_NODE_NONE; +} + +static void llama_dflash_kv_node_profile_add( + llama_dflash_profile_stats & profile, + llama_dflash_kv_node_kind kind, + uint64_t elapsed_us) { + switch (kind) { + case LLAMA_DFLASH_KV_NODE_FUSED_TARGET: + profile.graph_kv_node_fused_target_calls++; + profile.graph_kv_node_fused_target_us += elapsed_us; + break; + case LLAMA_DFLASH_KV_NODE_K_PROJ: + profile.graph_kv_node_k_proj_calls++; + profile.graph_kv_node_k_proj_us += elapsed_us; + break; + case LLAMA_DFLASH_KV_NODE_K_NORM: + profile.graph_kv_node_k_norm_calls++; + profile.graph_kv_node_k_norm_us += elapsed_us; + break; + case LLAMA_DFLASH_KV_NODE_K_ROPE: + profile.graph_kv_node_k_rope_calls++; + profile.graph_kv_node_k_rope_us += elapsed_us; + break; + case LLAMA_DFLASH_KV_NODE_V_PROJ: + profile.graph_kv_node_v_proj_calls++; + profile.graph_kv_node_v_proj_us += elapsed_us; + break; + case LLAMA_DFLASH_KV_NODE_K_STORE: + profile.graph_kv_node_k_store_calls++; + profile.graph_kv_node_k_store_us += elapsed_us; + break; + case LLAMA_DFLASH_KV_NODE_V_STORE: + profile.graph_kv_node_v_store_calls++; + profile.graph_kv_node_v_store_us += elapsed_us; + break; + case LLAMA_DFLASH_KV_NODE_NONE: + break; + } +} + +static bool llama_dflash_kv_node_eval_callback(struct ggml_tensor * tensor, bool ask, void * user_data) { + auto * profiler = static_cast(user_data); + if (profiler == nullptr || profiler->profile == nullptr) { + return false; + } + + const llama_dflash_kv_node_kind kind = llama_dflash_kv_node_kind_from_tensor(tensor); + if (ask) { + if (kind == LLAMA_DFLASH_KV_NODE_NONE) { + return false; + } + + profiler->active_kind = kind; + profiler->t_start_us = ggml_time_us(); + return true; + } + + if (kind != LLAMA_DFLASH_KV_NODE_NONE && profiler->active_kind == kind && profiler->t_start_us > 0) { + llama_dflash_kv_node_profile_add(*profiler->profile, kind, (uint64_t) (ggml_time_us() - profiler->t_start_us)); + } + + profiler->active_kind = LLAMA_DFLASH_KV_NODE_NONE; + profiler->t_start_us = 0; + return true; +} + // extract ip and port from RPC[ip:port] for rpc and keep other device names static std::vector extract_device_from_rpc_device(std::vector devices) { std::vector rpc_servers; @@ -689,6 +812,7 @@ bool llama_context::ensure_dflash_kv_cache_tensors(int32_t cross_ctx) { } dflash_profile.last_kv_cache_host_layers = host_layers; + llama_reset_dflash_kv_cache_state(this); LLAMA_LOG_INFO("%s: DFlash K/V cache placement cross_ctx=%d host_layers=%d/%d first=%s last=%s\n", __func__, target_cross_ctx, @@ -703,6 +827,15 @@ bool llama_context::ensure_dflash_kv_cache_tensors(int32_t cross_ctx) { void llama_context::free_dflash_kv_cache_tensors() { dflash_k_ctx_cache.clear(); dflash_v_ctx_cache.clear(); + dflash_kv_cache_write_pos = 0; + dflash_kv_cache_n_filled = 0; + dflash_kv_cache_update_rows = 0; + dflash_kv_cache_reserved_rows = 0; + dflash_kv_cache_view_write_pos = 0; + dflash_kv_cache_view_n_filled = 0; + dflash_kv_cache_applied_window_version = 0; + dflash_kv_cache_valid = false; + dflash_kv_cache_view_valid = false; dflash_kv_input_target_features = nullptr; dflash_kv_input_pos_ctx = nullptr; dflash_kq_mask_tensor = nullptr; @@ -5271,11 +5404,8 @@ static bool validate_dflash_graph_contract(const llama_context & lctx) { static bool prepare_dflash_graph_inputs( struct llama_context & lctx, uint32_t n_tokens) { - const char * dflash_kv_cache_env = std::getenv("IK_DFLASH_KV_CACHE"); - const bool use_kv_cache = dflash_kv_cache_env != nullptr && *dflash_kv_cache_env != '\0' && - std::strcmp(dflash_kv_cache_env, "0") != 0 && - std::strcmp(dflash_kv_cache_env, "false") != 0 && - std::strcmp(dflash_kv_cache_env, "off") != 0; + const bool use_kv_cache = llama_env_flag_enabled("IK_DFLASH_KV_CACHE"); + const bool kv_node_timing = llama_env_flag_enabled("IK_DFLASH_KV_NODE_TIMING"); auto & profile = lctx.dflash_profile; const int32_t cross_ctx = lctx.dflash_visible_cross_ctx > 0 ? lctx.dflash_visible_cross_ctx @@ -5304,10 +5434,13 @@ static bool prepare_dflash_graph_inputs( } const float * src = lctx.dflash_target_features; + const float * append_src = lctx.dflash_target_append_features; const llama_pos * src_pos = lctx.dflash_target_positions; const size_t total_floats = lctx.dflash_target_features_n_floats; + const size_t append_floats = lctx.dflash_target_append_features_n_floats; const size_t total_positions = lctx.dflash_target_positions_n; const int32_t n_rows = lctx.dflash_target_features_n_rows; + const int32_t append_rows_available = lctx.dflash_target_append_features_n_rows; const int32_t width = (int32_t) lctx.model.hparams.dflash_n_target_features; const int32_t graph_cross_ctx = use_kv_cache ? (lctx.dflash_k_ctx_cache.front() != nullptr ? (int32_t) lctx.dflash_k_ctx_cache.front()->ne[2] : 0) @@ -5330,19 +5463,26 @@ static bool prepare_dflash_graph_inputs( __func__, graph_cross_ctx, cross_ctx); return false; } - if (src == nullptr || total_floats == 0 || n_rows <= 0) { + if (n_rows <= 0) { profile.graph_shape_failures++; - LLAMA_LOG_ERROR("%s: missing DFlash target features\n", __func__); + LLAMA_LOG_ERROR("%s: missing DFlash target feature rows\n", __func__); return false; } - if (n_rows > cross_ctx || total_floats != (size_t) n_rows * (size_t) width) { + const bool have_full_src = src != nullptr && total_floats == (size_t) n_rows * (size_t) width; + if (n_rows > cross_ctx || (src != nullptr && !have_full_src)) { profile.graph_shape_failures++; LLAMA_LOG_ERROR("%s: invalid DFlash target feature shape (rows=%d width=%d floats=%zu cross_ctx=%d)\n", __func__, n_rows, width, total_floats, cross_ctx); return false; } + if (!use_kv_cache && !have_full_src) { + profile.graph_shape_failures++; + LLAMA_LOG_ERROR("%s: missing contiguous DFlash target features for inline path\n", __func__); + return false; + } + if (n_kv_total < cross_ctx + (int32_t) n_tokens) { profile.graph_mask_overflow++; LLAMA_LOG_ERROR("%s: invalid DFlash mask shape (n_kv_total=%d < cross_ctx+n_tokens=%d)\n", @@ -5351,24 +5491,26 @@ static bool prepare_dflash_graph_inputs( } const int32_t left_pad = cross_ctx - n_rows; - const size_t padded_floats = (size_t) cross_ctx * (size_t) width; - const size_t dst_offset = (size_t) left_pad * (size_t) width; - const int64_t t_feature_us = ggml_time_us(); profile.last_left_pad = left_pad; - if (lctx.dflash_target_features_padded.size() != padded_floats) { - lctx.dflash_target_features_padded.resize(padded_floats); - } - if (left_pad == 0 && total_floats == padded_floats) { - std::copy(src, src + total_floats, lctx.dflash_target_features_padded.begin()); - } else { - if (dst_offset > 0) { - std::fill(lctx.dflash_target_features_padded.begin(), - lctx.dflash_target_features_padded.begin() + (ptrdiff_t) dst_offset, 0.0f); + if (!use_kv_cache) { + const size_t padded_floats = (size_t) cross_ctx * (size_t) width; + const size_t dst_offset = (size_t) left_pad * (size_t) width; + const int64_t t_feature_us = ggml_time_us(); + if (lctx.dflash_target_features_padded.size() != padded_floats) { + lctx.dflash_target_features_padded.resize(padded_floats); } - std::copy(src, src + total_floats, lctx.dflash_target_features_padded.begin() + (ptrdiff_t) dst_offset); + if (left_pad == 0 && total_floats == padded_floats) { + std::copy(src, src + total_floats, lctx.dflash_target_features_padded.begin()); + } else { + if (dst_offset > 0) { + std::fill(lctx.dflash_target_features_padded.begin(), + lctx.dflash_target_features_padded.begin() + (ptrdiff_t) dst_offset, 0.0f); + } + std::copy(src, src + total_floats, lctx.dflash_target_features_padded.begin() + (ptrdiff_t) dst_offset); + } + profile.graph_feature_copy_us += (uint64_t) (ggml_time_us() - t_feature_us); + profile.graph_feature_bytes += padded_floats * sizeof(float); } - profile.graph_feature_copy_us += (uint64_t) (ggml_time_us() - t_feature_us); - profile.graph_feature_bytes += padded_floats * sizeof(float); const int64_t t_pos_us = ggml_time_us(); lctx.dflash_pos_ctx_data.resize((size_t) cross_ctx); @@ -5403,22 +5545,32 @@ static bool prepare_dflash_graph_inputs( profile.graph_pos_bytes += lctx.dflash_pos_ctx_data.size() * sizeof(llama_pos); if (use_kv_cache) { + const llama_dflash_kv_cache_transition cache_plan = llama_plan_dflash_kv_cache_transition( + cross_ctx, + lctx.dflash_kv_cache_n_filled, + lctx.dflash_kv_cache_write_pos, + lctx.dflash_kv_cache_valid, + lctx.dflash_kv_cache_applied_window_version, + lctx.dflash_target_window_version, + lctx.dflash_target_window_keep_rows, + lctx.dflash_target_window_append_rows, + lctx.dflash_target_window_replace, + n_rows); + + const bool have_append_src = append_src != nullptr && + append_rows_available == cache_plan.append_rows && + append_floats == (size_t) cache_plan.append_rows * (size_t) width; + + const int32_t update_rows = cache_plan.cache_up_to_date + ? 0 + : (cache_plan.rebuild_cache ? n_rows : cache_plan.append_rows); const size_t max_nodes = lctx.model.max_nodes((int) std::max(1, cross_ctx)) + 24 * lctx.model.hparams.n_layer; const size_t meta_size = ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false); if (lctx.dflash_buf_compute_meta.size() != meta_size) { lctx.dflash_buf_compute_meta.resize(meta_size); } - const int64_t t_build_us = ggml_time_us(); - ggml_cgraph * gf_kv = llm_build_context::llama_build_graph_dflash_kv_cache(lctx); - profile.graph_kv_cache_build_us += (uint64_t) (ggml_time_us() - t_build_us); - if (gf_kv == nullptr || lctx.dflash_kv_input_target_features == nullptr || lctx.dflash_kv_input_pos_ctx == nullptr) { - profile.graph_shape_failures++; - LLAMA_LOG_ERROR("%s: failed to build DFlash K/V cache graph\n", __func__); - return false; - } - - if (lctx.dflash_sched == nullptr) { + if (lctx.dflash_sched == nullptr || lctx.dflash_kv_cache_reserved_rows != cross_ctx) { std::vector backend_buft; backend_buft.reserve(lctx.backends.size()); for (auto * backend : lctx.backends) { @@ -5429,51 +5581,117 @@ static bool prepare_dflash_graph_inputs( } } + if (lctx.dflash_sched != nullptr) { + ggml_backend_sched_free(lctx.dflash_sched); + lctx.dflash_sched = nullptr; + } + + const int32_t saved_update_rows = lctx.dflash_kv_cache_update_rows; + lctx.dflash_kv_cache_update_rows = cross_ctx; + const int64_t t_build_us = ggml_time_us(); + ggml_cgraph * gf_reserve = llm_build_context::llama_build_graph_dflash_kv_cache(lctx); + profile.graph_kv_cache_build_us += (uint64_t) (ggml_time_us() - t_build_us); + lctx.dflash_kv_cache_update_rows = saved_update_rows; + if (gf_reserve == nullptr) { + profile.graph_shape_failures++; + LLAMA_LOG_ERROR("%s: failed to build DFlash K/V cache reserve graph\n", __func__); + return false; + } + const int64_t t_reserve_us = ggml_time_us(); lctx.dflash_sched = ggml_backend_sched_new(lctx.backends.data(), backend_buft.data(), lctx.backends.size(), max_nodes, false); - const bool reserved = lctx.dflash_sched != nullptr && ggml_backend_sched_reserve(lctx.dflash_sched, gf_kv); + const bool reserved = lctx.dflash_sched != nullptr && ggml_backend_sched_reserve(lctx.dflash_sched, gf_reserve); profile.graph_kv_cache_reserve_us += (uint64_t) (ggml_time_us() - t_reserve_us); if (!reserved) { profile.graph_shape_failures++; LLAMA_LOG_ERROR("%s: failed to initialize DFlash K/V scheduler\n", __func__); return false; } + lctx.dflash_kv_cache_reserved_rows = cross_ctx; } - const int64_t t_reset_us = ggml_time_us(); - ggml_backend_sched_reset(lctx.dflash_sched); - profile.graph_kv_cache_reset_us += (uint64_t) (ggml_time_us() - t_reset_us); + if (update_rows > 0) { + const float * update_src = nullptr; + if (have_append_src && update_rows == cache_plan.append_rows) { + update_src = append_src; + } else if (have_full_src) { + update_src = src + (size_t) (n_rows - update_rows) * (size_t) width; + } + const llama_pos * update_pos = src_pos + (n_rows - update_rows); - const int64_t t_alloc_us = ggml_time_us(); - ggml_backend_sched_alloc_graph(lctx.dflash_sched, gf_kv); - profile.graph_kv_cache_alloc_us += (uint64_t) (ggml_time_us() - t_alloc_us); + if (update_src == nullptr) { + profile.graph_shape_failures++; + LLAMA_LOG_ERROR("%s: missing DFlash appended target features for cached update (rows=%d append_rows=%d floats=%zu)\n", + __func__, n_rows, update_rows, append_floats); + return false; + } - ggml_backend_t kv_feature_backend = llama_backend_for_tensor(lctx, lctx.dflash_kv_input_target_features); - const int64_t t_feature_upload_us = ggml_time_us(); - if (kv_feature_backend != nullptr) { - ggml_backend_tensor_set_async(kv_feature_backend, lctx.dflash_kv_input_target_features, lctx.dflash_target_features_padded.data(), 0, ggml_nbytes(lctx.dflash_kv_input_target_features)); - } else { - ggml_backend_tensor_set(lctx.dflash_kv_input_target_features, lctx.dflash_target_features_padded.data(), 0, ggml_nbytes(lctx.dflash_kv_input_target_features)); + if (cache_plan.rebuild_cache) { + llama_reset_dflash_kv_cache_state(&lctx); + } + + lctx.dflash_kv_cache_update_rows = update_rows; + const int64_t t_build_us = ggml_time_us(); + ggml_cgraph * gf_kv = llm_build_context::llama_build_graph_dflash_kv_cache(lctx); + profile.graph_kv_cache_build_us += (uint64_t) (ggml_time_us() - t_build_us); + if (gf_kv == nullptr || lctx.dflash_kv_input_target_features == nullptr || lctx.dflash_kv_input_pos_ctx == nullptr) { + profile.graph_shape_failures++; + LLAMA_LOG_ERROR("%s: failed to build DFlash K/V cache graph\n", __func__); + return false; + } + + const int64_t t_reset_us = ggml_time_us(); + ggml_backend_sched_reset(lctx.dflash_sched); + profile.graph_kv_cache_reset_us += (uint64_t) (ggml_time_us() - t_reset_us); + + const int64_t t_alloc_us = ggml_time_us(); + ggml_backend_sched_alloc_graph(lctx.dflash_sched, gf_kv); + profile.graph_kv_cache_alloc_us += (uint64_t) (ggml_time_us() - t_alloc_us); + + ggml_backend_t kv_feature_backend = llama_backend_for_tensor(lctx, lctx.dflash_kv_input_target_features); + const int64_t t_feature_upload_us = ggml_time_us(); + if (kv_feature_backend != nullptr) { + ggml_backend_tensor_set_async(kv_feature_backend, lctx.dflash_kv_input_target_features, update_src, 0, ggml_nbytes(lctx.dflash_kv_input_target_features)); + } else { + ggml_backend_tensor_set(lctx.dflash_kv_input_target_features, update_src, 0, ggml_nbytes(lctx.dflash_kv_input_target_features)); + } + profile.graph_kv_cache_feature_upload_us += (uint64_t) (ggml_time_us() - t_feature_upload_us); + profile.graph_feature_bytes += (size_t) update_rows * (size_t) width * sizeof(float); + + ggml_backend_t kv_pos_backend = llama_backend_for_tensor(lctx, lctx.dflash_kv_input_pos_ctx); + const int64_t t_pos_upload_us = ggml_time_us(); + if (kv_pos_backend != nullptr) { + ggml_backend_tensor_set_async(kv_pos_backend, lctx.dflash_kv_input_pos_ctx, update_pos, 0, ggml_nbytes(lctx.dflash_kv_input_pos_ctx)); + } else { + ggml_backend_tensor_set(lctx.dflash_kv_input_pos_ctx, update_pos, 0, ggml_nbytes(lctx.dflash_kv_input_pos_ctx)); + } + profile.graph_kv_cache_pos_upload_us += (uint64_t) (ggml_time_us() - t_pos_upload_us); + + const int64_t t_kv_cache_us = ggml_time_us(); + llama_dflash_kv_node_profiler kv_node_profiler; + if (kv_node_timing) { + kv_node_profiler.profile = &profile; + ggml_backend_sched_set_eval_callback(lctx.dflash_sched, llama_dflash_kv_node_eval_callback, &kv_node_profiler); + } + llama_graph_compute_sched(lctx, lctx.dflash_sched, gf_kv, lctx.cparams.n_threads); + if (kv_node_timing) { + ggml_backend_sched_set_eval_callback(lctx.dflash_sched, nullptr, nullptr); + } + profile.graph_kv_cache_compute_us += (uint64_t) (ggml_time_us() - t_kv_cache_us); + + const int64_t t_sync_us = ggml_time_us(); + ggml_backend_sched_synchronize(lctx.dflash_sched); + profile.graph_kv_cache_sync_us += (uint64_t) (ggml_time_us() - t_sync_us); + profile.graph_kv_cache_calls++; + + lctx.dflash_kv_cache_n_filled = std::min(cross_ctx, lctx.dflash_kv_cache_n_filled + update_rows); + lctx.dflash_kv_cache_write_pos = (lctx.dflash_kv_cache_write_pos + update_rows) % cross_ctx; + lctx.dflash_kv_cache_applied_window_version = lctx.dflash_target_window_version; + lctx.dflash_kv_cache_valid = true; + lctx.dflash_kv_cache_view_n_filled = lctx.dflash_kv_cache_n_filled; + lctx.dflash_kv_cache_view_write_pos = lctx.dflash_kv_cache_write_pos; + lctx.dflash_kv_cache_view_valid = true; } - profile.graph_kv_cache_feature_upload_us += (uint64_t) (ggml_time_us() - t_feature_upload_us); - - ggml_backend_t kv_pos_backend = llama_backend_for_tensor(lctx, lctx.dflash_kv_input_pos_ctx); - const int64_t t_pos_upload_us = ggml_time_us(); - if (kv_pos_backend != nullptr) { - ggml_backend_tensor_set_async(kv_pos_backend, lctx.dflash_kv_input_pos_ctx, lctx.dflash_pos_ctx_data.data(), 0, ggml_nbytes(lctx.dflash_kv_input_pos_ctx)); - } else { - ggml_backend_tensor_set(lctx.dflash_kv_input_pos_ctx, lctx.dflash_pos_ctx_data.data(), 0, ggml_nbytes(lctx.dflash_kv_input_pos_ctx)); - } - profile.graph_kv_cache_pos_upload_us += (uint64_t) (ggml_time_us() - t_pos_upload_us); - - const int64_t t_kv_cache_us = ggml_time_us(); - llama_graph_compute_sched(lctx, lctx.dflash_sched, gf_kv, lctx.cparams.n_threads); - profile.graph_kv_cache_compute_us += (uint64_t) (ggml_time_us() - t_kv_cache_us); - - const int64_t t_sync_us = ggml_time_us(); - ggml_backend_sched_synchronize(lctx.dflash_sched); - profile.graph_kv_cache_sync_us += (uint64_t) (ggml_time_us() - t_sync_us); - profile.graph_kv_cache_calls++; } else { ggml_backend_tensor_set(lctx.inp_dflash_target_features, lctx.dflash_target_features_padded.data(), 0, ggml_nbytes(lctx.inp_dflash_target_features)); ggml_backend_tensor_set(lctx.inp_dflash_pos_ctx, lctx.dflash_pos_ctx_data.data(), 0, ggml_nbytes(lctx.inp_dflash_pos_ctx)); @@ -5586,6 +5804,9 @@ static int llama_decode_internal( } lctx.n_queued_tokens += n_tokens_all; + auto * dflash_profile = lctx.model.arch == LLM_ARCH_DFLASH_DRAFT ? &lctx.dflash_profile : nullptr; + const bool dflash_decode_timing = dflash_profile != nullptr && llama_env_flag_enabled("IK_DFLASH_DECODE_TIMING"); + auto & kv_self = lctx.kv_self; const int64_t n_embd = hparams.n_embd; @@ -5670,6 +5891,10 @@ static int llama_decode_internal( #if IK_PRINT_TIMING auto tim1 = ggml_time_us(); #endif + const int64_t t_dflash_prelude_us = dflash_decode_timing ? ggml_time_us() : 0; + if (dflash_decode_timing) { + dflash_profile->decode_internal_chunks++; + } uint32_t n_tokens = std::min(n_ubatch, n_tokens_all - cur_token); if (llm_arch_is_hybrid(model.arch) && n_tokens > 1 && @@ -5804,6 +6029,9 @@ static int llama_decode_internal( auto tim2 = ggml_time_us(); printf("prelude(...): %d us\n", int(tim2-tim1)); #endif + if (dflash_decode_timing) { + dflash_profile->decode_prelude_us += (uint64_t) (ggml_time_us() - t_dflash_prelude_us); + } #if IK_PRINT_TIMING tim1 = ggml_time_us(); @@ -5811,30 +6039,45 @@ static int llama_decode_internal( auto & prev = cparams.mtp_op_type == MTP_OP_NONE ? lctx.prev : lctx.prev_mtp; ggml_cgraph * gf = nullptr; if (!lctx.can_reuse_graph(u_batch)) { + if (dflash_decode_timing) { + dflash_profile->decode_graph_rebuilds++; + } + const int64_t t_dflash_sched_reset_us = dflash_decode_timing ? ggml_time_us() : 0; lctx.reset_scheduler(); ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data); #if IK_PRINT_TIMING tim2 = ggml_time_us(); printf("sched_reset(...): %d us\n", int(tim2-tim1)); #endif + if (dflash_decode_timing) { + dflash_profile->decode_sched_reset_us += (uint64_t) (ggml_time_us() - t_dflash_sched_reset_us); + } #if IK_PRINT_TIMING tim1 = ggml_time_us(); #endif + const int64_t t_dflash_build_graph_us = dflash_decode_timing ? ggml_time_us() : 0; gf = llm_build_context::llama_build_graph(lctx, u_batch, false); #if IK_PRINT_TIMING tim2 = ggml_time_us(); printf("build_graph(...): %d us\n", int(tim2-tim1)); #endif + if (dflash_decode_timing) { + dflash_profile->decode_build_graph_us += (uint64_t) (ggml_time_us() - t_dflash_build_graph_us); + } #if IK_PRINT_TIMING tim1 = ggml_time_us(); #endif + const int64_t t_dflash_sched_alloc_us = dflash_decode_timing ? ggml_time_us() : 0; ggml_backend_sched_alloc_graph(lctx.sched, gf); #if IK_PRINT_TIMING tim2 = ggml_time_us(); printf("sched_alloc_graph(...): %d us\n", int(tim2-tim1)); #endif + if (dflash_decode_timing) { + dflash_profile->decode_sched_alloc_graph_us += (uint64_t) (ggml_time_us() - t_dflash_sched_alloc_us); + } //if (u_batch.n_tokens == 1 && u_batch.embd == nullptr && lctx.cparams.graph_reuse) { if (u_batch.embd == nullptr && lctx.cparams.graph_reuse && !(lctx.model.arch == LLM_ARCH_GEMMA4_MTP && lctx.mtp_target_ctx != nullptr)) { @@ -5855,16 +6098,15 @@ static int llama_decode_internal( } } - if (lctx.model.arch == LLM_ARCH_DFLASH_DRAFT) { - auto & profile = lctx.dflash_profile; - profile.decode_prepare_calls++; + if (dflash_profile != nullptr) { + dflash_profile->decode_prepare_calls++; const int64_t t_prepare_dflash_us = ggml_time_us(); if (!prepare_dflash_graph_inputs(lctx, n_tokens)) { - profile.decode_prepare_failures++; - profile.decode_prepare_us += (uint64_t) (ggml_time_us() - t_prepare_dflash_us); + dflash_profile->decode_prepare_failures++; + dflash_profile->decode_prepare_us += (uint64_t) (ggml_time_us() - t_prepare_dflash_us); return GGML_STATUS_FAILED; } - profile.decode_prepare_us += (uint64_t) (ggml_time_us() - t_prepare_dflash_us); + dflash_profile->decode_prepare_us += (uint64_t) (ggml_time_us() - t_prepare_dflash_us); } // the output is always the last tensor in the graph @@ -5910,16 +6152,26 @@ static int llama_decode_internal( #if IK_PRINT_TIMING == 1 tim1 = ggml_time_us(); #endif + const int64_t t_dflash_set_inputs_us = dflash_decode_timing ? ggml_time_us() : 0; llama_set_inputs(lctx, u_batch); #if IK_PRINT_TIMING == 1 tim2 = ggml_time_us(); printf("set_inputs(...): %d us\n", int(tim2-tim1)); #endif + if (dflash_decode_timing) { + dflash_profile->decode_set_inputs_us += (uint64_t) (ggml_time_us() - t_dflash_set_inputs_us); + } #if IK_PRINT_TIMING tim1 = ggml_time_us(); #endif + const int64_t t_dflash_graph_compute_us = dflash_decode_timing ? ggml_time_us() : 0; llama_graph_compute(lctx, gf, n_threads); + if (dflash_decode_timing) { + llama_synchronize(&lctx); + dflash_profile->decode_sync_profile_points++; + dflash_profile->decode_graph_compute_us += (uint64_t) (ggml_time_us() - t_dflash_graph_compute_us); + } #if IK_PRINT_TIMING llama_synchronize(&lctx); tim2 = ggml_time_us(); @@ -5950,6 +6202,7 @@ static int llama_decode_internal( #if IK_PRINT_TIMING tim1 = ggml_time_us(); #endif + const int64_t t_dflash_get_result_us = dflash_decode_timing ? ggml_time_us() : 0; // Do not process logits if MTP is only updating the KV cache. if (cparams.mtp_op_type != MTP_OP_WARMUP) { // && cparams.mtp_op_type != MTP_OP_UPDATE_ACCEPTED) { ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched, res); @@ -5980,6 +6233,11 @@ static int llama_decode_internal( } } } + if (dflash_decode_timing) { + llama_synchronize(&lctx); + dflash_profile->decode_sync_profile_points++; + dflash_profile->decode_result_us += (uint64_t) (ggml_time_us() - t_dflash_get_result_us); + } #if IK_PRINT_TIMING tim2 = ggml_time_us(); printf("get_result(...): %d us\n", int(tim2-tim1)); @@ -5992,6 +6250,7 @@ static int llama_decode_internal( #if IK_PRINT_TIMING tim1 = ggml_time_us(); #endif + const int64_t t_dflash_get_embedding_us = dflash_decode_timing ? ggml_time_us() : 0; ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched, embd); GGML_ASSERT(backend_embd != nullptr); @@ -6031,6 +6290,11 @@ static int llama_decode_internal( GGML_ABORT("unknown pooling type"); } } + if (dflash_decode_timing) { + llama_synchronize(&lctx); + dflash_profile->decode_sync_profile_points++; + dflash_profile->decode_embedding_us += (uint64_t) (ggml_time_us() - t_dflash_get_embedding_us); + } #if IK_PRINT_TIMING tim2 = ggml_time_us(); printf("get_embedding(...): %d us\n", int(tim2-tim1)); @@ -6074,9 +6338,13 @@ static int llama_decode_internal( #if IK_PRINT_TIMING auto tim1 = ggml_time_us(); #endif + const int64_t t_dflash_final_sched_reset_us = dflash_decode_timing ? ggml_time_us() : 0; if (!lctx.prev) { lctx.reset_scheduler(); } + if (dflash_decode_timing) { + dflash_profile->decode_final_sched_reset_us += (uint64_t) (ggml_time_us() - t_dflash_final_sched_reset_us); + } #if IK_PRINT_TIMING auto tim2 = ggml_time_us(); printf("sched_reset(...): %d us\n", int(tim2-tim1));