From 0d75eee35a892d37d0d196e37d7e2e9b5a090f05 Mon Sep 17 00:00:00 2001 From: SamuelOliveirads Date: Sun, 14 Jun 2026 16:02:02 -0300 Subject: [PATCH] remove duplicated code and unnecesary refactor --- common/speculative-dflash-impl.h | 872 ++++++++++++++ common/speculative-impl.h | 1743 --------------------------- common/speculative.cpp | 1767 +++++++++++++++++++++------- convert_hf_to_gguf.py | 228 +++- examples/server/server-context.cpp | 111 +- src/llama-dflash-profile.h | 340 ++++++ src/llama-dflash.cpp | 337 +----- src/llama-spec-features-dflash.cpp | 30 + src/llama-spec-features-dflash.h | 10 + src/llama-spec-features.cpp | 5 - src/llama.cpp | 338 +----- 11 files changed, 2818 insertions(+), 2963 deletions(-) create mode 100644 common/speculative-dflash-impl.h delete mode 100644 common/speculative-impl.h create mode 100644 src/llama-dflash-profile.h diff --git a/common/speculative-dflash-impl.h b/common/speculative-dflash-impl.h new file mode 100644 index 00000000..c644ddbb --- /dev/null +++ b/common/speculative-dflash-impl.h @@ -0,0 +1,872 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +static bool common_speculative_are_dflash_compatible( + const llama_model * model_tgt, + const llama_model * model_dft) { + const llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt); + const llama_vocab * vocab_dft = llama_model_get_vocab(model_dft); + + if (llama_vocab_type(vocab_tgt) != llama_vocab_type(vocab_dft)) { + LOG_DBG("%s: DFlash draft model vocab type must match the target model\n", __func__); + return false; + } + + const bool add_bos_tgt = llama_vocab_get_add_bos(vocab_tgt); + const bool add_bos_dft = llama_vocab_get_add_bos(vocab_dft); + const bool add_eos_tgt = llama_vocab_get_add_eos(vocab_tgt); + const bool add_eos_dft = llama_vocab_get_add_eos(vocab_dft); + const llama_token bos_tgt = llama_vocab_bos(vocab_tgt); + const llama_token bos_dft = llama_vocab_bos(vocab_dft); + const llama_token eos_tgt = llama_vocab_eos(vocab_tgt); + const llama_token eos_dft = llama_vocab_eos(vocab_dft); + + if (add_bos_tgt != add_bos_dft || add_eos_tgt != add_eos_dft || + (add_bos_tgt && bos_tgt != bos_dft) || + (add_eos_tgt && eos_tgt != eos_dft)) { + LOG_DBG("%s: DFlash draft special tokens must match the target model (add_bos=%d/%d add_eos=%d/%d bos=%d/%d eos=%d/%d)\n", + __func__, + (int) add_bos_tgt, + (int) add_bos_dft, + (int) add_eos_tgt, + (int) add_eos_dft, + (int) bos_tgt, + (int) bos_dft, + (int) eos_tgt, + (int) eos_dft); + return false; + } + + const int n_vocab_tgt = llama_vocab_n_tokens(vocab_tgt); + const int n_vocab_dft = llama_vocab_n_tokens(vocab_dft); + const int vocab_diff = n_vocab_tgt > n_vocab_dft + ? n_vocab_tgt - n_vocab_dft + : n_vocab_dft - n_vocab_tgt; + + if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) { + LOG_DBG("%s: DFlash draft vocab size differs too much from the target model (%d vs %d)\n", + __func__, n_vocab_dft, n_vocab_tgt); + return false; + } + + for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) { + const char * token_text_tgt = llama_vocab_get_text(vocab_tgt, i); + const char * token_text_dft = llama_vocab_get_text(vocab_dft, i); + + if (std::strcmp(token_text_tgt, token_text_dft) != 0) { + LOG_DBG("%s: DFlash draft token %d differs - target '%s', draft '%s'\n", __func__, i, + common_token_to_piece(vocab_tgt, i).c_str(), + common_token_to_piece(vocab_dft, i).c_str()); + return false; + } + } + + return true; +} + +static bool dflash_contract_log_enabled() { + const char * env = std::getenv("IK_DFLASH_CONTRACT_LOG"); + if (env == nullptr || *env == '\0') { + return false; + } + + return std::strcmp(env, "0") != 0 && + std::strcmp(env, "false") != 0 && + std::strcmp(env, "off") != 0; +} + +static bool dflash_stats_log_enabled() { + const char * env = std::getenv("IK_DFLASH_STATS_LOG"); + if (env == nullptr || *env == '\0') { + return false; + } + + return std::strcmp(env, "0") != 0 && + std::strcmp(env, "false") != 0 && + std::strcmp(env, "off") != 0; +} + +template +static std::string dflash_contract_format_values( + const std::vector & values, + size_t edge_count = 4) { + std::ostringstream oss; + oss << '['; + if (values.empty()) { + oss << ']'; + return oss.str(); + } + + const size_t head = std::min(edge_count, values.size()); + for (size_t i = 0; i < head; ++i) { + if (i > 0) { + oss << ','; + } + oss << values[i]; + } + + if (values.size() > edge_count * 2) { + oss << ",...,"; + for (size_t i = values.size() - edge_count; i < values.size(); ++i) { + if (i > values.size() - edge_count) { + oss << ','; + } + oss << values[i]; + } + } else { + for (size_t i = head; i < values.size(); ++i) { + oss << ',' << values[i]; + } + } + + oss << ']'; + return oss.str(); +} + +struct dflash_contract_pos_summary { + llama_pos first = -1; + llama_pos last = -1; + int32_t gap_count = 0; + int32_t nonmono_count = 0; +}; + +static dflash_contract_pos_summary dflash_contract_summarize_positions( + const std::vector & positions) { + dflash_contract_pos_summary summary; + if (positions.empty()) { + return summary; + } + + summary.first = positions.front(); + summary.last = positions.back(); + for (size_t i = 1; i < positions.size(); ++i) { + if (positions[i] <= positions[i - 1]) { + summary.nonmono_count++; + } else if (positions[i] != positions[i - 1] + 1) { + summary.gap_count++; + } + } + + return summary; +} + +struct common_speculative_state_dflash; + +static void dflash_contract_log_append( + const common_speculative_state_dflash & state, + llama_seq_id seq_id, + const std::vector & new_positions); +static void dflash_contract_log_draft( + const common_speculative_state_dflash & state, + int32_t n_keep, + size_t result_size); +static void dflash_materialize_target_window_features(common_speculative_state_dflash & state); + +// DFlash runtime state and draft path. +struct common_speculative_state_dflash : public common_speculative_state { + llama_context * ctx_tgt; + llama_context * ctx_dft; + + llama_batch batch = {}; + + int32_t block_size = 0; + int32_t mask_token_id = -1; + int32_t n_target_features = 0; + int32_t cross_ctx = 0; + bool ready = false; + + std::vector target_layer_ids; + std::vector target_window; + std::vector target_window_pos; + std::vector target_window_stage; + std::vector target_window_pos_stage; + std::vector target_window_ring; + std::vector target_window_append_features; + int32_t target_window_rows = 0; + int32_t target_window_ring_write_pos = 0; + int32_t target_window_ring_filled = 0; + uint64_t target_window_version = 0; + int32_t target_window_keep_rows = 0; + int32_t target_window_append_rows = 0; + bool target_window_replace = false; + bool target_window_materialized = false; + llama_pos last_target_pos = -1; + size_t n_window_updates = 0; + size_t n_rows_seen = 0; + size_t n_rows_dropped = 0; + size_t n_context_shifts = 0; + size_t n_draft_empty = 0; + size_t n_set_target_fail = 0; + size_t n_decode_fail = 0; + llama_pos last_draft_pos_base = -1; + + uint64_t t_draft_decode_us = 0; + uint64_t t_draft_sample_us = 0; + uint64_t t_warmup_collect_us = 0; + uint64_t t_warmup_append_us = 0; + uint64_t t_accept_output_copy_us = 0; + uint64_t t_accept_commit_us = 0; + uint64_t t_accept_append_us = 0; + uint64_t t_accept_append_filter_us = 0; + uint64_t t_accept_append_window_alloc_us = 0; + uint64_t t_accept_append_replace_us = 0; + uint64_t t_accept_append_keep_old_us = 0; + uint64_t t_accept_append_new_rows_us = 0; + uint64_t t_accept_append_commit_detail_us = 0; + uint64_t t_accept_append_log_us = 0; + size_t n_warmup_collect_calls = 0; + size_t n_warmup_collect_rows = 0; + size_t n_warmup_append_calls = 0; + size_t n_warmup_append_rows = 0; + size_t n_accept_output_copy_calls = 0; + size_t n_accept_output_copy_rows = 0; + size_t n_accept_commit_calls = 0; + size_t n_accept_commit_rows = 0; + size_t n_accept_append_calls = 0; + size_t n_accept_append_rows = 0; + size_t n_accept_append_replace_calls = 0; + size_t n_accept_append_slide_calls = 0; + + common_speculative_state_dflash( + enum common_speculative_type type, + llama_context * ctx_tgt, + llama_context * ctx_dft, + int32_t cross_ctx) + : common_speculative_state(type) + , ctx_tgt(ctx_tgt) + , ctx_dft(ctx_dft) + , cross_ctx(std::max(1, cross_ctx)) + { + const llama_model * model_tgt = llama_get_model(ctx_tgt); + const llama_model * model_dft = llama_get_model(ctx_dft); + + if (!common_speculative_are_dflash_compatible(model_tgt, model_dft)) { + LOG_ERR("%s: DFlash draft model vocab/tokenizer is incompatible with the target model\n", __func__); + return; + } + + block_size = llama_model_dflash_block_size(model_dft); + mask_token_id = llama_model_dflash_mask_token_id(model_dft); + n_target_features = llama_model_dflash_n_target_features(model_dft); + const int32_t n_target_layers = llama_model_dflash_n_target_layers(model_dft); + + if (block_size <= 0 || mask_token_id < 0 || n_target_features <= 0 || n_target_layers <= 0) { + LOG_ERR("%s: invalid DFlash metadata (block_size=%d, mask_token_id=%d, n_target_features=%d, n_target_layers=%d)\n", + __func__, block_size, mask_token_id, n_target_features, n_target_layers); + return; + } + + target_layer_ids.resize((size_t) n_target_layers); + if (llama_model_dflash_target_layer_ids(model_dft, target_layer_ids.data(), n_target_layers) != n_target_layers) { + LOG_ERR("%s: failed to read DFlash target layer ids\n", __func__); + target_layer_ids.clear(); + return; + } + + const auto * vocab_tgt = llama_model_get_vocab(model_tgt); + const auto * vocab_dft = llama_model_get_vocab(model_dft); + const int32_t target_vocab_size = llama_vocab_n_tokens(vocab_tgt); + const int32_t draft_vocab_size = llama_vocab_n_tokens(vocab_dft); + const int32_t target_hidden_size = llama_model_n_embd(model_tgt); + const int32_t draft_hidden_size = llama_model_n_embd(model_dft); + const int32_t target_mask_token_id = llama_model_dflash_target_mask_token_id(model_tgt); + const int32_t expected_n_target_features = target_hidden_size > 0 ? target_hidden_size * n_target_layers : 0; + + if (target_mask_token_id != (int32_t) LLAMA_TOKEN_NULL && mask_token_id != target_mask_token_id) { + LOG_ERR("%s: DFlash mask token mismatch (draft=%d target=%d)\n", + __func__, mask_token_id, target_mask_token_id); + return; + } + + if (target_hidden_size <= 0 || draft_hidden_size <= 0) { + LOG_ERR("%s: invalid DFlash hidden sizes (draft=%d target=%d)\n", + __func__, draft_hidden_size, target_hidden_size); + return; + } + + if (expected_n_target_features <= 0 || n_target_features != expected_n_target_features) { + LOG_ERR("%s: DFlash target feature width mismatch (metadata=%d expected=%d target_hidden=%d target_layers=%d)\n", + __func__, n_target_features, expected_n_target_features, target_hidden_size, n_target_layers); + return; + } + + std::vector sorted_target_layer_ids = target_layer_ids; + std::sort(sorted_target_layer_ids.begin(), sorted_target_layer_ids.end()); + if (std::adjacent_find(sorted_target_layer_ids.begin(), sorted_target_layer_ids.end()) != sorted_target_layer_ids.end()) { + LOG_ERR("%s: duplicate DFlash target layer ids survived into runtime validation\n", __func__); + target_layer_ids.clear(); + return; + } + + const int32_t n_target_model_layers = llama_n_layer(model_tgt); + for (int32_t layer_id : target_layer_ids) { + if (layer_id < 0 || layer_id >= n_target_model_layers) { + LOG_ERR("%s: invalid DFlash target layer id %d for target model with %d layers\n", + __func__, layer_id, n_target_model_layers); + target_layer_ids.clear(); + return; + } + } + + const int32_t io_mode = llama_model_dflash_io_mode(model_dft, model_tgt); + if (io_mode == LLAMA_DFLASH_IO_MODE_INVALID) { + LOG_ERR("%s: DFlash draft is missing required IO tensors after target sharing\n", __func__); + return; + } + + if (io_mode == LLAMA_DFLASH_IO_MODE_MIXED) { + LOG_ERR("%s: DFlash IO contract must be fully shared or fully self-contained, but resolved to mixed mode\n", __func__); + return; + } + + if (io_mode == LLAMA_DFLASH_IO_MODE_SELF_CONTAINED && !llama_model_dflash_io_tensors_match(model_dft, target_hidden_size, target_vocab_size)) { + LOG_ERR("%s: DFlash self-contained IO tensors do not match the target hidden/vocab contract (target_hidden=%d target_vocab=%d)\n", + __func__, + target_hidden_size, + target_vocab_size); + return; + } + + if (!llama_set_dflash_capture_layers(ctx_tgt, target_layer_ids.data(), (int32_t) target_layer_ids.size())) { + LOG_ERR("%s: failed to configure DFlash target capture callback\n", __func__); + return; + } + + batch = llama_batch_init(std::max(1, block_size), 0, 1); + target_window.reserve((size_t) this->cross_ctx * (size_t) n_target_features); + target_window_stage.reserve((size_t) this->cross_ctx * (size_t) n_target_features); + target_window_ring.resize((size_t) this->cross_ctx * (size_t) n_target_features); + target_window_append_features.reserve((size_t) this->cross_ctx * (size_t) n_target_features); + target_window_pos.reserve((size_t) this->cross_ctx); + target_window_pos_stage.reserve((size_t) this->cross_ctx); + ready = true; + + llama_set_dflash_visible_cross_ctx(ctx_dft, this->cross_ctx); + llama_dflash_profile_reset(ctx_tgt); + llama_dflash_profile_reset(ctx_dft); + + std::ostringstream layers_oss; + for (size_t i = 0; i < target_layer_ids.size(); ++i) { + if (i > 0) { + layers_oss << ","; + } + layers_oss << target_layer_ids[i]; + } + + const char * io_mode_name = io_mode == LLAMA_DFLASH_IO_MODE_SHARED ? "shared" : "self-contained"; + LOG_INF("%s: DFlash context ready (n_ctx=%d, block_size=%d, cross_ctx=%d, n_target_features=%d, target_layer_ids=[%s])\n", + __func__, llama_n_ctx(ctx_dft), block_size, this->cross_ctx, n_target_features, layers_oss.str().c_str()); + LOG_INF("%s: DFlash artifact io=%s draft_vocab=%d target_vocab=%d draft_hidden=%d target_hidden=%d mask_token_id=%d target_mask_token_id=%d\n", + __func__, io_mode_name, draft_vocab_size, target_vocab_size, draft_hidden_size, target_hidden_size, mask_token_id, target_mask_token_id); + } + + ~common_speculative_state_dflash() override { + llama_clear_dflash_capture(ctx_tgt); + if (ctx_dft) { + llama_free(ctx_dft); + } + if (batch.token != nullptr) { + llama_batch_free(batch); + } + } + + void begin(const llama_tokens & prompt) override { + GGML_UNUSED(prompt); + llama_kv_cache_clear(ctx_dft); + llama_reset_dflash_kv_cache_state(ctx_dft); + n_window_updates = 0; + n_rows_seen = 0; + n_rows_dropped = 0; + n_context_shifts = 0; + n_draft_empty = 0; + n_set_target_fail = 0; + n_decode_fail = 0; + last_draft_pos_base = -1; + t_draft_decode_us = 0; + t_draft_sample_us = 0; + t_warmup_collect_us = 0; + t_warmup_append_us = 0; + t_accept_output_copy_us = 0; + t_accept_commit_us = 0; + t_accept_append_us = 0; + t_accept_append_filter_us = 0; + t_accept_append_window_alloc_us = 0; + t_accept_append_replace_us = 0; + t_accept_append_keep_old_us = 0; + t_accept_append_new_rows_us = 0; + t_accept_append_commit_detail_us = 0; + t_accept_append_log_us = 0; + n_warmup_collect_calls = 0; + n_warmup_collect_rows = 0; + n_warmup_append_calls = 0; + n_warmup_append_rows = 0; + n_accept_output_copy_calls = 0; + n_accept_output_copy_rows = 0; + n_accept_commit_calls = 0; + n_accept_commit_rows = 0; + n_accept_append_calls = 0; + n_accept_append_rows = 0; + n_accept_append_replace_calls = 0; + n_accept_append_slide_calls = 0; + llama_dflash_profile_reset(ctx_tgt); + llama_dflash_profile_reset(ctx_dft); + } + + void draft( + const common_params_speculative & params, + const llama_tokens & prompt_tgt, + llama_token id_last, + llama_tokens & result) override { + GGML_UNUSED(prompt_tgt); + + result.clear(); + if (!ready || target_window_rows <= 0) { + n_draft_empty++; + return; + } + + const int32_t n_keep = std::min(params.n_max, block_size - 1); + if (n_keep <= 0) { + return; + } + + const float * target_features = nullptr; + size_t target_feature_floats = 0; + llama_dflash_window_update window_update = { + target_window_version, + target_window_keep_rows, + target_window_append_rows, + target_window_replace, + target_window_append_features.empty() ? nullptr : target_window_append_features.data(), + target_window_append_features.size(), + }; + const llama_dflash_kv_cache_transition cache_plan = + llama_plan_dflash_kv_cache_transition_for_ctx(ctx_dft, window_update, target_window_rows); + + if (cache_plan.rebuild_cache) { + dflash_materialize_target_window_features(*this); + target_features = target_window.data(); + target_feature_floats = target_window.size(); + window_update.append_features = target_window.data(); + window_update.append_floats = target_window.size(); + window_update.append_rows = target_window_rows; + } + + if (!llama_set_dflash_target_features_view(ctx_dft, target_features, target_feature_floats, target_window_rows, target_window_pos.data(), &window_update)) { + LOG_ERR("%s: failed to set DFlash target features\n", __func__); + n_set_target_fail++; + return; + } + + llama_kv_cache_clear(ctx_dft); + batch.n_tokens = 0; + const int32_t batch_len = n_keep + 1; + const llama_pos draft_pos_base = last_target_pos >= 0 ? last_target_pos + 1 : (llama_pos) target_window_rows; + const llama_pos seed_pos = last_target_pos >= 0 ? last_target_pos : draft_pos_base - 1; + last_draft_pos_base = draft_pos_base; + common_batch_add(batch, id_last, seed_pos, { 0 }, false); + for (int32_t i = 1; i < batch_len; ++i) { + common_batch_add(batch, mask_token_id, draft_pos_base + (i - 1), { 0 }, i <= n_keep); + } + + const int64_t t_decode_us = ggml_time_us(); + if (llama_decode(ctx_dft, batch) != 0) { + LOG_ERR("%s: llama_decode() failed for DFlash draft batch\n", __func__); + n_decode_fail++; + batch.n_tokens = 0; + return; + } + t_draft_decode_us += (uint64_t) (ggml_time_us() - t_decode_us); + + result.reserve((size_t) n_keep); + const int64_t t_sample_us = ggml_time_us(); + for (int32_t i = 0; i < n_keep; ++i) { + llama_token id = llama_get_dflash_draft_token_ith(ctx_dft, i); + if (id == LLAMA_TOKEN_NULL) { + id = common_sampler_sample_speculative(nullptr, ctx_dft, i + 1, nullptr); + } + result.push_back(id); + } + t_draft_sample_us += (uint64_t) (ggml_time_us() - t_sample_us); + + batch.n_tokens = 0; + dflash_contract_log_draft(*this, n_keep, result.size()); + } + + void accept(uint16_t n_accepted) override { + GGML_UNUSED(n_accepted); + } +}; + +static void dflash_contract_log_append( + const common_speculative_state_dflash & state, + llama_seq_id seq_id, + const std::vector & new_positions) { + if (!dflash_contract_log_enabled()) { + return; + } + + static std::atomic counter = 0; + const uint64_t ordinal = counter.fetch_add(1, std::memory_order_relaxed); + if (ordinal >= 8) { + return; + } + + const dflash_contract_pos_summary incoming = dflash_contract_summarize_positions(new_positions); + const dflash_contract_pos_summary window = dflash_contract_summarize_positions(state.target_window_pos); + + LOG_INF("dflash contract append[%llu]: seq=%d incoming_rows=%zu incoming_pos=%s pos=[%d..%d] gaps=%d nonmono=%d window_rows=%d window_pos=%s pos=[%d..%d] gaps=%d nonmono=%d last_target_pos=%d\n", + (unsigned long long) (ordinal + 1), + (int) seq_id, + new_positions.size(), + dflash_contract_format_values(new_positions).c_str(), + (int) incoming.first, + (int) incoming.last, + incoming.gap_count, + incoming.nonmono_count, + state.target_window_rows, + dflash_contract_format_values(state.target_window_pos).c_str(), + (int) window.first, + (int) window.last, + window.gap_count, + window.nonmono_count, + (int) state.last_target_pos); +} + +static void dflash_contract_log_draft( + const common_speculative_state_dflash & state, + int32_t n_keep, + size_t result_size) { + if (!dflash_contract_log_enabled()) { + return; + } + + static std::atomic counter = 0; + const uint64_t ordinal = counter.fetch_add(1, std::memory_order_relaxed); + if (ordinal >= 8) { + return; + } + + const dflash_contract_pos_summary window = dflash_contract_summarize_positions(state.target_window_pos); + llama_dflash_profile_stats graph_stats = {}; + llama_dflash_profile_get_stats(state.ctx_dft, &graph_stats); + const int draft_delta = (state.last_target_pos >= 0 && state.last_draft_pos_base >= 0) + ? (int) (state.last_draft_pos_base - state.last_target_pos) + : -1; + const llama_pos seed_pos = state.last_target_pos; + const llama_pos mask_first_pos = state.last_draft_pos_base; + const llama_pos mask_last_pos = state.last_draft_pos_base >= 0 + ? state.last_draft_pos_base + n_keep - 1 + : -1; + + LOG_INF("dflash contract draft[%llu]: window_rows=%d window_pos=%s pos=[%d..%d] gaps=%d nonmono=%d last_target_pos=%d seed_pos=%d mask_pos=[%d..%d] sample_rows=[1..%d] output_rows=[1..%d] draft_pos_base=%d delta=%d n_keep=%d result=%zu set_target(missing/nonmono)=%llu/%llu graph(fallback/nonmono)=%llu/%llu graph_pos=[%d..%d]\n", + (unsigned long long) (ordinal + 1), + state.target_window_rows, + dflash_contract_format_values(state.target_window_pos).c_str(), + (int) window.first, + (int) window.last, + window.gap_count, + window.nonmono_count, + (int) state.last_target_pos, + (int) seed_pos, + (int) mask_first_pos, + (int) mask_last_pos, + n_keep, + n_keep, + (int) state.last_draft_pos_base, + draft_delta, + n_keep, + result_size, + (unsigned long long) graph_stats.set_target_missing_positions, + (unsigned long long) graph_stats.set_target_non_monotonic_positions, + (unsigned long long) graph_stats.graph_pos_fallbacks, + (unsigned long long) graph_stats.graph_pos_non_monotonic, + (int) graph_stats.last_pos_first, + (int) graph_stats.last_pos_last); +} + +struct dflash_append_breakdown { + uint64_t filter_us = 0; + uint64_t window_alloc_us = 0; + uint64_t replace_us = 0; + uint64_t keep_old_us = 0; + uint64_t new_rows_us = 0; + uint64_t commit_us = 0; + uint64_t log_us = 0; + bool replace_call = false; +}; + +static void dflash_record_window_update( + common_speculative_state_dflash & state, + int32_t keep_rows, + int32_t append_rows, + bool replace) { + state.target_window_keep_rows = std::max(0, keep_rows); + state.target_window_append_rows = std::max(0, append_rows); + state.target_window_replace = replace; + state.target_window_version++; +} + +static void dflash_ring_reset_rows( + common_speculative_state_dflash & state, + const float * rows, + int32_t n_rows) { + const size_t row_width = (size_t) state.n_target_features; + if (n_rows <= 0 || rows == nullptr) { + state.target_window_ring_write_pos = 0; + state.target_window_ring_filled = 0; + return; + } + + if (state.target_window_ring.size() != (size_t) state.cross_ctx * row_width) { + state.target_window_ring.resize((size_t) state.cross_ctx * row_width); + } + + std::memcpy(state.target_window_ring.data(), rows, (size_t) n_rows * row_width * sizeof(float)); + state.target_window_ring_write_pos = n_rows % state.cross_ctx; + state.target_window_ring_filled = n_rows; + state.target_window_materialized = false; +} + +static void dflash_ring_append_rows( + common_speculative_state_dflash & state, + const float * rows, + int32_t n_rows) { + const size_t row_width = (size_t) state.n_target_features; + if (n_rows <= 0 || rows == nullptr) { + return; + } + + if (state.target_window_ring.size() != (size_t) state.cross_ctx * row_width) { + state.target_window_ring.resize((size_t) state.cross_ctx * row_width); + } + + int32_t write_pos = state.target_window_ring_write_pos; + int32_t remaining = n_rows; + const float * src = rows; + while (remaining > 0) { + const int32_t chunk_rows = std::min(remaining, state.cross_ctx - write_pos); + std::memcpy( + state.target_window_ring.data() + (size_t) write_pos * row_width, + src, + (size_t) chunk_rows * row_width * sizeof(float)); + src += (size_t) chunk_rows * row_width; + remaining -= chunk_rows; + write_pos = (write_pos + chunk_rows) % state.cross_ctx; + } + + state.target_window_ring_write_pos = write_pos; + state.target_window_ring_filled = std::min(state.cross_ctx, state.target_window_ring_filled + n_rows); + state.target_window_materialized = false; +} + +static void dflash_materialize_target_window_features(common_speculative_state_dflash & state) { + if (state.target_window_materialized || state.target_window_rows <= 0) { + return; + } + + const size_t row_width = (size_t) state.n_target_features; + state.target_window.resize((size_t) state.target_window_rows * row_width); + + const int32_t read_start = (state.target_window_ring_write_pos - state.target_window_rows + state.cross_ctx) % state.cross_ctx; + const int32_t first_rows = std::min(state.target_window_rows, state.cross_ctx - read_start); + std::memcpy( + state.target_window.data(), + state.target_window_ring.data() + (size_t) read_start * row_width, + (size_t) first_rows * row_width * sizeof(float)); + + const int32_t second_rows = state.target_window_rows - first_rows; + if (second_rows > 0) { + std::memcpy( + state.target_window.data() + (size_t) first_rows * row_width, + state.target_window_ring.data(), + (size_t) second_rows * row_width * sizeof(float)); + } + + state.target_window_materialized = true; +} + +static bool dflash_append_target_features( + common_speculative_state_dflash & state, + const common_speculative_feature_view & features, + const llama_batch & batch, + llama_seq_id seq_id, + dflash_append_breakdown * breakdown = nullptr) { + GGML_UNUSED(batch); + + if (features.kind != COMMON_SPECULATIVE_FEATURE_HIDDEN_STATE || + features.width != state.n_target_features || + features.rows.empty() || + state.cross_ctx <= 0) { + return false; + } + + const size_t row_width = (size_t) state.n_target_features; + std::vector new_rows; + std::vector new_positions; + new_rows.reserve(features.rows.size() * row_width); + new_positions.reserve(features.rows.size()); + + const int64_t t_filter_us = ggml_time_us(); + for (const auto & row : features.rows) { + if (row.seq_id != seq_id || row.data == nullptr) { + continue; + } + + new_positions.push_back(row.pos); + new_rows.insert(new_rows.end(), row.data, row.data + row_width); + } + if (breakdown != nullptr) { + breakdown->filter_us += (uint64_t) (ggml_time_us() - t_filter_us); + } + + if (new_positions.empty()) { + return false; + } + + const int32_t n_rows = (int32_t) new_positions.size(); + state.n_window_updates++; + state.n_rows_seen += (size_t) n_rows; + if (n_rows >= state.cross_ctx) { + state.n_rows_dropped += (size_t) state.target_window_rows + (size_t) (n_rows - state.cross_ctx); + const int32_t keep_from = n_rows - state.cross_ctx; + const int64_t t_replace_us = ggml_time_us(); + state.target_window_pos.assign(new_positions.begin() + keep_from, new_positions.end()); + state.target_window_append_features.assign( + new_rows.begin() + (ptrdiff_t) keep_from * (ptrdiff_t) row_width, + new_rows.end()); + dflash_ring_reset_rows(state, state.target_window_append_features.data(), state.cross_ctx); + if (breakdown != nullptr) { + breakdown->replace_us += (uint64_t) (ggml_time_us() - t_replace_us); + breakdown->replace_call = true; + } + + const int64_t t_commit_us = ggml_time_us(); + state.target_window_rows = state.cross_ctx; + state.target_window_ring_filled = state.target_window_rows; + state.last_target_pos = state.target_window_pos.empty() ? -1 : state.target_window_pos.back(); + dflash_record_window_update(state, 0, state.target_window_rows, true); + if (breakdown != nullptr) { + breakdown->commit_us += (uint64_t) (ggml_time_us() - t_commit_us); + } + + const int64_t t_log_us = ggml_time_us(); + dflash_contract_log_append(state, seq_id, new_positions); + if (breakdown != nullptr) { + breakdown->log_us += (uint64_t) (ggml_time_us() - t_log_us); + } + return true; + } + + const int32_t keep_old_rows = std::min(state.target_window_rows, state.cross_ctx - n_rows); + state.n_rows_dropped += (size_t) std::max(0, state.target_window_rows - keep_old_rows); + const int64_t t_window_alloc_us = ggml_time_us(); + std::vector & next_window_pos = state.target_window_pos_stage; + next_window_pos.resize((size_t) (keep_old_rows + n_rows)); + if (breakdown != nullptr) { + breakdown->window_alloc_us += (uint64_t) (ggml_time_us() - t_window_alloc_us); + } + + if (keep_old_rows > 0) { + const int64_t t_keep_old_us = ggml_time_us(); + std::copy(state.target_window_pos.end() - keep_old_rows, state.target_window_pos.end(), next_window_pos.begin()); + if (breakdown != nullptr) { + breakdown->keep_old_us += (uint64_t) (ggml_time_us() - t_keep_old_us); + } + } + + const int64_t t_new_rows_us = ggml_time_us(); + state.target_window_append_features.assign(new_rows.begin(), new_rows.end()); + dflash_ring_append_rows(state, state.target_window_append_features.data(), n_rows); + std::copy(new_positions.begin(), new_positions.end(), next_window_pos.begin() + keep_old_rows); + if (breakdown != nullptr) { + breakdown->new_rows_us += (uint64_t) (ggml_time_us() - t_new_rows_us); + } + + const int64_t t_commit_us = ggml_time_us(); + state.target_window_pos.swap(next_window_pos); + next_window_pos.clear(); + state.target_window_rows = keep_old_rows + n_rows; + state.target_window_ring_filled = state.target_window_rows; + state.last_target_pos = state.target_window_pos.empty() ? -1 : state.target_window_pos.back(); + dflash_record_window_update(state, keep_old_rows, n_rows, false); + if (breakdown != nullptr) { + breakdown->commit_us += (uint64_t) (ggml_time_us() - t_commit_us); + } + + const int64_t t_log_us = ggml_time_us(); + dflash_contract_log_append(state, seq_id, new_positions); + if (breakdown != nullptr) { + breakdown->log_us += (uint64_t) (ggml_time_us() - t_log_us); + } + return true; +} + +static void dflash_clear_target_features(common_speculative_state_dflash & state) { + state.target_window.clear(); + state.target_window_pos.clear(); + state.target_window_stage.clear(); + state.target_window_pos_stage.clear(); + state.target_window_append_features.clear(); + state.target_window_rows = 0; + state.target_window_ring_write_pos = 0; + state.target_window_ring_filled = 0; + state.target_window_keep_rows = 0; + state.target_window_append_rows = 0; + state.target_window_replace = false; + state.target_window_materialized = false; + state.last_target_pos = -1; + llama_reset_dflash_kv_cache_state(state.ctx_dft); +} + +static void dflash_context_shift( + common_speculative_state_dflash & state, + llama_pos kv_keep, + llama_pos kv_discard, + llama_pos kv_past) { + if (kv_discard <= 0 || state.target_window_rows <= 0 || state.target_window_pos.empty()) { + return; + } + + dflash_materialize_target_window_features(state); + + const size_t row_width = (size_t) state.n_target_features; + const llama_pos discard_begin = kv_keep; + const llama_pos discard_end = kv_keep + kv_discard; + + std::vector shifted_rows; + std::vector shifted_positions; + shifted_rows.reserve(state.target_window.size()); + shifted_positions.reserve(state.target_window_pos.size()); + + for (int32_t row = 0; row < state.target_window_rows; ++row) { + llama_pos pos = state.target_window_pos[(size_t) row]; + if (pos >= discard_begin && pos < discard_end) { + continue; + } + + if (pos >= discard_end && pos < kv_past) { + pos -= kv_discard; + } + + const float * row_src = state.target_window.data() + (size_t) row * row_width; + shifted_rows.insert(shifted_rows.end(), row_src, row_src + row_width); + shifted_positions.push_back(pos); + } + + state.target_window = std::move(shifted_rows); + state.target_window_pos = std::move(shifted_positions); + state.target_window_rows = (int32_t) state.target_window_pos.size(); + dflash_ring_reset_rows(state, state.target_window.data(), state.target_window_rows); + state.last_target_pos = state.target_window_pos.empty() ? -1 : state.target_window_pos.back(); + dflash_record_window_update(state, 0, state.target_window_rows, true); + llama_reset_dflash_kv_cache_state(state.ctx_dft); + state.n_context_shifts++; +} diff --git a/common/speculative-impl.h b/common/speculative-impl.h deleted file mode 100644 index dbf8cfb1..00000000 --- a/common/speculative-impl.h +++ /dev/null @@ -1,1743 +0,0 @@ -// DFlash runtime state and draft path. -struct common_speculative_state_dflash : public common_speculative_state { - llama_context * ctx_tgt; - llama_context * ctx_dft; - - llama_batch batch = {}; - - int32_t block_size = 0; - int32_t mask_token_id = -1; - int32_t n_target_features = 0; - int32_t cross_ctx = 0; - bool ready = false; - - std::vector target_layer_ids; - std::vector target_window; - std::vector target_window_pos; - std::vector target_window_stage; - std::vector target_window_pos_stage; - std::vector target_window_ring; - std::vector target_window_append_features; - int32_t target_window_rows = 0; - int32_t target_window_ring_write_pos = 0; - int32_t target_window_ring_filled = 0; - uint64_t target_window_version = 0; - int32_t target_window_keep_rows = 0; - int32_t target_window_append_rows = 0; - bool target_window_replace = false; - bool target_window_materialized = false; - llama_pos last_target_pos = -1; - size_t n_window_updates = 0; - size_t n_rows_seen = 0; - size_t n_rows_dropped = 0; - size_t n_context_shifts = 0; - size_t n_draft_empty = 0; - size_t n_set_target_fail = 0; - size_t n_decode_fail = 0; - llama_pos last_draft_pos_base = -1; - - uint64_t t_draft_decode_us = 0; - uint64_t t_draft_sample_us = 0; - uint64_t t_warmup_collect_us = 0; - uint64_t t_warmup_append_us = 0; - uint64_t t_accept_output_copy_us = 0; - uint64_t t_accept_commit_us = 0; - uint64_t t_accept_append_us = 0; - uint64_t t_accept_append_filter_us = 0; - uint64_t t_accept_append_window_alloc_us = 0; - uint64_t t_accept_append_replace_us = 0; - uint64_t t_accept_append_keep_old_us = 0; - uint64_t t_accept_append_new_rows_us = 0; - uint64_t t_accept_append_commit_detail_us = 0; - uint64_t t_accept_append_log_us = 0; - size_t n_warmup_collect_calls = 0; - size_t n_warmup_collect_rows = 0; - size_t n_warmup_append_calls = 0; - size_t n_warmup_append_rows = 0; - size_t n_accept_output_copy_calls = 0; - size_t n_accept_output_copy_rows = 0; - size_t n_accept_commit_calls = 0; - size_t n_accept_commit_rows = 0; - size_t n_accept_append_calls = 0; - size_t n_accept_append_rows = 0; - size_t n_accept_append_replace_calls = 0; - size_t n_accept_append_slide_calls = 0; - - common_speculative_state_dflash( - enum common_speculative_type type, - llama_context * ctx_tgt, - llama_context * ctx_dft, - int32_t cross_ctx) - : common_speculative_state(type) - , ctx_tgt(ctx_tgt) - , ctx_dft(ctx_dft) - , cross_ctx(std::max(1, cross_ctx)) - { - const llama_model * model_tgt = llama_get_model(ctx_tgt); - const llama_model * model_dft = llama_get_model(ctx_dft); - - if (!common_speculative_are_dflash_compatible(model_tgt, model_dft)) { - LOG_ERR("%s: DFlash draft model vocab/tokenizer is incompatible with the target model\n", __func__); - return; - } - - block_size = llama_model_dflash_block_size(model_dft); - mask_token_id = llama_model_dflash_mask_token_id(model_dft); - n_target_features = llama_model_dflash_n_target_features(model_dft); - const int32_t n_target_layers = llama_model_dflash_n_target_layers(model_dft); - - if (block_size <= 0 || mask_token_id < 0 || n_target_features <= 0 || n_target_layers <= 0) { - LOG_ERR("%s: invalid DFlash metadata (block_size=%d, mask_token_id=%d, n_target_features=%d, n_target_layers=%d)\n", - __func__, block_size, mask_token_id, n_target_features, n_target_layers); - return; - } - - target_layer_ids.resize((size_t) n_target_layers); - if (llama_model_dflash_target_layer_ids(model_dft, target_layer_ids.data(), n_target_layers) != n_target_layers) { - LOG_ERR("%s: failed to read DFlash target layer ids\n", __func__); - target_layer_ids.clear(); - return; - } - - const auto * vocab_tgt = llama_model_get_vocab(model_tgt); - const auto * vocab_dft = llama_model_get_vocab(model_dft); - const int32_t target_vocab_size = llama_vocab_n_tokens(vocab_tgt); - const int32_t draft_vocab_size = llama_vocab_n_tokens(vocab_dft); - const int32_t target_hidden_size = llama_model_n_embd(model_tgt); - const int32_t draft_hidden_size = llama_model_n_embd(model_dft); - const int32_t target_mask_token_id = llama_model_dflash_target_mask_token_id(model_tgt); - const int32_t expected_n_target_features = target_hidden_size > 0 ? target_hidden_size * n_target_layers : 0; - - if (target_mask_token_id != (int32_t) LLAMA_TOKEN_NULL && mask_token_id != target_mask_token_id) { - LOG_ERR("%s: DFlash mask token mismatch (draft=%d target=%d)\n", - __func__, mask_token_id, target_mask_token_id); - return; - } - - if (target_hidden_size <= 0 || draft_hidden_size <= 0) { - LOG_ERR("%s: invalid DFlash hidden sizes (draft=%d target=%d)\n", - __func__, draft_hidden_size, target_hidden_size); - return; - } - - if (expected_n_target_features <= 0 || n_target_features != expected_n_target_features) { - LOG_ERR("%s: DFlash target feature width mismatch (metadata=%d expected=%d target_hidden=%d target_layers=%d)\n", - __func__, n_target_features, expected_n_target_features, target_hidden_size, n_target_layers); - return; - } - - std::vector sorted_target_layer_ids = target_layer_ids; - std::sort(sorted_target_layer_ids.begin(), sorted_target_layer_ids.end()); - if (std::adjacent_find(sorted_target_layer_ids.begin(), sorted_target_layer_ids.end()) != sorted_target_layer_ids.end()) { - LOG_ERR("%s: duplicate DFlash target layer ids survived into runtime validation\n", __func__); - target_layer_ids.clear(); - return; - } - - const int32_t n_target_model_layers = llama_n_layer(model_tgt); - for (int32_t layer_id : target_layer_ids) { - if (layer_id < 0 || layer_id >= n_target_model_layers) { - LOG_ERR("%s: invalid DFlash target layer id %d for target model with %d layers\n", - __func__, layer_id, n_target_model_layers); - target_layer_ids.clear(); - return; - } - } - - const int32_t io_mode = llama_model_dflash_io_mode(model_dft, model_tgt); - if (io_mode == LLAMA_DFLASH_IO_MODE_INVALID) { - LOG_ERR("%s: DFlash draft is missing required IO tensors after target sharing\n", __func__); - return; - } - - if (io_mode == LLAMA_DFLASH_IO_MODE_MIXED) { - LOG_ERR("%s: DFlash IO contract must be fully shared or fully self-contained, but resolved to mixed mode\n", __func__); - return; - } - - if (io_mode == LLAMA_DFLASH_IO_MODE_SELF_CONTAINED && !llama_model_dflash_io_tensors_match(model_dft, target_hidden_size, target_vocab_size)) { - LOG_ERR("%s: DFlash self-contained IO tensors do not match the target hidden/vocab contract (target_hidden=%d target_vocab=%d)\n", - __func__, - target_hidden_size, - target_vocab_size); - return; - } - - if (!llama_set_dflash_capture_layers(ctx_tgt, target_layer_ids.data(), (int32_t) target_layer_ids.size())) { - LOG_ERR("%s: failed to configure DFlash target capture callback\n", __func__); - return; - } - - batch = llama_batch_init(std::max(1, block_size), 0, 1); - target_window.reserve((size_t) this->cross_ctx * (size_t) n_target_features); - target_window_stage.reserve((size_t) this->cross_ctx * (size_t) n_target_features); - target_window_ring.resize((size_t) this->cross_ctx * (size_t) n_target_features); - target_window_append_features.reserve((size_t) this->cross_ctx * (size_t) n_target_features); - target_window_pos.reserve((size_t) this->cross_ctx); - target_window_pos_stage.reserve((size_t) this->cross_ctx); - ready = true; - - llama_set_dflash_visible_cross_ctx(ctx_dft, this->cross_ctx); - llama_dflash_profile_reset(ctx_tgt); - llama_dflash_profile_reset(ctx_dft); - - std::ostringstream layers_oss; - for (size_t i = 0; i < target_layer_ids.size(); ++i) { - if (i > 0) { - layers_oss << ","; - } - layers_oss << target_layer_ids[i]; - } - - const char * io_mode_name = io_mode == LLAMA_DFLASH_IO_MODE_SHARED ? "shared" : "self-contained"; - LOG_INF("%s: DFlash context ready (n_ctx=%d, block_size=%d, cross_ctx=%d, n_target_features=%d, target_layer_ids=[%s])\n", - __func__, llama_n_ctx(ctx_dft), block_size, this->cross_ctx, n_target_features, layers_oss.str().c_str()); - LOG_INF("%s: DFlash artifact io=%s draft_vocab=%d target_vocab=%d draft_hidden=%d target_hidden=%d mask_token_id=%d target_mask_token_id=%d\n", - __func__, io_mode_name, draft_vocab_size, target_vocab_size, draft_hidden_size, target_hidden_size, mask_token_id, target_mask_token_id); - } - - ~common_speculative_state_dflash() override { - llama_clear_dflash_capture(ctx_tgt); - if (ctx_dft) { - llama_free(ctx_dft); - } - if (batch.token != nullptr) { - llama_batch_free(batch); - } - } - - void begin(const llama_tokens & prompt) override { - GGML_UNUSED(prompt); - llama_kv_cache_clear(ctx_dft); - llama_reset_dflash_kv_cache_state(ctx_dft); - n_window_updates = 0; - n_rows_seen = 0; - n_rows_dropped = 0; - n_context_shifts = 0; - n_draft_empty = 0; - n_set_target_fail = 0; - n_decode_fail = 0; - last_draft_pos_base = -1; - t_draft_decode_us = 0; - t_draft_sample_us = 0; - t_warmup_collect_us = 0; - t_warmup_append_us = 0; - t_accept_output_copy_us = 0; - t_accept_commit_us = 0; - t_accept_append_us = 0; - t_accept_append_filter_us = 0; - t_accept_append_window_alloc_us = 0; - t_accept_append_replace_us = 0; - t_accept_append_keep_old_us = 0; - t_accept_append_new_rows_us = 0; - t_accept_append_commit_detail_us = 0; - t_accept_append_log_us = 0; - n_warmup_collect_calls = 0; - n_warmup_collect_rows = 0; - n_warmup_append_calls = 0; - n_warmup_append_rows = 0; - n_accept_output_copy_calls = 0; - n_accept_output_copy_rows = 0; - n_accept_commit_calls = 0; - n_accept_commit_rows = 0; - n_accept_append_calls = 0; - n_accept_append_rows = 0; - n_accept_append_replace_calls = 0; - n_accept_append_slide_calls = 0; - llama_dflash_profile_reset(ctx_tgt); - llama_dflash_profile_reset(ctx_dft); - } - - void draft( - const common_params_speculative & params, - const llama_tokens & prompt_tgt, - llama_token id_last, - llama_tokens & result) override { - GGML_UNUSED(prompt_tgt); - - result.clear(); - if (!ready || target_window_rows <= 0) { - n_draft_empty++; - return; - } - - const int32_t n_keep = std::min(params.n_max, block_size - 1); - if (n_keep <= 0) { - return; - } - - const float * target_features = nullptr; - size_t target_feature_floats = 0; - llama_dflash_window_update window_update = { - target_window_version, - target_window_keep_rows, - target_window_append_rows, - target_window_replace, - target_window_append_features.empty() ? nullptr : target_window_append_features.data(), - target_window_append_features.size(), - }; - const llama_dflash_kv_cache_transition cache_plan = - llama_plan_dflash_kv_cache_transition_for_ctx(ctx_dft, window_update, target_window_rows); - - if (cache_plan.rebuild_cache) { - dflash_materialize_target_window_features(*this); - target_features = target_window.data(); - target_feature_floats = target_window.size(); - window_update.append_features = target_window.data(); - window_update.append_floats = target_window.size(); - window_update.append_rows = target_window_rows; - } - - if (!llama_set_dflash_target_features_view(ctx_dft, target_features, target_feature_floats, target_window_rows, target_window_pos.data(), &window_update)) { - LOG_ERR("%s: failed to set DFlash target features\n", __func__); - n_set_target_fail++; - return; - } - - llama_kv_cache_clear(ctx_dft); - batch.n_tokens = 0; - const int32_t batch_len = n_keep + 1; - const llama_pos draft_pos_base = last_target_pos >= 0 ? last_target_pos + 1 : (llama_pos) target_window_rows; - const llama_pos seed_pos = last_target_pos >= 0 ? last_target_pos : draft_pos_base - 1; - last_draft_pos_base = draft_pos_base; - common_batch_add(batch, id_last, seed_pos, { 0 }, false); - for (int32_t i = 1; i < batch_len; ++i) { - common_batch_add(batch, mask_token_id, draft_pos_base + (i - 1), { 0 }, i <= n_keep); - } - - const int64_t t_decode_us = ggml_time_us(); - if (llama_decode(ctx_dft, batch) != 0) { - LOG_ERR("%s: llama_decode() failed for DFlash draft batch\n", __func__); - n_decode_fail++; - batch.n_tokens = 0; - return; - } - t_draft_decode_us += (uint64_t) (ggml_time_us() - t_decode_us); - - result.reserve((size_t) n_keep); - const int64_t t_sample_us = ggml_time_us(); - for (int32_t i = 0; i < n_keep; ++i) { - // Use argmax in GPU when available - llama_token id = llama_get_dflash_draft_token_ith(ctx_dft, i); - if (id == LLAMA_TOKEN_NULL) { - id = common_sampler_sample_speculative(nullptr, ctx_dft, i + 1, nullptr); - } - result.push_back(id); - } - t_draft_sample_us += (uint64_t) (ggml_time_us() - t_sample_us); - - batch.n_tokens = 0; - dflash_contract_log_draft(*this, n_keep, result.size()); - } - - void accept(uint16_t n_accepted) override { - GGML_UNUSED(n_accepted); - } -}; - -static void dflash_contract_log_append( - const common_speculative_state_dflash & state, - llama_seq_id seq_id, - const std::vector & new_positions) { - if (!dflash_contract_log_enabled()) { - return; - } - - static std::atomic counter = 0; - const uint64_t ordinal = counter.fetch_add(1, std::memory_order_relaxed); - if (ordinal >= 8) { - return; - } - - const dflash_contract_pos_summary incoming = dflash_contract_summarize_positions(new_positions); - const dflash_contract_pos_summary window = dflash_contract_summarize_positions(state.target_window_pos); - - LOG_INF("dflash contract append[%llu]: seq=%d incoming_rows=%zu incoming_pos=%s pos=[%d..%d] gaps=%d nonmono=%d window_rows=%d window_pos=%s pos=[%d..%d] gaps=%d nonmono=%d last_target_pos=%d\n", - (unsigned long long) (ordinal + 1), - (int) seq_id, - new_positions.size(), - dflash_contract_format_values(new_positions).c_str(), - (int) incoming.first, - (int) incoming.last, - incoming.gap_count, - incoming.nonmono_count, - state.target_window_rows, - dflash_contract_format_values(state.target_window_pos).c_str(), - (int) window.first, - (int) window.last, - window.gap_count, - window.nonmono_count, - (int) state.last_target_pos); -} - -static void dflash_contract_log_draft( - const common_speculative_state_dflash & state, - int32_t n_keep, - size_t result_size) { - if (!dflash_contract_log_enabled()) { - return; - } - - static std::atomic counter = 0; - const uint64_t ordinal = counter.fetch_add(1, std::memory_order_relaxed); - if (ordinal >= 8) { - return; - } - - const dflash_contract_pos_summary window = dflash_contract_summarize_positions(state.target_window_pos); - llama_dflash_profile_stats graph_stats = {}; - llama_dflash_profile_get_stats(state.ctx_dft, &graph_stats); - const int draft_delta = (state.last_target_pos >= 0 && state.last_draft_pos_base >= 0) - ? (int) (state.last_draft_pos_base - state.last_target_pos) - : -1; - const llama_pos seed_pos = state.last_target_pos; - const llama_pos mask_first_pos = state.last_draft_pos_base; - const llama_pos mask_last_pos = state.last_draft_pos_base >= 0 - ? state.last_draft_pos_base + n_keep - 1 - : -1; - - LOG_INF("dflash contract draft[%llu]: window_rows=%d window_pos=%s pos=[%d..%d] gaps=%d nonmono=%d last_target_pos=%d seed_pos=%d mask_pos=[%d..%d] sample_rows=[1..%d] output_rows=[1..%d] draft_pos_base=%d delta=%d n_keep=%d result=%zu set_target(missing/nonmono)=%llu/%llu graph(fallback/nonmono)=%llu/%llu graph_pos=[%d..%d]\n", - (unsigned long long) (ordinal + 1), - state.target_window_rows, - dflash_contract_format_values(state.target_window_pos).c_str(), - (int) window.first, - (int) window.last, - window.gap_count, - window.nonmono_count, - (int) state.last_target_pos, - (int) seed_pos, - (int) mask_first_pos, - (int) mask_last_pos, - n_keep, - n_keep, - (int) state.last_draft_pos_base, - draft_delta, - n_keep, - result_size, - (unsigned long long) graph_stats.set_target_missing_positions, - (unsigned long long) graph_stats.set_target_non_monotonic_positions, - (unsigned long long) graph_stats.graph_pos_fallbacks, - (unsigned long long) graph_stats.graph_pos_non_monotonic, - (int) graph_stats.last_pos_first, - (int) graph_stats.last_pos_last); -} - -struct common_speculative_state_draft : public common_speculative_state { - llama_context * ctx_tgt; // only used for retokenizing from ctx_dft - llama_context * ctx_dft; - - common_sampler * smpl; - - llama_batch batch; - llama_tokens prompt_dft; - - bool vocab_cmpt = true; // whether retokenization is needed - std::unordered_map vocab_map; - - common_speculative_state_draft( - enum common_speculative_type type, - llama_context * ctx_tgt, - llama_context * ctx_dft, - const std::vector> & replacements) - : common_speculative_state(type) - , ctx_tgt(ctx_tgt) - , ctx_dft(ctx_dft) - { - batch = llama_batch_init(llama_n_batch(ctx_dft), 0, 1); - smpl = nullptr; - { - struct common_params_sampling params; - params.top_k = 10; - params.samplers_sequence = { - llama_sampler_type::TOP_K, - llama_sampler_type::DIST, // needed to get probabilities - }; - smpl = common_sampler_init(llama_get_model(ctx_dft), params); - } - - vocab_cmpt = common_speculative_are_compatible(llama_get_model(ctx_tgt), llama_get_model(ctx_dft)); - LOG_DBG("vocab_cmpt = %d\n", vocab_cmpt); - - if (!vocab_cmpt) { - LOG_WRN("the target and draft vocabs are not compatible - tokens will be translated between the two\n"); - - for (const auto & pair : replacements) { - vocab_map[pair.first] = pair.second; - } - } - } - - ~common_speculative_state_draft() override { - llama_free(ctx_dft); - - common_sampler_free(smpl); - - llama_batch_free(batch); - } - - void begin(const llama_tokens & prompt) override { - GGML_UNUSED(prompt); - } - - void draft( - const common_params_speculative & params, - const llama_tokens & prompt_tgt, - llama_token id_last, - llama_tokens & result) override { - auto * spec = this; - - auto & batch = spec->batch; - auto & ctx_tgt = spec->ctx_tgt; - auto & ctx_dft = spec->ctx_dft; - auto & smpl = spec->smpl; - auto & prompt_dft = spec->prompt_dft; - - int reuse_i = 0; - int reuse_n = 0; - - const int n_ctx = llama_n_ctx(ctx_dft) - params.n_max; - - llama_tokens prompt_cnv; - if (!spec->vocab_cmpt) { - // convert id_last to draft vocab. llama_detokenize is called directly to avoid an allocation - const auto * model_tgt = llama_get_model(ctx_tgt); - const auto * vocab_tgt = llama_model_get_vocab(model_tgt); - - std::string text; - - text = common_detokenize(ctx_tgt, prompt_tgt, true); - text = replace_to_dft(text); - - LOG_DBG("%s: main->draft detokenized string: '%s'\n", __func__, text.c_str()); - - prompt_cnv = common_tokenize(ctx_dft, text, false, true); - - - - int32_t n_chars = llama_detokenize(vocab_tgt, &id_last, 1, nullptr, 0, false, false); - GGML_ASSERT(n_chars < 0 && "failed to detokenize id_last"); - - text.resize(-n_chars); - llama_detokenize(vocab_tgt, &id_last, 1, text.data(), text.size(), false, false); - text = replace_to_dft(text); - - LOG_DBG("main->draft detokenized id_last(%d): '%s'\n", id_last, text.c_str()); - id_last = common_tokenize(ctx_dft, text, false, true)[0]; - } - - const llama_tokens & prompt_cur = spec->vocab_cmpt ? prompt_tgt : prompt_cnv; - - const int i_start = std::max(0, (int) prompt_cur.size() - n_ctx); - - // reuse as much as possible from the old draft context - // ideally, the draft context should be as big as the target context and we will always reuse the entire prompt - for (int i = 0; i < (int) prompt_dft.size(); ++i) { - int cur = 0; - while (i_start + cur < (int) prompt_cur.size() && - i + cur < (int) prompt_dft.size() && - prompt_cur[i_start + cur] == prompt_dft[i + cur]) { - cur++; - } - - if ((cur >= 256 || n_ctx >= (int) prompt_cur.size()) && cur > reuse_n) { - reuse_i = i; - reuse_n = cur; - } - } - - LOG_DBG("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt_dft.size()); - - result.clear(); - result.reserve(params.n_max); - - if (reuse_n == 0) { - llama_kv_cache_clear(ctx_dft); - prompt_dft.clear(); - } else { - // this happens when a previous draft has been discarded (for example, due to being too small), but the - // target model agreed with it. in this case, we simply pass back the previous results to save compute - if (reuse_i + reuse_n < (int) prompt_dft.size() && prompt_dft[reuse_i + reuse_n] == id_last) { - for (int i = reuse_i + reuse_n + 1; i < (int) prompt_dft.size(); ++i) { - result.push_back(prompt_dft[i]); - - if (params.n_max <= (int) result.size()) { - break; - } - } - - return; - } - - if (reuse_i > 0) { - llama_kv_cache_seq_rm (ctx_dft, 0, 0, reuse_i); - llama_kv_cache_seq_add(ctx_dft, 0, reuse_i, -1, -reuse_i); - - prompt_dft.erase(prompt_dft.begin(), prompt_dft.begin() + reuse_i); - } - - if (reuse_n < (int) prompt_dft.size()) { - llama_kv_cache_seq_rm (ctx_dft, 0, reuse_n, -1); - prompt_dft.erase(prompt_dft.begin() + reuse_n, prompt_dft.end()); - } - } - - // prepare a batch to evaluate any new tokens in the prompt - common_batch_clear(batch); - - for (size_t i = i_start + reuse_n; i < prompt_cur.size(); ++i) { - //LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_cur[i]); - common_batch_add(batch, prompt_cur[i], i - i_start, { 0 }, false); - - prompt_dft.push_back(prompt_cur[i]); - } - - // we should rarely end-up here during normal decoding - if (batch.n_tokens > 0) { - //LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str()); - - llama_decode(ctx_dft, batch); - } - - const llama_pos n_past = prompt_dft.size(); - - LOG_DBG("%s: n_past = %d\n", __func__, n_past); - - common_batch_clear(batch); - common_batch_add (batch, id_last, n_past, { 0 }, true); - - prompt_dft.push_back(id_last); - - //LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx_dft, prompt_dft).c_str()); - - llama_decode(ctx_dft, batch); - - common_sampler_reset(smpl); - - // sample n_draft tokens from the draft model - for (int i = 0; i < params.n_max; ++i) { - common_batch_clear(batch); - - common_sampler_sample(smpl, ctx_dft, 0, true); - - const auto * cur_p = common_sampler_get_candidates(smpl, true); - - for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) { - LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", - k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str()); - } - - // add drafted token for each sequence - const llama_token id = cur_p->data[0].id; - - common_sampler_accept(smpl, nullptr, id, true); - - // only collect very high-confidence draft tokens - if (cur_p->data[0].p < params.p_min) { - if (i == 0) { - result.push_back(id); - } - break; - } - - result.push_back(id); - - if (params.n_max <= (int) result.size()) { - break; - } - - - common_batch_add(batch, id, n_past + i + 1, { 0 }, true); - - // evaluate the drafted tokens on the draft model - llama_decode(ctx_dft, batch); - - prompt_dft.push_back(id); - } - - if (!spec->vocab_cmpt) { - std::string detokenized = common_detokenize(ctx_dft, result, true); - detokenized = replace_to_tgt(detokenized); - LOG_DBG("draft->main detokenized string: '%s'\n", detokenized.c_str()); - result = common_tokenize(ctx_tgt, detokenized, false, true); - if (result.size() > (size_t)params.n_max) { - result.resize(params.n_max); - } - } - } - - void accept(uint16_t n_accepted) override { - // noop - GGML_UNUSED(n_accepted); - } - - std::string replace_to_dft(const std::string & input) const { - std::string result = input; - - for (const auto & pair : this->vocab_map) { - size_t pos = result.find(pair.first); - while (pos != std::string::npos) { - result.replace(pos, pair.first.length(), pair.second); - pos = result.find(pair.first, pos + pair.second.length()); - } - } - - return result; - } - - std::string replace_to_tgt(const std::string & input) const { - std::string result = input; - - for (const auto & pair : this->vocab_map) { - size_t pos = result.find(pair.second); - while (pos != std::string::npos) { - result.replace(pos, pair.second.length(), pair.first); - pos = result.find(pair.second, pos + pair.first.length()); - } - } - - return result; - } -}; - -struct common_speculative_state_eagle3 : public common_speculative_state { - common_speculative_state_eagle3(enum common_speculative_type type) : common_speculative_state(type) {} - - void begin(const llama_tokens & prompt) override { - GGML_UNUSED(prompt); - } - - void draft( - const common_params_speculative & params, - const llama_tokens & prompt_tgt, - llama_token id_last, - llama_tokens & draft_tokens) override { - // TODO: implement - GGML_UNUSED(params); - GGML_UNUSED(prompt_tgt); - GGML_UNUSED(id_last); - GGML_UNUSED(draft_tokens); - } - - void accept(uint16_t n_accepted) override { - // noop - GGML_UNUSED(n_accepted); - } -}; - -// state of self-speculation (simple implementation, not ngram-map) -struct common_speculative_state_ngram_simple : public common_speculative_state { - common_ngram_simple_config config; - - common_speculative_state_ngram_simple( - enum common_speculative_type type, - common_ngram_simple_config config) - : common_speculative_state(type), config(config) {} - - void begin(const llama_tokens & prompt) override { - GGML_UNUSED(prompt); - } - - void draft( - const common_params_speculative & params, - const llama_tokens & prompt_tgt, - llama_token id_last, - llama_tokens & result) override { - - result = common_ngram_simple_draft(config, prompt_tgt, id_last); - GGML_UNUSED(params); - } - - void accept(uint16_t n_accepted) override { - // noop - GGML_UNUSED(n_accepted); - } -}; - -struct common_speculative_state_ngram_map_k : public common_speculative_state { - // draft ngram map for speculative decoding without draft model - common_ngram_map map; - - common_speculative_state_ngram_map_k( - enum common_speculative_type type, - common_ngram_map map) - : common_speculative_state(type), map(std::move(map)) {} - - void begin(const llama_tokens & prompt) override { - common_ngram_map_begin(map, prompt); - } - - void draft( - const common_params_speculative & params, - const llama_tokens & prompt_tgt, - llama_token id_last, - llama_tokens & result) override { - common_ngram_map_draft(map, prompt_tgt, id_last, result); - GGML_UNUSED(params); - } - - void accept(uint16_t n_accepted) override { - common_ngram_map_accept(map, n_accepted); - } -}; - -struct common_speculative_state_ngram_mod : public common_speculative_state { - common_ngram_mod & mod; - - // the last position in the prompt that was added to the ngram container - size_t i_last = 0; - - // length of the last drafted n‑gram (number of tokens returned by draft) - size_t n_draft_last = 0; - - // consecutive accept rounds with low acceptance fraction (< 0.5) - int n_low = 0; - - // enable trace logging if LLAMA_TRACE is set - const bool verbose; - - common_speculative_state_ngram_mod(enum common_speculative_type type, common_ngram_mod & mod) - : common_speculative_state(type), mod(mod), verbose(std::getenv("LLAMA_TRACE") != nullptr) { - static_assert(sizeof(llama_token) == sizeof(common_ngram_mod::entry_t)); - } - - void begin(const llama_tokens & prompt) override { - i_last = 0; - - n_draft_last = 0; - n_low = 0; - - const size_t n = mod.get_n(); - - if (prompt.size() < n) { - return; - } - - for (size_t i = 0; i < prompt.size() - n; ++i) { - mod.add(prompt.data() + i); - } - - i_last = prompt.size() - n; - - const double f = (double)mod.get_used() / (double)mod.size(); - LOG_INF("%s: ngram_mod occupancy = %zu/%zu (%.2f)\n", __func__, mod.get_used(), mod.size(), f); - - constexpr double f_thold = 0.25; - if (f > f_thold) { - LOG_WRN("%s: ngram_mod occupancy %.2f exceeds threshold (%.2f) - resetting\n", __func__, f, f_thold); - - mod.reset(); - } - } - - void draft( - const common_params_speculative & params, - const llama_tokens & prompt_tgt, - llama_token id_last, - llama_tokens & result) override { - GGML_UNUSED(params); - - n_draft_last = 0; - - const size_t cur_len = prompt_tgt.size(); - if (cur_len < mod.get_n()) { - return; - } - - const size_t n = mod.get_n(); - - // add new ngrams in chunks - if (i_last + 32 < cur_len) { - for (size_t i = i_last; i < cur_len - n; ++i) { - mod.add(prompt_tgt.data() + i); - } - - i_last = cur_len - n; - } - - result.resize(n + params.n_max); - for (size_t i = 0; i < n - 1; ++i) { - result[i] = prompt_tgt[cur_len - n + 1 + i]; - } - result[n - 1] = id_last; - - for (int i = 0; i < params.n_max; ++i) { - const llama_token token = mod.get(result.data() + i); - if (token == common_ngram_mod::EMPTY) { - if (i < params.n_min) { - result.clear(); - return; - } - - result.resize(n + i); - break; - } - result[n + i] = token; - } - - // only return the m tokens that were drafted - for (size_t i = 0; n + i < result.size(); ++i) { - result[i] = result[n + i]; - } - result.resize(result.size() - n); - - // store length of drafted n‑gram for later acceptance analysis - n_draft_last = result.size(); - } - - void accept(uint16_t n_accepted) override { - if (verbose) { - LOG_INF("%s: accepted %d tokens from %zu drafted tokens\n", __func__, n_accepted, n_draft_last); - } - - // compute acceptance fraction if we have a recorded draft length - if (n_draft_last > 0) { - const double f_acc = (double)n_accepted / (double)n_draft_last; - if (f_acc < 0.5) { - n_low++; - if (n_low >= 3) { - LOG_WRN("%s: low acceptance streak (%d) – resetting ngram_mod\n", __func__, n_low); - - mod.reset(); - n_low = 0; - i_last = 0; - } - } else { - n_low = 0; - } - } - } -}; - -struct common_speculative_state_ngram_cache : public common_speculative_state { - uint16_t n_draft; - bool save_dynamic; - bool save_static; - - common_ngram_cache ngram_cache_context; - common_ngram_cache ngram_cache_dynamic; - common_ngram_cache ngram_cache_static; - - size_t cache_size = 0; // number of tokens in n-gram cache - - common_speculative_state_ngram_cache( - const enum common_speculative_type type, - const std::string & path_static, - const std::string & path_dynamic, - uint16_t n_draft, - bool save_dynamic, - bool save_static) - : common_speculative_state(type) - , n_draft(n_draft) - , save_dynamic(save_dynamic) - , save_static(save_static) - { - if (!path_static.empty()) { - try { - ngram_cache_static = common_ngram_cache_load(path_static); - } catch (...) { - LOG_ERR("failed to open static lookup cache: %s", path_static.c_str()); - GGML_ABORT("Couldn't read static lookup cache"); - } - } - - if (!path_dynamic.empty()) { - try { - ngram_cache_dynamic = common_ngram_cache_load(path_dynamic); - } catch (...) { - LOG_ERR("failed to open dynamic lookup cache: %s", path_dynamic.c_str()); - GGML_ABORT("Couldn't read dynamic lookup cache"); - } - } - } - - void begin(const llama_tokens & prompt) override { - GGML_UNUSED(prompt); - } - - void draft( - const common_params_speculative & params, - const llama_tokens & prompt_tgt, - llama_token id_last, - llama_tokens & result) override { - GGML_UNUSED(params); - - if (cache_size < prompt_tgt.size() + 1) { - llama_tokens tokens_new; - tokens_new.reserve(prompt_tgt.size() + 1 - cache_size); - for (size_t j = cache_size; j < prompt_tgt.size(); ++j) { - tokens_new.push_back(prompt_tgt[j]); - } - tokens_new.push_back(id_last); // add the last token - - // Update context ngram cache with new prompt_tgt: - common_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, - tokens_new, tokens_new.size(), false); - cache_size = prompt_tgt.size() + 1; - } - - llama_tokens inp; - inp.reserve(prompt_tgt.size() + 1); - for (size_t j = 0; j < prompt_tgt.size(); ++j) { - inp.push_back(prompt_tgt[j]); - } - inp.push_back(id_last); - - result.push_back(id_last); - - common_ngram_cache_draft(inp, result, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, - ngram_cache_context, - ngram_cache_dynamic, - ngram_cache_static); - - if (result.size() > 0) { - // delete first token in result (which is the id_last token) - result.erase(result.begin()); - } - } - - void accept(uint16_t n_accepted) override { - // TODO: noop - GGML_UNUSED(n_accepted); - } -}; - -struct common_speculative_state_suffix : public common_speculative_state { - common_suffix_tree tree; - common_suffix_tree corpus_tree; - bool has_corpus = false; - size_t cache_size = 0; - - // Acceptance feedback - size_t n_draft_last = 0; - bool had_accept = false; - int n_low = 0; - float base_p_min = 0.1f; - float eff_p_min = 0.1f; - - common_speculative_state_suffix( - enum common_speculative_type type, - int max_depth, - const std::string & corpus_path, - const llama_model * model) - : common_speculative_state(type) - , tree(max_depth) - , corpus_tree(max_depth) - { - if (!corpus_path.empty()) { - std::function(const std::string &)> tokenize_fn; - if (model) { - tokenize_fn = [model](const std::string & text) -> std::vector { - return common_tokenize(model, text, false, true); - }; - } - has_corpus = corpus_tree.load_corpus(corpus_path, tokenize_fn); - } - } - - void begin(const llama_tokens & prompt) override { - cache_size = 0; - n_draft_last = 0; - had_accept = false; - n_low = 0; - GGML_UNUSED(prompt); - } - - void draft( - const common_params_speculative & params, - const llama_tokens & prompt_tgt, - llama_token id_last, - llama_tokens & result) override { - - base_p_min = params.p_min; - if (n_draft_last > 0 && !had_accept) { - if (++n_low >= 3) { - eff_p_min = std::min(eff_p_min + 0.1f, 0.5f); - n_low = 0; - } - } - had_accept = false; - - if (cache_size < prompt_tgt.size() + 1) { - llama_tokens tokens_new; - tokens_new.reserve(prompt_tgt.size() + 1 - cache_size); - for (size_t j = cache_size; j < prompt_tgt.size(); ++j) { - tokens_new.push_back(prompt_tgt[j]); - } - tokens_new.push_back(id_last); - - tree.extend(tokens_new.data(), (int)tokens_new.size()); - cache_size = prompt_tgt.size() + 1; - } - - const int ctx_len = std::min((int)(prompt_tgt.size() + 1), tree.max_depth()); - llama_tokens context; - context.reserve(ctx_len); - const int ctx_start = (int)prompt_tgt.size() + 1 - ctx_len; - for (int j = ctx_start; j < (int)prompt_tgt.size(); ++j) { - context.push_back(prompt_tgt[j]); - } - context.push_back(id_last); - const int min_match_len = std::max(1, params.suffix_min_match_len); - - result = tree.speculate( - context.data(), (int)context.size(), - params.n_max, - eff_p_min, - 1, - min_match_len); - - if (has_corpus) { - auto corpus_result = corpus_tree.speculate( - context.data(), (int)context.size(), - params.n_max, - eff_p_min, - 1, - min_match_len); - if (corpus_result.size() > result.size()) { - result = std::move(corpus_result); - } - } - - n_draft_last = result.size(); - } - - void accept(uint16_t n_accepted) override { - if (n_draft_last == 0) { - return; - } - had_accept = true; - const double f_acc = (double)n_accepted / (double)n_draft_last; - if (f_acc < 0.5) { - if (++n_low >= 3) { - eff_p_min = std::min(eff_p_min + 0.1f, 0.5f); - n_low = 0; - } - } else { - n_low = 0; - if (eff_p_min > base_p_min) { - eff_p_min = std::max(eff_p_min - 0.05f, base_p_min); - } - } - } -}; - -struct common_speculative { - std::vector configs; // resolved stage config for each implementation - std::vector> impls; // list of implementations to use and their states - common_speculative_checkpoint checkpoint; - common_speculative_state * curr_impl = nullptr; // current implementation in use (for stats) - std::unique_ptr tuner; - int last_n_drafted = 0; - int64_t t_step_start_us = 0; -}; - -static bool common_speculative_stage_chain_matches( - const std::vector & stages, - const std::vector & configs) { - if (stages.size() != configs.size()) { - return false; - } - - for (size_t i = 0; i < stages.size(); ++i) { - if (stages[i].type != configs[i].type) { - return false; - } - } - - return true; -} - -static common_params_speculative common_speculative_get_runtime_params( - const common_speculative_config & config, - const common_params_speculative & params, - const common_speculative_stage_params & stage) { - common_params_speculative result = config.params; - - result.type = config.type; - result.n_max = stage.has_n_max_override() ? stage.n_max : params.n_max; - result.n_min = stage.has_n_min_override() ? stage.n_min : params.n_min; - result.p_min = stage.has_p_min_override() ? stage.p_min : params.p_min; - - if (config.type == COMMON_SPECULATIVE_TYPE_SUFFIX) { - result.suffix_min_match_len = stage.has_suffix_min_match_len_override() - ? stage.suffix_min_match_len - : params.suffix_min_match_len; - } - - result.n_max = std::max(result.n_max, 0); - result.n_min = std::max(0, std::min(result.n_min, result.n_max)); - result.stages.clear(); - - return result; -} - -static common_ngram_map get_common_ngram_map(const common_speculative_config & config) { - uint16_t size_key = config.params.ngram_size_n; - uint16_t size_value = config.params.ngram_size_m; - bool key_only = (config.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K); - uint16_t min_hits = config.params.ngram_min_hits; - - return common_ngram_map(size_key, size_value, key_only, min_hits); -} - -static common_speculative_state_ngram_cache create_state_ngram_cache( - const std::string & path_static, const std::string & path_dynamic, - const common_speculative_config & config) { - uint16_t n_draft = 8; // TODO get from config? - - // TODO bool param in common/common.h to set save_static/save_dynamic? - bool save_static = false; - bool save_dynamic = false; - - common_speculative_state_ngram_cache state(config.type, path_static, path_dynamic, n_draft, save_static, save_dynamic); - - return state; -} - -std::string common_speculative_type_name_str() { - std::string result; - for (size_t i = 0; i < common_speculative_types.size(); i++) { - if (i > 0) { - result += ", "; - } - result += common_speculative_type_to_str(common_speculative_types[i]); - } - return result; -} - -std::string common_speculative_type_to_str(enum common_speculative_type type) { - switch (type) { - case COMMON_SPECULATIVE_TYPE_NONE: return "none"; - case COMMON_SPECULATIVE_TYPE_DRAFT: return "draft"; - case COMMON_SPECULATIVE_TYPE_DFLASH: return "dflash"; - case COMMON_SPECULATIVE_TYPE_MTP: return "mtp"; - case COMMON_SPECULATIVE_TYPE_EAGLE3: return "eagle3"; - case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return "ngram_simple"; - case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: return "ngram_map_k"; - case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: return "ngram_map_k4v"; - case COMMON_SPECULATIVE_TYPE_NGRAM_MOD: return "ngram_mod"; - case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: return "ngram_cache"; - case COMMON_SPECULATIVE_TYPE_SUFFIX: return "suffix"; - default: return "unknown"; - } -} - -enum common_speculative_type common_speculative_type_from_name(const std::string & name) { - std::string normalized = name; - std::replace(normalized.begin(), normalized.end(), '-', '_'); - - const auto it = common_speculative_type_from_name_map.find(normalized); - if (it == common_speculative_type_from_name_map.end()) { - return COMMON_SPECULATIVE_TYPE_COUNT; - } - return it->second; -} - -bool common_speculative_is_compat(llama_context * ctx_tgt) { - bool res = true; - - llama_kv_cache_clear(ctx_tgt); - - // eval 2 tokens to check if the context is compatible - std::vector tmp; - tmp.push_back(0); - tmp.push_back(0); - - int ret = llama_decode(ctx_tgt, llama_batch_get_one(tmp.data(), tmp.size(), 0, 0)); - if (ret != 0) { - LOG_ERR("%s: llama_decode() failed: %d\n", __func__, ret); - res = false; - goto done; - } - - // try to remove the last tokens - if (!llama_kv_cache_seq_rm(ctx_tgt, 0, 1, -1)) { - LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__); - res = false; - goto done; - } - -done: - llama_kv_cache_clear(ctx_tgt); - llama_synchronize(ctx_tgt); - - return res; -} - -// initialization of the speculative decoding system -// -common_speculative * common_speculative_init( - common_params_speculative & params, - llama_context * ctx_tgt) { - std::string chain_error; - if (!common_speculative_validate_chain(params, &chain_error)) { - LOG_ERR("%s: invalid speculative stage chain: %s\n", __func__, chain_error.c_str()); - return nullptr; - } - - const auto stages = params.get_resolved_stages(); - if (params.model_dft && llama_model_is_gemma4_mtp_assistant(params.model_dft)) { - const bool has_draft_stage = std::any_of(stages.begin(), stages.end(), [](const common_speculative_stage_params & stage) { - return stage.type == COMMON_SPECULATIVE_TYPE_DRAFT; - }); - - if (has_draft_stage) { - LOG_ERR("%s: Gemma4 assistant models only support MTP stages; omit -md for self-spec-only runs or use -mtp/--spec-stage mtp for assistant-backed MTP\n", __func__); - return nullptr; - } - } - - const bool has_dflash_stage = std::any_of(stages.begin(), stages.end(), [](const common_speculative_stage_params & stage) { - return stage.type == COMMON_SPECULATIVE_TYPE_DFLASH; - }); - - const bool needs_draft_ctx = std::any_of(stages.begin(), stages.end(), [¶ms](const common_speculative_stage_params & stage) { - return stage.type == COMMON_SPECULATIVE_TYPE_DRAFT || - stage.type == COMMON_SPECULATIVE_TYPE_DFLASH || - (stage.type == COMMON_SPECULATIVE_TYPE_MTP && params.model_dft != nullptr); - }); - - llama_context * ctx_dft = nullptr; - if (needs_draft_ctx) { - if (!params.model_dft) { - LOG_ERR("%s: draft speculative stage requires a loaded draft model\n", __func__); - return nullptr; - } - - llama_context_params cparams_dft = params.cparams_dft; - - if (has_dflash_stage) { - if (!llama_model_share_dflash_io_tensors(params.model_dft, llama_get_model(ctx_tgt))) { - LOG_ERR("%s: failed to share target IO tensors with DFlash draft model\n", __func__); - return nullptr; - } - - int32_t max_cross_ctx = 0; - for (const auto & stage : stages) { - if (stage.type != COMMON_SPECULATIVE_TYPE_DFLASH) { - continue; - } - - max_cross_ctx = std::max(max_cross_ctx, params.with_stage_overrides(stage).dflash_cross_ctx); - } - - const int32_t block_size = llama_model_dflash_block_size(params.model_dft); - if (block_size <= 0) { - LOG_ERR("%s: invalid DFlash draft block size\n", __func__); - return nullptr; - } - - const int64_t required_n_ctx = (int64_t) max_cross_ctx + (int64_t) block_size; - if (required_n_ctx > std::numeric_limits::max()) { - LOG_ERR("%s: invalid DFlash draft context size cross_ctx=%d block_size=%d required_n_ctx=%lld\n", - __func__, max_cross_ctx, block_size, (long long) required_n_ctx); - return nullptr; - } - - cparams_dft.n_ctx = (uint32_t) required_n_ctx; - } - - ctx_dft = llama_init_from_model(params.model_dft, cparams_dft); - if (ctx_dft == nullptr) { - LOG_ERR("%s", "failed to create draft context\n"); - return nullptr; - } - } - - // Compute the implementations to use based on the resolved stage chain. - std::vector configs = {}; - configs.reserve(stages.size()); - - for (const auto & stage : stages) { - common_params_speculative stage_params = params.with_stage_overrides(stage); - - if (stage.type == COMMON_SPECULATIVE_TYPE_NGRAM_MOD && !stage_params.ngram_mod) { - stage_params.ngram_mod = std::make_shared(stage_params.ngram_size_n, 4*1024*1024); - - LOG_INF("%s: initialized ngram_mod with n=%d, size=%zu (%.3f MB)\n", __func__, - stage_params.ngram_size_n, stage_params.ngram_mod->size(), - (float)(stage_params.ngram_mod->size_bytes())/1024/1024); - - if (stage_params.ngram_size_n < 16) { - LOG_WRN("%s: ngram_mod n=%d is too small - poor quality is possible, see: https://github.com/ggml-org/llama.cpp/pull/19164\n", __func__, stage_params.ngram_size_n); - } - } - - configs.push_back(common_speculative_config(stage, stage_params)); - } - - if (!configs.empty() && llama_model_has_recurrent(llama_get_model(ctx_tgt))) { - const int ckpt_tokens = std::max(1, params.get_max_stage_n_max() + 1); - const int actual_mode = llama_spec_ckpt_init(ctx_tgt, params.recurrent_ckpt_mode, ckpt_tokens); - if (actual_mode == LLAMA_SPEC_CKPT_NONE) { - LOG_ERR("%s: failed to prepare recurrent checkpoint mode '%s' during speculative init (max_tokens=%d)\n", - __func__, - params.recurrent_ckpt_mode == LLAMA_SPEC_CKPT_PER_STEP ? "per-step" : - params.recurrent_ckpt_mode == LLAMA_SPEC_CKPT_GPU_FALLBACK ? "gpu-fallback" : - params.recurrent_ckpt_mode == LLAMA_SPEC_CKPT_CPU ? "cpu" : "auto", - ckpt_tokens); - if (ctx_dft != nullptr) { - llama_free(ctx_dft); - } - return nullptr; - } - llama_spec_ckpt_discard(ctx_tgt); - params.recurrent_ckpt_mode = actual_mode; - } - - std::vector> impls = {}; - - for (const common_speculative_config & config : configs) { - LOG_DBG("%s: adding implementation %s\n", __func__, common_speculative_type_to_str(config.type).c_str()); - switch (config.type) { - case COMMON_SPECULATIVE_TYPE_NONE: - break; - case COMMON_SPECULATIVE_TYPE_DRAFT: { - impls.push_back(std::make_unique(config.type, - /* .ctx_tgt = */ ctx_tgt, - /* .ctx_dft = */ ctx_dft, - /* .replacements = */ config.params.replacements - )); - break; - } - case COMMON_SPECULATIVE_TYPE_DFLASH: { - auto state = std::make_unique( - config.type, - ctx_tgt, - ctx_dft, - config.params.dflash_cross_ctx); - if (!state->ready) { - LOG_ERR("%s: failed to initialize DFlash speculative state\n", __func__); - return nullptr; - } - impls.push_back(std::move(state)); - ctx_dft = nullptr; - break; - } - case COMMON_SPECULATIVE_TYPE_MTP: { - llama_context * ctx_mtp = ctx_dft; - if (!ctx_mtp) { - const llama_model * model = llama_get_model(ctx_tgt); - ctx_mtp = llama_init_from_model(const_cast(model), config.params.cparams_dft); - if (!ctx_mtp) { - LOG_ERR("%s: failed to create MTP context\n", __func__); - return nullptr; - } - } - ctx_dft = nullptr; - - const bool use_constant_draft_positions = llama_model_is_gemma4_mtp_assistant(llama_get_model(ctx_mtp)); - impls.push_back(std::make_unique( - config.type, ctx_tgt, ctx_mtp, use_constant_draft_positions)); - break; - } - case COMMON_SPECULATIVE_TYPE_EAGLE3: { - impls.push_back(std::make_unique(config.type)); - break; - } - case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: { - common_ngram_map ngram_map = get_common_ngram_map(config); - - uint16_t ngram_size_key = ngram_map.size_key; - uint16_t mgram_size_value = ngram_map.size_value; - - auto config_simple = common_ngram_simple_config { - /* .size_ngram = */ ngram_size_key, - /* .size_mgram = */ mgram_size_value - }; - auto state = std::make_unique( - /* .type = */ config.type, - /* .state = */ config_simple - ); - impls.push_back(std::move(state)); - break; - } - case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: - case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: { - impls.push_back(std::make_unique( - (config.type), - get_common_ngram_map(config) - )); - break; - } - case COMMON_SPECULATIVE_TYPE_NGRAM_MOD: { - GGML_ASSERT(config.params.ngram_mod); - impls.push_back(std::make_unique(config.type, *config.params.ngram_mod)); - break; - } - case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: { - auto state = create_state_ngram_cache( - config.params.lookup_cache_static, config.params.lookup_cache_dynamic, config); - impls.push_back(std::make_unique(state)); - break; - } - case COMMON_SPECULATIVE_TYPE_SUFFIX: { - int depth = config.params.suffix_max_depth > 0 ? config.params.suffix_max_depth : 64; - const llama_model * model = llama_get_model(ctx_tgt); - impls.push_back(std::make_unique( - config.type, depth, config.params.suffix_corpus, model)); - break; - } - default: - break; - } - } - - if (impls.empty()) { - LOG_WRN("%s", "no implementations specified for speculative decoding\n"); - return nullptr; - } - - auto * result = new common_speculative { - /* .configs = */ std::move(configs), - /* .impls = */ std::move(impls) - }; - - // initialize autotune if requested - if (params.autotune && params.has_composite_stage_chain()) { - LOG_WRN("Autotune disabled — explicit speculative stage chains are not supported yet\n"); - } else if (params.autotune && !result->impls.empty()) { - auto actual_type = result->impls[0]->type; - if (actual_type != COMMON_SPECULATIVE_TYPE_NONE && - actual_type != COMMON_SPECULATIVE_TYPE_EAGLE3) { - result->tuner = std::make_unique(); - result->tuner->init(actual_type, params, llama_get_model(ctx_tgt)); - LOG_DBG("Autotune initialized for %s, tuning %zu parameters\n", - common_speculative_type_to_str(actual_type).c_str(), - result->tuner->coords.size()); - } else { - LOG_WRN("Autotune disabled — speculative type %s is not supported for autotuning\n", - common_speculative_type_to_str(actual_type).c_str()); - } - } - - return result; -} - -void common_speculative_free(common_speculative * spec) { - if (spec == nullptr) { - return; - } - - spec->checkpoint.clear(); - delete spec; -} - -void common_speculative_begin(common_speculative * spec, const llama_tokens & prompt) { - if (spec == nullptr) { - return; - } - - for (auto & impl : spec->impls) { - common_time_meas tm(impl->t_begin_us, !impl->gen_perf); - impl->begin(prompt); - impl->n_call_begin++; - } -} - -llama_tokens common_speculative_draft( - common_speculative * spec, - common_params_speculative & params, - const llama_tokens & prompt_tgt, // specified in target model vocab - llama_token id_last, - llama_pos draft_base_pos, - llama_seq_id draft_seq_id) { - llama_tokens result; - - spec->t_step_start_us = ggml_time_us(); - - // apply autotune proposal if enabled - if (spec->tuner && spec->tuner->enabled) { - spec->tuner->propose(params); - } - - const auto runtime_stages = params.get_resolved_stages(); - const bool use_runtime_stage_overrides = common_speculative_stage_chain_matches(runtime_stages, spec->configs); - - spec->curr_impl = nullptr; // reset current implementation - - for (size_t i = 0; i < spec->impls.size(); ++i) { - auto & impl = spec->impls[i]; - const auto & runtime_stage = use_runtime_stage_overrides ? runtime_stages[i] : spec->configs[i].stage; - common_params_speculative impl_params = common_speculative_get_runtime_params(spec->configs[i], params, runtime_stage); - result.clear(); - - { - common_time_meas tm(impl->t_draft_us, !impl->gen_perf); - impl->draft(impl_params, prompt_tgt, id_last, draft_base_pos, draft_seq_id, result); - impl->n_call_draft++; - } - - if (result.empty()) { - continue; - } - - if (common_speculative_type_is_self_spec(impl->type) && impl_params.n_min > 0 && (int)result.size() < impl_params.n_min) { - LOG_DBG("%s: impl %s drafted %zu tokens, below fallback threshold %d - trying next implementation\n", - __func__, common_speculative_type_to_str(impl->type).c_str(), result.size(), impl_params.n_min); - result.clear(); - continue; - } - LOG_DBG("%s: called impl %s, hist size = %zu, call_count = %zu, gen = %zu\n", __func__, - common_speculative_type_to_str(impl.get()->type).c_str(), prompt_tgt.size(), - impl.get()->n_call_draft, result.size()); - - spec->curr_impl = impl.get(); - impl->n_gen_drafts++; - impl->n_gen_tokens += result.size(); - - break; // We have a draft, so break out of the loop and return it. - } - - // store draft count for tuner feedback - if (spec->tuner && spec->tuner->enabled) { - spec->last_n_drafted = (int)result.size(); - } - - return result; -} - -void common_speculative_accept(common_speculative * spec, uint16_t n_accepted) { - if (spec->tuner && spec->tuner->enabled && spec->t_step_start_us > 0) { - int64_t step_time_us = ggml_time_us() - spec->t_step_start_us; - double step_tps = (step_time_us > 100) - ? (n_accepted + 1.0) * 1e6 / (double)step_time_us - : 0.0; - spec->tuner->accept_feedback(n_accepted, spec->last_n_drafted, step_tps); - spec->t_step_start_us = 0; - } - - common_speculative_state * impl = spec->curr_impl; - - if (!impl) { - return; - } - - { - common_time_meas tm(impl->t_accept_us, !impl->gen_perf); - if (n_accepted > 0) { - impl->n_acc_drafts++; - impl->n_acc_tokens += n_accepted; - } - - impl->accept(n_accepted); - impl->n_call_accept++; - } - - if (impl->type != COMMON_SPECULATIVE_TYPE_MTP) { - if (auto * mtp_state = common_speculative_get_mtp_state(spec); mtp_state != nullptr) { - mtp_invalidate_cached_drafts(*mtp_state); - } - } -} - -static bool common_speculative_has_type(const common_speculative * spec, common_speculative_type type) { - if (spec == nullptr) { - return false; - } - - return std::any_of(spec->configs.begin(), spec->configs.end(), [type](const common_speculative_config & config) { - return config.type == type; - }); -} - -static int common_speculative_ctx_mtp_n_embd(llama_context * ctx) { - return ctx ? (int) llama_mtp_state_n_embd(ctx) : 0; -} - -static bool common_speculative_batch_token_has_seq_id( - const llama_batch & batch, - int token_index, - llama_seq_id seq_id) { - if (batch.n_seq_id == nullptr || batch.seq_id == nullptr || batch.n_seq_id[token_index] <= 0 || batch.seq_id[token_index] == nullptr) { - return false; - } - - for (int i = 0; i < batch.n_seq_id[token_index]; ++i) { - if (batch.seq_id[token_index][i] == seq_id) { - return true; - } - } - - return false; -} - -static bool common_speculative_batch_is_exact_single_seq( - const llama_batch & batch, - llama_seq_id seq_id) { - if (batch.n_tokens <= 0 || batch.n_seq_id == nullptr || batch.seq_id == nullptr) { - return false; - } - - for (int i = 0; i < batch.n_tokens; ++i) { - if (batch.n_seq_id[i] != 1 || batch.seq_id[i] == nullptr || batch.seq_id[i][0] != seq_id) { - return false; - } - } - - return true; -} - -static int common_speculative_copy_seq_batch( - const llama_batch & batch, - llama_seq_id seq_id, - llama_batch & seq_batch) { - if (batch.token == nullptr || batch.pos == nullptr) { - return -1; - } - - if (batch.n_tokens < 1) { - return 0; - } - - std::vector token_indices; - token_indices.reserve(batch.n_tokens); - for (int i = 0; i < batch.n_tokens; ++i) { - if (common_speculative_batch_token_has_seq_id(batch, i, seq_id)) { - token_indices.push_back(i); - } - } - - if (token_indices.empty()) { - return 0; - } - - seq_batch = llama_batch_init((int) token_indices.size(), 0, 1); - for (const int i : token_indices) { - common_batch_add(seq_batch, batch.token[i], batch.pos[i], { seq_id }, batch.logits != nullptr && batch.logits[i]); - } - - return (int) token_indices.size(); -} - -static bool common_speculative_feature_view_copy_batch_rows( - const common_speculative_feature_view & view, - const llama_batch & batch, - llama_seq_id seq_id, - std::vector * hidden_rows) { - if (hidden_rows == nullptr || view.kind != COMMON_SPECULATIVE_FEATURE_HIDDEN_STATE || view.width <= 0 || batch.n_tokens <= 0 || batch.pos == nullptr) { - return false; - } - - std::unordered_map rows_by_pos; - rows_by_pos.reserve(view.rows.size()); - for (const auto & row : view.rows) { - if (row.seq_id == seq_id && row.data != nullptr) { - rows_by_pos[row.pos] = row.data; - } - } - - hidden_rows->clear(); - hidden_rows->reserve((size_t) batch.n_tokens * view.width); - for (int i = 0; i < batch.n_tokens; ++i) { - auto it = rows_by_pos.find(batch.pos[i]); - if (it == rows_by_pos.end()) { - hidden_rows->clear(); - return false; - } - - hidden_rows->insert(hidden_rows->end(), it->second, it->second + view.width); - } - - return hidden_rows->size() == (size_t) batch.n_tokens * view.width; -} diff --git a/common/speculative.cpp b/common/speculative.cpp index b491c244..8f15f8ef 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -135,69 +135,6 @@ static bool common_speculative_are_compatible( return true; } -static bool common_speculative_are_dflash_compatible( - const llama_model * model_tgt, - const llama_model * model_dft) { - const llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt); - const llama_vocab * vocab_dft = llama_model_get_vocab(model_dft); - - if (llama_vocab_type(vocab_tgt) != llama_vocab_type(vocab_dft)) { - LOG_DBG("%s: DFlash draft model vocab type must match the target model\n", __func__); - return false; - } - - const bool add_bos_tgt = llama_vocab_get_add_bos(vocab_tgt); - const bool add_bos_dft = llama_vocab_get_add_bos(vocab_dft); - const bool add_eos_tgt = llama_vocab_get_add_eos(vocab_tgt); - const bool add_eos_dft = llama_vocab_get_add_eos(vocab_dft); - const llama_token bos_tgt = llama_vocab_bos(vocab_tgt); - const llama_token bos_dft = llama_vocab_bos(vocab_dft); - const llama_token eos_tgt = llama_vocab_eos(vocab_tgt); - const llama_token eos_dft = llama_vocab_eos(vocab_dft); - - if (add_bos_tgt != add_bos_dft || add_eos_tgt != add_eos_dft || - (add_bos_tgt && bos_tgt != bos_dft) || - (add_eos_tgt && eos_tgt != eos_dft)) { - LOG_DBG("%s: DFlash draft special tokens must match the target model (add_bos=%d/%d add_eos=%d/%d bos=%d/%d eos=%d/%d)\n", - __func__, - (int) add_bos_tgt, - (int) add_bos_dft, - (int) add_eos_tgt, - (int) add_eos_dft, - (int) bos_tgt, - (int) bos_dft, - (int) eos_tgt, - (int) eos_dft); - return false; - } - - const int n_vocab_tgt = llama_vocab_n_tokens(vocab_tgt); - const int n_vocab_dft = llama_vocab_n_tokens(vocab_dft); - const int vocab_diff = n_vocab_tgt > n_vocab_dft - ? n_vocab_tgt - n_vocab_dft - : n_vocab_dft - n_vocab_tgt; - - if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) { - LOG_DBG("%s: DFlash draft vocab size differs too much from the target model (%d vs %d)\n", - __func__, n_vocab_dft, n_vocab_tgt); - return false; - } - - for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) { - const char * token_text_tgt = llama_vocab_get_text(vocab_tgt, i); - const char * token_text_dft = llama_vocab_get_text(vocab_dft, i); - - if (std::strcmp(token_text_tgt, token_text_dft) != 0) { - LOG_DBG("%s: DFlash draft token %d differs - target '%s', draft '%s'\n", __func__, i, - common_token_to_piece(vocab_tgt, i).c_str(), - common_token_to_piece(vocab_dft, i).c_str()); - return false; - } - } - - return true; -} - // state of an implementation of speculative decoding // // each implementation has a unique type and a state that is implementation-specific @@ -251,27 +188,11 @@ struct common_speculative_state { struct common_speculative_state_mtp; struct common_speculative_state_dflash; -static void dflash_contract_log_append( - const common_speculative_state_dflash & state, - llama_seq_id seq_id, - const std::vector & new_positions); -static void dflash_contract_log_draft( - const common_speculative_state_dflash & state, - int32_t n_keep, - size_t result_size); - static common_speculative_state_mtp * common_speculative_get_mtp_state(common_speculative * spec); static const common_speculative_state_mtp * common_speculative_get_mtp_state(const common_speculative * spec); static common_speculative_state_dflash * common_speculative_get_dflash_state(common_speculative * spec); static const common_speculative_state_dflash * common_speculative_get_dflash_state(const common_speculative * spec); static int32_t common_speculative_feature_width(const common_speculative * spec); -static void dflash_materialize_target_window_features(common_speculative_state_dflash & state); -static void dflash_ring_reset_rows(common_speculative_state_dflash & state, const float * rows, int32_t n_rows); -static void dflash_append_target_features( - common_speculative_state_dflash & state, - const float * feature_rows, - int32_t n_rows); -static void dflash_clear_target_features(common_speculative_state_dflash & state); static void mtp_invalidate_cached_drafts(common_speculative_state_mtp & state); static bool common_speculative_checkpoint_save( common_speculative_checkpoint & ckpt, @@ -298,92 +219,6 @@ static std::vector mtp_speculative_gen_draft( static int32_t mtp_update_kv_cache(struct llama_context * ctx, const llama_batch & batch, bool is_prompt_warmup); -static bool dflash_contract_log_enabled() { - const char * env = std::getenv("IK_DFLASH_CONTRACT_LOG"); - if (env == nullptr || *env == '\0') { - return false; - } - - return std::strcmp(env, "0") != 0 && - std::strcmp(env, "false") != 0 && - std::strcmp(env, "off") != 0; -} - -static bool dflash_stats_log_enabled() { - const char * env = std::getenv("IK_DFLASH_STATS_LOG"); - if (env == nullptr || *env == '\0') { - return false; - } - - return std::strcmp(env, "0") != 0 && - std::strcmp(env, "false") != 0 && - std::strcmp(env, "off") != 0; -} - -template -static std::string dflash_contract_format_values( - const std::vector & values, - size_t edge_count = 4) { - std::ostringstream oss; - oss << '['; - if (values.empty()) { - oss << ']'; - return oss.str(); - } - - const size_t head = std::min(edge_count, values.size()); - for (size_t i = 0; i < head; ++i) { - if (i > 0) { - oss << ','; - } - oss << values[i]; - } - - if (values.size() > edge_count * 2) { - oss << ",...,"; - for (size_t i = values.size() - edge_count; i < values.size(); ++i) { - if (i > values.size() - edge_count) { - oss << ','; - } - oss << values[i]; - } - } else { - for (size_t i = head; i < values.size(); ++i) { - oss << ',' << values[i]; - } - } - - oss << ']'; - return oss.str(); -} - -struct dflash_contract_pos_summary { - llama_pos first = -1; - llama_pos last = -1; - int32_t gap_count = 0; - int32_t nonmono_count = 0; -}; - -static dflash_contract_pos_summary dflash_contract_summarize_positions( - const std::vector & positions) { - dflash_contract_pos_summary summary; - if (positions.empty()) { - return summary; - } - - summary.first = positions.front(); - summary.last = positions.back(); - for (size_t i = 1; i < positions.size(); ++i) { - if (positions[i] <= positions[i - 1]) { - summary.nonmono_count++; - } else if (positions[i] != positions[i - 1] + 1) { - summary.gap_count++; - } - } - - return summary; -} - struct mtp_last_embd { std::vector embd; float prob = 0.0f; @@ -500,7 +335,1327 @@ struct common_speculative_state_mtp : public common_speculative_state { } }; -#include "speculative-impl.h" +#include "speculative-dflash-impl.h" + +struct common_speculative_state_draft : public common_speculative_state { + llama_context * ctx_tgt; // only used for retokenizing from ctx_dft + llama_context * ctx_dft; + + common_sampler * smpl; + + llama_batch batch; + llama_tokens prompt_dft; + + bool vocab_cmpt = true; // whether retokenization is needed + std::unordered_map vocab_map; + + common_speculative_state_draft( + enum common_speculative_type type, + llama_context * ctx_tgt, + llama_context * ctx_dft, + const std::vector> & replacements) + : common_speculative_state(type) + , ctx_tgt(ctx_tgt) + , ctx_dft(ctx_dft) + { + batch = llama_batch_init(llama_n_batch(ctx_dft), 0, 1); + smpl = nullptr; + { + struct common_params_sampling params; + params.top_k = 10; + params.samplers_sequence = { + llama_sampler_type::TOP_K, + llama_sampler_type::DIST, // needed to get probabilities + }; + smpl = common_sampler_init(llama_get_model(ctx_dft), params); + } + + vocab_cmpt = common_speculative_are_compatible(llama_get_model(ctx_tgt), llama_get_model(ctx_dft)); + LOG_DBG("vocab_cmpt = %d\n", vocab_cmpt); + + if (!vocab_cmpt) { + LOG_WRN("the target and draft vocabs are not compatible - tokens will be translated between the two\n"); + + for (const auto & pair : replacements) { + vocab_map[pair.first] = pair.second; + } + } + } + + ~common_speculative_state_draft() override { + llama_free(ctx_dft); + + common_sampler_free(smpl); + + llama_batch_free(batch); + } + + void begin(const llama_tokens & prompt) override { + GGML_UNUSED(prompt); + } + + void draft( + const common_params_speculative & params, + const llama_tokens & prompt_tgt, + llama_token id_last, + llama_tokens & result) override { + auto * spec = this; + + auto & batch = spec->batch; + auto & ctx_tgt = spec->ctx_tgt; + auto & ctx_dft = spec->ctx_dft; + auto & smpl = spec->smpl; + auto & prompt_dft = spec->prompt_dft; + + int reuse_i = 0; + int reuse_n = 0; + + const int n_ctx = llama_n_ctx(ctx_dft) - params.n_max; + + llama_tokens prompt_cnv; + if (!spec->vocab_cmpt) { + // convert id_last to draft vocab. llama_detokenize is called directly to avoid an allocation + const auto * model_tgt = llama_get_model(ctx_tgt); + const auto * vocab_tgt = llama_model_get_vocab(model_tgt); + + std::string text; + + text = common_detokenize(ctx_tgt, prompt_tgt, true); + text = replace_to_dft(text); + + LOG_DBG("%s: main->draft detokenized string: '%s'\n", __func__, text.c_str()); + + prompt_cnv = common_tokenize(ctx_dft, text, false, true); + + + + int32_t n_chars = llama_detokenize(vocab_tgt, &id_last, 1, nullptr, 0, false, false); + GGML_ASSERT(n_chars < 0 && "failed to detokenize id_last"); + + text.resize(-n_chars); + llama_detokenize(vocab_tgt, &id_last, 1, text.data(), text.size(), false, false); + text = replace_to_dft(text); + + LOG_DBG("main->draft detokenized id_last(%d): '%s'\n", id_last, text.c_str()); + id_last = common_tokenize(ctx_dft, text, false, true)[0]; + } + + const llama_tokens & prompt_cur = spec->vocab_cmpt ? prompt_tgt : prompt_cnv; + + const int i_start = std::max(0, (int) prompt_cur.size() - n_ctx); + + // reuse as much as possible from the old draft context + // ideally, the draft context should be as big as the target context and we will always reuse the entire prompt + for (int i = 0; i < (int) prompt_dft.size(); ++i) { + int cur = 0; + while (i_start + cur < (int) prompt_cur.size() && + i + cur < (int) prompt_dft.size() && + prompt_cur[i_start + cur] == prompt_dft[i + cur]) { + cur++; + } + + if ((cur >= 256 || n_ctx >= (int) prompt_cur.size()) && cur > reuse_n) { + reuse_i = i; + reuse_n = cur; + } + } + + LOG_DBG("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt_dft.size()); + + result.clear(); + result.reserve(params.n_max); + + if (reuse_n == 0) { + llama_kv_cache_clear(ctx_dft); + prompt_dft.clear(); + } else { + // this happens when a previous draft has been discarded (for example, due to being too small), but the + // target model agreed with it. in this case, we simply pass back the previous results to save compute + if (reuse_i + reuse_n < (int) prompt_dft.size() && prompt_dft[reuse_i + reuse_n] == id_last) { + for (int i = reuse_i + reuse_n + 1; i < (int) prompt_dft.size(); ++i) { + result.push_back(prompt_dft[i]); + + if (params.n_max <= (int) result.size()) { + break; + } + } + + return; + } + + if (reuse_i > 0) { + llama_kv_cache_seq_rm (ctx_dft, 0, 0, reuse_i); + llama_kv_cache_seq_add(ctx_dft, 0, reuse_i, -1, -reuse_i); + + prompt_dft.erase(prompt_dft.begin(), prompt_dft.begin() + reuse_i); + } + + if (reuse_n < (int) prompt_dft.size()) { + llama_kv_cache_seq_rm (ctx_dft, 0, reuse_n, -1); + prompt_dft.erase(prompt_dft.begin() + reuse_n, prompt_dft.end()); + } + } + + // prepare a batch to evaluate any new tokens in the prompt + common_batch_clear(batch); + + for (size_t i = i_start + reuse_n; i < prompt_cur.size(); ++i) { + //LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_cur[i]); + common_batch_add(batch, prompt_cur[i], i - i_start, { 0 }, false); + + prompt_dft.push_back(prompt_cur[i]); + } + + // we should rarely end-up here during normal decoding + if (batch.n_tokens > 0) { + //LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str()); + + llama_decode(ctx_dft, batch); + } + + const llama_pos n_past = prompt_dft.size(); + + LOG_DBG("%s: n_past = %d\n", __func__, n_past); + + common_batch_clear(batch); + common_batch_add (batch, id_last, n_past, { 0 }, true); + + prompt_dft.push_back(id_last); + + //LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx_dft, prompt_dft).c_str()); + + llama_decode(ctx_dft, batch); + + common_sampler_reset(smpl); + + // sample n_draft tokens from the draft model + for (int i = 0; i < params.n_max; ++i) { + common_batch_clear(batch); + + common_sampler_sample(smpl, ctx_dft, 0, true); + + const auto * cur_p = common_sampler_get_candidates(smpl, true); + + for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) { + LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", + k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str()); + } + + // add drafted token for each sequence + const llama_token id = cur_p->data[0].id; + + common_sampler_accept(smpl, nullptr, id, true); + + // only collect very high-confidence draft tokens + if (cur_p->data[0].p < params.p_min) { + if (i == 0) { + result.push_back(id); + } + break; + } + + result.push_back(id); + + if (params.n_max <= (int) result.size()) { + break; + } + + + common_batch_add(batch, id, n_past + i + 1, { 0 }, true); + + // evaluate the drafted tokens on the draft model + llama_decode(ctx_dft, batch); + + prompt_dft.push_back(id); + } + + if (!spec->vocab_cmpt) { + std::string detokenized = common_detokenize(ctx_dft, result, true); + detokenized = replace_to_tgt(detokenized); + LOG_DBG("draft->main detokenized string: '%s'\n", detokenized.c_str()); + result = common_tokenize(ctx_tgt, detokenized, false, true); + if (result.size() > (size_t)params.n_max) { + result.resize(params.n_max); + } + } + } + + void accept(uint16_t n_accepted) override { + // noop + GGML_UNUSED(n_accepted); + } + + std::string replace_to_dft(const std::string & input) const { + std::string result = input; + + for (const auto & pair : this->vocab_map) { + size_t pos = result.find(pair.first); + while (pos != std::string::npos) { + result.replace(pos, pair.first.length(), pair.second); + pos = result.find(pair.first, pos + pair.second.length()); + } + } + + return result; + } + + std::string replace_to_tgt(const std::string & input) const { + std::string result = input; + + for (const auto & pair : this->vocab_map) { + size_t pos = result.find(pair.second); + while (pos != std::string::npos) { + result.replace(pos, pair.second.length(), pair.first); + pos = result.find(pair.second, pos + pair.first.length()); + } + } + + return result; + } +}; + +struct common_speculative_state_eagle3 : public common_speculative_state { + common_speculative_state_eagle3(enum common_speculative_type type) : common_speculative_state(type) {} + + void begin(const llama_tokens & prompt) override { + GGML_UNUSED(prompt); + } + + void draft( + const common_params_speculative & params, + const llama_tokens & prompt_tgt, + llama_token id_last, + llama_tokens & draft_tokens) override { + // TODO: implement + GGML_UNUSED(params); + GGML_UNUSED(prompt_tgt); + GGML_UNUSED(id_last); + GGML_UNUSED(draft_tokens); + } + + void accept(uint16_t n_accepted) override { + // noop + GGML_UNUSED(n_accepted); + } +}; + +// state of self-speculation (simple implementation, not ngram-map) +struct common_speculative_state_ngram_simple : public common_speculative_state { + common_ngram_simple_config config; + + common_speculative_state_ngram_simple( + enum common_speculative_type type, + common_ngram_simple_config config) + : common_speculative_state(type), config(config) {} + + void begin(const llama_tokens & prompt) override { + GGML_UNUSED(prompt); + } + + void draft( + const common_params_speculative & params, + const llama_tokens & prompt_tgt, + llama_token id_last, + llama_tokens & result) override { + + result = common_ngram_simple_draft(config, prompt_tgt, id_last); + GGML_UNUSED(params); + } + + void accept(uint16_t n_accepted) override { + // noop + GGML_UNUSED(n_accepted); + } +}; + +struct common_speculative_state_ngram_map_k : public common_speculative_state { + // draft ngram map for speculative decoding without draft model + common_ngram_map map; + + common_speculative_state_ngram_map_k( + enum common_speculative_type type, + common_ngram_map map) + : common_speculative_state(type), map(std::move(map)) {} + + void begin(const llama_tokens & prompt) override { + common_ngram_map_begin(map, prompt); + } + + void draft( + const common_params_speculative & params, + const llama_tokens & prompt_tgt, + llama_token id_last, + llama_tokens & result) override { + common_ngram_map_draft(map, prompt_tgt, id_last, result); + GGML_UNUSED(params); + } + + void accept(uint16_t n_accepted) override { + common_ngram_map_accept(map, n_accepted); + } +}; + +struct common_speculative_state_ngram_mod : public common_speculative_state { + common_ngram_mod & mod; + + // the last position in the prompt that was added to the ngram container + size_t i_last = 0; + + // length of the last drafted n‑gram (number of tokens returned by draft) + size_t n_draft_last = 0; + + // consecutive accept rounds with low acceptance fraction (< 0.5) + int n_low = 0; + + // enable trace logging if LLAMA_TRACE is set + const bool verbose; + + common_speculative_state_ngram_mod(enum common_speculative_type type, common_ngram_mod & mod) + : common_speculative_state(type), mod(mod), verbose(std::getenv("LLAMA_TRACE") != nullptr) { + static_assert(sizeof(llama_token) == sizeof(common_ngram_mod::entry_t)); + } + + void begin(const llama_tokens & prompt) override { + i_last = 0; + + n_draft_last = 0; + n_low = 0; + + const size_t n = mod.get_n(); + + if (prompt.size() < n) { + return; + } + + for (size_t i = 0; i < prompt.size() - n; ++i) { + mod.add(prompt.data() + i); + } + + i_last = prompt.size() - n; + + const double f = (double)mod.get_used() / (double)mod.size(); + LOG_INF("%s: ngram_mod occupancy = %zu/%zu (%.2f)\n", __func__, mod.get_used(), mod.size(), f); + + constexpr double f_thold = 0.25; + if (f > f_thold) { + LOG_WRN("%s: ngram_mod occupancy %.2f exceeds threshold (%.2f) - resetting\n", __func__, f, f_thold); + + mod.reset(); + } + } + + void draft( + const common_params_speculative & params, + const llama_tokens & prompt_tgt, + llama_token id_last, + llama_tokens & result) override { + GGML_UNUSED(params); + + n_draft_last = 0; + + const size_t cur_len = prompt_tgt.size(); + if (cur_len < mod.get_n()) { + return; + } + + const size_t n = mod.get_n(); + + // add new ngrams in chunks + if (i_last + 32 < cur_len) { + for (size_t i = i_last; i < cur_len - n; ++i) { + mod.add(prompt_tgt.data() + i); + } + + i_last = cur_len - n; + } + + result.resize(n + params.n_max); + for (size_t i = 0; i < n - 1; ++i) { + result[i] = prompt_tgt[cur_len - n + 1 + i]; + } + result[n - 1] = id_last; + + for (int i = 0; i < params.n_max; ++i) { + const llama_token token = mod.get(result.data() + i); + if (token == common_ngram_mod::EMPTY) { + if (i < params.n_min) { + result.clear(); + return; + } + + result.resize(n + i); + break; + } + result[n + i] = token; + } + + // only return the m tokens that were drafted + for (size_t i = 0; n + i < result.size(); ++i) { + result[i] = result[n + i]; + } + result.resize(result.size() - n); + + // store length of drafted n‑gram for later acceptance analysis + n_draft_last = result.size(); + } + + void accept(uint16_t n_accepted) override { + if (verbose) { + LOG_INF("%s: accepted %d tokens from %zu drafted tokens\n", __func__, n_accepted, n_draft_last); + } + + // compute acceptance fraction if we have a recorded draft length + if (n_draft_last > 0) { + const double f_acc = (double)n_accepted / (double)n_draft_last; + if (f_acc < 0.5) { + n_low++; + if (n_low >= 3) { + LOG_WRN("%s: low acceptance streak (%d) – resetting ngram_mod\n", __func__, n_low); + + mod.reset(); + n_low = 0; + i_last = 0; + } + } else { + n_low = 0; + } + } + } +}; + +struct common_speculative_state_ngram_cache : public common_speculative_state { + uint16_t n_draft; + bool save_dynamic; + bool save_static; + + common_ngram_cache ngram_cache_context; + common_ngram_cache ngram_cache_dynamic; + common_ngram_cache ngram_cache_static; + + size_t cache_size = 0; // number of tokens in n-gram cache + + common_speculative_state_ngram_cache( + const enum common_speculative_type type, + const std::string & path_static, + const std::string & path_dynamic, + uint16_t n_draft, + bool save_dynamic, + bool save_static) + : common_speculative_state(type) + , n_draft(n_draft) + , save_dynamic(save_dynamic) + , save_static(save_static) + { + if (!path_static.empty()) { + try { + ngram_cache_static = common_ngram_cache_load(path_static); + } catch (...) { + LOG_ERR("failed to open static lookup cache: %s", path_static.c_str()); + GGML_ABORT("Couldn't read static lookup cache"); + } + } + + if (!path_dynamic.empty()) { + try { + ngram_cache_dynamic = common_ngram_cache_load(path_dynamic); + } catch (...) { + LOG_ERR("failed to open dynamic lookup cache: %s", path_dynamic.c_str()); + GGML_ABORT("Couldn't read dynamic lookup cache"); + } + } + } + + void begin(const llama_tokens & prompt) override { + GGML_UNUSED(prompt); + } + + void draft( + const common_params_speculative & params, + const llama_tokens & prompt_tgt, + llama_token id_last, + llama_tokens & result) override { + GGML_UNUSED(params); + + if (cache_size < prompt_tgt.size() + 1) { + llama_tokens tokens_new; + tokens_new.reserve(prompt_tgt.size() + 1 - cache_size); + for (size_t j = cache_size; j < prompt_tgt.size(); ++j) { + tokens_new.push_back(prompt_tgt[j]); + } + tokens_new.push_back(id_last); // add the last token + + // Update context ngram cache with new prompt_tgt: + common_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, + tokens_new, tokens_new.size(), false); + cache_size = prompt_tgt.size() + 1; + } + + llama_tokens inp; + inp.reserve(prompt_tgt.size() + 1); + for (size_t j = 0; j < prompt_tgt.size(); ++j) { + inp.push_back(prompt_tgt[j]); + } + inp.push_back(id_last); + + result.push_back(id_last); + + common_ngram_cache_draft(inp, result, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, + ngram_cache_context, + ngram_cache_dynamic, + ngram_cache_static); + + if (result.size() > 0) { + // delete first token in result (which is the id_last token) + result.erase(result.begin()); + } + } + + void accept(uint16_t n_accepted) override { + // TODO: noop + GGML_UNUSED(n_accepted); + } +}; + +struct common_speculative_state_suffix : public common_speculative_state { + common_suffix_tree tree; + common_suffix_tree corpus_tree; + bool has_corpus = false; + size_t cache_size = 0; + + // Acceptance feedback + size_t n_draft_last = 0; + bool had_accept = false; + int n_low = 0; + float base_p_min = 0.1f; + float eff_p_min = 0.1f; + + common_speculative_state_suffix( + enum common_speculative_type type, + int max_depth, + const std::string & corpus_path, + const llama_model * model) + : common_speculative_state(type) + , tree(max_depth) + , corpus_tree(max_depth) + { + if (!corpus_path.empty()) { + std::function(const std::string &)> tokenize_fn; + if (model) { + tokenize_fn = [model](const std::string & text) -> std::vector { + return common_tokenize(model, text, false, true); + }; + } + has_corpus = corpus_tree.load_corpus(corpus_path, tokenize_fn); + } + } + + void begin(const llama_tokens & prompt) override { + cache_size = 0; + n_draft_last = 0; + had_accept = false; + n_low = 0; + GGML_UNUSED(prompt); + } + + void draft( + const common_params_speculative & params, + const llama_tokens & prompt_tgt, + llama_token id_last, + llama_tokens & result) override { + + base_p_min = params.p_min; + if (n_draft_last > 0 && !had_accept) { + if (++n_low >= 3) { + eff_p_min = std::min(eff_p_min + 0.1f, 0.5f); + n_low = 0; + } + } + had_accept = false; + + if (cache_size < prompt_tgt.size() + 1) { + llama_tokens tokens_new; + tokens_new.reserve(prompt_tgt.size() + 1 - cache_size); + for (size_t j = cache_size; j < prompt_tgt.size(); ++j) { + tokens_new.push_back(prompt_tgt[j]); + } + tokens_new.push_back(id_last); + + tree.extend(tokens_new.data(), (int)tokens_new.size()); + cache_size = prompt_tgt.size() + 1; + } + + const int ctx_len = std::min((int)(prompt_tgt.size() + 1), tree.max_depth()); + llama_tokens context; + context.reserve(ctx_len); + const int ctx_start = (int)prompt_tgt.size() + 1 - ctx_len; + for (int j = ctx_start; j < (int)prompt_tgt.size(); ++j) { + context.push_back(prompt_tgt[j]); + } + context.push_back(id_last); + const int min_match_len = std::max(1, params.suffix_min_match_len); + + result = tree.speculate( + context.data(), (int)context.size(), + params.n_max, + eff_p_min, + 1, + min_match_len); + + if (has_corpus) { + auto corpus_result = corpus_tree.speculate( + context.data(), (int)context.size(), + params.n_max, + eff_p_min, + 1, + min_match_len); + if (corpus_result.size() > result.size()) { + result = std::move(corpus_result); + } + } + + n_draft_last = result.size(); + } + + void accept(uint16_t n_accepted) override { + if (n_draft_last == 0) { + return; + } + had_accept = true; + const double f_acc = (double)n_accepted / (double)n_draft_last; + if (f_acc < 0.5) { + if (++n_low >= 3) { + eff_p_min = std::min(eff_p_min + 0.1f, 0.5f); + n_low = 0; + } + } else { + n_low = 0; + if (eff_p_min > base_p_min) { + eff_p_min = std::max(eff_p_min - 0.05f, base_p_min); + } + } + } +}; + +struct common_speculative { + std::vector configs; // resolved stage config for each implementation + std::vector> impls; // list of implementations to use and their states + common_speculative_checkpoint checkpoint; + common_speculative_state * curr_impl = nullptr; // current implementation in use (for stats) + std::unique_ptr tuner; + int last_n_drafted = 0; + int64_t t_step_start_us = 0; +}; + +static bool common_speculative_stage_chain_matches( + const std::vector & stages, + const std::vector & configs) { + if (stages.size() != configs.size()) { + return false; + } + + for (size_t i = 0; i < stages.size(); ++i) { + if (stages[i].type != configs[i].type) { + return false; + } + } + + return true; +} + +static common_params_speculative common_speculative_get_runtime_params( + const common_speculative_config & config, + const common_params_speculative & params, + const common_speculative_stage_params & stage) { + common_params_speculative result = config.params; + + result.type = config.type; + result.n_max = stage.has_n_max_override() ? stage.n_max : params.n_max; + result.n_min = stage.has_n_min_override() ? stage.n_min : params.n_min; + result.p_min = stage.has_p_min_override() ? stage.p_min : params.p_min; + + if (config.type == COMMON_SPECULATIVE_TYPE_SUFFIX) { + result.suffix_min_match_len = stage.has_suffix_min_match_len_override() + ? stage.suffix_min_match_len + : params.suffix_min_match_len; + } + + result.n_max = std::max(result.n_max, 0); + result.n_min = std::max(0, std::min(result.n_min, result.n_max)); + result.stages.clear(); + + return result; +} + +static common_ngram_map get_common_ngram_map(const common_speculative_config & config) { + uint16_t size_key = config.params.ngram_size_n; + uint16_t size_value = config.params.ngram_size_m; + bool key_only = (config.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K); + uint16_t min_hits = config.params.ngram_min_hits; + + return common_ngram_map(size_key, size_value, key_only, min_hits); +} + +static common_speculative_state_ngram_cache create_state_ngram_cache( + const std::string & path_static, const std::string & path_dynamic, + const common_speculative_config & config) { + uint16_t n_draft = 8; // TODO get from config? + + // TODO bool param in common/common.h to set save_static/save_dynamic? + bool save_static = false; + bool save_dynamic = false; + + common_speculative_state_ngram_cache state(config.type, path_static, path_dynamic, n_draft, save_static, save_dynamic); + + return state; +} + +std::string common_speculative_type_name_str() { + std::string result; + for (size_t i = 0; i < common_speculative_types.size(); i++) { + if (i > 0) { + result += ", "; + } + result += common_speculative_type_to_str(common_speculative_types[i]); + } + return result; +} + +std::string common_speculative_type_to_str(enum common_speculative_type type) { + switch (type) { + case COMMON_SPECULATIVE_TYPE_NONE: return "none"; + case COMMON_SPECULATIVE_TYPE_DRAFT: return "draft"; + case COMMON_SPECULATIVE_TYPE_DFLASH: return "dflash"; + case COMMON_SPECULATIVE_TYPE_MTP: return "mtp"; + case COMMON_SPECULATIVE_TYPE_EAGLE3: return "eagle3"; + case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return "ngram_simple"; + case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: return "ngram_map_k"; + case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: return "ngram_map_k4v"; + case COMMON_SPECULATIVE_TYPE_NGRAM_MOD: return "ngram_mod"; + case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: return "ngram_cache"; + case COMMON_SPECULATIVE_TYPE_SUFFIX: return "suffix"; + default: return "unknown"; + } +} + +enum common_speculative_type common_speculative_type_from_name(const std::string & name) { + std::string normalized = name; + std::replace(normalized.begin(), normalized.end(), '-', '_'); + + const auto it = common_speculative_type_from_name_map.find(normalized); + if (it == common_speculative_type_from_name_map.end()) { + return COMMON_SPECULATIVE_TYPE_COUNT; + } + return it->second; +} + +bool common_speculative_is_compat(llama_context * ctx_tgt) { + bool res = true; + + llama_kv_cache_clear(ctx_tgt); + + // eval 2 tokens to check if the context is compatible + std::vector tmp; + tmp.push_back(0); + tmp.push_back(0); + + int ret = llama_decode(ctx_tgt, llama_batch_get_one(tmp.data(), tmp.size(), 0, 0)); + if (ret != 0) { + LOG_ERR("%s: llama_decode() failed: %d\n", __func__, ret); + res = false; + goto done; + } + + // try to remove the last tokens + if (!llama_kv_cache_seq_rm(ctx_tgt, 0, 1, -1)) { + LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__); + res = false; + goto done; + } + +done: + llama_kv_cache_clear(ctx_tgt); + llama_synchronize(ctx_tgt); + + return res; +} + +// initialization of the speculative decoding system +// +common_speculative * common_speculative_init( + common_params_speculative & params, + llama_context * ctx_tgt) { + std::string chain_error; + if (!common_speculative_validate_chain(params, &chain_error)) { + LOG_ERR("%s: invalid speculative stage chain: %s\n", __func__, chain_error.c_str()); + return nullptr; + } + + const auto stages = params.get_resolved_stages(); + if (params.model_dft && llama_model_is_gemma4_mtp_assistant(params.model_dft)) { + const bool has_draft_stage = std::any_of(stages.begin(), stages.end(), [](const common_speculative_stage_params & stage) { + return stage.type == COMMON_SPECULATIVE_TYPE_DRAFT; + }); + + if (has_draft_stage) { + LOG_ERR("%s: Gemma4 assistant models only support MTP stages; omit -md for self-spec-only runs or use -mtp/--spec-stage mtp for assistant-backed MTP\n", __func__); + return nullptr; + } + } + + const bool has_dflash_stage = std::any_of(stages.begin(), stages.end(), [](const common_speculative_stage_params & stage) { + return stage.type == COMMON_SPECULATIVE_TYPE_DFLASH; + }); + + const bool needs_draft_ctx = std::any_of(stages.begin(), stages.end(), [¶ms](const common_speculative_stage_params & stage) { + return stage.type == COMMON_SPECULATIVE_TYPE_DRAFT || + stage.type == COMMON_SPECULATIVE_TYPE_DFLASH || + (stage.type == COMMON_SPECULATIVE_TYPE_MTP && params.model_dft != nullptr); + }); + + llama_context * ctx_dft = nullptr; + if (needs_draft_ctx) { + if (!params.model_dft) { + LOG_ERR("%s: draft speculative stage requires a loaded draft model\n", __func__); + return nullptr; + } + + llama_context_params cparams_dft = params.cparams_dft; + + if (has_dflash_stage) { + if (!llama_model_share_dflash_io_tensors(params.model_dft, llama_get_model(ctx_tgt))) { + LOG_ERR("%s: failed to share target IO tensors with DFlash draft model\n", __func__); + return nullptr; + } + + int32_t max_cross_ctx = 0; + for (const auto & stage : stages) { + if (stage.type != COMMON_SPECULATIVE_TYPE_DFLASH) { + continue; + } + + max_cross_ctx = std::max(max_cross_ctx, params.with_stage_overrides(stage).dflash_cross_ctx); + } + + const int32_t block_size = llama_model_dflash_block_size(params.model_dft); + if (block_size <= 0) { + LOG_ERR("%s: invalid DFlash draft block size\n", __func__); + return nullptr; + } + + const int64_t required_n_ctx = (int64_t) max_cross_ctx + (int64_t) block_size; + if (required_n_ctx > std::numeric_limits::max()) { + LOG_ERR("%s: invalid DFlash draft context size cross_ctx=%d block_size=%d required_n_ctx=%lld\n", + __func__, max_cross_ctx, block_size, (long long) required_n_ctx); + return nullptr; + } + + cparams_dft.n_ctx = (uint32_t) required_n_ctx; + } + + ctx_dft = llama_init_from_model(params.model_dft, cparams_dft); + if (ctx_dft == nullptr) { + LOG_ERR("%s", "failed to create draft context\n"); + return nullptr; + } + } + + // Compute the implementations to use based on the resolved stage chain. + std::vector configs = {}; + configs.reserve(stages.size()); + + for (const auto & stage : stages) { + common_params_speculative stage_params = params.with_stage_overrides(stage); + + if (stage.type == COMMON_SPECULATIVE_TYPE_NGRAM_MOD && !stage_params.ngram_mod) { + stage_params.ngram_mod = std::make_shared(stage_params.ngram_size_n, 4*1024*1024); + + LOG_INF("%s: initialized ngram_mod with n=%d, size=%zu (%.3f MB)\n", __func__, + stage_params.ngram_size_n, stage_params.ngram_mod->size(), + (float)(stage_params.ngram_mod->size_bytes())/1024/1024); + + if (stage_params.ngram_size_n < 16) { + LOG_WRN("%s: ngram_mod n=%d is too small - poor quality is possible, see: https://github.com/ggml-org/llama.cpp/pull/19164\n", __func__, stage_params.ngram_size_n); + } + } + + configs.push_back(common_speculative_config(stage, stage_params)); + } + + if (!configs.empty() && llama_model_has_recurrent(llama_get_model(ctx_tgt))) { + const int ckpt_tokens = std::max(1, params.get_max_stage_n_max() + 1); + const int actual_mode = llama_spec_ckpt_init(ctx_tgt, params.recurrent_ckpt_mode, ckpt_tokens); + if (actual_mode == LLAMA_SPEC_CKPT_NONE) { + LOG_ERR("%s: failed to prepare recurrent checkpoint mode '%s' during speculative init (max_tokens=%d)\n", + __func__, + params.recurrent_ckpt_mode == LLAMA_SPEC_CKPT_PER_STEP ? "per-step" : + params.recurrent_ckpt_mode == LLAMA_SPEC_CKPT_GPU_FALLBACK ? "gpu-fallback" : + params.recurrent_ckpt_mode == LLAMA_SPEC_CKPT_CPU ? "cpu" : "auto", + ckpt_tokens); + if (ctx_dft != nullptr) { + llama_free(ctx_dft); + } + return nullptr; + } + llama_spec_ckpt_discard(ctx_tgt); + params.recurrent_ckpt_mode = actual_mode; + } + + std::vector> impls = {}; + + for (const common_speculative_config & config : configs) { + LOG_DBG("%s: adding implementation %s\n", __func__, common_speculative_type_to_str(config.type).c_str()); + switch (config.type) { + case COMMON_SPECULATIVE_TYPE_NONE: + break; + case COMMON_SPECULATIVE_TYPE_DRAFT: { + impls.push_back(std::make_unique(config.type, + /* .ctx_tgt = */ ctx_tgt, + /* .ctx_dft = */ ctx_dft, + /* .replacements = */ config.params.replacements + )); + break; + } + case COMMON_SPECULATIVE_TYPE_DFLASH: { + auto state = std::make_unique( + config.type, + ctx_tgt, + ctx_dft, + config.params.dflash_cross_ctx); + if (!state->ready) { + LOG_ERR("%s: failed to initialize DFlash speculative state\n", __func__); + return nullptr; + } + impls.push_back(std::move(state)); + ctx_dft = nullptr; + break; + } + case COMMON_SPECULATIVE_TYPE_MTP: { + llama_context * ctx_mtp = ctx_dft; + if (!ctx_mtp) { + const llama_model * model = llama_get_model(ctx_tgt); + ctx_mtp = llama_init_from_model(const_cast(model), config.params.cparams_dft); + if (!ctx_mtp) { + LOG_ERR("%s: failed to create MTP context\n", __func__); + return nullptr; + } + } + ctx_dft = nullptr; + + const bool use_constant_draft_positions = llama_model_is_gemma4_mtp_assistant(llama_get_model(ctx_mtp)); + impls.push_back(std::make_unique( + config.type, ctx_tgt, ctx_mtp, use_constant_draft_positions)); + break; + } + case COMMON_SPECULATIVE_TYPE_EAGLE3: { + impls.push_back(std::make_unique(config.type)); + break; + } + case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: { + common_ngram_map ngram_map = get_common_ngram_map(config); + + uint16_t ngram_size_key = ngram_map.size_key; + uint16_t mgram_size_value = ngram_map.size_value; + + auto config_simple = common_ngram_simple_config { + /* .size_ngram = */ ngram_size_key, + /* .size_mgram = */ mgram_size_value + }; + auto state = std::make_unique( + /* .type = */ config.type, + /* .state = */ config_simple + ); + impls.push_back(std::move(state)); + break; + } + case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: + case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: { + impls.push_back(std::make_unique( + (config.type), + get_common_ngram_map(config) + )); + break; + } + case COMMON_SPECULATIVE_TYPE_NGRAM_MOD: { + GGML_ASSERT(config.params.ngram_mod); + impls.push_back(std::make_unique(config.type, *config.params.ngram_mod)); + break; + } + case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: { + auto state = create_state_ngram_cache( + config.params.lookup_cache_static, config.params.lookup_cache_dynamic, config); + impls.push_back(std::make_unique(state)); + break; + } + case COMMON_SPECULATIVE_TYPE_SUFFIX: { + int depth = config.params.suffix_max_depth > 0 ? config.params.suffix_max_depth : 64; + const llama_model * model = llama_get_model(ctx_tgt); + impls.push_back(std::make_unique( + config.type, depth, config.params.suffix_corpus, model)); + break; + } + default: + break; + } + } + + if (impls.empty()) { + LOG_WRN("%s", "no implementations specified for speculative decoding\n"); + return nullptr; + } + + auto * result = new common_speculative { + /* .configs = */ std::move(configs), + /* .impls = */ std::move(impls) + }; + + // initialize autotune if requested + if (params.autotune && params.has_composite_stage_chain()) { + LOG_WRN("Autotune disabled — explicit speculative stage chains are not supported yet\n"); + } else if (params.autotune && !result->impls.empty()) { + auto actual_type = result->impls[0]->type; + if (actual_type != COMMON_SPECULATIVE_TYPE_NONE && + actual_type != COMMON_SPECULATIVE_TYPE_EAGLE3) { + result->tuner = std::make_unique(); + result->tuner->init(actual_type, params, llama_get_model(ctx_tgt)); + LOG_DBG("Autotune initialized for %s, tuning %zu parameters\n", + common_speculative_type_to_str(actual_type).c_str(), + result->tuner->coords.size()); + } else { + LOG_WRN("Autotune disabled — speculative type %s is not supported for autotuning\n", + common_speculative_type_to_str(actual_type).c_str()); + } + } + + return result; +} + +void common_speculative_free(common_speculative * spec) { + if (spec == nullptr) { + return; + } + + spec->checkpoint.clear(); + delete spec; +} + +void common_speculative_begin(common_speculative * spec, const llama_tokens & prompt) { + if (spec == nullptr) { + return; + } + + for (auto & impl : spec->impls) { + common_time_meas tm(impl->t_begin_us, !impl->gen_perf); + impl->begin(prompt); + impl->n_call_begin++; + } +} + +llama_tokens common_speculative_draft( + common_speculative * spec, + common_params_speculative & params, + const llama_tokens & prompt_tgt, // specified in target model vocab + llama_token id_last, + llama_pos draft_base_pos, + llama_seq_id draft_seq_id) { + llama_tokens result; + + spec->t_step_start_us = ggml_time_us(); + + // apply autotune proposal if enabled + if (spec->tuner && spec->tuner->enabled) { + spec->tuner->propose(params); + } + + const auto runtime_stages = params.get_resolved_stages(); + const bool use_runtime_stage_overrides = common_speculative_stage_chain_matches(runtime_stages, spec->configs); + + spec->curr_impl = nullptr; // reset current implementation + + for (size_t i = 0; i < spec->impls.size(); ++i) { + auto & impl = spec->impls[i]; + const auto & runtime_stage = use_runtime_stage_overrides ? runtime_stages[i] : spec->configs[i].stage; + common_params_speculative impl_params = common_speculative_get_runtime_params(spec->configs[i], params, runtime_stage); + result.clear(); + + { + common_time_meas tm(impl->t_draft_us, !impl->gen_perf); + impl->draft(impl_params, prompt_tgt, id_last, draft_base_pos, draft_seq_id, result); + impl->n_call_draft++; + } + + if (result.empty()) { + continue; + } + + if (common_speculative_type_is_self_spec(impl->type) && impl_params.n_min > 0 && (int)result.size() < impl_params.n_min) { + LOG_DBG("%s: impl %s drafted %zu tokens, below fallback threshold %d - trying next implementation\n", + __func__, common_speculative_type_to_str(impl->type).c_str(), result.size(), impl_params.n_min); + result.clear(); + continue; + } + LOG_DBG("%s: called impl %s, hist size = %zu, call_count = %zu, gen = %zu\n", __func__, + common_speculative_type_to_str(impl.get()->type).c_str(), prompt_tgt.size(), + impl.get()->n_call_draft, result.size()); + + spec->curr_impl = impl.get(); + impl->n_gen_drafts++; + impl->n_gen_tokens += result.size(); + + break; // We have a draft, so break out of the loop and return it. + } + + // store draft count for tuner feedback + if (spec->tuner && spec->tuner->enabled) { + spec->last_n_drafted = (int)result.size(); + } + + return result; +} + +void common_speculative_accept(common_speculative * spec, uint16_t n_accepted) { + if (spec->tuner && spec->tuner->enabled && spec->t_step_start_us > 0) { + int64_t step_time_us = ggml_time_us() - spec->t_step_start_us; + double step_tps = (step_time_us > 100) + ? (n_accepted + 1.0) * 1e6 / (double)step_time_us + : 0.0; + spec->tuner->accept_feedback(n_accepted, spec->last_n_drafted, step_tps); + spec->t_step_start_us = 0; + } + + common_speculative_state * impl = spec->curr_impl; + + if (!impl) { + return; + } + + { + common_time_meas tm(impl->t_accept_us, !impl->gen_perf); + if (n_accepted > 0) { + impl->n_acc_drafts++; + impl->n_acc_tokens += n_accepted; + } + + impl->accept(n_accepted); + impl->n_call_accept++; + } + + if (impl->type != COMMON_SPECULATIVE_TYPE_MTP) { + if (auto * mtp_state = common_speculative_get_mtp_state(spec); mtp_state != nullptr) { + mtp_invalidate_cached_drafts(*mtp_state); + } + } +} + +static bool common_speculative_has_type(const common_speculative * spec, common_speculative_type type) { + if (spec == nullptr) { + return false; + } + + return std::any_of(spec->configs.begin(), spec->configs.end(), [type](const common_speculative_config & config) { + return config.type == type; + }); +} + +static int common_speculative_ctx_mtp_n_embd(llama_context * ctx) { + return ctx ? (int) llama_mtp_state_n_embd(ctx) : 0; +} + +static bool common_speculative_batch_token_has_seq_id( + const llama_batch & batch, + int token_index, + llama_seq_id seq_id) { + if (batch.n_seq_id == nullptr || batch.seq_id == nullptr || batch.n_seq_id[token_index] <= 0 || batch.seq_id[token_index] == nullptr) { + return false; + } + + for (int i = 0; i < batch.n_seq_id[token_index]; ++i) { + if (batch.seq_id[token_index][i] == seq_id) { + return true; + } + } + + return false; +} + +static bool common_speculative_batch_is_exact_single_seq( + const llama_batch & batch, + llama_seq_id seq_id) { + if (batch.n_tokens <= 0 || batch.n_seq_id == nullptr || batch.seq_id == nullptr) { + return false; + } + + for (int i = 0; i < batch.n_tokens; ++i) { + if (batch.n_seq_id[i] != 1 || batch.seq_id[i] == nullptr || batch.seq_id[i][0] != seq_id) { + return false; + } + } + + return true; +} + +static int common_speculative_copy_seq_batch( + const llama_batch & batch, + llama_seq_id seq_id, + llama_batch & seq_batch) { + if (batch.token == nullptr || batch.pos == nullptr) { + return -1; + } + + if (batch.n_tokens < 1) { + return 0; + } + + std::vector token_indices; + token_indices.reserve(batch.n_tokens); + for (int i = 0; i < batch.n_tokens; ++i) { + if (common_speculative_batch_token_has_seq_id(batch, i, seq_id)) { + token_indices.push_back(i); + } + } + + if (token_indices.empty()) { + return 0; + } + + seq_batch = llama_batch_init((int) token_indices.size(), 0, 1); + for (const int i : token_indices) { + common_batch_add(seq_batch, batch.token[i], batch.pos[i], { seq_id }, batch.logits != nullptr && batch.logits[i]); + } + + return (int) token_indices.size(); +} + +static bool common_speculative_feature_view_copy_batch_rows( + const common_speculative_feature_view & view, + const llama_batch & batch, + llama_seq_id seq_id, + std::vector * hidden_rows) { + if (hidden_rows == nullptr || view.kind != COMMON_SPECULATIVE_FEATURE_HIDDEN_STATE || view.width <= 0 || batch.n_tokens <= 0 || batch.pos == nullptr) { + return false; + } + + std::unordered_map rows_by_pos; + rows_by_pos.reserve(view.rows.size()); + for (const auto & row : view.rows) { + if (row.seq_id == seq_id && row.data != nullptr) { + rows_by_pos[row.pos] = row.data; + } + } + + hidden_rows->clear(); + hidden_rows->reserve((size_t) batch.n_tokens * view.width); + for (int i = 0; i < batch.n_tokens; ++i) { + auto it = rows_by_pos.find(batch.pos[i]); + if (it == rows_by_pos.end()) { + hidden_rows->clear(); + return false; + } + + hidden_rows->insert(hidden_rows->end(), it->second, it->second + view.width); + } + + return hidden_rows->size() == (size_t) batch.n_tokens * view.width; +} static bool common_speculative_capture_target_features( common_speculative * spec, @@ -1767,286 +2922,6 @@ static void mtp_clear_target_hidden(common_speculative_state_mtp & state, llama_ state.draft_cache_by_seq.erase(seq_id); } -// DFlash target-window replay and maintenance helpers. -struct dflash_append_breakdown { - uint64_t filter_us = 0; - uint64_t window_alloc_us = 0; - uint64_t replace_us = 0; - uint64_t keep_old_us = 0; - uint64_t new_rows_us = 0; - uint64_t commit_us = 0; - uint64_t log_us = 0; - bool replace_call = false; -}; - -static void dflash_record_window_update( - common_speculative_state_dflash & state, - int32_t keep_rows, - int32_t append_rows, - bool replace) { - state.target_window_keep_rows = std::max(0, keep_rows); - state.target_window_append_rows = std::max(0, append_rows); - state.target_window_replace = replace; - state.target_window_version++; -} - -static void dflash_ring_reset_rows( - common_speculative_state_dflash & state, - const float * rows, - int32_t n_rows) { - const size_t row_width = (size_t) state.n_target_features; - if (n_rows <= 0 || rows == nullptr) { - state.target_window_ring_write_pos = 0; - state.target_window_ring_filled = 0; - return; - } - - if (state.target_window_ring.size() != (size_t) state.cross_ctx * row_width) { - state.target_window_ring.resize((size_t) state.cross_ctx * row_width); - } - - std::memcpy(state.target_window_ring.data(), rows, (size_t) n_rows * row_width * sizeof(float)); - state.target_window_ring_write_pos = n_rows % state.cross_ctx; - state.target_window_ring_filled = n_rows; - state.target_window_materialized = false; -} - -static void dflash_ring_append_rows( - common_speculative_state_dflash & state, - const float * rows, - int32_t n_rows) { - const size_t row_width = (size_t) state.n_target_features; - if (n_rows <= 0 || rows == nullptr) { - return; - } - - if (state.target_window_ring.size() != (size_t) state.cross_ctx * row_width) { - state.target_window_ring.resize((size_t) state.cross_ctx * row_width); - } - - int32_t write_pos = state.target_window_ring_write_pos; - int32_t remaining = n_rows; - const float * src = rows; - while (remaining > 0) { - const int32_t chunk_rows = std::min(remaining, state.cross_ctx - write_pos); - std::memcpy( - state.target_window_ring.data() + (size_t) write_pos * row_width, - src, - (size_t) chunk_rows * row_width * sizeof(float)); - src += (size_t) chunk_rows * row_width; - remaining -= chunk_rows; - write_pos = (write_pos + chunk_rows) % state.cross_ctx; - } - - state.target_window_ring_write_pos = write_pos; - state.target_window_ring_filled = std::min(state.cross_ctx, state.target_window_ring_filled + n_rows); - state.target_window_materialized = false; -} - -static void dflash_materialize_target_window_features(common_speculative_state_dflash & state) { - if (state.target_window_materialized || state.target_window_rows <= 0) { - return; - } - - const size_t row_width = (size_t) state.n_target_features; - state.target_window.resize((size_t) state.target_window_rows * row_width); - - const int32_t read_start = (state.target_window_ring_write_pos - state.target_window_rows + state.cross_ctx) % state.cross_ctx; - const int32_t first_rows = std::min(state.target_window_rows, state.cross_ctx - read_start); - std::memcpy( - state.target_window.data(), - state.target_window_ring.data() + (size_t) read_start * row_width, - (size_t) first_rows * row_width * sizeof(float)); - - const int32_t second_rows = state.target_window_rows - first_rows; - if (second_rows > 0) { - std::memcpy( - state.target_window.data() + (size_t) first_rows * row_width, - state.target_window_ring.data(), - (size_t) second_rows * row_width * sizeof(float)); - } - - state.target_window_materialized = true; -} - -static bool dflash_append_target_features( - common_speculative_state_dflash & state, - const common_speculative_feature_view & features, - const llama_batch & batch, - llama_seq_id seq_id, - dflash_append_breakdown * breakdown = nullptr) { - GGML_UNUSED(batch); - - if (features.kind != COMMON_SPECULATIVE_FEATURE_HIDDEN_STATE || - features.width != state.n_target_features || - features.rows.empty() || - state.cross_ctx <= 0) { - return false; - } - - const size_t row_width = (size_t) state.n_target_features; - std::vector new_rows; - std::vector new_positions; - new_rows.reserve(features.rows.size() * row_width); - new_positions.reserve(features.rows.size()); - - const int64_t t_filter_us = ggml_time_us(); - for (const auto & row : features.rows) { - if (row.seq_id != seq_id || row.data == nullptr) { - continue; - } - - new_positions.push_back(row.pos); - new_rows.insert(new_rows.end(), row.data, row.data + row_width); - } - if (breakdown != nullptr) { - breakdown->filter_us += (uint64_t) (ggml_time_us() - t_filter_us); - } - - if (new_positions.empty()) { - return false; - } - - const int32_t n_rows = (int32_t) new_positions.size(); - state.n_window_updates++; - state.n_rows_seen += (size_t) n_rows; - if (n_rows >= state.cross_ctx) { - state.n_rows_dropped += (size_t) state.target_window_rows + (size_t) (n_rows - state.cross_ctx); - const int32_t keep_from = n_rows - state.cross_ctx; - const int64_t t_replace_us = ggml_time_us(); - state.target_window_pos.assign(new_positions.begin() + keep_from, new_positions.end()); - state.target_window_append_features.assign( - new_rows.begin() + (ptrdiff_t) keep_from * (ptrdiff_t) row_width, - new_rows.end()); - dflash_ring_reset_rows(state, state.target_window_append_features.data(), state.cross_ctx); - if (breakdown != nullptr) { - breakdown->replace_us += (uint64_t) (ggml_time_us() - t_replace_us); - breakdown->replace_call = true; - } - - const int64_t t_commit_us = ggml_time_us(); - state.target_window_rows = state.cross_ctx; - state.target_window_ring_filled = state.target_window_rows; - state.last_target_pos = state.target_window_pos.empty() ? -1 : state.target_window_pos.back(); - dflash_record_window_update(state, 0, state.target_window_rows, true); - if (breakdown != nullptr) { - breakdown->commit_us += (uint64_t) (ggml_time_us() - t_commit_us); - } - - const int64_t t_log_us = ggml_time_us(); - dflash_contract_log_append(state, seq_id, new_positions); - if (breakdown != nullptr) { - breakdown->log_us += (uint64_t) (ggml_time_us() - t_log_us); - } - return true; - } - - const int32_t keep_old_rows = std::min(state.target_window_rows, state.cross_ctx - n_rows); - state.n_rows_dropped += (size_t) std::max(0, state.target_window_rows - keep_old_rows); - const int64_t t_window_alloc_us = ggml_time_us(); - std::vector & next_window_pos = state.target_window_pos_stage; - next_window_pos.resize((size_t) (keep_old_rows + n_rows)); - if (breakdown != nullptr) { - breakdown->window_alloc_us += (uint64_t) (ggml_time_us() - t_window_alloc_us); - } - - if (keep_old_rows > 0) { - const int64_t t_keep_old_us = ggml_time_us(); - std::copy(state.target_window_pos.end() - keep_old_rows, state.target_window_pos.end(), next_window_pos.begin()); - if (breakdown != nullptr) { - breakdown->keep_old_us += (uint64_t) (ggml_time_us() - t_keep_old_us); - } - } - - const int64_t t_new_rows_us = ggml_time_us(); - state.target_window_append_features.assign(new_rows.begin(), new_rows.end()); - dflash_ring_append_rows(state, state.target_window_append_features.data(), n_rows); - std::copy(new_positions.begin(), new_positions.end(), next_window_pos.begin() + keep_old_rows); - if (breakdown != nullptr) { - breakdown->new_rows_us += (uint64_t) (ggml_time_us() - t_new_rows_us); - } - - const int64_t t_commit_us = ggml_time_us(); - state.target_window_pos.swap(next_window_pos); - next_window_pos.clear(); - state.target_window_rows = keep_old_rows + n_rows; - state.target_window_ring_filled = state.target_window_rows; - state.last_target_pos = state.target_window_pos.empty() ? -1 : state.target_window_pos.back(); - dflash_record_window_update(state, keep_old_rows, n_rows, false); - if (breakdown != nullptr) { - breakdown->commit_us += (uint64_t) (ggml_time_us() - t_commit_us); - } - - const int64_t t_log_us = ggml_time_us(); - dflash_contract_log_append(state, seq_id, new_positions); - if (breakdown != nullptr) { - breakdown->log_us += (uint64_t) (ggml_time_us() - t_log_us); - } - return true; -} - -static void dflash_clear_target_features(common_speculative_state_dflash & state) { - state.target_window.clear(); - state.target_window_pos.clear(); - state.target_window_stage.clear(); - state.target_window_pos_stage.clear(); - state.target_window_append_features.clear(); - state.target_window_rows = 0; - state.target_window_ring_write_pos = 0; - state.target_window_ring_filled = 0; - state.target_window_keep_rows = 0; - state.target_window_append_rows = 0; - state.target_window_replace = false; - state.target_window_materialized = false; - state.last_target_pos = -1; - llama_reset_dflash_kv_cache_state(state.ctx_dft); -} - -static void dflash_context_shift( - common_speculative_state_dflash & state, - llama_pos kv_keep, - llama_pos kv_discard, - llama_pos kv_past) { - if (kv_discard <= 0 || state.target_window_rows <= 0 || state.target_window_pos.empty()) { - return; - } - - dflash_materialize_target_window_features(state); - - const size_t row_width = (size_t) state.n_target_features; - const llama_pos discard_begin = kv_keep; - const llama_pos discard_end = kv_keep + kv_discard; - - std::vector shifted_rows; - std::vector shifted_positions; - shifted_rows.reserve(state.target_window.size()); - shifted_positions.reserve(state.target_window_pos.size()); - - for (int32_t row = 0; row < state.target_window_rows; ++row) { - llama_pos pos = state.target_window_pos[(size_t) row]; - if (pos >= discard_begin && pos < discard_end) { - continue; - } - - if (pos >= discard_end && pos < kv_past) { - pos -= kv_discard; - } - - const float * row_src = state.target_window.data() + (size_t) row * row_width; - shifted_rows.insert(shifted_rows.end(), row_src, row_src + row_width); - shifted_positions.push_back(pos); - } - - state.target_window = std::move(shifted_rows); - state.target_window_pos = std::move(shifted_positions); - state.target_window_rows = (int32_t) state.target_window_pos.size(); - dflash_ring_reset_rows(state, state.target_window.data(), state.target_window_rows); - state.last_target_pos = state.target_window_pos.empty() ? -1 : state.target_window_pos.back(); - dflash_record_window_update(state, 0, state.target_window_rows, true); - llama_reset_dflash_kv_cache_state(state.ctx_dft); - state.n_context_shifts++; -} - static bool common_speculative_capture_target_features(common_speculative * spec, const common_speculative_feature_view & features) { auto * mtp_state = common_speculative_get_mtp_state(spec); if (mtp_state == nullptr || features.kind != COMMON_SPECULATIVE_FEATURE_HIDDEN_STATE || features.width <= 0) { diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 3cd8f9fe..e615d85d 100644 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -570,6 +570,9 @@ class Model: if chkhsh == "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273": # ref: https://huggingface.co/alvarobartt/grok-2-tokenizer res = "grok-2" + if chkhsh == "972da7b59cec44d1f0a490a86c96df53859e486e481563e5dddac155013d87ac": + # ref: https://huggingface.co/poolside/Laguna-XS.2 + res = "laguna" if chkhsh == "0ef9807a4087ebef797fc749390439009c3b9eda9ad1a097abbe738f486c01e5": # ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B res = "llama-bpe" @@ -603,6 +606,9 @@ class Model: if chkhsh == "9c2227e4dd922002fb81bde4fc02b0483ca4f12911410dee2255e4987644e3f8": # ref: https://huggingface.co/CohereForAI/c4ai-command-r-v01 res = "command-r" + if chkhsh == "52df12b4c8d4176e7481aab4b6e8454d1fd0a210a04a574f6d4e067d10e23c3e": + # ref: https://huggingface.co/CohereLabs/North-Mini-Code-1.0 + res = "cohere2_moe" if chkhsh == "e636dc30a262dcc0d8c323492e32ae2b70728f4df7dfe9737d9f920a282b8aea": # ref: https://huggingface.co/Qwen/Qwen1.5-7B res = "qwen2" @@ -684,6 +690,9 @@ class Model: if chkhsh == "f4f37b6c8eb9ea29b3eac6bb8c8487c5ab7885f8d8022e67edc1c68ce8403e95": # ref: https://huggingface.co/MiniMaxAI/MiniMax-M2 res = "minimax-m2" + if chkhsh == "9dcf830ee9990cdbf78cc523a5f7bd9ad8f3f9890c2d3581d2785ad10f07049d": + # ref: https://huggingface.co/JetBrains/Mellum2-12B-A2.5B-Base + res = "mellum2" if res is None: logger.warning("\n") logger.warning("**************************************************************************************") @@ -1580,7 +1589,12 @@ class LlamaModel(Model): special_vocab.add_to_gguf(self.gguf_writer) def set_gguf_parameters(self): + saved_intermediate_size = self.hparams.get("intermediate_size") + saved_num_experts_per_tok = self.hparams.pop("num_experts_per_tok") + self.hparams["intermediate_size"] = self.hparams["prefix_dense_intermediate_size"] super().set_gguf_parameters() + self.hparams["intermediate_size"] = saved_intermediate_size + self.hparams["num_experts_per_tok"] = saved_num_experts_per_tok hparams = self.hparams self.gguf_writer.add_vocab_size(hparams["vocab_size"]) @@ -3735,7 +3749,7 @@ class Gemma4Model(Gemma4BaseModel): return [(self.map_tensor_name(name), data_torch)] -@Model.register("Gemma4AssistantForCausalLM") +@Model.register("Gemma4AssistantForCausalLM", "Gemma4UnifiedAssistantForCausalLM") class Gemma4AssistantModel(Gemma4BaseModel): model_arch = gguf.MODEL_ARCH.GEMMA4_MTP @@ -3950,6 +3964,86 @@ class CommandR2Model(Model): self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) +@Model.register("Cohere2MoeForCausalLM") +class Cohere2MoeModel(Model): + model_arch = gguf.MODEL_ARCH.COHERE2_MOE + + _experts: list[dict[str, Tensor]] | None = None + + def set_gguf_parameters(self): + saved_intermediate_size = self.hparams["intermediate_size"] + saved_num_experts_per_tok = self.hparams.pop("num_experts_per_tok") + self.hparams["intermediate_size"] = self.hparams["prefix_dense_intermediate_size"] + super().set_gguf_parameters() + self.hparams["intermediate_size"] = saved_intermediate_size + self.hparams["num_experts_per_tok"] = saved_num_experts_per_tok + hparams = self.hparams + + self.gguf_writer.add_vocab_size(hparams["vocab_size"]) + self.gguf_writer.add_logit_scale(hparams.get("logit_scale", 1.0)) + self.gguf_writer.add_sliding_window(hparams["sliding_window"]) + self.gguf_writer.add_sliding_window_pattern([ + layer_type == "sliding_attention" + for layer_type in hparams["layer_types"] + ]) + self.gguf_writer.add_rope_dimension_count(hparams.get("head_dim", hparams["hidden_size"] // hparams["num_attention_heads"])) + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) + + self.gguf_writer.add_expert_feed_forward_length(hparams["intermediate_size"]) + self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"]) + self.gguf_writer.add_expert_count(hparams["num_experts"]) + self.gguf_writer.add_expert_used_count(hparams["num_experts_per_tok"]) + self.gguf_writer.add_expert_weights_norm(bool(hparams.get("norm_topk_prob", False))) + + expert_selection_fn = hparams.get("expert_selection_fn", "softmax") + if expert_selection_fn != "sigmoid": + raise ValueError(f"Unsupported Cohere2-MoE expert_selection_fn={expert_selection_fn!r}") + self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID) + + if hparams.get("num_shared_experts", 0) != 0: + raise ValueError("Cohere2-MoE shared experts are not supported in this GGUF converter yet") + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # Cohere2-MoE HF tensors already use the interleaved RoPE layout expected here. + + if ".mlp.experts." in name: + n_experts = self.hparams["num_experts"] + assert bid is not None + + if self._experts is None: + self._experts = [{} for _ in range(self.block_count)] + + self._experts[bid][name] = data_torch + + if len(self._experts[bid]) < n_experts * 3: + return [] + + tensors: list[tuple[str, Tensor]] = [] + for src, dst in [ + ("gate_proj", "gate_proj"), + ("down_proj", "down_proj"), + ("up_proj", "up_proj"), + ]: + datas: list[Tensor] = [] + for xid in range(n_experts): + ename = f"model.layers.{bid}.mlp.experts.{xid}.{src}.weight" + datas.append(self._experts[bid][ename]) + del self._experts[bid][ename] + + merged_name = f"model.layers.{bid}.mlp.experts.{dst}.weight" + tensors.append((self.map_tensor_name(merged_name), torch.stack(datas, dim=0))) + yield from tensors + return + + if name == "model.embed_tokens.weight": + yield self.map_tensor_name(name), data_torch + if self.tensor_names is None or "lm_head.weight" not in self.tensor_names: + yield self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT, suffix=".weight"), data_torch + return + + yield self.map_tensor_name(name), data_torch + + @Model.register("OlmoForCausalLM") @Model.register("OLMoForCausalLM") class OlmoModel(Model): @@ -5307,6 +5401,138 @@ class BailingMoeV2Model(Model): raise ValueError(f"Unprocessed experts: {experts}") +@Model.register("LagunaForCausalLM") +class LagunaModel(Model): + model_arch = gguf.MODEL_ARCH.LAGUNA + + _experts: list[dict[str, Tensor]] | None = None + + def set_gguf_parameters(self): + hparams = self.hparams + arch = gguf.MODEL_ARCH_NAMES[self.model_arch] + n_layers = int(hparams["num_hidden_layers"]) + n_head_base = int(hparams["num_attention_heads"]) + n_kv_base = int(hparams.get("num_key_value_heads", n_head_base)) + head_dim = int(hparams.get("head_dim", hparams["hidden_size"] // n_head_base)) + + heads_per_layer = hparams.get("num_attention_heads_per_layer") + kv_per_layer = hparams.get("num_key_value_heads_per_layer") + + head_arr: list[int] = [] + kv_arr: list[int] = [] + for i in range(n_layers): + head_arr.append(int(heads_per_layer[i]) if heads_per_layer is not None else n_head_base) + kv_arr.append(int(kv_per_layer[i]) if kv_per_layer is not None else n_kv_base) + + rope_params = hparams.get("rope_parameters", {}) + full_rope = rope_params.get("full_attention", rope_params) + swa_rope = rope_params.get("sliding_attention", {}) + + self.gguf_writer.add_context_length(int(hparams["max_position_embeddings"])) + self.gguf_writer.add_embedding_length(int(hparams["hidden_size"])) + self.gguf_writer.add_block_count(n_layers) + self.gguf_writer.add_feed_forward_length(int(hparams["intermediate_size"])) + self.gguf_writer.add_head_count(head_arr) + if all(n_kv == kv_arr[0] for n_kv in kv_arr): + self.gguf_writer.add_head_count_kv(kv_arr[0]) + else: + self.gguf_writer.add_head_count_kv(kv_arr) + self.gguf_writer.add_key_length(head_dim) + self.gguf_writer.add_value_length(head_dim) + self.gguf_writer.add_layer_norm_rms_eps(float(hparams["rms_norm_eps"])) + self.gguf_writer.add_file_type(self.ftype) + + self.gguf_writer.add_sliding_window(int(hparams["sliding_window"])) + self.gguf_writer.add_rope_dimension_count(head_dim // 2) + self.gguf_writer.add_uint32(f"{arch}.rope.dimension_count_swa", head_dim) + self.gguf_writer.add_rope_freq_base(float(full_rope.get("rope_theta", 500000.0))) + self.gguf_writer.add_float32(f"{arch}.rope.freq_base_swa", float(swa_rope.get("rope_theta", 10000.0))) + if full_rope.get("rope_type") == "yarn": + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) + self.gguf_writer.add_rope_scaling_factor(float(full_rope.get("factor", 1.0))) + self.gguf_writer.add_rope_scaling_orig_ctx_len(int(full_rope.get( + "original_max_position_embeddings", + rope_params.get("original_max_position_embeddings", hparams["max_position_embeddings"]), + ))) + self.gguf_writer.add_rope_scaling_yarn_ext_factor(float(full_rope.get("factor", 1.0))) + self.gguf_writer.add_rope_scaling_yarn_attn_factor(float(full_rope.get("attention_factor", 1.0))) + self.gguf_writer.add_rope_scaling_yarn_beta_fast(float(full_rope.get("beta_fast", 32.0))) + self.gguf_writer.add_rope_scaling_yarn_beta_slow(float(full_rope.get("beta_slow", 1.0))) + + self.gguf_writer.add_expert_count(int(hparams["num_experts"])) + self.gguf_writer.add_expert_used_count(int(hparams["num_experts_per_tok"])) + self.gguf_writer.add_expert_feed_forward_length(int(hparams["moe_intermediate_size"])) + if (shared_dim := hparams.get("shared_expert_intermediate_size")) is not None and int(shared_dim) > 0: + self.gguf_writer.add_expert_shared_feed_forward_length(int(shared_dim)) + if (routing_scale := hparams.get("moe_routed_scaling_factor")) is not None: + self.gguf_writer.add_expert_weights_scale(float(routing_scale)) + self.gguf_writer.add_expert_weights_norm(True) + self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID) + + leading_dense = 0 + for mlp_type in hparams.get("mlp_layer_types", []): + if mlp_type != "dense": + break + leading_dense += 1 + self.gguf_writer.add_uint32(f"{arch}.leading_dense_block_count", leading_dense) + + if hparams.get("moe_apply_router_weight_on_input", False): + raise ValueError("moe_apply_router_weight_on_input=True is not supported for Laguna") + + def set_vocab(self) -> None: + super().set_vocab() + if isinstance(eos_token_id := self.hparams.get("eos_token_id"), list) and len(eos_token_id) > 1: + # Poolside uses token 24 () as a turn boundary. + self.gguf_writer.add_eot_token_id(int(eos_token_id[1])) + template_file = self.dir_model / "chat_template.jinja" + if template_file.is_file(): + self.gguf_writer.add_chat_template(template_file.read_text(encoding="utf-8")) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if bid is not None and name in ( + f"model.layers.{bid}.mlp.experts.e_score_correction_bias", + f"model.layers.{bid}.mlp.experts.e_score_correction", + ): + # The C++ loader asks for this tensor through the ".bias" suffix. + # Keep the Laguna converter aligned with existing community GGUFs. + yield f"blk.{bid}.exp_probs_b.bias", data_torch + return + + if name.endswith(".self_attn.g_proj.weight"): + # HF stores the head-wise attention gate with a singleton dimension. + data_torch = data_torch.squeeze().contiguous() + + if bid is not None and re.match(r"model\.layers\.\d+\.mlp\.experts\.\d+\.(gate_proj|up_proj|down_proj)\.weight$", name): + n_experts = int(self.find_hparam(["num_experts"])) + if self._experts is None: + self._experts = [{} for _ in range(self.block_count)] + + self._experts[bid][name] = data_torch + if len(self._experts[bid]) < n_experts * 3: + return + + for w_name in ("down_proj", "gate_proj", "up_proj"): + datas: list[Tensor] = [] + for xid in range(n_experts): + ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight" + datas.append(self._experts[bid][ename]) + del self._experts[bid][ename] + + merged = torch.stack(datas, dim=0) + merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight" + yield from super().modify_tensors(merged, merged_name, bid) + return + + yield from super().modify_tensors(data_torch, name, bid) + + def prepare_tensors(self): + super().prepare_tensors() + if self._experts is not None: + experts = [k for d in self._experts for k in d.keys()] + if experts: + raise ValueError(f"Unprocessed experts: {experts}") + + ###### CONVERSION LOGIC ###### diff --git a/examples/server/server-context.cpp b/examples/server/server-context.cpp index dcf3469d..7be87f97 100644 --- a/examples/server/server-context.cpp +++ b/examples/server/server-context.cpp @@ -6,6 +6,7 @@ #include "common.h" #include "llama.h" +#include "llama-spec-features.h" #include "log.h" #include "sampling.h" #include "speculative.h" @@ -13,12 +14,8 @@ #include "mtmd-helper.h" #include -#include -#include -#include #include #include -#include #include static void server_prompt_checkpoint_update(server_prompt_checkpoint & ckpt, llama_context * ctx, int id, int64_t n_tokens, llama_pos pos_min = -1, llama_pos pos_max = -1, int32_t offset = 0) { @@ -49,83 +46,6 @@ static void log_text(const gpt_params & params_base, const std::string & text) { } } -static bool server_dflash_contract_log_enabled() { - const char * env = std::getenv("IK_DFLASH_CONTRACT_LOG"); - if (env == nullptr || *env == '\0') { - return false; - } - - return std::strcmp(env, "0") != 0 && - std::strcmp(env, "false") != 0 && - std::strcmp(env, "off") != 0; -} - -static std::string server_dflash_contract_format_indices( - const std::vector & values, - size_t edge_count = 4) { - std::ostringstream oss; - oss << '['; - if (values.empty()) { - oss << ']'; - return oss.str(); - } - - const size_t head = std::min(edge_count, values.size()); - for (size_t i = 0; i < head; ++i) { - if (i > 0) { - oss << ','; - } - oss << values[i]; - } - - if (values.size() > edge_count * 2) { - oss << ",...,"; - for (size_t i = values.size() - edge_count; i < values.size(); ++i) { - if (i > values.size() - edge_count) { - oss << ','; - } - oss << values[i]; - } - } else { - for (size_t i = head; i < values.size(); ++i) { - oss << ',' << values[i]; - } - } - - oss << ']'; - return oss.str(); -} - -static void server_dflash_contract_log_accept( - const server_slot & slot, - common_speculative_type spec_type_used, - const char * path, - bool any_rejected, - size_t n_draft, - const std::vector & ids, - llama_pos pos_base, - const std::vector & output_indices) { - if (!server_dflash_contract_log_enabled() || spec_type_used != COMMON_SPECULATIVE_TYPE_DFLASH) { - return; - } - - static std::atomic counter = 0; - const uint64_t ordinal = counter.fetch_add(1, std::memory_order_relaxed); - if (ordinal >= 8) { - return; - } - - LLAMA_LOG_INFO("dflash contract accept[%llu]: slot=%d path=%s rejected=%s drafted=%zu accepted=%zu pos_base=%d output_indices=%s\n", - (unsigned long long) (ordinal + 1), - slot.id, - path, - any_rejected ? "true" : "false", - n_draft, - ids.size(), - (int) pos_base, - server_dflash_contract_format_indices(output_indices).c_str()); -} - static bool server_slot_prompt_batch_overlaps( const server_slot & slot, int32_t batch_i0, @@ -158,12 +78,13 @@ static bool server_response_needs_chat_parse(oaicompat_type oaicompat) { oaicompat == OAICOMPAT_TYPE_RESP; } -static bool server_speculative_has_dflash(const common_params_speculative & spec) { - return spec.has_stage_type(COMMON_SPECULATIVE_TYPE_DFLASH); +static bool server_speculative_uses_target_features(const common_params_speculative & spec) { + return spec.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP) || + spec.has_stage_type(COMMON_SPECULATIVE_TYPE_DFLASH); } -static bool server_speculative_has_target_features(const common_params_speculative & spec) { - return spec.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP) || server_speculative_has_dflash(spec); +static bool server_speculative_requires_single_slot(const common_params_speculative & spec) { + return spec.has_stage_chain(); } static bool server_speculative_same_stage_types( @@ -295,8 +216,8 @@ bool server_context::load_model(const gpt_params& params_) { common_speculative_prepare_startup(params_base, false); - if (server_speculative_has_dflash(params_base.speculative) && params_base.n_parallel > 1) { - LOG_ERROR("DFlash is currently limited to a single server slot (-np 1).\n", { + if (server_speculative_requires_single_slot(params_base.speculative) && params_base.n_parallel > 1) { + LOG_ERROR("Speculative decoding is currently limited to a single server slot (-np 1).\n", { {"n_parallel", params_base.n_parallel}, }); return false; @@ -4127,7 +4048,7 @@ void server_context::speculative_decoding_accept() { } std::vector accepted_output_indices; - if (server_speculative_has_target_features(slot.params.speculative)) { + if (server_speculative_uses_target_features(slot.params.speculative)) { if (!ids.empty()) { accepted_output_indices.assign(slot.i_batch_dft.begin(), slot.i_batch_dft.begin() + ids.size()); } @@ -4160,14 +4081,14 @@ void server_context::speculative_decoding_accept() { const common_speculative_checkpoint * ckpt = common_speculative_get_checkpoint(slot.spec); const bool will_restore = any_rejected && ckpt != nullptr && ckpt->valid; - if (server_speculative_has_target_features(slot.params.speculative) && !accepted_output_indices.empty()) { - server_dflash_contract_log_accept( - slot, - spec_type_used, + if (server_speculative_uses_target_features(slot.params.speculative) && !accepted_output_indices.empty()) { + llama_dflash_contract_log_accept( + slot.id, + spec_type_used == COMMON_SPECULATIVE_TYPE_DFLASH, will_restore ? "restore" : "direct", any_rejected, n_draft, - ids, + ids.size(), spec_pos_base, accepted_output_indices); } @@ -4556,9 +4477,9 @@ void server_context::process_batch_tokens(int32_t & n_batch) { continue; // continue loop of n_batch } - if (server_speculative_has_target_features(params_base.speculative)) { + if (server_speculative_uses_target_features(params_base.speculative)) { for (auto & slot : slots) { - if (!slot.spec || !server_speculative_has_target_features(slot.params.speculative)) { + if (!slot.spec || !server_speculative_uses_target_features(slot.params.speculative)) { continue; } diff --git a/src/llama-dflash-profile.h b/src/llama-dflash-profile.h new file mode 100644 index 00000000..4d998488 --- /dev/null +++ b/src/llama-dflash-profile.h @@ -0,0 +1,340 @@ +#pragma once + +#include +#include + +enum llama_dflash_kv_node_kind { + LLAMA_DFLASH_KV_NODE_NONE = 0, + LLAMA_DFLASH_KV_NODE_FUSED_TARGET, + LLAMA_DFLASH_KV_NODE_K_PROJ, + LLAMA_DFLASH_KV_NODE_K_NORM, + LLAMA_DFLASH_KV_NODE_K_ROPE, + LLAMA_DFLASH_KV_NODE_V_PROJ, + LLAMA_DFLASH_KV_NODE_K_STORE, + LLAMA_DFLASH_KV_NODE_V_STORE, +}; + +enum llama_dflash_main_node_kind { + LLAMA_DFLASH_MAIN_NODE_NONE = 0, + LLAMA_DFLASH_MAIN_NODE_QCUR, + LLAMA_DFLASH_MAIN_NODE_K_DRAFT, + LLAMA_DFLASH_MAIN_NODE_V_DRAFT, + LLAMA_DFLASH_MAIN_NODE_K_CTX_VIEW, + LLAMA_DFLASH_MAIN_NODE_V_CTX_VIEW, + LLAMA_DFLASH_MAIN_NODE_K_CONCAT, + LLAMA_DFLASH_MAIN_NODE_V_CONCAT, + LLAMA_DFLASH_MAIN_NODE_K_PAD, + LLAMA_DFLASH_MAIN_NODE_V_PAD, + LLAMA_DFLASH_MAIN_NODE_K_PERM_CONT, + LLAMA_DFLASH_MAIN_NODE_V_PERM_CONT, + LLAMA_DFLASH_MAIN_NODE_FLASH_ATTN, + LLAMA_DFLASH_MAIN_NODE_ATTN_OUT, + LLAMA_DFLASH_MAIN_NODE_FFN, + LLAMA_DFLASH_MAIN_NODE_RESULT_ROWS, + LLAMA_DFLASH_MAIN_NODE_RESULT_NORM, + LLAMA_DFLASH_MAIN_NODE_RESULT, +}; + +struct llama_dflash_kv_node_profiler { + llama_dflash_profile_stats * profile = nullptr; + int64_t t_start_us = 0; + llama_dflash_kv_node_kind active_kind = LLAMA_DFLASH_KV_NODE_NONE; +}; + +struct llama_dflash_main_node_profiler { + llama_dflash_profile_stats * profile = nullptr; + ggml_backend_sched_eval_callback prev_callback = nullptr; + void * prev_user_data = nullptr; + bool prev_active = false; + int64_t t_start_us = 0; + llama_dflash_main_node_kind active_kind = LLAMA_DFLASH_MAIN_NODE_NONE; +}; + +static inline bool llama_dflash_tensor_name_has_prefix(const struct ggml_tensor * tensor, const char * prefix) { + if (tensor == nullptr || prefix == nullptr || prefix[0] == '\0') { + return false; + } + + return std::strncmp(tensor->name, prefix, std::strlen(prefix)) == 0; +} + +static inline bool llama_dflash_tensor_name_matches_label(const struct ggml_tensor * tensor, const char * label) { + if (!llama_dflash_tensor_name_has_prefix(tensor, label)) { + return false; + } + + const size_t label_len = std::strlen(label); + const char next = tensor->name[label_len]; + return next == '\0' || next == '-'; +} + +static inline llama_dflash_kv_node_kind llama_dflash_kv_node_kind_from_tensor(const struct ggml_tensor * tensor) { + if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_fused_target")) { + return LLAMA_DFLASH_KV_NODE_FUSED_TARGET; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_proj")) { + return LLAMA_DFLASH_KV_NODE_K_PROJ; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_norm")) { + return LLAMA_DFLASH_KV_NODE_K_NORM; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_rope")) { + return LLAMA_DFLASH_KV_NODE_K_ROPE; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_v_proj")) { + return LLAMA_DFLASH_KV_NODE_V_PROJ; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_store")) { + return LLAMA_DFLASH_KV_NODE_K_STORE; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_v_store")) { + return LLAMA_DFLASH_KV_NODE_V_STORE; + } + + return LLAMA_DFLASH_KV_NODE_NONE; +} + +static inline void llama_dflash_kv_node_profile_add( + llama_dflash_profile_stats & profile, + llama_dflash_kv_node_kind kind, + uint64_t elapsed_us) { + switch (kind) { + case LLAMA_DFLASH_KV_NODE_FUSED_TARGET: + profile.graph_kv_node_fused_target_calls++; + profile.graph_kv_node_fused_target_us += elapsed_us; + break; + case LLAMA_DFLASH_KV_NODE_K_PROJ: + profile.graph_kv_node_k_proj_calls++; + profile.graph_kv_node_k_proj_us += elapsed_us; + break; + case LLAMA_DFLASH_KV_NODE_K_NORM: + profile.graph_kv_node_k_norm_calls++; + profile.graph_kv_node_k_norm_us += elapsed_us; + break; + case LLAMA_DFLASH_KV_NODE_K_ROPE: + profile.graph_kv_node_k_rope_calls++; + profile.graph_kv_node_k_rope_us += elapsed_us; + break; + case LLAMA_DFLASH_KV_NODE_V_PROJ: + profile.graph_kv_node_v_proj_calls++; + profile.graph_kv_node_v_proj_us += elapsed_us; + break; + case LLAMA_DFLASH_KV_NODE_K_STORE: + profile.graph_kv_node_k_store_calls++; + profile.graph_kv_node_k_store_us += elapsed_us; + break; + case LLAMA_DFLASH_KV_NODE_V_STORE: + profile.graph_kv_node_v_store_calls++; + profile.graph_kv_node_v_store_us += elapsed_us; + break; + case LLAMA_DFLASH_KV_NODE_NONE: + break; + } +} + +static inline llama_dflash_main_node_kind llama_dflash_main_node_kind_from_tensor(const struct ggml_tensor * tensor) { + if (llama_dflash_tensor_name_has_prefix(tensor, "Qcur")) { + return LLAMA_DFLASH_MAIN_NODE_QCUR; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "Kcur_noise")) { + return LLAMA_DFLASH_MAIN_NODE_K_DRAFT; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "Vcur_noise")) { + return LLAMA_DFLASH_MAIN_NODE_V_DRAFT; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "Kcur_ctx_cache")) { + return LLAMA_DFLASH_MAIN_NODE_K_CTX_VIEW; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "Vcur_ctx_cache")) { + return LLAMA_DFLASH_MAIN_NODE_V_CTX_VIEW; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_k_concat")) { + return LLAMA_DFLASH_MAIN_NODE_K_CONCAT; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_v_concat")) { + return LLAMA_DFLASH_MAIN_NODE_V_CONCAT; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_k_pad")) { + return LLAMA_DFLASH_MAIN_NODE_K_PAD; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_v_pad")) { + return LLAMA_DFLASH_MAIN_NODE_V_PAD; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_k_perm_cont")) { + return LLAMA_DFLASH_MAIN_NODE_K_PERM_CONT; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_v_perm_cont")) { + return LLAMA_DFLASH_MAIN_NODE_V_PERM_CONT; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "flash_attn_reshaped")) { + return LLAMA_DFLASH_MAIN_NODE_NONE; + } + if (llama_dflash_tensor_name_matches_label(tensor, "flash_attn")) { + return LLAMA_DFLASH_MAIN_NODE_FLASH_ATTN; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "kqv_out")) { + return LLAMA_DFLASH_MAIN_NODE_ATTN_OUT; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "ffn_out")) { + return LLAMA_DFLASH_MAIN_NODE_FFN; + } + if (llama_dflash_tensor_name_matches_label(tensor, "result_output_rows")) { + return LLAMA_DFLASH_MAIN_NODE_RESULT_ROWS; + } + if (llama_dflash_tensor_name_matches_label(tensor, "result_norm")) { + return LLAMA_DFLASH_MAIN_NODE_RESULT_NORM; + } + if (llama_dflash_tensor_name_matches_label(tensor, "output")) { + return LLAMA_DFLASH_MAIN_NODE_RESULT; + } + if (llama_dflash_tensor_name_matches_label(tensor, "result_output")) { + return LLAMA_DFLASH_MAIN_NODE_RESULT; + } + + return LLAMA_DFLASH_MAIN_NODE_NONE; +} + +static inline void llama_dflash_main_node_profile_add( + llama_dflash_profile_stats & profile, + llama_dflash_main_node_kind kind, + uint64_t elapsed_us) { + switch (kind) { + case LLAMA_DFLASH_MAIN_NODE_QCUR: + profile.graph_main_node_qcur_calls++; + profile.graph_main_node_qcur_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_K_DRAFT: + profile.graph_main_node_k_draft_calls++; + profile.graph_main_node_k_draft_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_V_DRAFT: + profile.graph_main_node_v_draft_calls++; + profile.graph_main_node_v_draft_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_K_CTX_VIEW: + profile.graph_main_node_k_ctx_view_calls++; + profile.graph_main_node_k_ctx_view_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_V_CTX_VIEW: + profile.graph_main_node_v_ctx_view_calls++; + profile.graph_main_node_v_ctx_view_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_K_CONCAT: + profile.graph_main_node_k_concat_calls++; + profile.graph_main_node_k_concat_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_V_CONCAT: + profile.graph_main_node_v_concat_calls++; + profile.graph_main_node_v_concat_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_K_PAD: + profile.graph_main_node_k_pad_calls++; + profile.graph_main_node_k_pad_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_V_PAD: + profile.graph_main_node_v_pad_calls++; + profile.graph_main_node_v_pad_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_K_PERM_CONT: + profile.graph_main_node_k_perm_cont_calls++; + profile.graph_main_node_k_perm_cont_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_V_PERM_CONT: + profile.graph_main_node_v_perm_cont_calls++; + profile.graph_main_node_v_perm_cont_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_FLASH_ATTN: + profile.graph_main_node_flash_attn_calls++; + profile.graph_main_node_flash_attn_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_ATTN_OUT: + profile.graph_main_node_attn_out_calls++; + profile.graph_main_node_attn_out_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_FFN: + profile.graph_main_node_ffn_calls++; + profile.graph_main_node_ffn_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_RESULT_ROWS: + profile.graph_main_node_result_rows_calls++; + profile.graph_main_node_result_rows_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_RESULT_NORM: + profile.graph_main_node_result_norm_calls++; + profile.graph_main_node_result_norm_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_RESULT: + profile.graph_main_node_result_calls++; + profile.graph_main_node_result_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_NONE: + break; + } +} + +static inline bool llama_dflash_kv_node_eval_callback(struct ggml_tensor * tensor, bool ask, void * user_data) { + auto * profiler = static_cast(user_data); + if (profiler == nullptr || profiler->profile == nullptr) { + return false; + } + + const llama_dflash_kv_node_kind kind = llama_dflash_kv_node_kind_from_tensor(tensor); + if (ask) { + if (kind == LLAMA_DFLASH_KV_NODE_NONE) { + return false; + } + + profiler->active_kind = kind; + profiler->t_start_us = ggml_time_us(); + return true; + } + + if (kind != LLAMA_DFLASH_KV_NODE_NONE && profiler->active_kind == kind && profiler->t_start_us > 0) { + llama_dflash_kv_node_profile_add(*profiler->profile, kind, (uint64_t) (ggml_time_us() - profiler->t_start_us)); + } + + profiler->active_kind = LLAMA_DFLASH_KV_NODE_NONE; + profiler->t_start_us = 0; + return true; +} + +static inline bool llama_dflash_main_node_eval_callback(struct ggml_tensor * tensor, bool ask, void * user_data) { + auto * profiler = static_cast(user_data); + if (profiler == nullptr || profiler->profile == nullptr) { + return false; + } + + const llama_dflash_main_node_kind kind = llama_dflash_main_node_kind_from_tensor(tensor); + if (ask) { + profiler->prev_active = profiler->prev_callback != nullptr + ? profiler->prev_callback(tensor, ask, profiler->prev_user_data) + : false; + + if (kind == LLAMA_DFLASH_MAIN_NODE_NONE) { + profiler->active_kind = LLAMA_DFLASH_MAIN_NODE_NONE; + profiler->t_start_us = 0; + return profiler->prev_active; + } + + profiler->active_kind = kind; + profiler->t_start_us = ggml_time_us(); + return true; + } + + bool prev_result = false; + if (profiler->prev_active && profiler->prev_callback != nullptr) { + prev_result = profiler->prev_callback(tensor, ask, profiler->prev_user_data); + } + + const bool tracked = kind != LLAMA_DFLASH_MAIN_NODE_NONE && + profiler->active_kind == kind && + profiler->t_start_us > 0; + if (tracked) { + llama_dflash_main_node_profile_add(*profiler->profile, kind, (uint64_t) (ggml_time_us() - profiler->t_start_us)); + } + + profiler->prev_active = false; + profiler->active_kind = LLAMA_DFLASH_MAIN_NODE_NONE; + profiler->t_start_us = 0; + return prev_result || tracked; +} diff --git a/src/llama-dflash.cpp b/src/llama-dflash.cpp index 9230840d..e32f53d1 100644 --- a/src/llama-dflash.cpp +++ b/src/llama-dflash.cpp @@ -5,6 +5,7 @@ #include "llama-context.h" #include "llama-model.h" #include "llama-spec-features.h" +#include "llama-dflash-profile.h" #include "ggml.h" #include "ggml-backend.h" @@ -27,342 +28,6 @@ static bool llama_dflash_stats_log_enabled() { return llama_env_flag_enabled_local("IK_DFLASH_STATS_LOG"); } -enum llama_dflash_kv_node_kind { - LLAMA_DFLASH_KV_NODE_NONE = 0, - LLAMA_DFLASH_KV_NODE_FUSED_TARGET, - LLAMA_DFLASH_KV_NODE_K_PROJ, - LLAMA_DFLASH_KV_NODE_K_NORM, - LLAMA_DFLASH_KV_NODE_K_ROPE, - LLAMA_DFLASH_KV_NODE_V_PROJ, - LLAMA_DFLASH_KV_NODE_K_STORE, - LLAMA_DFLASH_KV_NODE_V_STORE, -}; - -enum llama_dflash_main_node_kind { - LLAMA_DFLASH_MAIN_NODE_NONE = 0, - LLAMA_DFLASH_MAIN_NODE_QCUR, - LLAMA_DFLASH_MAIN_NODE_K_DRAFT, - LLAMA_DFLASH_MAIN_NODE_V_DRAFT, - LLAMA_DFLASH_MAIN_NODE_K_CTX_VIEW, - LLAMA_DFLASH_MAIN_NODE_V_CTX_VIEW, - LLAMA_DFLASH_MAIN_NODE_K_CONCAT, - LLAMA_DFLASH_MAIN_NODE_V_CONCAT, - LLAMA_DFLASH_MAIN_NODE_K_PAD, - LLAMA_DFLASH_MAIN_NODE_V_PAD, - LLAMA_DFLASH_MAIN_NODE_K_PERM_CONT, - LLAMA_DFLASH_MAIN_NODE_V_PERM_CONT, - LLAMA_DFLASH_MAIN_NODE_FLASH_ATTN, - LLAMA_DFLASH_MAIN_NODE_ATTN_OUT, - LLAMA_DFLASH_MAIN_NODE_FFN, - LLAMA_DFLASH_MAIN_NODE_RESULT_ROWS, - LLAMA_DFLASH_MAIN_NODE_RESULT_NORM, - LLAMA_DFLASH_MAIN_NODE_RESULT, -}; - -struct llama_dflash_kv_node_profiler { - llama_dflash_profile_stats * profile = nullptr; - int64_t t_start_us = 0; - llama_dflash_kv_node_kind active_kind = LLAMA_DFLASH_KV_NODE_NONE; -}; - -struct llama_dflash_main_node_profiler { - llama_dflash_profile_stats * profile = nullptr; - ggml_backend_sched_eval_callback prev_callback = nullptr; - void * prev_user_data = nullptr; - bool prev_active = false; - int64_t t_start_us = 0; - llama_dflash_main_node_kind active_kind = LLAMA_DFLASH_MAIN_NODE_NONE; -}; - -static bool llama_dflash_tensor_name_has_prefix(const struct ggml_tensor * tensor, const char * prefix) { - if (tensor == nullptr || prefix == nullptr || prefix[0] == '\0') { - return false; - } - - return std::strncmp(tensor->name, prefix, std::strlen(prefix)) == 0; -} - -static bool llama_dflash_tensor_name_matches_label(const struct ggml_tensor * tensor, const char * label) { - if (!llama_dflash_tensor_name_has_prefix(tensor, label)) { - return false; - } - - const size_t label_len = std::strlen(label); - const char next = tensor->name[label_len]; - return next == '\0' || next == '-'; -} - -static llama_dflash_kv_node_kind llama_dflash_kv_node_kind_from_tensor(const struct ggml_tensor * tensor) { - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_fused_target")) { - return LLAMA_DFLASH_KV_NODE_FUSED_TARGET; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_proj")) { - return LLAMA_DFLASH_KV_NODE_K_PROJ; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_norm")) { - return LLAMA_DFLASH_KV_NODE_K_NORM; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_rope")) { - return LLAMA_DFLASH_KV_NODE_K_ROPE; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_v_proj")) { - return LLAMA_DFLASH_KV_NODE_V_PROJ; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_store")) { - return LLAMA_DFLASH_KV_NODE_K_STORE; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_v_store")) { - return LLAMA_DFLASH_KV_NODE_V_STORE; - } - - return LLAMA_DFLASH_KV_NODE_NONE; -} - -static void llama_dflash_kv_node_profile_add( - llama_dflash_profile_stats & profile, - llama_dflash_kv_node_kind kind, - uint64_t elapsed_us) { - switch (kind) { - case LLAMA_DFLASH_KV_NODE_FUSED_TARGET: - profile.graph_kv_node_fused_target_calls++; - profile.graph_kv_node_fused_target_us += elapsed_us; - break; - case LLAMA_DFLASH_KV_NODE_K_PROJ: - profile.graph_kv_node_k_proj_calls++; - profile.graph_kv_node_k_proj_us += elapsed_us; - break; - case LLAMA_DFLASH_KV_NODE_K_NORM: - profile.graph_kv_node_k_norm_calls++; - profile.graph_kv_node_k_norm_us += elapsed_us; - break; - case LLAMA_DFLASH_KV_NODE_K_ROPE: - profile.graph_kv_node_k_rope_calls++; - profile.graph_kv_node_k_rope_us += elapsed_us; - break; - case LLAMA_DFLASH_KV_NODE_V_PROJ: - profile.graph_kv_node_v_proj_calls++; - profile.graph_kv_node_v_proj_us += elapsed_us; - break; - case LLAMA_DFLASH_KV_NODE_K_STORE: - profile.graph_kv_node_k_store_calls++; - profile.graph_kv_node_k_store_us += elapsed_us; - break; - case LLAMA_DFLASH_KV_NODE_V_STORE: - profile.graph_kv_node_v_store_calls++; - profile.graph_kv_node_v_store_us += elapsed_us; - break; - case LLAMA_DFLASH_KV_NODE_NONE: - break; - } -} - -static llama_dflash_main_node_kind llama_dflash_main_node_kind_from_tensor(const struct ggml_tensor * tensor) { - if (llama_dflash_tensor_name_has_prefix(tensor, "Qcur")) { - return LLAMA_DFLASH_MAIN_NODE_QCUR; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "Kcur_noise")) { - return LLAMA_DFLASH_MAIN_NODE_K_DRAFT; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "Vcur_noise")) { - return LLAMA_DFLASH_MAIN_NODE_V_DRAFT; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "Kcur_ctx_cache")) { - return LLAMA_DFLASH_MAIN_NODE_K_CTX_VIEW; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "Vcur_ctx_cache")) { - return LLAMA_DFLASH_MAIN_NODE_V_CTX_VIEW; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_k_concat")) { - return LLAMA_DFLASH_MAIN_NODE_K_CONCAT; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_v_concat")) { - return LLAMA_DFLASH_MAIN_NODE_V_CONCAT; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_k_pad")) { - return LLAMA_DFLASH_MAIN_NODE_K_PAD; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_v_pad")) { - return LLAMA_DFLASH_MAIN_NODE_V_PAD; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_k_perm_cont")) { - return LLAMA_DFLASH_MAIN_NODE_K_PERM_CONT; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_v_perm_cont")) { - return LLAMA_DFLASH_MAIN_NODE_V_PERM_CONT; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "flash_attn_reshaped")) { - return LLAMA_DFLASH_MAIN_NODE_NONE; - } - if (llama_dflash_tensor_name_matches_label(tensor, "flash_attn")) { - return LLAMA_DFLASH_MAIN_NODE_FLASH_ATTN; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "kqv_out")) { - return LLAMA_DFLASH_MAIN_NODE_ATTN_OUT; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "ffn_out")) { - return LLAMA_DFLASH_MAIN_NODE_FFN; - } - if (llama_dflash_tensor_name_matches_label(tensor, "result_output_rows")) { - return LLAMA_DFLASH_MAIN_NODE_RESULT_ROWS; - } - if (llama_dflash_tensor_name_matches_label(tensor, "result_norm")) { - return LLAMA_DFLASH_MAIN_NODE_RESULT_NORM; - } - if (llama_dflash_tensor_name_matches_label(tensor, "output")) { - return LLAMA_DFLASH_MAIN_NODE_RESULT; - } - if (llama_dflash_tensor_name_matches_label(tensor, "result_output")) { - return LLAMA_DFLASH_MAIN_NODE_RESULT; - } - - return LLAMA_DFLASH_MAIN_NODE_NONE; -} - -static void llama_dflash_main_node_profile_add( - llama_dflash_profile_stats & profile, - llama_dflash_main_node_kind kind, - uint64_t elapsed_us) { - switch (kind) { - case LLAMA_DFLASH_MAIN_NODE_QCUR: - profile.graph_main_node_qcur_calls++; - profile.graph_main_node_qcur_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_K_DRAFT: - profile.graph_main_node_k_draft_calls++; - profile.graph_main_node_k_draft_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_V_DRAFT: - profile.graph_main_node_v_draft_calls++; - profile.graph_main_node_v_draft_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_K_CTX_VIEW: - profile.graph_main_node_k_ctx_view_calls++; - profile.graph_main_node_k_ctx_view_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_V_CTX_VIEW: - profile.graph_main_node_v_ctx_view_calls++; - profile.graph_main_node_v_ctx_view_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_K_CONCAT: - profile.graph_main_node_k_concat_calls++; - profile.graph_main_node_k_concat_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_V_CONCAT: - profile.graph_main_node_v_concat_calls++; - profile.graph_main_node_v_concat_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_K_PAD: - profile.graph_main_node_k_pad_calls++; - profile.graph_main_node_k_pad_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_V_PAD: - profile.graph_main_node_v_pad_calls++; - profile.graph_main_node_v_pad_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_K_PERM_CONT: - profile.graph_main_node_k_perm_cont_calls++; - profile.graph_main_node_k_perm_cont_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_V_PERM_CONT: - profile.graph_main_node_v_perm_cont_calls++; - profile.graph_main_node_v_perm_cont_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_FLASH_ATTN: - profile.graph_main_node_flash_attn_calls++; - profile.graph_main_node_flash_attn_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_ATTN_OUT: - profile.graph_main_node_attn_out_calls++; - profile.graph_main_node_attn_out_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_FFN: - profile.graph_main_node_ffn_calls++; - profile.graph_main_node_ffn_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_RESULT_ROWS: - profile.graph_main_node_result_rows_calls++; - profile.graph_main_node_result_rows_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_RESULT_NORM: - profile.graph_main_node_result_norm_calls++; - profile.graph_main_node_result_norm_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_RESULT: - profile.graph_main_node_result_calls++; - profile.graph_main_node_result_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_NONE: - break; - } -} - -static bool llama_dflash_kv_node_eval_callback(struct ggml_tensor * tensor, bool ask, void * user_data) { - auto * profiler = static_cast(user_data); - if (profiler == nullptr || profiler->profile == nullptr) { - return false; - } - - const llama_dflash_kv_node_kind kind = llama_dflash_kv_node_kind_from_tensor(tensor); - if (ask) { - if (kind == LLAMA_DFLASH_KV_NODE_NONE) { - return false; - } - - profiler->active_kind = kind; - profiler->t_start_us = ggml_time_us(); - return true; - } - - if (kind != LLAMA_DFLASH_KV_NODE_NONE && profiler->active_kind == kind && profiler->t_start_us > 0) { - llama_dflash_kv_node_profile_add(*profiler->profile, kind, (uint64_t) (ggml_time_us() - profiler->t_start_us)); - } - - profiler->active_kind = LLAMA_DFLASH_KV_NODE_NONE; - profiler->t_start_us = 0; - return true; -} - -static bool llama_dflash_main_node_eval_callback(struct ggml_tensor * tensor, bool ask, void * user_data) { - auto * profiler = static_cast(user_data); - if (profiler == nullptr || profiler->profile == nullptr) { - return false; - } - - const llama_dflash_main_node_kind kind = llama_dflash_main_node_kind_from_tensor(tensor); - if (ask) { - profiler->prev_active = profiler->prev_callback != nullptr - ? profiler->prev_callback(tensor, ask, profiler->prev_user_data) - : false; - - if (kind == LLAMA_DFLASH_MAIN_NODE_NONE) { - profiler->active_kind = LLAMA_DFLASH_MAIN_NODE_NONE; - profiler->t_start_us = 0; - return profiler->prev_active; - } - - profiler->active_kind = kind; - profiler->t_start_us = ggml_time_us(); - return true; - } - - bool prev_result = false; - if (profiler->prev_active && profiler->prev_callback != nullptr) { - prev_result = profiler->prev_callback(tensor, ask, profiler->prev_user_data); - } - - const bool tracked = kind != LLAMA_DFLASH_MAIN_NODE_NONE && - profiler->active_kind == kind && - profiler->t_start_us > 0; - if (tracked) { - llama_dflash_main_node_profile_add(*profiler->profile, kind, (uint64_t) (ggml_time_us() - profiler->t_start_us)); - } - - profiler->prev_active = false; - profiler->active_kind = LLAMA_DFLASH_MAIN_NODE_NONE; - profiler->t_start_us = 0; - return prev_result || tracked; -} - void llama_sync_dflash_workspace_if_pending(struct llama_context & lctx) { if (!lctx.dflash_kv_workspace_sync_pending || lctx.dflash_workspace_sched == nullptr) { return; diff --git a/src/llama-spec-features-dflash.cpp b/src/llama-spec-features-dflash.cpp index 2df71006..c8cd5181 100644 --- a/src/llama-spec-features-dflash.cpp +++ b/src/llama-spec-features-dflash.cpp @@ -880,6 +880,36 @@ static void llama_dflash_contract_log_output_indices( have_capture ? "true" : "false"); } +void llama_dflash_contract_log_accept( + int slot_id, + bool is_dflash, + const char * path, + bool any_rejected, + size_t n_draft, + size_t n_accepted, + llama_pos pos_base, + const std::vector & output_indices) { + if (!llama_dflash_contract_log_enabled() || !is_dflash) { + return; + } + + static std::atomic counter = 0; + const uint64_t ordinal = counter.fetch_add(1, std::memory_order_relaxed); + if (ordinal >= 8) { + return; + } + + LLAMA_LOG_INFO("dflash contract accept[%llu]: slot=%d path=%s rejected=%s drafted=%zu accepted=%zu pos_base=%d output_indices=%s\n", + (unsigned long long) (ordinal + 1), + slot_id, + path, + any_rejected ? "true" : "false", + n_draft, + n_accepted, + (int) pos_base, + llama_dflash_contract_format_values(output_indices).c_str()); +} + static bool llama_spec_materialize_dflash_rows_prepared( struct llama_context * ctx, int32_t row_count, diff --git a/src/llama-spec-features-dflash.h b/src/llama-spec-features-dflash.h index 02e709d1..e05f2a91 100644 --- a/src/llama-spec-features-dflash.h +++ b/src/llama-spec-features-dflash.h @@ -277,3 +277,13 @@ bool llama_spec_copy_dflash_rows_from_output_indices( struct llama_context * ctx, const std::vector & output_indices, std::vector & hidden_rows); + +void llama_dflash_contract_log_accept( + int slot_id, + bool is_dflash, + const char * path, + bool any_rejected, + size_t n_draft, + size_t n_accepted, + llama_pos pos_base, + const std::vector & output_indices); diff --git a/src/llama-spec-features.cpp b/src/llama-spec-features.cpp index 86e00399..2e93b5fa 100644 --- a/src/llama-spec-features.cpp +++ b/src/llama-spec-features.cpp @@ -1,11 +1,6 @@ #include "llama-spec-features.h" -#include -#include -#include -#include #include -#include #include "llama-model.h" #include "llama-context.h" diff --git a/src/llama.cpp b/src/llama.cpp index 48b6815c..8adab80e 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -19,6 +19,7 @@ #include "llama-context.h" #include "llama-spec-features.h" #include "llama-dflash.h" +#include "llama-dflash-profile.h" #include "llama-quantize.h" #include "unicode.h" @@ -180,343 +181,6 @@ static bool llama_env_flag_enabled(const char * name) { std::strcmp(env, "off") != 0; } -enum llama_dflash_kv_node_kind { - LLAMA_DFLASH_KV_NODE_NONE = 0, - LLAMA_DFLASH_KV_NODE_FUSED_TARGET, - LLAMA_DFLASH_KV_NODE_K_PROJ, - LLAMA_DFLASH_KV_NODE_K_NORM, - LLAMA_DFLASH_KV_NODE_K_ROPE, - LLAMA_DFLASH_KV_NODE_V_PROJ, - LLAMA_DFLASH_KV_NODE_K_STORE, - LLAMA_DFLASH_KV_NODE_V_STORE, -}; - -enum llama_dflash_main_node_kind { - LLAMA_DFLASH_MAIN_NODE_NONE = 0, - LLAMA_DFLASH_MAIN_NODE_QCUR, - LLAMA_DFLASH_MAIN_NODE_K_DRAFT, - LLAMA_DFLASH_MAIN_NODE_V_DRAFT, - LLAMA_DFLASH_MAIN_NODE_K_CTX_VIEW, - LLAMA_DFLASH_MAIN_NODE_V_CTX_VIEW, - LLAMA_DFLASH_MAIN_NODE_K_CONCAT, - LLAMA_DFLASH_MAIN_NODE_V_CONCAT, - LLAMA_DFLASH_MAIN_NODE_K_PAD, - LLAMA_DFLASH_MAIN_NODE_V_PAD, - LLAMA_DFLASH_MAIN_NODE_K_PERM_CONT, - LLAMA_DFLASH_MAIN_NODE_V_PERM_CONT, - LLAMA_DFLASH_MAIN_NODE_FLASH_ATTN, - LLAMA_DFLASH_MAIN_NODE_ATTN_OUT, - LLAMA_DFLASH_MAIN_NODE_FFN, - LLAMA_DFLASH_MAIN_NODE_RESULT_ROWS, - LLAMA_DFLASH_MAIN_NODE_RESULT_NORM, - LLAMA_DFLASH_MAIN_NODE_RESULT, -}; - -struct llama_dflash_kv_node_profiler { - llama_dflash_profile_stats * profile = nullptr; - int64_t t_start_us = 0; - llama_dflash_kv_node_kind active_kind = LLAMA_DFLASH_KV_NODE_NONE; -}; - -struct llama_dflash_main_node_profiler { - llama_dflash_profile_stats * profile = nullptr; - ggml_backend_sched_eval_callback prev_callback = nullptr; - void * prev_user_data = nullptr; - bool prev_active = false; - int64_t t_start_us = 0; - llama_dflash_main_node_kind active_kind = LLAMA_DFLASH_MAIN_NODE_NONE; -}; - -static bool llama_dflash_tensor_name_has_prefix(const struct ggml_tensor * tensor, const char * prefix) { - if (tensor == nullptr || prefix == nullptr || prefix[0] == '\0') { - return false; - } - - return std::strncmp(tensor->name, prefix, std::strlen(prefix)) == 0; -} - -static bool llama_dflash_tensor_name_matches_label(const struct ggml_tensor * tensor, const char * label) { - if (!llama_dflash_tensor_name_has_prefix(tensor, label)) { - return false; - } - - const size_t label_len = std::strlen(label); - const char next = tensor->name[label_len]; - return next == '\0' || next == '-'; -} - -static llama_dflash_kv_node_kind llama_dflash_kv_node_kind_from_tensor(const struct ggml_tensor * tensor) { - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_fused_target")) { - return LLAMA_DFLASH_KV_NODE_FUSED_TARGET; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_proj")) { - return LLAMA_DFLASH_KV_NODE_K_PROJ; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_norm")) { - return LLAMA_DFLASH_KV_NODE_K_NORM; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_rope")) { - return LLAMA_DFLASH_KV_NODE_K_ROPE; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_v_proj")) { - return LLAMA_DFLASH_KV_NODE_V_PROJ; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_store")) { - return LLAMA_DFLASH_KV_NODE_K_STORE; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_v_store")) { - return LLAMA_DFLASH_KV_NODE_V_STORE; - } - - return LLAMA_DFLASH_KV_NODE_NONE; -} - -static void llama_dflash_kv_node_profile_add( - llama_dflash_profile_stats & profile, - llama_dflash_kv_node_kind kind, - uint64_t elapsed_us) { - switch (kind) { - case LLAMA_DFLASH_KV_NODE_FUSED_TARGET: - profile.graph_kv_node_fused_target_calls++; - profile.graph_kv_node_fused_target_us += elapsed_us; - break; - case LLAMA_DFLASH_KV_NODE_K_PROJ: - profile.graph_kv_node_k_proj_calls++; - profile.graph_kv_node_k_proj_us += elapsed_us; - break; - case LLAMA_DFLASH_KV_NODE_K_NORM: - profile.graph_kv_node_k_norm_calls++; - profile.graph_kv_node_k_norm_us += elapsed_us; - break; - case LLAMA_DFLASH_KV_NODE_K_ROPE: - profile.graph_kv_node_k_rope_calls++; - profile.graph_kv_node_k_rope_us += elapsed_us; - break; - case LLAMA_DFLASH_KV_NODE_V_PROJ: - profile.graph_kv_node_v_proj_calls++; - profile.graph_kv_node_v_proj_us += elapsed_us; - break; - case LLAMA_DFLASH_KV_NODE_K_STORE: - profile.graph_kv_node_k_store_calls++; - profile.graph_kv_node_k_store_us += elapsed_us; - break; - case LLAMA_DFLASH_KV_NODE_V_STORE: - profile.graph_kv_node_v_store_calls++; - profile.graph_kv_node_v_store_us += elapsed_us; - break; - case LLAMA_DFLASH_KV_NODE_NONE: - break; - } -} - -static llama_dflash_main_node_kind llama_dflash_main_node_kind_from_tensor(const struct ggml_tensor * tensor) { - if (llama_dflash_tensor_name_has_prefix(tensor, "Qcur")) { - return LLAMA_DFLASH_MAIN_NODE_QCUR; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "Kcur_noise")) { - return LLAMA_DFLASH_MAIN_NODE_K_DRAFT; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "Vcur_noise")) { - return LLAMA_DFLASH_MAIN_NODE_V_DRAFT; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "Kcur_ctx_cache")) { - return LLAMA_DFLASH_MAIN_NODE_K_CTX_VIEW; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "Vcur_ctx_cache")) { - return LLAMA_DFLASH_MAIN_NODE_V_CTX_VIEW; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_k_concat")) { - return LLAMA_DFLASH_MAIN_NODE_K_CONCAT; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_v_concat")) { - return LLAMA_DFLASH_MAIN_NODE_V_CONCAT; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_k_pad")) { - return LLAMA_DFLASH_MAIN_NODE_K_PAD; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_v_pad")) { - return LLAMA_DFLASH_MAIN_NODE_V_PAD; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_k_perm_cont")) { - return LLAMA_DFLASH_MAIN_NODE_K_PERM_CONT; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_v_perm_cont")) { - return LLAMA_DFLASH_MAIN_NODE_V_PERM_CONT; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "flash_attn_reshaped")) { - return LLAMA_DFLASH_MAIN_NODE_NONE; - } - if (llama_dflash_tensor_name_matches_label(tensor, "flash_attn")) { - return LLAMA_DFLASH_MAIN_NODE_FLASH_ATTN; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "kqv_out")) { - return LLAMA_DFLASH_MAIN_NODE_ATTN_OUT; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "ffn_out")) { - return LLAMA_DFLASH_MAIN_NODE_FFN; - } - if (llama_dflash_tensor_name_matches_label(tensor, "result_output_rows")) { - return LLAMA_DFLASH_MAIN_NODE_RESULT_ROWS; - } - if (llama_dflash_tensor_name_matches_label(tensor, "result_norm")) { - return LLAMA_DFLASH_MAIN_NODE_RESULT_NORM; - } - if (llama_dflash_tensor_name_matches_label(tensor, "output")) { - return LLAMA_DFLASH_MAIN_NODE_RESULT; - } - if (llama_dflash_tensor_name_matches_label(tensor, "result_output")) { - return LLAMA_DFLASH_MAIN_NODE_RESULT; - } - - return LLAMA_DFLASH_MAIN_NODE_NONE; -} - -static void llama_dflash_main_node_profile_add( - llama_dflash_profile_stats & profile, - llama_dflash_main_node_kind kind, - uint64_t elapsed_us) { - switch (kind) { - case LLAMA_DFLASH_MAIN_NODE_QCUR: - profile.graph_main_node_qcur_calls++; - profile.graph_main_node_qcur_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_K_DRAFT: - profile.graph_main_node_k_draft_calls++; - profile.graph_main_node_k_draft_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_V_DRAFT: - profile.graph_main_node_v_draft_calls++; - profile.graph_main_node_v_draft_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_K_CTX_VIEW: - profile.graph_main_node_k_ctx_view_calls++; - profile.graph_main_node_k_ctx_view_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_V_CTX_VIEW: - profile.graph_main_node_v_ctx_view_calls++; - profile.graph_main_node_v_ctx_view_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_K_CONCAT: - profile.graph_main_node_k_concat_calls++; - profile.graph_main_node_k_concat_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_V_CONCAT: - profile.graph_main_node_v_concat_calls++; - profile.graph_main_node_v_concat_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_K_PAD: - profile.graph_main_node_k_pad_calls++; - profile.graph_main_node_k_pad_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_V_PAD: - profile.graph_main_node_v_pad_calls++; - profile.graph_main_node_v_pad_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_K_PERM_CONT: - profile.graph_main_node_k_perm_cont_calls++; - profile.graph_main_node_k_perm_cont_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_V_PERM_CONT: - profile.graph_main_node_v_perm_cont_calls++; - profile.graph_main_node_v_perm_cont_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_FLASH_ATTN: - profile.graph_main_node_flash_attn_calls++; - profile.graph_main_node_flash_attn_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_ATTN_OUT: - profile.graph_main_node_attn_out_calls++; - profile.graph_main_node_attn_out_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_FFN: - profile.graph_main_node_ffn_calls++; - profile.graph_main_node_ffn_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_RESULT_ROWS: - profile.graph_main_node_result_rows_calls++; - profile.graph_main_node_result_rows_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_RESULT_NORM: - profile.graph_main_node_result_norm_calls++; - profile.graph_main_node_result_norm_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_RESULT: - profile.graph_main_node_result_calls++; - profile.graph_main_node_result_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_NONE: - break; - } -} - -static bool llama_dflash_kv_node_eval_callback(struct ggml_tensor * tensor, bool ask, void * user_data) { - auto * profiler = static_cast(user_data); - if (profiler == nullptr || profiler->profile == nullptr) { - return false; - } - - const llama_dflash_kv_node_kind kind = llama_dflash_kv_node_kind_from_tensor(tensor); - if (ask) { - if (kind == LLAMA_DFLASH_KV_NODE_NONE) { - return false; - } - - profiler->active_kind = kind; - profiler->t_start_us = ggml_time_us(); - return true; - } - - if (kind != LLAMA_DFLASH_KV_NODE_NONE && profiler->active_kind == kind && profiler->t_start_us > 0) { - llama_dflash_kv_node_profile_add(*profiler->profile, kind, (uint64_t) (ggml_time_us() - profiler->t_start_us)); - } - - profiler->active_kind = LLAMA_DFLASH_KV_NODE_NONE; - profiler->t_start_us = 0; - return true; -} - -static bool llama_dflash_main_node_eval_callback(struct ggml_tensor * tensor, bool ask, void * user_data) { - auto * profiler = static_cast(user_data); - if (profiler == nullptr || profiler->profile == nullptr) { - return false; - } - - const llama_dflash_main_node_kind kind = llama_dflash_main_node_kind_from_tensor(tensor); - if (ask) { - profiler->prev_active = profiler->prev_callback != nullptr - ? profiler->prev_callback(tensor, ask, profiler->prev_user_data) - : false; - - if (kind == LLAMA_DFLASH_MAIN_NODE_NONE) { - profiler->active_kind = LLAMA_DFLASH_MAIN_NODE_NONE; - profiler->t_start_us = 0; - return profiler->prev_active; - } - - profiler->active_kind = kind; - profiler->t_start_us = ggml_time_us(); - return true; - } - - bool prev_result = false; - if (profiler->prev_active && profiler->prev_callback != nullptr) { - prev_result = profiler->prev_callback(tensor, ask, profiler->prev_user_data); - } - - const bool tracked = kind != LLAMA_DFLASH_MAIN_NODE_NONE && - profiler->active_kind == kind && - profiler->t_start_us > 0; - if (tracked) { - llama_dflash_main_node_profile_add(*profiler->profile, kind, (uint64_t) (ggml_time_us() - profiler->t_start_us)); - } - - profiler->prev_active = false; - profiler->active_kind = LLAMA_DFLASH_MAIN_NODE_NONE; - profiler->t_start_us = 0; - return prev_result || tracked; -} - - // extract ip and port from RPC[ip:port] for rpc and keep other device names static std::vector extract_device_from_rpc_device(std::vector devices) { std::vector rpc_servers;