From 82cff238fea079104fab95eb6c2dd6e322c39b0d Mon Sep 17 00:00:00 2001 From: SamuelOliveirads Date: Thu, 28 May 2026 18:57:58 -0300 Subject: [PATCH 01/13] Initial dflash implementation --- .flake8 | 2 + common/common.cpp | 26 +- common/common.h | 6 +- common/speculative.cpp | 396 +++++++++++++++++++++++++++-- convert_hf_to_gguf.py | 160 +++++++++++- convert_hf_to_gguf_update.py | 63 +++-- examples/server/server-context.cpp | 120 +++++---- gguf-py/gguf/constants.py | 205 +++++++++------ src/CMakeLists.txt | 1 + src/graphs/build_dflash.cpp | 144 +++++++++++ src/llama-arch.cpp | 6 +- src/llama-arch.h | 7 + src/llama-build-context.cpp | 7 + src/llama-build-context.h | 2 + src/llama-context.h | 23 +- src/llama-hparams.cpp | 96 +++++++ src/llama-hparams.h | 14 + src/llama-load-tensors.cpp | 41 +++ src/llama-model.cpp | 21 ++ src/llama-model.h | 3 +- src/llama-spec-features.cpp | 366 +++++++++++++++++++++++++- src/llama-spec-features.h | 48 +++- src/llama.cpp | 66 +++++ 23 files changed, 1618 insertions(+), 205 deletions(-) create mode 100644 src/graphs/build_dflash.cpp diff --git a/.flake8 b/.flake8 index a8bd3a81..d8f049f7 100644 --- a/.flake8 +++ b/.flake8 @@ -17,3 +17,5 @@ exclude = # This contains builds that we don't want to check dist # This is generated with `python build .` for package releases # max-complexity = 10 +per-file-ignores = + gguf-py/gguf/constants.py: E201, E222 diff --git a/common/common.cpp b/common/common.cpp index ffb8d5fd..99c517e4 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -148,6 +148,9 @@ common_params_speculative common_params_speculative::with_stage_overrides(const if (stage.has_p_min_override()) { result.p_min = stage.p_min; } + if (stage.has_dflash_cross_ctx_override()) { + result.dflash_cross_ctx = stage.dflash_cross_ctx; + } if (stage.has_ngram_size_n_override()) { result.ngram_size_n = stage.ngram_size_n; result.ngram_mod.reset(); @@ -247,8 +250,12 @@ bool common_speculative_validate_chain(const common_params_speculative & params, return fail("speculative stage has n_min greater than n_max"); } - if (stage.type == COMMON_SPECULATIVE_TYPE_DRAFT && !params.has_dft()) { - return fail("draft speculative stage requires a draft model or draft params"); + if ((stage.type == COMMON_SPECULATIVE_TYPE_DRAFT || stage.type == COMMON_SPECULATIVE_TYPE_DFLASH) && !params.has_dft()) { + return fail(common_speculative_type_to_str(stage.type) + " speculative stage requires a draft model or draft params"); + } + + if (stage.type == COMMON_SPECULATIVE_TYPE_DFLASH && stage_params.dflash_cross_ctx < 1) { + return fail("dflash speculative stage requires cross_ctx >= 1"); } } @@ -871,6 +878,13 @@ static void common_speculative_stage_apply_kv( } return; } + if (key == "cross_ctx" || key == "dflash_cross_ctx") { + stage.dflash_cross_ctx = std::stoi(value_raw); + if (stage.dflash_cross_ctx < 1) { + throw std::invalid_argument("speculative stage dflash cross_ctx must be at least 1"); + } + return; + } if (key == "ngram_size_n") { stage.ngram_size_n = std::stoi(value_raw); if (stage.ngram_size_n < 1 || stage.ngram_size_n > 1024) { @@ -1468,8 +1482,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa throw std::invalid_argument("--spec-type cannot be combined with --spec-stage; use only --spec-stage for explicit stage chains"); } - const auto type = common_speculative_type_from_name(argv[i]); - if (type == COMMON_SPECULATIVE_TYPE_NONE || type == COMMON_SPECULATIVE_TYPE_MTP || common_speculative_type_is_self_spec(type)) { + const auto stage = common_speculative_stage_from_arg(argv[i]); + const auto type = stage.type; + if (type == COMMON_SPECULATIVE_TYPE_NONE || type == COMMON_SPECULATIVE_TYPE_DFLASH || type == COMMON_SPECULATIVE_TYPE_MTP || common_speculative_type_is_self_spec(type)) { + params.speculative = params.speculative.with_stage_overrides(stage); params.speculative.type = type; if (type == COMMON_SPECULATIVE_TYPE_MTP) { params.has_mtp = true; @@ -3178,7 +3194,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", "--spec-stage SPEC[:k=v,...]", "explicit speculative stage. repeat once for a supported two-stage chain.\n" "examples: --spec-stage ngram-mod:n_max=64,n_min=2 --spec-stage mtp:n_max=1\n" "supported two-stage shape in this PR: self-spec first, then mtp or draft fallback" }); - options.push_back({ "*", "--spec-type Name [none | mtp | ngram-cache | ngram-simple | ngram-map-k | ngram-map-k4v | ngram-mod | suffix]", "single-stage speculative selection when --spec-stage is not used (default: %d)\n", (int)params.speculative.type}); + options.push_back({ "*", "--spec-type Name[:k=v,...] [none | dflash | mtp | ngram-cache | ngram-simple | ngram-map-k | ngram-map-k4v | ngram-mod | suffix]", "single-stage speculative selection when --spec-stage is not used (default: %d)\n", (int)params.speculative.type}); options.push_back({ "*", "--spec-ngram-size-n N", "ngram size N for ngram-simple/ngram-map speculative decoding, length of lookup n-gram (default: %d)\n",params.speculative.ngram_size_n }); options.push_back({ "*", "--spec-ngram-size-m N", "ngram size M for ngram-simple/ngram-map speculative decoding, length of draft m-gram (default: %d)\n", params.speculative.ngram_size_m }); diff --git a/common/common.h b/common/common.h index bc68ca0f..87de68d9 100644 --- a/common/common.h +++ b/common/common.h @@ -140,6 +140,7 @@ thinking_tokens thinking_tokens_from_string(const std::string& format); enum common_speculative_type { COMMON_SPECULATIVE_TYPE_NONE, // no speculative decoding COMMON_SPECULATIVE_TYPE_DRAFT, // draft model + COMMON_SPECULATIVE_TYPE_DFLASH, // DFlash draft model COMMON_SPECULATIVE_TYPE_MTP, // MTP model COMMON_SPECULATIVE_TYPE_EAGLE3, // eagle draft model COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding @@ -162,6 +163,7 @@ struct common_speculative_stage_params { int32_t n_max = -1; int32_t n_min = -1; float p_min = -1.0f; + int32_t dflash_cross_ctx = -1; uint16_t ngram_size_n = 0; uint16_t ngram_size_m = 0; @@ -173,6 +175,7 @@ struct common_speculative_stage_params { bool has_n_max_override() const { return n_max >= 0; } bool has_n_min_override() const { return n_min >= 0; } bool has_p_min_override() const { return p_min >= 0.0f; } + bool has_dflash_cross_ctx_override() const { return dflash_cross_ctx >= 0; } bool has_ngram_size_n_override() const { return ngram_size_n > 0; } bool has_ngram_size_m_override() const { return ngram_size_m > 0; } bool has_ngram_min_hits_override() const { return ngram_min_hits > 0; } @@ -204,6 +207,7 @@ struct common_params_speculative { int32_t n_max = 16; // number of tokens to draft during speculative decoding int32_t n_min = 0; // minimum number of tokens to draft during speculative decoding std::vector stages; // explicit stage chain for single-spec or self-spec + model fallback + int32_t dflash_cross_ctx = 512; // target-feature context window for DFlash float p_split = 0.1f; // speculative decoding split probability float p_min = 0.75f; // minimum speculative decoding probability (greedy) @@ -516,7 +520,7 @@ struct gpt_params { bool do_checkpoint = false; // do checkpoint for recurrent models only int32_t ctx_checkpoints_n = 32; // max number of context checkpoints per slot int32_t ctx_checkpoints_interval = 512; // minimum number of tokens between each context checkpoints - int32_t ctx_checkpoints_tolerance = 5; // the number of tokens before the full prompt to create the checkpoint + int32_t ctx_checkpoints_tolerance = 5; // the number of tokens before the full prompt to create the checkpoint int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc. int32_t cache_ram_n_min = 0; // min number of tokens required to save in the ram float cache_ram_similarity = 0.5f; // similarity of tokens to cached tokens diff --git a/common/speculative.cpp b/common/speculative.cpp index 2341bb6c..3b08b26a 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -24,6 +24,7 @@ void llama_set_mtp_target_context(struct llama_context * ctx, struct llama_conte const std::vector common_speculative_types = { COMMON_SPECULATIVE_TYPE_NONE, COMMON_SPECULATIVE_TYPE_DRAFT, + COMMON_SPECULATIVE_TYPE_DFLASH, COMMON_SPECULATIVE_TYPE_MTP, COMMON_SPECULATIVE_TYPE_EAGLE3, COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, @@ -37,6 +38,7 @@ const std::vector common_speculative_types = { const std::map common_speculative_type_from_name_map = { {"none", COMMON_SPECULATIVE_TYPE_NONE}, {"draft", COMMON_SPECULATIVE_TYPE_DRAFT}, + {"dflash", COMMON_SPECULATIVE_TYPE_DFLASH}, {"mtp", COMMON_SPECULATIVE_TYPE_MTP}, {"eagle3", COMMON_SPECULATIVE_TYPE_EAGLE3}, {"ngram_simple", COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE}, @@ -117,6 +119,44 @@ static bool common_speculative_are_compatible( return true; } +static bool common_speculative_are_dflash_compatible( + const llama_model * model_tgt, + const llama_model * model_dft) { + const llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt); + const llama_vocab * vocab_dft = llama_model_get_vocab(model_dft); + + if (llama_vocab_type(vocab_tgt) != llama_vocab_type(vocab_dft)) { + LOG_DBG("%s: DFlash draft model vocab type must match the target model\n", __func__); + return false; + } + + const int n_vocab_tgt = llama_vocab_n_tokens(vocab_tgt); + const int n_vocab_dft = llama_vocab_n_tokens(vocab_dft); + const int vocab_diff = n_vocab_tgt > n_vocab_dft + ? n_vocab_tgt - n_vocab_dft + : n_vocab_dft - n_vocab_tgt; + + if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) { + LOG_DBG("%s: DFlash draft vocab size differs too much from the target model (%d vs %d)\n", + __func__, n_vocab_dft, n_vocab_tgt); + return false; + } + + for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) { + const char * token_text_tgt = llama_vocab_get_text(vocab_tgt, i); + const char * token_text_dft = llama_vocab_get_text(vocab_dft, i); + + if (std::strcmp(token_text_tgt, token_text_dft) != 0) { + LOG_DBG("%s: DFlash draft token %d differs - target '%s', draft '%s'\n", __func__, i, + common_token_to_piece(vocab_tgt, i).c_str(), + common_token_to_piece(vocab_dft, i).c_str()); + return false; + } + } + + return true; +} + // state of an implementation of speculative decoding // // each implementation has a unique type and a state that is implementation-specific @@ -168,9 +208,18 @@ struct common_speculative_state { }; struct common_speculative_state_mtp; +struct common_speculative_state_dflash; static common_speculative_state_mtp * common_speculative_get_mtp_state(common_speculative * spec); static const common_speculative_state_mtp * common_speculative_get_mtp_state(const common_speculative * spec); +static common_speculative_state_dflash * common_speculative_get_dflash_state(common_speculative * spec); +static const common_speculative_state_dflash * common_speculative_get_dflash_state(const common_speculative * spec); +static int32_t common_speculative_feature_width(const common_speculative * spec); +static void dflash_append_target_features( + common_speculative_state_dflash & state, + const float * feature_rows, + int32_t n_rows); +static void dflash_clear_target_features(common_speculative_state_dflash & state); static void mtp_invalidate_cached_drafts(common_speculative_state_mtp & state); static std::vector mtp_speculative_gen_draft( @@ -302,6 +351,134 @@ struct common_speculative_state_mtp : public common_speculative_state { } }; +struct common_speculative_state_dflash : public common_speculative_state { + llama_context * ctx_tgt; + llama_context * ctx_dft; + + llama_batch batch = {}; + + int32_t block_size = 0; + int32_t mask_token_id = -1; + int32_t n_target_features = 0; + int32_t cross_ctx = 0; + bool ready = false; + + std::vector target_layer_ids; + std::vector target_window; + int32_t target_window_rows = 0; + + common_speculative_state_dflash( + enum common_speculative_type type, + llama_context * ctx_tgt, + llama_context * ctx_dft, + int32_t cross_ctx) + : common_speculative_state(type) + , ctx_tgt(ctx_tgt) + , ctx_dft(ctx_dft) + , cross_ctx(std::max(1, cross_ctx)) + { + const llama_model * model_dft = llama_get_model(ctx_dft); + + if (!common_speculative_are_dflash_compatible(llama_get_model(ctx_tgt), model_dft)) { + LOG_ERR("%s: DFlash draft model vocab/tokenizer is incompatible with the target model\n", __func__); + return; + } + + block_size = llama_model_dflash_block_size(model_dft); + mask_token_id = llama_model_dflash_mask_token_id(model_dft); + n_target_features = llama_model_dflash_n_target_features(model_dft); + const int32_t n_target_layers = llama_model_dflash_n_target_layers(model_dft); + + if (block_size <= 0 || mask_token_id < 0 || n_target_features <= 0 || n_target_layers <= 0) { + LOG_ERR("%s: invalid DFlash metadata (block_size=%d, mask_token_id=%d, n_target_features=%d, n_target_layers=%d)\n", + __func__, block_size, mask_token_id, n_target_features, n_target_layers); + return; + } + + target_layer_ids.resize((size_t) n_target_layers); + if (llama_model_dflash_target_layer_ids(model_dft, target_layer_ids.data(), n_target_layers) != n_target_layers) { + LOG_ERR("%s: failed to read DFlash target layer ids\n", __func__); + target_layer_ids.clear(); + return; + } + + if (!llama_set_dflash_capture_layers(ctx_tgt, target_layer_ids.data(), (int32_t) target_layer_ids.size())) { + LOG_ERR("%s: failed to configure DFlash target capture callback\n", __func__); + return; + } + + batch = llama_batch_init(std::max(1, block_size), 0, 1); + ready = true; + + LOG_INF("%s: DFlash context ready (n_ctx=%d, block_size=%d, cross_ctx=%d, n_target_features=%d)\n", + __func__, llama_n_ctx(ctx_dft), block_size, this->cross_ctx, n_target_features); + } + + ~common_speculative_state_dflash() override { + llama_clear_dflash_capture(ctx_tgt); + if (ctx_dft) { + llama_free(ctx_dft); + } + if (batch.token != nullptr) { + llama_batch_free(batch); + } + } + + void begin(const llama_tokens & prompt) override { + GGML_UNUSED(prompt); + target_window.clear(); + target_window_rows = 0; + llama_kv_cache_clear(ctx_dft); + } + + void draft( + const common_params_speculative & params, + const llama_tokens & prompt_tgt, + llama_token id_last, + llama_tokens & result) override { + GGML_UNUSED(prompt_tgt); + GGML_UNUSED(id_last); + + result.clear(); + if (!ready || target_window_rows <= 0) { + return; + } + + const int32_t n_draft = std::min(params.n_max, block_size); + if (n_draft <= 0) { + return; + } + + if (!llama_set_dflash_target_features_copy(ctx_dft, target_window.data(), target_window.size(), target_window_rows)) { + LOG_ERR("%s: failed to set DFlash target features\n", __func__); + return; + } + + llama_kv_cache_clear(ctx_dft); + batch.n_tokens = 0; + for (int32_t i = 0; i < n_draft; ++i) { + common_batch_add(batch, mask_token_id, cross_ctx + i, { 0 }, true); + } + + if (llama_decode(ctx_dft, batch) != 0) { + LOG_ERR("%s: llama_decode() failed for DFlash draft batch\n", __func__); + batch.n_tokens = 0; + return; + } + + result.reserve((size_t) n_draft); + for (int32_t i = 0; i < n_draft; ++i) { + result.push_back(common_sampler_sample_speculative(nullptr, ctx_dft, i, nullptr)); + } + + batch.n_tokens = 0; + } + + void accept(uint16_t n_accepted) override { + GGML_UNUSED(n_accepted); + } +}; + struct common_speculative_state_draft : public common_speculative_state { llama_context * ctx_tgt; // only used for retokenizing from ctx_dft llama_context * ctx_dft; @@ -1088,6 +1265,7 @@ std::string common_speculative_type_to_str(enum common_speculative_type type) { switch (type) { case COMMON_SPECULATIVE_TYPE_NONE: return "none"; case COMMON_SPECULATIVE_TYPE_DRAFT: return "draft"; + case COMMON_SPECULATIVE_TYPE_DFLASH: return "dflash"; case COMMON_SPECULATIVE_TYPE_MTP: return "mtp"; case COMMON_SPECULATIVE_TYPE_EAGLE3: return "eagle3"; case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return "ngram_simple"; @@ -1165,8 +1343,13 @@ common_speculative * common_speculative_init( } } + const bool has_dflash_stage = std::any_of(stages.begin(), stages.end(), [](const common_speculative_stage_params & stage) { + return stage.type == COMMON_SPECULATIVE_TYPE_DFLASH; + }); + const bool needs_draft_ctx = std::any_of(stages.begin(), stages.end(), [¶ms](const common_speculative_stage_params & stage) { return stage.type == COMMON_SPECULATIVE_TYPE_DRAFT || + stage.type == COMMON_SPECULATIVE_TYPE_DFLASH || (stage.type == COMMON_SPECULATIVE_TYPE_MTP && params.model_dft != nullptr); }); @@ -1177,7 +1360,33 @@ common_speculative * common_speculative_init( return nullptr; } - ctx_dft = llama_init_from_model(params.model_dft, params.cparams_dft); + llama_context_params cparams_dft = params.cparams_dft; + + if (has_dflash_stage) { + if (!llama_model_share_dflash_io_tensors(params.model_dft, llama_get_model(ctx_tgt))) { + LOG_ERR("%s: failed to share target IO tensors with DFlash draft model\n", __func__); + return nullptr; + } + + int32_t max_cross_ctx = 0; + for (const auto & stage : stages) { + if (stage.type != COMMON_SPECULATIVE_TYPE_DFLASH) { + continue; + } + + max_cross_ctx = std::max(max_cross_ctx, params.with_stage_overrides(stage).dflash_cross_ctx); + } + + const int32_t block_size = llama_model_dflash_block_size(params.model_dft); + if (block_size <= 0) { + LOG_ERR("%s: invalid DFlash draft block size\n", __func__); + return nullptr; + } + + cparams_dft.n_ctx = (uint32_t) (max_cross_ctx + block_size); + } + + ctx_dft = llama_init_from_model(params.model_dft, cparams_dft); if (ctx_dft == nullptr) { LOG_ERR("%s", "failed to create draft context\n"); return nullptr; @@ -1240,6 +1449,20 @@ common_speculative * common_speculative_init( )); break; } + case COMMON_SPECULATIVE_TYPE_DFLASH: { + auto state = std::make_unique( + config.type, + ctx_tgt, + ctx_dft, + config.params.dflash_cross_ctx); + if (!state->ready) { + LOG_ERR("%s: failed to initialize DFlash speculative state\n", __func__); + return nullptr; + } + impls.push_back(std::move(state)); + ctx_dft = nullptr; + break; + } case COMMON_SPECULATIVE_TYPE_MTP: { llama_context * ctx_mtp = ctx_dft; if (!ctx_mtp) { @@ -1604,6 +1827,10 @@ static bool common_speculative_collect_target_batch_features( const llama_batch & batch, common_speculative_feature_view & features) { features = {}; + if (common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_DFLASH)) { + return llama_spec_get_dflash_feature_view(ctx, batch, features); + } + if (!common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_MTP)) { return true; } @@ -1622,6 +1849,10 @@ static bool common_speculative_collect_target_seq_batch_features( llama_seq_id seq_id, common_speculative_feature_view & features) { features = {}; + if (common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_DFLASH)) { + return llama_spec_get_dflash_feature_view_for_seq(ctx, batch, seq_id, features); + } + if (!common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_MTP)) { return true; } @@ -1669,21 +1900,27 @@ int32_t common_speculative_on_target_seq_batch( const llama_batch & batch, llama_seq_id seq_id, bool is_prompt_warmup) { - llama_context * ctx_mtp = common_speculative_get_companion_ctx(spec); - ctx_mtp = ctx_mtp ? ctx_mtp : ctx_tgt; - if (ctx_tgt == nullptr || ctx_mtp == nullptr || batch.n_tokens <= 0) { + if (ctx_tgt == nullptr || batch.n_tokens <= 0) { return 0; } - const int n_embd_src = common_speculative_ctx_mtp_n_embd(ctx_tgt); - const int n_embd_dst = common_speculative_ctx_mtp_n_embd(ctx_mtp); - if (n_embd_src <= 0 || n_embd_dst <= 0) { - return -1; - } + if (!common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_DFLASH)) { + llama_context * ctx_mtp = common_speculative_get_companion_ctx(spec); + ctx_mtp = ctx_mtp ? ctx_mtp : ctx_tgt; + if (ctx_mtp == nullptr) { + return 0; + } - if (n_embd_src != n_embd_dst) { - LOG_ERR("MTP warmup hidden state width mismatch: n_embd_src = %d, n_embd_dst = %d\n", n_embd_src, n_embd_dst); - return -1; + const int n_embd_src = common_speculative_ctx_mtp_n_embd(ctx_tgt); + const int n_embd_dst = common_speculative_ctx_mtp_n_embd(ctx_mtp); + if (n_embd_src <= 0 || n_embd_dst <= 0) { + return -1; + } + + if (n_embd_src != n_embd_dst) { + LOG_ERR("MTP warmup hidden state width mismatch: n_embd_src = %d, n_embd_dst = %d\n", n_embd_src, n_embd_dst); + return -1; + } } common_speculative_feature_view feature_view; @@ -1723,6 +1960,10 @@ bool common_speculative_copy_output_hidden_rows( const std::vector & output_indices, std::vector & hidden_rows) { hidden_rows.clear(); + if (common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_DFLASH)) { + return llama_spec_copy_dflash_rows_from_output_indices(ctx, output_indices, hidden_rows); + } + if (!common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_MTP)) { return true; } @@ -1760,13 +2001,13 @@ static bool common_speculative_apply_hidden_rows( llama_pos pos_base, const std::vector & ids, const std::vector & hidden_rows) { - auto * mtp_state = common_speculative_get_mtp_state(spec); - if (mtp_state == nullptr || ids.empty()) { + const int32_t feature_width = common_speculative_feature_width(spec); + if (feature_width <= 0 || ids.empty()) { return true; } - const size_t expected_floats = ids.size() * (size_t) mtp_state->n_embd; - if (mtp_state->n_embd <= 0 || hidden_rows.size() != expected_floats) { + const size_t expected_floats = ids.size() * (size_t) feature_width; + if (hidden_rows.size() != expected_floats) { return false; } @@ -1777,7 +2018,7 @@ static bool common_speculative_apply_hidden_rows( common_speculative_feature_view feature_view; const bool have_feature_view = common_speculative_feature_view_from_hidden_rows( - hidden_rows, mtp_state->n_embd, seq_id, pos_base, feature_view); + hidden_rows, feature_width, seq_id, pos_base, feature_view); const int32_t ret = have_feature_view ? common_speculative_on_target_batch(spec, accepted_batch, feature_view, false) : -1; @@ -1794,7 +2035,7 @@ bool common_speculative_commit_accepted_hidden_rows( llama_token sampled_before, const std::vector & ids, const std::vector & hidden_rows) { - if (!common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_MTP) || ids.empty()) { + if (common_speculative_feature_width(spec) <= 0 || ids.empty()) { return true; } @@ -1815,7 +2056,7 @@ bool common_speculative_commit_accepted_output( llama_token sampled_before, const std::vector & ids, const std::vector & output_indices) { - if (!common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_MTP) || ids.empty()) { + if (common_speculative_feature_width(spec) <= 0 || ids.empty()) { return true; } @@ -1898,6 +2139,40 @@ static const common_speculative_state_mtp * common_speculative_get_mtp_state(con return common_speculative_get_mtp_state(const_cast(spec)); } +static common_speculative_state_dflash * common_speculative_get_dflash_state(common_speculative * spec) { + if (!spec) { + return nullptr; + } + + for (auto & impl : spec->impls) { + if (impl->type != COMMON_SPECULATIVE_TYPE_DFLASH) { + continue; + } + + if (auto * dflash_state = dynamic_cast(impl.get())) { + return dflash_state; + } + } + + return nullptr; +} + +static const common_speculative_state_dflash * common_speculative_get_dflash_state(const common_speculative * spec) { + return common_speculative_get_dflash_state(const_cast(spec)); +} + +static int32_t common_speculative_feature_width(const common_speculative * spec) { + if (const auto * dflash_state = common_speculative_get_dflash_state(spec); dflash_state != nullptr) { + return dflash_state->n_target_features; + } + + if (const auto * mtp_state = common_speculative_get_mtp_state(spec); mtp_state != nullptr) { + return mtp_state->n_embd; + } + + return 0; +} + static mtp_last_embd & mtp_get_last_embd(common_speculative_state_mtp & state, llama_seq_id seq_id) { auto & last = state.draft_cache_by_seq[seq_id]; if ((int) last.embd.size() != state.n_embd) { @@ -1941,6 +2216,44 @@ static void mtp_clear_target_hidden(common_speculative_state_mtp & state, llama_ state.draft_cache_by_seq.erase(seq_id); } +static void dflash_append_target_features( + common_speculative_state_dflash & state, + const float * feature_rows, + int32_t n_rows) { + if (feature_rows == nullptr || n_rows <= 0 || state.n_target_features <= 0 || state.cross_ctx <= 0) { + return; + } + + const size_t row_width = (size_t) state.n_target_features; + if (n_rows >= state.cross_ctx) { + const float * src = feature_rows + (size_t) (n_rows - state.cross_ctx) * row_width; + state.target_window.assign(src, src + (size_t) state.cross_ctx * row_width); + state.target_window_rows = state.cross_ctx; + return; + } + + const int32_t keep_old_rows = std::min(state.target_window_rows, state.cross_ctx - n_rows); + std::vector next_window((size_t) (keep_old_rows + n_rows) * row_width); + + if (keep_old_rows > 0) { + const float * old_src = state.target_window.data() + (size_t) (state.target_window_rows - keep_old_rows) * row_width; + std::memcpy(next_window.data(), old_src, (size_t) keep_old_rows * row_width * sizeof(float)); + } + + std::memcpy( + next_window.data() + (size_t) keep_old_rows * row_width, + feature_rows, + (size_t) n_rows * row_width * sizeof(float)); + + state.target_window = std::move(next_window); + state.target_window_rows = keep_old_rows + n_rows; +} + +static void dflash_clear_target_features(common_speculative_state_dflash & state) { + state.target_window.clear(); + state.target_window_rows = 0; +} + static bool common_speculative_capture_target_features(common_speculative * spec, const common_speculative_feature_view & features) { auto * mtp_state = common_speculative_get_mtp_state(spec); if (mtp_state == nullptr || features.kind != COMMON_SPECULATIVE_FEATURE_HIDDEN_STATE || features.width <= 0) { @@ -1973,11 +2286,13 @@ bool common_speculative_has_sequence_hidden(const common_speculative * spec, lla void common_speculative_clear_sequence_hidden(common_speculative * spec, llama_seq_id seq_id) { auto * mtp_state = common_speculative_get_mtp_state(spec); - if (mtp_state == nullptr) { - return; + if (mtp_state != nullptr) { + mtp_clear_target_hidden(*mtp_state, seq_id); } - mtp_clear_target_hidden(*mtp_state, seq_id); + if (auto * dflash_state = common_speculative_get_dflash_state(spec); dflash_state != nullptr) { + dflash_clear_target_features(*dflash_state); + } } llama_context * common_speculative_get_companion_ctx(common_speculative * spec) { @@ -1985,6 +2300,10 @@ llama_context * common_speculative_get_companion_ctx(common_speculative * spec) return mtp_state->ctx_mtp; } + if (auto * dflash_state = common_speculative_get_dflash_state(spec); dflash_state != nullptr) { + return dflash_state->ctx_dft; + } + return nullptr; } @@ -2023,6 +2342,39 @@ int32_t common_speculative_on_target_batch( const llama_batch & batch, const common_speculative_feature_view & features, bool is_prompt_warmup) { + if (auto * dflash_state = common_speculative_get_dflash_state(spec); dflash_state != nullptr) { + GGML_UNUSED(is_prompt_warmup); + + if (features.kind != COMMON_SPECULATIVE_FEATURE_HIDDEN_STATE || batch.n_tokens <= 0) { + return 0; + } + + if (features.width != dflash_state->n_target_features) { + LOG_ERR("%s: DFlash feature width mismatch: got %d expected %d\n", + __func__, features.width, dflash_state->n_target_features); + return -1; + } + + if (batch.n_seq_id == nullptr || batch.seq_id == nullptr || batch.n_seq_id[0] <= 0 || batch.seq_id[0] == nullptr) { + return -1; + } + + const llama_seq_id seq_id = batch.seq_id[0][0]; + for (int i = 0; i < batch.n_tokens; ++i) { + if (batch.n_seq_id[i] != 1 || batch.seq_id[i] == nullptr || batch.seq_id[i][0] != seq_id) { + return -1; + } + } + + std::vector hidden_rows_storage; + if (!common_speculative_feature_view_copy_batch_rows(features, batch, seq_id, &hidden_rows_storage)) { + return -1; + } + + dflash_append_target_features(*dflash_state, hidden_rows_storage.data(), batch.n_tokens); + return 0; + } + auto * mtp_state = common_speculative_get_mtp_state(spec); if (mtp_state == nullptr) { return 0; diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 0664e5aa..232ba706 100644 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -64,6 +64,7 @@ class Model: model_name: str | None metadata_override: Path | None dir_model_card: Path + target_model_dir: Path | None # subclasses should define this! model_arch: gguf.MODEL_ARCH @@ -71,7 +72,8 @@ class Model: def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool = False, use_temp_file: bool = False, eager: bool = False, metadata_override: Path | None = None, model_name: str | None = None, - split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False): + split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False, + target_model_dir: Path | None = None): if type(self) is Model: raise TypeError(f"{type(self).__name__!r} should not be directly instantiated") @@ -93,6 +95,7 @@ class Model: self.metadata_override = metadata_override self.model_name = model_name self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py + self.target_model_dir = target_model_dir # Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type if self.ftype == gguf.LlamaFileType.GUESSED: @@ -459,6 +462,14 @@ class Model: with open(dir_model / "config.json", "r", encoding="utf-8") as f: return json.load(f) + @staticmethod + def load_text_hparams(dir_model: Path) -> dict[str, Any]: + hparams = Model.load_hparams(dir_model) + text_config = hparams.get("text_config") + if isinstance(text_config, dict): + return {**hparams, **text_config} + return hparams + @classmethod def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]: assert names @@ -500,13 +511,14 @@ class Model: return seems_special # used for GPT-2 BPE and WordPiece vocabs - def get_vocab_base(self) -> tuple[list[str], list[int], str]: + def get_vocab_base(self, dir_model: Path | None = None, vocab_size: int | None = None) -> tuple[list[str], list[int], str]: tokens: list[str] = [] toktypes: list[int] = [] from transformers import AutoTokenizer - tokenizer = AutoTokenizer.from_pretrained(self.dir_model) - vocab_size = self.hparams.get("vocab_size", len(tokenizer.vocab)) + dir_model = dir_model or self.dir_model + tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True) + vocab_size = vocab_size or self.hparams.get("vocab_size", len(tokenizer.vocab)) assert max(tokenizer.vocab.values()) < vocab_size tokpre = self.get_vocab_base_pre(tokenizer) @@ -594,6 +606,18 @@ class Model: if chkhsh == "e636dc30a262dcc0d8c323492e32ae2b70728f4df7dfe9737d9f920a282b8aea": # ref: https://huggingface.co/Qwen/Qwen1.5-7B res = "qwen2" + if chkhsh == "d30d75d9059f1aa2c19359de71047b3ae408c70875e8a3ccf8c5fba56c9d8af4": + # ref: https://huggingface.co/Qwen/Qwen3.5-9B-Instruct + res = "qwen35" + if chkhsh == "99cc61242f7106804ce24fdf3a6451e4a55251078dffd5453c806e11b2310db3": + # ref: https://huggingface.co/Qwen/Qwen3.5-27B + res = "qwen35" + if chkhsh == "1444df51289cfa8063b96f0e62b1125440111bc79a52003ea14b6eac7016fd5f": + # ref: https://huggingface.co/z-lab/Qwen3.5-27B-DFlash (uses Qwen3.5 tokenizer) + res = "qwen35" + if chkhsh == "4f53cda18c2baa0c0354bb5f9a3ecbe5ed12ab4d8e11ba873c2f11161202b945": + # ref: https://huggingface.co/Qwen/Qwen3.6-35B-A3B (identical pre-tokenizer regex to qwen35) + res = "qwen35" if chkhsh == "b6dc8df998e1cfbdc4eac8243701a65afe638679230920b50d6f17d81c098166": # ref: https://huggingface.co/allenai/OLMo-1.7-7B-hf res = "olmo" @@ -681,19 +705,20 @@ class Model: return res # Marker: End get_vocab_base_pre - def _set_vocab_gpt2(self) -> None: - tokens, toktypes, tokpre = self.get_vocab_base() + def _set_vocab_gpt2(self, dir_model: Path | None = None, vocab_size: int | None = None) -> None: + dir_model = dir_model or self.dir_model + tokens, toktypes, tokpre = self.get_vocab_base(dir_model=dir_model, vocab_size=vocab_size) self.gguf_writer.add_tokenizer_model("gpt2") self.gguf_writer.add_tokenizer_pre(tokpre) self.gguf_writer.add_token_list(tokens) self.gguf_writer.add_token_types(toktypes) - special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True) + special_vocab = gguf.SpecialVocab(dir_model, load_merges=True) special_vocab.add_to_gguf(self.gguf_writer) - def _set_vocab_qwen(self): - dir_model = self.dir_model - hparams = self.hparams + def _set_vocab_qwen(self, dir_model: Path | None = None, hparams: dict[str, Any] | None = None): + dir_model = dir_model or self.dir_model + hparams = hparams or self.hparams tokens: list[str] = [] toktypes: list[int] = [] @@ -2246,15 +2271,118 @@ class Qwen2MoeModel(Model): if len(experts) > 0: raise ValueError(f"Unprocessed experts: {experts}") + @Model.register("Qwen3ForCausalLM") class Qwen3Model(Qwen2Model): model_arch = gguf.MODEL_ARCH.QWEN3 + @Model.register("Qwen3MoeForCausalLM") class Qwen3MoeModel(Qwen2MoeModel): model_arch = gguf.MODEL_ARCH.QWEN3MOE +@Model.register("DFlashDraftModel") +class DFlashDraftModel(Qwen3Model): + model_arch = gguf.MODEL_ARCH.DFLASH_DRAFT + + _target_hparams: dict[str, Any] | None = None + + def _require_target_model_dir(self) -> Path: + if self.target_model_dir is None: + raise ValueError("DFlashDraftModel conversion requires --target-model-dir ") + return self.target_model_dir + + def _get_target_hparams(self) -> dict[str, Any]: + if self._target_hparams is None: + self._target_hparams = Model.load_text_hparams(self._require_target_model_dir()) + return self._target_hparams + + def set_vocab(self): + target_hparams = self._get_target_hparams() + self._set_vocab_gpt2( + dir_model=self._require_target_model_dir(), + vocab_size=target_hparams.get("vocab_size"), + ) + + def set_gguf_parameters(self): + super().set_gguf_parameters() + + self.gguf_writer.add_causal_attention(False) + self.gguf_writer.add_rope_dimension_count(self.hparams.get("head_dim", 128)) + + arch = self.gguf_writer.arch + dflash_cfg = self.hparams.get("dflash_config") + dflash_cfg = dflash_cfg if isinstance(dflash_cfg, dict) else {} + + def dflash_required_value(name: str) -> Any: + if name in dflash_cfg: + return dflash_cfg[name] + if name in self.hparams: + return self.hparams[name] + raise ValueError(f"DFlashDraftModel conversion requires explicit {name} metadata") + + block_size = int(dflash_required_value("block_size")) + self.gguf_writer.add_uint32(f"{arch}.dflash.block_size", block_size) + + mask_token_id = int(dflash_required_value("mask_token_id")) + self.gguf_writer.add_uint32(f"{arch}.dflash.mask_token_id", mask_token_id) + + target_layer_ids = [int(layer_id) for layer_id in dflash_required_value("target_layer_ids")] + if len(target_layer_ids) == 0: + raise ValueError("DFlashDraftModel conversion requires at least one target_layer_id") + self.gguf_writer.add_array(f"{arch}.dflash.target_layer_ids", target_layer_ids) + + if "n_target_features" in dflash_cfg: + n_target_features = int(dflash_cfg["n_target_features"]) + elif "n_target_features" in self.hparams: + n_target_features = int(self.hparams["n_target_features"]) + else: + draft_hidden_size = self.hparams.get("hidden_size") + if draft_hidden_size is None: + raise ValueError("DFlashDraftModel: draft config is missing hidden_size") + + n_target_features = int(draft_hidden_size) * len(target_layer_ids) + + target_hparams = self._get_target_hparams() + target_hidden_size = target_hparams.get("hidden_size") + if target_hidden_size is not None and int(target_hidden_size) != int(draft_hidden_size): + logger.warning( + "DFlashDraftModel: target hidden_size=%d differs from draft hidden_size=%d; using draft hidden width for n_target_features", + int(target_hidden_size), + int(draft_hidden_size), + ) + + logger.info( + "DFlashDraftModel: inferred n_target_features=%d from draft hidden_size=%d and n_target_layers=%d", + n_target_features, + int(draft_hidden_size), + len(target_layer_ids), + ) + + self.gguf_writer.add_uint32(f"{arch}.dflash.n_target_features", n_target_features) + + logger.info( + "DFlashDraftModel metadata: block_size=%s mask_token_id=%s target_layer_ids=%s n_target_features=%s", + block_size, + mask_token_id, + target_layer_ids, + n_target_features, + ) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if name == "fc.weight": + return [(f"{gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.DFLASH_FC]}.weight", data_torch)] + if name == "hidden_norm.weight": + return [(f"{gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.DFLASH_HIDDEN_NORM]}.weight", data_torch)] + if name == "norm.weight": + name = "model.norm.weight" + elif name.startswith("layers."): + name = f"model.{name}" + + return super().modify_tensors(data_torch, name, bid) + + @Model.register("Ernie4_5_ForCausalLM", "Ernie4_5ForCausalLM") class Ernie4_5Model(Model): model_arch = gguf.MODEL_ARCH.ERNIE4_5 @@ -4385,6 +4513,7 @@ class JaisModel(Model): super().prepare_tensors() self.gguf_writer.add_max_alibi_bias(self.max_alibi_bias) + @Model.register("MiniMaxM2ForCausalLM") class MiniMaxM2Model(Model): model_arch = gguf.MODEL_ARCH.MINIMAXM2 @@ -4457,10 +4586,12 @@ class SmolLM3Model(LlamaModel): chat_template = tokenizer.chat_template.replace("[:]", "") self.gguf_writer.add_chat_template(chat_template) + @Model.register("SeedOssForCausalLM") class SeedOssModel(Model): model_arch = gguf.MODEL_ARCH.SEED_OSS + @Model.register("Dots1ForCausalLM") class Dots1Model(Qwen2MoeModel): model_arch = gguf.MODEL_ARCH.DOTS1 @@ -4621,6 +4752,7 @@ class Glm4MoeModel(Model): if len(experts) > 0: raise ValueError(f"Unprocessed experts: {experts}") + @Model.register("ChatGLMModel", "ChatGLMForConditionalGeneration") class ChatGLMModel(Model): model_arch = gguf.MODEL_ARCH.CHATGLM @@ -4803,6 +4935,7 @@ class ChatGLMModel(Model): name = name.removeprefix("transformer.") return [(self.map_tensor_name(name), data_torch)] + @Model.register("BailingMoeV2ForCausalLM") class BailingMoeV2Model(Model): model_arch = gguf.MODEL_ARCH.BAILINGMOE2 @@ -5028,6 +5161,10 @@ def parse_args() -> argparse.Namespace: "--metadata", type=Path, help="Specify the path for an authorship metadata override file" ) + parser.add_argument( + "--target-model-dir", type=Path, + help="matching target model directory; required for DFlash conversion to reuse tokenizer and infer target feature width", + ) return parser.parse_args() @@ -5107,7 +5244,8 @@ def main() -> None: metadata_override=args.metadata, model_name=args.model_name, split_max_tensors=args.split_max_tensors, split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run, - small_first_shard=args.no_tensor_first_split) + small_first_shard=args.no_tensor_first_split, + target_model_dir=args.target_model_dir) if args.vocab_only: logger.info("Exporting model vocab...") diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index 96936717..da894158 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -78,6 +78,10 @@ models = [ {"name": "refact", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/smallcloudai/Refact-1_6-base", }, {"name": "command-r", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/CohereForAI/c4ai-command-r-v01", }, {"name": "qwen2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen1.5-7B", }, + {"name": "qwen35", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3.5-9B-Instruct", "chkhsh": "d30d75d9059f1aa2c19359de71047b3ae408c70875e8a3ccf8c5fba56c9d8af4", }, + {"name": "qwen35", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3.5-27B", "chkhsh": "99cc61242f7106804ce24fdf3a6451e4a55251078dffd5453c806e11b2310db3", }, + {"name": "qwen35", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/z-lab/Qwen3.5-27B-DFlash", "chkhsh": "1444df51289cfa8063b96f0e62b1125440111bc79a52003ea14b6eac7016fd5f", }, + {"name": "qwen35", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3.6-35B-A3B", "chkhsh": "4f53cda18c2baa0c0354bb5f9a3ecbe5ed12ab4d8e11ba873c2f11161202b945", }, {"name": "olmo", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/allenai/OLMo-1.7-7B-hf", }, {"name": "dbrx", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/databricks/dbrx-base", }, {"name": "jina-v2-en", "tokt": TOKENIZER_TYPE.WPM, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-en", }, # WPM! @@ -154,39 +158,46 @@ for model in models: if tokt == TOKENIZER_TYPE.SPM or tokt == TOKENIZER_TYPE.UGM: continue - # Skip if the tokenizer folder does not exist or there are other download issues previously - if not os.path.exists(f"models/tokenizers/{name}"): - logger.warning(f"Directory for tokenizer {name} not found. Skipping...") - continue + chkhsh = model.get("chkhsh") - # create the tokenizer - try: - if name == "t5": - tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}", use_fast=False) - else: - tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}") - except (OSError, TypeError) as e: - logger.error(f"Error loading tokenizer for model {name}. The model may not exist or is not accessible with the provided token. Error: {e}") - continue # Skip to the next model if the tokenizer can't be loaded + if chkhsh is None: + # Skip if the tokenizer folder does not exist or there are other download issues previously + if not os.path.exists(f"models/tokenizers/{name}"): + logger.warning(f"Directory for tokenizer {name} not found. Skipping...") + continue - chktok = tokenizer.encode(CHK_TXT) - chkhsh = sha256(str(chktok).encode()).hexdigest() + # create the tokenizer + try: + if name == "t5": + tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}", use_fast=False) + else: + tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}") + except (OSError, TypeError) as e: + logger.error(f"Error loading tokenizer for model {name}. The model may not exist or is not accessible with the provided token. Error: {e}") + continue # Skip to the next model if the tokenizer can't be loaded + + chktok = tokenizer.encode(CHK_TXT) + chkhsh = sha256(str(chktok).encode()).hexdigest() logger.info(f"model: {name}") logger.info(f"tokt: {tokt}") logger.info(f"repo: {model['repo']}") - logger.info(f"chktok: {chktok}") logger.info(f"chkhsh: {chkhsh}") - # print the "pre_tokenizer" content from the tokenizer.json - with open(f"models/tokenizers/{name}/tokenizer.json", "r", encoding="utf-8") as f: - cfg = json.load(f) - normalizer = cfg["normalizer"] - logger.info("normalizer: " + json.dumps(normalizer, indent=4)) - pre_tokenizer = cfg["pre_tokenizer"] - logger.info("pre_tokenizer: " + json.dumps(pre_tokenizer, indent=4)) - if "ignore_merges" in cfg["model"]: - logger.info("ignore_merges: " + json.dumps(cfg["model"]["ignore_merges"], indent=4)) + if model.get("chkhsh") is None: + logger.info(f"chktok: {chktok}") + + # print the "pre_tokenizer" content from the tokenizer.json + with open(f"models/tokenizers/{name}/tokenizer.json", "r", encoding="utf-8") as f: + cfg = json.load(f) + normalizer = cfg["normalizer"] + logger.info("normalizer: " + json.dumps(normalizer, indent=4)) + pre_tokenizer = cfg["pre_tokenizer"] + logger.info("pre_tokenizer: " + json.dumps(pre_tokenizer, indent=4)) + if "ignore_merges" in cfg["model"]: + logger.info("ignore_merges: " + json.dumps(cfg["model"]["ignore_merges"], indent=4)) + else: + logger.info("using manually provided tokenizer hash") logger.info("") @@ -353,6 +364,6 @@ logger.info("\nRun the following commands to generate the vocab files for testin for model in models: name = model["name"] - print(f"python3 convert_hf_to_gguf.py models/tokenizers/{name}/ --outfile models/ggml-vocab-{name}.gguf --vocab-only") # noqa: NP100 + logger.info(f"python3 convert_hf_to_gguf.py models/tokenizers/{name}/ --outfile models/ggml-vocab-{name}.gguf --vocab-only") # noqa: NP100 logger.info("\n") diff --git a/examples/server/server-context.cpp b/examples/server/server-context.cpp index 99345458..af6924fe 100644 --- a/examples/server/server-context.cpp +++ b/examples/server/server-context.cpp @@ -134,6 +134,14 @@ static bool server_speculative_has_mtp(const common_params_speculative & spec) { return spec.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP); } +static bool server_speculative_has_dflash(const common_params_speculative & spec) { + return spec.has_stage_type(COMMON_SPECULATIVE_TYPE_DFLASH); +} + +static bool server_speculative_has_target_features(const common_params_speculative & spec) { + return server_speculative_has_mtp(spec) || server_speculative_has_dflash(spec); +} + static bool server_speculative_same_stage_types( const common_params_speculative & lhs, const common_params_speculative & rhs) { @@ -217,6 +225,18 @@ static common_speculative_stage_params server_parse_speculative_stage_json(const } server_context::~server_context() { + // Speculative state may reference the live target context during teardown. + for (server_slot& slot : slots) { + if (slot.ctx_sampling != nullptr) { + common_sampler_free(slot.ctx_sampling); + } + slot.spec_ckpt.clear(); + common_speculative_free(slot.spec); + slot.spec = nullptr; + slot.ctx_dft = nullptr; + llama_batch_free(slot.batch_spec); + } + if (ctx) { llama_free(ctx); ctx = nullptr; @@ -238,19 +258,6 @@ server_context::~server_context() { model_draft = nullptr; } - // Clear any sampling context - for (server_slot& slot : slots) { - if (slot.ctx_sampling != nullptr) { - common_sampler_free(slot.ctx_sampling); - } - slot.spec_ckpt.clear(); - if (slot.ctx_dft) { - llama_free(slot.ctx_dft); - } - common_speculative_free(slot.spec); - llama_batch_free(slot.batch_spec); - } - llama_batch_free(batch); } @@ -286,6 +293,13 @@ bool server_context::load_model(const gpt_params& params_) { params_base.speculative.model_dft = nullptr; } + if (server_speculative_has_dflash(params_base.speculative) && params_base.n_parallel > 1) { + LOG_ERROR("DFlash is currently limited to a single server slot (-np 1).\n", { + {"n_parallel", params_base.n_parallel}, + }); + return false; + } + bool has_draft_model = !params_base.speculative.model.empty() || !params_base.speculative.params.empty(); std::string& mmproj_path = params_base.mmproj.path; if (!mmproj_path.empty()) { @@ -470,7 +484,7 @@ void server_context::init() { bool can_spec = true; if (!params_base.dry_run) { can_spec = common_speculative_is_compat(ctx); - } + } if (!can_spec) { SRV_WRN("%s", "speculative decoding not supported by this context\n"); } @@ -1656,7 +1670,7 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task) int32_t banbuffer_size = json_value(data, "banbuffer_size", 0); slot.n_buffer = 0; // Ensure buffer calculation starts fresh for this slot slot.rewind_count_max = json_value(data, "rewind_count_max", -1); - + const auto& banned_strings = data.find("banned_strings"); if (banned_strings != data.end() && banned_strings->is_array()) { slot.ban_phrases.clear(); @@ -2805,7 +2819,7 @@ static size_t load_server_tokens_from_file(const std::string & filename, server size_t pos = 0; json token_json; if (file.is_open()) { - file >> token_json; + file >> token_json; pos = file.tellg(); file.close(); } @@ -3727,7 +3741,7 @@ bool server_context::create_checkpoint(server_slot & slot) { slot.server_cached_prompt.checkpoints.erase(slot.server_cached_prompt.checkpoints.begin()); } - + auto & cur = slot.server_cached_prompt.checkpoints.emplace_back(); server_prompt_checkpoint_update(cur, ctx, slot.id, slot.cache_tokens.n_tokens(), pos_min, pos_max, slot.n_past_offset); @@ -4060,7 +4074,7 @@ void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t slot.do_checkpoint = true; break; } - + } LOG_VERBOSE("prompt processing progress", { {"id_slot", slot.id}, @@ -4143,7 +4157,7 @@ static void restore_speculative_checkpoint( common_speculative_type spec_type_used, llama_token sampled_before, const std::vector & ids, int n_draft, - const std::vector & mtp_hidden_state_pre, int32_t mtp_n_past_base) { + const std::vector & spec_feature_rows_pre, int32_t spec_n_past_base) { if (slot.spec_ckpt.per_step_enabled) { const int step = (int)ids.size() - 1; llama_spec_ckpt_restore(ctx, slot.id, slot.spec_ckpt.n_past, step); @@ -4155,16 +4169,16 @@ static void restore_speculative_checkpoint( common_sampler_accept(slot.ctx_sampling, ctx, id, true); } - // Update MTP KV cache and hidden state using embeddings collected before checkpoint restore. - if (slot.has_mtp && !mtp_hidden_state_pre.empty()) { + // Update speculative target features using rows collected before checkpoint restore. + if (server_speculative_has_target_features(slot.params.speculative) && !spec_feature_rows_pre.empty()) { if (!common_speculative_commit_accepted_hidden_rows( slot.spec, spec_type_used, slot.id, - mtp_n_past_base, + spec_n_past_base, sampled_before, ids, - mtp_hidden_state_pre)) { + spec_feature_rows_pre)) { common_speculative_clear_sequence_hidden(slot.spec, slot.id); } else if (spec_type_used != COMMON_SPECULATIVE_TYPE_MTP) { SLT_DBG(slot, "%s", "synced MTP target hidden state from accepted-prefix rows after per-step restore"); @@ -4201,7 +4215,7 @@ static void restore_speculative_checkpoint( if (ret != 0) { SLT_ERR(slot, "failed to re-decode accepted tokens after checkpoint restore: %d\n", ret); } - if (slot.has_mtp) { + if (server_speculative_has_target_features(slot.params.speculative)) { const int n_accepted = (int)ids.size(); std::vector redecoded_indices(n_accepted); for (int j = 0; j < n_accepted; ++j) { @@ -4272,20 +4286,20 @@ void server_context::speculative_decoding_accept() { } const bool any_rejected = (ids.size() - 1) < n_draft; - int32_t mtp_n_past_base = 0; - std::vector mtp_hidden_state_pre; + int32_t spec_n_past_base = 0; + std::vector spec_feature_rows_pre; std::vector accepted_output_indices; - if (slot.has_mtp) { + if (server_speculative_has_target_features(slot.params.speculative)) { const int32_t n_pre_spec_tokens = slot.cache_tokens.n_tokens() - (int32_t)(slot.drafted.size() + 1); - mtp_n_past_base = slot.cache_tokens.pos_next(n_pre_spec_tokens); + spec_n_past_base = slot.cache_tokens.pos_next(n_pre_spec_tokens); if (!ids.empty()) { accepted_output_indices.assign(slot.i_batch_dft.begin(), slot.i_batch_dft.begin() + ids.size()); } if (any_rejected && slot.spec_ckpt.valid && !accepted_output_indices.empty()) { - if (!common_speculative_copy_output_hidden_rows(slot.spec, ctx, accepted_output_indices, mtp_hidden_state_pre)) { - mtp_hidden_state_pre.clear(); + if (!common_speculative_copy_output_hidden_rows(slot.spec, ctx, accepted_output_indices, spec_feature_rows_pre)) { + spec_feature_rows_pre.clear(); } } } @@ -4317,15 +4331,15 @@ void server_context::speculative_decoding_accept() { // for recurrent/hybrid models: if any drafts were rejected, restore recurrent state if (any_rejected && slot.spec_ckpt.valid) { - restore_speculative_checkpoint(slot, ctx, model, spec_type_used, sampled_before, ids, n_draft, mtp_hidden_state_pre, mtp_n_past_base); + restore_speculative_checkpoint(slot, ctx, model, spec_type_used, sampled_before, ids, n_draft, spec_feature_rows_pre, spec_n_past_base); } else { - if (slot.has_mtp && !accepted_output_indices.empty()) { + if (server_speculative_has_target_features(slot.params.speculative) && !accepted_output_indices.empty()) { if (!common_speculative_commit_accepted_output( slot.spec, ctx, spec_type_used, slot.id, - mtp_n_past_base, + spec_n_past_base, sampled_before, ids, accepted_output_indices)) { @@ -4395,15 +4409,15 @@ void server_context::release_slot_after_final_response(server_slot & slot) { void server_context::send_token_results(completion_token_outputs& results, server_slot& slot, int32_t n) { int count = 0; bool released = false; - + int32_t start_pos = slot.n_past - (int32_t)slot.token_buffer.size() + 1; for (auto& it : results) { bool has_next = process_token(it, slot); - + // Clean up positional bans for the token we just confirmed/sent slot.positional_bans.erase(start_pos + count); - + count++; if (!has_next) { if (slot.stopped_limit && !slot.stopped_eos && !slot.stopped_word) { @@ -4436,7 +4450,7 @@ inline int32_t check_ban_phrase(server_slot& slot) { std::string string_buffer; std::vector token_offsets; - + for (const auto& it : slot.token_buffer) { token_offsets.push_back(string_buffer.size()); string_buffer += it.text_to_send; @@ -4488,10 +4502,10 @@ inline int32_t check_ban_phrase(server_slot& slot) { if (found) { int32_t token_idx = -1; for (size_t i = 0; i < token_offsets.size(); ++i) { - size_t len = (i == token_offsets.size() - 1) - ? string_buffer.size() - token_offsets[i] + size_t len = (i == token_offsets.size() - 1) + ? string_buffer.size() - token_offsets[i] : token_offsets[i+1] - token_offsets[i]; - + if (best_start >= token_offsets[i] && best_start < token_offsets[i] + len) { token_idx = (int32_t)i; break; @@ -4509,7 +4523,7 @@ inline int32_t check_ban_phrase(server_slot& slot) { inline void rewind_context(server_slot& slot, int32_t ban_pos) { slot.rewind_count++; - + int32_t buffer_start_pos = slot.n_past - (int32_t)slot.token_buffer.size() + 1; int32_t n_keep_buffer = ban_pos - buffer_start_pos; if (n_keep_buffer < 0) n_keep_buffer = 0; @@ -4518,9 +4532,9 @@ inline void rewind_context(server_slot& slot, int32_t ban_pos) { int32_t n = 0; for (auto result = slot.token_buffer.begin() + n_keep_buffer; result != slot.token_buffer.end(); result++) { llama_token banned_tok = result->tok; - + if (n == 0) { - LLAMA_LOG_DEBUG("Banned pattern detected at pos %d. Banning token %d ('%s') and rewinding.\n", + LLAMA_LOG_DEBUG("Banned pattern detected at pos %d. Banning token %d ('%s') and rewinding.\n", ban_pos, banned_tok, result->text_to_send.c_str()); } @@ -4533,7 +4547,7 @@ inline void rewind_context(server_slot& slot, int32_t ban_pos) { } int32_t n_rewind_total = (slot.n_past + 1) - ban_pos; - + size_t n_keep_cache = 0; if (ban_pos > 0) { n_keep_cache = (size_t)(ban_pos - 1); @@ -4546,13 +4560,13 @@ inline void rewind_context(server_slot& slot, int32_t ban_pos) { if (n_keep_cache < slot.cache_tokens.size()) { slot.sampled = slot.cache_tokens[n_keep_cache]; } else { - slot.sampled = 0; + slot.sampled = 0; } // Truncate cache slot.cache_tokens.keep_first(n_keep_cache); slot.n_past = slot.cache_tokens.n_tokens(); - + // Remove from KV cache llama_kv_cache_seq_rm(slot.ctx, slot.id, slot.cache_tokens.pos_next(slot.n_past), -1); @@ -4590,13 +4604,13 @@ void server_context::buffer_and_check_string_ban(server_slot & slot, completion_ // Automatic / Heuristic logic // Account for strings + regex + regex_ci size_t total_bans = slot.ban_phrases.size() + slot.ban_regex.size() + slot.ban_regex_ci.size(); - + // Heuristic: Allow if under 20 OR under 2 * total_bans // Conversely: Stop if >= 20 AND > 2 * total_bans if (slot.rewind_count >= 20 && slot.rewind_count > 2 * total_bans) { allow_rewind = false; } - } + } else if (slot.rewind_count_max > 0) { // Strict limit logic if (slot.rewind_count >= slot.rewind_count_max) { @@ -4613,7 +4627,7 @@ void server_context::buffer_and_check_string_ban(server_slot & slot, completion_ else if (buffer_full || !next_token) { slot.rewind_status = false; slot.rewind_count = 0; - + if (!next_token) { // send all remaining tokens send_token_results(slot.token_buffer, slot); @@ -4625,7 +4639,7 @@ void server_context::buffer_and_check_string_ban(server_slot & slot, completion_ } else { // buffer the result, wait for more tokens to validate string - slot.sampled = result.tok; + slot.sampled = result.tok; } } @@ -4710,9 +4724,9 @@ void server_context::process_batch_tokens(int32_t & n_batch) { continue; // continue loop of n_batch } - if (server_speculative_has_mtp(params_base.speculative)) { + if (server_speculative_has_target_features(params_base.speculative)) { for (auto & slot : slots) { - if (!slot.spec || !slot.has_mtp) { + if (!slot.spec || !server_speculative_has_target_features(slot.params.speculative)) { continue; } @@ -4722,7 +4736,7 @@ void server_context::process_batch_tokens(int32_t & n_batch) { } if (common_speculative_on_target_seq_batch(slot.spec, ctx, batch_view, slot.id, true) != 0) { - LOG_ERROR("failed to warm up MTP state from prompt batch for slot %d\n", slot.id); + LOG_ERROR("failed to warm up speculative target-feature state from prompt batch for slot %d\n", slot.id); } } } diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 9219f0e5..232b664c 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -236,6 +236,8 @@ class MODEL_ARCH(IntEnum): GEMMA3 = auto() GEMMA4 = auto() GEMMA4_MTP = auto() + DFLASH = auto() + DFLASH_DRAFT = auto() STARCODER2 = auto() MAMBA = auto() XVERSE = auto() @@ -260,6 +262,7 @@ class MODEL_ARCH(IntEnum): SMOLLM3 = auto() SEED_OSS = auto() + class MODEL_TENSOR(IntEnum): TOKEN_EMBD = auto() TOKEN_EMBD_NORM = auto() @@ -366,6 +369,8 @@ class MODEL_TENSOR(IntEnum): MTP_POST_PROJ = auto() MTP_TOKEN_ORDERING = auto() MTP_CENTROIDS = auto() + DFLASH_FC = auto() + DFLASH_HIDDEN_NORM = auto() MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { @@ -402,6 +407,8 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.GEMMA3: "gemma3", MODEL_ARCH.GEMMA4: "gemma4", MODEL_ARCH.GEMMA4_MTP: "gemma4_mtp", + MODEL_ARCH.DFLASH: "dflash", + MODEL_ARCH.DFLASH_DRAFT: "dflash-draft", MODEL_ARCH.STARCODER2: "starcoder2", MODEL_ARCH.MAMBA: "mamba", MODEL_ARCH.XVERSE: "xverse", @@ -534,6 +541,8 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.MTP_POST_PROJ: "mtp_post_proj", MODEL_TENSOR.MTP_TOKEN_ORDERING: "mtp_token_ordering", MODEL_TENSOR.MTP_CENTROIDS: "mtp_centroids", + MODEL_TENSOR.DFLASH_FC: "dflash_fc", + MODEL_TENSOR.DFLASH_HIDDEN_NORM: "dflash_hidden_norm", } MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { @@ -1235,6 +1244,38 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD, MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM, ], + MODEL_ARCH.DFLASH: [ + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_POST_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.DFLASH_FC, + MODEL_TENSOR.DFLASH_HIDDEN_NORM, + ], + MODEL_ARCH.DFLASH_DRAFT: [ + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_POST_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.DFLASH_FC, + MODEL_TENSOR.DFLASH_HIDDEN_NORM, + ], MODEL_ARCH.BITNET: [ MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, @@ -1644,89 +1685,89 @@ class ExpertGatingFuncType(IntEnum): # ALL VALUES SHOULD BE THE SAME HERE AS THEY ARE OVER THERE. class LlamaFileType(IntEnum): ALL_F32 = 0 - MOSTLY_F16 = 1 #except 1d tensors - MOSTLY_Q4_0 = 2 #except 1d tensors - MOSTLY_Q4_1 = 3 #except 1d tensors - MOSTLY_Q8_0 = 7 #except 1d tensors - MOSTLY_Q5_0 = 8 #except 1d tensors - MOSTLY_Q5_1 = 9 #except 1d tensors - MOSTLY_Q2_K = 10 #except 1d tensors - MOSTLY_Q3_K_S = 11 #except 1d tensors - MOSTLY_Q3_K_M = 12 #except 1d tensors - MOSTLY_Q3_K_L = 13 #except 1d tensors - MOSTLY_Q4_K_S = 14 #except 1d tensors - MOSTLY_Q4_K_M = 15 #except 1d tensors - MOSTLY_Q5_K_S = 16 #except 1d tensors - MOSTLY_Q5_K_M = 17 #except 1d tensors - MOSTLY_Q6_K = 18 #except 1d tensors - MOSTLY_IQ2_XXS = 19 #except 1d tensors - MOSTLY_IQ2_XS = 20 #except 1d tensors - MOSTLY_Q2_K_S = 21 #except 1d tensors - MOSTLY_IQ3_XS = 22 #except 1d tensors - MOSTLY_IQ3_XXS = 23 #except 1d tensors - MOSTLY_IQ1_S = 24 #except 1d tensors - MOSTLY_IQ4_NL = 25 #except 1d tensors - MOSTLY_IQ3_S = 26 #except 1d tensors - MOSTLY_IQ3_M = 27 #except 1d tensors - MOSTLY_IQ2_S = 28 #except 1d tensors - MOSTLY_IQ2_M = 29 #except 1d tensors - MOSTLY_IQ4_XS = 30 #except 1d tensors - MOSTLY_IQ1_M = 31 #except 1d tensors - MOSTLY_BF16 = 32 #except 1d tensors - MOSTLY_Q4_0_4_4 = 33 #except 1d tensors - MOSTLY_Q4_0_4_8 = 34 #except 1d tensors - MOSTLY_Q4_0_8_8 = 35 #except 1d tensors - MOSTLY_MXFP4 = 38 #except 1d tensors, 38 to be compatible with mainline + MOSTLY_F16 = 1 # except 1d tensors + MOSTLY_Q4_0 = 2 # except 1d tensors + MOSTLY_Q4_1 = 3 # except 1d tensors + MOSTLY_Q8_0 = 7 # except 1d tensors + MOSTLY_Q5_0 = 8 # except 1d tensors + MOSTLY_Q5_1 = 9 # except 1d tensors + MOSTLY_Q2_K = 10 # except 1d tensors + MOSTLY_Q3_K_S = 11 # except 1d tensors + MOSTLY_Q3_K_M = 12 # except 1d tensors + MOSTLY_Q3_K_L = 13 # except 1d tensors + MOSTLY_Q4_K_S = 14 # except 1d tensors + MOSTLY_Q4_K_M = 15 # except 1d tensors + MOSTLY_Q5_K_S = 16 # except 1d tensors + MOSTLY_Q5_K_M = 17 # except 1d tensors + MOSTLY_Q6_K = 18 # except 1d tensors + MOSTLY_IQ2_XXS = 19 # except 1d tensors + MOSTLY_IQ2_XS = 20 # except 1d tensors + MOSTLY_Q2_K_S = 21 # except 1d tensors + MOSTLY_IQ3_XS = 22 # except 1d tensors + MOSTLY_IQ3_XXS = 23 # except 1d tensors + MOSTLY_IQ1_S = 24 # except 1d tensors + MOSTLY_IQ4_NL = 25 # except 1d tensors + MOSTLY_IQ3_S = 26 # except 1d tensors + MOSTLY_IQ3_M = 27 # except 1d tensors + MOSTLY_IQ2_S = 28 # except 1d tensors + MOSTLY_IQ2_M = 29 # except 1d tensors + MOSTLY_IQ4_XS = 30 # except 1d tensors + MOSTLY_IQ1_M = 31 # except 1d tensors + MOSTLY_BF16 = 32 # except 1d tensors + MOSTLY_Q4_0_4_4 = 33 # except 1d tensors + MOSTLY_Q4_0_4_8 = 34 # except 1d tensors + MOSTLY_Q4_0_8_8 = 35 # except 1d tensors + MOSTLY_MXFP4 = 38 # except 1d tensors, 38 to be compatible with mainline - MOSTLY_Q6_0 = 135 #except 1d tensors - MOSTLY_IQ1_BN = 136 #except 1d tensors - MOSTLY_IQ2_BN = 137 #except 1d tensors - MOSTLY_IQ2_K = 138 #except 1d tensors - MOSTLY_IQ3_K = 139 #except 1d tensors - MOSTLY_IQ4_K = 140 #except 1d tensors - MOSTLY_IQ5_K = 141 #except 1d tensors - MOSTLY_IQ6_K = 142 #except 1d tensors - MOSTLY_IQ4_KS = 145 #except 1d tensors - MOSTLY_IQ3_KL = 146 #except 1d tensors - MOSTLY_IQ2_KS = 147 #except 1d tensors - MOSTLY_IQ4_KSS = 148 #except 1d tensors - MOSTLY_Q8_KV = 149 #except 1d tensors - MOSTLY_IQ5_KS = 150 #except 1d tensors - MOSTLY_IQ2_KT = 151 #except 1d tensors - MOSTLY_IQ3_KT = 152 #except 1d tensors - MOSTLY_IQ4_KT = 153 #except 1d tensors - MOSTLY_IQ3_KS = 154 #except 1d tensors - MOSTLY_IQ2_KL = 155 #except 1d tensors - MOSTLY_IQ1_KT = 156 #except 1d tensors + MOSTLY_Q6_0 = 135 # except 1d tensors + MOSTLY_IQ1_BN = 136 # except 1d tensors + MOSTLY_IQ2_BN = 137 # except 1d tensors + MOSTLY_IQ2_K = 138 # except 1d tensors + MOSTLY_IQ3_K = 139 # except 1d tensors + MOSTLY_IQ4_K = 140 # except 1d tensors + MOSTLY_IQ5_K = 141 # except 1d tensors + MOSTLY_IQ6_K = 142 # except 1d tensors + MOSTLY_IQ4_KS = 145 # except 1d tensors + MOSTLY_IQ3_KL = 146 # except 1d tensors + MOSTLY_IQ2_KS = 147 # except 1d tensors + MOSTLY_IQ4_KSS = 148 # except 1d tensors + MOSTLY_Q8_KV = 149 # except 1d tensors + MOSTLY_IQ5_KS = 150 # except 1d tensors + MOSTLY_IQ2_KT = 151 # except 1d tensors + MOSTLY_IQ3_KT = 152 # except 1d tensors + MOSTLY_IQ4_KT = 153 # except 1d tensors + MOSTLY_IQ3_KS = 154 # except 1d tensors + MOSTLY_IQ2_KL = 155 # except 1d tensors + MOSTLY_IQ1_KT = 156 # except 1d tensors - MOSTLY_Q4_0_R8 = 202 #except 1d tensors - MOSTLY_Q8_0_R8 = 207 #except 1d tensors - MOSTLY_Q5_0_R4 = 208 #except 1d tensors - MOSTLY_Q2_K_R4 = 210 #except 1d tensors - MOSTLY_Q3_K_R4 = 211 #except 1d tensors - MOSTLY_Q4_K_R4 = 214 #except 1d tensors - MOSTLY_Q5_K_R4 = 216 #except 1d tensors - MOSTLY_Q6_K_R4 = 218 #except 1d tensors - MOSTLY_IQ2_XXS_R4 = 219 #except 1d tensors - MOSTLY_IQ2_XS_R4 = 220 #except 1d tensors - MOSTLY_IQ3_XXS_R4 = 223 #except 1d tensors - MOSTLY_IQ1_S_R4 = 224 #except 1d tensors - MOSTLY_IQ4_NL_R4 = 225 #except 1d tensors - MOSTLY_IQ3_S_R4 = 226 #except 1d tensors - MOSTLY_IQ2_M_R4 = 229 #except 1d tensors - MOSTLY_IQ4_XS_R8 = 230 #except 1d tensors - MOSTLY_IQ1_M_R4 = 231 #except 1d tensors - MOSTLY_Q6_0_R4 = 335 #except 1d tensors - MOSTLY_BF16_R16 = 232 #except 1d tensors - MOSTLY_IQ2_BN_R4 = 337 #except 1d tensors - MOSTLY_IQ2_K_R4 = 338 #except 1d tensors - MOSTLY_IQ3_K_R4 = 339 #except 1d tensors - MOSTLY_IQ4_K_R4 = 340 #except 1d tensors - MOSTLY_IQ5_K_R4 = 341 #except 1d tensors - MOSTLY_IQ4_KS_R4 = 345 #except 1d tensors - MOSTLY_IQ5_KS_R4 = 350 #except 1d tensors - MOSTLY_Q8_KV_R8 = 398 #except 1d tensors - MOSTLY_Q8_K_R8 = 399 #except 1d tensors + MOSTLY_Q4_0_R8 = 202 # except 1d tensors + MOSTLY_Q8_0_R8 = 207 # except 1d tensors + MOSTLY_Q5_0_R4 = 208 # except 1d tensors + MOSTLY_Q2_K_R4 = 210 # except 1d tensors + MOSTLY_Q3_K_R4 = 211 # except 1d tensors + MOSTLY_Q4_K_R4 = 214 # except 1d tensors + MOSTLY_Q5_K_R4 = 216 # except 1d tensors + MOSTLY_Q6_K_R4 = 218 # except 1d tensors + MOSTLY_IQ2_XXS_R4 = 219 # except 1d tensors + MOSTLY_IQ2_XS_R4 = 220 # except 1d tensors + MOSTLY_IQ3_XXS_R4 = 223 # except 1d tensors + MOSTLY_IQ1_S_R4 = 224 # except 1d tensors + MOSTLY_IQ4_NL_R4 = 225 # except 1d tensors + MOSTLY_IQ3_S_R4 = 226 # except 1d tensors + MOSTLY_IQ2_M_R4 = 229 # except 1d tensors + MOSTLY_IQ4_XS_R8 = 230 # except 1d tensors + MOSTLY_IQ1_M_R4 = 231 # except 1d tensors + MOSTLY_Q6_0_R4 = 335 # except 1d tensors + MOSTLY_BF16_R16 = 232 # except 1d tensors + MOSTLY_IQ2_BN_R4 = 337 # except 1d tensors + MOSTLY_IQ2_K_R4 = 338 # except 1d tensors + MOSTLY_IQ3_K_R4 = 339 # except 1d tensors + MOSTLY_IQ4_K_R4 = 340 # except 1d tensors + MOSTLY_IQ5_K_R4 = 341 # except 1d tensors + MOSTLY_IQ4_KS_R4 = 345 # except 1d tensors + MOSTLY_IQ5_KS_R4 = 350 # except 1d tensors + MOSTLY_Q8_KV_R8 = 398 # except 1d tensors + MOSTLY_Q8_K_R8 = 399 # except 1d tensors GUESSED = 1024 # not specified in the model file @@ -1771,7 +1812,7 @@ class GGUFValueType(IntEnum): # Items here are (block size, type size) QK_K = 256 -#Values generated programatically +# Values generated programatically GGML_QUANT_SIZES: dict[GGMLQuantizationType, tuple[int, int]] = { GGMLQuantizationType.F32 : ( 1, 4), GGMLQuantizationType.F16 : ( 1, 2), diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 59f9b0a0..035dd8e6 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -97,6 +97,7 @@ add_library(llama graphs/build_gemma2.cpp graphs/build_gemma3.cpp graphs/build_gemma4.cpp + graphs/build_dflash.cpp graphs/build_mamba.cpp graphs/build_command_r.cpp graphs/build_olmo.cpp diff --git a/src/graphs/build_dflash.cpp b/src/graphs/build_dflash.cpp new file mode 100644 index 00000000..fe1cec15 --- /dev/null +++ b/src/graphs/build_dflash.cpp @@ -0,0 +1,144 @@ +#include "../llama-build-context.h" +#include "../llama-context.h" +#include "../llama-model.h" + +#include + +ggml_cgraph * llm_build_context::build_dflash() { + const int64_t n_embd_head_k = hparams.n_embd_head_k(0); + const int64_t n_embd_head_v = hparams.n_embd_head_v(0); + const int64_t n_target_features = hparams.dflash_n_target_features; + const int64_t ctx_len = std::max(1, (int64_t) cparams.n_ctx - (int64_t) hparams.dflash_block_size); + const int64_t n_kv_total = ctx_len + n_tokens; + + GGML_ASSERT(n_embd_head_k == n_embd_head_v); + GGML_ASSERT(n_target_features > 0); + + ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes((int) std::max(n_tokens, ctx_len)) + 32 * n_layer, false); + + lctx.inp_dflash_target_features = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_target_features, ctx_len); + ggml_set_input(lctx.inp_dflash_target_features); + cb(lctx.inp_dflash_target_features, "dflash_target_features", -1); + + lctx.inp_dflash_pos_ctx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ctx_len); + ggml_set_input(lctx.inp_dflash_pos_ctx); + cb(lctx.inp_dflash_pos_ctx, "dflash_pos_ctx", -1); + + lctx.inp_dflash_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv_total, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + ggml_set_input(lctx.inp_dflash_kq_mask); + cb(lctx.inp_dflash_kq_mask, "dflash_kq_mask", -1); + + ggml_tensor * tok_embd = model.tok_embd; + if (tok_embd == nullptr) { + tok_embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_Q4_0, n_embd, hparams.n_vocab); + } + + ggml_tensor * inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, tok_embd, cb); + ggml_tensor * inp_pos = build_inp_pos(); + + ggml_tensor * fused_target = llm_build_lora_mm(lctx, ctx0, model.dflash_fc, lctx.inp_dflash_target_features); + fused_target = llm_build_norm(ctx0, fused_target, hparams, model.dflash_hidden_norm, nullptr, LLM_NORM_RMS, cb, -1); + cb(fused_target, "dflash_target_fused", -1); + + const float kq_scale = 1.0f / std::sqrt((float) n_embd_head_k); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + ggml_tensor * cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens); + Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, cb, il); + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur_noise = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); + Kcur_noise = ggml_reshape_3d(ctx0, Kcur_noise, n_embd_head_k, n_head_kv, n_tokens); + Kcur_noise = llm_build_norm(ctx0, Kcur_noise, hparams, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, cb, il); + Kcur_noise = ggml_rope_ext(ctx0, Kcur_noise, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + cb(Kcur_noise, "Kcur_noise", il); + + ggml_tensor * Vcur_noise = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); + Vcur_noise = ggml_reshape_3d(ctx0, Vcur_noise, n_embd_head_v, n_head_kv, n_tokens); + cb(Vcur_noise, "Vcur_noise", il); + + ggml_tensor * Kcur_ctx = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, fused_target); + Kcur_ctx = ggml_reshape_3d(ctx0, Kcur_ctx, n_embd_head_k, n_head_kv, ctx_len); + Kcur_ctx = llm_build_norm(ctx0, Kcur_ctx, hparams, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, cb, il); + Kcur_ctx = ggml_rope_ext(ctx0, Kcur_ctx, lctx.inp_dflash_pos_ctx, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + cb(Kcur_ctx, "Kcur_ctx", il); + + ggml_tensor * Vcur_ctx = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, fused_target); + Vcur_ctx = ggml_reshape_3d(ctx0, Vcur_ctx, n_embd_head_v, n_head_kv, ctx_len); + cb(Vcur_ctx, "Vcur_ctx", il); + + ggml_tensor * Kcur = ggml_concat(ctx0, Kcur_ctx, Kcur_noise, 2); + ggml_tensor * Vcur = ggml_concat(ctx0, Vcur_ctx, Vcur_noise, 2); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + Qcur = ggml_cast(ctx0, Qcur, GGML_TYPE_F16); + Kcur = ggml_cast(ctx0, Kcur, GGML_TYPE_F16); + Vcur = ggml_cast(ctx0, Vcur, GGML_TYPE_F16); + cb(Qcur, "Qcur_f16", il); + cb(Kcur, "Kcur_f16", il); + cb(Vcur, "Vcur_f16", il); + + ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); + ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3)); + ggml_tensor * v = ggml_cont(ctx0, ggml_permute(ctx0, Vcur, 0, 2, 1, 3)); + cb(q, "q", il); + cb(k, "k", il); + cb(v, "v", il); + + cur = ggml_flash_attn_ext(ctx0, q, k, v, lctx.inp_dflash_kq_mask, kq_scale, hparams.f_max_alibi_bias, + hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f); + cb(cur, "flash_attn", il); + ggml_build_forward_expand(gf, cur); + + cur = ggml_reshape_2d(ctx0, cur, model.layers[il].wo->ne[0], n_tokens); + cb(cur, "flash_attn_reshaped", il); + + cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur); + cb(cur, "kqv_out", il); + + cur = ggml_add(ctx0, cur, inpSA); + cb(cur, "attn_residual", il); + + ggml_tensor * ffn_residual = cur; + cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, cb, il); + cb(cur, "attn_post_norm", il); + + cur = llm_build_ffn(ctx0, lctx, nullptr, cur, + model.layers[il].ffn_up, nullptr, nullptr, + model.layers[il].ffn_gate, nullptr, nullptr, + model.layers[il].ffn_down, nullptr, nullptr, + nullptr, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il, gf, false, false); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_residual); + cb(cur, "l_out", il); + + inpL = cur; + } + + ggml_tensor * output = model.output; + if (output == nullptr) { + output = ggml_new_tensor_2d(ctx0, GGML_TYPE_Q4_0, n_embd, hparams.n_vocab); + } + + ggml_tensor * result = build_output(lctx, ctx0, inpL, output, model.output_norm, cb); + cb(result, "result_output", -1); + ggml_build_forward_expand(gf, result); + + return gf; +} diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 7e2bb4c4..3cff96dc 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -79,6 +79,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_MISTRAL4, "mistral4" }, { LLM_ARCH_GEMMA4, "gemma4" }, { LLM_ARCH_GEMMA4_MTP, "gemma4_mtp" }, + { LLM_ARCH_DFLASH_DRAFT, "dflash-draft" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -145,6 +146,10 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_MTP_USE_ORDERED_EMBEDDINGS, "%s.use_ordered_embeddings" }, { LLM_KV_MTP_CENTROID_COUNT, "%s.centroid_count" }, { LLM_KV_MTP_CENTROID_TOP_K, "%s.centroid_top_k" }, + { LLM_KV_DFLASH_BLOCK_SIZE, "%s.dflash.block_size" }, + { LLM_KV_DFLASH_MASK_TOKEN_ID, "%s.dflash.mask_token_id" }, + { LLM_KV_DFLASH_TARGET_LAYER_IDS, "%s.dflash.target_layer_ids" }, + { LLM_KV_DFLASH_N_TARGET_FEATURES, "%s.dflash.n_target_features" }, { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" }, { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" }, @@ -279,4 +284,3 @@ bool llm_arch_is_hybrid(const llm_arch & arch) { return false; } } - diff --git a/src/llama-arch.h b/src/llama-arch.h index 5a148ad7..d6f8dae4 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -78,6 +78,7 @@ enum llm_arch { LLM_ARCH_MISTRAL4, LLM_ARCH_GEMMA4, LLM_ARCH_GEMMA4_MTP, + LLM_ARCH_DFLASH_DRAFT, LLM_ARCH_UNKNOWN, }; @@ -138,6 +139,10 @@ enum llm_kv { LLM_KV_MTP_USE_ORDERED_EMBEDDINGS, LLM_KV_MTP_CENTROID_COUNT, LLM_KV_MTP_CENTROID_TOP_K, + LLM_KV_DFLASH_BLOCK_SIZE, + LLM_KV_DFLASH_MASK_TOKEN_ID, + LLM_KV_DFLASH_TARGET_LAYER_IDS, + LLM_KV_DFLASH_N_TARGET_FEATURES, LLM_KV_ATTENTION_HEAD_COUNT, LLM_KV_ATTENTION_HEAD_COUNT_KV, @@ -367,6 +372,8 @@ enum llm_tensor { LLM_TENSOR_MTP_POST_PROJ, LLM_TENSOR_MTP_TOKEN_ORDERING, LLM_TENSOR_MTP_CENTROIDS, + LLM_TENSOR_DFLASH_FC, + LLM_TENSOR_DFLASH_HIDDEN_NORM, LLM_TENSOR_UNKNOWN, }; diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index 54a649ab..ad89e6e2 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -112,6 +112,9 @@ void llm_build_context::init() { lctx.inp_pos_bucket = nullptr; lctx.inp_embd_enc = nullptr; lctx.inp_KQ_mask_cross = nullptr; + lctx.inp_dflash_target_features = nullptr; + lctx.inp_dflash_pos_ctx = nullptr; + lctx.inp_dflash_kq_mask = nullptr; } void llm_build_context::free() { @@ -2372,6 +2375,10 @@ ggml_cgraph * llm_build_context::llama_build_graph( { result = llm.build_gemma4_mtp(); } break; + case LLM_ARCH_DFLASH_DRAFT: + { + result = llm.build_dflash(); + } break; case LLM_ARCH_STARCODER2: { result = llm.build_starcoder2(); diff --git a/src/llama-build-context.h b/src/llama-build-context.h index 73490c3a..7542aff6 100644 --- a/src/llama-build-context.h +++ b/src/llama-build-context.h @@ -242,6 +242,8 @@ struct llm_build_context { ggml_cgraph * build_gemma4_mtp(); + ggml_cgraph * build_dflash(); + ggml_cgraph * build_starcoder2(); ggml_cgraph * build_mamba(); diff --git a/src/llama-context.h b/src/llama-context.h index db0018d7..d0d9fe61 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -278,6 +278,25 @@ struct llama_context { size_t draft_input_hidden_state_n_floats = 0; std::vector draft_input_hidden_state_owned; + const float * dflash_target_features = nullptr; + size_t dflash_target_features_n_floats = 0; + int32_t dflash_target_features_n_rows = 0; + std::vector dflash_target_features_owned; + std::vector dflash_target_features_padded; + std::vector dflash_feature_view_buffer; + std::vector dflash_pos_ctx_data; + std::vector dflash_kq_mask_data; + + struct dflash_capture_state { + std::vector layer_ids; + std::vector> layer_rows; + int32_t row_count = 0; + int32_t row_width = 0; + ggml_backend_sched_eval_callback prev_cb_eval = nullptr; + void * prev_cb_eval_user_data = nullptr; + }; + std::unique_ptr dflash_capture; + // input tensors struct ggml_tensor * inp_tokens; // I32 [n_batch] struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch] @@ -297,6 +316,9 @@ struct llama_context { struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch] struct ggml_tensor * inp_scale = nullptr; // F32 [n_tokens] struct ggml_tensor * inp_mtp_states = nullptr; + struct ggml_tensor * inp_dflash_target_features = nullptr; // F32 [n_target_features, cross_ctx] + struct ggml_tensor * inp_dflash_pos_ctx = nullptr; // I32 [cross_ctx] + struct ggml_tensor * inp_dflash_kq_mask = nullptr; // F32 [cross_ctx + n_batch, GGML_PAD(n_batch)] ggml_backend_t ggml_backend_by_name(const char * name); @@ -320,4 +342,3 @@ struct llama_context { void set_mtp_op_type(llama_mtp_op_type value); }; - diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index f3be1c11..271633c6 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -3,6 +3,7 @@ #include "llama-model-loader.h" #include "llama-model.h" +#include #include #define LLAMA_MAX_EXPERTS 512 // Qwen3 Next @@ -36,6 +37,89 @@ static inline const char * llm_expert_gating_func_name(llm_expert_gating_func_ty } } +static bool load_dflash_target_layer_ids( + llama_model_loader & ml, + const std::string & key, + llama_hparams & hparams, + bool required) { + const int kid = gguf_find_key(ml.meta, key.c_str()); + if (kid < 0 || gguf_get_kv_type(ml.meta, kid) != GGUF_TYPE_ARRAY) { + if (required) { + throw std::runtime_error(format("array key not found in model: %s", key.c_str())); + } + return false; + } + + const enum gguf_type type = gguf_get_arr_type(ml.meta, kid); + if (type != GGUF_TYPE_UINT32 && type != GGUF_TYPE_INT32) { + throw std::runtime_error(format("dflash: %s must be a uint32/int32 array", key.c_str())); + } + + const size_t n = gguf_get_arr_n(ml.meta, kid); + if (n == 0) { + throw std::runtime_error(format("dflash: %s must not be empty", key.c_str())); + } + if (n > 8) { + throw std::runtime_error(format("dflash: %s has %zu entries, max is 8", key.c_str(), n)); + } + + hparams.dflash_n_target_layers = (uint32_t) n; + for (uint32_t & id : hparams.dflash_target_layer_ids) { + id = 0; + } + + const void * data = gguf_get_arr_data(ml.meta, kid); + for (uint32_t i = 0; i < hparams.dflash_n_target_layers; ++i) { + if (type == GGUF_TYPE_INT32) { + const int32_t id = ((const int32_t *) data)[i]; + if (id < 0) { + throw std::runtime_error(format("dflash: %s contains negative layer id %d", key.c_str(), id)); + } + hparams.dflash_target_layer_ids[i] = (uint32_t) id; + } else { + hparams.dflash_target_layer_ids[i] = ((const uint32_t *) data)[i]; + } + } + + return true; +} + +static void validate_dflash_hparams(llama_hparams & hparams, llm_arch arch) { + if (hparams.dflash_block_size <= 1) { + throw std::runtime_error(format("%s: dflash block_size must be > 1", llama_model_arch_name(arch))); + } + if (hparams.dflash_n_target_layers == 0) { + throw std::runtime_error(format("%s: dflash target_layer_ids are required", llama_model_arch_name(arch))); + } + + if (arch == LLM_ARCH_DFLASH_DRAFT && hparams.n_embd > 0) { + const uint32_t expected_n_target_features = hparams.n_embd * hparams.dflash_n_target_layers; + if (expected_n_target_features > 0 && hparams.dflash_n_target_features != expected_n_target_features) { + LLAMA_LOG_WARN( + "%s: overriding dflash n_target_features from %u to %u based on n_embd=%u and n_target_layers=%u\n", + llama_model_arch_name(arch), + hparams.dflash_n_target_features, + expected_n_target_features, + hparams.n_embd, + hparams.dflash_n_target_layers); + hparams.dflash_n_target_features = expected_n_target_features; + } + } + + if (hparams.dflash_n_target_features == 0) { + throw std::runtime_error(format( + "%s: dflash n_target_features must be > 0", + llama_model_arch_name(arch))); + } + if (hparams.dflash_n_target_features % hparams.dflash_n_target_layers != 0) { + throw std::runtime_error(format( + "%s: dflash n_target_features=%u must be divisible by n_target_layers=%u", + llama_model_arch_name(arch), + hparams.dflash_n_target_features, + hparams.dflash_n_target_layers)); + } +} + void llm_load_hparams( llama_model_loader & ml, @@ -774,6 +858,18 @@ void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_DFLASH_DRAFT: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_DFLASH_BLOCK_SIZE, hparams.dflash_block_size, false); + ml.get_key(LLM_KV_DFLASH_MASK_TOKEN_ID, hparams.dflash_mask_token_id, false); + ml.get_key(LLM_KV_DFLASH_N_TARGET_FEATURES, hparams.dflash_n_target_features, false); + load_dflash_target_layer_ids(ml, LLM_KV(model.arch)(LLM_KV_DFLASH_TARGET_LAYER_IDS), hparams, false); + validate_dflash_hparams(hparams, model.arch); + + hparams.n_layer_kv_from_start = hparams.n_layer; + model.type = e_model::MODEL_UNKNOWN; + } break; case LLM_ARCH_STARCODER2: { diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 13fe1811..65fbb7ce 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -140,6 +140,13 @@ struct llama_hparams { uint32_t mtp_num_centroids = 0; uint32_t mtp_centroid_top_k = 0; + // DFlash draft model metadata + uint32_t dflash_block_size = 16; + uint32_t dflash_mask_token_id = 0; + uint32_t dflash_n_target_features = 0; + uint32_t dflash_n_target_layers = 0; + uint32_t dflash_target_layer_ids[8] = {}; + // needed by encoder-decoder models (e.g. T5, FLAN-T5) // ref: https://github.com/ggerganov/llama.cpp/pull/8141 llama_token dec_start_token_id = -1; @@ -159,6 +166,10 @@ struct llama_hparams { if (this->n_ctx_train != other.n_ctx_train) return true; if (this->n_embd != other.n_embd) return true; if (this->mtp_backbone_n_embd != other.mtp_backbone_n_embd) return true; + if (this->dflash_block_size != other.dflash_block_size) return true; + if (this->dflash_mask_token_id != other.dflash_mask_token_id) return true; + if (this->dflash_n_target_features != other.dflash_n_target_features) return true; + if (this->dflash_n_target_layers != other.dflash_n_target_layers) return true; if (this->n_layer != other.n_layer) return true; if (this->n_rot != other.n_rot) return true; if (this->n_swa != other.n_swa) return true; @@ -189,6 +200,9 @@ struct llama_hparams { if (this->ssm_dt_rank != other.ssm_dt_rank) return true; if (this->ssm_n_group != other.ssm_n_group) return true; if (this->recurrent_layer_arr != other.recurrent_layer_arr) return true; + for (int i = 0; i < 8; ++i) { + if (this->dflash_target_layer_ids[i] != other.dflash_target_layer_ids[i]) return true; + } if (this->dec_start_token_id != other.dec_start_token_id) return true; diff --git a/src/llama-load-tensors.cpp b/src/llama-load-tensors.cpp index b3f1ff06..bcd08ff4 100644 --- a/src/llama-load-tensors.cpp +++ b/src/llama-load-tensors.cpp @@ -98,6 +98,8 @@ struct create_tensors_helper : public create_tensors_helper_interface { bool create_gemma4_mtp_tensors(const LLM_TN & tn); + bool create_dflash_tensors(const LLM_TN & tn); + bool create_starcoder2_tensors(const LLM_TN & tn); bool create_mamba_tensors(const LLM_TN & tn); @@ -2192,6 +2194,43 @@ bool create_tensors_helper::create_gemma4_mtp_tensors(const LLM_TN & tn) { return use_mmap_buffer; } +bool create_tensors_helper::create_dflash_tensors(const LLM_TN & tn) { + LOADING_PRELUDE + + const bool use_split_ctx = model.split_mode == LLAMA_SPLIT_MODE_GRAPH || model.split_mode == LLAMA_SPLIT_MODE_ATTN; + + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + if (model.output == nullptr && model.tok_embd != nullptr) { + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + } + model.dflash_fc = create_tensor(ctx_output, tn(LLM_TENSOR_DFLASH_FC, "weight"), {(int64_t) hparams.dflash_n_target_features, n_embd}, 0); + model.dflash_hidden_norm = create_tensor(ctx_output, tn(LLM_TENSOR_DFLASH_HIDDEN_NORM, "weight"), {n_embd}, 0); + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_split = use_split_ctx ? ctx_for_layer_split(i) : ctx_for_layer(i); + auto & layer = model.layers[i]; + + layer.attn_norm = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_post_norm = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v * n_head, n_embd}, 0); + + layer.attn_q_norm = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_k_norm = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + + return use_mmap_buffer; +} + bool create_tensors_helper::create_starcoder2_tensors(const LLM_TN & tn) { LOADING_PRELUDE @@ -4263,6 +4302,8 @@ bool create_tensors_helper::create_tensors() { use_mmap_buffer = create_gemma4_tensors(tn); break; case LLM_ARCH_GEMMA4_MTP: use_mmap_buffer = create_gemma4_mtp_tensors(tn); break; + case LLM_ARCH_DFLASH_DRAFT: + use_mmap_buffer = create_dflash_tensors(tn); break; case LLM_ARCH_STARCODER2: use_mmap_buffer = create_starcoder2_tensors(tn); break; case LLM_ARCH_MAMBA: diff --git a/src/llama-model.cpp b/src/llama-model.cpp index fef0069d..553c9520 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -825,6 +825,27 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_MTP_CENTROIDS, "mtp_centroids" }, }, }, + { + LLM_ARCH_DFLASH_DRAFT, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_DFLASH_FC, "dflash_fc" }, + { LLM_TENSOR_DFLASH_HIDDEN_NORM, "dflash_hidden_norm" }, + }, + }, { LLM_ARCH_STARCODER2, { diff --git a/src/llama-model.h b/src/llama-model.h index 5ff084fe..6a11ba41 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -428,6 +428,8 @@ struct llama_model { struct ggml_tensor * mtp_post_proj = nullptr; struct ggml_tensor * mtp_token_ordering = nullptr; struct ggml_tensor * mtp_centroids = nullptr; + struct ggml_tensor * dflash_fc = nullptr; + struct ggml_tensor * dflash_hidden_norm = nullptr; struct ggml_tensor * output_norm; struct ggml_tensor * output_norm_b; @@ -621,4 +623,3 @@ struct LLM_TN { std::string llama_model_ftype_name(llama_ftype ftype); const char * llama_model_type_name(e_model type); - diff --git a/src/llama-spec-features.cpp b/src/llama-spec-features.cpp index 5a32b848..827d536f 100644 --- a/src/llama-spec-features.cpp +++ b/src/llama-spec-features.cpp @@ -1,5 +1,8 @@ #include "llama-spec-features.h" +#include +#include +#include #include #include "llama-model.h" @@ -18,6 +21,63 @@ uint32_t llama_mtp_state_n_embd(const struct llama_context * ctx) { return hparams.n_embd; } +int32_t llama_model_dflash_block_size(const struct llama_model * model) { + return model ? (int32_t) model->hparams.dflash_block_size : 0; +} + +int32_t llama_model_dflash_mask_token_id(const struct llama_model * model) { + return model ? (int32_t) model->hparams.dflash_mask_token_id : -1; +} + +int32_t llama_model_dflash_n_target_layers(const struct llama_model * model) { + return model ? (int32_t) model->hparams.dflash_n_target_layers : 0; +} + +int32_t llama_model_dflash_n_target_features(const struct llama_model * model) { + return model ? (int32_t) model->hparams.dflash_n_target_features : 0; +} + +int32_t llama_model_dflash_target_layer_ids( + const struct llama_model * model, + int32_t * layer_ids, + int32_t capacity) { + if (model == nullptr || layer_ids == nullptr || capacity <= 0) { + return 0; + } + + const int32_t n_layers = std::min((int32_t) model->hparams.dflash_n_target_layers, capacity); + for (int32_t i = 0; i < n_layers; ++i) { + layer_ids[i] = (int32_t) model->hparams.dflash_target_layer_ids[i]; + } + + return n_layers; +} + +bool llama_model_share_dflash_io_tensors( + struct llama_model * draft_model, + const struct llama_model * target_model) { + if (draft_model == nullptr || target_model == nullptr) { + return false; + } + + if (draft_model->arch != LLM_ARCH_DFLASH_DRAFT) { + return true; + } + + if (draft_model->tok_embd == nullptr) { + draft_model->tok_embd = target_model->tok_embd; + } + + if (draft_model->output == nullptr) { + draft_model->output = target_model->output ? target_model->output : target_model->tok_embd; + if (draft_model->output == nullptr) { + draft_model->output = draft_model->tok_embd; + } + } + + return draft_model->tok_embd != nullptr && draft_model->output != nullptr; +} + bool llama_set_draft_input_hidden_state_copy( struct llama_context * ctx, const float * hidden_state, @@ -32,6 +92,211 @@ bool llama_set_draft_input_hidden_state_copy( return true; } +bool llama_set_dflash_target_features_copy( + struct llama_context * ctx, + const float * target_features, + size_t n_floats, + int32_t n_rows) { + if (ctx == nullptr || target_features == nullptr || n_floats == 0 || n_rows <= 0) { + return false; + } + + ctx->dflash_target_features_owned.assign(target_features, target_features + n_floats); + ctx->dflash_target_features = ctx->dflash_target_features_owned.data(); + ctx->dflash_target_features_n_floats = n_floats; + ctx->dflash_target_features_n_rows = n_rows; + return true; +} + +static bool llama_dflash_parse_layer_id(const struct ggml_tensor * tensor, int32_t & layer_id) { + if (tensor == nullptr) { + return false; + } + + static constexpr const char * prefix = "l_out-"; + if (std::strncmp(tensor->name, prefix, std::strlen(prefix)) != 0) { + return false; + } + + char * end = nullptr; + const long raw = std::strtol(tensor->name + std::strlen(prefix), &end, 10); + if (end == tensor->name + std::strlen(prefix) || *end != '\0') { + return false; + } + + layer_id = (int32_t) raw; + if (layer_id >= 1000) { + layer_id %= 1000; + } + + return layer_id >= 0; +} + +static int32_t llama_dflash_find_layer_index(const struct llama_context * ctx, int32_t layer_id) { + if (ctx == nullptr || !ctx->dflash_capture) { + return -1; + } + + const auto & layer_ids = ctx->dflash_capture->layer_ids; + const auto it = std::find(layer_ids.begin(), layer_ids.end(), layer_id); + return it == layer_ids.end() ? -1 : (int32_t) std::distance(layer_ids.begin(), it); +} + +static bool llama_dflash_capture_eval_callback(struct ggml_tensor * tensor, bool ask, void * user_data) { + auto * ctx = static_cast(user_data); + if (ctx == nullptr || !ctx->dflash_capture) { + return false; + } + + int32_t layer_id = -1; + if (!llama_dflash_parse_layer_id(tensor, layer_id)) { + return false; + } + + const int32_t layer_idx = llama_dflash_find_layer_index(ctx, layer_id); + if (layer_idx < 0) { + return false; + } + + if (ask) { + return true; + } + + const int32_t row_width = (int32_t) tensor->ne[0]; + const int32_t row_count = row_width > 0 ? (int32_t) (ggml_nelements(tensor) / (int64_t) row_width) : 0; + if (row_width <= 0 || row_count <= 0) { + return false; + } + + auto & capture = *ctx->dflash_capture; + auto & rows = capture.layer_rows[(size_t) layer_idx]; + rows.resize((size_t) row_count * (size_t) row_width); + ggml_backend_tensor_get(tensor, rows.data(), 0, ggml_nbytes(tensor)); + capture.row_width = row_width; + capture.row_count = row_count; + return true; +} + +bool llama_set_dflash_capture_layers( + struct llama_context * ctx, + const int32_t * layer_ids, + int32_t n_layers) { + if (ctx == nullptr || layer_ids == nullptr || n_layers <= 0) { + return false; + } + + auto capture = std::make_unique(); + capture->layer_ids.assign(layer_ids, layer_ids + n_layers); + capture->layer_rows.resize((size_t) n_layers); + capture->prev_cb_eval = ctx->cparams.cb_eval; + capture->prev_cb_eval_user_data = ctx->cparams.cb_eval_user_data; + ctx->dflash_capture = std::move(capture); + ctx->dflash_feature_view_buffer.clear(); + + ctx->cparams.cb_eval = llama_dflash_capture_eval_callback; + ctx->cparams.cb_eval_user_data = ctx; + if (ctx->sched != nullptr) { + ggml_backend_sched_set_eval_callback(ctx->sched, ctx->cparams.cb_eval, ctx->cparams.cb_eval_user_data); + } + + return true; +} + +void llama_clear_dflash_capture(struct llama_context * ctx) { + if (ctx == nullptr) { + return; + } + + ggml_backend_sched_eval_callback prev_cb_eval = nullptr; + void * prev_cb_eval_user_data = nullptr; + if (ctx->dflash_capture) { + prev_cb_eval = ctx->dflash_capture->prev_cb_eval; + prev_cb_eval_user_data = ctx->dflash_capture->prev_cb_eval_user_data; + } + + ctx->dflash_capture.reset(); + ctx->dflash_feature_view_buffer.clear(); + + if (ctx->cparams.cb_eval == llama_dflash_capture_eval_callback && ctx->cparams.cb_eval_user_data == ctx) { + ctx->cparams.cb_eval = prev_cb_eval; + ctx->cparams.cb_eval_user_data = prev_cb_eval_user_data; + if (ctx->sched != nullptr) { + ggml_backend_sched_set_eval_callback(ctx->sched, prev_cb_eval, prev_cb_eval_user_data); + } + } +} + +static bool llama_spec_prepare_dflash_capture( + struct llama_context * ctx, + int32_t & row_count, + int32_t & row_width, + int32_t & n_layers) { + if (ctx == nullptr || !ctx->dflash_capture) { + return false; + } + + llama_synchronize(ctx); + + auto & capture = *ctx->dflash_capture; + row_count = capture.row_count; + row_width = capture.row_width; + n_layers = (int32_t) capture.layer_ids.size(); + if (row_count <= 0 || row_width <= 0 || n_layers <= 0 || capture.layer_rows.size() != (size_t) n_layers) { + return false; + } + + for (const auto & rows : capture.layer_rows) { + if (rows.size() != (size_t) row_count * (size_t) row_width) { + return false; + } + } + + return true; +} + +static bool llama_spec_materialize_dflash_rows( + struct llama_context * ctx, + const std::vector & row_indices, + std::vector & rows_out, + int32_t & combined_width) { + rows_out.clear(); + combined_width = 0; + if (ctx == nullptr || row_indices.empty()) { + return false; + } + + int32_t row_count = 0; + int32_t row_width = 0; + int32_t n_layers = 0; + if (!llama_spec_prepare_dflash_capture(ctx, row_count, row_width, n_layers)) { + return false; + } + + combined_width = row_width * n_layers; + rows_out.resize((size_t) row_indices.size() * (size_t) combined_width); + + const auto & layer_rows = ctx->dflash_capture->layer_rows; + for (size_t out_row = 0; out_row < row_indices.size(); ++out_row) { + int32_t row_index = row_indices[out_row]; + if (row_index < 0) { + row_index += row_count; + } + if (row_index < 0 || row_index >= row_count) { + rows_out.clear(); + combined_width = 0; + return false; + } + + float * dst = rows_out.data() + out_row * (size_t) combined_width; + for (int32_t layer_idx = 0; layer_idx < n_layers; ++layer_idx) { + const float * src = layer_rows[(size_t) layer_idx].data() + (size_t) row_index * (size_t) row_width; + std::memcpy(dst + (size_t) layer_idx * (size_t) row_width, src, (size_t) row_width * sizeof(float)); + } + } + + return true; +} + static bool llama_spec_prepare_hidden_feature_view( struct llama_context * ctx, int32_t n_rows, @@ -88,6 +353,92 @@ bool llama_spec_get_hidden_feature_view( return true; } +bool llama_spec_get_dflash_feature_view( + struct llama_context * ctx, + const llama_batch & batch, + llama_spec_feature_view & view) { + if (ctx == nullptr || batch.n_tokens <= 0 || batch.pos == nullptr || batch.n_seq_id == nullptr || batch.seq_id == nullptr) { + return false; + } + + std::vector row_indices((size_t) batch.n_tokens); + for (int32_t i = 0; i < batch.n_tokens; ++i) { + row_indices[(size_t) i] = i; + } + + view = {}; + view.kind = LLAMA_SPEC_FEATURE_HIDDEN_STATE; + if (!llama_spec_materialize_dflash_rows(ctx, row_indices, ctx->dflash_feature_view_buffer, view.width)) { + return false; + } + + view.rows.reserve((size_t) batch.n_tokens); + for (int32_t i = 0; i < batch.n_tokens; ++i) { + if (batch.n_seq_id[i] <= 0 || batch.seq_id[i] == nullptr) { + view.rows.clear(); + return false; + } + + view.rows.push_back({ + /* .seq_id = */ batch.seq_id[i][0], + /* .pos = */ batch.pos[i], + /* .data = */ ctx->dflash_feature_view_buffer.data() + (size_t) i * (size_t) view.width, + }); + } + + return true; +} + +bool llama_spec_get_dflash_feature_view_for_seq( + struct llama_context * ctx, + const llama_batch & batch, + llama_seq_id seq_id, + llama_spec_feature_view & view) { + if (ctx == nullptr || batch.n_tokens <= 0 || batch.pos == nullptr || batch.n_seq_id == nullptr || batch.seq_id == nullptr) { + return false; + } + + std::vector row_indices; + row_indices.reserve((size_t) batch.n_tokens); + std::vector batch_indices; + batch_indices.reserve((size_t) batch.n_tokens); + for (int32_t i = 0; i < batch.n_tokens; ++i) { + if (batch.n_seq_id[i] <= 0 || batch.seq_id[i] == nullptr) { + return false; + } + + for (int32_t j = 0; j < batch.n_seq_id[i]; ++j) { + if (batch.seq_id[i][j] == seq_id) { + row_indices.push_back(i); + batch_indices.push_back(i); + break; + } + } + } + + if (row_indices.empty()) { + return false; + } + + view = {}; + view.kind = LLAMA_SPEC_FEATURE_HIDDEN_STATE; + if (!llama_spec_materialize_dflash_rows(ctx, row_indices, ctx->dflash_feature_view_buffer, view.width)) { + return false; + } + + view.rows.reserve(row_indices.size()); + for (size_t i = 0; i < batch_indices.size(); ++i) { + const int32_t batch_index = batch_indices[i]; + view.rows.push_back({ + /* .seq_id = */ seq_id, + /* .pos = */ batch.pos[batch_index], + /* .data = */ ctx->dflash_feature_view_buffer.data() + i * (size_t) view.width, + }); + } + + return true; +} + bool llama_spec_get_hidden_feature_view_for_seq( struct llama_context * ctx, const llama_batch & batch, @@ -179,4 +530,17 @@ bool llama_spec_copy_hidden_rows_from_output_indices( } return hidden_rows.size() == (size_t) output_indices.size() * view.width; -} \ No newline at end of file +} + +bool llama_spec_copy_dflash_rows_from_output_indices( + struct llama_context * ctx, + const std::vector & output_indices, + std::vector & hidden_rows) { + int32_t combined_width = 0; + if (!llama_spec_materialize_dflash_rows(ctx, output_indices, hidden_rows, combined_width)) { + hidden_rows.clear(); + return false; + } + + return hidden_rows.size() == (size_t) output_indices.size() * (size_t) combined_width; +} diff --git a/src/llama-spec-features.h b/src/llama-spec-features.h index 7634197b..ea177c1e 100644 --- a/src/llama-spec-features.h +++ b/src/llama-spec-features.h @@ -25,16 +25,57 @@ struct llama_spec_feature_view { uint32_t llama_mtp_state_n_embd(const struct llama_context * ctx); +int32_t llama_model_dflash_block_size(const struct llama_model * model); + +int32_t llama_model_dflash_mask_token_id(const struct llama_model * model); + +int32_t llama_model_dflash_n_target_layers(const struct llama_model * model); + +int32_t llama_model_dflash_n_target_features(const struct llama_model * model); + +int32_t llama_model_dflash_target_layer_ids( + const struct llama_model * model, + int32_t * layer_ids, + int32_t capacity); + +bool llama_model_share_dflash_io_tensors( + struct llama_model * draft_model, + const struct llama_model * target_model); + bool llama_set_draft_input_hidden_state_copy( struct llama_context * ctx, const float * hidden_state, size_t n_floats); +bool llama_set_dflash_target_features_copy( + struct llama_context * ctx, + const float * target_features, + size_t n_floats, + int32_t n_rows); + +bool llama_set_dflash_capture_layers( + struct llama_context * ctx, + const int32_t * layer_ids, + int32_t n_layers); + +void llama_clear_dflash_capture(struct llama_context * ctx); + bool llama_spec_get_hidden_feature_view( struct llama_context * ctx, const llama_batch & batch, llama_spec_feature_view & view); +bool llama_spec_get_dflash_feature_view( + struct llama_context * ctx, + const llama_batch & batch, + llama_spec_feature_view & view); + +bool llama_spec_get_dflash_feature_view_for_seq( + struct llama_context * ctx, + const llama_batch & batch, + llama_seq_id seq_id, + llama_spec_feature_view & view); + bool llama_spec_get_hidden_feature_view_for_seq( struct llama_context * ctx, const llama_batch & batch, @@ -51,4 +92,9 @@ bool llama_spec_get_hidden_feature_view_from_output_index( bool llama_spec_copy_hidden_rows_from_output_indices( struct llama_context * ctx, const std::vector & output_indices, - std::vector & hidden_rows); \ No newline at end of file + std::vector & hidden_rows); + +bool llama_spec_copy_dflash_rows_from_output_indices( + struct llama_context * ctx, + const std::vector & output_indices, + std::vector & hidden_rows); diff --git a/src/llama.cpp b/src/llama.cpp index a14259cc..0bc960c3 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -3125,6 +3125,10 @@ static std::pair, double> get_layer_sizes(const llama_model_ name == "mtp_centroids.weight" || name == "mtp_token_ordering.weight") { continue; } + if (name == "dflash_fc.weight" || name == "dflash_hidden_norm.weight") { + output_misc_size += size; + continue; + } auto pos = name.find("blk."); if (pos != 0) { LLAMA_LOG_WARN("Oops: tensor with strange name %s\n", name.c_str()); @@ -4977,6 +4981,61 @@ static bool prepare_mtp_graph_inputs( return true; } +static bool prepare_dflash_graph_inputs( + struct llama_context & lctx, + uint32_t n_tokens) { + ggml_tensor * target_hidden = lctx.inp_dflash_target_features; + ggml_tensor * pos_ctx = lctx.inp_dflash_pos_ctx; + ggml_tensor * kq_mask = lctx.inp_dflash_kq_mask; + + if (target_hidden == nullptr || pos_ctx == nullptr || kq_mask == nullptr) { + LLAMA_LOG_ERROR("%s: DFlash graph inputs are not initialized\n", __func__); + return false; + } + + const float * src = lctx.dflash_target_features; + const size_t total_floats = lctx.dflash_target_features_n_floats; + const int32_t n_rows = lctx.dflash_target_features_n_rows; + const int32_t width = (int32_t) target_hidden->ne[0]; + const int32_t cross_ctx = (int32_t) target_hidden->ne[1]; + const int32_t n_mask_tokens = (int32_t) kq_mask->ne[1]; + const int32_t n_kv_total = (int32_t) kq_mask->ne[0]; + + if (src == nullptr || total_floats == 0 || n_rows <= 0) { + LLAMA_LOG_ERROR("%s: missing DFlash target features\n", __func__); + return false; + } + + if (n_rows > cross_ctx || total_floats != (size_t) n_rows * (size_t) width) { + LLAMA_LOG_ERROR("%s: invalid DFlash target feature shape (rows=%d width=%d floats=%zu cross_ctx=%d)\n", + __func__, n_rows, width, total_floats, cross_ctx); + return false; + } + + lctx.dflash_target_features_padded.assign((size_t) cross_ctx * (size_t) width, 0.0f); + const size_t dst_offset = (size_t) (cross_ctx - n_rows) * (size_t) width; + std::copy(src, src + total_floats, lctx.dflash_target_features_padded.begin() + (ptrdiff_t) dst_offset); + ggml_backend_tensor_set(target_hidden, lctx.dflash_target_features_padded.data(), 0, ggml_nbytes(target_hidden)); + + lctx.dflash_pos_ctx_data.resize((size_t) cross_ctx); + for (int32_t i = 0; i < cross_ctx; ++i) { + lctx.dflash_pos_ctx_data[i] = i; + } + ggml_backend_tensor_set(pos_ctx, lctx.dflash_pos_ctx_data.data(), 0, ggml_nbytes(pos_ctx)); + + lctx.dflash_kq_mask_data.assign((size_t) n_kv_total * (size_t) n_mask_tokens, -INFINITY); + const int32_t left_pad = cross_ctx - n_rows; + for (uint32_t j = 0; j < n_tokens; ++j) { + float * row = lctx.dflash_kq_mask_data.data() + (size_t) j * (size_t) n_kv_total; + for (int32_t i = left_pad; i < n_kv_total; ++i) { + row[i] = 0.0f; + } + } + ggml_backend_tensor_set(kq_mask, lctx.dflash_kq_mask_data.data(), 0, ggml_nbytes(kq_mask)); + + return true; +} + // decode a batch of tokens by evaluating the transformer // // - lctx: llama context @@ -5269,6 +5328,12 @@ static int llama_decode_internal( } } + if (lctx.model.arch == LLM_ARCH_DFLASH_DRAFT) { + if (!prepare_dflash_graph_inputs(lctx, n_tokens)) { + return GGML_STATUS_FAILED; + } + } + // the output is always the last tensor in the graph struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1]; struct ggml_tensor * embd = nullptr; @@ -7371,6 +7436,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_STEP35: case LLM_ARCH_GEMMA4: case LLM_ARCH_GEMMA4_MTP: + case LLM_ARCH_DFLASH_DRAFT: return LLAMA_ROPE_TYPE_NEOX; case LLM_ARCH_QWEN2VL: From 9f5f70cf7e97971d5f3ea94228653689601ccc0a Mon Sep 17 00:00:00 2001 From: SamuelOliveirads Date: Fri, 29 May 2026 23:11:38 -0300 Subject: [PATCH 02/13] implement target position tracking and context management --- common/speculative.cpp | 122 +++++++++++++++++++++++------ examples/server/server-context.cpp | 14 +++- examples/server/server-context.h | 2 + src/graphs/build_dflash.cpp | 16 +++- src/llama-context.h | 3 + src/llama-spec-features.cpp | 58 +++++++++++--- src/llama-spec-features.h | 3 +- src/llama.cpp | 16 +++- 8 files changed, 189 insertions(+), 45 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index 3b08b26a..d740ded9 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -365,7 +365,9 @@ struct common_speculative_state_dflash : public common_speculative_state { std::vector target_layer_ids; std::vector target_window; + std::vector target_window_pos; int32_t target_window_rows = 0; + llama_pos last_target_pos = -1; common_speculative_state_dflash( enum common_speculative_type type, @@ -426,8 +428,6 @@ struct common_speculative_state_dflash : public common_speculative_state { void begin(const llama_tokens & prompt) override { GGML_UNUSED(prompt); - target_window.clear(); - target_window_rows = 0; llama_kv_cache_clear(ctx_dft); } @@ -444,20 +444,21 @@ struct common_speculative_state_dflash : public common_speculative_state { return; } - const int32_t n_draft = std::min(params.n_max, block_size); - if (n_draft <= 0) { + const int32_t n_keep = std::min(params.n_max, block_size); + if (n_keep <= 0) { return; } - if (!llama_set_dflash_target_features_copy(ctx_dft, target_window.data(), target_window.size(), target_window_rows)) { + if (!llama_set_dflash_target_features_copy(ctx_dft, target_window.data(), target_window.size(), target_window_rows, target_window_pos.data())) { LOG_ERR("%s: failed to set DFlash target features\n", __func__); return; } llama_kv_cache_clear(ctx_dft); batch.n_tokens = 0; - for (int32_t i = 0; i < n_draft; ++i) { - common_batch_add(batch, mask_token_id, cross_ctx + i, { 0 }, true); + const llama_pos draft_pos_base = last_target_pos >= 0 ? last_target_pos + 1 : (llama_pos) target_window_rows; + for (int32_t i = 0; i < block_size; ++i) { + common_batch_add(batch, mask_token_id, draft_pos_base + i, { 0 }, i < n_keep); } if (llama_decode(ctx_dft, batch) != 0) { @@ -466,8 +467,8 @@ struct common_speculative_state_dflash : public common_speculative_state { return; } - result.reserve((size_t) n_draft); - for (int32_t i = 0; i < n_draft; ++i) { + result.reserve((size_t) n_keep); + for (int32_t i = 0; i < n_keep; ++i) { result.push_back(common_sampler_sample_speculative(nullptr, ctx_dft, i, nullptr)); } @@ -2216,42 +2217,118 @@ static void mtp_clear_target_hidden(common_speculative_state_mtp & state, llama_ state.draft_cache_by_seq.erase(seq_id); } -static void dflash_append_target_features( +static bool dflash_append_target_features( common_speculative_state_dflash & state, - const float * feature_rows, - int32_t n_rows) { - if (feature_rows == nullptr || n_rows <= 0 || state.n_target_features <= 0 || state.cross_ctx <= 0) { - return; + const common_speculative_feature_view & features, + const llama_batch & batch, + llama_seq_id seq_id) { + GGML_UNUSED(batch); + + if (features.kind != COMMON_SPECULATIVE_FEATURE_HIDDEN_STATE || + features.width != state.n_target_features || + features.rows.empty() || + state.cross_ctx <= 0) { + return false; } const size_t row_width = (size_t) state.n_target_features; + std::vector new_rows; + std::vector new_positions; + new_rows.reserve(features.rows.size() * row_width); + new_positions.reserve(features.rows.size()); + + for (const auto & row : features.rows) { + if (row.seq_id != seq_id || row.data == nullptr) { + continue; + } + + new_positions.push_back(row.pos); + new_rows.insert(new_rows.end(), row.data, row.data + row_width); + } + + if (new_positions.empty()) { + return false; + } + + const int32_t n_rows = (int32_t) new_positions.size(); if (n_rows >= state.cross_ctx) { - const float * src = feature_rows + (size_t) (n_rows - state.cross_ctx) * row_width; - state.target_window.assign(src, src + (size_t) state.cross_ctx * row_width); + const int32_t keep_from = n_rows - state.cross_ctx; + state.target_window.assign( + new_rows.begin() + (ptrdiff_t) keep_from * (ptrdiff_t) row_width, + new_rows.end()); + state.target_window_pos.assign(new_positions.begin() + keep_from, new_positions.end()); state.target_window_rows = state.cross_ctx; - return; + state.last_target_pos = state.target_window_pos.empty() ? -1 : state.target_window_pos.back(); + return true; } const int32_t keep_old_rows = std::min(state.target_window_rows, state.cross_ctx - n_rows); std::vector next_window((size_t) (keep_old_rows + n_rows) * row_width); + std::vector next_window_pos((size_t) (keep_old_rows + n_rows)); if (keep_old_rows > 0) { const float * old_src = state.target_window.data() + (size_t) (state.target_window_rows - keep_old_rows) * row_width; std::memcpy(next_window.data(), old_src, (size_t) keep_old_rows * row_width * sizeof(float)); + std::copy(state.target_window_pos.end() - keep_old_rows, state.target_window_pos.end(), next_window_pos.begin()); } std::memcpy( next_window.data() + (size_t) keep_old_rows * row_width, - feature_rows, + new_rows.data(), (size_t) n_rows * row_width * sizeof(float)); + std::copy(new_positions.begin(), new_positions.end(), next_window_pos.begin() + keep_old_rows); state.target_window = std::move(next_window); + state.target_window_pos = std::move(next_window_pos); state.target_window_rows = keep_old_rows + n_rows; + state.last_target_pos = state.target_window_pos.empty() ? -1 : state.target_window_pos.back(); + return true; } static void dflash_clear_target_features(common_speculative_state_dflash & state) { state.target_window.clear(); + state.target_window_pos.clear(); state.target_window_rows = 0; + state.last_target_pos = -1; +} + +static void dflash_context_shift( + common_speculative_state_dflash & state, + llama_pos kv_keep, + llama_pos kv_discard, + llama_pos kv_past) { + if (kv_discard <= 0 || state.target_window_rows <= 0 || state.target_window.empty() || state.target_window_pos.empty()) { + return; + } + + const size_t row_width = (size_t) state.n_target_features; + const llama_pos discard_begin = kv_keep; + const llama_pos discard_end = kv_keep + kv_discard; + + std::vector shifted_rows; + std::vector shifted_positions; + shifted_rows.reserve(state.target_window.size()); + shifted_positions.reserve(state.target_window_pos.size()); + + for (int32_t row = 0; row < state.target_window_rows; ++row) { + llama_pos pos = state.target_window_pos[(size_t) row]; + if (pos >= discard_begin && pos < discard_end) { + continue; + } + + if (pos >= discard_end && pos < kv_past) { + pos -= kv_discard; + } + + const float * row_src = state.target_window.data() + (size_t) row * row_width; + shifted_rows.insert(shifted_rows.end(), row_src, row_src + row_width); + shifted_positions.push_back(pos); + } + + state.target_window = std::move(shifted_rows); + state.target_window_pos = std::move(shifted_positions); + state.target_window_rows = (int32_t) state.target_window_pos.size(); + state.last_target_pos = state.target_window_pos.empty() ? -1 : state.target_window_pos.back(); } static bool common_speculative_capture_target_features(common_speculative * spec, const common_speculative_feature_view & features) { @@ -2366,12 +2443,9 @@ int32_t common_speculative_on_target_batch( } } - std::vector hidden_rows_storage; - if (!common_speculative_feature_view_copy_batch_rows(features, batch, seq_id, &hidden_rows_storage)) { + if (!dflash_append_target_features(*dflash_state, features, batch, seq_id)) { return -1; } - - dflash_append_target_features(*dflash_state, hidden_rows_storage.data(), batch.n_tokens); return 0; } @@ -2439,6 +2513,10 @@ void common_speculative_context_shift( llama_kv_cache_seq_rm (ctx_mtp, seq_id, kv_keep, kv_keep + kv_discard); llama_kv_cache_seq_add(ctx_mtp, seq_id, kv_keep + kv_discard, kv_past, -kv_discard); } + + if (auto * dflash_state = common_speculative_get_dflash_state(spec); dflash_state != nullptr) { + dflash_context_shift(*dflash_state, kv_keep, kv_discard, kv_past); + } } std::vector mtp_speculative_gen_draft( diff --git a/examples/server/server-context.cpp b/examples/server/server-context.cpp index af6924fe..2a4e8c95 100644 --- a/examples/server/server-context.cpp +++ b/examples/server/server-context.cpp @@ -367,6 +367,9 @@ bool server_context::load_model(const gpt_params& params_) { if (params_dft.n_ctx == 0) { params_dft.n_ctx = params_base.speculative.n_ctx; } + if (server_speculative_has_dflash(params_base.speculative) && params_dft.n_gpu_layers < 0) { + params_dft.n_gpu_layers = params_base.n_gpu_layers; + } params_dft.n_ctx = params_dft.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_dft.n_ctx; params_dft.n_parallel = 1; params_dft.n_batch = params_dft.n_ctx; @@ -629,6 +632,7 @@ void server_slot::reset() { drafted_spec_type = COMMON_SPECULATIVE_TYPE_NONE; i_batch_dft.clear(); spec_ckpt.clear(); + spec_prompt_warmup_failed = false; n_sent_token_probs = 0; infill = false; ga_i = 0; @@ -717,7 +721,7 @@ void server_slot::add_token_string(const completion_token_output& token) { } bool server_slot::can_speculate() const { - return (!!spec || has_mtp); + return !spec_prompt_warmup_failed && (!!spec || has_mtp); } int server_slot::get_n_draft_max() const { @@ -3347,7 +3351,7 @@ void server_context::discard_n_kv_and_cache_tokens(llama_context* ctx, server_sl const auto pos_max = llama_kv_cache_seq_pos_max(slot.ctx, slot.id); llama_kv_cache_seq_rm(ctx, slot.id, slot.cache_tokens.pos_next(kv_keep), slot.cache_tokens.pos_next(kv_keep + kv_discard)); llama_kv_cache_seq_add(ctx, slot.id, kv_keep + kv_discard, kv_past, -kv_discard); - if (slot.has_mtp && slot.spec) { + if (slot.spec) { common_speculative_context_shift(slot.spec, slot.id, kv_keep, kv_discard, kv_past); } if (slot.params.cache_prompt) { @@ -4730,12 +4734,18 @@ void server_context::process_batch_tokens(int32_t & n_batch) { continue; } + if (slot.spec_prompt_warmup_failed) { + continue; + } + if ((slot.state != SLOT_STATE_PROCESSING || slot.n_decoded != 0) && (slot.state != SLOT_STATE_IDLE || slot.command != SLOT_COMMAND_LOAD_PROMPT)) { continue; } if (common_speculative_on_target_seq_batch(slot.spec, ctx, batch_view, slot.id, true) != 0) { + common_speculative_clear_sequence_hidden(slot.spec, slot.id); + slot.spec_prompt_warmup_failed = true; LOG_ERROR("failed to warm up speculative target-feature state from prompt batch for slot %d\n", slot.id); } } diff --git a/examples/server/server-context.h b/examples/server/server-context.h index a33c2113..d4d0913c 100644 --- a/examples/server/server-context.h +++ b/examples/server/server-context.h @@ -176,6 +176,8 @@ struct server_slot { // saves recurrent state before a speculative batch so it can be restored on rejection server_speculative_checkpoint spec_ckpt; + bool spec_prompt_warmup_failed = false; + // speculative decoding stats int32_t n_draft_total = 0; // Total draft tokens generated int32_t n_draft_accepted = 0; // Draft tokens actually accepted diff --git a/src/graphs/build_dflash.cpp b/src/graphs/build_dflash.cpp index fe1cec15..542821ad 100644 --- a/src/graphs/build_dflash.cpp +++ b/src/graphs/build_dflash.cpp @@ -28,6 +28,8 @@ ggml_cgraph * llm_build_context::build_dflash() { ggml_set_input(lctx.inp_dflash_kq_mask); cb(lctx.inp_dflash_kq_mask, "dflash_kq_mask", -1); + ggml_tensor * dflash_kq_mask = flash_attn ? ggml_cast(ctx0, lctx.inp_dflash_kq_mask, GGML_TYPE_F16) : lctx.inp_dflash_kq_mask; + ggml_tensor * tok_embd = model.tok_embd; if (tok_embd == nullptr) { tok_embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_Q4_0, n_embd, hparams.n_vocab); @@ -35,6 +37,7 @@ ggml_cgraph * llm_build_context::build_dflash() { ggml_tensor * inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, tok_embd, cb); ggml_tensor * inp_pos = build_inp_pos(); + ggml_tensor * inp_out_ids = (n_tokens > 1 && n_outputs < n_tokens) ? build_inp_out_ids() : nullptr; ggml_tensor * fused_target = llm_build_lora_mm(lctx, ctx0, model.dflash_fc, lctx.inp_dflash_target_features); fused_target = llm_build_norm(ctx0, fused_target, hparams, model.dflash_hidden_norm, nullptr, LLM_NORM_RMS, cb, -1); @@ -85,10 +88,9 @@ ggml_cgraph * llm_build_context::build_dflash() { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - Qcur = ggml_cast(ctx0, Qcur, GGML_TYPE_F16); Kcur = ggml_cast(ctx0, Kcur, GGML_TYPE_F16); Vcur = ggml_cast(ctx0, Vcur, GGML_TYPE_F16); - cb(Qcur, "Qcur_f16", il); + cb(Qcur, "Qcur", il); cb(Kcur, "Kcur_f16", il); cb(Vcur, "Vcur_f16", il); @@ -99,7 +101,7 @@ ggml_cgraph * llm_build_context::build_dflash() { cb(k, "k", il); cb(v, "v", il); - cur = ggml_flash_attn_ext(ctx0, q, k, v, lctx.inp_dflash_kq_mask, kq_scale, hparams.f_max_alibi_bias, + cur = ggml_flash_attn_ext(ctx0, q, k, v, dflash_kq_mask, kq_scale, hparams.f_max_alibi_bias, hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f); cb(cur, "flash_attn", il); ggml_build_forward_expand(gf, cur); @@ -136,7 +138,13 @@ ggml_cgraph * llm_build_context::build_dflash() { output = ggml_new_tensor_2d(ctx0, GGML_TYPE_Q4_0, n_embd, hparams.n_vocab); } - ggml_tensor * result = build_output(lctx, ctx0, inpL, output, model.output_norm, cb); + ggml_tensor * result_input = inpL; + if (inp_out_ids) { + result_input = ggml_get_rows(ctx0, result_input, inp_out_ids); + cb(result_input, "result_output_rows", -1); + } + + ggml_tensor * result = build_output(lctx, ctx0, result_input, output, model.output_norm, cb); cb(result, "result_output", -1); ggml_build_forward_expand(gf, result); diff --git a/src/llama-context.h b/src/llama-context.h index d0d9fe61..c4f62ac1 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -281,7 +281,10 @@ struct llama_context { const float * dflash_target_features = nullptr; size_t dflash_target_features_n_floats = 0; int32_t dflash_target_features_n_rows = 0; + const llama_pos * dflash_target_positions = nullptr; + size_t dflash_target_positions_n = 0; std::vector dflash_target_features_owned; + std::vector dflash_target_positions_owned; std::vector dflash_target_features_padded; std::vector dflash_feature_view_buffer; std::vector dflash_pos_ctx_data; diff --git a/src/llama-spec-features.cpp b/src/llama-spec-features.cpp index 827d536f..ccc6fb5d 100644 --- a/src/llama-spec-features.cpp +++ b/src/llama-spec-features.cpp @@ -96,7 +96,8 @@ bool llama_set_dflash_target_features_copy( struct llama_context * ctx, const float * target_features, size_t n_floats, - int32_t n_rows) { + int32_t n_rows, + const llama_pos * target_positions) { if (ctx == nullptr || target_features == nullptr || n_floats == 0 || n_rows <= 0) { return false; } @@ -105,6 +106,15 @@ bool llama_set_dflash_target_features_copy( ctx->dflash_target_features = ctx->dflash_target_features_owned.data(); ctx->dflash_target_features_n_floats = n_floats; ctx->dflash_target_features_n_rows = n_rows; + if (target_positions != nullptr) { + ctx->dflash_target_positions_owned.assign(target_positions, target_positions + n_rows); + ctx->dflash_target_positions = ctx->dflash_target_positions_owned.data(); + ctx->dflash_target_positions_n = (size_t) n_rows; + } else { + ctx->dflash_target_positions_owned.clear(); + ctx->dflash_target_positions = nullptr; + ctx->dflash_target_positions_n = 0; + } return true; } @@ -361,9 +371,25 @@ bool llama_spec_get_dflash_feature_view( return false; } - std::vector row_indices((size_t) batch.n_tokens); - for (int32_t i = 0; i < batch.n_tokens; ++i) { - row_indices[(size_t) i] = i; + int32_t row_count = 0; + int32_t row_width = 0; + int32_t n_layers = 0; + if (!llama_spec_prepare_dflash_capture(ctx, row_count, row_width, n_layers)) { + return false; + } + + const int32_t batch_row_offset = std::max(0, batch.n_tokens - row_count); + std::vector row_indices; + std::vector batch_indices; + row_indices.reserve((size_t) (batch.n_tokens - batch_row_offset)); + batch_indices.reserve((size_t) (batch.n_tokens - batch_row_offset)); + for (int32_t i = batch_row_offset; i < batch.n_tokens; ++i) { + row_indices.push_back(i - batch_row_offset); + batch_indices.push_back(i); + } + + if (row_indices.empty()) { + return false; } view = {}; @@ -372,17 +398,17 @@ bool llama_spec_get_dflash_feature_view( return false; } - view.rows.reserve((size_t) batch.n_tokens); - for (int32_t i = 0; i < batch.n_tokens; ++i) { - if (batch.n_seq_id[i] <= 0 || batch.seq_id[i] == nullptr) { + view.rows.reserve(batch_indices.size()); + for (int32_t batch_index : batch_indices) { + if (batch.n_seq_id[batch_index] <= 0 || batch.seq_id[batch_index] == nullptr) { view.rows.clear(); return false; } view.rows.push_back({ - /* .seq_id = */ batch.seq_id[i][0], - /* .pos = */ batch.pos[i], - /* .data = */ ctx->dflash_feature_view_buffer.data() + (size_t) i * (size_t) view.width, + /* .seq_id = */ batch.seq_id[batch_index][0], + /* .pos = */ batch.pos[batch_index], + /* .data = */ ctx->dflash_feature_view_buffer.data() + view.rows.size() * (size_t) view.width, }); } @@ -398,18 +424,26 @@ bool llama_spec_get_dflash_feature_view_for_seq( return false; } + int32_t row_count = 0; + int32_t row_width = 0; + int32_t n_layers = 0; + if (!llama_spec_prepare_dflash_capture(ctx, row_count, row_width, n_layers)) { + return false; + } + + const int32_t batch_row_offset = std::max(0, batch.n_tokens - row_count); std::vector row_indices; row_indices.reserve((size_t) batch.n_tokens); std::vector batch_indices; batch_indices.reserve((size_t) batch.n_tokens); - for (int32_t i = 0; i < batch.n_tokens; ++i) { + for (int32_t i = batch_row_offset; i < batch.n_tokens; ++i) { if (batch.n_seq_id[i] <= 0 || batch.seq_id[i] == nullptr) { return false; } for (int32_t j = 0; j < batch.n_seq_id[i]; ++j) { if (batch.seq_id[i][j] == seq_id) { - row_indices.push_back(i); + row_indices.push_back(i - batch_row_offset); batch_indices.push_back(i); break; } diff --git a/src/llama-spec-features.h b/src/llama-spec-features.h index ea177c1e..130d3895 100644 --- a/src/llama-spec-features.h +++ b/src/llama-spec-features.h @@ -51,7 +51,8 @@ bool llama_set_dflash_target_features_copy( struct llama_context * ctx, const float * target_features, size_t n_floats, - int32_t n_rows); + int32_t n_rows, + const llama_pos * target_positions); bool llama_set_dflash_capture_layers( struct llama_context * ctx, diff --git a/src/llama.cpp b/src/llama.cpp index 0bc960c3..a9b443c7 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -4994,7 +4994,9 @@ static bool prepare_dflash_graph_inputs( } const float * src = lctx.dflash_target_features; + const llama_pos * src_pos = lctx.dflash_target_positions; const size_t total_floats = lctx.dflash_target_features_n_floats; + const size_t total_positions = lctx.dflash_target_positions_n; const int32_t n_rows = lctx.dflash_target_features_n_rows; const int32_t width = (int32_t) target_hidden->ne[0]; const int32_t cross_ctx = (int32_t) target_hidden->ne[1]; @@ -5014,20 +5016,26 @@ static bool prepare_dflash_graph_inputs( lctx.dflash_target_features_padded.assign((size_t) cross_ctx * (size_t) width, 0.0f); const size_t dst_offset = (size_t) (cross_ctx - n_rows) * (size_t) width; + const int32_t left_pad = cross_ctx - n_rows; std::copy(src, src + total_floats, lctx.dflash_target_features_padded.begin() + (ptrdiff_t) dst_offset); ggml_backend_tensor_set(target_hidden, lctx.dflash_target_features_padded.data(), 0, ggml_nbytes(target_hidden)); lctx.dflash_pos_ctx_data.resize((size_t) cross_ctx); - for (int32_t i = 0; i < cross_ctx; ++i) { - lctx.dflash_pos_ctx_data[i] = i; + std::fill(lctx.dflash_pos_ctx_data.begin(), lctx.dflash_pos_ctx_data.end(), 0); + if (src_pos != nullptr && total_positions == (size_t) n_rows) { + std::copy(src_pos, src_pos + n_rows, lctx.dflash_pos_ctx_data.begin() + (ptrdiff_t) left_pad); + } else { + for (int32_t i = 0; i < n_rows; ++i) { + lctx.dflash_pos_ctx_data[(size_t) left_pad + (size_t) i] = i; + } } ggml_backend_tensor_set(pos_ctx, lctx.dflash_pos_ctx_data.data(), 0, ggml_nbytes(pos_ctx)); lctx.dflash_kq_mask_data.assign((size_t) n_kv_total * (size_t) n_mask_tokens, -INFINITY); - const int32_t left_pad = cross_ctx - n_rows; for (uint32_t j = 0; j < n_tokens; ++j) { float * row = lctx.dflash_kq_mask_data.data() + (size_t) j * (size_t) n_kv_total; - for (int32_t i = left_pad; i < n_kv_total; ++i) { + const int32_t visible_kv = cross_ctx + (int32_t) j + 1; + for (int32_t i = left_pad; i < visible_kv; ++i) { row[i] = 0.0f; } } From 532499836efb11c723dc152633239e6697e2489a Mon Sep 17 00:00:00 2001 From: SamuelOliveirads Date: Sat, 30 May 2026 21:36:10 -0300 Subject: [PATCH 03/13] improve DFlash caching and profiling capabilities --- common/speculative.cpp | 276 ++++++++++++++++++++- examples/server/server-context.cpp | 111 ++++++++- src/graphs/build_dflash.cpp | 130 ++++++++-- src/llama-build-context.cpp | 64 +++-- src/llama-build-context.h | 9 +- src/llama-context.h | 14 ++ src/llama-spec-features.cpp | 379 ++++++++++++++++++++++++++++- src/llama-spec-features.h | 77 ++++++ src/llama.cpp | 274 ++++++++++++++++++++- 9 files changed, 1261 insertions(+), 73 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index d740ded9..bd854a7f 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -11,9 +11,12 @@ #include "suffix-tree.h" #include +#include +#include #include #include #include +#include #include #define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128 @@ -210,6 +213,15 @@ struct common_speculative_state { struct common_speculative_state_mtp; struct common_speculative_state_dflash; +static void dflash_contract_log_append( + const common_speculative_state_dflash & state, + llama_seq_id seq_id, + const std::vector & new_positions); +static void dflash_contract_log_draft( + const common_speculative_state_dflash & state, + int32_t n_keep, + size_t result_size); + static common_speculative_state_mtp * common_speculative_get_mtp_state(common_speculative * spec); static const common_speculative_state_mtp * common_speculative_get_mtp_state(const common_speculative * spec); static common_speculative_state_dflash * common_speculative_get_dflash_state(common_speculative * spec); @@ -235,6 +247,81 @@ static std::vector mtp_speculative_gen_draft( static int32_t mtp_update_kv_cache(struct llama_context * ctx, const llama_batch & batch, bool is_prompt_warmup); +static bool dflash_contract_log_enabled() { + const char * env = std::getenv("IK_DFLASH_CONTRACT_LOG"); + if (env == nullptr || *env == '\0') { + return false; + } + + return std::strcmp(env, "0") != 0 && + std::strcmp(env, "false") != 0 && + std::strcmp(env, "off") != 0; +} + +template +static std::string dflash_contract_format_values( + const std::vector & values, + size_t edge_count = 4) { + std::ostringstream oss; + oss << '['; + if (values.empty()) { + oss << ']'; + return oss.str(); + } + + const size_t head = std::min(edge_count, values.size()); + for (size_t i = 0; i < head; ++i) { + if (i > 0) { + oss << ','; + } + oss << values[i]; + } + + if (values.size() > edge_count * 2) { + oss << ",...,"; + for (size_t i = values.size() - edge_count; i < values.size(); ++i) { + if (i > values.size() - edge_count) { + oss << ','; + } + oss << values[i]; + } + } else { + for (size_t i = head; i < values.size(); ++i) { + oss << ',' << values[i]; + } + } + + oss << ']'; + return oss.str(); +} + +struct dflash_contract_pos_summary { + llama_pos first = -1; + llama_pos last = -1; + int32_t gap_count = 0; + int32_t nonmono_count = 0; +}; + +static dflash_contract_pos_summary dflash_contract_summarize_positions( + const std::vector & positions) { + dflash_contract_pos_summary summary; + if (positions.empty()) { + return summary; + } + + summary.first = positions.front(); + summary.last = positions.back(); + for (size_t i = 1; i < positions.size(); ++i) { + if (positions[i] <= positions[i - 1]) { + summary.nonmono_count++; + } else if (positions[i] != positions[i - 1] + 1) { + summary.gap_count++; + } + } + + return summary; +} + struct mtp_last_embd { std::vector embd; float prob = 0.0f; @@ -368,6 +455,14 @@ struct common_speculative_state_dflash : public common_speculative_state { std::vector target_window_pos; int32_t target_window_rows = 0; llama_pos last_target_pos = -1; + size_t n_window_updates = 0; + size_t n_rows_seen = 0; + size_t n_rows_dropped = 0; + size_t n_context_shifts = 0; + size_t n_draft_empty = 0; + size_t n_set_target_fail = 0; + size_t n_decode_fail = 0; + llama_pos last_draft_pos_base = -1; common_speculative_state_dflash( enum common_speculative_type type, @@ -412,8 +507,20 @@ struct common_speculative_state_dflash : public common_speculative_state { batch = llama_batch_init(std::max(1, block_size), 0, 1); ready = true; - LOG_INF("%s: DFlash context ready (n_ctx=%d, block_size=%d, cross_ctx=%d, n_target_features=%d)\n", - __func__, llama_n_ctx(ctx_dft), block_size, this->cross_ctx, n_target_features); + llama_set_dflash_visible_cross_ctx(ctx_dft, this->cross_ctx); + llama_dflash_profile_reset(ctx_tgt); + llama_dflash_profile_reset(ctx_dft); + + std::ostringstream layers_oss; + for (size_t i = 0; i < target_layer_ids.size(); ++i) { + if (i > 0) { + layers_oss << ","; + } + layers_oss << target_layer_ids[i]; + } + + LOG_INF("%s: DFlash context ready (n_ctx=%d, block_size=%d, cross_ctx=%d, n_target_features=%d, target_layer_ids=[%s])\n", + __func__, llama_n_ctx(ctx_dft), block_size, this->cross_ctx, n_target_features, layers_oss.str().c_str()); } ~common_speculative_state_dflash() override { @@ -429,6 +536,16 @@ struct common_speculative_state_dflash : public common_speculative_state { void begin(const llama_tokens & prompt) override { GGML_UNUSED(prompt); llama_kv_cache_clear(ctx_dft); + n_window_updates = 0; + n_rows_seen = 0; + n_rows_dropped = 0; + n_context_shifts = 0; + n_draft_empty = 0; + n_set_target_fail = 0; + n_decode_fail = 0; + last_draft_pos_base = -1; + llama_dflash_profile_reset(ctx_tgt); + llama_dflash_profile_reset(ctx_dft); } void draft( @@ -441,6 +558,7 @@ struct common_speculative_state_dflash : public common_speculative_state { result.clear(); if (!ready || target_window_rows <= 0) { + n_draft_empty++; return; } @@ -449,20 +567,23 @@ struct common_speculative_state_dflash : public common_speculative_state { return; } - if (!llama_set_dflash_target_features_copy(ctx_dft, target_window.data(), target_window.size(), target_window_rows, target_window_pos.data())) { + if (!llama_set_dflash_target_features_view(ctx_dft, target_window.data(), target_window.size(), target_window_rows, target_window_pos.data())) { LOG_ERR("%s: failed to set DFlash target features\n", __func__); + n_set_target_fail++; return; } llama_kv_cache_clear(ctx_dft); batch.n_tokens = 0; const llama_pos draft_pos_base = last_target_pos >= 0 ? last_target_pos + 1 : (llama_pos) target_window_rows; + last_draft_pos_base = draft_pos_base; for (int32_t i = 0; i < block_size; ++i) { common_batch_add(batch, mask_token_id, draft_pos_base + i, { 0 }, i < n_keep); } if (llama_decode(ctx_dft, batch) != 0) { LOG_ERR("%s: llama_decode() failed for DFlash draft batch\n", __func__); + n_decode_fail++; batch.n_tokens = 0; return; } @@ -473,6 +594,7 @@ struct common_speculative_state_dflash : public common_speculative_state { } batch.n_tokens = 0; + dflash_contract_log_draft(*this, n_keep, result.size()); } void accept(uint16_t n_accepted) override { @@ -480,6 +602,83 @@ struct common_speculative_state_dflash : public common_speculative_state { } }; +static void dflash_contract_log_append( + const common_speculative_state_dflash & state, + llama_seq_id seq_id, + const std::vector & new_positions) { + if (!dflash_contract_log_enabled()) { + return; + } + + static std::atomic counter = 0; + const uint64_t ordinal = counter.fetch_add(1, std::memory_order_relaxed); + if (ordinal >= 8) { + return; + } + + const dflash_contract_pos_summary incoming = dflash_contract_summarize_positions(new_positions); + const dflash_contract_pos_summary window = dflash_contract_summarize_positions(state.target_window_pos); + + LOG_INF("dflash contract append[%llu]: seq=%d incoming_rows=%zu incoming_pos=%s pos=[%d..%d] gaps=%d nonmono=%d window_rows=%d window_pos=%s pos=[%d..%d] gaps=%d nonmono=%d last_target_pos=%d\n", + (unsigned long long) (ordinal + 1), + (int) seq_id, + new_positions.size(), + dflash_contract_format_values(new_positions).c_str(), + (int) incoming.first, + (int) incoming.last, + incoming.gap_count, + incoming.nonmono_count, + state.target_window_rows, + dflash_contract_format_values(state.target_window_pos).c_str(), + (int) window.first, + (int) window.last, + window.gap_count, + window.nonmono_count, + (int) state.last_target_pos); +} + +static void dflash_contract_log_draft( + const common_speculative_state_dflash & state, + int32_t n_keep, + size_t result_size) { + if (!dflash_contract_log_enabled()) { + return; + } + + static std::atomic counter = 0; + const uint64_t ordinal = counter.fetch_add(1, std::memory_order_relaxed); + if (ordinal >= 8) { + return; + } + + const dflash_contract_pos_summary window = dflash_contract_summarize_positions(state.target_window_pos); + llama_dflash_profile_stats graph_stats = {}; + llama_dflash_profile_get_stats(state.ctx_dft, &graph_stats); + const int draft_delta = (state.last_target_pos >= 0 && state.last_draft_pos_base >= 0) + ? (int) (state.last_draft_pos_base - state.last_target_pos) + : -1; + + LOG_INF("dflash contract draft[%llu]: window_rows=%d window_pos=%s pos=[%d..%d] gaps=%d nonmono=%d last_target_pos=%d draft_pos_base=%d delta=%d n_keep=%d result=%zu set_target(missing/nonmono)=%llu/%llu graph(fallback/nonmono)=%llu/%llu graph_pos=[%d..%d]\n", + (unsigned long long) (ordinal + 1), + state.target_window_rows, + dflash_contract_format_values(state.target_window_pos).c_str(), + (int) window.first, + (int) window.last, + window.gap_count, + window.nonmono_count, + (int) state.last_target_pos, + (int) state.last_draft_pos_base, + draft_delta, + n_keep, + result_size, + (unsigned long long) graph_stats.set_target_missing_positions, + (unsigned long long) graph_stats.set_target_non_monotonic_positions, + (unsigned long long) graph_stats.graph_pos_fallbacks, + (unsigned long long) graph_stats.graph_pos_non_monotonic, + (int) graph_stats.last_pos_first, + (int) graph_stats.last_pos_last); +} + struct common_speculative_state_draft : public common_speculative_state { llama_context * ctx_tgt; // only used for retokenizing from ctx_dft llama_context * ctx_dft; @@ -2101,6 +2300,70 @@ void common_speculative_print_stats(const common_speculative * spec, double slot impl->n_gen_tokens, impl->n_acc_tokens, str_perf.c_str()); + + if (impl->type == COMMON_SPECULATIVE_TYPE_DFLASH) { + const auto * dflash_state = dynamic_cast(impl.get()); + if (dflash_state != nullptr) { + llama_dflash_profile_stats capture_stats; + llama_dflash_profile_stats graph_stats; + const bool have_capture = llama_dflash_profile_get_stats(dflash_state->ctx_tgt, &capture_stats); + const bool have_graph = llama_dflash_profile_get_stats(dflash_state->ctx_dft, &graph_stats); + + LOG_INF("statistics dflash detail: cross_ctx=%d, window_rows=%d, pos=[%d..%d], window_updates=%zu, rows_seen=%zu, rows_dropped=%zu, shifts=%zu, draft_fail(empty/set/decode)=%zu/%zu/%zu, next_draft_pos=%d\n", + dflash_state->cross_ctx, + dflash_state->target_window_rows, + dflash_state->target_window_pos.empty() ? -1 : (int) dflash_state->target_window_pos.front(), + dflash_state->target_window_pos.empty() ? -1 : (int) dflash_state->target_window_pos.back(), + dflash_state->n_window_updates, + dflash_state->n_rows_seen, + dflash_state->n_rows_dropped, + dflash_state->n_context_shifts, + dflash_state->n_draft_empty, + dflash_state->n_set_target_fail, + dflash_state->n_decode_fail, + (int) dflash_state->last_draft_pos_base); + + if (have_capture || have_graph) { + LOG_INF("statistics dflash profile: capture(sync/materialize)=%.3f/%.3f ms calls=%llu/%llu bytes=%llu phase(prompt/verify batches changes)=%llu/%llu %llu/%llu, set_target=%.3f ms rows=%llu bytes=%llu, prep(total/features/pos/mask)=%.3f/%.3f/%.3f/%.3f ms kv_cache=%.3f ms calls=%llu/%llu bytes=%llu/%llu/%llu, fallback_pos(copy/graph)=%llu/%llu, nonmono(copy/graph)=%llu/%llu, capture_fail=%llu/%llu, visible_kv_max=%llu, last(rows=%d width=%d left_pad=%d n_tokens=%d n_kv=%d pos=[%d..%d])\n", + (double) capture_stats.capture_prepare_sync_us / 1000.0, + (double) capture_stats.capture_materialize_us / 1000.0, + (unsigned long long) capture_stats.capture_prepare_calls, + (unsigned long long) capture_stats.capture_materialize_calls, + (unsigned long long) capture_stats.capture_materialize_bytes, + (unsigned long long) capture_stats.capture_prompt_batches, + (unsigned long long) capture_stats.capture_prompt_shape_changes, + (unsigned long long) capture_stats.capture_verify_batches, + (unsigned long long) capture_stats.capture_verify_shape_changes, + (double) graph_stats.set_target_copy_us / 1000.0, + (unsigned long long) graph_stats.set_target_rows, + (unsigned long long) graph_stats.set_target_copy_bytes, + (double) graph_stats.graph_prepare_total_us / 1000.0, + (double) graph_stats.graph_feature_copy_us / 1000.0, + (double) graph_stats.graph_pos_copy_us / 1000.0, + (double) graph_stats.graph_mask_build_us / 1000.0, + (double) graph_stats.graph_kv_cache_compute_us / 1000.0, + (unsigned long long) graph_stats.graph_prepare_calls, + (unsigned long long) graph_stats.graph_kv_cache_calls, + (unsigned long long) graph_stats.graph_feature_bytes, + (unsigned long long) graph_stats.graph_pos_bytes, + (unsigned long long) graph_stats.graph_mask_bytes, + (unsigned long long) graph_stats.set_target_missing_positions, + (unsigned long long) graph_stats.graph_pos_fallbacks, + (unsigned long long) graph_stats.set_target_non_monotonic_positions, + (unsigned long long) graph_stats.graph_pos_non_monotonic, + (unsigned long long) capture_stats.capture_prepare_failures, + (unsigned long long) capture_stats.capture_materialize_failures, + (unsigned long long) graph_stats.graph_visible_kv_max, + graph_stats.last_n_rows, + graph_stats.last_width, + graph_stats.last_left_pad, + graph_stats.last_n_tokens, + graph_stats.last_n_kv_total, + (int) graph_stats.last_pos_first, + (int) graph_stats.last_pos_last); + } + } + } } if (spec->tuner && spec->tuner->enabled && slot_tps > 0.0 && n_decoded > 0) { @@ -2251,7 +2514,10 @@ static bool dflash_append_target_features( } const int32_t n_rows = (int32_t) new_positions.size(); + state.n_window_updates++; + state.n_rows_seen += (size_t) n_rows; if (n_rows >= state.cross_ctx) { + state.n_rows_dropped += (size_t) state.target_window_rows + (size_t) (n_rows - state.cross_ctx); const int32_t keep_from = n_rows - state.cross_ctx; state.target_window.assign( new_rows.begin() + (ptrdiff_t) keep_from * (ptrdiff_t) row_width, @@ -2259,10 +2525,12 @@ static bool dflash_append_target_features( state.target_window_pos.assign(new_positions.begin() + keep_from, new_positions.end()); state.target_window_rows = state.cross_ctx; state.last_target_pos = state.target_window_pos.empty() ? -1 : state.target_window_pos.back(); + dflash_contract_log_append(state, seq_id, new_positions); return true; } const int32_t keep_old_rows = std::min(state.target_window_rows, state.cross_ctx - n_rows); + state.n_rows_dropped += (size_t) std::max(0, state.target_window_rows - keep_old_rows); std::vector next_window((size_t) (keep_old_rows + n_rows) * row_width); std::vector next_window_pos((size_t) (keep_old_rows + n_rows)); @@ -2282,6 +2550,7 @@ static bool dflash_append_target_features( state.target_window_pos = std::move(next_window_pos); state.target_window_rows = keep_old_rows + n_rows; state.last_target_pos = state.target_window_pos.empty() ? -1 : state.target_window_pos.back(); + dflash_contract_log_append(state, seq_id, new_positions); return true; } @@ -2329,6 +2598,7 @@ static void dflash_context_shift( state.target_window_pos = std::move(shifted_positions); state.target_window_rows = (int32_t) state.target_window_pos.size(); state.last_target_pos = state.target_window_pos.empty() ? -1 : state.target_window_pos.back(); + state.n_context_shifts++; } static bool common_speculative_capture_target_features(common_speculative * spec, const common_speculative_feature_view & features) { diff --git a/examples/server/server-context.cpp b/examples/server/server-context.cpp index 2a4e8c95..c3171d76 100644 --- a/examples/server/server-context.cpp +++ b/examples/server/server-context.cpp @@ -13,8 +13,12 @@ #include "mtmd-helper.h" #include +#include +#include +#include #include #include +#include #include static void server_prompt_checkpoint_update(server_prompt_checkpoint & ckpt, llama_context * ctx, int id, int64_t n_tokens, llama_pos pos_min = -1, llama_pos pos_max = -1, int32_t offset = 0) { @@ -45,6 +49,83 @@ static void log_text(const gpt_params & params_base, const std::string & text) { } } +static bool server_dflash_contract_log_enabled() { + const char * env = std::getenv("IK_DFLASH_CONTRACT_LOG"); + if (env == nullptr || *env == '\0') { + return false; + } + + return std::strcmp(env, "0") != 0 && + std::strcmp(env, "false") != 0 && + std::strcmp(env, "off") != 0; +} + +static std::string server_dflash_contract_format_indices( + const std::vector & values, + size_t edge_count = 4) { + std::ostringstream oss; + oss << '['; + if (values.empty()) { + oss << ']'; + return oss.str(); + } + + const size_t head = std::min(edge_count, values.size()); + for (size_t i = 0; i < head; ++i) { + if (i > 0) { + oss << ','; + } + oss << values[i]; + } + + if (values.size() > edge_count * 2) { + oss << ",...,"; + for (size_t i = values.size() - edge_count; i < values.size(); ++i) { + if (i > values.size() - edge_count) { + oss << ','; + } + oss << values[i]; + } + } else { + for (size_t i = head; i < values.size(); ++i) { + oss << ',' << values[i]; + } + } + + oss << ']'; + return oss.str(); +} + +static void server_dflash_contract_log_accept( + const server_slot & slot, + common_speculative_type spec_type_used, + const char * path, + bool any_rejected, + size_t n_draft, + const std::vector & ids, + llama_pos pos_base, + const std::vector & output_indices) { + if (!server_dflash_contract_log_enabled() || spec_type_used != COMMON_SPECULATIVE_TYPE_DFLASH) { + return; + } + + static std::atomic counter = 0; + const uint64_t ordinal = counter.fetch_add(1, std::memory_order_relaxed); + if (ordinal >= 8) { + return; + } + + LLAMA_LOG_INFO("dflash contract accept[%llu]: slot=%d path=%s rejected=%s drafted=%zu accepted=%zu pos_base=%d output_indices=%s\n", + (unsigned long long) (ordinal + 1), + slot.id, + path, + any_rejected ? "true" : "false", + n_draft, + ids.size(), + (int) pos_base, + server_dflash_contract_format_indices(output_indices).c_str()); +} + static bool params_use_gemma4_external_mtp(const gpt_params & params_base) { return params_base.has_mtp && llama_model_is_gemma4_mtp_assistant(params_base.speculative.model_dft); @@ -4226,6 +4307,16 @@ static void restore_speculative_checkpoint( redecoded_indices[j] = j; } + server_dflash_contract_log_accept( + slot, + spec_type_used, + "restore", + true, + n_draft, + ids, + slot.spec_ckpt.n_past, + redecoded_indices); + if (!common_speculative_commit_accepted_output( slot.spec, ctx, @@ -4338,6 +4429,16 @@ void server_context::speculative_decoding_accept() { restore_speculative_checkpoint(slot, ctx, model, spec_type_used, sampled_before, ids, n_draft, spec_feature_rows_pre, spec_n_past_base); } else { if (server_speculative_has_target_features(slot.params.speculative) && !accepted_output_indices.empty()) { + server_dflash_contract_log_accept( + slot, + spec_type_used, + "direct", + false, + n_draft, + ids, + spec_n_past_base, + accepted_output_indices); + if (!common_speculative_commit_accepted_output( slot.spec, ctx, @@ -4729,6 +4830,7 @@ void server_context::process_batch_tokens(int32_t & n_batch) { } if (server_speculative_has_target_features(params_base.speculative)) { + bool finished_prompt_warmup_batch = false; for (auto & slot : slots) { if (!slot.spec || !server_speculative_has_target_features(slot.params.speculative)) { continue; @@ -4738,8 +4840,7 @@ void server_context::process_batch_tokens(int32_t & n_batch) { continue; } - if ((slot.state != SLOT_STATE_PROCESSING || slot.n_decoded != 0) && - (slot.state != SLOT_STATE_IDLE || slot.command != SLOT_COMMAND_LOAD_PROMPT)) { + if (slot.command != SLOT_COMMAND_LOAD_PROMPT) { continue; } @@ -4747,8 +4848,14 @@ void server_context::process_batch_tokens(int32_t & n_batch) { common_speculative_clear_sequence_hidden(slot.spec, slot.id); slot.spec_prompt_warmup_failed = true; LOG_ERROR("failed to warm up speculative target-feature state from prompt batch for slot %d\n", slot.id); + } else { + finished_prompt_warmup_batch = true; } } + + if (finished_prompt_warmup_batch) { + llama_finish_dflash_capture_batch(ctx, true); + } } for (auto& slot : slots) { diff --git a/src/graphs/build_dflash.cpp b/src/graphs/build_dflash.cpp index 542821ad..ef50f868 100644 --- a/src/graphs/build_dflash.cpp +++ b/src/graphs/build_dflash.cpp @@ -3,33 +3,103 @@ #include "../llama-model.h" #include +#include + +static bool dflash_use_kv_cache_experiment() { + const char * env = std::getenv("IK_DFLASH_KV_CACHE"); + if (env == nullptr || *env == '\0') { + return false; + } + + return std::strcmp(env, "0") != 0 && + std::strcmp(env, "false") != 0 && + std::strcmp(env, "off") != 0; +} + +ggml_cgraph * llm_build_context::build_dflash_kv_cache() { + const int64_t n_embd_head_k = hparams.n_embd_head_k(0); + const int64_t n_embd_head_v = hparams.n_embd_head_v(0); + const int64_t n_target_features = hparams.dflash_n_target_features; + const int64_t ctx_len = lctx.dflash_visible_cross_ctx > 0 + ? (int64_t) lctx.dflash_visible_cross_ctx + : std::max(1, (int64_t) cparams.n_ctx - (int64_t) hparams.dflash_block_size); + + GGML_ASSERT(n_embd_head_k == n_embd_head_v); + GGML_ASSERT(n_target_features > 0); + GGML_ASSERT(lctx.ensure_dflash_kv_cache_tensors((int32_t) ctx_len)); + + ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes((int) std::max(1, ctx_len)) + 24 * n_layer, false); + + lctx.dflash_kv_input_target_features = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_target_features, ctx_len); + ggml_set_input(lctx.dflash_kv_input_target_features); + + lctx.dflash_kv_input_pos_ctx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ctx_len); + ggml_set_input(lctx.dflash_kv_input_pos_ctx); + + ggml_tensor * fused_target = llm_build_lora_mm(lctx, ctx0, model.dflash_fc, lctx.dflash_kv_input_target_features); + fused_target = llm_build_norm(ctx0, fused_target, hparams, model.dflash_hidden_norm, nullptr, LLM_NORM_RMS, cb, -1); + + for (int il = 0; il < n_layer; ++il) { + GGML_ASSERT((size_t) il < lctx.dflash_k_ctx_cache.size()); + GGML_ASSERT((size_t) il < lctx.dflash_v_ctx_cache.size()); + + ggml_tensor * Kcur_ctx = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, fused_target); + Kcur_ctx = ggml_reshape_3d(ctx0, Kcur_ctx, n_embd_head_k, n_head_kv, ctx_len); + Kcur_ctx = llm_build_norm(ctx0, Kcur_ctx, hparams, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, cb, il); + Kcur_ctx = ggml_rope_ext(ctx0, Kcur_ctx, lctx.dflash_kv_input_pos_ctx, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + ggml_tensor * Vcur_ctx = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, fused_target); + Vcur_ctx = ggml_reshape_3d(ctx0, Vcur_ctx, n_embd_head_v, n_head_kv, ctx_len); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur_ctx, lctx.dflash_k_ctx_cache[(size_t) il])); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur_ctx, lctx.dflash_v_ctx_cache[(size_t) il])); + } + + return gf; +} ggml_cgraph * llm_build_context::build_dflash() { const int64_t n_embd_head_k = hparams.n_embd_head_k(0); const int64_t n_embd_head_v = hparams.n_embd_head_v(0); const int64_t n_target_features = hparams.dflash_n_target_features; - const int64_t ctx_len = std::max(1, (int64_t) cparams.n_ctx - (int64_t) hparams.dflash_block_size); - const int64_t n_kv_total = ctx_len + n_tokens; + const bool use_kv_cache = dflash_use_kv_cache_experiment(); + const int64_t ctx_len = lctx.dflash_visible_cross_ctx > 0 + ? (int64_t) lctx.dflash_visible_cross_ctx + : std::max(1, (int64_t) cparams.n_ctx - (int64_t) hparams.dflash_block_size); + const int64_t n_kv_total = GGML_PAD(ctx_len + n_tokens, flash_attn ? 256 : 32); + const int64_t n_kv_pad = n_kv_total - (ctx_len + n_tokens); GGML_ASSERT(n_embd_head_k == n_embd_head_v); GGML_ASSERT(n_target_features > 0); + GGML_ASSERT(!use_kv_cache || lctx.ensure_dflash_kv_cache_tensors((int32_t) ctx_len)); ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes((int) std::max(n_tokens, ctx_len)) + 32 * n_layer, false); - lctx.inp_dflash_target_features = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_target_features, ctx_len); - ggml_set_input(lctx.inp_dflash_target_features); - cb(lctx.inp_dflash_target_features, "dflash_target_features", -1); - - lctx.inp_dflash_pos_ctx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ctx_len); - ggml_set_input(lctx.inp_dflash_pos_ctx); - cb(lctx.inp_dflash_pos_ctx, "dflash_pos_ctx", -1); - lctx.inp_dflash_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv_total, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + lctx.dflash_kq_mask_tensor = lctx.inp_dflash_kq_mask; ggml_set_input(lctx.inp_dflash_kq_mask); cb(lctx.inp_dflash_kq_mask, "dflash_kq_mask", -1); ggml_tensor * dflash_kq_mask = flash_attn ? ggml_cast(ctx0, lctx.inp_dflash_kq_mask, GGML_TYPE_F16) : lctx.inp_dflash_kq_mask; + ggml_tensor * fused_target = nullptr; + ggml_tensor * pos_ctx = nullptr; + if (!use_kv_cache) { + lctx.inp_dflash_target_features = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_target_features, ctx_len); + ggml_set_input(lctx.inp_dflash_target_features); + cb(lctx.inp_dflash_target_features, "dflash_target_features", -1); + + lctx.inp_dflash_pos_ctx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ctx_len); + ggml_set_input(lctx.inp_dflash_pos_ctx); + cb(lctx.inp_dflash_pos_ctx, "dflash_pos_ctx", -1); + + fused_target = llm_build_lora_mm(lctx, ctx0, model.dflash_fc, lctx.inp_dflash_target_features); + fused_target = llm_build_norm(ctx0, fused_target, hparams, model.dflash_hidden_norm, nullptr, LLM_NORM_RMS, cb, -1); + pos_ctx = lctx.inp_dflash_pos_ctx; + } + ggml_tensor * tok_embd = model.tok_embd; if (tok_embd == nullptr) { tok_embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_Q4_0, n_embd, hparams.n_vocab); @@ -39,10 +109,6 @@ ggml_cgraph * llm_build_context::build_dflash() { ggml_tensor * inp_pos = build_inp_pos(); ggml_tensor * inp_out_ids = (n_tokens > 1 && n_outputs < n_tokens) ? build_inp_out_ids() : nullptr; - ggml_tensor * fused_target = llm_build_lora_mm(lctx, ctx0, model.dflash_fc, lctx.inp_dflash_target_features); - fused_target = llm_build_norm(ctx0, fused_target, hparams, model.dflash_hidden_norm, nullptr, LLM_NORM_RMS, cb, -1); - cb(fused_target, "dflash_target_fused", -1); - const float kq_scale = 1.0f / std::sqrt((float) n_embd_head_k); for (int il = 0; il < n_layer; ++il) { @@ -71,25 +137,35 @@ ggml_cgraph * llm_build_context::build_dflash() { Vcur_noise = ggml_reshape_3d(ctx0, Vcur_noise, n_embd_head_v, n_head_kv, n_tokens); cb(Vcur_noise, "Vcur_noise", il); - ggml_tensor * Kcur_ctx = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, fused_target); - Kcur_ctx = ggml_reshape_3d(ctx0, Kcur_ctx, n_embd_head_k, n_head_kv, ctx_len); - Kcur_ctx = llm_build_norm(ctx0, Kcur_ctx, hparams, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, cb, il); - Kcur_ctx = ggml_rope_ext(ctx0, Kcur_ctx, lctx.inp_dflash_pos_ctx, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow); - cb(Kcur_ctx, "Kcur_ctx", il); + ggml_tensor * Kcur_ctx = nullptr; + ggml_tensor * Vcur_ctx = nullptr; + if (use_kv_cache) { + Kcur_ctx = lctx.dflash_k_ctx_cache[(size_t) il]; + Vcur_ctx = lctx.dflash_v_ctx_cache[(size_t) il]; + cb(Kcur_ctx, "Kcur_ctx_cache", il); + cb(Vcur_ctx, "Vcur_ctx_cache", il); + } else { + Kcur_ctx = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, fused_target); + Kcur_ctx = ggml_reshape_3d(ctx0, Kcur_ctx, n_embd_head_k, n_head_kv, ctx_len); + Kcur_ctx = llm_build_norm(ctx0, Kcur_ctx, hparams, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, cb, il); + Kcur_ctx = ggml_rope_ext(ctx0, Kcur_ctx, pos_ctx, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); - ggml_tensor * Vcur_ctx = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, fused_target); - Vcur_ctx = ggml_reshape_3d(ctx0, Vcur_ctx, n_embd_head_v, n_head_kv, ctx_len); - cb(Vcur_ctx, "Vcur_ctx", il); + Vcur_ctx = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, fused_target); + Vcur_ctx = ggml_reshape_3d(ctx0, Vcur_ctx, n_embd_head_v, n_head_kv, ctx_len); + cb(Kcur_ctx, "Kcur_ctx", il); + cb(Vcur_ctx, "Vcur_ctx", il); + } ggml_tensor * Kcur = ggml_concat(ctx0, Kcur_ctx, Kcur_noise, 2); ggml_tensor * Vcur = ggml_concat(ctx0, Vcur_ctx, Vcur_noise, 2); + if (n_kv_pad > 0) { + Kcur = ggml_pad(ctx0, Kcur, 0, 0, (int) n_kv_pad, 0); + Vcur = ggml_pad(ctx0, Vcur, 0, 0, (int) n_kv_pad, 0); + } cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - - Kcur = ggml_cast(ctx0, Kcur, GGML_TYPE_F16); - Vcur = ggml_cast(ctx0, Vcur, GGML_TYPE_F16); cb(Qcur, "Qcur", il); cb(Kcur, "Kcur_f16", il); cb(Vcur, "Vcur_f16", il); diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index ad89e6e2..081215dd 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -35,7 +35,9 @@ llm_build_context::llm_build_context( const llm_build_cb & cb, bool worst_case, bool warmup, - int n_outputs_) : + int n_outputs_, + bool clear_lctx_inputs, + std::vector * buf_compute_meta_override) : model (lctx.model), lctx (lctx), hparams (model.hparams), @@ -82,8 +84,9 @@ llm_build_context::llm_build_context( thresh_experts (cparams.thresh_experts), pooling_type (cparams.pooling_type), rope_type (hparams.rope_type), + clear_lctx_inputs(clear_lctx_inputs), cb (cb), - buf_compute_meta (lctx.buf_compute_meta) { + buf_compute_meta (buf_compute_meta_override ? *buf_compute_meta_override : lctx.buf_compute_meta) { // all initializations should be done in init() } @@ -96,25 +99,27 @@ void llm_build_context::init() { ctx0 = ggml_init(params); - lctx.inp_tokens = nullptr; - lctx.inp_embd = nullptr; - lctx.inp_pos = nullptr; - lctx.inp_out_ids = nullptr; - lctx.inp_KQ_mask = nullptr; - lctx.inp_KQ_mask_swa = nullptr; - lctx.inp_K_shift = nullptr; - lctx.inp_mean = nullptr; - lctx.inp_cls = nullptr; - lctx.inp_s_copy = nullptr; - lctx.inp_s_mask = nullptr; - lctx.inp_s_seq = nullptr; - lctx.inp_s_seq_qnext = nullptr; - lctx.inp_pos_bucket = nullptr; - lctx.inp_embd_enc = nullptr; - lctx.inp_KQ_mask_cross = nullptr; - lctx.inp_dflash_target_features = nullptr; - lctx.inp_dflash_pos_ctx = nullptr; - lctx.inp_dflash_kq_mask = nullptr; + if (clear_lctx_inputs) { + lctx.inp_tokens = nullptr; + lctx.inp_embd = nullptr; + lctx.inp_pos = nullptr; + lctx.inp_out_ids = nullptr; + lctx.inp_KQ_mask = nullptr; + lctx.inp_KQ_mask_swa = nullptr; + lctx.inp_K_shift = nullptr; + lctx.inp_mean = nullptr; + lctx.inp_cls = nullptr; + lctx.inp_s_copy = nullptr; + lctx.inp_s_mask = nullptr; + lctx.inp_s_seq = nullptr; + lctx.inp_s_seq_qnext = nullptr; + lctx.inp_pos_bucket = nullptr; + lctx.inp_embd_enc = nullptr; + lctx.inp_KQ_mask_cross = nullptr; + lctx.inp_dflash_target_features = nullptr; + lctx.inp_dflash_pos_ctx = nullptr; + lctx.inp_dflash_kq_mask = nullptr; + } } void llm_build_context::free() { @@ -2164,6 +2169,23 @@ struct ggml_cgraph * llm_build_context::llama_build_graph_s_copy(llama_context & return result; } +struct ggml_cgraph * llm_build_context::llama_build_graph_dflash_kv_cache(llama_context & lctx) { + llama_batch dummy; + dummy.n_tokens = 0; + + llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { }; + + struct llm_build_context llm(lctx, dummy, cb, false, false, 0, false, &lctx.dflash_buf_compute_meta); + + llm.init(); + + struct ggml_cgraph * result = llm.build_dflash_kv_cache(); + + llm.free(); + + return result; +} + ggml_cgraph * llm_build_context::llama_build_graph( llama_context & lctx, const llama_batch & batch, diff --git a/src/llama-build-context.h b/src/llama-build-context.h index 7542aff6..ec7cbbb9 100644 --- a/src/llama-build-context.h +++ b/src/llama-build-context.h @@ -89,6 +89,7 @@ struct llm_build_context { const enum llama_pooling_type pooling_type; const enum llama_rope_type rope_type; + const bool clear_lctx_inputs; const llm_build_cb & cb; @@ -103,7 +104,9 @@ struct llm_build_context { const llm_build_cb & cb, bool worst_case, bool warmup, - int n_outputs = 0); + int n_outputs = 0, + bool clear_lctx_inputs = true, + std::vector * buf_compute_meta_override = nullptr); void init(); @@ -244,6 +247,8 @@ struct llm_build_context { ggml_cgraph * build_dflash(); + ggml_cgraph * build_dflash_kv_cache(); + ggml_cgraph * build_starcoder2(); ggml_cgraph * build_mamba(); @@ -459,6 +464,8 @@ llm_expert_gating_func_type gating_op, static ggml_cgraph * llama_build_graph_s_copy(llama_context & lctx); + static ggml_cgraph * llama_build_graph_dflash_kv_cache(llama_context & lctx); + static ggml_cgraph * llama_build_graph(llama_context & lctx, const llama_batch & batch, bool worst_case, int n_outputs = 0); ggml_tensor * build_std_attention(ggml_cgraph * gf, ggml_tensor * attn_norm, ggml_tensor * cur, diff --git a/src/llama-context.h b/src/llama-context.h index c4f62ac1..9e902003 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -289,6 +289,16 @@ struct llama_context { std::vector dflash_feature_view_buffer; std::vector dflash_pos_ctx_data; std::vector dflash_kq_mask_data; + int32_t dflash_visible_cross_ctx = 0; + std::vector dflash_k_ctx_cache; + std::vector dflash_v_ctx_cache; + struct ggml_context * dflash_cache_ctx = nullptr; + ggml_backend_buffer_t dflash_cache_buf = nullptr; + std::vector dflash_buf_compute_meta; + ggml_backend_sched_t dflash_sched = nullptr; + struct ggml_tensor * dflash_kv_input_target_features = nullptr; + struct ggml_tensor * dflash_kv_input_pos_ctx = nullptr; + struct ggml_tensor * dflash_kq_mask_tensor = nullptr; struct dflash_capture_state { std::vector layer_ids; @@ -299,6 +309,7 @@ struct llama_context { void * prev_cb_eval_user_data = nullptr; }; std::unique_ptr dflash_capture; + llama_dflash_profile_stats dflash_profile; // input tensors struct ggml_tensor * inp_tokens; // I32 [n_batch] @@ -340,6 +351,9 @@ struct llama_context { bool update_cache_copies(); + bool ensure_dflash_kv_cache_tensors(int32_t cross_ctx); + void free_dflash_kv_cache_tensors(); + bool prepare_mtp_graph_inputs( struct llama_context & lctx); void set_mtp_op_type(llama_mtp_op_type value); diff --git a/src/llama-spec-features.cpp b/src/llama-spec-features.cpp index ccc6fb5d..91b5d41d 100644 --- a/src/llama-spec-features.cpp +++ b/src/llama-spec-features.cpp @@ -1,13 +1,39 @@ #include "llama-spec-features.h" #include +#include #include #include #include +#include #include "llama-model.h" #include "llama-context.h" +static bool llama_dflash_positions_strictly_increasing( + const llama_pos * positions, + int32_t n_rows, + llama_pos & first_pos, + llama_pos & last_pos) { + first_pos = -1; + last_pos = -1; + + if (positions == nullptr || n_rows <= 0) { + return false; + } + + first_pos = positions[0]; + last_pos = positions[n_rows - 1]; + + for (int32_t i = 1; i < n_rows; ++i) { + if (positions[i] <= positions[i - 1]) { + return false; + } + } + + return true; +} + uint32_t llama_mtp_state_n_embd(const struct llama_context * ctx) { if (ctx == nullptr) { return 0; @@ -21,6 +47,40 @@ uint32_t llama_mtp_state_n_embd(const struct llama_context * ctx) { return hparams.n_embd; } +void llama_dflash_profile_reset(struct llama_context * ctx) { + if (ctx == nullptr) { + return; + } + + ctx->dflash_profile = {}; +} + +void llama_set_dflash_visible_cross_ctx( + struct llama_context * ctx, + int32_t cross_ctx) { + if (ctx == nullptr) { + return; + } + + ctx->dflash_visible_cross_ctx = std::max(0, cross_ctx); +} + +int32_t llama_get_dflash_visible_cross_ctx( + const struct llama_context * ctx) { + return ctx != nullptr ? ctx->dflash_visible_cross_ctx : 0; +} + +bool llama_dflash_profile_get_stats( + const struct llama_context * ctx, + llama_dflash_profile_stats * stats) { + if (ctx == nullptr || stats == nullptr) { + return false; + } + + *stats = ctx->dflash_profile; + return true; +} + int32_t llama_model_dflash_block_size(const struct llama_model * model) { return model ? (int32_t) model->hparams.dflash_block_size : 0; } @@ -92,32 +152,116 @@ bool llama_set_draft_input_hidden_state_copy( return true; } -bool llama_set_dflash_target_features_copy( +static bool llama_set_dflash_target_features_impl( struct llama_context * ctx, const float * target_features, size_t n_floats, int32_t n_rows, - const llama_pos * target_positions) { + const llama_pos * target_positions, + bool copy_data) { if (ctx == nullptr || target_features == nullptr || n_floats == 0 || n_rows <= 0) { return false; } - ctx->dflash_target_features_owned.assign(target_features, target_features + n_floats); - ctx->dflash_target_features = ctx->dflash_target_features_owned.data(); + auto & profile = ctx->dflash_profile; + const int64_t t_start_us = ggml_time_us(); + const int32_t row_width = n_rows > 0 ? (int32_t) (n_floats / (size_t) n_rows) : 0; + llama_pos first_pos = -1; + llama_pos last_pos = -1; + + if (copy_data) { + ctx->dflash_target_features_owned.assign(target_features, target_features + n_floats); + ctx->dflash_target_features = ctx->dflash_target_features_owned.data(); + } else { + ctx->dflash_target_features_owned.clear(); + ctx->dflash_target_features = target_features; + } ctx->dflash_target_features_n_floats = n_floats; ctx->dflash_target_features_n_rows = n_rows; if (target_positions != nullptr) { - ctx->dflash_target_positions_owned.assign(target_positions, target_positions + n_rows); - ctx->dflash_target_positions = ctx->dflash_target_positions_owned.data(); + if (copy_data) { + ctx->dflash_target_positions_owned.assign(target_positions, target_positions + n_rows); + ctx->dflash_target_positions = ctx->dflash_target_positions_owned.data(); + } else { + ctx->dflash_target_positions_owned.clear(); + ctx->dflash_target_positions = target_positions; + } ctx->dflash_target_positions_n = (size_t) n_rows; } else { ctx->dflash_target_positions_owned.clear(); ctx->dflash_target_positions = nullptr; ctx->dflash_target_positions_n = 0; } + + profile.set_target_copy_calls++; + profile.set_target_copy_us += (uint64_t) (ggml_time_us() - t_start_us); + profile.set_target_rows += (uint64_t) n_rows; + profile.set_target_copy_bytes += n_floats * sizeof(float) + (target_positions ? (size_t) n_rows * sizeof(llama_pos) : 0); + profile.last_n_rows = n_rows; + profile.last_width = row_width; + + if (target_positions == nullptr) { + profile.set_target_missing_positions++; + profile.last_pos_first = -1; + profile.last_pos_last = -1; + } else { + if (!llama_dflash_positions_strictly_increasing(target_positions, n_rows, first_pos, last_pos)) { + profile.set_target_non_monotonic_positions++; + } + profile.last_pos_first = first_pos; + profile.last_pos_last = last_pos; + } + return true; } +bool llama_set_dflash_target_features_copy( + struct llama_context * ctx, + const float * target_features, + size_t n_floats, + int32_t n_rows, + const llama_pos * target_positions) { + return llama_set_dflash_target_features_impl(ctx, target_features, n_floats, n_rows, target_positions, true); +} + +bool llama_set_dflash_target_features_view( + struct llama_context * ctx, + const float * target_features, + size_t n_floats, + int32_t n_rows, + const llama_pos * target_positions) { + return llama_set_dflash_target_features_impl(ctx, target_features, n_floats, n_rows, target_positions, false); +} + +static void llama_record_dflash_capture_phase( + struct llama_context * ctx, + bool is_prompt_warmup, + int32_t row_count, + int32_t row_width) { + if (ctx == nullptr || row_count <= 0 || row_width <= 0) { + return; + } + + auto & profile = ctx->dflash_profile; + if (is_prompt_warmup) { + profile.capture_prompt_batches++; + if (profile.capture_prompt_last_rows > 0 && profile.capture_prompt_last_width > 0 && + (profile.capture_prompt_last_rows != row_count || profile.capture_prompt_last_width != row_width)) { + profile.capture_prompt_shape_changes++; + } + profile.capture_prompt_last_rows = row_count; + profile.capture_prompt_last_width = row_width; + } else { + profile.capture_verify_batches++; + if (profile.capture_verify_last_rows > 0 && profile.capture_verify_last_width > 0 && + (profile.capture_verify_last_rows != row_count || profile.capture_verify_last_width != row_width)) { + profile.capture_verify_shape_changes++; + } + profile.capture_verify_last_rows = row_count; + profile.capture_verify_last_width = row_width; + } +} + static bool llama_dflash_parse_layer_id(const struct ggml_tensor * tensor, int32_t & layer_id) { if (tensor == nullptr) { return false; @@ -236,6 +380,22 @@ void llama_clear_dflash_capture(struct llama_context * ctx) { } } +void llama_finish_dflash_capture_batch( + struct llama_context * ctx, + bool is_prompt_warmup) { + if (ctx == nullptr || !ctx->dflash_capture) { + return; + } + + auto & capture = *ctx->dflash_capture; + llama_record_dflash_capture_phase(ctx, is_prompt_warmup, capture.row_count, capture.row_width); + + // Reset the batch-local reference shape so the next decode only compares layers within + // the same batch, not against the previous prompt/verify batch. + capture.row_count = 0; + capture.row_width = 0; +} + static bool llama_spec_prepare_dflash_capture( struct llama_context * ctx, int32_t & row_count, @@ -245,18 +405,31 @@ static bool llama_spec_prepare_dflash_capture( return false; } + auto & profile = ctx->dflash_profile; + profile.capture_prepare_calls++; + const int64_t t_sync_us = ggml_time_us(); llama_synchronize(ctx); + profile.capture_prepare_sync_us += (uint64_t) (ggml_time_us() - t_sync_us); auto & capture = *ctx->dflash_capture; row_count = capture.row_count; row_width = capture.row_width; n_layers = (int32_t) capture.layer_ids.size(); if (row_count <= 0 || row_width <= 0 || n_layers <= 0 || capture.layer_rows.size() != (size_t) n_layers) { + profile.capture_prepare_failures++; return false; } - for (const auto & rows : capture.layer_rows) { + for (int32_t layer_idx = 0; layer_idx < n_layers; ++layer_idx) { + const auto & rows = capture.layer_rows[(size_t) layer_idx]; if (rows.size() != (size_t) row_count * (size_t) row_width) { + profile.capture_prepare_failures++; + profile.capture_layer_shape_mismatch++; + if (profile.capture_layer_shape_mismatch <= 3) { + LLAMA_LOG_WARN("%s: DFlash capture rows mismatch for layer %d: got=%zu expected=%zu (rows=%d width=%d)\n", + __func__, capture.layer_ids[(size_t) layer_idx], rows.size(), + (size_t) row_count * (size_t) row_width, row_count, row_width); + } return false; } } @@ -264,6 +437,164 @@ static bool llama_spec_prepare_dflash_capture( return true; } +static bool llama_dflash_contract_log_enabled() { + const char * env = std::getenv("IK_DFLASH_CONTRACT_LOG"); + if (env == nullptr || *env == '\0') { + return false; + } + + return std::strcmp(env, "0") != 0 && + std::strcmp(env, "false") != 0 && + std::strcmp(env, "off") != 0; +} + +template +static std::string llama_dflash_contract_format_values( + const std::vector & values, + size_t edge_count = 4) { + std::ostringstream oss; + oss << '['; + if (values.empty()) { + oss << ']'; + return oss.str(); + } + + const size_t head = std::min(edge_count, values.size()); + for (size_t i = 0; i < head; ++i) { + if (i > 0) { + oss << ','; + } + oss << values[i]; + } + + if (values.size() > edge_count * 2) { + oss << ",...,"; + for (size_t i = values.size() - edge_count; i < values.size(); ++i) { + if (i > values.size() - edge_count) { + oss << ','; + } + oss << values[i]; + } + } else { + for (size_t i = head; i < values.size(); ++i) { + oss << ',' << values[i]; + } + } + + oss << ']'; + return oss.str(); +} + +static std::vector llama_dflash_contract_collect_batch_positions( + const llama_batch & batch, + const std::vector & batch_indices) { + std::vector positions; + positions.reserve(batch_indices.size()); + for (int32_t batch_index : batch_indices) { + positions.push_back(batch.pos[batch_index]); + } + return positions; +} + +static void llama_dflash_contract_summarize_positions( + const std::vector & positions, + llama_pos & first_pos, + llama_pos & last_pos, + int32_t & gap_count, + int32_t & nonmono_count) { + first_pos = -1; + last_pos = -1; + gap_count = 0; + nonmono_count = 0; + if (positions.empty()) { + return; + } + + first_pos = positions.front(); + last_pos = positions.back(); + for (size_t i = 1; i < positions.size(); ++i) { + if (positions[i] <= positions[i - 1]) { + nonmono_count++; + } else if (positions[i] != positions[i - 1] + 1) { + gap_count++; + } + } +} + +static void llama_dflash_contract_log_feature_view( + const char * kind, + llama_seq_id seq_id, + const llama_batch & batch, + int32_t row_count, + int32_t row_width, + int32_t n_layers, + int32_t batch_row_offset, + const std::vector & row_indices, + const std::vector & batch_indices) { + if (!llama_dflash_contract_log_enabled()) { + return; + } + + static std::atomic counter = 0; + const uint64_t ordinal = counter.fetch_add(1, std::memory_order_relaxed); + if (ordinal >= 8) { + return; + } + + const std::vector positions = llama_dflash_contract_collect_batch_positions(batch, batch_indices); + llama_pos first_pos = -1; + llama_pos last_pos = -1; + int32_t gap_count = 0; + int32_t nonmono_count = 0; + llama_dflash_contract_summarize_positions(positions, first_pos, last_pos, gap_count, nonmono_count); + + LLAMA_LOG_INFO("%s[%llu]: kind=%s seq=%d batch_tokens=%d capture_rows=%d row_width=%d layers=%d batch_row_offset=%d row_indices=%s batch_indices=%s batch_pos=%s pos=[%d..%d] gaps=%d nonmono=%d\n", + __func__, + (unsigned long long) (ordinal + 1), + kind, + (int) seq_id, + batch.n_tokens, + row_count, + row_width, + n_layers, + batch_row_offset, + llama_dflash_contract_format_values(row_indices).c_str(), + llama_dflash_contract_format_values(batch_indices).c_str(), + llama_dflash_contract_format_values(positions).c_str(), + (int) first_pos, + (int) last_pos, + gap_count, + nonmono_count); +} + +static void llama_dflash_contract_log_output_indices( + struct llama_context * ctx, + const std::vector & output_indices) { + if (!llama_dflash_contract_log_enabled()) { + return; + } + + static std::atomic counter = 0; + const uint64_t ordinal = counter.fetch_add(1, std::memory_order_relaxed); + if (ordinal >= 8) { + return; + } + + int32_t row_count = 0; + int32_t row_width = 0; + int32_t n_layers = 0; + const bool have_capture = llama_spec_prepare_dflash_capture(ctx, row_count, row_width, n_layers); + + LLAMA_LOG_INFO("%s[%llu]: output_indices=%s capture_rows=%d row_width=%d layers=%d have_capture=%s\n", + __func__, + (unsigned long long) (ordinal + 1), + llama_dflash_contract_format_values(output_indices).c_str(), + row_count, + row_width, + n_layers, + have_capture ? "true" : "false"); +} + static bool llama_spec_materialize_dflash_rows( struct llama_context * ctx, const std::vector & row_indices, @@ -275,10 +606,15 @@ static bool llama_spec_materialize_dflash_rows( return false; } + auto & profile = ctx->dflash_profile; + profile.capture_materialize_calls++; + const int64_t t_start_us = ggml_time_us(); + int32_t row_count = 0; int32_t row_width = 0; int32_t n_layers = 0; if (!llama_spec_prepare_dflash_capture(ctx, row_count, row_width, n_layers)) { + profile.capture_materialize_failures++; return false; } @@ -294,6 +630,7 @@ static bool llama_spec_materialize_dflash_rows( if (row_index < 0 || row_index >= row_count) { rows_out.clear(); combined_width = 0; + profile.capture_materialize_failures++; return false; } @@ -304,6 +641,10 @@ static bool llama_spec_materialize_dflash_rows( } } + profile.capture_materialize_us += (uint64_t) (ggml_time_us() - t_start_us); + profile.capture_materialize_rows += (uint64_t) row_indices.size(); + profile.capture_materialize_bytes += rows_out.size() * sizeof(float); + return true; } @@ -412,6 +753,17 @@ bool llama_spec_get_dflash_feature_view( }); } + llama_dflash_contract_log_feature_view( + "batch", + view.rows.empty() ? -1 : view.rows.front().seq_id, + batch, + row_count, + row_width, + n_layers, + batch_row_offset, + row_indices, + batch_indices); + return true; } @@ -470,6 +822,17 @@ bool llama_spec_get_dflash_feature_view_for_seq( }); } + llama_dflash_contract_log_feature_view( + "seq", + seq_id, + batch, + row_count, + row_width, + n_layers, + batch_row_offset, + row_indices, + batch_indices); + return true; } @@ -576,5 +939,7 @@ bool llama_spec_copy_dflash_rows_from_output_indices( return false; } + llama_dflash_contract_log_output_indices(ctx, output_indices); + return hidden_rows.size() == (size_t) output_indices.size() * (size_t) combined_width; } diff --git a/src/llama-spec-features.h b/src/llama-spec-features.h index 130d3895..20f0ff51 100644 --- a/src/llama-spec-features.h +++ b/src/llama-spec-features.h @@ -23,8 +23,74 @@ struct llama_spec_feature_view { std::vector rows; }; +struct llama_dflash_profile_stats { + uint64_t set_target_copy_calls = 0; + uint64_t set_target_copy_us = 0; + uint64_t set_target_rows = 0; + uint64_t set_target_copy_bytes = 0; + uint64_t set_target_missing_positions = 0; + uint64_t set_target_non_monotonic_positions = 0; + + uint64_t capture_prepare_calls = 0; + uint64_t capture_prepare_sync_us = 0; + uint64_t capture_prepare_failures = 0; + uint64_t capture_layer_shape_mismatch = 0; + uint64_t capture_prompt_batches = 0; + uint64_t capture_prompt_shape_changes = 0; + uint64_t capture_verify_batches = 0; + uint64_t capture_verify_shape_changes = 0; + uint64_t capture_materialize_calls = 0; + uint64_t capture_materialize_rows = 0; + uint64_t capture_materialize_bytes = 0; + uint64_t capture_materialize_us = 0; + uint64_t capture_materialize_failures = 0; + + uint64_t graph_prepare_calls = 0; + uint64_t graph_prepare_total_us = 0; + uint64_t graph_feature_copy_us = 0; + uint64_t graph_pos_copy_us = 0; + uint64_t graph_mask_build_us = 0; + uint64_t graph_kv_cache_compute_us = 0; + uint64_t graph_kv_cache_calls = 0; + uint64_t graph_feature_bytes = 0; + uint64_t graph_pos_bytes = 0; + uint64_t graph_mask_bytes = 0; + uint64_t graph_visible_kv_sum = 0; + uint64_t graph_visible_kv_max = 0; + uint64_t graph_pos_fallbacks = 0; + uint64_t graph_pos_non_monotonic = 0; + uint64_t graph_shape_failures = 0; + uint64_t graph_mask_overflow = 0; + + int32_t last_n_rows = 0; + int32_t last_width = 0; + int32_t last_cross_ctx = 0; + int32_t last_left_pad = 0; + int32_t last_n_tokens = 0; + int32_t last_n_kv_total = 0; + int32_t capture_prompt_last_rows = 0; + int32_t capture_prompt_last_width = 0; + int32_t capture_verify_last_rows = 0; + int32_t capture_verify_last_width = 0; + llama_pos last_pos_first = -1; + llama_pos last_pos_last = -1; +}; + uint32_t llama_mtp_state_n_embd(const struct llama_context * ctx); +void llama_dflash_profile_reset(struct llama_context * ctx); + +void llama_set_dflash_visible_cross_ctx( + struct llama_context * ctx, + int32_t cross_ctx); + +int32_t llama_get_dflash_visible_cross_ctx( + const struct llama_context * ctx); + +bool llama_dflash_profile_get_stats( + const struct llama_context * ctx, + llama_dflash_profile_stats * stats); + int32_t llama_model_dflash_block_size(const struct llama_model * model); int32_t llama_model_dflash_mask_token_id(const struct llama_model * model); @@ -54,6 +120,13 @@ bool llama_set_dflash_target_features_copy( int32_t n_rows, const llama_pos * target_positions); +bool llama_set_dflash_target_features_view( + struct llama_context * ctx, + const float * target_features, + size_t n_floats, + int32_t n_rows, + const llama_pos * target_positions); + bool llama_set_dflash_capture_layers( struct llama_context * ctx, const int32_t * layer_ids, @@ -61,6 +134,10 @@ bool llama_set_dflash_capture_layers( void llama_clear_dflash_capture(struct llama_context * ctx); +void llama_finish_dflash_capture_batch( + struct llama_context * ctx, + bool is_prompt_warmup); + bool llama_spec_get_hidden_feature_view( struct llama_context * ctx, const llama_batch & batch, diff --git a/src/llama.cpp b/src/llama.cpp index a9b443c7..af37211e 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -565,6 +565,84 @@ void llama_context::reset_scheduler() { prev_mtp.reset(); } +bool llama_context::ensure_dflash_kv_cache_tensors(int32_t cross_ctx) { + const int32_t target_cross_ctx = std::max(1, cross_ctx); + const int32_t n_layer = model.hparams.n_layer; + const int64_t n_embd_head_k = model.hparams.n_embd_head_k(0); + const int64_t n_embd_head_v = model.hparams.n_embd_head_v(0); + const int64_t n_head_kv = model.hparams.n_head_kv(); + + if (dflash_cache_ctx != nullptr && !dflash_k_ctx_cache.empty()) { + if ((int32_t) dflash_k_ctx_cache.size() == n_layer && + dflash_k_ctx_cache.front() != nullptr && + (int32_t) dflash_k_ctx_cache.front()->ne[2] == target_cross_ctx) { + return true; + } + + free_dflash_kv_cache_tensors(); + if (dflash_sched != nullptr) { + ggml_backend_sched_free(dflash_sched); + dflash_sched = nullptr; + } + dflash_buf_compute_meta.clear(); + } + + ggml_backend_buffer_type_t buft = llama_default_buffer_type_cpu(true); + + ggml_init_params params = { + /*.mem_size =*/ (size_t) (2 * std::max(1, n_layer)) * ggml_tensor_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + + dflash_cache_ctx = ggml_init(params); + if (dflash_cache_ctx == nullptr) { + return false; + } + + dflash_k_ctx_cache.resize((size_t) n_layer); + dflash_v_ctx_cache.resize((size_t) n_layer); + for (int32_t il = 0; il < n_layer; ++il) { + dflash_k_ctx_cache[(size_t) il] = ggml_new_tensor_3d(dflash_cache_ctx, GGML_TYPE_F32, n_embd_head_k, n_head_kv, target_cross_ctx); + dflash_v_ctx_cache[(size_t) il] = ggml_new_tensor_3d(dflash_cache_ctx, GGML_TYPE_F32, n_embd_head_v, n_head_kv, target_cross_ctx); + if (dflash_k_ctx_cache[(size_t) il] == nullptr || dflash_v_ctx_cache[(size_t) il] == nullptr) { + free_dflash_kv_cache_tensors(); + return false; + } + + ggml_set_input(dflash_k_ctx_cache[(size_t) il]); + ggml_set_input(dflash_v_ctx_cache[(size_t) il]); + ggml_format_name(dflash_k_ctx_cache[(size_t) il], "dflash_k_ctx_cache_%d", il); + ggml_format_name(dflash_v_ctx_cache[(size_t) il], "dflash_v_ctx_cache_%d", il); + } + + dflash_cache_buf = ggml_backend_alloc_ctx_tensors_from_buft(dflash_cache_ctx, buft); + if (dflash_cache_buf == nullptr) { + free_dflash_kv_cache_tensors(); + return false; + } + + ggml_backend_buffer_clear(dflash_cache_buf, 0); + return true; +} + +void llama_context::free_dflash_kv_cache_tensors() { + dflash_k_ctx_cache.clear(); + dflash_v_ctx_cache.clear(); + dflash_kv_input_target_features = nullptr; + dflash_kv_input_pos_ctx = nullptr; + dflash_kq_mask_tensor = nullptr; + + if (dflash_cache_buf != nullptr) { + ggml_backend_buffer_free(dflash_cache_buf); + dflash_cache_buf = nullptr; + } + if (dflash_cache_ctx != nullptr) { + ggml_free(dflash_cache_ctx); + dflash_cache_ctx = nullptr; + } +} + bool llama_context::can_reuse_graph(const llama_batch & u_batch) { if (!cparams.graph_reuse) return false; //if (kv_self.save_per_step_ssm) return false; @@ -687,6 +765,10 @@ void llama_context::set_mtp_op_type(llama_mtp_op_type value) { } llama_context::~llama_context() { + if (dflash_sched != nullptr) { + ggml_backend_sched_free(dflash_sched); + } + free_dflash_kv_cache_tensors(); ggml_backend_sched_free(sched); for (ggml_backend_t backend : backends) { @@ -4934,6 +5016,30 @@ static void llama_graph_compute( // fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(lctx.sched)); } +static void llama_graph_compute_sched( + llama_context & lctx, + ggml_backend_sched_t sched, + ggml_cgraph * gf, + int n_threads) { +#ifdef GGML_USE_METAL + if (ggml_backend_is_metal(lctx.backend_metal)) { + ggml_backend_metal_set_n_cb(lctx.backend_metal, n_threads); + } +#endif + + if (lctx.backend_cpu != nullptr) { + ggml_backend_cpu_set_n_threads(lctx.backend_cpu, n_threads); + ggml_backend_cpu_set_abort_callback(lctx.backend_cpu, lctx.abort_callback, lctx.abort_callback_data); + } +#ifdef GGML_USE_BLAS + if (lctx.backend_blas != nullptr) { + ggml_backend_blas_set_n_threads(lctx.backend_blas, n_threads); + } +#endif + + ggml_backend_sched_graph_compute_async(sched, gf); +} + static bool prepare_mtp_graph_inputs( struct llama_context & lctx, uint32_t cur_token, @@ -4984,62 +5090,206 @@ static bool prepare_mtp_graph_inputs( static bool prepare_dflash_graph_inputs( struct llama_context & lctx, uint32_t n_tokens) { - ggml_tensor * target_hidden = lctx.inp_dflash_target_features; - ggml_tensor * pos_ctx = lctx.inp_dflash_pos_ctx; - ggml_tensor * kq_mask = lctx.inp_dflash_kq_mask; + const char * dflash_kv_cache_env = std::getenv("IK_DFLASH_KV_CACHE"); + const bool use_kv_cache = dflash_kv_cache_env != nullptr && *dflash_kv_cache_env != '\0' && + std::strcmp(dflash_kv_cache_env, "0") != 0 && + std::strcmp(dflash_kv_cache_env, "false") != 0 && + std::strcmp(dflash_kv_cache_env, "off") != 0; + const int32_t cross_ctx = lctx.dflash_visible_cross_ctx > 0 + ? lctx.dflash_visible_cross_ctx + : std::max(1, (int32_t) lctx.cparams.n_ctx - (int32_t) lctx.model.hparams.dflash_block_size); + ggml_tensor * kq_mask = lctx.dflash_kq_mask_tensor; - if (target_hidden == nullptr || pos_ctx == nullptr || kq_mask == nullptr) { + if (kq_mask == nullptr) { LLAMA_LOG_ERROR("%s: DFlash graph inputs are not initialized\n", __func__); return false; } + if (use_kv_cache) { + if (!lctx.ensure_dflash_kv_cache_tensors(cross_ctx) || lctx.dflash_k_ctx_cache.empty() || lctx.dflash_v_ctx_cache.empty()) { + LLAMA_LOG_ERROR("%s: DFlash K/V cache inputs are not initialized\n", __func__); + return false; + } + } else if (lctx.inp_dflash_target_features == nullptr || lctx.inp_dflash_pos_ctx == nullptr) { + LLAMA_LOG_ERROR("%s: DFlash inline inputs are not initialized\n", __func__); + return false; + } + const float * src = lctx.dflash_target_features; const llama_pos * src_pos = lctx.dflash_target_positions; const size_t total_floats = lctx.dflash_target_features_n_floats; const size_t total_positions = lctx.dflash_target_positions_n; const int32_t n_rows = lctx.dflash_target_features_n_rows; - const int32_t width = (int32_t) target_hidden->ne[0]; - const int32_t cross_ctx = (int32_t) target_hidden->ne[1]; + const int32_t width = (int32_t) lctx.model.hparams.dflash_n_target_features; + const int32_t graph_cross_ctx = use_kv_cache + ? (lctx.dflash_k_ctx_cache.front() != nullptr ? (int32_t) lctx.dflash_k_ctx_cache.front()->ne[2] : 0) + : (lctx.inp_dflash_target_features != nullptr ? (int32_t) lctx.inp_dflash_target_features->ne[1] : 0); const int32_t n_mask_tokens = (int32_t) kq_mask->ne[1]; const int32_t n_kv_total = (int32_t) kq_mask->ne[0]; + auto & profile = lctx.dflash_profile; + const int64_t t_total_us = ggml_time_us(); + profile.graph_prepare_calls++; + profile.last_n_rows = n_rows; + profile.last_width = width; + profile.last_cross_ctx = cross_ctx; + profile.last_n_tokens = (int32_t) n_tokens; + profile.last_n_kv_total = n_kv_total; + + if (graph_cross_ctx != cross_ctx) { + profile.graph_shape_failures++; + + LLAMA_LOG_ERROR("%s: DFlash graph cross_ctx drift (graph=%d configured=%d)\n", + __func__, graph_cross_ctx, cross_ctx); + return false; + } if (src == nullptr || total_floats == 0 || n_rows <= 0) { + profile.graph_shape_failures++; LLAMA_LOG_ERROR("%s: missing DFlash target features\n", __func__); return false; } if (n_rows > cross_ctx || total_floats != (size_t) n_rows * (size_t) width) { + profile.graph_shape_failures++; LLAMA_LOG_ERROR("%s: invalid DFlash target feature shape (rows=%d width=%d floats=%zu cross_ctx=%d)\n", __func__, n_rows, width, total_floats, cross_ctx); return false; } - lctx.dflash_target_features_padded.assign((size_t) cross_ctx * (size_t) width, 0.0f); - const size_t dst_offset = (size_t) (cross_ctx - n_rows) * (size_t) width; - const int32_t left_pad = cross_ctx - n_rows; - std::copy(src, src + total_floats, lctx.dflash_target_features_padded.begin() + (ptrdiff_t) dst_offset); - ggml_backend_tensor_set(target_hidden, lctx.dflash_target_features_padded.data(), 0, ggml_nbytes(target_hidden)); + if (n_kv_total < cross_ctx + (int32_t) n_tokens) { + profile.graph_mask_overflow++; + LLAMA_LOG_ERROR("%s: invalid DFlash mask shape (n_kv_total=%d < cross_ctx+n_tokens=%d)\n", + __func__, n_kv_total, cross_ctx + (int32_t) n_tokens); + return false; + } + const int32_t left_pad = cross_ctx - n_rows; + const size_t padded_floats = (size_t) cross_ctx * (size_t) width; + const size_t dst_offset = (size_t) left_pad * (size_t) width; + const int64_t t_feature_us = ggml_time_us(); + profile.last_left_pad = left_pad; + if (lctx.dflash_target_features_padded.size() != padded_floats) { + lctx.dflash_target_features_padded.resize(padded_floats); + } + if (left_pad == 0 && total_floats == padded_floats) { + std::copy(src, src + total_floats, lctx.dflash_target_features_padded.begin()); + } else { + if (dst_offset > 0) { + std::fill(lctx.dflash_target_features_padded.begin(), + lctx.dflash_target_features_padded.begin() + (ptrdiff_t) dst_offset, 0.0f); + } + std::copy(src, src + total_floats, lctx.dflash_target_features_padded.begin() + (ptrdiff_t) dst_offset); + } + profile.graph_feature_copy_us += (uint64_t) (ggml_time_us() - t_feature_us); + profile.graph_feature_bytes += padded_floats * sizeof(float); + + const int64_t t_pos_us = ggml_time_us(); lctx.dflash_pos_ctx_data.resize((size_t) cross_ctx); std::fill(lctx.dflash_pos_ctx_data.begin(), lctx.dflash_pos_ctx_data.end(), 0); if (src_pos != nullptr && total_positions == (size_t) n_rows) { + bool monotonic = true; + for (int32_t i = 1; i < n_rows; ++i) { + if (src_pos[i] <= src_pos[i - 1]) { + monotonic = false; + break; + } + } + if (!monotonic) { + profile.graph_pos_non_monotonic++; + if (profile.graph_pos_non_monotonic <= 3) { + LLAMA_LOG_WARN("%s: DFlash target positions are not strictly increasing (rows=%d first=%d last=%d)\n", + __func__, n_rows, (int) src_pos[0], (int) src_pos[n_rows - 1]); + } + } + profile.last_pos_first = src_pos[0]; + profile.last_pos_last = src_pos[n_rows - 1]; std::copy(src_pos, src_pos + n_rows, lctx.dflash_pos_ctx_data.begin() + (ptrdiff_t) left_pad); } else { + profile.graph_pos_fallbacks++; + profile.last_pos_first = -1; + profile.last_pos_last = -1; + if (profile.graph_pos_fallbacks <= 3) { + LLAMA_LOG_WARN("%s: using synthetic DFlash positions (rows=%d positions=%zu cross_ctx=%d)\n", + __func__, n_rows, total_positions, cross_ctx); + } for (int32_t i = 0; i < n_rows; ++i) { lctx.dflash_pos_ctx_data[(size_t) left_pad + (size_t) i] = i; } } - ggml_backend_tensor_set(pos_ctx, lctx.dflash_pos_ctx_data.data(), 0, ggml_nbytes(pos_ctx)); + profile.graph_pos_copy_us += (uint64_t) (ggml_time_us() - t_pos_us); + profile.graph_pos_bytes += lctx.dflash_pos_ctx_data.size() * sizeof(llama_pos); + if (use_kv_cache) { + const size_t max_nodes = lctx.model.max_nodes((int) std::max(1, cross_ctx)) + 24 * lctx.model.hparams.n_layer; + const size_t meta_size = ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false); + if (lctx.dflash_buf_compute_meta.size() != meta_size) { + lctx.dflash_buf_compute_meta.resize(meta_size); + } + + ggml_cgraph * gf_kv = llm_build_context::llama_build_graph_dflash_kv_cache(lctx); + if (gf_kv == nullptr || lctx.dflash_kv_input_target_features == nullptr || lctx.dflash_kv_input_pos_ctx == nullptr) { + profile.graph_shape_failures++; + LLAMA_LOG_ERROR("%s: failed to build DFlash K/V cache graph\n", __func__); + return false; + } + + if (lctx.dflash_sched == nullptr) { + std::vector backend_buft; + backend_buft.reserve(lctx.backends.size()); + for (auto * backend : lctx.backends) { + if (ggml_backend_is_cpu(backend)) { + backend_buft.push_back(llama_default_buffer_type_cpu(true)); + } else { + backend_buft.push_back(ggml_backend_get_default_buffer_type(backend)); + } + } + + lctx.dflash_sched = ggml_backend_sched_new(lctx.backends.data(), backend_buft.data(), lctx.backends.size(), max_nodes, false); + if (lctx.dflash_sched == nullptr || !ggml_backend_sched_reserve(lctx.dflash_sched, gf_kv)) { + profile.graph_shape_failures++; + LLAMA_LOG_ERROR("%s: failed to initialize DFlash K/V scheduler\n", __func__); + return false; + } + } + + ggml_backend_sched_reset(lctx.dflash_sched); + ggml_backend_sched_alloc_graph(lctx.dflash_sched, gf_kv); + ggml_backend_tensor_set(lctx.dflash_kv_input_target_features, lctx.dflash_target_features_padded.data(), 0, ggml_nbytes(lctx.dflash_kv_input_target_features)); + ggml_backend_tensor_set(lctx.dflash_kv_input_pos_ctx, lctx.dflash_pos_ctx_data.data(), 0, ggml_nbytes(lctx.dflash_kv_input_pos_ctx)); + const int64_t t_kv_cache_us = ggml_time_us(); + llama_graph_compute_sched(lctx, lctx.dflash_sched, gf_kv, lctx.cparams.n_threads); + llama_synchronize(&lctx); + profile.graph_kv_cache_compute_us += (uint64_t) (ggml_time_us() - t_kv_cache_us); + profile.graph_kv_cache_calls++; + } else { + ggml_backend_tensor_set(lctx.inp_dflash_target_features, lctx.dflash_target_features_padded.data(), 0, ggml_nbytes(lctx.inp_dflash_target_features)); + ggml_backend_tensor_set(lctx.inp_dflash_pos_ctx, lctx.dflash_pos_ctx_data.data(), 0, ggml_nbytes(lctx.inp_dflash_pos_ctx)); + } + + const int64_t t_mask_us = ggml_time_us(); lctx.dflash_kq_mask_data.assign((size_t) n_kv_total * (size_t) n_mask_tokens, -INFINITY); + int32_t visible_kv_max = 0; for (uint32_t j = 0; j < n_tokens; ++j) { float * row = lctx.dflash_kq_mask_data.data() + (size_t) j * (size_t) n_kv_total; const int32_t visible_kv = cross_ctx + (int32_t) j + 1; + visible_kv_max = std::max(visible_kv_max, visible_kv); + profile.graph_visible_kv_sum += (uint64_t) visible_kv; for (int32_t i = left_pad; i < visible_kv; ++i) { row[i] = 0.0f; } } ggml_backend_tensor_set(kq_mask, lctx.dflash_kq_mask_data.data(), 0, ggml_nbytes(kq_mask)); + profile.graph_mask_build_us += (uint64_t) (ggml_time_us() - t_mask_us); + profile.graph_mask_bytes += ggml_nbytes(kq_mask); + profile.graph_visible_kv_max = std::max(profile.graph_visible_kv_max, (uint64_t) visible_kv_max); + profile.graph_prepare_total_us += (uint64_t) (ggml_time_us() - t_total_us); + + if (profile.graph_prepare_calls == 1) { + LLAMA_LOG_INFO("%s: DFlash graph contract rows=%d width=%d cross_ctx=%d n_tokens=%u left_pad=%d n_kv_total=%d draft_n_ctx=%u pos=%s [%d..%d]\n", + __func__, n_rows, width, cross_ctx, n_tokens, left_pad, n_kv_total, lctx.cparams.n_ctx, + (src_pos != nullptr && total_positions == (size_t) n_rows) ? "target" : "synthetic", + (int) profile.last_pos_first, (int) profile.last_pos_last); + } return true; } From 1369e684711ff5901e80e23238a293189a4fc10b Mon Sep 17 00:00:00 2001 From: SamuelOliveirads Date: Sun, 31 May 2026 11:12:03 -0300 Subject: [PATCH 04/13] fix graph mask, swa layers and tokens positions --- common/speculative.cpp | 285 ++++++++++++++++++++-- convert_hf_to_gguf.py | 58 ++++- examples/server/server-context.cpp | 39 +++- examples/server/server-context.h | 2 + gguf-py/gguf/constants.py | 2 + src/graphs/build_dflash.cpp | 32 ++- src/llama-context.h | 7 +- src/llama-hparams.cpp | 26 +-- src/llama-spec-features.cpp | 135 ++++++++++- src/llama-spec-features.h | 40 ++++ src/llama.cpp | 363 +++++++++++++++++++++++++---- 11 files changed, 894 insertions(+), 95 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index bd854a7f..911526a8 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #include #include #include @@ -133,6 +134,31 @@ static bool common_speculative_are_dflash_compatible( return false; } + const bool add_bos_tgt = llama_vocab_get_add_bos(vocab_tgt); + const bool add_bos_dft = llama_vocab_get_add_bos(vocab_dft); + const bool add_eos_tgt = llama_vocab_get_add_eos(vocab_tgt); + const bool add_eos_dft = llama_vocab_get_add_eos(vocab_dft); + const llama_token bos_tgt = llama_vocab_bos(vocab_tgt); + const llama_token bos_dft = llama_vocab_bos(vocab_dft); + const llama_token eos_tgt = llama_vocab_eos(vocab_tgt); + const llama_token eos_dft = llama_vocab_eos(vocab_dft); + + if (add_bos_tgt != add_bos_dft || add_eos_tgt != add_eos_dft || + (add_bos_tgt && bos_tgt != bos_dft) || + (add_eos_tgt && eos_tgt != eos_dft)) { + LOG_DBG("%s: DFlash draft special tokens must match the target model (add_bos=%d/%d add_eos=%d/%d bos=%d/%d eos=%d/%d)\n", + __func__, + (int) add_bos_tgt, + (int) add_bos_dft, + (int) add_eos_tgt, + (int) add_eos_dft, + (int) bos_tgt, + (int) bos_dft, + (int) eos_tgt, + (int) eos_dft); + return false; + } + const int n_vocab_tgt = llama_vocab_n_tokens(vocab_tgt); const int n_vocab_dft = llama_vocab_n_tokens(vocab_dft); const int vocab_diff = n_vocab_tgt > n_vocab_dft @@ -464,6 +490,24 @@ struct common_speculative_state_dflash : public common_speculative_state { size_t n_decode_fail = 0; llama_pos last_draft_pos_base = -1; + uint64_t t_draft_decode_us = 0; + uint64_t t_draft_sample_us = 0; + uint64_t t_warmup_collect_us = 0; + uint64_t t_warmup_append_us = 0; + uint64_t t_accept_output_copy_us = 0; + uint64_t t_accept_commit_us = 0; + uint64_t t_accept_append_us = 0; + size_t n_warmup_collect_calls = 0; + size_t n_warmup_collect_rows = 0; + size_t n_warmup_append_calls = 0; + size_t n_warmup_append_rows = 0; + size_t n_accept_output_copy_calls = 0; + size_t n_accept_output_copy_rows = 0; + size_t n_accept_commit_calls = 0; + size_t n_accept_commit_rows = 0; + size_t n_accept_append_calls = 0; + size_t n_accept_append_rows = 0; + common_speculative_state_dflash( enum common_speculative_type type, llama_context * ctx_tgt, @@ -474,9 +518,10 @@ struct common_speculative_state_dflash : public common_speculative_state { , ctx_dft(ctx_dft) , cross_ctx(std::max(1, cross_ctx)) { + const llama_model * model_tgt = llama_get_model(ctx_tgt); const llama_model * model_dft = llama_get_model(ctx_dft); - if (!common_speculative_are_dflash_compatible(llama_get_model(ctx_tgt), model_dft)) { + if (!common_speculative_are_dflash_compatible(model_tgt, model_dft)) { LOG_ERR("%s: DFlash draft model vocab/tokenizer is incompatible with the target model\n", __func__); return; } @@ -499,6 +544,70 @@ struct common_speculative_state_dflash : public common_speculative_state { return; } + const auto * vocab_tgt = llama_model_get_vocab(model_tgt); + const auto * vocab_dft = llama_model_get_vocab(model_dft); + const int32_t target_vocab_size = llama_vocab_n_tokens(vocab_tgt); + const int32_t draft_vocab_size = llama_vocab_n_tokens(vocab_dft); + const int32_t target_hidden_size = llama_model_n_embd(model_tgt); + const int32_t draft_hidden_size = llama_model_n_embd(model_dft); + const int32_t target_mask_token_id = llama_model_dflash_target_mask_token_id(model_tgt); + const int32_t expected_n_target_features = target_hidden_size > 0 ? target_hidden_size * n_target_layers : 0; + + if (target_mask_token_id != (int32_t) LLAMA_TOKEN_NULL && mask_token_id != target_mask_token_id) { + LOG_ERR("%s: DFlash mask token mismatch (draft=%d target=%d)\n", + __func__, mask_token_id, target_mask_token_id); + return; + } + + if (target_hidden_size <= 0 || draft_hidden_size <= 0) { + LOG_ERR("%s: invalid DFlash hidden sizes (draft=%d target=%d)\n", + __func__, draft_hidden_size, target_hidden_size); + return; + } + + if (expected_n_target_features <= 0 || n_target_features != expected_n_target_features) { + LOG_ERR("%s: DFlash target feature width mismatch (metadata=%d expected=%d target_hidden=%d target_layers=%d)\n", + __func__, n_target_features, expected_n_target_features, target_hidden_size, n_target_layers); + return; + } + + std::vector sorted_target_layer_ids = target_layer_ids; + std::sort(sorted_target_layer_ids.begin(), sorted_target_layer_ids.end()); + if (std::adjacent_find(sorted_target_layer_ids.begin(), sorted_target_layer_ids.end()) != sorted_target_layer_ids.end()) { + LOG_ERR("%s: duplicate DFlash target layer ids survived into runtime validation\n", __func__); + target_layer_ids.clear(); + return; + } + + const int32_t n_target_model_layers = llama_n_layer(model_tgt); + for (int32_t layer_id : target_layer_ids) { + if (layer_id < 0 || layer_id >= n_target_model_layers) { + LOG_ERR("%s: invalid DFlash target layer id %d for target model with %d layers\n", + __func__, layer_id, n_target_model_layers); + target_layer_ids.clear(); + return; + } + } + + const int32_t io_mode = llama_model_dflash_io_mode(model_dft, model_tgt); + if (io_mode == LLAMA_DFLASH_IO_MODE_INVALID) { + LOG_ERR("%s: DFlash draft is missing required IO tensors after target sharing\n", __func__); + return; + } + + if (io_mode == LLAMA_DFLASH_IO_MODE_MIXED) { + LOG_ERR("%s: DFlash IO contract must be fully shared or fully self-contained, but resolved to mixed mode\n", __func__); + return; + } + + if (io_mode == LLAMA_DFLASH_IO_MODE_SELF_CONTAINED && !llama_model_dflash_io_tensors_match(model_dft, target_hidden_size, target_vocab_size)) { + LOG_ERR("%s: DFlash self-contained IO tensors do not match the target hidden/vocab contract (target_hidden=%d target_vocab=%d)\n", + __func__, + target_hidden_size, + target_vocab_size); + return; + } + if (!llama_set_dflash_capture_layers(ctx_tgt, target_layer_ids.data(), (int32_t) target_layer_ids.size())) { LOG_ERR("%s: failed to configure DFlash target capture callback\n", __func__); return; @@ -519,8 +628,11 @@ struct common_speculative_state_dflash : public common_speculative_state { layers_oss << target_layer_ids[i]; } + const char * io_mode_name = io_mode == LLAMA_DFLASH_IO_MODE_SHARED ? "shared" : "self-contained"; LOG_INF("%s: DFlash context ready (n_ctx=%d, block_size=%d, cross_ctx=%d, n_target_features=%d, target_layer_ids=[%s])\n", - __func__, llama_n_ctx(ctx_dft), block_size, this->cross_ctx, n_target_features, layers_oss.str().c_str()); + __func__, llama_n_ctx(ctx_dft), block_size, this->cross_ctx, n_target_features, layers_oss.str().c_str()); + LOG_INF("%s: DFlash artifact io=%s draft_vocab=%d target_vocab=%d draft_hidden=%d target_hidden=%d mask_token_id=%d target_mask_token_id=%d\n", + __func__, io_mode_name, draft_vocab_size, target_vocab_size, draft_hidden_size, target_hidden_size, mask_token_id, target_mask_token_id); } ~common_speculative_state_dflash() override { @@ -544,6 +656,23 @@ struct common_speculative_state_dflash : public common_speculative_state { n_set_target_fail = 0; n_decode_fail = 0; last_draft_pos_base = -1; + t_draft_decode_us = 0; + t_draft_sample_us = 0; + t_warmup_collect_us = 0; + t_warmup_append_us = 0; + t_accept_output_copy_us = 0; + t_accept_commit_us = 0; + t_accept_append_us = 0; + n_warmup_collect_calls = 0; + n_warmup_collect_rows = 0; + n_warmup_append_calls = 0; + n_warmup_append_rows = 0; + n_accept_output_copy_calls = 0; + n_accept_output_copy_rows = 0; + n_accept_commit_calls = 0; + n_accept_commit_rows = 0; + n_accept_append_calls = 0; + n_accept_append_rows = 0; llama_dflash_profile_reset(ctx_tgt); llama_dflash_profile_reset(ctx_dft); } @@ -554,7 +683,6 @@ struct common_speculative_state_dflash : public common_speculative_state { llama_token id_last, llama_tokens & result) override { GGML_UNUSED(prompt_tgt); - GGML_UNUSED(id_last); result.clear(); if (!ready || target_window_rows <= 0) { @@ -562,7 +690,7 @@ struct common_speculative_state_dflash : public common_speculative_state { return; } - const int32_t n_keep = std::min(params.n_max, block_size); + const int32_t n_keep = std::min(params.n_max, block_size - 1); if (n_keep <= 0) { return; } @@ -575,23 +703,30 @@ struct common_speculative_state_dflash : public common_speculative_state { llama_kv_cache_clear(ctx_dft); batch.n_tokens = 0; + const int32_t batch_len = n_keep + 1; const llama_pos draft_pos_base = last_target_pos >= 0 ? last_target_pos + 1 : (llama_pos) target_window_rows; + const llama_pos seed_pos = last_target_pos >= 0 ? last_target_pos : draft_pos_base - 1; last_draft_pos_base = draft_pos_base; - for (int32_t i = 0; i < block_size; ++i) { - common_batch_add(batch, mask_token_id, draft_pos_base + i, { 0 }, i < n_keep); + common_batch_add(batch, id_last, seed_pos, { 0 }, false); + for (int32_t i = 1; i < batch_len; ++i) { + common_batch_add(batch, mask_token_id, draft_pos_base + (i - 1), { 0 }, i <= n_keep); } + const int64_t t_decode_us = ggml_time_us(); if (llama_decode(ctx_dft, batch) != 0) { LOG_ERR("%s: llama_decode() failed for DFlash draft batch\n", __func__); n_decode_fail++; batch.n_tokens = 0; return; } + t_draft_decode_us += (uint64_t) (ggml_time_us() - t_decode_us); result.reserve((size_t) n_keep); + const int64_t t_sample_us = ggml_time_us(); for (int32_t i = 0; i < n_keep; ++i) { - result.push_back(common_sampler_sample_speculative(nullptr, ctx_dft, i, nullptr)); + result.push_back(common_sampler_sample_speculative(nullptr, ctx_dft, i + 1, nullptr)); } + t_draft_sample_us += (uint64_t) (ggml_time_us() - t_sample_us); batch.n_tokens = 0; dflash_contract_log_draft(*this, n_keep, result.size()); @@ -657,8 +792,13 @@ static void dflash_contract_log_draft( const int draft_delta = (state.last_target_pos >= 0 && state.last_draft_pos_base >= 0) ? (int) (state.last_draft_pos_base - state.last_target_pos) : -1; + const llama_pos seed_pos = state.last_target_pos; + const llama_pos mask_first_pos = state.last_draft_pos_base; + const llama_pos mask_last_pos = state.last_draft_pos_base >= 0 + ? state.last_draft_pos_base + n_keep - 1 + : -1; - LOG_INF("dflash contract draft[%llu]: window_rows=%d window_pos=%s pos=[%d..%d] gaps=%d nonmono=%d last_target_pos=%d draft_pos_base=%d delta=%d n_keep=%d result=%zu set_target(missing/nonmono)=%llu/%llu graph(fallback/nonmono)=%llu/%llu graph_pos=[%d..%d]\n", + LOG_INF("dflash contract draft[%llu]: window_rows=%d window_pos=%s pos=[%d..%d] gaps=%d nonmono=%d last_target_pos=%d seed_pos=%d mask_pos=[%d..%d] sample_rows=[1..%d] output_rows=[1..%d] draft_pos_base=%d delta=%d n_keep=%d result=%zu set_target(missing/nonmono)=%llu/%llu graph(fallback/nonmono)=%llu/%llu graph_pos=[%d..%d]\n", (unsigned long long) (ordinal + 1), state.target_window_rows, dflash_contract_format_values(state.target_window_pos).c_str(), @@ -667,6 +807,11 @@ static void dflash_contract_log_draft( window.gap_count, window.nonmono_count, (int) state.last_target_pos, + (int) seed_pos, + (int) mask_first_pos, + (int) mask_last_pos, + n_keep, + n_keep, (int) state.last_draft_pos_base, draft_delta, n_keep, @@ -1583,7 +1728,14 @@ common_speculative * common_speculative_init( return nullptr; } - cparams_dft.n_ctx = (uint32_t) (max_cross_ctx + block_size); + const int64_t required_n_ctx = (int64_t) max_cross_ctx + (int64_t) block_size; + if (required_n_ctx > std::numeric_limits::max()) { + LOG_ERR("%s: invalid DFlash draft context size cross_ctx=%d block_size=%d required_n_ctx=%lld\n", + __func__, max_cross_ctx, block_size, (long long) required_n_ctx); + return nullptr; + } + + cparams_dft.n_ctx = (uint32_t) required_n_ctx; } ctx_dft = llama_init_from_model(params.model_dft, cparams_dft); @@ -2127,6 +2279,8 @@ int32_t common_speculative_on_target_seq_batch( const llama_batch * batch_for_spec = &batch; llama_batch seq_batch = {}; const bool needs_seq_split = is_prompt_warmup && !common_speculative_batch_is_exact_single_seq(batch, seq_id); + auto * dflash_state = common_speculative_get_dflash_state(spec); + const bool measure_dflash_warmup_collect = dflash_state != nullptr && is_prompt_warmup; if (needs_seq_split) { const int n_seq_tokens = common_speculative_copy_seq_batch(batch, seq_id, seq_batch); @@ -2134,16 +2288,28 @@ int32_t common_speculative_on_target_seq_batch( return n_seq_tokens < 0 ? -1 : 0; } + const int64_t t_collect_us = measure_dflash_warmup_collect ? ggml_time_us() : 0; if (!common_speculative_collect_target_seq_batch_features(spec, ctx_tgt, batch, seq_id, feature_view)) { llama_batch_free(seq_batch); return -1; } + if (measure_dflash_warmup_collect) { + dflash_state->t_warmup_collect_us += (uint64_t) (ggml_time_us() - t_collect_us); + dflash_state->n_warmup_collect_calls++; + dflash_state->n_warmup_collect_rows += (size_t) n_seq_tokens; + } batch_for_spec = &seq_batch; } else { + const int64_t t_collect_us = measure_dflash_warmup_collect ? ggml_time_us() : 0; if (!common_speculative_collect_target_batch_features(spec, ctx_tgt, batch, feature_view)) { return -1; } + if (measure_dflash_warmup_collect) { + dflash_state->t_warmup_collect_us += (uint64_t) (ggml_time_us() - t_collect_us); + dflash_state->n_warmup_collect_calls++; + dflash_state->n_warmup_collect_rows += (size_t) batch.n_tokens; + } } const int32_t ret = common_speculative_on_target_batch(spec, *batch_for_spec, feature_view, is_prompt_warmup); @@ -2244,7 +2410,16 @@ bool common_speculative_commit_accepted_hidden_rows( return false; } - return common_speculative_apply_hidden_rows(spec, seq_id, pos_base, commit_tokens, hidden_rows); + auto * dflash_state = common_speculative_get_dflash_state(spec); + const int64_t t_commit_us = dflash_state != nullptr ? ggml_time_us() : 0; + const bool ok = common_speculative_apply_hidden_rows(spec, seq_id, pos_base, commit_tokens, hidden_rows); + if (dflash_state != nullptr) { + dflash_state->t_accept_commit_us += (uint64_t) (ggml_time_us() - t_commit_us); + dflash_state->n_accept_commit_calls++; + dflash_state->n_accept_commit_rows += commit_tokens.size(); + } + + return ok; } bool common_speculative_commit_accepted_output( @@ -2261,9 +2436,16 @@ bool common_speculative_commit_accepted_output( } std::vector hidden_rows; + auto * dflash_state = common_speculative_get_dflash_state(spec); + const int64_t t_copy_us = dflash_state != nullptr ? ggml_time_us() : 0; if (!common_speculative_copy_output_hidden_rows(spec, ctx, output_indices, hidden_rows)) { return false; } + if (dflash_state != nullptr) { + dflash_state->t_accept_output_copy_us += (uint64_t) (ggml_time_us() - t_copy_us); + dflash_state->n_accept_output_copy_calls++; + dflash_state->n_accept_output_copy_rows += output_indices.size(); + } return common_speculative_commit_accepted_hidden_rows( spec, @@ -2324,7 +2506,24 @@ void common_speculative_print_stats(const common_speculative * spec, double slot (int) dflash_state->last_draft_pos_base); if (have_capture || have_graph) { - LOG_INF("statistics dflash profile: capture(sync/materialize)=%.3f/%.3f ms calls=%llu/%llu bytes=%llu phase(prompt/verify batches changes)=%llu/%llu %llu/%llu, set_target=%.3f ms rows=%llu bytes=%llu, prep(total/features/pos/mask)=%.3f/%.3f/%.3f/%.3f ms kv_cache=%.3f ms calls=%llu/%llu bytes=%llu/%llu/%llu, fallback_pos(copy/graph)=%llu/%llu, nonmono(copy/graph)=%llu/%llu, capture_fail=%llu/%llu, visible_kv_max=%llu, last(rows=%d width=%d left_pad=%d n_tokens=%d n_kv=%d pos=[%d..%d])\n", + const double kv_cache_total_ms = (double) ( + graph_stats.graph_kv_cache_build_us + + graph_stats.graph_kv_cache_reserve_us + + graph_stats.graph_kv_cache_reset_us + + graph_stats.graph_kv_cache_alloc_us + + graph_stats.graph_kv_cache_feature_upload_us + + graph_stats.graph_kv_cache_pos_upload_us + + graph_stats.graph_kv_cache_compute_us + + graph_stats.graph_kv_cache_sync_us + + graph_stats.graph_kv_cache_read_concat_pad_us) / 1000.0; + const double kv_upload_feature_ms = (double) graph_stats.graph_kv_cache_feature_upload_us / 1000.0; + const double kv_upload_pos_ms = (double) graph_stats.graph_kv_cache_pos_upload_us / 1000.0; + const double kv_upload_total_ms = kv_upload_feature_ms + kv_upload_pos_ms; + const double kv_compute_ms = (double) graph_stats.graph_kv_cache_compute_us / 1000.0; + const double kv_sync_ms = (double) graph_stats.graph_kv_cache_sync_us / 1000.0; + const double replay_append_ms = (double) dflash_state->t_accept_append_us / 1000.0; + + LOG_INF("statistics dflash profile: capture(sync/materialize)=%.3f/%.3f ms calls=%llu/%llu bytes=%llu phase(prompt/verify batches changes)=%llu/%llu %llu/%llu, set_target=%.3f ms rows=%llu bytes=%llu, decode(llama_output_reserve/prepare)=%.3f/%.3f ms calls=%llu/%llu realloc(bytes)=%llu/%llu, prep(total/features/pos/mask)=%.3f/%.3f/%.3f/%.3f ms kv_cache(total/build/reserve/reset/alloc/up_f/up_p/compute/sync/read)=%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f ms calls(prepare/cache/read)=%llu/%llu/%llu bytes(feature/pos/mask/read)=%llu/%llu/%llu/%llu host_layers=%d, fallback_pos(copy/graph)=%llu/%llu, nonmono(copy/graph)=%llu/%llu, capture_fail=%llu/%llu decode_prepare_fail=%llu, visible_kv_max=%llu, last(rows=%d width=%d left_pad=%d n_tokens=%d n_kv=%d pos=[%d..%d])\n", (double) capture_stats.capture_prepare_sync_us / 1000.0, (double) capture_stats.capture_materialize_us / 1000.0, (unsigned long long) capture_stats.capture_prepare_calls, @@ -2337,22 +2536,41 @@ void common_speculative_print_stats(const common_speculative * spec, double slot (double) graph_stats.set_target_copy_us / 1000.0, (unsigned long long) graph_stats.set_target_rows, (unsigned long long) graph_stats.set_target_copy_bytes, + (double) graph_stats.decode_output_reserve_us / 1000.0, + (double) graph_stats.decode_prepare_us / 1000.0, + (unsigned long long) graph_stats.decode_output_reserve_calls, + (unsigned long long) graph_stats.decode_prepare_calls, + (unsigned long long) graph_stats.decode_output_reserve_reallocs, + (unsigned long long) graph_stats.decode_output_reserve_realloc_bytes, (double) graph_stats.graph_prepare_total_us / 1000.0, (double) graph_stats.graph_feature_copy_us / 1000.0, (double) graph_stats.graph_pos_copy_us / 1000.0, (double) graph_stats.graph_mask_build_us / 1000.0, + kv_cache_total_ms, + (double) graph_stats.graph_kv_cache_build_us / 1000.0, + (double) graph_stats.graph_kv_cache_reserve_us / 1000.0, + (double) graph_stats.graph_kv_cache_reset_us / 1000.0, + (double) graph_stats.graph_kv_cache_alloc_us / 1000.0, + (double) graph_stats.graph_kv_cache_feature_upload_us / 1000.0, + (double) graph_stats.graph_kv_cache_pos_upload_us / 1000.0, (double) graph_stats.graph_kv_cache_compute_us / 1000.0, + (double) graph_stats.graph_kv_cache_sync_us / 1000.0, + (double) graph_stats.graph_kv_cache_read_concat_pad_us / 1000.0, (unsigned long long) graph_stats.graph_prepare_calls, (unsigned long long) graph_stats.graph_kv_cache_calls, + (unsigned long long) graph_stats.graph_kv_cache_read_concat_pad_calls, (unsigned long long) graph_stats.graph_feature_bytes, (unsigned long long) graph_stats.graph_pos_bytes, (unsigned long long) graph_stats.graph_mask_bytes, + (unsigned long long) graph_stats.graph_kv_cache_cached_bytes, + graph_stats.last_kv_cache_host_layers, (unsigned long long) graph_stats.set_target_missing_positions, (unsigned long long) graph_stats.graph_pos_fallbacks, (unsigned long long) graph_stats.set_target_non_monotonic_positions, (unsigned long long) graph_stats.graph_pos_non_monotonic, (unsigned long long) capture_stats.capture_prepare_failures, (unsigned long long) capture_stats.capture_materialize_failures, + (unsigned long long) graph_stats.decode_prepare_failures, (unsigned long long) graph_stats.graph_visible_kv_max, graph_stats.last_n_rows, graph_stats.last_width, @@ -2361,6 +2579,36 @@ void common_speculative_print_stats(const common_speculative * spec, double slot graph_stats.last_n_kv_total, (int) graph_stats.last_pos_first, (int) graph_stats.last_pos_last); + + LOG_INF("statistics dflash hot: kv(upload_f/upload_p/upload/compute/sync)=%.3f/%.3f/%.3f/%.3f/%.3f ms calls=%llu replay(accepted_prefix_append)=%.3f ms calls=%zu rows=%zu\n", + kv_upload_feature_ms, + kv_upload_pos_ms, + kv_upload_total_ms, + kv_compute_ms, + kv_sync_ms, + (unsigned long long) graph_stats.graph_kv_cache_calls, + replay_append_ms, + dflash_state->n_accept_append_calls, + dflash_state->n_accept_append_rows); + + LOG_INF("statistics dflash stages: draft(decode/sample)=%.3f/%.3f ms warmup(collect/append)=%.3f/%.3f ms calls=%zu/%zu rows=%zu/%zu accept(total/output_copy/append)=%.3f/%.3f/%.3f ms calls=%zu/%zu/%zu rows=%zu/%zu/%zu\n", + (double) dflash_state->t_draft_decode_us / 1000.0, + (double) dflash_state->t_draft_sample_us / 1000.0, + (double) dflash_state->t_warmup_collect_us / 1000.0, + (double) dflash_state->t_warmup_append_us / 1000.0, + dflash_state->n_warmup_collect_calls, + dflash_state->n_warmup_append_calls, + dflash_state->n_warmup_collect_rows, + dflash_state->n_warmup_append_rows, + (double) dflash_state->t_accept_commit_us / 1000.0, + (double) dflash_state->t_accept_output_copy_us / 1000.0, + (double) dflash_state->t_accept_append_us / 1000.0, + dflash_state->n_accept_commit_calls, + dflash_state->n_accept_output_copy_calls, + dflash_state->n_accept_append_calls, + dflash_state->n_accept_commit_rows, + dflash_state->n_accept_output_copy_rows, + dflash_state->n_accept_append_rows); } } } @@ -2690,8 +2938,6 @@ int32_t common_speculative_on_target_batch( const common_speculative_feature_view & features, bool is_prompt_warmup) { if (auto * dflash_state = common_speculative_get_dflash_state(spec); dflash_state != nullptr) { - GGML_UNUSED(is_prompt_warmup); - if (features.kind != COMMON_SPECULATIVE_FEATURE_HIDDEN_STATE || batch.n_tokens <= 0) { return 0; } @@ -2713,9 +2959,22 @@ int32_t common_speculative_on_target_batch( } } + const int64_t t_append_us = ggml_time_us(); if (!dflash_append_target_features(*dflash_state, features, batch, seq_id)) { return -1; } + + const uint64_t append_us = (uint64_t) (ggml_time_us() - t_append_us); + if (is_prompt_warmup) { + dflash_state->t_warmup_append_us += append_us; + dflash_state->n_warmup_append_calls++; + dflash_state->n_warmup_append_rows += (size_t) batch.n_tokens; + } else { + dflash_state->t_accept_append_us += append_us; + dflash_state->n_accept_append_calls++; + dflash_state->n_accept_append_rows += (size_t) batch.n_tokens; + } + return 0; } diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 232ba706..acf3ecb7 100644 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2287,6 +2287,8 @@ class DFlashDraftModel(Qwen3Model): model_arch = gguf.MODEL_ARCH.DFLASH_DRAFT _target_hparams: dict[str, Any] | None = None + _saw_token_embd = False + _saw_output = False def _require_target_model_dir(self) -> Path: if self.target_model_dir is None: @@ -2338,25 +2340,28 @@ class DFlashDraftModel(Qwen3Model): elif "n_target_features" in self.hparams: n_target_features = int(self.hparams["n_target_features"]) else: + target_hparams = self._get_target_hparams() + target_hidden_size = target_hparams.get("hidden_size") + if target_hidden_size is None: + raise ValueError("DFlashDraftModel: target config is missing hidden_size") + draft_hidden_size = self.hparams.get("hidden_size") if draft_hidden_size is None: raise ValueError("DFlashDraftModel: draft config is missing hidden_size") - n_target_features = int(draft_hidden_size) * len(target_layer_ids) + n_target_features = int(target_hidden_size) * len(target_layer_ids) - target_hparams = self._get_target_hparams() - target_hidden_size = target_hparams.get("hidden_size") if target_hidden_size is not None and int(target_hidden_size) != int(draft_hidden_size): logger.warning( - "DFlashDraftModel: target hidden_size=%d differs from draft hidden_size=%d; using draft hidden width for n_target_features", + "DFlashDraftModel: target hidden_size=%d differs from draft hidden_size=%d; using target hidden width for n_target_features", int(target_hidden_size), int(draft_hidden_size), ) logger.info( - "DFlashDraftModel: inferred n_target_features=%d from draft hidden_size=%d and n_target_layers=%d", + "DFlashDraftModel: inferred n_target_features=%d from target hidden_size=%d and n_target_layers=%d", n_target_features, - int(draft_hidden_size), + int(target_hidden_size), len(target_layer_ids), ) @@ -2370,17 +2375,52 @@ class DFlashDraftModel(Qwen3Model): n_target_features, ) + def prepare_tensors(self): + super().prepare_tensors() + + if self._saw_output and not self._saw_token_embd: + raise ValueError( + "DFlashDraftModel conversion requires token_embd.weight when output.weight is present" + ) + + if self._saw_token_embd and self._saw_output: + io_mode = "self-contained" + elif self._saw_token_embd: + io_mode = "self-contained-tied" + else: + io_mode = "shared-target" + + logger.info( + "DFlashDraftModel IO contract: io=%s token_embd=%s output=%s target_model_dir=%s", + io_mode, + self._saw_token_embd, + self._saw_output, + self._require_target_model_dir(), + ) + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - if name == "fc.weight": + top_level_name = name[6:] if name.startswith("model.") else name + + if top_level_name == "fc.weight": return [(f"{gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.DFLASH_FC]}.weight", data_torch)] - if name == "hidden_norm.weight": + if top_level_name == "hidden_norm.weight": return [(f"{gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.DFLASH_HIDDEN_NORM]}.weight", data_torch)] if name == "norm.weight": name = "model.norm.weight" elif name.startswith("layers."): name = f"model.{name}" - return super().modify_tensors(data_torch, name, bid) + tensors = list(super().modify_tensors(data_torch, name, bid)) + token_embd_name = f"{gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.TOKEN_EMBD]}.weight" + output_name = f"{gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.OUTPUT]}.weight" + + for tensor_name, _ in tensors: + if tensor_name == token_embd_name: + self._saw_token_embd = True + elif tensor_name == output_name: + self._saw_output = True + + return tensors @Model.register("Ernie4_5_ForCausalLM", "Ernie4_5ForCausalLM") diff --git a/examples/server/server-context.cpp b/examples/server/server-context.cpp index c3171d76..81139ac1 100644 --- a/examples/server/server-context.cpp +++ b/examples/server/server-context.cpp @@ -126,6 +126,17 @@ static void server_dflash_contract_log_accept( server_dflash_contract_format_indices(output_indices).c_str()); } +static bool server_slot_prompt_batch_overlaps( + const server_slot & slot, + int32_t batch_i0, + int32_t batch_i1) { + if (slot.prompt_batch_i0 < 0 || slot.prompt_batch_i1 <= slot.prompt_batch_i0) { + return false; + } + + return slot.prompt_batch_i0 < batch_i1 && batch_i0 < slot.prompt_batch_i1; +} + static bool params_use_gemma4_external_mtp(const gpt_params & params_base) { return params_base.has_mtp && llama_model_is_gemma4_mtp_assistant(params_base.speculative.model_dft); @@ -708,6 +719,8 @@ void server_slot::reset() { n_past_prompt = 0; n_discarded_prompt = 0; n_kept_prompt = 0; + prompt_batch_i0 = -1; + prompt_batch_i1 = -1; n_sent_text = 0; drafted.clear(); drafted_spec_type = COMMON_SPECULATIVE_TYPE_NONE; @@ -3840,6 +3853,9 @@ bool server_context::create_checkpoint(server_slot & slot) { void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t n_batch, int32_t & batch_type) { if (params_base.cont_batching || batch.n_tokens == 0) { for (auto& slot : slots) { + slot.prompt_batch_i0 = -1; + slot.prompt_batch_i1 = -1; + // this slot still has a prompt to be processed if (slot.state == SLOT_STATE_IDLE && slot.command == SLOT_COMMAND_LOAD_PROMPT) { auto& prompt_tokens = slot.prompt_tokens; @@ -4127,6 +4143,7 @@ void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t int32_t ga_i = slot.ga_i; int32_t ga_n = slot.ga_n; int32_t ga_w = slot.ga_w; + const int32_t prompt_batch_i0 = batch.n_tokens; // add prompt tokens for processing in the current batch // TODO: the self-extend stuff here is a mess - simplify and/or abstract it somehow @@ -4161,6 +4178,9 @@ void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t } } + slot.prompt_batch_i0 = prompt_batch_i0; + slot.prompt_batch_i1 = batch.n_tokens; + LOG_VERBOSE("prompt processing progress", { {"id_slot", slot.id}, {"n_past", slot.n_past}, @@ -4770,6 +4790,7 @@ void server_context::update_allowlist_state(server_slot& slot) { void server_context::process_batch_tokens(int32_t & n_batch) { for (int32_t i = 0; i < batch.n_tokens; i += n_batch) { const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i); + bool finish_prompt_warmup_batch = false; extend_context(n_tokens); llama_batch batch_view = { @@ -4830,7 +4851,6 @@ void server_context::process_batch_tokens(int32_t & n_batch) { } if (server_speculative_has_target_features(params_base.speculative)) { - bool finished_prompt_warmup_batch = false; for (auto & slot : slots) { if (!slot.spec || !server_speculative_has_target_features(slot.params.speculative)) { continue; @@ -4840,7 +4860,7 @@ void server_context::process_batch_tokens(int32_t & n_batch) { continue; } - if (slot.command != SLOT_COMMAND_LOAD_PROMPT) { + if (!server_slot_prompt_batch_overlaps(slot, i, i + n_tokens)) { continue; } @@ -4849,13 +4869,9 @@ void server_context::process_batch_tokens(int32_t & n_batch) { slot.spec_prompt_warmup_failed = true; LOG_ERROR("failed to warm up speculative target-feature state from prompt batch for slot %d\n", slot.id); } else { - finished_prompt_warmup_batch = true; + finish_prompt_warmup_batch = true; } } - - if (finished_prompt_warmup_batch) { - llama_finish_dflash_capture_batch(ctx, true); - } } for (auto& slot : slots) { @@ -4974,6 +4990,10 @@ void server_context::process_batch_tokens(int32_t & n_batch) { } // speculative decoding - main model sample and accept speculative_decoding_accept(); + + if (finish_prompt_warmup_batch) { + llama_finish_dflash_capture_batch(ctx, true); + } } } @@ -5005,6 +5025,11 @@ void server_context::update_slots() { // start populating the batch for this iteration common_batch_clear(batch); + for (auto & slot : slots) { + slot.prompt_batch_i0 = -1; + slot.prompt_batch_i1 = -1; + } + // first, add sampled tokens from any ongoing sequences add_sampled_tokens(); // Prepare batch for inference diff --git a/examples/server/server-context.h b/examples/server/server-context.h index d4d0913c..f1e25ecd 100644 --- a/examples/server/server-context.h +++ b/examples/server/server-context.h @@ -64,6 +64,8 @@ struct server_slot { int32_t i_batch = -1; int32_t n_predict = -1; // TODO: disambiguate from params.n_predict + int32_t prompt_batch_i0 = -1; + int32_t prompt_batch_i1 = -1; int32_t n_prompt_tokens = 0; int32_t n_prompt_tokens_cache = 0; diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 232b664c..d0a8de4d 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -1261,7 +1261,9 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.DFLASH_HIDDEN_NORM, ], MODEL_ARCH.DFLASH_DRAFT: [ + MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_NORM, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_Q_NORM, diff --git a/src/graphs/build_dflash.cpp b/src/graphs/build_dflash.cpp index ef50f868..b9862c2a 100644 --- a/src/graphs/build_dflash.cpp +++ b/src/graphs/build_dflash.cpp @@ -64,6 +64,7 @@ ggml_cgraph * llm_build_context::build_dflash() { const int64_t n_embd_head_k = hparams.n_embd_head_k(0); const int64_t n_embd_head_v = hparams.n_embd_head_v(0); const int64_t n_target_features = hparams.dflash_n_target_features; + auto & profile = lctx.dflash_profile; const bool use_kv_cache = dflash_use_kv_cache_experiment(); const int64_t ctx_len = lctx.dflash_visible_cross_ctx > 0 ? (int64_t) lctx.dflash_visible_cross_ctx @@ -77,12 +78,30 @@ ggml_cgraph * llm_build_context::build_dflash() { ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes((int) std::max(n_tokens, ctx_len)) + 32 * n_layer, false); + bool have_swa_layers = false; + for (int il = 0; il < n_layer; ++il) { + if (hparams.swa_layers[il]) { + have_swa_layers = true; + break; + } + } + lctx.inp_dflash_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv_total, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); lctx.dflash_kq_mask_tensor = lctx.inp_dflash_kq_mask; ggml_set_input(lctx.inp_dflash_kq_mask); cb(lctx.inp_dflash_kq_mask, "dflash_kq_mask", -1); - ggml_tensor * dflash_kq_mask = flash_attn ? ggml_cast(ctx0, lctx.inp_dflash_kq_mask, GGML_TYPE_F16) : lctx.inp_dflash_kq_mask; + ggml_tensor * dflash_kq_mask_full = flash_attn ? ggml_cast(ctx0, lctx.inp_dflash_kq_mask, GGML_TYPE_F16) : lctx.inp_dflash_kq_mask; + ggml_tensor * dflash_kq_mask_swa = nullptr; + lctx.inp_dflash_kq_mask_swa = nullptr; + lctx.dflash_kq_mask_swa_tensor = nullptr; + if (have_swa_layers && hparams.n_swa > 0) { + lctx.inp_dflash_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv_total, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + lctx.dflash_kq_mask_swa_tensor = lctx.inp_dflash_kq_mask_swa; + ggml_set_input(lctx.inp_dflash_kq_mask_swa); + cb(lctx.inp_dflash_kq_mask_swa, "dflash_kq_mask_swa", -1); + dflash_kq_mask_swa = flash_attn ? ggml_cast(ctx0, lctx.inp_dflash_kq_mask_swa, GGML_TYPE_F16) : lctx.inp_dflash_kq_mask_swa; + } ggml_tensor * fused_target = nullptr; ggml_tensor * pos_ctx = nullptr; @@ -137,6 +156,7 @@ ggml_cgraph * llm_build_context::build_dflash() { Vcur_noise = ggml_reshape_3d(ctx0, Vcur_noise, n_embd_head_v, n_head_kv, n_tokens); cb(Vcur_noise, "Vcur_noise", il); + const int64_t t_cache_read_us = use_kv_cache ? ggml_time_us() : 0; ggml_tensor * Kcur_ctx = nullptr; ggml_tensor * Vcur_ctx = nullptr; if (use_kv_cache) { @@ -164,6 +184,11 @@ ggml_cgraph * llm_build_context::build_dflash() { Kcur = ggml_pad(ctx0, Kcur, 0, 0, (int) n_kv_pad, 0); Vcur = ggml_pad(ctx0, Vcur, 0, 0, (int) n_kv_pad, 0); } + if (use_kv_cache) { + profile.graph_kv_cache_read_concat_pad_us += (uint64_t) (ggml_time_us() - t_cache_read_us); + profile.graph_kv_cache_read_concat_pad_calls++; + profile.graph_kv_cache_cached_bytes += ggml_nbytes(lctx.dflash_k_ctx_cache[(size_t) il]) + ggml_nbytes(lctx.dflash_v_ctx_cache[(size_t) il]); + } cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); cb(Qcur, "Qcur", il); @@ -173,11 +198,14 @@ ggml_cgraph * llm_build_context::build_dflash() { ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3)); ggml_tensor * v = ggml_cont(ctx0, ggml_permute(ctx0, Vcur, 0, 2, 1, 3)); + ggml_tensor * dflash_kq_mask_l = (hparams.swa_layers[il] && dflash_kq_mask_swa != nullptr) + ? dflash_kq_mask_swa + : dflash_kq_mask_full; cb(q, "q", il); cb(k, "k", il); cb(v, "v", il); - cur = ggml_flash_attn_ext(ctx0, q, k, v, dflash_kq_mask, kq_scale, hparams.f_max_alibi_bias, + cur = ggml_flash_attn_ext(ctx0, q, k, v, dflash_kq_mask_l, kq_scale, hparams.f_max_alibi_bias, hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f); cb(cur, "flash_attn", il); ggml_build_forward_expand(gf, cur); diff --git a/src/llama-context.h b/src/llama-context.h index 9e902003..cc207a36 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -289,22 +289,26 @@ struct llama_context { std::vector dflash_feature_view_buffer; std::vector dflash_pos_ctx_data; std::vector dflash_kq_mask_data; + std::vector dflash_kq_mask_swa_data; int32_t dflash_visible_cross_ctx = 0; std::vector dflash_k_ctx_cache; std::vector dflash_v_ctx_cache; struct ggml_context * dflash_cache_ctx = nullptr; - ggml_backend_buffer_t dflash_cache_buf = nullptr; + std::vector dflash_cache_bufs; std::vector dflash_buf_compute_meta; ggml_backend_sched_t dflash_sched = nullptr; struct ggml_tensor * dflash_kv_input_target_features = nullptr; struct ggml_tensor * dflash_kv_input_pos_ctx = nullptr; struct ggml_tensor * dflash_kq_mask_tensor = nullptr; + struct ggml_tensor * dflash_kq_mask_swa_tensor = nullptr; struct dflash_capture_state { std::vector layer_ids; std::vector> layer_rows; int32_t row_count = 0; int32_t row_width = 0; + uint64_t capture_batch_id = 0; + std::vector layer_seen_batch_id; ggml_backend_sched_eval_callback prev_cb_eval = nullptr; void * prev_cb_eval_user_data = nullptr; }; @@ -333,6 +337,7 @@ struct llama_context { struct ggml_tensor * inp_dflash_target_features = nullptr; // F32 [n_target_features, cross_ctx] struct ggml_tensor * inp_dflash_pos_ctx = nullptr; // I32 [cross_ctx] struct ggml_tensor * inp_dflash_kq_mask = nullptr; // F32 [cross_ctx + n_batch, GGML_PAD(n_batch)] + struct ggml_tensor * inp_dflash_kq_mask_swa = nullptr; // F32 [cross_ctx + n_batch, GGML_PAD(n_batch)] ggml_backend_t ggml_backend_by_name(const char * name); diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 271633c6..3ebc2459 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -79,6 +79,17 @@ static bool load_dflash_target_layer_ids( } else { hparams.dflash_target_layer_ids[i] = ((const uint32_t *) data)[i]; } + + const uint32_t id = hparams.dflash_target_layer_ids[i]; + + for (uint32_t j = 0; j < i; ++j) { + if (hparams.dflash_target_layer_ids[j] == id) { + throw std::runtime_error(format( + "dflash: %s contains duplicate layer id %u", + key.c_str(), + id)); + } + } } return true; @@ -92,19 +103,8 @@ static void validate_dflash_hparams(llama_hparams & hparams, llm_arch arch) { throw std::runtime_error(format("%s: dflash target_layer_ids are required", llama_model_arch_name(arch))); } - if (arch == LLM_ARCH_DFLASH_DRAFT && hparams.n_embd > 0) { - const uint32_t expected_n_target_features = hparams.n_embd * hparams.dflash_n_target_layers; - if (expected_n_target_features > 0 && hparams.dflash_n_target_features != expected_n_target_features) { - LLAMA_LOG_WARN( - "%s: overriding dflash n_target_features from %u to %u based on n_embd=%u and n_target_layers=%u\n", - llama_model_arch_name(arch), - hparams.dflash_n_target_features, - expected_n_target_features, - hparams.n_embd, - hparams.dflash_n_target_layers); - hparams.dflash_n_target_features = expected_n_target_features; - } - } + // DFlash feature width is target-model specific. Keep the serialized metadata intact here + // and validate it against the live target model during DFlash init. if (hparams.dflash_n_target_features == 0) { throw std::runtime_error(format( diff --git a/src/llama-spec-features.cpp b/src/llama-spec-features.cpp index 91b5d41d..bcf1ca89 100644 --- a/src/llama-spec-features.cpp +++ b/src/llama-spec-features.cpp @@ -113,6 +113,53 @@ int32_t llama_model_dflash_target_layer_ids( return n_layers; } +int32_t llama_model_dflash_target_mask_token_id(const struct llama_model * model) { + if (model == nullptr) { + return (int32_t) LLAMA_TOKEN_NULL; + } + + return (int32_t) model->vocab.token_mask(); +} + +int32_t llama_model_dflash_io_mode( + const struct llama_model * draft_model, + const struct llama_model * target_model) { + if (draft_model == nullptr || target_model == nullptr || draft_model->arch != LLM_ARCH_DFLASH_DRAFT) { + return LLAMA_DFLASH_IO_MODE_INVALID; + } + + const ggml_tensor * target_output = target_model->output != nullptr ? target_model->output : target_model->tok_embd; + if (draft_model->tok_embd == nullptr || draft_model->output == nullptr || target_model->tok_embd == nullptr || target_output == nullptr) { + return LLAMA_DFLASH_IO_MODE_INVALID; + } + + const bool shared_tok = draft_model->tok_embd == target_model->tok_embd; + const bool shared_output = draft_model->output == target_output; + if (shared_tok && shared_output) { + return LLAMA_DFLASH_IO_MODE_SHARED; + } + + if (!shared_tok && !shared_output) { + return LLAMA_DFLASH_IO_MODE_SELF_CONTAINED; + } + + return LLAMA_DFLASH_IO_MODE_MIXED; +} + +bool llama_model_dflash_io_tensors_match( + const struct llama_model * draft_model, + int32_t n_embd, + int32_t n_vocab) { + if (draft_model == nullptr || draft_model->tok_embd == nullptr || draft_model->output == nullptr || n_embd <= 0 || n_vocab <= 0) { + return false; + } + + return (int32_t) draft_model->tok_embd->ne[0] == n_embd && + (int32_t) draft_model->tok_embd->ne[1] == n_vocab && + (int32_t) draft_model->output->ne[0] == n_embd && + (int32_t) draft_model->output->ne[1] == n_vocab; +} + bool llama_model_share_dflash_io_tensors( struct llama_model * draft_model, const struct llama_model * target_model) { @@ -323,11 +370,19 @@ static bool llama_dflash_capture_eval_callback(struct ggml_tensor * tensor, bool } auto & capture = *ctx->dflash_capture; + if (capture.capture_batch_id == 0) { + capture.capture_batch_id = 1; + } + if (capture.layer_seen_batch_id.size() != capture.layer_ids.size()) { + capture.layer_seen_batch_id.assign(capture.layer_ids.size(), 0); + } + auto & rows = capture.layer_rows[(size_t) layer_idx]; rows.resize((size_t) row_count * (size_t) row_width); ggml_backend_tensor_get(tensor, rows.data(), 0, ggml_nbytes(tensor)); capture.row_width = row_width; capture.row_count = row_count; + capture.layer_seen_batch_id[(size_t) layer_idx] = capture.capture_batch_id; return true; } @@ -342,6 +397,7 @@ bool llama_set_dflash_capture_layers( auto capture = std::make_unique(); capture->layer_ids.assign(layer_ids, layer_ids + n_layers); capture->layer_rows.resize((size_t) n_layers); + capture->layer_seen_batch_id.assign((size_t) n_layers, 0); capture->prev_cb_eval = ctx->cparams.cb_eval; capture->prev_cb_eval_user_data = ctx->cparams.cb_eval_user_data; ctx->dflash_capture = std::move(capture); @@ -380,6 +436,18 @@ void llama_clear_dflash_capture(struct llama_context * ctx) { } } +void llama_begin_dflash_capture_batch(struct llama_context * ctx) { + if (ctx == nullptr || !ctx->dflash_capture) { + return; + } + + auto & capture = *ctx->dflash_capture; + capture.capture_batch_id++; + capture.row_count = 0; + capture.row_width = 0; + std::fill(capture.layer_seen_batch_id.begin(), capture.layer_seen_batch_id.end(), 0); +} + void llama_finish_dflash_capture_batch( struct llama_context * ctx, bool is_prompt_warmup) { @@ -420,7 +488,35 @@ static bool llama_spec_prepare_dflash_capture( return false; } + if (capture.capture_batch_id == 0 || capture.layer_seen_batch_id.size() != (size_t) n_layers) { + profile.capture_prepare_failures++; + profile.capture_layer_batch_mismatch++; + if (profile.capture_layer_batch_mismatch <= 3) { + LLAMA_LOG_WARN("%s: DFlash capture batch markers are not initialized (batch_id=%llu layers=%zu expected=%d)\n", + __func__, + (unsigned long long) capture.capture_batch_id, + capture.layer_seen_batch_id.size(), + n_layers); + } + return false; + } + for (int32_t layer_idx = 0; layer_idx < n_layers; ++layer_idx) { + if (capture.layer_seen_batch_id[(size_t) layer_idx] != capture.capture_batch_id) { + profile.capture_prepare_failures++; + profile.capture_layer_batch_mismatch++; + if (profile.capture_layer_batch_mismatch <= 3) { + LLAMA_LOG_WARN("%s: DFlash capture is stale for layer %d (seen_batch=%llu current_batch=%llu rows=%d width=%d)\n", + __func__, + capture.layer_ids[(size_t) layer_idx], + (unsigned long long) capture.layer_seen_batch_id[(size_t) layer_idx], + (unsigned long long) capture.capture_batch_id, + row_count, + row_width); + } + return false; + } + const auto & rows = capture.layer_rows[(size_t) layer_idx]; if (rows.size() != (size_t) row_count * (size_t) row_width) { profile.capture_prepare_failures++; @@ -595,11 +691,41 @@ static void llama_dflash_contract_log_output_indices( have_capture ? "true" : "false"); } + static bool llama_spec_materialize_dflash_rows_prepared( + struct llama_context * ctx, + int32_t row_count, + int32_t row_width, + int32_t n_layers, + const std::vector & row_indices, + std::vector & rows_out, + int32_t & combined_width); + static bool llama_spec_materialize_dflash_rows( struct llama_context * ctx, const std::vector & row_indices, std::vector & rows_out, int32_t & combined_width) { + int32_t row_count = 0; + int32_t row_width = 0; + int32_t n_layers = 0; + if (!llama_spec_prepare_dflash_capture(ctx, row_count, row_width, n_layers)) { + if (ctx != nullptr) { + ctx->dflash_profile.capture_materialize_failures++; + } + return false; + } + + return llama_spec_materialize_dflash_rows_prepared(ctx, row_count, row_width, n_layers, row_indices, rows_out, combined_width); +} + +static bool llama_spec_materialize_dflash_rows_prepared( + struct llama_context * ctx, + int32_t row_count, + int32_t row_width, + int32_t n_layers, + const std::vector & row_indices, + std::vector & rows_out, + int32_t & combined_width) { rows_out.clear(); combined_width = 0; if (ctx == nullptr || row_indices.empty()) { @@ -610,10 +736,7 @@ static bool llama_spec_materialize_dflash_rows( profile.capture_materialize_calls++; const int64_t t_start_us = ggml_time_us(); - int32_t row_count = 0; - int32_t row_width = 0; - int32_t n_layers = 0; - if (!llama_spec_prepare_dflash_capture(ctx, row_count, row_width, n_layers)) { + if (row_count <= 0 || row_width <= 0 || n_layers <= 0 || ctx->dflash_capture == nullptr) { profile.capture_materialize_failures++; return false; } @@ -735,7 +858,7 @@ bool llama_spec_get_dflash_feature_view( view = {}; view.kind = LLAMA_SPEC_FEATURE_HIDDEN_STATE; - if (!llama_spec_materialize_dflash_rows(ctx, row_indices, ctx->dflash_feature_view_buffer, view.width)) { + if (!llama_spec_materialize_dflash_rows_prepared(ctx, row_count, row_width, n_layers, row_indices, ctx->dflash_feature_view_buffer, view.width)) { return false; } @@ -808,7 +931,7 @@ bool llama_spec_get_dflash_feature_view_for_seq( view = {}; view.kind = LLAMA_SPEC_FEATURE_HIDDEN_STATE; - if (!llama_spec_materialize_dflash_rows(ctx, row_indices, ctx->dflash_feature_view_buffer, view.width)) { + if (!llama_spec_materialize_dflash_rows_prepared(ctx, row_count, row_width, n_layers, row_indices, ctx->dflash_feature_view_buffer, view.width)) { return false; } diff --git a/src/llama-spec-features.h b/src/llama-spec-features.h index 20f0ff51..9ec2e827 100644 --- a/src/llama-spec-features.h +++ b/src/llama-spec-features.h @@ -24,6 +24,14 @@ struct llama_spec_feature_view { }; struct llama_dflash_profile_stats { + uint64_t decode_output_reserve_calls = 0; + uint64_t decode_output_reserve_us = 0; + uint64_t decode_output_reserve_reallocs = 0; + uint64_t decode_output_reserve_realloc_bytes = 0; + uint64_t decode_prepare_calls = 0; + uint64_t decode_prepare_us = 0; + uint64_t decode_prepare_failures = 0; + uint64_t set_target_copy_calls = 0; uint64_t set_target_copy_us = 0; uint64_t set_target_rows = 0; @@ -35,6 +43,7 @@ struct llama_dflash_profile_stats { uint64_t capture_prepare_sync_us = 0; uint64_t capture_prepare_failures = 0; uint64_t capture_layer_shape_mismatch = 0; + uint64_t capture_layer_batch_mismatch = 0; uint64_t capture_prompt_batches = 0; uint64_t capture_prompt_shape_changes = 0; uint64_t capture_verify_batches = 0; @@ -50,7 +59,17 @@ struct llama_dflash_profile_stats { uint64_t graph_feature_copy_us = 0; uint64_t graph_pos_copy_us = 0; uint64_t graph_mask_build_us = 0; + uint64_t graph_kv_cache_build_us = 0; + uint64_t graph_kv_cache_reserve_us = 0; + uint64_t graph_kv_cache_reset_us = 0; + uint64_t graph_kv_cache_alloc_us = 0; + uint64_t graph_kv_cache_feature_upload_us = 0; + uint64_t graph_kv_cache_pos_upload_us = 0; uint64_t graph_kv_cache_compute_us = 0; + uint64_t graph_kv_cache_sync_us = 0; + uint64_t graph_kv_cache_read_concat_pad_us = 0; + uint64_t graph_kv_cache_read_concat_pad_calls = 0; + uint64_t graph_kv_cache_cached_bytes = 0; uint64_t graph_kv_cache_calls = 0; uint64_t graph_feature_bytes = 0; uint64_t graph_pos_bytes = 0; @@ -68,6 +87,7 @@ struct llama_dflash_profile_stats { int32_t last_left_pad = 0; int32_t last_n_tokens = 0; int32_t last_n_kv_total = 0; + int32_t last_kv_cache_host_layers = 0; int32_t capture_prompt_last_rows = 0; int32_t capture_prompt_last_width = 0; int32_t capture_verify_last_rows = 0; @@ -104,6 +124,24 @@ int32_t llama_model_dflash_target_layer_ids( int32_t * layer_ids, int32_t capacity); +enum llama_dflash_io_mode { + LLAMA_DFLASH_IO_MODE_INVALID = 0, + LLAMA_DFLASH_IO_MODE_SHARED, + LLAMA_DFLASH_IO_MODE_SELF_CONTAINED, + LLAMA_DFLASH_IO_MODE_MIXED, +}; + +int32_t llama_model_dflash_target_mask_token_id(const struct llama_model * model); + +int32_t llama_model_dflash_io_mode( + const struct llama_model * draft_model, + const struct llama_model * target_model); + +bool llama_model_dflash_io_tensors_match( + const struct llama_model * draft_model, + int32_t n_embd, + int32_t n_vocab); + bool llama_model_share_dflash_io_tensors( struct llama_model * draft_model, const struct llama_model * target_model); @@ -134,6 +172,8 @@ bool llama_set_dflash_capture_layers( void llama_clear_dflash_capture(struct llama_context * ctx); +void llama_begin_dflash_capture_batch(struct llama_context * ctx); + void llama_finish_dflash_capture_batch( struct llama_context * ctx, bool is_prompt_warmup); diff --git a/src/llama.cpp b/src/llama.cpp index af37211e..e3b91b0b 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -565,6 +565,44 @@ void llama_context::reset_scheduler() { prev_mtp.reset(); } +static ggml_backend_buffer_type_t llama_dflash_kv_cache_layer_buft(const llama_context & lctx, int32_t il) { + if (il >= 0 && (size_t) il < lctx.model.buft_layer.size() && lctx.model.buft_layer[(size_t) il].buft != nullptr) { + return lctx.model.buft_layer[(size_t) il].buft; + } + + if (il >= 0 && (size_t) il < lctx.model.layers.size()) { + const ggml_tensor * wk = lctx.model.layers[(size_t) il].wk; + if (wk != nullptr && wk->buffer != nullptr) { + return ggml_backend_buffer_get_type(wk->buffer); + } + } + + return llama_default_buffer_type_cpu(true); +} + +static ggml_backend_t llama_backend_for_tensor(const llama_context & lctx, const ggml_tensor * tensor) { + if (tensor == nullptr) { + return nullptr; + } + + ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; + if (buf == nullptr) { + return nullptr; + } + + ggml_backend_buffer_type_t buft = ggml_backend_buffer_get_type(buf); + for (ggml_backend_t backend : lctx.backends) { + ggml_backend_buffer_type_t backend_buft = ggml_backend_is_cpu(backend) + ? llama_default_buffer_type_cpu(true) + : ggml_backend_get_default_buffer_type(backend); + if (backend_buft == buft) { + return backend; + } + } + + return nullptr; +} + bool llama_context::ensure_dflash_kv_cache_tensors(int32_t cross_ctx) { const int32_t target_cross_ctx = std::max(1, cross_ctx); const int32_t n_layer = model.hparams.n_layer; @@ -587,8 +625,6 @@ bool llama_context::ensure_dflash_kv_cache_tensors(int32_t cross_ctx) { dflash_buf_compute_meta.clear(); } - ggml_backend_buffer_type_t buft = llama_default_buffer_type_cpu(true); - ggml_init_params params = { /*.mem_size =*/ (size_t) (2 * std::max(1, n_layer)) * ggml_tensor_overhead(), /*.mem_buffer =*/ nullptr, @@ -602,7 +638,21 @@ bool llama_context::ensure_dflash_kv_cache_tensors(int32_t cross_ctx) { dflash_k_ctx_cache.resize((size_t) n_layer); dflash_v_ctx_cache.resize((size_t) n_layer); + dflash_cache_bufs.clear(); + dflash_cache_bufs.reserve((size_t) std::max(1, n_layer) * 2); + int32_t host_layers = 0; + const char * first_buft_name = nullptr; + const char * last_buft_name = nullptr; for (int32_t il = 0; il < n_layer; ++il) { + ggml_backend_buffer_type_t layer_buft = llama_dflash_kv_cache_layer_buft(*this, il); + if (ggml_backend_buft_is_host(layer_buft)) { + host_layers++; + } + if (first_buft_name == nullptr) { + first_buft_name = ggml_backend_buft_name(layer_buft); + } + last_buft_name = ggml_backend_buft_name(layer_buft); + dflash_k_ctx_cache[(size_t) il] = ggml_new_tensor_3d(dflash_cache_ctx, GGML_TYPE_F32, n_embd_head_k, n_head_kv, target_cross_ctx); dflash_v_ctx_cache[(size_t) il] = ggml_new_tensor_3d(dflash_cache_ctx, GGML_TYPE_F32, n_embd_head_v, n_head_kv, target_cross_ctx); if (dflash_k_ctx_cache[(size_t) il] == nullptr || dflash_v_ctx_cache[(size_t) il] == nullptr) { @@ -614,15 +664,39 @@ bool llama_context::ensure_dflash_kv_cache_tensors(int32_t cross_ctx) { ggml_set_input(dflash_v_ctx_cache[(size_t) il]); ggml_format_name(dflash_k_ctx_cache[(size_t) il], "dflash_k_ctx_cache_%d", il); ggml_format_name(dflash_v_ctx_cache[(size_t) il], "dflash_v_ctx_cache_%d", il); + + const size_t k_bytes = ggml_backend_buft_get_alloc_size(layer_buft, dflash_k_ctx_cache[(size_t) il]); + ggml_backend_buffer_t k_buf = ggml_backend_buft_alloc_buffer(layer_buft, k_bytes); + if (k_buf == nullptr) { + free_dflash_kv_cache_tensors(); + return false; + } + ggml_backend_buffer_set_usage(k_buf, GGML_BACKEND_BUFFER_USAGE_COMPUTE); + ggml_backend_tensor_alloc(k_buf, dflash_k_ctx_cache[(size_t) il], ggml_backend_buffer_get_base(k_buf)); + ggml_backend_buffer_clear(k_buf, 0); + dflash_cache_bufs.push_back(k_buf); + + const size_t v_bytes = ggml_backend_buft_get_alloc_size(layer_buft, dflash_v_ctx_cache[(size_t) il]); + ggml_backend_buffer_t v_buf = ggml_backend_buft_alloc_buffer(layer_buft, v_bytes); + if (v_buf == nullptr) { + free_dflash_kv_cache_tensors(); + return false; + } + ggml_backend_buffer_set_usage(v_buf, GGML_BACKEND_BUFFER_USAGE_COMPUTE); + ggml_backend_tensor_alloc(v_buf, dflash_v_ctx_cache[(size_t) il], ggml_backend_buffer_get_base(v_buf)); + ggml_backend_buffer_clear(v_buf, 0); + dflash_cache_bufs.push_back(v_buf); } - dflash_cache_buf = ggml_backend_alloc_ctx_tensors_from_buft(dflash_cache_ctx, buft); - if (dflash_cache_buf == nullptr) { - free_dflash_kv_cache_tensors(); - return false; - } + dflash_profile.last_kv_cache_host_layers = host_layers; + LLAMA_LOG_INFO("%s: DFlash K/V cache placement cross_ctx=%d host_layers=%d/%d first=%s last=%s\n", + __func__, + target_cross_ctx, + host_layers, + n_layer, + first_buft_name != nullptr ? first_buft_name : "(none)", + last_buft_name != nullptr ? last_buft_name : "(none)"); - ggml_backend_buffer_clear(dflash_cache_buf, 0); return true; } @@ -632,11 +706,14 @@ void llama_context::free_dflash_kv_cache_tensors() { dflash_kv_input_target_features = nullptr; dflash_kv_input_pos_ctx = nullptr; dflash_kq_mask_tensor = nullptr; + dflash_kq_mask_swa_tensor = nullptr; - if (dflash_cache_buf != nullptr) { - ggml_backend_buffer_free(dflash_cache_buf); - dflash_cache_buf = nullptr; + for (ggml_backend_buffer_t buf : dflash_cache_bufs) { + if (buf != nullptr) { + ggml_backend_buffer_free(buf); + } } + dflash_cache_bufs.clear(); if (dflash_cache_ctx != nullptr) { ggml_free(dflash_cache_ctx); dflash_cache_ctx = nullptr; @@ -5087,6 +5164,110 @@ static bool prepare_mtp_graph_inputs( return true; } +static bool dflash_layer_has_attention_bias(const llama_layer & layer) { + return layer.bq != nullptr || + layer.bk != nullptr || + layer.bv != nullptr || + layer.bo != nullptr || + layer.bqkv != nullptr || + layer.bqk != nullptr || + layer.bkv != nullptr; +} + +static bool validate_dflash_graph_contract(const llama_context & lctx) { + const auto & model = lctx.model; + const auto & hparams = model.hparams; + + auto rope_dim_for_layer = [&hparams](int32_t il) -> uint32_t { + if (hparams.rope_dim_per_layer[(size_t) il] != 0) { + return hparams.rope_dim_per_layer[(size_t) il]; + } + + return hparams.swa_layers[(size_t) il] ? hparams.n_rot_swa : hparams.n_rot; + }; + + auto rope_base_for_layer = [&hparams](int32_t il) -> float { + if (hparams.has_rope_freq_base_per_layer) { + return hparams.rope_freq_base_per_layer[(size_t) il]; + } + + return hparams.swa_layers[(size_t) il] ? hparams.rope_freq_base_train_swa : hparams.rope_freq_base_train; + }; + + auto rope_scale_for_layer = [&hparams](int32_t il) -> float { + return hparams.swa_layers[(size_t) il] ? hparams.rope_freq_scale_train_swa : hparams.rope_freq_scale_train; + }; + + const uint32_t ref_n_head = hparams.n_head(0); + const uint32_t ref_n_head_kv = hparams.n_head_kv(0); + const uint32_t ref_n_embd_head_k = hparams.n_embd_head_k(0); + const uint32_t ref_n_embd_head_v = hparams.n_embd_head_v(0); + const uint32_t ref_rope_dim = rope_dim_for_layer(0); + const float ref_rope_base = rope_base_for_layer(0); + const float ref_rope_scale = rope_scale_for_layer(0); + + for (int32_t il = 0; il < (int32_t) hparams.n_layer; ++il) { + if (hparams.n_head((uint32_t) il) != ref_n_head || + hparams.n_head_kv((uint32_t) il) != ref_n_head_kv || + hparams.n_embd_head_k(il) != ref_n_embd_head_k || + hparams.n_embd_head_v(il) != ref_n_embd_head_v) { + LLAMA_LOG_ERROR("%s: DFlash graph assumes layer-invariant head config, but layer %d differs (n_head=%u/%u n_head_kv=%u/%u head_k=%u/%u head_v=%u/%u)\n", + __func__, + il, + hparams.n_head((uint32_t) il), ref_n_head, + hparams.n_head_kv((uint32_t) il), ref_n_head_kv, + hparams.n_embd_head_k(il), ref_n_embd_head_k, + hparams.n_embd_head_v(il), ref_n_embd_head_v); + return false; + } + + const uint32_t rope_dim = rope_dim_for_layer(il); + const float rope_base = rope_base_for_layer(il); + const float rope_scale = rope_scale_for_layer(il); + if (rope_dim != ref_rope_dim || std::fabs(rope_base - ref_rope_base) > 1e-6f || std::fabs(rope_scale - ref_rope_scale) > 1e-6f) { + LLAMA_LOG_ERROR("%s: DFlash graph assumes layer-invariant RoPE config, but layer %d differs (dim=%u/%u base=%g/%g scale=%g/%g)\n", + __func__, + il, + rope_dim, ref_rope_dim, + (double) rope_base, (double) ref_rope_base, + (double) rope_scale, (double) ref_rope_scale); + return false; + } + + if (model.layers[(size_t) il].attn_norm == nullptr || + model.layers[(size_t) il].attn_q_norm == nullptr || + model.layers[(size_t) il].attn_k_norm == nullptr) { + LLAMA_LOG_ERROR("%s: DFlash graph requires attn_norm, attn_q_norm, and attn_k_norm weights, but layer %d is missing one or more of them\n", + __func__, il); + return false; + } + + const bool has_q_norm = model.layers[(size_t) il].attn_q_norm != nullptr; + const bool has_k_norm = model.layers[(size_t) il].attn_k_norm != nullptr; + if (has_q_norm != has_k_norm) { + LLAMA_LOG_ERROR("%s: DFlash graph requires symmetric Q/K norm presence, but layer %d has q_norm=%d k_norm=%d\n", + __func__, il, (int) has_q_norm, (int) has_k_norm); + return false; + } + + if (model.layers[(size_t) il].attn_norm_b != nullptr || + model.layers[(size_t) il].attn_q_norm_b != nullptr || + model.layers[(size_t) il].attn_k_norm_b != nullptr) { + LLAMA_LOG_ERROR("%s: DFlash graph does not implement norm-bias tensors, but layer %d requires attn_norm_b/q_norm_b/k_norm_b\n", + __func__, il); + return false; + } + + if (dflash_layer_has_attention_bias(model.layers[(size_t) il])) { + LLAMA_LOG_ERROR("%s: DFlash graph does not implement attention bias tensors, but layer %d requires them\n", + __func__, il); + return false; + } + } + + return true; +} + static bool prepare_dflash_graph_inputs( struct llama_context & lctx, uint32_t n_tokens) { @@ -5095,16 +5276,23 @@ static bool prepare_dflash_graph_inputs( std::strcmp(dflash_kv_cache_env, "0") != 0 && std::strcmp(dflash_kv_cache_env, "false") != 0 && std::strcmp(dflash_kv_cache_env, "off") != 0; + auto & profile = lctx.dflash_profile; const int32_t cross_ctx = lctx.dflash_visible_cross_ctx > 0 ? lctx.dflash_visible_cross_ctx : std::max(1, (int32_t) lctx.cparams.n_ctx - (int32_t) lctx.model.hparams.dflash_block_size); ggml_tensor * kq_mask = lctx.dflash_kq_mask_tensor; + ggml_tensor * kq_mask_swa = lctx.dflash_kq_mask_swa_tensor; if (kq_mask == nullptr) { LLAMA_LOG_ERROR("%s: DFlash graph inputs are not initialized\n", __func__); return false; } + if (!validate_dflash_graph_contract(lctx)) { + profile.graph_shape_failures++; + return false; + } + if (use_kv_cache) { if (!lctx.ensure_dflash_kv_cache_tensors(cross_ctx) || lctx.dflash_k_ctx_cache.empty() || lctx.dflash_v_ctx_cache.empty()) { LLAMA_LOG_ERROR("%s: DFlash K/V cache inputs are not initialized\n", __func__); @@ -5126,7 +5314,6 @@ static bool prepare_dflash_graph_inputs( : (lctx.inp_dflash_target_features != nullptr ? (int32_t) lctx.inp_dflash_target_features->ne[1] : 0); const int32_t n_mask_tokens = (int32_t) kq_mask->ne[1]; const int32_t n_kv_total = (int32_t) kq_mask->ne[0]; - auto & profile = lctx.dflash_profile; const int64_t t_total_us = ggml_time_us(); profile.graph_prepare_calls++; @@ -5186,36 +5373,32 @@ static bool prepare_dflash_graph_inputs( const int64_t t_pos_us = ggml_time_us(); lctx.dflash_pos_ctx_data.resize((size_t) cross_ctx); std::fill(lctx.dflash_pos_ctx_data.begin(), lctx.dflash_pos_ctx_data.end(), 0); - if (src_pos != nullptr && total_positions == (size_t) n_rows) { - bool monotonic = true; - for (int32_t i = 1; i < n_rows; ++i) { - if (src_pos[i] <= src_pos[i - 1]) { - monotonic = false; - break; - } - } - if (!monotonic) { - profile.graph_pos_non_monotonic++; - if (profile.graph_pos_non_monotonic <= 3) { - LLAMA_LOG_WARN("%s: DFlash target positions are not strictly increasing (rows=%d first=%d last=%d)\n", - __func__, n_rows, (int) src_pos[0], (int) src_pos[n_rows - 1]); - } - } - profile.last_pos_first = src_pos[0]; - profile.last_pos_last = src_pos[n_rows - 1]; - std::copy(src_pos, src_pos + n_rows, lctx.dflash_pos_ctx_data.begin() + (ptrdiff_t) left_pad); - } else { + if (src_pos == nullptr || total_positions != (size_t) n_rows) { profile.graph_pos_fallbacks++; + profile.graph_shape_failures++; profile.last_pos_first = -1; profile.last_pos_last = -1; if (profile.graph_pos_fallbacks <= 3) { - LLAMA_LOG_WARN("%s: using synthetic DFlash positions (rows=%d positions=%zu cross_ctx=%d)\n", + LLAMA_LOG_ERROR("%s: missing DFlash target positions (rows=%d positions=%zu cross_ctx=%d)\n", __func__, n_rows, total_positions, cross_ctx); } - for (int32_t i = 0; i < n_rows; ++i) { - lctx.dflash_pos_ctx_data[(size_t) left_pad + (size_t) i] = i; + return false; + } + + profile.last_pos_first = src_pos[0]; + profile.last_pos_last = src_pos[n_rows - 1]; + for (int32_t i = 1; i < n_rows; ++i) { + if (src_pos[i] <= src_pos[i - 1]) { + profile.graph_pos_non_monotonic++; + profile.graph_shape_failures++; + if (profile.graph_pos_non_monotonic <= 3) { + LLAMA_LOG_ERROR("%s: DFlash target positions are not strictly increasing (rows=%d first=%d last=%d)\n", + __func__, n_rows, (int) src_pos[0], (int) src_pos[n_rows - 1]); + } + return false; } } + std::copy(src_pos, src_pos + n_rows, lctx.dflash_pos_ctx_data.begin() + (ptrdiff_t) left_pad); profile.graph_pos_copy_us += (uint64_t) (ggml_time_us() - t_pos_us); profile.graph_pos_bytes += lctx.dflash_pos_ctx_data.size() * sizeof(llama_pos); @@ -5226,7 +5409,9 @@ static bool prepare_dflash_graph_inputs( lctx.dflash_buf_compute_meta.resize(meta_size); } + const int64_t t_build_us = ggml_time_us(); ggml_cgraph * gf_kv = llm_build_context::llama_build_graph_dflash_kv_cache(lctx); + profile.graph_kv_cache_build_us += (uint64_t) (ggml_time_us() - t_build_us); if (gf_kv == nullptr || lctx.dflash_kv_input_target_features == nullptr || lctx.dflash_kv_input_pos_ctx == nullptr) { profile.graph_shape_failures++; LLAMA_LOG_ERROR("%s: failed to build DFlash K/V cache graph\n", __func__); @@ -5244,22 +5429,50 @@ static bool prepare_dflash_graph_inputs( } } + const int64_t t_reserve_us = ggml_time_us(); lctx.dflash_sched = ggml_backend_sched_new(lctx.backends.data(), backend_buft.data(), lctx.backends.size(), max_nodes, false); - if (lctx.dflash_sched == nullptr || !ggml_backend_sched_reserve(lctx.dflash_sched, gf_kv)) { + const bool reserved = lctx.dflash_sched != nullptr && ggml_backend_sched_reserve(lctx.dflash_sched, gf_kv); + profile.graph_kv_cache_reserve_us += (uint64_t) (ggml_time_us() - t_reserve_us); + if (!reserved) { profile.graph_shape_failures++; LLAMA_LOG_ERROR("%s: failed to initialize DFlash K/V scheduler\n", __func__); return false; } } + const int64_t t_reset_us = ggml_time_us(); ggml_backend_sched_reset(lctx.dflash_sched); + profile.graph_kv_cache_reset_us += (uint64_t) (ggml_time_us() - t_reset_us); + + const int64_t t_alloc_us = ggml_time_us(); ggml_backend_sched_alloc_graph(lctx.dflash_sched, gf_kv); - ggml_backend_tensor_set(lctx.dflash_kv_input_target_features, lctx.dflash_target_features_padded.data(), 0, ggml_nbytes(lctx.dflash_kv_input_target_features)); - ggml_backend_tensor_set(lctx.dflash_kv_input_pos_ctx, lctx.dflash_pos_ctx_data.data(), 0, ggml_nbytes(lctx.dflash_kv_input_pos_ctx)); + profile.graph_kv_cache_alloc_us += (uint64_t) (ggml_time_us() - t_alloc_us); + + ggml_backend_t kv_feature_backend = llama_backend_for_tensor(lctx, lctx.dflash_kv_input_target_features); + const int64_t t_feature_upload_us = ggml_time_us(); + if (kv_feature_backend != nullptr) { + ggml_backend_tensor_set_async(kv_feature_backend, lctx.dflash_kv_input_target_features, lctx.dflash_target_features_padded.data(), 0, ggml_nbytes(lctx.dflash_kv_input_target_features)); + } else { + ggml_backend_tensor_set(lctx.dflash_kv_input_target_features, lctx.dflash_target_features_padded.data(), 0, ggml_nbytes(lctx.dflash_kv_input_target_features)); + } + profile.graph_kv_cache_feature_upload_us += (uint64_t) (ggml_time_us() - t_feature_upload_us); + + ggml_backend_t kv_pos_backend = llama_backend_for_tensor(lctx, lctx.dflash_kv_input_pos_ctx); + const int64_t t_pos_upload_us = ggml_time_us(); + if (kv_pos_backend != nullptr) { + ggml_backend_tensor_set_async(kv_pos_backend, lctx.dflash_kv_input_pos_ctx, lctx.dflash_pos_ctx_data.data(), 0, ggml_nbytes(lctx.dflash_kv_input_pos_ctx)); + } else { + ggml_backend_tensor_set(lctx.dflash_kv_input_pos_ctx, lctx.dflash_pos_ctx_data.data(), 0, ggml_nbytes(lctx.dflash_kv_input_pos_ctx)); + } + profile.graph_kv_cache_pos_upload_us += (uint64_t) (ggml_time_us() - t_pos_upload_us); + const int64_t t_kv_cache_us = ggml_time_us(); llama_graph_compute_sched(lctx, lctx.dflash_sched, gf_kv, lctx.cparams.n_threads); - llama_synchronize(&lctx); profile.graph_kv_cache_compute_us += (uint64_t) (ggml_time_us() - t_kv_cache_us); + + const int64_t t_sync_us = ggml_time_us(); + ggml_backend_sched_synchronize(lctx.dflash_sched); + profile.graph_kv_cache_sync_us += (uint64_t) (ggml_time_us() - t_sync_us); profile.graph_kv_cache_calls++; } else { ggml_backend_tensor_set(lctx.inp_dflash_target_features, lctx.dflash_target_features_padded.data(), 0, ggml_nbytes(lctx.inp_dflash_target_features)); @@ -5267,28 +5480,66 @@ static bool prepare_dflash_graph_inputs( } const int64_t t_mask_us = ggml_time_us(); + const int32_t full_visible_first = left_pad; + const int32_t full_visible_last = cross_ctx + (int32_t) n_tokens - 1; lctx.dflash_kq_mask_data.assign((size_t) n_kv_total * (size_t) n_mask_tokens, -INFINITY); int32_t visible_kv_max = 0; for (uint32_t j = 0; j < n_tokens; ++j) { float * row = lctx.dflash_kq_mask_data.data() + (size_t) j * (size_t) n_kv_total; - const int32_t visible_kv = cross_ctx + (int32_t) j + 1; + const int32_t visible_kv = cross_ctx + (int32_t) n_tokens; visible_kv_max = std::max(visible_kv_max, visible_kv); profile.graph_visible_kv_sum += (uint64_t) visible_kv; - for (int32_t i = left_pad; i < visible_kv; ++i) { + for (int32_t i = full_visible_first; i <= full_visible_last; ++i) { row[i] = 0.0f; } } ggml_backend_tensor_set(kq_mask, lctx.dflash_kq_mask_data.data(), 0, ggml_nbytes(kq_mask)); profile.graph_mask_build_us += (uint64_t) (ggml_time_us() - t_mask_us); profile.graph_mask_bytes += ggml_nbytes(kq_mask); + + if (kq_mask_swa != nullptr) { + lctx.dflash_kq_mask_swa_data.assign((size_t) n_kv_total * (size_t) n_mask_tokens, -INFINITY); + const int32_t swa_window = (int32_t) lctx.model.hparams.n_swa; + const int32_t draft_pos_base = (int32_t) profile.last_pos_last; + for (uint32_t j = 0; j < n_tokens; ++j) { + float * row = lctx.dflash_kq_mask_swa_data.data() + (size_t) j * (size_t) n_kv_total; + const int32_t q_pos = draft_pos_base + (int32_t) j; + + for (int32_t k = left_pad; k < cross_ctx; ++k) { + const int32_t k_pos = (int32_t) lctx.dflash_pos_ctx_data[(size_t) k]; + if (q_pos - k_pos < swa_window) { + row[k] = 0.0f; + } + } + + for (int32_t k = cross_ctx; k < cross_ctx + (int32_t) n_tokens; ++k) { + const int32_t block_k = k - cross_ctx; + if (block_k <= (int32_t) j) { + row[k] = 0.0f; + } + } + } + + ggml_backend_tensor_set(kq_mask_swa, lctx.dflash_kq_mask_swa_data.data(), 0, ggml_nbytes(kq_mask_swa)); + profile.graph_mask_bytes += ggml_nbytes(kq_mask_swa); + } + profile.graph_visible_kv_max = std::max(profile.graph_visible_kv_max, (uint64_t) visible_kv_max); profile.graph_prepare_total_us += (uint64_t) (ggml_time_us() - t_total_us); if (profile.graph_prepare_calls == 1) { - LLAMA_LOG_INFO("%s: DFlash graph contract rows=%d width=%d cross_ctx=%d n_tokens=%u left_pad=%d n_kv_total=%d draft_n_ctx=%u pos=%s [%d..%d]\n", + int32_t n_swa_layers = 0; + for (int32_t il = 0; il < lctx.model.hparams.n_layer; ++il) { + n_swa_layers += lctx.model.hparams.swa_layers[(size_t) il] ? 1 : 0; + } + + LLAMA_LOG_INFO("%s: DFlash graph contract rows=%d width=%d cross_ctx=%d n_tokens=%u left_pad=%d n_kv_total=%d draft_n_ctx=%u pos=%s [%d..%d] full_mask=[%d..%d] swa_window=%u swa_layers=%d\n", __func__, n_rows, width, cross_ctx, n_tokens, left_pad, n_kv_total, lctx.cparams.n_ctx, (src_pos != nullptr && total_positions == (size_t) n_rows) ? "target" : "synthetic", - (int) profile.last_pos_first, (int) profile.last_pos_last); + (int) profile.last_pos_first, (int) profile.last_pos_last, + full_visible_first, full_visible_last, + lctx.model.hparams.n_swa, + n_swa_layers); } return true; @@ -5322,6 +5573,8 @@ static int llama_decode_internal( const auto & hparams = model.hparams; const auto & cparams = lctx.cparams; + llama_begin_dflash_capture_batch(&lctx); + GGML_ASSERT((!batch_all.token && batch_all.embd) || (batch_all.token && !batch_all.embd)); // NOLINT GGML_ASSERT(n_tokens_all <= cparams.n_batch); @@ -5370,8 +5623,24 @@ static int llama_decode_internal( // reserve output buffer n_outputs_embd = has_mtp && cparams.mtp_op_type == MTP_OP_NONE ? n_tokens_all : n_outputs; - if (llama_output_reserve(lctx, std::max(n_outputs, n_outputs_embd)) < std::max(n_outputs, n_outputs_embd)) { - LLAMA_LOG_ERROR("%s: could not reserve space for batch with %zu outputs\n", __func__, std::max(n_outputs, n_outputs_embd)); + const size_t required_outputs = std::max(n_outputs, n_outputs_embd); + const bool is_dflash_decode = lctx.model.arch == LLM_ARCH_DFLASH_DRAFT; + const size_t output_buf_size_before = lctx.buf_output ? ggml_backend_buffer_get_size(lctx.buf_output) : 0; + const int64_t t_output_reserve_us = is_dflash_decode ? ggml_time_us() : 0; + const size_t reserved_outputs = llama_output_reserve(lctx, required_outputs); + if (is_dflash_decode) { + auto & profile = lctx.dflash_profile; + profile.decode_output_reserve_calls++; + profile.decode_output_reserve_us += (uint64_t) (ggml_time_us() - t_output_reserve_us); + + const size_t output_buf_size_after = lctx.buf_output ? ggml_backend_buffer_get_size(lctx.buf_output) : 0; + if (output_buf_size_after > output_buf_size_before) { + profile.decode_output_reserve_reallocs++; + profile.decode_output_reserve_realloc_bytes += (uint64_t) output_buf_size_after; + } + } + if (reserved_outputs < required_outputs) { + LLAMA_LOG_ERROR("%s: could not reserve space for batch with %zu outputs\n", __func__, required_outputs); return -2; }; @@ -5587,9 +5856,15 @@ static int llama_decode_internal( } if (lctx.model.arch == LLM_ARCH_DFLASH_DRAFT) { + auto & profile = lctx.dflash_profile; + profile.decode_prepare_calls++; + const int64_t t_prepare_dflash_us = ggml_time_us(); if (!prepare_dflash_graph_inputs(lctx, n_tokens)) { + profile.decode_prepare_failures++; + profile.decode_prepare_us += (uint64_t) (ggml_time_us() - t_prepare_dflash_us); return GGML_STATUS_FAILED; } + profile.decode_prepare_us += (uint64_t) (ggml_time_us() - t_prepare_dflash_us); } // the output is always the last tensor in the graph From ed403dca271e0b013be75d96ee9531fc97b768c2 Mon Sep 17 00:00:00 2001 From: SamuelOliveirads Date: Sun, 31 May 2026 14:51:21 -0300 Subject: [PATCH 05/13] Use windows update in kv cache --- common/speculative.cpp | 393 ++++++++++++++++++++++++++++++++-- src/graphs/build_dflash.cpp | 163 +++++++++++++- src/llama-build-context.cpp | 22 +- src/llama-context.h | 17 ++ src/llama-spec-features.cpp | 142 ++++++++++++- src/llama-spec-features.h | 101 ++++++++- src/llama.cpp | 410 +++++++++++++++++++++++++++++------- 7 files changed, 1133 insertions(+), 115 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index 911526a8..e7ce71f9 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -253,6 +253,8 @@ static const common_speculative_state_mtp * common_speculative_get_mtp_state(con static common_speculative_state_dflash * common_speculative_get_dflash_state(common_speculative * spec); static const common_speculative_state_dflash * common_speculative_get_dflash_state(const common_speculative * spec); static int32_t common_speculative_feature_width(const common_speculative * spec); +static void dflash_materialize_target_window_features(common_speculative_state_dflash & state); +static void dflash_ring_reset_rows(common_speculative_state_dflash & state, const float * rows, int32_t n_rows); static void dflash_append_target_features( common_speculative_state_dflash & state, const float * feature_rows, @@ -284,6 +286,17 @@ static bool dflash_contract_log_enabled() { std::strcmp(env, "off") != 0; } +static bool dflash_use_kv_cache_experiment() { + const char * env = std::getenv("IK_DFLASH_KV_CACHE"); + if (env == nullptr || *env == '\0') { + return false; + } + + return std::strcmp(env, "0") != 0 && + std::strcmp(env, "false") != 0 && + std::strcmp(env, "off") != 0; +} + template static std::string dflash_contract_format_values( const std::vector & values, @@ -479,7 +492,18 @@ struct common_speculative_state_dflash : public common_speculative_state { std::vector target_layer_ids; std::vector target_window; std::vector target_window_pos; + std::vector target_window_stage; + std::vector target_window_pos_stage; + std::vector target_window_ring; + std::vector target_window_append_features; int32_t target_window_rows = 0; + int32_t target_window_ring_write_pos = 0; + int32_t target_window_ring_filled = 0; + uint64_t target_window_version = 0; + int32_t target_window_keep_rows = 0; + int32_t target_window_append_rows = 0; + bool target_window_replace = false; + bool target_window_materialized = false; llama_pos last_target_pos = -1; size_t n_window_updates = 0; size_t n_rows_seen = 0; @@ -497,6 +521,13 @@ struct common_speculative_state_dflash : public common_speculative_state { uint64_t t_accept_output_copy_us = 0; uint64_t t_accept_commit_us = 0; uint64_t t_accept_append_us = 0; + uint64_t t_accept_append_filter_us = 0; + uint64_t t_accept_append_window_alloc_us = 0; + uint64_t t_accept_append_replace_us = 0; + uint64_t t_accept_append_keep_old_us = 0; + uint64_t t_accept_append_new_rows_us = 0; + uint64_t t_accept_append_commit_detail_us = 0; + uint64_t t_accept_append_log_us = 0; size_t n_warmup_collect_calls = 0; size_t n_warmup_collect_rows = 0; size_t n_warmup_append_calls = 0; @@ -507,6 +538,8 @@ struct common_speculative_state_dflash : public common_speculative_state { size_t n_accept_commit_rows = 0; size_t n_accept_append_calls = 0; size_t n_accept_append_rows = 0; + size_t n_accept_append_replace_calls = 0; + size_t n_accept_append_slide_calls = 0; common_speculative_state_dflash( enum common_speculative_type type, @@ -614,6 +647,12 @@ struct common_speculative_state_dflash : public common_speculative_state { } batch = llama_batch_init(std::max(1, block_size), 0, 1); + target_window.reserve((size_t) this->cross_ctx * (size_t) n_target_features); + target_window_stage.reserve((size_t) this->cross_ctx * (size_t) n_target_features); + target_window_ring.resize((size_t) this->cross_ctx * (size_t) n_target_features); + target_window_append_features.reserve((size_t) this->cross_ctx * (size_t) n_target_features); + target_window_pos.reserve((size_t) this->cross_ctx); + target_window_pos_stage.reserve((size_t) this->cross_ctx); ready = true; llama_set_dflash_visible_cross_ctx(ctx_dft, this->cross_ctx); @@ -648,6 +687,7 @@ struct common_speculative_state_dflash : public common_speculative_state { void begin(const llama_tokens & prompt) override { GGML_UNUSED(prompt); llama_kv_cache_clear(ctx_dft); + llama_reset_dflash_kv_cache_state(ctx_dft); n_window_updates = 0; n_rows_seen = 0; n_rows_dropped = 0; @@ -663,6 +703,13 @@ struct common_speculative_state_dflash : public common_speculative_state { t_accept_output_copy_us = 0; t_accept_commit_us = 0; t_accept_append_us = 0; + t_accept_append_filter_us = 0; + t_accept_append_window_alloc_us = 0; + t_accept_append_replace_us = 0; + t_accept_append_keep_old_us = 0; + t_accept_append_new_rows_us = 0; + t_accept_append_commit_detail_us = 0; + t_accept_append_log_us = 0; n_warmup_collect_calls = 0; n_warmup_collect_rows = 0; n_warmup_append_calls = 0; @@ -673,6 +720,8 @@ struct common_speculative_state_dflash : public common_speculative_state { n_accept_commit_rows = 0; n_accept_append_calls = 0; n_accept_append_rows = 0; + n_accept_append_replace_calls = 0; + n_accept_append_slide_calls = 0; llama_dflash_profile_reset(ctx_tgt); llama_dflash_profile_reset(ctx_dft); } @@ -695,7 +744,33 @@ struct common_speculative_state_dflash : public common_speculative_state { return; } - if (!llama_set_dflash_target_features_view(ctx_dft, target_window.data(), target_window.size(), target_window_rows, target_window_pos.data())) { + const bool use_kv_cache = dflash_use_kv_cache_experiment(); + const float * target_features = nullptr; + size_t target_feature_floats = 0; + llama_dflash_window_update window_update = { + target_window_version, + target_window_keep_rows, + target_window_append_rows, + target_window_replace, + target_window_append_features.empty() ? nullptr : target_window_append_features.data(), + target_window_append_features.size(), + }; + const llama_dflash_kv_cache_transition cache_plan = use_kv_cache + ? llama_plan_dflash_kv_cache_transition_for_ctx(ctx_dft, window_update, target_window_rows) + : llama_dflash_kv_cache_transition{}; + + if (!use_kv_cache || cache_plan.rebuild_cache) { + dflash_materialize_target_window_features(*this); + target_features = target_window.data(); + target_feature_floats = target_window.size(); + } + if (use_kv_cache && cache_plan.rebuild_cache) { + window_update.append_features = target_window.data(); + window_update.append_floats = target_window.size(); + window_update.append_rows = target_window_rows; + } + + if (!llama_set_dflash_target_features_view(ctx_dft, target_features, target_feature_floats, target_window_rows, target_window_pos.data(), &window_update)) { LOG_ERR("%s: failed to set DFlash target features\n", __func__); n_set_target_fail++; return; @@ -2522,6 +2597,24 @@ void common_speculative_print_stats(const common_speculative * spec, double slot const double kv_compute_ms = (double) graph_stats.graph_kv_cache_compute_us / 1000.0; const double kv_sync_ms = (double) graph_stats.graph_kv_cache_sync_us / 1000.0; const double replay_append_ms = (double) dflash_state->t_accept_append_us / 1000.0; + const double feature_path_ms = (double) ( + capture_stats.capture_prepare_sync_us + + capture_stats.capture_materialize_us + + graph_stats.set_target_copy_us + + graph_stats.graph_feature_copy_us + + graph_stats.graph_pos_copy_us + + graph_stats.graph_mask_build_us) / 1000.0; + const double decode_internal_ms = (double) ( + graph_stats.decode_prelude_us + + graph_stats.decode_sched_reset_us + + graph_stats.decode_build_graph_us + + graph_stats.decode_sched_alloc_graph_us + + graph_stats.decode_prepare_us + + graph_stats.decode_set_inputs_us + + graph_stats.decode_graph_compute_us + + graph_stats.decode_result_us + + graph_stats.decode_embedding_us + + graph_stats.decode_final_sched_reset_us) / 1000.0; LOG_INF("statistics dflash profile: capture(sync/materialize)=%.3f/%.3f ms calls=%llu/%llu bytes=%llu phase(prompt/verify batches changes)=%llu/%llu %llu/%llu, set_target=%.3f ms rows=%llu bytes=%llu, decode(llama_output_reserve/prepare)=%.3f/%.3f ms calls=%llu/%llu realloc(bytes)=%llu/%llu, prep(total/features/pos/mask)=%.3f/%.3f/%.3f/%.3f ms kv_cache(total/build/reserve/reset/alloc/up_f/up_p/compute/sync/read)=%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f ms calls(prepare/cache/read)=%llu/%llu/%llu bytes(feature/pos/mask/read)=%llu/%llu/%llu/%llu host_layers=%d, fallback_pos(copy/graph)=%llu/%llu, nonmono(copy/graph)=%llu/%llu, capture_fail=%llu/%llu decode_prepare_fail=%llu, visible_kv_max=%llu, last(rows=%d width=%d left_pad=%d n_tokens=%d n_kv=%d pos=[%d..%d])\n", (double) capture_stats.capture_prepare_sync_us / 1000.0, @@ -2580,6 +2673,81 @@ void common_speculative_print_stats(const common_speculative * spec, double slot (int) graph_stats.last_pos_first, (int) graph_stats.last_pos_last); + LOG_INF("statistics dflash features: total=%.3f ms capture(sync/materialize)=%.3f/%.3f ms set_target=%.3f ms prep(feature/pos/mask)=%.3f/%.3f/%.3f ms rows(materialize/set_target)=%llu/%llu bytes(materialize/set_target/feature/pos/mask)=%llu/%llu/%llu/%llu/%llu\n", + feature_path_ms, + (double) capture_stats.capture_prepare_sync_us / 1000.0, + (double) capture_stats.capture_materialize_us / 1000.0, + (double) graph_stats.set_target_copy_us / 1000.0, + (double) graph_stats.graph_feature_copy_us / 1000.0, + (double) graph_stats.graph_pos_copy_us / 1000.0, + (double) graph_stats.graph_mask_build_us / 1000.0, + (unsigned long long) capture_stats.capture_materialize_rows, + (unsigned long long) graph_stats.set_target_rows, + (unsigned long long) capture_stats.capture_materialize_bytes, + (unsigned long long) graph_stats.set_target_copy_bytes, + (unsigned long long) graph_stats.graph_feature_bytes, + (unsigned long long) graph_stats.graph_pos_bytes, + (unsigned long long) graph_stats.graph_mask_bytes); + + LOG_INF("statistics dflash kv: total=%.3f ms build/reserve/reset/alloc/upload_f/upload_p/compute/sync/read=%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f ms calls=%llu cached_bytes=%llu host_layers=%d\n", + kv_cache_total_ms, + (double) graph_stats.graph_kv_cache_build_us / 1000.0, + (double) graph_stats.graph_kv_cache_reserve_us / 1000.0, + (double) graph_stats.graph_kv_cache_reset_us / 1000.0, + (double) graph_stats.graph_kv_cache_alloc_us / 1000.0, + (double) graph_stats.graph_kv_cache_feature_upload_us / 1000.0, + (double) graph_stats.graph_kv_cache_pos_upload_us / 1000.0, + (double) graph_stats.graph_kv_cache_compute_us / 1000.0, + (double) graph_stats.graph_kv_cache_sync_us / 1000.0, + (double) graph_stats.graph_kv_cache_read_concat_pad_us / 1000.0, + (unsigned long long) graph_stats.graph_kv_cache_calls, + (unsigned long long) graph_stats.graph_kv_cache_cached_bytes, + graph_stats.last_kv_cache_host_layers); + + if (graph_stats.decode_internal_chunks > 0) { + LOG_INF("statistics dflash decode: llama_decode(total)=%.3f ms calls=%zu chunks=%llu rebuilds=%llu sync_points=%llu internal(total/prelude/sched_reset/build/alloc/prepare/set_inputs/compute/get_result/get_embedding/final_reset)=%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f ms\n", + (double) dflash_state->t_draft_decode_us / 1000.0, + dflash_state->n_call_draft, + (unsigned long long) graph_stats.decode_internal_chunks, + (unsigned long long) graph_stats.decode_graph_rebuilds, + (unsigned long long) graph_stats.decode_sync_profile_points, + decode_internal_ms, + (double) graph_stats.decode_prelude_us / 1000.0, + (double) graph_stats.decode_sched_reset_us / 1000.0, + (double) graph_stats.decode_build_graph_us / 1000.0, + (double) graph_stats.decode_sched_alloc_graph_us / 1000.0, + (double) graph_stats.decode_prepare_us / 1000.0, + (double) graph_stats.decode_set_inputs_us / 1000.0, + (double) graph_stats.decode_graph_compute_us / 1000.0, + (double) graph_stats.decode_result_us / 1000.0, + (double) graph_stats.decode_embedding_us / 1000.0, + (double) graph_stats.decode_final_sched_reset_us / 1000.0); + } + + if (graph_stats.graph_kv_node_fused_target_calls > 0 || + graph_stats.graph_kv_node_k_proj_calls > 0 || + graph_stats.graph_kv_node_k_norm_calls > 0 || + graph_stats.graph_kv_node_k_rope_calls > 0 || + graph_stats.graph_kv_node_v_proj_calls > 0 || + graph_stats.graph_kv_node_k_store_calls > 0 || + graph_stats.graph_kv_node_v_store_calls > 0) { + LOG_INF("statistics dflash kv nodes: fused_target/k_proj/k_norm/k_rope/v_proj/k_store/v_store=%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f ms calls=%llu/%llu/%llu/%llu/%llu/%llu/%llu\n", + (double) graph_stats.graph_kv_node_fused_target_us / 1000.0, + (double) graph_stats.graph_kv_node_k_proj_us / 1000.0, + (double) graph_stats.graph_kv_node_k_norm_us / 1000.0, + (double) graph_stats.graph_kv_node_k_rope_us / 1000.0, + (double) graph_stats.graph_kv_node_v_proj_us / 1000.0, + (double) graph_stats.graph_kv_node_k_store_us / 1000.0, + (double) graph_stats.graph_kv_node_v_store_us / 1000.0, + (unsigned long long) graph_stats.graph_kv_node_fused_target_calls, + (unsigned long long) graph_stats.graph_kv_node_k_proj_calls, + (unsigned long long) graph_stats.graph_kv_node_k_norm_calls, + (unsigned long long) graph_stats.graph_kv_node_k_rope_calls, + (unsigned long long) graph_stats.graph_kv_node_v_proj_calls, + (unsigned long long) graph_stats.graph_kv_node_k_store_calls, + (unsigned long long) graph_stats.graph_kv_node_v_store_calls); + } + LOG_INF("statistics dflash hot: kv(upload_f/upload_p/upload/compute/sync)=%.3f/%.3f/%.3f/%.3f/%.3f ms calls=%llu replay(accepted_prefix_append)=%.3f ms calls=%zu rows=%zu\n", kv_upload_feature_ms, kv_upload_pos_ms, @@ -2609,6 +2777,20 @@ void common_speculative_print_stats(const common_speculative * spec, double slot dflash_state->n_accept_commit_rows, dflash_state->n_accept_output_copy_rows, dflash_state->n_accept_append_rows); + + if (dflash_state->n_accept_append_calls > 0) { + LOG_INF("statistics dflash replay: append(filter/window_alloc/replace/keep_old/new_rows/commit/log)=%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f ms calls=%zu replace/slide=%zu/%zu\n", + (double) dflash_state->t_accept_append_filter_us / 1000.0, + (double) dflash_state->t_accept_append_window_alloc_us / 1000.0, + (double) dflash_state->t_accept_append_replace_us / 1000.0, + (double) dflash_state->t_accept_append_keep_old_us / 1000.0, + (double) dflash_state->t_accept_append_new_rows_us / 1000.0, + (double) dflash_state->t_accept_append_commit_detail_us / 1000.0, + (double) dflash_state->t_accept_append_log_us / 1000.0, + dflash_state->n_accept_append_calls, + dflash_state->n_accept_append_replace_calls, + dflash_state->n_accept_append_slide_calls); + } } } } @@ -2728,11 +2910,113 @@ static void mtp_clear_target_hidden(common_speculative_state_mtp & state, llama_ state.draft_cache_by_seq.erase(seq_id); } +struct dflash_append_breakdown { + uint64_t filter_us = 0; + uint64_t window_alloc_us = 0; + uint64_t replace_us = 0; + uint64_t keep_old_us = 0; + uint64_t new_rows_us = 0; + uint64_t commit_us = 0; + uint64_t log_us = 0; + bool replace_call = false; +}; + +static void dflash_record_window_update( + common_speculative_state_dflash & state, + int32_t keep_rows, + int32_t append_rows, + bool replace) { + state.target_window_keep_rows = std::max(0, keep_rows); + state.target_window_append_rows = std::max(0, append_rows); + state.target_window_replace = replace; + state.target_window_version++; +} + +static void dflash_ring_reset_rows( + common_speculative_state_dflash & state, + const float * rows, + int32_t n_rows) { + const size_t row_width = (size_t) state.n_target_features; + if (n_rows <= 0 || rows == nullptr) { + state.target_window_ring_write_pos = 0; + state.target_window_ring_filled = 0; + return; + } + + if (state.target_window_ring.size() != (size_t) state.cross_ctx * row_width) { + state.target_window_ring.resize((size_t) state.cross_ctx * row_width); + } + + std::memcpy(state.target_window_ring.data(), rows, (size_t) n_rows * row_width * sizeof(float)); + state.target_window_ring_write_pos = n_rows % state.cross_ctx; + state.target_window_ring_filled = n_rows; + state.target_window_materialized = false; +} + +static void dflash_ring_append_rows( + common_speculative_state_dflash & state, + const float * rows, + int32_t n_rows) { + const size_t row_width = (size_t) state.n_target_features; + if (n_rows <= 0 || rows == nullptr) { + return; + } + + if (state.target_window_ring.size() != (size_t) state.cross_ctx * row_width) { + state.target_window_ring.resize((size_t) state.cross_ctx * row_width); + } + + int32_t write_pos = state.target_window_ring_write_pos; + int32_t remaining = n_rows; + const float * src = rows; + while (remaining > 0) { + const int32_t chunk_rows = std::min(remaining, state.cross_ctx - write_pos); + std::memcpy( + state.target_window_ring.data() + (size_t) write_pos * row_width, + src, + (size_t) chunk_rows * row_width * sizeof(float)); + src += (size_t) chunk_rows * row_width; + remaining -= chunk_rows; + write_pos = (write_pos + chunk_rows) % state.cross_ctx; + } + + state.target_window_ring_write_pos = write_pos; + state.target_window_ring_filled = std::min(state.cross_ctx, state.target_window_ring_filled + n_rows); + state.target_window_materialized = false; +} + +static void dflash_materialize_target_window_features(common_speculative_state_dflash & state) { + if (state.target_window_materialized || state.target_window_rows <= 0) { + return; + } + + const size_t row_width = (size_t) state.n_target_features; + state.target_window.resize((size_t) state.target_window_rows * row_width); + + const int32_t read_start = (state.target_window_ring_write_pos - state.target_window_rows + state.cross_ctx) % state.cross_ctx; + const int32_t first_rows = std::min(state.target_window_rows, state.cross_ctx - read_start); + std::memcpy( + state.target_window.data(), + state.target_window_ring.data() + (size_t) read_start * row_width, + (size_t) first_rows * row_width * sizeof(float)); + + const int32_t second_rows = state.target_window_rows - first_rows; + if (second_rows > 0) { + std::memcpy( + state.target_window.data() + (size_t) first_rows * row_width, + state.target_window_ring.data(), + (size_t) second_rows * row_width * sizeof(float)); + } + + state.target_window_materialized = true; +} + static bool dflash_append_target_features( common_speculative_state_dflash & state, const common_speculative_feature_view & features, const llama_batch & batch, - llama_seq_id seq_id) { + llama_seq_id seq_id, + dflash_append_breakdown * breakdown = nullptr) { GGML_UNUSED(batch); if (features.kind != COMMON_SPECULATIVE_FEATURE_HIDDEN_STATE || @@ -2748,6 +3032,7 @@ static bool dflash_append_target_features( new_rows.reserve(features.rows.size() * row_width); new_positions.reserve(features.rows.size()); + const int64_t t_filter_us = ggml_time_us(); for (const auto & row : features.rows) { if (row.seq_id != seq_id || row.data == nullptr) { continue; @@ -2756,6 +3041,9 @@ static bool dflash_append_target_features( new_positions.push_back(row.pos); new_rows.insert(new_rows.end(), row.data, row.data + row_width); } + if (breakdown != nullptr) { + breakdown->filter_us += (uint64_t) (ggml_time_us() - t_filter_us); + } if (new_positions.empty()) { return false; @@ -2767,46 +3055,93 @@ static bool dflash_append_target_features( if (n_rows >= state.cross_ctx) { state.n_rows_dropped += (size_t) state.target_window_rows + (size_t) (n_rows - state.cross_ctx); const int32_t keep_from = n_rows - state.cross_ctx; - state.target_window.assign( + const int64_t t_replace_us = ggml_time_us(); + state.target_window_pos.assign(new_positions.begin() + keep_from, new_positions.end()); + state.target_window_append_features.assign( new_rows.begin() + (ptrdiff_t) keep_from * (ptrdiff_t) row_width, new_rows.end()); - state.target_window_pos.assign(new_positions.begin() + keep_from, new_positions.end()); + dflash_ring_reset_rows(state, state.target_window_append_features.data(), state.cross_ctx); + if (breakdown != nullptr) { + breakdown->replace_us += (uint64_t) (ggml_time_us() - t_replace_us); + breakdown->replace_call = true; + } + + const int64_t t_commit_us = ggml_time_us(); state.target_window_rows = state.cross_ctx; + state.target_window_ring_filled = state.target_window_rows; state.last_target_pos = state.target_window_pos.empty() ? -1 : state.target_window_pos.back(); + dflash_record_window_update(state, 0, state.target_window_rows, true); + if (breakdown != nullptr) { + breakdown->commit_us += (uint64_t) (ggml_time_us() - t_commit_us); + } + + const int64_t t_log_us = ggml_time_us(); dflash_contract_log_append(state, seq_id, new_positions); + if (breakdown != nullptr) { + breakdown->log_us += (uint64_t) (ggml_time_us() - t_log_us); + } return true; } const int32_t keep_old_rows = std::min(state.target_window_rows, state.cross_ctx - n_rows); state.n_rows_dropped += (size_t) std::max(0, state.target_window_rows - keep_old_rows); - std::vector next_window((size_t) (keep_old_rows + n_rows) * row_width); - std::vector next_window_pos((size_t) (keep_old_rows + n_rows)); - - if (keep_old_rows > 0) { - const float * old_src = state.target_window.data() + (size_t) (state.target_window_rows - keep_old_rows) * row_width; - std::memcpy(next_window.data(), old_src, (size_t) keep_old_rows * row_width * sizeof(float)); - std::copy(state.target_window_pos.end() - keep_old_rows, state.target_window_pos.end(), next_window_pos.begin()); + const int64_t t_window_alloc_us = ggml_time_us(); + std::vector & next_window_pos = state.target_window_pos_stage; + next_window_pos.resize((size_t) (keep_old_rows + n_rows)); + if (breakdown != nullptr) { + breakdown->window_alloc_us += (uint64_t) (ggml_time_us() - t_window_alloc_us); } - std::memcpy( - next_window.data() + (size_t) keep_old_rows * row_width, - new_rows.data(), - (size_t) n_rows * row_width * sizeof(float)); - std::copy(new_positions.begin(), new_positions.end(), next_window_pos.begin() + keep_old_rows); + if (keep_old_rows > 0) { + const int64_t t_keep_old_us = ggml_time_us(); + std::copy(state.target_window_pos.end() - keep_old_rows, state.target_window_pos.end(), next_window_pos.begin()); + if (breakdown != nullptr) { + breakdown->keep_old_us += (uint64_t) (ggml_time_us() - t_keep_old_us); + } + } - state.target_window = std::move(next_window); - state.target_window_pos = std::move(next_window_pos); + const int64_t t_new_rows_us = ggml_time_us(); + state.target_window_append_features.assign(new_rows.begin(), new_rows.end()); + dflash_ring_append_rows(state, state.target_window_append_features.data(), n_rows); + std::copy(new_positions.begin(), new_positions.end(), next_window_pos.begin() + keep_old_rows); + if (breakdown != nullptr) { + breakdown->new_rows_us += (uint64_t) (ggml_time_us() - t_new_rows_us); + } + + const int64_t t_commit_us = ggml_time_us(); + state.target_window_pos.swap(next_window_pos); + next_window_pos.clear(); state.target_window_rows = keep_old_rows + n_rows; + state.target_window_ring_filled = state.target_window_rows; state.last_target_pos = state.target_window_pos.empty() ? -1 : state.target_window_pos.back(); + dflash_record_window_update(state, keep_old_rows, n_rows, false); + if (breakdown != nullptr) { + breakdown->commit_us += (uint64_t) (ggml_time_us() - t_commit_us); + } + + const int64_t t_log_us = ggml_time_us(); dflash_contract_log_append(state, seq_id, new_positions); + if (breakdown != nullptr) { + breakdown->log_us += (uint64_t) (ggml_time_us() - t_log_us); + } return true; } static void dflash_clear_target_features(common_speculative_state_dflash & state) { state.target_window.clear(); state.target_window_pos.clear(); + state.target_window_stage.clear(); + state.target_window_pos_stage.clear(); + state.target_window_append_features.clear(); state.target_window_rows = 0; + state.target_window_ring_write_pos = 0; + state.target_window_ring_filled = 0; + state.target_window_keep_rows = 0; + state.target_window_append_rows = 0; + state.target_window_replace = false; + state.target_window_materialized = false; state.last_target_pos = -1; + llama_reset_dflash_kv_cache_state(state.ctx_dft); } static void dflash_context_shift( @@ -2814,10 +3149,12 @@ static void dflash_context_shift( llama_pos kv_keep, llama_pos kv_discard, llama_pos kv_past) { - if (kv_discard <= 0 || state.target_window_rows <= 0 || state.target_window.empty() || state.target_window_pos.empty()) { + if (kv_discard <= 0 || state.target_window_rows <= 0 || state.target_window_pos.empty()) { return; } + dflash_materialize_target_window_features(state); + const size_t row_width = (size_t) state.n_target_features; const llama_pos discard_begin = kv_keep; const llama_pos discard_end = kv_keep + kv_discard; @@ -2845,7 +3182,10 @@ static void dflash_context_shift( state.target_window = std::move(shifted_rows); state.target_window_pos = std::move(shifted_positions); state.target_window_rows = (int32_t) state.target_window_pos.size(); + dflash_ring_reset_rows(state, state.target_window.data(), state.target_window_rows); state.last_target_pos = state.target_window_pos.empty() ? -1 : state.target_window_pos.back(); + dflash_record_window_update(state, 0, state.target_window_rows, true); + llama_reset_dflash_kv_cache_state(state.ctx_dft); state.n_context_shifts++; } @@ -2959,8 +3299,9 @@ int32_t common_speculative_on_target_batch( } } + dflash_append_breakdown append_breakdown; const int64_t t_append_us = ggml_time_us(); - if (!dflash_append_target_features(*dflash_state, features, batch, seq_id)) { + if (!dflash_append_target_features(*dflash_state, features, batch, seq_id, &append_breakdown)) { return -1; } @@ -2971,8 +3312,20 @@ int32_t common_speculative_on_target_batch( dflash_state->n_warmup_append_rows += (size_t) batch.n_tokens; } else { dflash_state->t_accept_append_us += append_us; + dflash_state->t_accept_append_filter_us += append_breakdown.filter_us; + dflash_state->t_accept_append_window_alloc_us += append_breakdown.window_alloc_us; + dflash_state->t_accept_append_replace_us += append_breakdown.replace_us; + dflash_state->t_accept_append_keep_old_us += append_breakdown.keep_old_us; + dflash_state->t_accept_append_new_rows_us += append_breakdown.new_rows_us; + dflash_state->t_accept_append_commit_detail_us += append_breakdown.commit_us; + dflash_state->t_accept_append_log_us += append_breakdown.log_us; dflash_state->n_accept_append_calls++; dflash_state->n_accept_append_rows += (size_t) batch.n_tokens; + if (append_breakdown.replace_call) { + dflash_state->n_accept_append_replace_calls++; + } else { + dflash_state->n_accept_append_slide_calls++; + } } return 0; diff --git a/src/graphs/build_dflash.cpp b/src/graphs/build_dflash.cpp index b9862c2a..a5b9a815 100644 --- a/src/graphs/build_dflash.cpp +++ b/src/graphs/build_dflash.cpp @@ -23,38 +23,132 @@ ggml_cgraph * llm_build_context::build_dflash_kv_cache() { const int64_t ctx_len = lctx.dflash_visible_cross_ctx > 0 ? (int64_t) lctx.dflash_visible_cross_ctx : std::max(1, (int64_t) cparams.n_ctx - (int64_t) hparams.dflash_block_size); + const int64_t update_rows = std::max(1, lctx.dflash_kv_cache_update_rows > 0 ? lctx.dflash_kv_cache_update_rows : ctx_len); + const int32_t write_pos = lctx.dflash_kv_cache_write_pos; GGML_ASSERT(n_embd_head_k == n_embd_head_v); GGML_ASSERT(n_target_features > 0); GGML_ASSERT(lctx.ensure_dflash_kv_cache_tensors((int32_t) ctx_len)); + GGML_ASSERT(update_rows > 0 && update_rows <= ctx_len); + GGML_ASSERT(write_pos >= 0 && write_pos < ctx_len); - ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes((int) std::max(1, ctx_len)) + 24 * n_layer, false); + ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes((int) std::max(1, update_rows)) + 24 * n_layer, false); - lctx.dflash_kv_input_target_features = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_target_features, ctx_len); + lctx.dflash_kv_input_target_features = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_target_features, update_rows); ggml_set_input(lctx.dflash_kv_input_target_features); + cb(lctx.dflash_kv_input_target_features, "dflash_kv_input_target_features", -1); - lctx.dflash_kv_input_pos_ctx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ctx_len); + lctx.dflash_kv_input_pos_ctx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, update_rows); ggml_set_input(lctx.dflash_kv_input_pos_ctx); + cb(lctx.dflash_kv_input_pos_ctx, "dflash_kv_input_pos_ctx", -1); ggml_tensor * fused_target = llm_build_lora_mm(lctx, ctx0, model.dflash_fc, lctx.dflash_kv_input_target_features); fused_target = llm_build_norm(ctx0, fused_target, hparams, model.dflash_hidden_norm, nullptr, LLM_NORM_RMS, cb, -1); + cb(fused_target, "dflash_kv_fused_target", -1); for (int il = 0; il < n_layer; ++il) { GGML_ASSERT((size_t) il < lctx.dflash_k_ctx_cache.size()); GGML_ASSERT((size_t) il < lctx.dflash_v_ctx_cache.size()); - ggml_tensor * Kcur_ctx = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, fused_target); - Kcur_ctx = ggml_reshape_3d(ctx0, Kcur_ctx, n_embd_head_k, n_head_kv, ctx_len); + ggml_tensor * Kcur_ctx_proj = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, fused_target); + cb(Kcur_ctx_proj, "dflash_kv_k_proj", il); + + ggml_tensor * Kcur_ctx = ggml_reshape_3d(ctx0, Kcur_ctx_proj, n_embd_head_k, n_head_kv, update_rows); Kcur_ctx = llm_build_norm(ctx0, Kcur_ctx, hparams, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, cb, il); + cb(Kcur_ctx, "dflash_kv_k_norm", il); Kcur_ctx = ggml_rope_ext(ctx0, Kcur_ctx, lctx.dflash_kv_input_pos_ctx, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + cb(Kcur_ctx, "dflash_kv_k_rope", il); ggml_tensor * Vcur_ctx = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, fused_target); - Vcur_ctx = ggml_reshape_3d(ctx0, Vcur_ctx, n_embd_head_v, n_head_kv, ctx_len); + cb(Vcur_ctx, "dflash_kv_v_proj", il); + Vcur_ctx = ggml_reshape_3d(ctx0, Vcur_ctx, n_embd_head_v, n_head_kv, update_rows); - ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur_ctx, lctx.dflash_k_ctx_cache[(size_t) il])); - ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur_ctx, lctx.dflash_v_ctx_cache[(size_t) il])); + const int32_t first_rows = std::min((int32_t) update_rows, (int32_t) ctx_len - write_pos); + const int32_t second_rows = (int32_t) update_rows - first_rows; + + if (first_rows > 0) { + ggml_tensor * Ksrc_first = first_rows == update_rows + ? Kcur_ctx + : ggml_view_3d(ctx0, Kcur_ctx, + Kcur_ctx->ne[0], + Kcur_ctx->ne[1], + first_rows, + Kcur_ctx->nb[1], + Kcur_ctx->nb[2], + 0); + ggml_tensor * Vsrc_first = first_rows == update_rows + ? Vcur_ctx + : ggml_view_3d(ctx0, Vcur_ctx, + Vcur_ctx->ne[0], + Vcur_ctx->ne[1], + first_rows, + Vcur_ctx->nb[1], + Vcur_ctx->nb[2], + 0); + ggml_tensor * Kdst_first = ggml_view_3d(ctx0, lctx.dflash_k_ctx_cache[(size_t) il], + lctx.dflash_k_ctx_cache[(size_t) il]->ne[0], + lctx.dflash_k_ctx_cache[(size_t) il]->ne[1], + first_rows, + lctx.dflash_k_ctx_cache[(size_t) il]->nb[1], + lctx.dflash_k_ctx_cache[(size_t) il]->nb[2], + (size_t) write_pos * lctx.dflash_k_ctx_cache[(size_t) il]->nb[2]); + ggml_tensor * Vdst_first = ggml_view_3d(ctx0, lctx.dflash_v_ctx_cache[(size_t) il], + lctx.dflash_v_ctx_cache[(size_t) il]->ne[0], + lctx.dflash_v_ctx_cache[(size_t) il]->ne[1], + first_rows, + lctx.dflash_v_ctx_cache[(size_t) il]->nb[1], + lctx.dflash_v_ctx_cache[(size_t) il]->nb[2], + (size_t) write_pos * lctx.dflash_v_ctx_cache[(size_t) il]->nb[2]); + + ggml_tensor * Kstore_first = ggml_cpy(ctx0, Ksrc_first, Kdst_first); + cb(Kstore_first, "dflash_kv_k_store", il); + ggml_build_forward_expand(gf, Kstore_first); + + ggml_tensor * Vstore_first = ggml_cpy(ctx0, Vsrc_first, Vdst_first); + cb(Vstore_first, "dflash_kv_v_store", il); + ggml_build_forward_expand(gf, Vstore_first); + } + + if (second_rows > 0) { + ggml_tensor * Ksrc_second = ggml_view_3d(ctx0, Kcur_ctx, + Kcur_ctx->ne[0], + Kcur_ctx->ne[1], + second_rows, + Kcur_ctx->nb[1], + Kcur_ctx->nb[2], + (size_t) first_rows * Kcur_ctx->nb[2]); + ggml_tensor * Vsrc_second = ggml_view_3d(ctx0, Vcur_ctx, + Vcur_ctx->ne[0], + Vcur_ctx->ne[1], + second_rows, + Vcur_ctx->nb[1], + Vcur_ctx->nb[2], + (size_t) first_rows * Vcur_ctx->nb[2]); + ggml_tensor * Kdst_second = ggml_view_3d(ctx0, lctx.dflash_k_ctx_cache[(size_t) il], + lctx.dflash_k_ctx_cache[(size_t) il]->ne[0], + lctx.dflash_k_ctx_cache[(size_t) il]->ne[1], + second_rows, + lctx.dflash_k_ctx_cache[(size_t) il]->nb[1], + lctx.dflash_k_ctx_cache[(size_t) il]->nb[2], + 0); + ggml_tensor * Vdst_second = ggml_view_3d(ctx0, lctx.dflash_v_ctx_cache[(size_t) il], + lctx.dflash_v_ctx_cache[(size_t) il]->ne[0], + lctx.dflash_v_ctx_cache[(size_t) il]->ne[1], + second_rows, + lctx.dflash_v_ctx_cache[(size_t) il]->nb[1], + lctx.dflash_v_ctx_cache[(size_t) il]->nb[2], + 0); + + ggml_tensor * Kstore_second = ggml_cpy(ctx0, Ksrc_second, Kdst_second); + cb(Kstore_second, "dflash_kv_k_store", il); + ggml_build_forward_expand(gf, Kstore_second); + + ggml_tensor * Vstore_second = ggml_cpy(ctx0, Vsrc_second, Vdst_second); + cb(Vstore_second, "dflash_kv_v_store", il); + ggml_build_forward_expand(gf, Vstore_second); + } } return gf; @@ -69,12 +163,17 @@ ggml_cgraph * llm_build_context::build_dflash() { const int64_t ctx_len = lctx.dflash_visible_cross_ctx > 0 ? (int64_t) lctx.dflash_visible_cross_ctx : std::max(1, (int64_t) cparams.n_ctx - (int64_t) hparams.dflash_block_size); + const int32_t cache_rows = use_kv_cache ? std::clamp(lctx.dflash_kv_cache_view_n_filled, 0, (int32_t) ctx_len) : 0; + const int32_t cache_write_pos = use_kv_cache && ctx_len > 0 + ? ((lctx.dflash_kv_cache_view_write_pos % (int32_t) ctx_len) + (int32_t) ctx_len) % (int32_t) ctx_len + : 0; const int64_t n_kv_total = GGML_PAD(ctx_len + n_tokens, flash_attn ? 256 : 32); const int64_t n_kv_pad = n_kv_total - (ctx_len + n_tokens); GGML_ASSERT(n_embd_head_k == n_embd_head_v); GGML_ASSERT(n_target_features > 0); GGML_ASSERT(!use_kv_cache || lctx.ensure_dflash_kv_cache_tensors((int32_t) ctx_len)); + GGML_ASSERT(!use_kv_cache || (cache_write_pos >= 0 && cache_write_pos < ctx_len)); ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes((int) std::max(n_tokens, ctx_len)) + 32 * n_layer, false); @@ -160,8 +259,52 @@ ggml_cgraph * llm_build_context::build_dflash() { ggml_tensor * Kcur_ctx = nullptr; ggml_tensor * Vcur_ctx = nullptr; if (use_kv_cache) { - Kcur_ctx = lctx.dflash_k_ctx_cache[(size_t) il]; - Vcur_ctx = lctx.dflash_v_ctx_cache[(size_t) il]; + auto build_ordered_cache_view = [&](ggml_tensor * cache) -> ggml_tensor * { + if (!lctx.dflash_kv_cache_view_valid || cache_rows <= 0) { + return cache; + } + + if (cache_rows < ctx_len) { + ggml_tensor * zero_pad = ggml_view_3d(ctx0, cache, + cache->ne[0], + cache->ne[1], + ctx_len - cache_rows, + cache->nb[1], + cache->nb[2], + (size_t) cache_rows * cache->nb[2]); + ggml_tensor * valid = ggml_view_3d(ctx0, cache, + cache->ne[0], + cache->ne[1], + cache_rows, + cache->nb[1], + cache->nb[2], + 0); + return ggml_concat(ctx0, zero_pad, valid, 2); + } + + if (cache_write_pos == 0) { + return cache; + } + + ggml_tensor * tail = ggml_view_3d(ctx0, cache, + cache->ne[0], + cache->ne[1], + ctx_len - cache_write_pos, + cache->nb[1], + cache->nb[2], + (size_t) cache_write_pos * cache->nb[2]); + ggml_tensor * head = ggml_view_3d(ctx0, cache, + cache->ne[0], + cache->ne[1], + cache_write_pos, + cache->nb[1], + cache->nb[2], + 0); + return ggml_concat(ctx0, tail, head, 2); + }; + + Kcur_ctx = build_ordered_cache_view(lctx.dflash_k_ctx_cache[(size_t) il]); + Vcur_ctx = build_ordered_cache_view(lctx.dflash_v_ctx_cache[(size_t) il]); cb(Kcur_ctx, "Kcur_ctx_cache", il); cb(Vcur_ctx, "Vcur_ctx_cache", il); } else { diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index 081215dd..fc03353c 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -2173,7 +2173,27 @@ struct ggml_cgraph * llm_build_context::llama_build_graph_dflash_kv_cache(llama_ llama_batch dummy; dummy.n_tokens = 0; - llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { }; + llm_build_cb cb = [&](struct ggml_tensor * cur, const char * name, int il) { + if (il >= 0) { + int j = 0; + for (; j < GGML_MAX_NAME - 1; ++j) { + cur->name[j] = name[j]; + if (!name[j]) { + break; + } + } + if (j < GGML_MAX_NAME - 3) { + cur->name[j++] = '-'; + auto sil = std::to_string(il); + for (int k = 0; k < (int) sil.size() && j < GGML_MAX_NAME - 1; ++k) { + cur->name[j++] = sil[k]; + } + } + cur->name[j] = 0; + } else { + ggml_set_name(cur, name); + } + }; struct llm_build_context llm(lctx, dummy, cb, false, false, 0, false, &lctx.dflash_buf_compute_meta); diff --git a/src/llama-context.h b/src/llama-context.h index cc207a36..1a7a9d80 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -281,9 +281,17 @@ struct llama_context { const float * dflash_target_features = nullptr; size_t dflash_target_features_n_floats = 0; int32_t dflash_target_features_n_rows = 0; + const float * dflash_target_append_features = nullptr; + size_t dflash_target_append_features_n_floats = 0; + int32_t dflash_target_append_features_n_rows = 0; const llama_pos * dflash_target_positions = nullptr; size_t dflash_target_positions_n = 0; + uint64_t dflash_target_window_version = 0; + int32_t dflash_target_window_keep_rows = 0; + int32_t dflash_target_window_append_rows = 0; + bool dflash_target_window_replace = false; std::vector dflash_target_features_owned; + std::vector dflash_target_append_features_owned; std::vector dflash_target_positions_owned; std::vector dflash_target_features_padded; std::vector dflash_feature_view_buffer; @@ -295,6 +303,15 @@ struct llama_context { std::vector dflash_v_ctx_cache; struct ggml_context * dflash_cache_ctx = nullptr; std::vector dflash_cache_bufs; + int32_t dflash_kv_cache_write_pos = 0; + int32_t dflash_kv_cache_n_filled = 0; + int32_t dflash_kv_cache_update_rows = 0; + int32_t dflash_kv_cache_reserved_rows = 0; + int32_t dflash_kv_cache_view_write_pos = 0; + int32_t dflash_kv_cache_view_n_filled = 0; + uint64_t dflash_kv_cache_applied_window_version = 0; + bool dflash_kv_cache_valid = false; + bool dflash_kv_cache_view_valid = false; std::vector dflash_buf_compute_meta; ggml_backend_sched_t dflash_sched = nullptr; struct ggml_tensor * dflash_kv_input_target_features = nullptr; diff --git a/src/llama-spec-features.cpp b/src/llama-spec-features.cpp index bcf1ca89..ab8efddb 100644 --- a/src/llama-spec-features.cpp +++ b/src/llama-spec-features.cpp @@ -55,6 +55,56 @@ void llama_dflash_profile_reset(struct llama_context * ctx) { ctx->dflash_profile = {}; } +void llama_reset_dflash_kv_cache_state(struct llama_context * ctx) { + if (ctx == nullptr) { + return; + } + + ctx->dflash_kv_cache_write_pos = 0; + ctx->dflash_kv_cache_n_filled = 0; + ctx->dflash_kv_cache_update_rows = 0; + ctx->dflash_kv_cache_view_write_pos = 0; + ctx->dflash_kv_cache_view_n_filled = 0; + ctx->dflash_kv_cache_applied_window_version = 0; + ctx->dflash_kv_cache_valid = false; + ctx->dflash_kv_cache_view_valid = false; + + for (ggml_backend_buffer_t buf : ctx->dflash_cache_bufs) { + if (buf != nullptr) { + ggml_backend_buffer_clear(buf, 0); + } + } +} + +llama_dflash_kv_cache_transition llama_plan_dflash_kv_cache_transition_for_ctx( + const struct llama_context * ctx, + const llama_dflash_window_update & window_update, + int32_t n_rows) { + if (ctx == nullptr) { + llama_dflash_kv_cache_transition plan; + plan.rebuild_cache = true; + plan.append_rows = std::clamp(window_update.append_rows, 0, n_rows); + plan.next_n_filled = n_rows; + return plan; + } + + const int32_t cross_ctx = ctx->dflash_visible_cross_ctx > 0 + ? ctx->dflash_visible_cross_ctx + : std::max(1, (int32_t) ctx->cparams.n_ctx - (int32_t) ctx->model.hparams.dflash_block_size); + + return llama_plan_dflash_kv_cache_transition( + cross_ctx, + ctx->dflash_kv_cache_n_filled, + ctx->dflash_kv_cache_write_pos, + ctx->dflash_kv_cache_valid, + ctx->dflash_kv_cache_applied_window_version, + window_update.version, + window_update.keep_rows, + window_update.append_rows, + window_update.replace, + n_rows); +} + void llama_set_dflash_visible_cross_ctx( struct llama_context * ctx, int32_t cross_ctx) { @@ -205,26 +255,91 @@ static bool llama_set_dflash_target_features_impl( size_t n_floats, int32_t n_rows, const llama_pos * target_positions, - bool copy_data) { - if (ctx == nullptr || target_features == nullptr || n_floats == 0 || n_rows <= 0) { + bool copy_data, + const llama_dflash_window_update * window_update) { + const bool have_full_features = target_features != nullptr && n_floats > 0; + const bool have_append_features = window_update != nullptr && + window_update->append_features != nullptr && + window_update->append_floats > 0 && + window_update->append_rows > 0; + + if (ctx == nullptr || n_rows <= 0 || (!have_full_features && !have_append_features)) { return false; } auto & profile = ctx->dflash_profile; const int64_t t_start_us = ggml_time_us(); - const int32_t row_width = n_rows > 0 ? (int32_t) (n_floats / (size_t) n_rows) : 0; + const int32_t row_width = have_full_features + ? (n_rows > 0 ? (int32_t) (n_floats / (size_t) n_rows) : 0) + : (window_update->append_rows > 0 ? (int32_t) (window_update->append_floats / (size_t) window_update->append_rows) : 0); llama_pos first_pos = -1; llama_pos last_pos = -1; - if (copy_data) { + if (have_full_features && copy_data) { ctx->dflash_target_features_owned.assign(target_features, target_features + n_floats); ctx->dflash_target_features = ctx->dflash_target_features_owned.data(); - } else { + } else if (have_full_features) { ctx->dflash_target_features_owned.clear(); ctx->dflash_target_features = target_features; + } else { + ctx->dflash_target_features_owned.clear(); + ctx->dflash_target_features = nullptr; } - ctx->dflash_target_features_n_floats = n_floats; + ctx->dflash_target_features_n_floats = have_full_features ? n_floats : 0; ctx->dflash_target_features_n_rows = n_rows; + if (have_append_features && copy_data) { + ctx->dflash_target_append_features_owned.assign( + window_update->append_features, + window_update->append_features + window_update->append_floats); + ctx->dflash_target_append_features = ctx->dflash_target_append_features_owned.data(); + } else if (have_append_features) { + ctx->dflash_target_append_features_owned.clear(); + ctx->dflash_target_append_features = window_update->append_features; + } else { + ctx->dflash_target_append_features_owned.clear(); + ctx->dflash_target_append_features = nullptr; + } + ctx->dflash_target_append_features_n_floats = have_append_features ? window_update->append_floats : 0; + ctx->dflash_target_append_features_n_rows = have_append_features ? window_update->append_rows : 0; + ctx->dflash_target_window_version = window_update != nullptr && window_update->version > 0 + ? window_update->version + : ctx->dflash_target_window_version + 1; + ctx->dflash_target_window_keep_rows = window_update != nullptr + ? std::max(0, std::min(n_rows, window_update->keep_rows)) + : 0; + ctx->dflash_target_window_append_rows = window_update != nullptr + ? std::max(0, std::min(n_rows, window_update->append_rows)) + : n_rows; + ctx->dflash_target_window_replace = window_update != nullptr + ? window_update->replace + : true; + if (ctx->dflash_target_window_keep_rows + ctx->dflash_target_window_append_rows > n_rows) { + ctx->dflash_target_window_keep_rows = std::max(0, n_rows - ctx->dflash_target_window_append_rows); + } + + const int32_t cross_ctx = ctx->dflash_visible_cross_ctx > 0 + ? ctx->dflash_visible_cross_ctx + : std::max(1, (int32_t) ctx->cparams.n_ctx - (int32_t) ctx->model.hparams.dflash_block_size); + const llama_dflash_window_update cache_window_update = { + ctx->dflash_target_window_version, + ctx->dflash_target_window_keep_rows, + ctx->dflash_target_window_append_rows, + ctx->dflash_target_window_replace, + ctx->dflash_target_append_features, + ctx->dflash_target_append_features_n_floats, + }; + const llama_dflash_kv_cache_transition cache_plan = llama_plan_dflash_kv_cache_transition_for_ctx(ctx, cache_window_update, n_rows); + + if (cache_plan.cache_up_to_date) { + ctx->dflash_kv_cache_view_n_filled = ctx->dflash_kv_cache_n_filled; + ctx->dflash_kv_cache_view_write_pos = ctx->dflash_kv_cache_write_pos; + ctx->dflash_kv_cache_view_valid = ctx->dflash_kv_cache_valid; + } else if (cross_ctx > 0) { + ctx->dflash_kv_cache_view_n_filled = cache_plan.next_n_filled; + ctx->dflash_kv_cache_view_write_pos = cache_plan.next_write_pos; + ctx->dflash_kv_cache_view_valid = cache_plan.next_n_filled > 0; + } + if (target_positions != nullptr) { if (copy_data) { ctx->dflash_target_positions_owned.assign(target_positions, target_positions + n_rows); @@ -243,7 +358,10 @@ static bool llama_set_dflash_target_features_impl( profile.set_target_copy_calls++; profile.set_target_copy_us += (uint64_t) (ggml_time_us() - t_start_us); profile.set_target_rows += (uint64_t) n_rows; - profile.set_target_copy_bytes += n_floats * sizeof(float) + (target_positions ? (size_t) n_rows * sizeof(llama_pos) : 0); + profile.set_target_copy_bytes += + (have_full_features ? n_floats : 0) * sizeof(float) + + (have_append_features ? window_update->append_floats : 0) * sizeof(float) + + (target_positions ? (size_t) n_rows * sizeof(llama_pos) : 0); profile.last_n_rows = n_rows; profile.last_width = row_width; @@ -267,8 +385,9 @@ bool llama_set_dflash_target_features_copy( const float * target_features, size_t n_floats, int32_t n_rows, - const llama_pos * target_positions) { - return llama_set_dflash_target_features_impl(ctx, target_features, n_floats, n_rows, target_positions, true); + const llama_pos * target_positions, + const llama_dflash_window_update * window_update) { + return llama_set_dflash_target_features_impl(ctx, target_features, n_floats, n_rows, target_positions, true, window_update); } bool llama_set_dflash_target_features_view( @@ -276,8 +395,9 @@ bool llama_set_dflash_target_features_view( const float * target_features, size_t n_floats, int32_t n_rows, - const llama_pos * target_positions) { - return llama_set_dflash_target_features_impl(ctx, target_features, n_floats, n_rows, target_positions, false); + const llama_pos * target_positions, + const llama_dflash_window_update * window_update) { + return llama_set_dflash_target_features_impl(ctx, target_features, n_floats, n_rows, target_positions, false, window_update); } static void llama_record_dflash_capture_phase( diff --git a/src/llama-spec-features.h b/src/llama-spec-features.h index 9ec2e827..d976c89a 100644 --- a/src/llama-spec-features.h +++ b/src/llama-spec-features.h @@ -2,6 +2,8 @@ #include "llama.h" +#include +#include #include struct llama_context; @@ -24,6 +26,19 @@ struct llama_spec_feature_view { }; struct llama_dflash_profile_stats { + uint64_t decode_internal_chunks = 0; + uint64_t decode_graph_rebuilds = 0; + uint64_t decode_sync_profile_points = 0; + uint64_t decode_prelude_us = 0; + uint64_t decode_sched_reset_us = 0; + uint64_t decode_build_graph_us = 0; + uint64_t decode_sched_alloc_graph_us = 0; + uint64_t decode_set_inputs_us = 0; + uint64_t decode_graph_compute_us = 0; + uint64_t decode_result_us = 0; + uint64_t decode_embedding_us = 0; + uint64_t decode_final_sched_reset_us = 0; + uint64_t decode_output_reserve_calls = 0; uint64_t decode_output_reserve_us = 0; uint64_t decode_output_reserve_reallocs = 0; @@ -71,6 +86,20 @@ struct llama_dflash_profile_stats { uint64_t graph_kv_cache_read_concat_pad_calls = 0; uint64_t graph_kv_cache_cached_bytes = 0; uint64_t graph_kv_cache_calls = 0; + uint64_t graph_kv_node_fused_target_calls = 0; + uint64_t graph_kv_node_fused_target_us = 0; + uint64_t graph_kv_node_k_proj_calls = 0; + uint64_t graph_kv_node_k_proj_us = 0; + uint64_t graph_kv_node_k_norm_calls = 0; + uint64_t graph_kv_node_k_norm_us = 0; + uint64_t graph_kv_node_k_rope_calls = 0; + uint64_t graph_kv_node_k_rope_us = 0; + uint64_t graph_kv_node_v_proj_calls = 0; + uint64_t graph_kv_node_v_proj_us = 0; + uint64_t graph_kv_node_k_store_calls = 0; + uint64_t graph_kv_node_k_store_us = 0; + uint64_t graph_kv_node_v_store_calls = 0; + uint64_t graph_kv_node_v_store_us = 0; uint64_t graph_feature_bytes = 0; uint64_t graph_pos_bytes = 0; uint64_t graph_mask_bytes = 0; @@ -96,10 +125,76 @@ struct llama_dflash_profile_stats { llama_pos last_pos_last = -1; }; +struct llama_dflash_window_update { + uint64_t version = 0; + int32_t keep_rows = 0; + int32_t append_rows = 0; + bool replace = false; + const float * append_features = nullptr; + size_t append_floats = 0; +}; + +struct llama_dflash_kv_cache_transition { + bool cache_up_to_date = false; + bool rebuild_cache = false; + int32_t append_rows = 0; + int32_t next_n_filled = 0; + int32_t next_write_pos = 0; +}; + +static inline llama_dflash_kv_cache_transition llama_plan_dflash_kv_cache_transition( + int32_t cross_ctx, + int32_t current_n_filled, + int32_t current_write_pos, + bool cache_valid, + uint64_t applied_window_version, + uint64_t target_window_version, + int32_t keep_rows, + int32_t append_rows, + bool replace, + int32_t n_rows) { + llama_dflash_kv_cache_transition plan; + + const int32_t safe_cross_ctx = std::max(1, cross_ctx); + const int32_t bounded_n_filled = std::clamp(current_n_filled, 0, safe_cross_ctx); + const int32_t bounded_append_rows = std::clamp(append_rows, 0, n_rows); + const int32_t bounded_keep_rows = std::clamp(keep_rows, 0, n_rows); + const int32_t expected_keep_rows = std::min(bounded_n_filled, std::max(0, safe_cross_ctx - bounded_append_rows)); + + plan.cache_up_to_date = cache_valid && applied_window_version == target_window_version; + plan.rebuild_cache = !cache_valid || replace || bounded_append_rows <= 0 || bounded_append_rows > n_rows; + if (!plan.rebuild_cache && bounded_keep_rows != expected_keep_rows) { + plan.rebuild_cache = true; + } + + plan.append_rows = bounded_append_rows; + if (plan.cache_up_to_date) { + plan.next_n_filled = bounded_n_filled; + plan.next_write_pos = safe_cross_ctx > 0 + ? ((current_write_pos % safe_cross_ctx) + safe_cross_ctx) % safe_cross_ctx + : 0; + } else if (plan.rebuild_cache) { + plan.next_n_filled = std::min(safe_cross_ctx, n_rows); + plan.next_write_pos = plan.next_n_filled % safe_cross_ctx; + } else { + plan.next_n_filled = std::min(safe_cross_ctx, bounded_n_filled + bounded_append_rows); + plan.next_write_pos = (current_write_pos + bounded_append_rows) % safe_cross_ctx; + } + + return plan; +} + +llama_dflash_kv_cache_transition llama_plan_dflash_kv_cache_transition_for_ctx( + const struct llama_context * ctx, + const llama_dflash_window_update & window_update, + int32_t n_rows); + uint32_t llama_mtp_state_n_embd(const struct llama_context * ctx); void llama_dflash_profile_reset(struct llama_context * ctx); +void llama_reset_dflash_kv_cache_state(struct llama_context * ctx); + void llama_set_dflash_visible_cross_ctx( struct llama_context * ctx, int32_t cross_ctx); @@ -156,14 +251,16 @@ bool llama_set_dflash_target_features_copy( const float * target_features, size_t n_floats, int32_t n_rows, - const llama_pos * target_positions); + const llama_pos * target_positions, + const llama_dflash_window_update * window_update = nullptr); bool llama_set_dflash_target_features_view( struct llama_context * ctx, const float * target_features, size_t n_floats, int32_t n_rows, - const llama_pos * target_positions); + const llama_pos * target_positions, + const llama_dflash_window_update * window_update = nullptr); bool llama_set_dflash_capture_layers( struct llama_context * ctx, diff --git a/src/llama.cpp b/src/llama.cpp index e3b91b0b..e53940a2 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -171,6 +171,129 @@ static std::vector string_split(const std::string& str, const std:: return parts; } +static bool llama_env_flag_enabled(const char * name) { + const char * env = std::getenv(name); + return env != nullptr && *env != '\0' && + std::strcmp(env, "0") != 0 && + std::strcmp(env, "false") != 0 && + std::strcmp(env, "off") != 0; +} + +enum llama_dflash_kv_node_kind { + LLAMA_DFLASH_KV_NODE_NONE = 0, + LLAMA_DFLASH_KV_NODE_FUSED_TARGET, + LLAMA_DFLASH_KV_NODE_K_PROJ, + LLAMA_DFLASH_KV_NODE_K_NORM, + LLAMA_DFLASH_KV_NODE_K_ROPE, + LLAMA_DFLASH_KV_NODE_V_PROJ, + LLAMA_DFLASH_KV_NODE_K_STORE, + LLAMA_DFLASH_KV_NODE_V_STORE, +}; + +struct llama_dflash_kv_node_profiler { + llama_dflash_profile_stats * profile = nullptr; + int64_t t_start_us = 0; + llama_dflash_kv_node_kind active_kind = LLAMA_DFLASH_KV_NODE_NONE; +}; + +static bool llama_dflash_tensor_name_has_prefix(const struct ggml_tensor * tensor, const char * prefix) { + if (tensor == nullptr || prefix == nullptr || prefix[0] == '\0') { + return false; + } + + return std::strncmp(tensor->name, prefix, std::strlen(prefix)) == 0; +} + +static llama_dflash_kv_node_kind llama_dflash_kv_node_kind_from_tensor(const struct ggml_tensor * tensor) { + if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_fused_target")) { + return LLAMA_DFLASH_KV_NODE_FUSED_TARGET; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_proj")) { + return LLAMA_DFLASH_KV_NODE_K_PROJ; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_norm")) { + return LLAMA_DFLASH_KV_NODE_K_NORM; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_rope")) { + return LLAMA_DFLASH_KV_NODE_K_ROPE; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_v_proj")) { + return LLAMA_DFLASH_KV_NODE_V_PROJ; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_store")) { + return LLAMA_DFLASH_KV_NODE_K_STORE; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_v_store")) { + return LLAMA_DFLASH_KV_NODE_V_STORE; + } + + return LLAMA_DFLASH_KV_NODE_NONE; +} + +static void llama_dflash_kv_node_profile_add( + llama_dflash_profile_stats & profile, + llama_dflash_kv_node_kind kind, + uint64_t elapsed_us) { + switch (kind) { + case LLAMA_DFLASH_KV_NODE_FUSED_TARGET: + profile.graph_kv_node_fused_target_calls++; + profile.graph_kv_node_fused_target_us += elapsed_us; + break; + case LLAMA_DFLASH_KV_NODE_K_PROJ: + profile.graph_kv_node_k_proj_calls++; + profile.graph_kv_node_k_proj_us += elapsed_us; + break; + case LLAMA_DFLASH_KV_NODE_K_NORM: + profile.graph_kv_node_k_norm_calls++; + profile.graph_kv_node_k_norm_us += elapsed_us; + break; + case LLAMA_DFLASH_KV_NODE_K_ROPE: + profile.graph_kv_node_k_rope_calls++; + profile.graph_kv_node_k_rope_us += elapsed_us; + break; + case LLAMA_DFLASH_KV_NODE_V_PROJ: + profile.graph_kv_node_v_proj_calls++; + profile.graph_kv_node_v_proj_us += elapsed_us; + break; + case LLAMA_DFLASH_KV_NODE_K_STORE: + profile.graph_kv_node_k_store_calls++; + profile.graph_kv_node_k_store_us += elapsed_us; + break; + case LLAMA_DFLASH_KV_NODE_V_STORE: + profile.graph_kv_node_v_store_calls++; + profile.graph_kv_node_v_store_us += elapsed_us; + break; + case LLAMA_DFLASH_KV_NODE_NONE: + break; + } +} + +static bool llama_dflash_kv_node_eval_callback(struct ggml_tensor * tensor, bool ask, void * user_data) { + auto * profiler = static_cast(user_data); + if (profiler == nullptr || profiler->profile == nullptr) { + return false; + } + + const llama_dflash_kv_node_kind kind = llama_dflash_kv_node_kind_from_tensor(tensor); + if (ask) { + if (kind == LLAMA_DFLASH_KV_NODE_NONE) { + return false; + } + + profiler->active_kind = kind; + profiler->t_start_us = ggml_time_us(); + return true; + } + + if (kind != LLAMA_DFLASH_KV_NODE_NONE && profiler->active_kind == kind && profiler->t_start_us > 0) { + llama_dflash_kv_node_profile_add(*profiler->profile, kind, (uint64_t) (ggml_time_us() - profiler->t_start_us)); + } + + profiler->active_kind = LLAMA_DFLASH_KV_NODE_NONE; + profiler->t_start_us = 0; + return true; +} + // extract ip and port from RPC[ip:port] for rpc and keep other device names static std::vector extract_device_from_rpc_device(std::vector devices) { std::vector rpc_servers; @@ -689,6 +812,7 @@ bool llama_context::ensure_dflash_kv_cache_tensors(int32_t cross_ctx) { } dflash_profile.last_kv_cache_host_layers = host_layers; + llama_reset_dflash_kv_cache_state(this); LLAMA_LOG_INFO("%s: DFlash K/V cache placement cross_ctx=%d host_layers=%d/%d first=%s last=%s\n", __func__, target_cross_ctx, @@ -703,6 +827,15 @@ bool llama_context::ensure_dflash_kv_cache_tensors(int32_t cross_ctx) { void llama_context::free_dflash_kv_cache_tensors() { dflash_k_ctx_cache.clear(); dflash_v_ctx_cache.clear(); + dflash_kv_cache_write_pos = 0; + dflash_kv_cache_n_filled = 0; + dflash_kv_cache_update_rows = 0; + dflash_kv_cache_reserved_rows = 0; + dflash_kv_cache_view_write_pos = 0; + dflash_kv_cache_view_n_filled = 0; + dflash_kv_cache_applied_window_version = 0; + dflash_kv_cache_valid = false; + dflash_kv_cache_view_valid = false; dflash_kv_input_target_features = nullptr; dflash_kv_input_pos_ctx = nullptr; dflash_kq_mask_tensor = nullptr; @@ -5271,11 +5404,8 @@ static bool validate_dflash_graph_contract(const llama_context & lctx) { static bool prepare_dflash_graph_inputs( struct llama_context & lctx, uint32_t n_tokens) { - const char * dflash_kv_cache_env = std::getenv("IK_DFLASH_KV_CACHE"); - const bool use_kv_cache = dflash_kv_cache_env != nullptr && *dflash_kv_cache_env != '\0' && - std::strcmp(dflash_kv_cache_env, "0") != 0 && - std::strcmp(dflash_kv_cache_env, "false") != 0 && - std::strcmp(dflash_kv_cache_env, "off") != 0; + const bool use_kv_cache = llama_env_flag_enabled("IK_DFLASH_KV_CACHE"); + const bool kv_node_timing = llama_env_flag_enabled("IK_DFLASH_KV_NODE_TIMING"); auto & profile = lctx.dflash_profile; const int32_t cross_ctx = lctx.dflash_visible_cross_ctx > 0 ? lctx.dflash_visible_cross_ctx @@ -5304,10 +5434,13 @@ static bool prepare_dflash_graph_inputs( } const float * src = lctx.dflash_target_features; + const float * append_src = lctx.dflash_target_append_features; const llama_pos * src_pos = lctx.dflash_target_positions; const size_t total_floats = lctx.dflash_target_features_n_floats; + const size_t append_floats = lctx.dflash_target_append_features_n_floats; const size_t total_positions = lctx.dflash_target_positions_n; const int32_t n_rows = lctx.dflash_target_features_n_rows; + const int32_t append_rows_available = lctx.dflash_target_append_features_n_rows; const int32_t width = (int32_t) lctx.model.hparams.dflash_n_target_features; const int32_t graph_cross_ctx = use_kv_cache ? (lctx.dflash_k_ctx_cache.front() != nullptr ? (int32_t) lctx.dflash_k_ctx_cache.front()->ne[2] : 0) @@ -5330,19 +5463,26 @@ static bool prepare_dflash_graph_inputs( __func__, graph_cross_ctx, cross_ctx); return false; } - if (src == nullptr || total_floats == 0 || n_rows <= 0) { + if (n_rows <= 0) { profile.graph_shape_failures++; - LLAMA_LOG_ERROR("%s: missing DFlash target features\n", __func__); + LLAMA_LOG_ERROR("%s: missing DFlash target feature rows\n", __func__); return false; } - if (n_rows > cross_ctx || total_floats != (size_t) n_rows * (size_t) width) { + const bool have_full_src = src != nullptr && total_floats == (size_t) n_rows * (size_t) width; + if (n_rows > cross_ctx || (src != nullptr && !have_full_src)) { profile.graph_shape_failures++; LLAMA_LOG_ERROR("%s: invalid DFlash target feature shape (rows=%d width=%d floats=%zu cross_ctx=%d)\n", __func__, n_rows, width, total_floats, cross_ctx); return false; } + if (!use_kv_cache && !have_full_src) { + profile.graph_shape_failures++; + LLAMA_LOG_ERROR("%s: missing contiguous DFlash target features for inline path\n", __func__); + return false; + } + if (n_kv_total < cross_ctx + (int32_t) n_tokens) { profile.graph_mask_overflow++; LLAMA_LOG_ERROR("%s: invalid DFlash mask shape (n_kv_total=%d < cross_ctx+n_tokens=%d)\n", @@ -5351,24 +5491,26 @@ static bool prepare_dflash_graph_inputs( } const int32_t left_pad = cross_ctx - n_rows; - const size_t padded_floats = (size_t) cross_ctx * (size_t) width; - const size_t dst_offset = (size_t) left_pad * (size_t) width; - const int64_t t_feature_us = ggml_time_us(); profile.last_left_pad = left_pad; - if (lctx.dflash_target_features_padded.size() != padded_floats) { - lctx.dflash_target_features_padded.resize(padded_floats); - } - if (left_pad == 0 && total_floats == padded_floats) { - std::copy(src, src + total_floats, lctx.dflash_target_features_padded.begin()); - } else { - if (dst_offset > 0) { - std::fill(lctx.dflash_target_features_padded.begin(), - lctx.dflash_target_features_padded.begin() + (ptrdiff_t) dst_offset, 0.0f); + if (!use_kv_cache) { + const size_t padded_floats = (size_t) cross_ctx * (size_t) width; + const size_t dst_offset = (size_t) left_pad * (size_t) width; + const int64_t t_feature_us = ggml_time_us(); + if (lctx.dflash_target_features_padded.size() != padded_floats) { + lctx.dflash_target_features_padded.resize(padded_floats); } - std::copy(src, src + total_floats, lctx.dflash_target_features_padded.begin() + (ptrdiff_t) dst_offset); + if (left_pad == 0 && total_floats == padded_floats) { + std::copy(src, src + total_floats, lctx.dflash_target_features_padded.begin()); + } else { + if (dst_offset > 0) { + std::fill(lctx.dflash_target_features_padded.begin(), + lctx.dflash_target_features_padded.begin() + (ptrdiff_t) dst_offset, 0.0f); + } + std::copy(src, src + total_floats, lctx.dflash_target_features_padded.begin() + (ptrdiff_t) dst_offset); + } + profile.graph_feature_copy_us += (uint64_t) (ggml_time_us() - t_feature_us); + profile.graph_feature_bytes += padded_floats * sizeof(float); } - profile.graph_feature_copy_us += (uint64_t) (ggml_time_us() - t_feature_us); - profile.graph_feature_bytes += padded_floats * sizeof(float); const int64_t t_pos_us = ggml_time_us(); lctx.dflash_pos_ctx_data.resize((size_t) cross_ctx); @@ -5403,22 +5545,32 @@ static bool prepare_dflash_graph_inputs( profile.graph_pos_bytes += lctx.dflash_pos_ctx_data.size() * sizeof(llama_pos); if (use_kv_cache) { + const llama_dflash_kv_cache_transition cache_plan = llama_plan_dflash_kv_cache_transition( + cross_ctx, + lctx.dflash_kv_cache_n_filled, + lctx.dflash_kv_cache_write_pos, + lctx.dflash_kv_cache_valid, + lctx.dflash_kv_cache_applied_window_version, + lctx.dflash_target_window_version, + lctx.dflash_target_window_keep_rows, + lctx.dflash_target_window_append_rows, + lctx.dflash_target_window_replace, + n_rows); + + const bool have_append_src = append_src != nullptr && + append_rows_available == cache_plan.append_rows && + append_floats == (size_t) cache_plan.append_rows * (size_t) width; + + const int32_t update_rows = cache_plan.cache_up_to_date + ? 0 + : (cache_plan.rebuild_cache ? n_rows : cache_plan.append_rows); const size_t max_nodes = lctx.model.max_nodes((int) std::max(1, cross_ctx)) + 24 * lctx.model.hparams.n_layer; const size_t meta_size = ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false); if (lctx.dflash_buf_compute_meta.size() != meta_size) { lctx.dflash_buf_compute_meta.resize(meta_size); } - const int64_t t_build_us = ggml_time_us(); - ggml_cgraph * gf_kv = llm_build_context::llama_build_graph_dflash_kv_cache(lctx); - profile.graph_kv_cache_build_us += (uint64_t) (ggml_time_us() - t_build_us); - if (gf_kv == nullptr || lctx.dflash_kv_input_target_features == nullptr || lctx.dflash_kv_input_pos_ctx == nullptr) { - profile.graph_shape_failures++; - LLAMA_LOG_ERROR("%s: failed to build DFlash K/V cache graph\n", __func__); - return false; - } - - if (lctx.dflash_sched == nullptr) { + if (lctx.dflash_sched == nullptr || lctx.dflash_kv_cache_reserved_rows != cross_ctx) { std::vector backend_buft; backend_buft.reserve(lctx.backends.size()); for (auto * backend : lctx.backends) { @@ -5429,51 +5581,117 @@ static bool prepare_dflash_graph_inputs( } } + if (lctx.dflash_sched != nullptr) { + ggml_backend_sched_free(lctx.dflash_sched); + lctx.dflash_sched = nullptr; + } + + const int32_t saved_update_rows = lctx.dflash_kv_cache_update_rows; + lctx.dflash_kv_cache_update_rows = cross_ctx; + const int64_t t_build_us = ggml_time_us(); + ggml_cgraph * gf_reserve = llm_build_context::llama_build_graph_dflash_kv_cache(lctx); + profile.graph_kv_cache_build_us += (uint64_t) (ggml_time_us() - t_build_us); + lctx.dflash_kv_cache_update_rows = saved_update_rows; + if (gf_reserve == nullptr) { + profile.graph_shape_failures++; + LLAMA_LOG_ERROR("%s: failed to build DFlash K/V cache reserve graph\n", __func__); + return false; + } + const int64_t t_reserve_us = ggml_time_us(); lctx.dflash_sched = ggml_backend_sched_new(lctx.backends.data(), backend_buft.data(), lctx.backends.size(), max_nodes, false); - const bool reserved = lctx.dflash_sched != nullptr && ggml_backend_sched_reserve(lctx.dflash_sched, gf_kv); + const bool reserved = lctx.dflash_sched != nullptr && ggml_backend_sched_reserve(lctx.dflash_sched, gf_reserve); profile.graph_kv_cache_reserve_us += (uint64_t) (ggml_time_us() - t_reserve_us); if (!reserved) { profile.graph_shape_failures++; LLAMA_LOG_ERROR("%s: failed to initialize DFlash K/V scheduler\n", __func__); return false; } + lctx.dflash_kv_cache_reserved_rows = cross_ctx; } - const int64_t t_reset_us = ggml_time_us(); - ggml_backend_sched_reset(lctx.dflash_sched); - profile.graph_kv_cache_reset_us += (uint64_t) (ggml_time_us() - t_reset_us); + if (update_rows > 0) { + const float * update_src = nullptr; + if (have_append_src && update_rows == cache_plan.append_rows) { + update_src = append_src; + } else if (have_full_src) { + update_src = src + (size_t) (n_rows - update_rows) * (size_t) width; + } + const llama_pos * update_pos = src_pos + (n_rows - update_rows); - const int64_t t_alloc_us = ggml_time_us(); - ggml_backend_sched_alloc_graph(lctx.dflash_sched, gf_kv); - profile.graph_kv_cache_alloc_us += (uint64_t) (ggml_time_us() - t_alloc_us); + if (update_src == nullptr) { + profile.graph_shape_failures++; + LLAMA_LOG_ERROR("%s: missing DFlash appended target features for cached update (rows=%d append_rows=%d floats=%zu)\n", + __func__, n_rows, update_rows, append_floats); + return false; + } - ggml_backend_t kv_feature_backend = llama_backend_for_tensor(lctx, lctx.dflash_kv_input_target_features); - const int64_t t_feature_upload_us = ggml_time_us(); - if (kv_feature_backend != nullptr) { - ggml_backend_tensor_set_async(kv_feature_backend, lctx.dflash_kv_input_target_features, lctx.dflash_target_features_padded.data(), 0, ggml_nbytes(lctx.dflash_kv_input_target_features)); - } else { - ggml_backend_tensor_set(lctx.dflash_kv_input_target_features, lctx.dflash_target_features_padded.data(), 0, ggml_nbytes(lctx.dflash_kv_input_target_features)); + if (cache_plan.rebuild_cache) { + llama_reset_dflash_kv_cache_state(&lctx); + } + + lctx.dflash_kv_cache_update_rows = update_rows; + const int64_t t_build_us = ggml_time_us(); + ggml_cgraph * gf_kv = llm_build_context::llama_build_graph_dflash_kv_cache(lctx); + profile.graph_kv_cache_build_us += (uint64_t) (ggml_time_us() - t_build_us); + if (gf_kv == nullptr || lctx.dflash_kv_input_target_features == nullptr || lctx.dflash_kv_input_pos_ctx == nullptr) { + profile.graph_shape_failures++; + LLAMA_LOG_ERROR("%s: failed to build DFlash K/V cache graph\n", __func__); + return false; + } + + const int64_t t_reset_us = ggml_time_us(); + ggml_backend_sched_reset(lctx.dflash_sched); + profile.graph_kv_cache_reset_us += (uint64_t) (ggml_time_us() - t_reset_us); + + const int64_t t_alloc_us = ggml_time_us(); + ggml_backend_sched_alloc_graph(lctx.dflash_sched, gf_kv); + profile.graph_kv_cache_alloc_us += (uint64_t) (ggml_time_us() - t_alloc_us); + + ggml_backend_t kv_feature_backend = llama_backend_for_tensor(lctx, lctx.dflash_kv_input_target_features); + const int64_t t_feature_upload_us = ggml_time_us(); + if (kv_feature_backend != nullptr) { + ggml_backend_tensor_set_async(kv_feature_backend, lctx.dflash_kv_input_target_features, update_src, 0, ggml_nbytes(lctx.dflash_kv_input_target_features)); + } else { + ggml_backend_tensor_set(lctx.dflash_kv_input_target_features, update_src, 0, ggml_nbytes(lctx.dflash_kv_input_target_features)); + } + profile.graph_kv_cache_feature_upload_us += (uint64_t) (ggml_time_us() - t_feature_upload_us); + profile.graph_feature_bytes += (size_t) update_rows * (size_t) width * sizeof(float); + + ggml_backend_t kv_pos_backend = llama_backend_for_tensor(lctx, lctx.dflash_kv_input_pos_ctx); + const int64_t t_pos_upload_us = ggml_time_us(); + if (kv_pos_backend != nullptr) { + ggml_backend_tensor_set_async(kv_pos_backend, lctx.dflash_kv_input_pos_ctx, update_pos, 0, ggml_nbytes(lctx.dflash_kv_input_pos_ctx)); + } else { + ggml_backend_tensor_set(lctx.dflash_kv_input_pos_ctx, update_pos, 0, ggml_nbytes(lctx.dflash_kv_input_pos_ctx)); + } + profile.graph_kv_cache_pos_upload_us += (uint64_t) (ggml_time_us() - t_pos_upload_us); + + const int64_t t_kv_cache_us = ggml_time_us(); + llama_dflash_kv_node_profiler kv_node_profiler; + if (kv_node_timing) { + kv_node_profiler.profile = &profile; + ggml_backend_sched_set_eval_callback(lctx.dflash_sched, llama_dflash_kv_node_eval_callback, &kv_node_profiler); + } + llama_graph_compute_sched(lctx, lctx.dflash_sched, gf_kv, lctx.cparams.n_threads); + if (kv_node_timing) { + ggml_backend_sched_set_eval_callback(lctx.dflash_sched, nullptr, nullptr); + } + profile.graph_kv_cache_compute_us += (uint64_t) (ggml_time_us() - t_kv_cache_us); + + const int64_t t_sync_us = ggml_time_us(); + ggml_backend_sched_synchronize(lctx.dflash_sched); + profile.graph_kv_cache_sync_us += (uint64_t) (ggml_time_us() - t_sync_us); + profile.graph_kv_cache_calls++; + + lctx.dflash_kv_cache_n_filled = std::min(cross_ctx, lctx.dflash_kv_cache_n_filled + update_rows); + lctx.dflash_kv_cache_write_pos = (lctx.dflash_kv_cache_write_pos + update_rows) % cross_ctx; + lctx.dflash_kv_cache_applied_window_version = lctx.dflash_target_window_version; + lctx.dflash_kv_cache_valid = true; + lctx.dflash_kv_cache_view_n_filled = lctx.dflash_kv_cache_n_filled; + lctx.dflash_kv_cache_view_write_pos = lctx.dflash_kv_cache_write_pos; + lctx.dflash_kv_cache_view_valid = true; } - profile.graph_kv_cache_feature_upload_us += (uint64_t) (ggml_time_us() - t_feature_upload_us); - - ggml_backend_t kv_pos_backend = llama_backend_for_tensor(lctx, lctx.dflash_kv_input_pos_ctx); - const int64_t t_pos_upload_us = ggml_time_us(); - if (kv_pos_backend != nullptr) { - ggml_backend_tensor_set_async(kv_pos_backend, lctx.dflash_kv_input_pos_ctx, lctx.dflash_pos_ctx_data.data(), 0, ggml_nbytes(lctx.dflash_kv_input_pos_ctx)); - } else { - ggml_backend_tensor_set(lctx.dflash_kv_input_pos_ctx, lctx.dflash_pos_ctx_data.data(), 0, ggml_nbytes(lctx.dflash_kv_input_pos_ctx)); - } - profile.graph_kv_cache_pos_upload_us += (uint64_t) (ggml_time_us() - t_pos_upload_us); - - const int64_t t_kv_cache_us = ggml_time_us(); - llama_graph_compute_sched(lctx, lctx.dflash_sched, gf_kv, lctx.cparams.n_threads); - profile.graph_kv_cache_compute_us += (uint64_t) (ggml_time_us() - t_kv_cache_us); - - const int64_t t_sync_us = ggml_time_us(); - ggml_backend_sched_synchronize(lctx.dflash_sched); - profile.graph_kv_cache_sync_us += (uint64_t) (ggml_time_us() - t_sync_us); - profile.graph_kv_cache_calls++; } else { ggml_backend_tensor_set(lctx.inp_dflash_target_features, lctx.dflash_target_features_padded.data(), 0, ggml_nbytes(lctx.inp_dflash_target_features)); ggml_backend_tensor_set(lctx.inp_dflash_pos_ctx, lctx.dflash_pos_ctx_data.data(), 0, ggml_nbytes(lctx.inp_dflash_pos_ctx)); @@ -5586,6 +5804,9 @@ static int llama_decode_internal( } lctx.n_queued_tokens += n_tokens_all; + auto * dflash_profile = lctx.model.arch == LLM_ARCH_DFLASH_DRAFT ? &lctx.dflash_profile : nullptr; + const bool dflash_decode_timing = dflash_profile != nullptr && llama_env_flag_enabled("IK_DFLASH_DECODE_TIMING"); + auto & kv_self = lctx.kv_self; const int64_t n_embd = hparams.n_embd; @@ -5670,6 +5891,10 @@ static int llama_decode_internal( #if IK_PRINT_TIMING auto tim1 = ggml_time_us(); #endif + const int64_t t_dflash_prelude_us = dflash_decode_timing ? ggml_time_us() : 0; + if (dflash_decode_timing) { + dflash_profile->decode_internal_chunks++; + } uint32_t n_tokens = std::min(n_ubatch, n_tokens_all - cur_token); if (llm_arch_is_hybrid(model.arch) && n_tokens > 1 && @@ -5804,6 +6029,9 @@ static int llama_decode_internal( auto tim2 = ggml_time_us(); printf("prelude(...): %d us\n", int(tim2-tim1)); #endif + if (dflash_decode_timing) { + dflash_profile->decode_prelude_us += (uint64_t) (ggml_time_us() - t_dflash_prelude_us); + } #if IK_PRINT_TIMING tim1 = ggml_time_us(); @@ -5811,30 +6039,45 @@ static int llama_decode_internal( auto & prev = cparams.mtp_op_type == MTP_OP_NONE ? lctx.prev : lctx.prev_mtp; ggml_cgraph * gf = nullptr; if (!lctx.can_reuse_graph(u_batch)) { + if (dflash_decode_timing) { + dflash_profile->decode_graph_rebuilds++; + } + const int64_t t_dflash_sched_reset_us = dflash_decode_timing ? ggml_time_us() : 0; lctx.reset_scheduler(); ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data); #if IK_PRINT_TIMING tim2 = ggml_time_us(); printf("sched_reset(...): %d us\n", int(tim2-tim1)); #endif + if (dflash_decode_timing) { + dflash_profile->decode_sched_reset_us += (uint64_t) (ggml_time_us() - t_dflash_sched_reset_us); + } #if IK_PRINT_TIMING tim1 = ggml_time_us(); #endif + const int64_t t_dflash_build_graph_us = dflash_decode_timing ? ggml_time_us() : 0; gf = llm_build_context::llama_build_graph(lctx, u_batch, false); #if IK_PRINT_TIMING tim2 = ggml_time_us(); printf("build_graph(...): %d us\n", int(tim2-tim1)); #endif + if (dflash_decode_timing) { + dflash_profile->decode_build_graph_us += (uint64_t) (ggml_time_us() - t_dflash_build_graph_us); + } #if IK_PRINT_TIMING tim1 = ggml_time_us(); #endif + const int64_t t_dflash_sched_alloc_us = dflash_decode_timing ? ggml_time_us() : 0; ggml_backend_sched_alloc_graph(lctx.sched, gf); #if IK_PRINT_TIMING tim2 = ggml_time_us(); printf("sched_alloc_graph(...): %d us\n", int(tim2-tim1)); #endif + if (dflash_decode_timing) { + dflash_profile->decode_sched_alloc_graph_us += (uint64_t) (ggml_time_us() - t_dflash_sched_alloc_us); + } //if (u_batch.n_tokens == 1 && u_batch.embd == nullptr && lctx.cparams.graph_reuse) { if (u_batch.embd == nullptr && lctx.cparams.graph_reuse && !(lctx.model.arch == LLM_ARCH_GEMMA4_MTP && lctx.mtp_target_ctx != nullptr)) { @@ -5855,16 +6098,15 @@ static int llama_decode_internal( } } - if (lctx.model.arch == LLM_ARCH_DFLASH_DRAFT) { - auto & profile = lctx.dflash_profile; - profile.decode_prepare_calls++; + if (dflash_profile != nullptr) { + dflash_profile->decode_prepare_calls++; const int64_t t_prepare_dflash_us = ggml_time_us(); if (!prepare_dflash_graph_inputs(lctx, n_tokens)) { - profile.decode_prepare_failures++; - profile.decode_prepare_us += (uint64_t) (ggml_time_us() - t_prepare_dflash_us); + dflash_profile->decode_prepare_failures++; + dflash_profile->decode_prepare_us += (uint64_t) (ggml_time_us() - t_prepare_dflash_us); return GGML_STATUS_FAILED; } - profile.decode_prepare_us += (uint64_t) (ggml_time_us() - t_prepare_dflash_us); + dflash_profile->decode_prepare_us += (uint64_t) (ggml_time_us() - t_prepare_dflash_us); } // the output is always the last tensor in the graph @@ -5910,16 +6152,26 @@ static int llama_decode_internal( #if IK_PRINT_TIMING == 1 tim1 = ggml_time_us(); #endif + const int64_t t_dflash_set_inputs_us = dflash_decode_timing ? ggml_time_us() : 0; llama_set_inputs(lctx, u_batch); #if IK_PRINT_TIMING == 1 tim2 = ggml_time_us(); printf("set_inputs(...): %d us\n", int(tim2-tim1)); #endif + if (dflash_decode_timing) { + dflash_profile->decode_set_inputs_us += (uint64_t) (ggml_time_us() - t_dflash_set_inputs_us); + } #if IK_PRINT_TIMING tim1 = ggml_time_us(); #endif + const int64_t t_dflash_graph_compute_us = dflash_decode_timing ? ggml_time_us() : 0; llama_graph_compute(lctx, gf, n_threads); + if (dflash_decode_timing) { + llama_synchronize(&lctx); + dflash_profile->decode_sync_profile_points++; + dflash_profile->decode_graph_compute_us += (uint64_t) (ggml_time_us() - t_dflash_graph_compute_us); + } #if IK_PRINT_TIMING llama_synchronize(&lctx); tim2 = ggml_time_us(); @@ -5950,6 +6202,7 @@ static int llama_decode_internal( #if IK_PRINT_TIMING tim1 = ggml_time_us(); #endif + const int64_t t_dflash_get_result_us = dflash_decode_timing ? ggml_time_us() : 0; // Do not process logits if MTP is only updating the KV cache. if (cparams.mtp_op_type != MTP_OP_WARMUP) { // && cparams.mtp_op_type != MTP_OP_UPDATE_ACCEPTED) { ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched, res); @@ -5980,6 +6233,11 @@ static int llama_decode_internal( } } } + if (dflash_decode_timing) { + llama_synchronize(&lctx); + dflash_profile->decode_sync_profile_points++; + dflash_profile->decode_result_us += (uint64_t) (ggml_time_us() - t_dflash_get_result_us); + } #if IK_PRINT_TIMING tim2 = ggml_time_us(); printf("get_result(...): %d us\n", int(tim2-tim1)); @@ -5992,6 +6250,7 @@ static int llama_decode_internal( #if IK_PRINT_TIMING tim1 = ggml_time_us(); #endif + const int64_t t_dflash_get_embedding_us = dflash_decode_timing ? ggml_time_us() : 0; ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched, embd); GGML_ASSERT(backend_embd != nullptr); @@ -6031,6 +6290,11 @@ static int llama_decode_internal( GGML_ABORT("unknown pooling type"); } } + if (dflash_decode_timing) { + llama_synchronize(&lctx); + dflash_profile->decode_sync_profile_points++; + dflash_profile->decode_embedding_us += (uint64_t) (ggml_time_us() - t_dflash_get_embedding_us); + } #if IK_PRINT_TIMING tim2 = ggml_time_us(); printf("get_embedding(...): %d us\n", int(tim2-tim1)); @@ -6074,9 +6338,13 @@ static int llama_decode_internal( #if IK_PRINT_TIMING auto tim1 = ggml_time_us(); #endif + const int64_t t_dflash_final_sched_reset_us = dflash_decode_timing ? ggml_time_us() : 0; if (!lctx.prev) { lctx.reset_scheduler(); } + if (dflash_decode_timing) { + dflash_profile->decode_final_sched_reset_us += (uint64_t) (ggml_time_us() - t_dflash_final_sched_reset_us); + } #if IK_PRINT_TIMING auto tim2 = ggml_time_us(); printf("sched_reset(...): %d us\n", int(tim2-tim1)); From 3d73312d9d1cf671c3c99f490857d6a47d53e26d Mon Sep 17 00:00:00 2001 From: SamuelOliveirads Date: Mon, 1 Jun 2026 09:55:34 -0300 Subject: [PATCH 06/13] apply workspace support for KV cache --- common/speculative.cpp | 99 +++++++ src/graphs/build_dflash.cpp | 218 ++++++++++++++-- src/llama-build-context.cpp | 37 +++ src/llama-build-context.h | 4 + src/llama-context.h | 18 ++ src/llama-load-tensors.cpp | 1 + src/llama-spec-features.cpp | 98 ++++++- src/llama-spec-features.h | 44 ++++ src/llama.cpp | 502 ++++++++++++++++++++++++++++++++++-- 9 files changed, 969 insertions(+), 52 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index e7ce71f9..016aeaa3 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -2596,6 +2596,32 @@ void common_speculative_print_stats(const common_speculative * spec, double slot const double kv_upload_total_ms = kv_upload_feature_ms + kv_upload_pos_ms; const double kv_compute_ms = (double) graph_stats.graph_kv_cache_compute_us / 1000.0; const double kv_sync_ms = (double) graph_stats.graph_kv_cache_sync_us / 1000.0; + const double kv_workspace_total_ms = (double) ( + graph_stats.graph_kv_workspace_build_us + + graph_stats.graph_kv_workspace_reserve_us + + graph_stats.graph_kv_workspace_reset_us + + graph_stats.graph_kv_workspace_alloc_us + + graph_stats.graph_kv_workspace_compute_us + + graph_stats.graph_kv_workspace_sync_us) / 1000.0; + const double draft_kv_traffic_ms = (double) ( + graph_stats.graph_main_node_k_ctx_view_us + + graph_stats.graph_main_node_v_ctx_view_us + + graph_stats.graph_main_node_k_concat_us + + graph_stats.graph_main_node_v_concat_us + + graph_stats.graph_main_node_k_pad_us + + graph_stats.graph_main_node_v_pad_us + + graph_stats.graph_main_node_k_perm_cont_us + + graph_stats.graph_main_node_v_perm_cont_us) / 1000.0; + const double draft_main_profiled_ms = (double) ( + graph_stats.graph_main_node_qcur_us + + graph_stats.graph_main_node_k_draft_us + + graph_stats.graph_main_node_v_draft_us + + graph_stats.graph_main_node_flash_attn_us + + graph_stats.graph_main_node_attn_out_us + + graph_stats.graph_main_node_ffn_us + + graph_stats.graph_main_node_result_rows_us + + graph_stats.graph_main_node_result_norm_us + + graph_stats.graph_main_node_result_us) / 1000.0; const double replay_append_ms = (double) dflash_state->t_accept_append_us / 1000.0; const double feature_path_ms = (double) ( capture_stats.capture_prepare_sync_us + @@ -2704,6 +2730,18 @@ void common_speculative_print_stats(const common_speculative * spec, double slot (unsigned long long) graph_stats.graph_kv_cache_cached_bytes, graph_stats.last_kv_cache_host_layers); + if (graph_stats.graph_kv_workspace_calls > 0) { + LOG_INF("statistics dflash kv workspace: total=%.3f ms build/reserve/reset/alloc/compute/sync=%.3f/%.3f/%.3f/%.3f/%.3f/%.3f ms calls=%llu\n", + kv_workspace_total_ms, + (double) graph_stats.graph_kv_workspace_build_us / 1000.0, + (double) graph_stats.graph_kv_workspace_reserve_us / 1000.0, + (double) graph_stats.graph_kv_workspace_reset_us / 1000.0, + (double) graph_stats.graph_kv_workspace_alloc_us / 1000.0, + (double) graph_stats.graph_kv_workspace_compute_us / 1000.0, + (double) graph_stats.graph_kv_workspace_sync_us / 1000.0, + (unsigned long long) graph_stats.graph_kv_workspace_calls); + } + if (graph_stats.decode_internal_chunks > 0) { LOG_INF("statistics dflash decode: llama_decode(total)=%.3f ms calls=%zu chunks=%llu rebuilds=%llu sync_points=%llu internal(total/prelude/sched_reset/build/alloc/prepare/set_inputs/compute/get_result/get_embedding/final_reset)=%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f ms\n", (double) dflash_state->t_draft_decode_us / 1000.0, @@ -2748,6 +2786,67 @@ void common_speculative_print_stats(const common_speculative * spec, double slot (unsigned long long) graph_stats.graph_kv_node_v_store_calls); } + if (graph_stats.graph_main_node_qcur_calls > 0 || + graph_stats.graph_main_node_k_draft_calls > 0 || + graph_stats.graph_main_node_v_draft_calls > 0 || + graph_stats.graph_main_node_flash_attn_calls > 0 || + graph_stats.graph_main_node_attn_out_calls > 0 || + graph_stats.graph_main_node_ffn_calls > 0 || + graph_stats.graph_main_node_result_rows_calls > 0 || + graph_stats.graph_main_node_result_norm_calls > 0 || + graph_stats.graph_main_node_result_calls > 0) { + LOG_INF("statistics dflash draft nodes: profiled=%.3f ms graph_compute=%.3f ms qcur/k_draft/v_draft/flash_attn/attn_out/ffn/result_rows/result_norm/result=%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f ms calls=%llu/%llu/%llu/%llu/%llu/%llu/%llu/%llu/%llu\n", + draft_main_profiled_ms, + (double) graph_stats.decode_graph_compute_us / 1000.0, + (double) graph_stats.graph_main_node_qcur_us / 1000.0, + (double) graph_stats.graph_main_node_k_draft_us / 1000.0, + (double) graph_stats.graph_main_node_v_draft_us / 1000.0, + (double) graph_stats.graph_main_node_flash_attn_us / 1000.0, + (double) graph_stats.graph_main_node_attn_out_us / 1000.0, + (double) graph_stats.graph_main_node_ffn_us / 1000.0, + (double) graph_stats.graph_main_node_result_rows_us / 1000.0, + (double) graph_stats.graph_main_node_result_norm_us / 1000.0, + (double) graph_stats.graph_main_node_result_us / 1000.0, + (unsigned long long) graph_stats.graph_main_node_qcur_calls, + (unsigned long long) graph_stats.graph_main_node_k_draft_calls, + (unsigned long long) graph_stats.graph_main_node_v_draft_calls, + (unsigned long long) graph_stats.graph_main_node_flash_attn_calls, + (unsigned long long) graph_stats.graph_main_node_attn_out_calls, + (unsigned long long) graph_stats.graph_main_node_ffn_calls, + (unsigned long long) graph_stats.graph_main_node_result_rows_calls, + (unsigned long long) graph_stats.graph_main_node_result_norm_calls, + (unsigned long long) graph_stats.graph_main_node_result_calls); + } + + if (graph_stats.graph_main_node_k_ctx_view_calls > 0 || + graph_stats.graph_main_node_v_ctx_view_calls > 0 || + graph_stats.graph_main_node_k_concat_calls > 0 || + graph_stats.graph_main_node_v_concat_calls > 0 || + graph_stats.graph_main_node_k_pad_calls > 0 || + graph_stats.graph_main_node_v_pad_calls > 0 || + graph_stats.graph_main_node_k_perm_cont_calls > 0 || + graph_stats.graph_main_node_v_perm_cont_calls > 0) { + LOG_INF("statistics dflash draft kv traffic: total=%.3f ms graph_compute=%.3f ms k_ctx_view/v_ctx_view/k_concat/v_concat/k_pad/v_pad/k_perm_cont/v_perm_cont=%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f ms calls=%llu/%llu/%llu/%llu/%llu/%llu/%llu/%llu\n", + draft_kv_traffic_ms, + (double) graph_stats.decode_graph_compute_us / 1000.0, + (double) graph_stats.graph_main_node_k_ctx_view_us / 1000.0, + (double) graph_stats.graph_main_node_v_ctx_view_us / 1000.0, + (double) graph_stats.graph_main_node_k_concat_us / 1000.0, + (double) graph_stats.graph_main_node_v_concat_us / 1000.0, + (double) graph_stats.graph_main_node_k_pad_us / 1000.0, + (double) graph_stats.graph_main_node_v_pad_us / 1000.0, + (double) graph_stats.graph_main_node_k_perm_cont_us / 1000.0, + (double) graph_stats.graph_main_node_v_perm_cont_us / 1000.0, + (unsigned long long) graph_stats.graph_main_node_k_ctx_view_calls, + (unsigned long long) graph_stats.graph_main_node_v_ctx_view_calls, + (unsigned long long) graph_stats.graph_main_node_k_concat_calls, + (unsigned long long) graph_stats.graph_main_node_v_concat_calls, + (unsigned long long) graph_stats.graph_main_node_k_pad_calls, + (unsigned long long) graph_stats.graph_main_node_v_pad_calls, + (unsigned long long) graph_stats.graph_main_node_k_perm_cont_calls, + (unsigned long long) graph_stats.graph_main_node_v_perm_cont_calls); + } + LOG_INF("statistics dflash hot: kv(upload_f/upload_p/upload/compute/sync)=%.3f/%.3f/%.3f/%.3f/%.3f ms calls=%llu replay(accepted_prefix_append)=%.3f ms calls=%zu rows=%zu\n", kv_upload_feature_ms, kv_upload_pos_ms, diff --git a/src/graphs/build_dflash.cpp b/src/graphs/build_dflash.cpp index a5b9a815..80c45c1e 100644 --- a/src/graphs/build_dflash.cpp +++ b/src/graphs/build_dflash.cpp @@ -16,6 +16,119 @@ static bool dflash_use_kv_cache_experiment() { std::strcmp(env, "off") != 0; } +static bool dflash_use_kv_workspace_experiment() { + const char * env = std::getenv("IK_DFLASH_KV_WORKSPACE"); + if (env == nullptr || *env == '\0') { + return false; + } + + return std::strcmp(env, "0") != 0 && + std::strcmp(env, "false") != 0 && + std::strcmp(env, "off") != 0; +} + +ggml_cgraph * llm_build_context::build_dflash_kv_workspace() { + const int64_t n_embd_head_k = hparams.n_embd_head_k(0); + const int64_t n_embd_head_v = hparams.n_embd_head_v(0); + const int64_t ctx_len = lctx.dflash_visible_cross_ctx > 0 + ? (int64_t) lctx.dflash_visible_cross_ctx + : std::max(1, (int64_t) cparams.n_ctx - (int64_t) hparams.dflash_block_size); + const int32_t cache_rows = std::clamp(lctx.dflash_kv_cache_view_n_filled, 0, (int32_t) ctx_len); + const int32_t cache_write_pos = ctx_len > 0 + ? ((lctx.dflash_kv_cache_view_write_pos % (int32_t) ctx_len) + (int32_t) ctx_len) % (int32_t) ctx_len + : 0; + + GGML_ASSERT(n_embd_head_k == n_embd_head_v); + GGML_ASSERT(lctx.ensure_dflash_kv_cache_tensors((int32_t) ctx_len)); + GGML_ASSERT((int32_t) lctx.dflash_k_ctx_workspace.size() == n_layer); + GGML_ASSERT((int32_t) lctx.dflash_v_ctx_workspace.size() == n_layer); + + ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes((int) std::max(1, ctx_len)) + 16 * n_layer, false); + + auto build_ordered_cache_view = [&](ggml_tensor * cache) -> ggml_tensor * { + if (!lctx.dflash_kv_cache_view_valid || cache_rows <= 0) { + return cache; + } + + if (cache_rows < ctx_len) { + ggml_tensor * zero_pad = ggml_view_3d(ctx0, cache, + cache->ne[0], + cache->ne[1], + ctx_len - cache_rows, + cache->nb[1], + cache->nb[2], + (size_t) cache_rows * cache->nb[2]); + ggml_tensor * valid = ggml_view_3d(ctx0, cache, + cache->ne[0], + cache->ne[1], + cache_rows, + cache->nb[1], + cache->nb[2], + 0); + return ggml_concat(ctx0, zero_pad, valid, 2); + } + + if (cache_write_pos == 0) { + return cache; + } + + ggml_tensor * tail = ggml_view_3d(ctx0, cache, + cache->ne[0], + cache->ne[1], + ctx_len - cache_write_pos, + cache->nb[1], + cache->nb[2], + (size_t) cache_write_pos * cache->nb[2]); + ggml_tensor * head = ggml_view_3d(ctx0, cache, + cache->ne[0], + cache->ne[1], + cache_write_pos, + cache->nb[1], + cache->nb[2], + 0); + return ggml_concat(ctx0, tail, head, 2); + }; + + for (int il = 0; il < n_layer; ++il) { + GGML_ASSERT((size_t) il < lctx.dflash_k_ctx_cache.size()); + GGML_ASSERT((size_t) il < lctx.dflash_v_ctx_cache.size()); + + ggml_tensor * Kordered = build_ordered_cache_view(lctx.dflash_k_ctx_cache[(size_t) il]); + ggml_tensor * Vordered = build_ordered_cache_view(lctx.dflash_v_ctx_cache[(size_t) il]); + cb(Kordered, "dflash_workspace_k_ctx_view", il); + cb(Vordered, "dflash_workspace_v_ctx_view", il); + + ggml_tensor * Kworkspace = ggml_cont(ctx0, ggml_permute(ctx0, Kordered, 0, 2, 1, 3)); + ggml_tensor * Vworkspace = ggml_cont(ctx0, ggml_permute(ctx0, Vordered, 0, 2, 1, 3)); + cb(Kworkspace, "dflash_workspace_k_perm_cont", il); + cb(Vworkspace, "dflash_workspace_v_perm_cont", il); + + ggml_tensor * Kdst = ggml_view_3d(ctx0, lctx.dflash_k_ctx_workspace[(size_t) il], + lctx.dflash_k_ctx_workspace[(size_t) il]->ne[0], + ctx_len, + lctx.dflash_k_ctx_workspace[(size_t) il]->ne[2], + lctx.dflash_k_ctx_workspace[(size_t) il]->nb[1], + lctx.dflash_k_ctx_workspace[(size_t) il]->nb[2], + 0); + ggml_tensor * Vdst = ggml_view_3d(ctx0, lctx.dflash_v_ctx_workspace[(size_t) il], + lctx.dflash_v_ctx_workspace[(size_t) il]->ne[0], + ctx_len, + lctx.dflash_v_ctx_workspace[(size_t) il]->ne[2], + lctx.dflash_v_ctx_workspace[(size_t) il]->nb[1], + lctx.dflash_v_ctx_workspace[(size_t) il]->nb[2], + 0); + + ggml_tensor * Kstore = ggml_cpy(ctx0, Kworkspace, Kdst); + ggml_tensor * Vstore = ggml_cpy(ctx0, Vworkspace, Vdst); + cb(Kstore, "dflash_workspace_k_store", il); + cb(Vstore, "dflash_workspace_v_store", il); + ggml_build_forward_expand(gf, Kstore); + ggml_build_forward_expand(gf, Vstore); + } + + return gf; +} + ggml_cgraph * llm_build_context::build_dflash_kv_cache() { const int64_t n_embd_head_k = hparams.n_embd_head_k(0); const int64_t n_embd_head_v = hparams.n_embd_head_v(0); @@ -160,6 +273,7 @@ ggml_cgraph * llm_build_context::build_dflash() { const int64_t n_target_features = hparams.dflash_n_target_features; auto & profile = lctx.dflash_profile; const bool use_kv_cache = dflash_use_kv_cache_experiment(); + const bool use_kv_workspace = use_kv_cache && dflash_use_kv_workspace_experiment(); const int64_t ctx_len = lctx.dflash_visible_cross_ctx > 0 ? (int64_t) lctx.dflash_visible_cross_ctx : std::max(1, (int64_t) cparams.n_ctx - (int64_t) hparams.dflash_block_size); @@ -226,6 +340,7 @@ ggml_cgraph * llm_build_context::build_dflash() { ggml_tensor * inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, tok_embd, cb); ggml_tensor * inp_pos = build_inp_pos(); ggml_tensor * inp_out_ids = (n_tokens > 1 && n_outputs < n_tokens) ? build_inp_out_ids() : nullptr; + bool result_rows_selected = false; const float kq_scale = 1.0f / std::sqrt((float) n_embd_head_k); @@ -258,7 +373,30 @@ ggml_cgraph * llm_build_context::build_dflash() { const int64_t t_cache_read_us = use_kv_cache ? ggml_time_us() : 0; ggml_tensor * Kcur_ctx = nullptr; ggml_tensor * Vcur_ctx = nullptr; - if (use_kv_cache) { + const bool have_workspace_ctx = use_kv_workspace && + (size_t) il < lctx.dflash_k_ctx_workspace.size() && + (size_t) il < lctx.dflash_v_ctx_workspace.size() && + lctx.dflash_k_ctx_workspace[(size_t) il] != nullptr && + lctx.dflash_v_ctx_workspace[(size_t) il] != nullptr; + + if (have_workspace_ctx) { + Kcur_ctx = ggml_view_3d(ctx0, lctx.dflash_k_ctx_workspace[(size_t) il], + lctx.dflash_k_ctx_workspace[(size_t) il]->ne[0], + ctx_len, + lctx.dflash_k_ctx_workspace[(size_t) il]->ne[2], + lctx.dflash_k_ctx_workspace[(size_t) il]->nb[1], + lctx.dflash_k_ctx_workspace[(size_t) il]->nb[2], + 0); + Vcur_ctx = ggml_view_3d(ctx0, lctx.dflash_v_ctx_workspace[(size_t) il], + lctx.dflash_v_ctx_workspace[(size_t) il]->ne[0], + ctx_len, + lctx.dflash_v_ctx_workspace[(size_t) il]->ne[2], + lctx.dflash_v_ctx_workspace[(size_t) il]->nb[1], + lctx.dflash_v_ctx_workspace[(size_t) il]->nb[2], + 0); + cb(Kcur_ctx, "Kcur_ctx_workspace", il); + cb(Vcur_ctx, "Vcur_ctx_workspace", il); + } else if (use_kv_cache) { auto build_ordered_cache_view = [&](ggml_tensor * cache) -> ggml_tensor * { if (!lctx.dflash_kv_cache_view_valid || cache_rows <= 0) { return cache; @@ -267,19 +405,19 @@ ggml_cgraph * llm_build_context::build_dflash() { if (cache_rows < ctx_len) { ggml_tensor * zero_pad = ggml_view_3d(ctx0, cache, cache->ne[0], - cache->ne[1], + cache->ne[1], ctx_len - cache_rows, cache->nb[1], cache->nb[2], - (size_t) cache_rows * cache->nb[2]); + (size_t) cache_rows * cache->nb[2]); ggml_tensor * valid = ggml_view_3d(ctx0, cache, cache->ne[0], - cache->ne[1], + cache->ne[1], cache_rows, cache->nb[1], cache->nb[2], 0); - return ggml_concat(ctx0, zero_pad, valid, 2); + return ggml_concat(ctx0, zero_pad, valid, 2); } if (cache_write_pos == 0) { @@ -288,19 +426,19 @@ ggml_cgraph * llm_build_context::build_dflash() { ggml_tensor * tail = ggml_view_3d(ctx0, cache, cache->ne[0], - cache->ne[1], + cache->ne[1], ctx_len - cache_write_pos, cache->nb[1], cache->nb[2], - (size_t) cache_write_pos * cache->nb[2]); + (size_t) cache_write_pos * cache->nb[2]); ggml_tensor * head = ggml_view_3d(ctx0, cache, cache->ne[0], - cache->ne[1], + cache->ne[1], cache_write_pos, cache->nb[1], cache->nb[2], 0); - return ggml_concat(ctx0, tail, head, 2); + return ggml_concat(ctx0, tail, head, 2); }; Kcur_ctx = build_ordered_cache_view(lctx.dflash_k_ctx_cache[(size_t) il]); @@ -321,32 +459,58 @@ ggml_cgraph * llm_build_context::build_dflash() { cb(Vcur_ctx, "Vcur_ctx", il); } - ggml_tensor * Kcur = ggml_concat(ctx0, Kcur_ctx, Kcur_noise, 2); - ggml_tensor * Vcur = ggml_concat(ctx0, Vcur_ctx, Vcur_noise, 2); - if (n_kv_pad > 0) { - Kcur = ggml_pad(ctx0, Kcur, 0, 0, (int) n_kv_pad, 0); - Vcur = ggml_pad(ctx0, Vcur, 0, 0, (int) n_kv_pad, 0); + ggml_tensor * Kcur = nullptr; + ggml_tensor * Vcur = nullptr; + if (have_workspace_ctx) { + ggml_tensor * Kcur_draft = ggml_cont(ctx0, ggml_permute(ctx0, Kcur_noise, 0, 2, 1, 3)); + ggml_tensor * Vcur_draft = ggml_cont(ctx0, ggml_permute(ctx0, Vcur_noise, 0, 2, 1, 3)); + cb(Kcur_draft, "dflash_main_k_perm_cont", il); + cb(Vcur_draft, "dflash_main_v_perm_cont", il); + + Kcur = ggml_concat(ctx0, Kcur_ctx, Kcur_draft, 1); + Vcur = ggml_concat(ctx0, Vcur_ctx, Vcur_draft, 1); + cb(Kcur, "dflash_main_k_concat", il); + cb(Vcur, "dflash_main_v_concat", il); + + if (n_kv_pad > 0) { + Kcur = ggml_pad(ctx0, Kcur, 0, (int) n_kv_pad, 0, 0); + Vcur = ggml_pad(ctx0, Vcur, 0, (int) n_kv_pad, 0, 0); + cb(Kcur, "dflash_main_k_pad", il); + cb(Vcur, "dflash_main_v_pad", il); + } + } else { + ggml_tensor * Kcur_concat = ggml_concat(ctx0, Kcur_ctx, Kcur_noise, 2); + ggml_tensor * Vcur_concat = ggml_concat(ctx0, Vcur_ctx, Vcur_noise, 2); + cb(Kcur_concat, "dflash_main_k_concat", il); + cb(Vcur_concat, "dflash_main_v_concat", il); + + Kcur = Kcur_concat; + Vcur = Vcur_concat; + if (n_kv_pad > 0) { + Kcur = ggml_pad(ctx0, Kcur, 0, 0, (int) n_kv_pad, 0); + Vcur = ggml_pad(ctx0, Vcur, 0, 0, (int) n_kv_pad, 0); + cb(Kcur, "dflash_main_k_pad", il); + cb(Vcur, "dflash_main_v_pad", il); + } } if (use_kv_cache) { profile.graph_kv_cache_read_concat_pad_us += (uint64_t) (ggml_time_us() - t_cache_read_us); profile.graph_kv_cache_read_concat_pad_calls++; profile.graph_kv_cache_cached_bytes += ggml_nbytes(lctx.dflash_k_ctx_cache[(size_t) il]) + ggml_nbytes(lctx.dflash_v_ctx_cache[(size_t) il]); } - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur_f16", il); - cb(Vcur, "Vcur_f16", il); ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); - ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3)); - ggml_tensor * v = ggml_cont(ctx0, ggml_permute(ctx0, Vcur, 0, 2, 1, 3)); + ggml_tensor * k = have_workspace_ctx ? Kcur : ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3)); + ggml_tensor * v = have_workspace_ctx ? Vcur : ggml_cont(ctx0, ggml_permute(ctx0, Vcur, 0, 2, 1, 3)); ggml_tensor * dflash_kq_mask_l = (hparams.swa_layers[il] && dflash_kq_mask_swa != nullptr) ? dflash_kq_mask_swa : dflash_kq_mask_full; cb(q, "q", il); - cb(k, "k", il); - cb(v, "v", il); + if (!have_workspace_ctx) { + cb(k, "dflash_main_k_perm_cont", il); + cb(v, "dflash_main_v_perm_cont", il); + } cur = ggml_flash_attn_ext(ctx0, q, k, v, dflash_kq_mask_l, kq_scale, hparams.f_max_alibi_bias, hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f); @@ -362,6 +526,12 @@ ggml_cgraph * llm_build_context::build_dflash() { cur = ggml_add(ctx0, cur, inpSA); cb(cur, "attn_residual", il); + if (inp_out_ids != nullptr && il == n_layer - 1) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + cb(cur, "result_output_rows", -1); + result_rows_selected = true; + } + ggml_tensor * ffn_residual = cur; cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, cb, il); cb(cur, "attn_post_norm", il); @@ -380,13 +550,13 @@ ggml_cgraph * llm_build_context::build_dflash() { inpL = cur; } - ggml_tensor * output = model.output; + ggml_tensor * output = const_cast(llama_model_dflash_output_tensor(&model)); if (output == nullptr) { output = ggml_new_tensor_2d(ctx0, GGML_TYPE_Q4_0, n_embd, hparams.n_vocab); } ggml_tensor * result_input = inpL; - if (inp_out_ids) { + if (inp_out_ids && !result_rows_selected) { result_input = ggml_get_rows(ctx0, result_input, inp_out_ids); cb(result_input, "result_output_rows", -1); } diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index fc03353c..eff7d675 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -2206,6 +2206,43 @@ struct ggml_cgraph * llm_build_context::llama_build_graph_dflash_kv_cache(llama_ return result; } +struct ggml_cgraph * llm_build_context::llama_build_graph_dflash_kv_workspace(llama_context & lctx) { + llama_batch dummy; + dummy.n_tokens = 0; + + llm_build_cb cb = [&](struct ggml_tensor * cur, const char * name, int il) { + if (il >= 0) { + int j = 0; + for (; j < GGML_MAX_NAME - 1; ++j) { + cur->name[j] = name[j]; + if (!name[j]) { + break; + } + } + if (j < GGML_MAX_NAME - 3) { + cur->name[j++] = '-'; + auto sil = std::to_string(il); + for (int k = 0; k < (int) sil.size() && j < GGML_MAX_NAME - 1; ++k) { + cur->name[j++] = sil[k]; + } + } + cur->name[j] = 0; + } else { + ggml_set_name(cur, name); + } + }; + + struct llm_build_context llm(lctx, dummy, cb, false, false, 0, false, &lctx.dflash_workspace_buf_compute_meta); + + llm.init(); + + struct ggml_cgraph * result = llm.build_dflash_kv_workspace(); + + llm.free(); + + return result; +} + ggml_cgraph * llm_build_context::llama_build_graph( llama_context & lctx, const llama_batch & batch, diff --git a/src/llama-build-context.h b/src/llama-build-context.h index ec7cbbb9..a33fdf39 100644 --- a/src/llama-build-context.h +++ b/src/llama-build-context.h @@ -249,6 +249,8 @@ struct llm_build_context { ggml_cgraph * build_dflash_kv_cache(); + ggml_cgraph * build_dflash_kv_workspace(); + ggml_cgraph * build_starcoder2(); ggml_cgraph * build_mamba(); @@ -466,6 +468,8 @@ llm_expert_gating_func_type gating_op, static ggml_cgraph * llama_build_graph_dflash_kv_cache(llama_context & lctx); + static ggml_cgraph * llama_build_graph_dflash_kv_workspace(llama_context & lctx); + static ggml_cgraph * llama_build_graph(llama_context & lctx, const llama_batch & batch, bool worst_case, int n_outputs = 0); ggml_tensor * build_std_attention(ggml_cgraph * gf, ggml_tensor * attn_norm, ggml_tensor * cur, diff --git a/src/llama-context.h b/src/llama-context.h index 1a7a9d80..8ad9d74b 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -301,6 +301,8 @@ struct llama_context { int32_t dflash_visible_cross_ctx = 0; std::vector dflash_k_ctx_cache; std::vector dflash_v_ctx_cache; + std::vector dflash_k_ctx_workspace; + std::vector dflash_v_ctx_workspace; struct ggml_context * dflash_cache_ctx = nullptr; std::vector dflash_cache_bufs; int32_t dflash_kv_cache_write_pos = 0; @@ -312,8 +314,24 @@ struct llama_context { uint64_t dflash_kv_cache_applied_window_version = 0; bool dflash_kv_cache_valid = false; bool dflash_kv_cache_view_valid = false; + int32_t dflash_kv_workspace_write_pos = 0; + int32_t dflash_kv_workspace_n_filled = 0; + int32_t dflash_kv_workspace_reserved_rows = 0; + int32_t dflash_kv_workspace_token_capacity = 0; + int32_t dflash_kv_workspace_n_kv_total = 0; + uint64_t dflash_kv_workspace_applied_window_version = 0; + bool dflash_kv_workspace_valid = false; + bool dflash_kv_workspace_sync_pending = false; std::vector dflash_buf_compute_meta; + std::vector dflash_workspace_buf_compute_meta; ggml_backend_sched_t dflash_sched = nullptr; + ggml_backend_sched_t dflash_workspace_sched = nullptr; + ggml_cgraph * dflash_kv_graph = nullptr; + ggml_cgraph * dflash_kv_workspace_graph = nullptr; + int32_t dflash_kv_graph_rows = 0; + int32_t dflash_kv_graph_write_pos = 0; + int32_t dflash_kv_workspace_graph_rows = 0; + int32_t dflash_kv_workspace_graph_write_pos = 0; struct ggml_tensor * dflash_kv_input_target_features = nullptr; struct ggml_tensor * dflash_kv_input_pos_ctx = nullptr; struct ggml_tensor * dflash_kq_mask_tensor = nullptr; diff --git a/src/llama-load-tensors.cpp b/src/llama-load-tensors.cpp index bcd08ff4..ab84302d 100644 --- a/src/llama-load-tensors.cpp +++ b/src/llama-load-tensors.cpp @@ -2202,6 +2202,7 @@ bool create_tensors_helper::create_dflash_tensors(const LLM_TN & tn) { model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); model.output = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + model.output_mtp = create_tensor(ctx_output, "output_extra.weight", {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); if (model.output == nullptr && model.tok_embd != nullptr) { model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); } diff --git a/src/llama-spec-features.cpp b/src/llama-spec-features.cpp index ab8efddb..00c4b6e2 100644 --- a/src/llama-spec-features.cpp +++ b/src/llama-spec-features.cpp @@ -68,6 +68,11 @@ void llama_reset_dflash_kv_cache_state(struct llama_context * ctx) { ctx->dflash_kv_cache_applied_window_version = 0; ctx->dflash_kv_cache_valid = false; ctx->dflash_kv_cache_view_valid = false; + ctx->dflash_kv_workspace_write_pos = 0; + ctx->dflash_kv_workspace_n_filled = 0; + ctx->dflash_kv_workspace_applied_window_version = 0; + ctx->dflash_kv_workspace_valid = false; + ctx->dflash_kv_workspace_sync_pending = false; for (ggml_backend_buffer_t buf : ctx->dflash_cache_bufs) { if (buf != nullptr) { @@ -171,6 +176,65 @@ int32_t llama_model_dflash_target_mask_token_id(const struct llama_model * model return (int32_t) model->vocab.token_mask(); } +const struct ggml_tensor * llama_model_dflash_output_tensor( + const struct llama_model * model) { + if (model == nullptr) { + return nullptr; + } + + if (model->output_mtp != nullptr) { + return model->output_mtp; + } + + if (model->output != nullptr) { + return model->output; + } + + return model->tok_embd; +} + +static const char * llama_dflash_io_mode_name(int32_t io_mode) { + switch (io_mode) { + case LLAMA_DFLASH_IO_MODE_SHARED: + return "shared"; + case LLAMA_DFLASH_IO_MODE_SELF_CONTAINED: + return "self-contained"; + case LLAMA_DFLASH_IO_MODE_MIXED: + return "mixed"; + default: + return "invalid"; + } +} + +static const char * llama_dflash_output_head_kind( + const struct llama_model * draft_model, + const struct llama_model * target_model) { + const struct ggml_tensor * output = llama_model_dflash_output_tensor(draft_model); + if (output == nullptr) { + return "missing"; + } + + if (output == draft_model->tok_embd) { + return draft_model->tok_embd == (target_model ? target_model->tok_embd : nullptr) + ? "shared_token_embedding" + : "token_embedding"; + } + + if (draft_model->output_mtp != nullptr && output == draft_model->output_mtp) { + if (target_model != nullptr && target_model->output_mtp != nullptr && output == target_model->output_mtp) { + return "output_mtp"; + } + + if (std::strcmp(output->name, "output_extra.weight") == 0) { + return "output_extra"; + } + + return "output_mtp"; + } + + return "output"; +} + int32_t llama_model_dflash_io_mode( const struct llama_model * draft_model, const struct llama_model * target_model) { @@ -178,13 +242,14 @@ int32_t llama_model_dflash_io_mode( return LLAMA_DFLASH_IO_MODE_INVALID; } - const ggml_tensor * target_output = target_model->output != nullptr ? target_model->output : target_model->tok_embd; - if (draft_model->tok_embd == nullptr || draft_model->output == nullptr || target_model->tok_embd == nullptr || target_output == nullptr) { + const ggml_tensor * draft_output = llama_model_dflash_output_tensor(draft_model); + const ggml_tensor * target_output = llama_model_dflash_output_tensor(target_model); + if (draft_model->tok_embd == nullptr || draft_output == nullptr || target_model->tok_embd == nullptr || target_output == nullptr) { return LLAMA_DFLASH_IO_MODE_INVALID; } const bool shared_tok = draft_model->tok_embd == target_model->tok_embd; - const bool shared_output = draft_model->output == target_output; + const bool shared_output = draft_output == target_output; if (shared_tok && shared_output) { return LLAMA_DFLASH_IO_MODE_SHARED; } @@ -200,14 +265,15 @@ bool llama_model_dflash_io_tensors_match( const struct llama_model * draft_model, int32_t n_embd, int32_t n_vocab) { - if (draft_model == nullptr || draft_model->tok_embd == nullptr || draft_model->output == nullptr || n_embd <= 0 || n_vocab <= 0) { + const ggml_tensor * output = llama_model_dflash_output_tensor(draft_model); + if (draft_model == nullptr || draft_model->tok_embd == nullptr || output == nullptr || n_embd <= 0 || n_vocab <= 0) { return false; } return (int32_t) draft_model->tok_embd->ne[0] == n_embd && (int32_t) draft_model->tok_embd->ne[1] == n_vocab && - (int32_t) draft_model->output->ne[0] == n_embd && - (int32_t) draft_model->output->ne[1] == n_vocab; + (int32_t) output->ne[0] == n_embd && + (int32_t) output->ne[1] == n_vocab; } bool llama_model_share_dflash_io_tensors( @@ -232,7 +298,25 @@ bool llama_model_share_dflash_io_tensors( } } - return draft_model->tok_embd != nullptr && draft_model->output != nullptr; + const bool uses_shared_tok = draft_model->tok_embd == target_model->tok_embd; + const bool uses_shared_output = draft_model->output == target_model->output || + draft_model->output == target_model->tok_embd; + + if (draft_model->output_mtp == nullptr && target_model->output_mtp != nullptr && uses_shared_tok && uses_shared_output) { + draft_model->output_mtp = target_model->output_mtp; + } + + const struct ggml_tensor * output = llama_model_dflash_output_tensor(draft_model); + if (draft_model->tok_embd != nullptr && output != nullptr) { + LLAMA_LOG_INFO("%s: DFlash IO mode=%s output_head=%s tensor=%s type=%s\n", + __func__, + llama_dflash_io_mode_name(llama_model_dflash_io_mode(draft_model, target_model)), + llama_dflash_output_head_kind(draft_model, target_model), + output->name[0] != '\0' ? output->name : "(unnamed)", + ggml_type_name(output->type)); + } + + return draft_model->tok_embd != nullptr && output != nullptr; } bool llama_set_draft_input_hidden_state_copy( diff --git a/src/llama-spec-features.h b/src/llama-spec-features.h index d976c89a..1c327049 100644 --- a/src/llama-spec-features.h +++ b/src/llama-spec-features.h @@ -86,6 +86,13 @@ struct llama_dflash_profile_stats { uint64_t graph_kv_cache_read_concat_pad_calls = 0; uint64_t graph_kv_cache_cached_bytes = 0; uint64_t graph_kv_cache_calls = 0; + uint64_t graph_kv_workspace_build_us = 0; + uint64_t graph_kv_workspace_reserve_us = 0; + uint64_t graph_kv_workspace_reset_us = 0; + uint64_t graph_kv_workspace_alloc_us = 0; + uint64_t graph_kv_workspace_compute_us = 0; + uint64_t graph_kv_workspace_sync_us = 0; + uint64_t graph_kv_workspace_calls = 0; uint64_t graph_kv_node_fused_target_calls = 0; uint64_t graph_kv_node_fused_target_us = 0; uint64_t graph_kv_node_k_proj_calls = 0; @@ -100,6 +107,40 @@ struct llama_dflash_profile_stats { uint64_t graph_kv_node_k_store_us = 0; uint64_t graph_kv_node_v_store_calls = 0; uint64_t graph_kv_node_v_store_us = 0; + uint64_t graph_main_node_qcur_calls = 0; + uint64_t graph_main_node_qcur_us = 0; + uint64_t graph_main_node_k_draft_calls = 0; + uint64_t graph_main_node_k_draft_us = 0; + uint64_t graph_main_node_v_draft_calls = 0; + uint64_t graph_main_node_v_draft_us = 0; + uint64_t graph_main_node_k_ctx_view_calls = 0; + uint64_t graph_main_node_k_ctx_view_us = 0; + uint64_t graph_main_node_v_ctx_view_calls = 0; + uint64_t graph_main_node_v_ctx_view_us = 0; + uint64_t graph_main_node_k_concat_calls = 0; + uint64_t graph_main_node_k_concat_us = 0; + uint64_t graph_main_node_v_concat_calls = 0; + uint64_t graph_main_node_v_concat_us = 0; + uint64_t graph_main_node_k_pad_calls = 0; + uint64_t graph_main_node_k_pad_us = 0; + uint64_t graph_main_node_v_pad_calls = 0; + uint64_t graph_main_node_v_pad_us = 0; + uint64_t graph_main_node_k_perm_cont_calls = 0; + uint64_t graph_main_node_k_perm_cont_us = 0; + uint64_t graph_main_node_v_perm_cont_calls = 0; + uint64_t graph_main_node_v_perm_cont_us = 0; + uint64_t graph_main_node_flash_attn_calls = 0; + uint64_t graph_main_node_flash_attn_us = 0; + uint64_t graph_main_node_attn_out_calls = 0; + uint64_t graph_main_node_attn_out_us = 0; + uint64_t graph_main_node_ffn_calls = 0; + uint64_t graph_main_node_ffn_us = 0; + uint64_t graph_main_node_result_rows_calls = 0; + uint64_t graph_main_node_result_rows_us = 0; + uint64_t graph_main_node_result_norm_calls = 0; + uint64_t graph_main_node_result_norm_us = 0; + uint64_t graph_main_node_result_calls = 0; + uint64_t graph_main_node_result_us = 0; uint64_t graph_feature_bytes = 0; uint64_t graph_pos_bytes = 0; uint64_t graph_mask_bytes = 0; @@ -232,6 +273,9 @@ int32_t llama_model_dflash_io_mode( const struct llama_model * draft_model, const struct llama_model * target_model); +const struct ggml_tensor * llama_model_dflash_output_tensor( + const struct llama_model * model); + bool llama_model_dflash_io_tensors_match( const struct llama_model * draft_model, int32_t n_embd, diff --git a/src/llama.cpp b/src/llama.cpp index e53940a2..a1b63a73 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -190,12 +190,42 @@ enum llama_dflash_kv_node_kind { LLAMA_DFLASH_KV_NODE_V_STORE, }; +enum llama_dflash_main_node_kind { + LLAMA_DFLASH_MAIN_NODE_NONE = 0, + LLAMA_DFLASH_MAIN_NODE_QCUR, + LLAMA_DFLASH_MAIN_NODE_K_DRAFT, + LLAMA_DFLASH_MAIN_NODE_V_DRAFT, + LLAMA_DFLASH_MAIN_NODE_K_CTX_VIEW, + LLAMA_DFLASH_MAIN_NODE_V_CTX_VIEW, + LLAMA_DFLASH_MAIN_NODE_K_CONCAT, + LLAMA_DFLASH_MAIN_NODE_V_CONCAT, + LLAMA_DFLASH_MAIN_NODE_K_PAD, + LLAMA_DFLASH_MAIN_NODE_V_PAD, + LLAMA_DFLASH_MAIN_NODE_K_PERM_CONT, + LLAMA_DFLASH_MAIN_NODE_V_PERM_CONT, + LLAMA_DFLASH_MAIN_NODE_FLASH_ATTN, + LLAMA_DFLASH_MAIN_NODE_ATTN_OUT, + LLAMA_DFLASH_MAIN_NODE_FFN, + LLAMA_DFLASH_MAIN_NODE_RESULT_ROWS, + LLAMA_DFLASH_MAIN_NODE_RESULT_NORM, + LLAMA_DFLASH_MAIN_NODE_RESULT, +}; + struct llama_dflash_kv_node_profiler { llama_dflash_profile_stats * profile = nullptr; int64_t t_start_us = 0; llama_dflash_kv_node_kind active_kind = LLAMA_DFLASH_KV_NODE_NONE; }; +struct llama_dflash_main_node_profiler { + llama_dflash_profile_stats * profile = nullptr; + ggml_backend_sched_eval_callback prev_callback = nullptr; + void * prev_user_data = nullptr; + bool prev_active = false; + int64_t t_start_us = 0; + llama_dflash_main_node_kind active_kind = LLAMA_DFLASH_MAIN_NODE_NONE; +}; + static bool llama_dflash_tensor_name_has_prefix(const struct ggml_tensor * tensor, const char * prefix) { if (tensor == nullptr || prefix == nullptr || prefix[0] == '\0') { return false; @@ -204,6 +234,16 @@ static bool llama_dflash_tensor_name_has_prefix(const struct ggml_tensor * tenso return std::strncmp(tensor->name, prefix, std::strlen(prefix)) == 0; } +static bool llama_dflash_tensor_name_matches_label(const struct ggml_tensor * tensor, const char * label) { + if (!llama_dflash_tensor_name_has_prefix(tensor, label)) { + return false; + } + + const size_t label_len = std::strlen(label); + const char next = tensor->name[label_len]; + return next == '\0' || next == '-'; +} + static llama_dflash_kv_node_kind llama_dflash_kv_node_kind_from_tensor(const struct ggml_tensor * tensor) { if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_fused_target")) { return LLAMA_DFLASH_KV_NODE_FUSED_TARGET; @@ -268,6 +308,146 @@ static void llama_dflash_kv_node_profile_add( } } +static llama_dflash_main_node_kind llama_dflash_main_node_kind_from_tensor(const struct ggml_tensor * tensor) { + if (llama_dflash_tensor_name_has_prefix(tensor, "Qcur")) { + return LLAMA_DFLASH_MAIN_NODE_QCUR; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "Kcur_noise")) { + return LLAMA_DFLASH_MAIN_NODE_K_DRAFT; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "Vcur_noise")) { + return LLAMA_DFLASH_MAIN_NODE_V_DRAFT; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "Kcur_ctx_cache")) { + return LLAMA_DFLASH_MAIN_NODE_K_CTX_VIEW; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "Vcur_ctx_cache")) { + return LLAMA_DFLASH_MAIN_NODE_V_CTX_VIEW; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_k_concat")) { + return LLAMA_DFLASH_MAIN_NODE_K_CONCAT; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_v_concat")) { + return LLAMA_DFLASH_MAIN_NODE_V_CONCAT; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_k_pad")) { + return LLAMA_DFLASH_MAIN_NODE_K_PAD; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_v_pad")) { + return LLAMA_DFLASH_MAIN_NODE_V_PAD; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_k_perm_cont")) { + return LLAMA_DFLASH_MAIN_NODE_K_PERM_CONT; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_v_perm_cont")) { + return LLAMA_DFLASH_MAIN_NODE_V_PERM_CONT; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "flash_attn_reshaped")) { + return LLAMA_DFLASH_MAIN_NODE_NONE; + } + if (llama_dflash_tensor_name_matches_label(tensor, "flash_attn")) { + return LLAMA_DFLASH_MAIN_NODE_FLASH_ATTN; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "kqv_out")) { + return LLAMA_DFLASH_MAIN_NODE_ATTN_OUT; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "ffn_out")) { + return LLAMA_DFLASH_MAIN_NODE_FFN; + } + if (llama_dflash_tensor_name_matches_label(tensor, "result_output_rows")) { + return LLAMA_DFLASH_MAIN_NODE_RESULT_ROWS; + } + if (llama_dflash_tensor_name_matches_label(tensor, "result_norm")) { + return LLAMA_DFLASH_MAIN_NODE_RESULT_NORM; + } + if (llama_dflash_tensor_name_matches_label(tensor, "output")) { + return LLAMA_DFLASH_MAIN_NODE_RESULT; + } + if (llama_dflash_tensor_name_matches_label(tensor, "result_output")) { + return LLAMA_DFLASH_MAIN_NODE_RESULT; + } + + return LLAMA_DFLASH_MAIN_NODE_NONE; +} + +static void llama_dflash_main_node_profile_add( + llama_dflash_profile_stats & profile, + llama_dflash_main_node_kind kind, + uint64_t elapsed_us) { + switch (kind) { + case LLAMA_DFLASH_MAIN_NODE_QCUR: + profile.graph_main_node_qcur_calls++; + profile.graph_main_node_qcur_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_K_DRAFT: + profile.graph_main_node_k_draft_calls++; + profile.graph_main_node_k_draft_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_V_DRAFT: + profile.graph_main_node_v_draft_calls++; + profile.graph_main_node_v_draft_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_K_CTX_VIEW: + profile.graph_main_node_k_ctx_view_calls++; + profile.graph_main_node_k_ctx_view_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_V_CTX_VIEW: + profile.graph_main_node_v_ctx_view_calls++; + profile.graph_main_node_v_ctx_view_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_K_CONCAT: + profile.graph_main_node_k_concat_calls++; + profile.graph_main_node_k_concat_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_V_CONCAT: + profile.graph_main_node_v_concat_calls++; + profile.graph_main_node_v_concat_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_K_PAD: + profile.graph_main_node_k_pad_calls++; + profile.graph_main_node_k_pad_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_V_PAD: + profile.graph_main_node_v_pad_calls++; + profile.graph_main_node_v_pad_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_K_PERM_CONT: + profile.graph_main_node_k_perm_cont_calls++; + profile.graph_main_node_k_perm_cont_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_V_PERM_CONT: + profile.graph_main_node_v_perm_cont_calls++; + profile.graph_main_node_v_perm_cont_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_FLASH_ATTN: + profile.graph_main_node_flash_attn_calls++; + profile.graph_main_node_flash_attn_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_ATTN_OUT: + profile.graph_main_node_attn_out_calls++; + profile.graph_main_node_attn_out_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_FFN: + profile.graph_main_node_ffn_calls++; + profile.graph_main_node_ffn_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_RESULT_ROWS: + profile.graph_main_node_result_rows_calls++; + profile.graph_main_node_result_rows_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_RESULT_NORM: + profile.graph_main_node_result_norm_calls++; + profile.graph_main_node_result_norm_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_RESULT: + profile.graph_main_node_result_calls++; + profile.graph_main_node_result_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_NONE: + break; + } +} + static bool llama_dflash_kv_node_eval_callback(struct ggml_tensor * tensor, bool ask, void * user_data) { auto * profiler = static_cast(user_data); if (profiler == nullptr || profiler->profile == nullptr) { @@ -294,6 +474,62 @@ static bool llama_dflash_kv_node_eval_callback(struct ggml_tensor * tensor, bool return true; } +static bool llama_dflash_main_node_eval_callback(struct ggml_tensor * tensor, bool ask, void * user_data) { + auto * profiler = static_cast(user_data); + if (profiler == nullptr || profiler->profile == nullptr) { + return false; + } + + const llama_dflash_main_node_kind kind = llama_dflash_main_node_kind_from_tensor(tensor); + if (ask) { + profiler->prev_active = profiler->prev_callback != nullptr + ? profiler->prev_callback(tensor, ask, profiler->prev_user_data) + : false; + + if (kind == LLAMA_DFLASH_MAIN_NODE_NONE) { + profiler->active_kind = LLAMA_DFLASH_MAIN_NODE_NONE; + profiler->t_start_us = 0; + return profiler->prev_active; + } + + profiler->active_kind = kind; + profiler->t_start_us = ggml_time_us(); + return true; + } + + bool prev_result = false; + if (profiler->prev_active && profiler->prev_callback != nullptr) { + prev_result = profiler->prev_callback(tensor, ask, profiler->prev_user_data); + } + + const bool tracked = kind != LLAMA_DFLASH_MAIN_NODE_NONE && + profiler->active_kind == kind && + profiler->t_start_us > 0; + if (tracked) { + llama_dflash_main_node_profile_add(*profiler->profile, kind, (uint64_t) (ggml_time_us() - profiler->t_start_us)); + } + + profiler->prev_active = false; + profiler->active_kind = LLAMA_DFLASH_MAIN_NODE_NONE; + profiler->t_start_us = 0; + return prev_result || tracked; +} + +static bool llama_dflash_use_kv_workspace_experiment() { + return llama_env_flag_enabled("IK_DFLASH_KV_WORKSPACE"); +} + +static void llama_sync_dflash_workspace_if_pending(struct llama_context & lctx) { + if (!lctx.dflash_kv_workspace_sync_pending || lctx.dflash_workspace_sched == nullptr) { + return; + } + + const int64_t t_workspace_sync_us = ggml_time_us(); + ggml_backend_sched_synchronize(lctx.dflash_workspace_sched); + lctx.dflash_profile.graph_kv_workspace_sync_us += (uint64_t) (ggml_time_us() - t_workspace_sync_us); + lctx.dflash_kv_workspace_sync_pending = false; +} + // extract ip and port from RPC[ip:port] for rpc and keep other device names static std::vector extract_device_from_rpc_device(std::vector devices) { std::vector rpc_servers; @@ -727,16 +963,26 @@ static ggml_backend_t llama_backend_for_tensor(const llama_context & lctx, const } bool llama_context::ensure_dflash_kv_cache_tensors(int32_t cross_ctx) { + const bool use_kv_workspace = llama_env_flag_enabled("IK_DFLASH_KV_WORKSPACE"); const int32_t target_cross_ctx = std::max(1, cross_ctx); + const int32_t target_token_capacity = std::max(1, (int32_t) model.hparams.dflash_block_size); + const int32_t target_workspace_n_kv_total = GGML_PAD(target_cross_ctx + target_token_capacity, cparams.flash_attn ? 256 : 32); const int32_t n_layer = model.hparams.n_layer; const int64_t n_embd_head_k = model.hparams.n_embd_head_k(0); const int64_t n_embd_head_v = model.hparams.n_embd_head_v(0); const int64_t n_head_kv = model.hparams.n_head_kv(); if (dflash_cache_ctx != nullptr && !dflash_k_ctx_cache.empty()) { - if ((int32_t) dflash_k_ctx_cache.size() == n_layer && - dflash_k_ctx_cache.front() != nullptr && - (int32_t) dflash_k_ctx_cache.front()->ne[2] == target_cross_ctx) { + const bool cache_matches = (int32_t) dflash_k_ctx_cache.size() == n_layer && + dflash_k_ctx_cache.front() != nullptr && + (int32_t) dflash_k_ctx_cache.front()->ne[2] == target_cross_ctx; + const bool workspace_matches = use_kv_workspace + ? ((int32_t) dflash_k_ctx_workspace.size() == n_layer && + dflash_k_ctx_workspace.front() != nullptr && + (int32_t) dflash_k_ctx_workspace.front()->ne[1] == target_workspace_n_kv_total) + : dflash_k_ctx_workspace.empty() && dflash_v_ctx_workspace.empty(); + + if (cache_matches && workspace_matches) { return true; } @@ -745,11 +991,23 @@ bool llama_context::ensure_dflash_kv_cache_tensors(int32_t cross_ctx) { ggml_backend_sched_free(dflash_sched); dflash_sched = nullptr; } + if (dflash_workspace_sched != nullptr) { + ggml_backend_sched_free(dflash_workspace_sched); + dflash_workspace_sched = nullptr; + } + dflash_kv_graph = nullptr; + dflash_kv_workspace_graph = nullptr; + dflash_kv_graph_rows = 0; + dflash_kv_graph_write_pos = 0; + dflash_kv_workspace_graph_rows = 0; + dflash_kv_workspace_graph_write_pos = 0; + dflash_kv_workspace_reserved_rows = 0; dflash_buf_compute_meta.clear(); + dflash_workspace_buf_compute_meta.clear(); } ggml_init_params params = { - /*.mem_size =*/ (size_t) (2 * std::max(1, n_layer)) * ggml_tensor_overhead(), + /*.mem_size =*/ (size_t) ((use_kv_workspace ? 4 : 2) * std::max(1, n_layer)) * ggml_tensor_overhead(), /*.mem_buffer =*/ nullptr, /*.no_alloc =*/ true, }; @@ -761,8 +1019,14 @@ bool llama_context::ensure_dflash_kv_cache_tensors(int32_t cross_ctx) { dflash_k_ctx_cache.resize((size_t) n_layer); dflash_v_ctx_cache.resize((size_t) n_layer); + dflash_k_ctx_workspace.clear(); + dflash_v_ctx_workspace.clear(); + if (use_kv_workspace) { + dflash_k_ctx_workspace.resize((size_t) n_layer); + dflash_v_ctx_workspace.resize((size_t) n_layer); + } dflash_cache_bufs.clear(); - dflash_cache_bufs.reserve((size_t) std::max(1, n_layer) * 2); + dflash_cache_bufs.reserve((size_t) std::max(1, n_layer) * (use_kv_workspace ? 4 : 2)); int32_t host_layers = 0; const char * first_buft_name = nullptr; const char * last_buft_name = nullptr; @@ -809,9 +1073,47 @@ bool llama_context::ensure_dflash_kv_cache_tensors(int32_t cross_ctx) { ggml_backend_tensor_alloc(v_buf, dflash_v_ctx_cache[(size_t) il], ggml_backend_buffer_get_base(v_buf)); ggml_backend_buffer_clear(v_buf, 0); dflash_cache_bufs.push_back(v_buf); + + if (use_kv_workspace) { + dflash_k_ctx_workspace[(size_t) il] = ggml_new_tensor_3d(dflash_cache_ctx, GGML_TYPE_F32, n_embd_head_k, target_workspace_n_kv_total, n_head_kv); + dflash_v_ctx_workspace[(size_t) il] = ggml_new_tensor_3d(dflash_cache_ctx, GGML_TYPE_F32, n_embd_head_v, target_workspace_n_kv_total, n_head_kv); + if (dflash_k_ctx_workspace[(size_t) il] == nullptr || dflash_v_ctx_workspace[(size_t) il] == nullptr) { + free_dflash_kv_cache_tensors(); + return false; + } + + ggml_set_input(dflash_k_ctx_workspace[(size_t) il]); + ggml_set_input(dflash_v_ctx_workspace[(size_t) il]); + ggml_format_name(dflash_k_ctx_workspace[(size_t) il], "dflash_k_ctx_workspace_%d", il); + ggml_format_name(dflash_v_ctx_workspace[(size_t) il], "dflash_v_ctx_workspace_%d", il); + + const size_t k_workspace_bytes = ggml_backend_buft_get_alloc_size(layer_buft, dflash_k_ctx_workspace[(size_t) il]); + ggml_backend_buffer_t k_workspace_buf = ggml_backend_buft_alloc_buffer(layer_buft, k_workspace_bytes); + if (k_workspace_buf == nullptr) { + free_dflash_kv_cache_tensors(); + return false; + } + ggml_backend_buffer_set_usage(k_workspace_buf, GGML_BACKEND_BUFFER_USAGE_COMPUTE); + ggml_backend_tensor_alloc(k_workspace_buf, dflash_k_ctx_workspace[(size_t) il], ggml_backend_buffer_get_base(k_workspace_buf)); + ggml_backend_buffer_clear(k_workspace_buf, 0); + dflash_cache_bufs.push_back(k_workspace_buf); + + const size_t v_workspace_bytes = ggml_backend_buft_get_alloc_size(layer_buft, dflash_v_ctx_workspace[(size_t) il]); + ggml_backend_buffer_t v_workspace_buf = ggml_backend_buft_alloc_buffer(layer_buft, v_workspace_bytes); + if (v_workspace_buf == nullptr) { + free_dflash_kv_cache_tensors(); + return false; + } + ggml_backend_buffer_set_usage(v_workspace_buf, GGML_BACKEND_BUFFER_USAGE_COMPUTE); + ggml_backend_tensor_alloc(v_workspace_buf, dflash_v_ctx_workspace[(size_t) il], ggml_backend_buffer_get_base(v_workspace_buf)); + ggml_backend_buffer_clear(v_workspace_buf, 0); + dflash_cache_bufs.push_back(v_workspace_buf); + } } dflash_profile.last_kv_cache_host_layers = host_layers; + dflash_kv_workspace_token_capacity = use_kv_workspace ? target_token_capacity : 0; + dflash_kv_workspace_n_kv_total = use_kv_workspace ? target_workspace_n_kv_total : 0; llama_reset_dflash_kv_cache_state(this); LLAMA_LOG_INFO("%s: DFlash K/V cache placement cross_ctx=%d host_layers=%d/%d first=%s last=%s\n", __func__, @@ -827,6 +1129,8 @@ bool llama_context::ensure_dflash_kv_cache_tensors(int32_t cross_ctx) { void llama_context::free_dflash_kv_cache_tensors() { dflash_k_ctx_cache.clear(); dflash_v_ctx_cache.clear(); + dflash_k_ctx_workspace.clear(); + dflash_v_ctx_workspace.clear(); dflash_kv_cache_write_pos = 0; dflash_kv_cache_n_filled = 0; dflash_kv_cache_update_rows = 0; @@ -836,11 +1140,31 @@ void llama_context::free_dflash_kv_cache_tensors() { dflash_kv_cache_applied_window_version = 0; dflash_kv_cache_valid = false; dflash_kv_cache_view_valid = false; + dflash_kv_workspace_write_pos = 0; + dflash_kv_workspace_n_filled = 0; + dflash_kv_workspace_reserved_rows = 0; + dflash_kv_workspace_token_capacity = 0; + dflash_kv_workspace_n_kv_total = 0; + dflash_kv_workspace_applied_window_version = 0; + dflash_kv_workspace_valid = false; + dflash_kv_workspace_sync_pending = false; + dflash_kv_graph = nullptr; + dflash_kv_workspace_graph = nullptr; + dflash_kv_graph_rows = 0; + dflash_kv_graph_write_pos = 0; + dflash_kv_workspace_graph_rows = 0; + dflash_kv_workspace_graph_write_pos = 0; dflash_kv_input_target_features = nullptr; dflash_kv_input_pos_ctx = nullptr; dflash_kq_mask_tensor = nullptr; dflash_kq_mask_swa_tensor = nullptr; + if (dflash_workspace_sched != nullptr) { + ggml_backend_sched_synchronize(dflash_workspace_sched); + ggml_backend_sched_free(dflash_workspace_sched); + dflash_workspace_sched = nullptr; + } + for (ggml_backend_buffer_t buf : dflash_cache_bufs) { if (buf != nullptr) { ggml_backend_buffer_free(buf); @@ -4229,7 +4553,7 @@ static bool llm_load_tensors( if (model.arch == LLM_ARCH_GEMMA4) { llm_scale_gate_inp_s(model, use_mmap_buffer); } - if ((model.arch == LLM_ARCH_QWEN35 || model.arch == LLM_ARCH_QWEN35MOE) && extra_output_type != GGML_TYPE_COUNT) { + if ((model.arch == LLM_ARCH_QWEN35 || model.arch == LLM_ARCH_QWEN35MOE || model.arch == LLM_ARCH_DFLASH_DRAFT) && extra_output_type != GGML_TYPE_COUNT) { llm_requantize_output_tensor(model, extra_output_type); } @@ -5405,6 +5729,7 @@ static bool prepare_dflash_graph_inputs( struct llama_context & lctx, uint32_t n_tokens) { const bool use_kv_cache = llama_env_flag_enabled("IK_DFLASH_KV_CACHE"); + const bool use_kv_workspace = use_kv_cache && llama_dflash_use_kv_workspace_experiment(); const bool kv_node_timing = llama_env_flag_enabled("IK_DFLASH_KV_NODE_TIMING"); auto & profile = lctx.dflash_profile; const int32_t cross_ctx = lctx.dflash_visible_cross_ctx > 0 @@ -5456,6 +5781,10 @@ static bool prepare_dflash_graph_inputs( profile.last_n_tokens = (int32_t) n_tokens; profile.last_n_kv_total = n_kv_total; + if (use_kv_workspace) { + llama_sync_dflash_workspace_if_pending(lctx); + } + if (graph_cross_ctx != cross_ctx) { profile.graph_shape_failures++; @@ -5585,6 +5914,9 @@ static bool prepare_dflash_graph_inputs( ggml_backend_sched_free(lctx.dflash_sched); lctx.dflash_sched = nullptr; } + lctx.dflash_kv_graph = nullptr; + lctx.dflash_kv_graph_rows = 0; + lctx.dflash_kv_graph_write_pos = 0; const int32_t saved_update_rows = lctx.dflash_kv_cache_update_rows; lctx.dflash_kv_cache_update_rows = cross_ctx; @@ -5631,23 +5963,35 @@ static bool prepare_dflash_graph_inputs( } lctx.dflash_kv_cache_update_rows = update_rows; - const int64_t t_build_us = ggml_time_us(); - ggml_cgraph * gf_kv = llm_build_context::llama_build_graph_dflash_kv_cache(lctx); - profile.graph_kv_cache_build_us += (uint64_t) (ggml_time_us() - t_build_us); - if (gf_kv == nullptr || lctx.dflash_kv_input_target_features == nullptr || lctx.dflash_kv_input_pos_ctx == nullptr) { - profile.graph_shape_failures++; - LLAMA_LOG_ERROR("%s: failed to build DFlash K/V cache graph\n", __func__); - return false; + ggml_cgraph * gf_kv = nullptr; + const bool can_reuse_kv_graph = lctx.dflash_kv_graph != nullptr && + lctx.dflash_kv_graph_rows == update_rows && + lctx.dflash_kv_graph_write_pos == lctx.dflash_kv_cache_write_pos; + if (can_reuse_kv_graph) { + gf_kv = lctx.dflash_kv_graph; + } else { + const int64_t t_build_us = ggml_time_us(); + gf_kv = llm_build_context::llama_build_graph_dflash_kv_cache(lctx); + profile.graph_kv_cache_build_us += (uint64_t) (ggml_time_us() - t_build_us); + if (gf_kv == nullptr || lctx.dflash_kv_input_target_features == nullptr || lctx.dflash_kv_input_pos_ctx == nullptr) { + profile.graph_shape_failures++; + LLAMA_LOG_ERROR("%s: failed to build DFlash K/V cache graph\n", __func__); + return false; + } + + const int64_t t_reset_us = ggml_time_us(); + ggml_backend_sched_reset(lctx.dflash_sched); + profile.graph_kv_cache_reset_us += (uint64_t) (ggml_time_us() - t_reset_us); + + const int64_t t_alloc_us = ggml_time_us(); + ggml_backend_sched_alloc_graph(lctx.dflash_sched, gf_kv); + profile.graph_kv_cache_alloc_us += (uint64_t) (ggml_time_us() - t_alloc_us); + + lctx.dflash_kv_graph = gf_kv; + lctx.dflash_kv_graph_rows = update_rows; + lctx.dflash_kv_graph_write_pos = lctx.dflash_kv_cache_write_pos; } - const int64_t t_reset_us = ggml_time_us(); - ggml_backend_sched_reset(lctx.dflash_sched); - profile.graph_kv_cache_reset_us += (uint64_t) (ggml_time_us() - t_reset_us); - - const int64_t t_alloc_us = ggml_time_us(); - ggml_backend_sched_alloc_graph(lctx.dflash_sched, gf_kv); - profile.graph_kv_cache_alloc_us += (uint64_t) (ggml_time_us() - t_alloc_us); - ggml_backend_t kv_feature_backend = llama_backend_for_tensor(lctx, lctx.dflash_kv_input_target_features); const int64_t t_feature_upload_us = ggml_time_us(); if (kv_feature_backend != nullptr) { @@ -5692,6 +6036,108 @@ static bool prepare_dflash_graph_inputs( lctx.dflash_kv_cache_view_write_pos = lctx.dflash_kv_cache_write_pos; lctx.dflash_kv_cache_view_valid = true; } + + if (use_kv_workspace && lctx.dflash_kv_cache_view_valid && + !lctx.dflash_k_ctx_workspace.empty() && !lctx.dflash_v_ctx_workspace.empty()) { + const bool need_workspace_refresh = !lctx.dflash_kv_workspace_valid || + lctx.dflash_kv_workspace_n_filled != lctx.dflash_kv_cache_view_n_filled || + lctx.dflash_kv_workspace_write_pos != lctx.dflash_kv_cache_view_write_pos || + lctx.dflash_kv_workspace_applied_window_version != lctx.dflash_kv_cache_applied_window_version; + + if (need_workspace_refresh) { + const size_t max_nodes = lctx.model.max_nodes((int) std::max(1, cross_ctx)) + 16 * lctx.model.hparams.n_layer; + const size_t meta_size = ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false); + if (lctx.dflash_workspace_buf_compute_meta.size() != meta_size) { + lctx.dflash_workspace_buf_compute_meta.resize(meta_size); + } + + ggml_cgraph * gf_workspace = nullptr; + const bool can_reuse_workspace_graph = lctx.dflash_kv_workspace_graph != nullptr && + lctx.dflash_kv_workspace_graph_rows == lctx.dflash_kv_cache_view_n_filled && + lctx.dflash_kv_workspace_graph_write_pos == lctx.dflash_kv_cache_view_write_pos; + + if (can_reuse_workspace_graph) { + gf_workspace = lctx.dflash_kv_workspace_graph; + } else { + const int64_t t_build_us = ggml_time_us(); + gf_workspace = llm_build_context::llama_build_graph_dflash_kv_workspace(lctx); + profile.graph_kv_workspace_build_us += (uint64_t) (ggml_time_us() - t_build_us); + if (gf_workspace == nullptr) { + profile.graph_shape_failures++; + LLAMA_LOG_ERROR("%s: failed to build DFlash K/V workspace graph\n", __func__); + return false; + } + + std::vector backend_buft; + backend_buft.reserve(lctx.backends.size()); + for (auto * backend : lctx.backends) { + if (ggml_backend_is_cpu(backend)) { + backend_buft.push_back(llama_default_buffer_type_cpu(true)); + } else { + backend_buft.push_back(ggml_backend_get_default_buffer_type(backend)); + } + } + + if (lctx.dflash_workspace_sched == nullptr) { + lctx.dflash_workspace_sched = ggml_backend_sched_new(lctx.backends.data(), backend_buft.data(), lctx.backends.size(), max_nodes, false); + } + + if (lctx.dflash_kv_workspace_reserved_rows != cross_ctx) { + const bool saved_view_valid = lctx.dflash_kv_cache_view_valid; + const int32_t saved_view_rows = lctx.dflash_kv_cache_view_n_filled; + const int32_t saved_view_write_pos = lctx.dflash_kv_cache_view_write_pos; + + lctx.dflash_kv_cache_view_valid = true; + lctx.dflash_kv_cache_view_n_filled = cross_ctx; + lctx.dflash_kv_cache_view_write_pos = cross_ctx > 1 ? 1 : 0; + + const int64_t t_reserve_build_us = ggml_time_us(); + ggml_cgraph * gf_workspace_reserve = llm_build_context::llama_build_graph_dflash_kv_workspace(lctx); + profile.graph_kv_workspace_build_us += (uint64_t) (ggml_time_us() - t_reserve_build_us); + + lctx.dflash_kv_cache_view_valid = saved_view_valid; + lctx.dflash_kv_cache_view_n_filled = saved_view_rows; + lctx.dflash_kv_cache_view_write_pos = saved_view_write_pos; + + const int64_t t_reserve_us = ggml_time_us(); + const bool reserved = lctx.dflash_workspace_sched != nullptr && + gf_workspace_reserve != nullptr && + ggml_backend_sched_reserve(lctx.dflash_workspace_sched, gf_workspace_reserve); + profile.graph_kv_workspace_reserve_us += (uint64_t) (ggml_time_us() - t_reserve_us); + if (!reserved) { + profile.graph_shape_failures++; + LLAMA_LOG_ERROR("%s: failed to initialize DFlash K/V workspace scheduler\n", __func__); + return false; + } + + lctx.dflash_kv_workspace_reserved_rows = cross_ctx; + } + + const int64_t t_reset_us = ggml_time_us(); + ggml_backend_sched_reset(lctx.dflash_workspace_sched); + profile.graph_kv_workspace_reset_us += (uint64_t) (ggml_time_us() - t_reset_us); + + const int64_t t_alloc_us = ggml_time_us(); + ggml_backend_sched_alloc_graph(lctx.dflash_workspace_sched, gf_workspace); + profile.graph_kv_workspace_alloc_us += (uint64_t) (ggml_time_us() - t_alloc_us); + + lctx.dflash_kv_workspace_graph = gf_workspace; + lctx.dflash_kv_workspace_graph_rows = lctx.dflash_kv_cache_view_n_filled; + lctx.dflash_kv_workspace_graph_write_pos = lctx.dflash_kv_cache_view_write_pos; + } + + const int64_t t_workspace_us = ggml_time_us(); + llama_graph_compute_sched(lctx, lctx.dflash_workspace_sched, gf_workspace, lctx.cparams.n_threads); + profile.graph_kv_workspace_compute_us += (uint64_t) (ggml_time_us() - t_workspace_us); + lctx.dflash_kv_workspace_sync_pending = true; + profile.graph_kv_workspace_calls++; + + lctx.dflash_kv_workspace_n_filled = lctx.dflash_kv_cache_view_n_filled; + lctx.dflash_kv_workspace_write_pos = lctx.dflash_kv_cache_view_write_pos; + lctx.dflash_kv_workspace_applied_window_version = lctx.dflash_kv_cache_applied_window_version; + lctx.dflash_kv_workspace_valid = true; + } + } } else { ggml_backend_tensor_set(lctx.inp_dflash_target_features, lctx.dflash_target_features_padded.data(), 0, ggml_nbytes(lctx.inp_dflash_target_features)); ggml_backend_tensor_set(lctx.inp_dflash_pos_ctx, lctx.dflash_pos_ctx_data.data(), 0, ggml_nbytes(lctx.inp_dflash_pos_ctx)); @@ -5806,6 +6252,7 @@ static int llama_decode_internal( auto * dflash_profile = lctx.model.arch == LLM_ARCH_DFLASH_DRAFT ? &lctx.dflash_profile : nullptr; const bool dflash_decode_timing = dflash_profile != nullptr && llama_env_flag_enabled("IK_DFLASH_DECODE_TIMING"); + const bool dflash_draft_node_timing = dflash_profile != nullptr && llama_env_flag_enabled("IK_DFLASH_DRAFT_NODE_TIMING"); auto & kv_self = lctx.kv_self; @@ -6165,8 +6612,21 @@ static int llama_decode_internal( #if IK_PRINT_TIMING tim1 = ggml_time_us(); #endif + if (lctx.dflash_kv_workspace_sync_pending) { + llama_sync_dflash_workspace_if_pending(lctx); + } const int64_t t_dflash_graph_compute_us = dflash_decode_timing ? ggml_time_us() : 0; + llama_dflash_main_node_profiler draft_node_profiler; + if (dflash_draft_node_timing) { + draft_node_profiler.profile = dflash_profile; + draft_node_profiler.prev_callback = lctx.cparams.cb_eval; + draft_node_profiler.prev_user_data = lctx.cparams.cb_eval_user_data; + ggml_backend_sched_set_eval_callback(lctx.sched, llama_dflash_main_node_eval_callback, &draft_node_profiler); + } llama_graph_compute(lctx, gf, n_threads); + if (dflash_draft_node_timing) { + ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data); + } if (dflash_decode_timing) { llama_synchronize(&lctx); dflash_profile->decode_sync_profile_points++; From 1250f522ed557e1c38fcec91cf76b6634ccf180e Mon Sep 17 00:00:00 2001 From: SamuelOliveirads Date: Mon, 1 Jun 2026 17:14:25 -0300 Subject: [PATCH 07/13] add qwen, gemma and kimi dflash support --- convert_hf_to_gguf.py | 102 ++++++++++++++++++++++++++++- examples/server/server-context.cpp | 12 ++++ 2 files changed, 111 insertions(+), 3 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index acf3ecb7..dd2766fa 100644 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2287,6 +2287,7 @@ class DFlashDraftModel(Qwen3Model): model_arch = gguf.MODEL_ARCH.DFLASH_DRAFT _target_hparams: dict[str, Any] | None = None + _target_raw_hparams: dict[str, Any] | None = None _saw_token_embd = False _saw_output = False @@ -2300,10 +2301,83 @@ class DFlashDraftModel(Qwen3Model): self._target_hparams = Model.load_text_hparams(self._require_target_model_dir()) return self._target_hparams + def _get_target_raw_hparams(self) -> dict[str, Any]: + if self._target_raw_hparams is None: + self._target_raw_hparams = Model.load_hparams(self._require_target_model_dir()) + return self._target_raw_hparams + + def _target_uses_gemma4_vocab(self) -> bool: + raw_hparams = self._get_target_raw_hparams() + model_type = str(raw_hparams.get("model_type", "")) + if model_type.startswith("gemma4"): + return True + architectures = raw_hparams.get("architectures") + if isinstance(architectures, list): + return any(str(arch).startswith("Gemma4") for arch in architectures) + return False + + def _get_target_hidden_size(self) -> int | None: + raw_hparams = self._get_target_raw_hparams() + if (hidden_size := raw_hparams.get("hidden_size")) is not None: + return int(hidden_size) + if (hidden_size := raw_hparams.get("backbone_hidden_size")) is not None: + return int(hidden_size) + text_hparams = raw_hparams.get("text_config") + if isinstance(text_hparams, dict) and (hidden_size := text_hparams.get("hidden_size")) is not None: + return int(hidden_size) + return None + + def _set_vocab_gemma4(self, dir_model: Path, vocab_size: int | None = None) -> None: + vocab = gguf.LlamaHfVocab(dir_model) + tokens = [] + scores = [] + toktypes = [] + visible_tokens = { + "<|channel>", + "", + "<|tool_call>", + "", + "<|tool_response>", + "", + "<|\"|>", + } + + for text, score, toktype in vocab.all_tokens(): + tokens.append(text) + scores.append(score) + text_str = text.decode() + if text_str in visible_tokens: + toktypes.append(gguf.TokenType.USER_DEFINED) + logger.info(f"Token {text_str!r} is set to USER_DEFINED") + else: + toktypes.append(toktype) + + if vocab_size is not None and len(tokens) != int(vocab_size): + raise ValueError( + f"DFlashDraftModel: Gemma4 tokenizer size {len(tokens)} does not match expected vocab_size={int(vocab_size)}" + ) + + self.gguf_writer.add_tokenizer_model("gemma4") + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_scores(scores) + self.gguf_writer.add_token_types(toktypes) + + special_vocab = gguf.SpecialVocab(dir_model, load_merges=True) + special_vocab.add_to_gguf(self.gguf_writer) + self.gguf_writer.add_add_space_prefix(False) + self.gguf_writer.add_add_bos_token(True) + def set_vocab(self): target_hparams = self._get_target_hparams() + target_model_dir = self._require_target_model_dir() + if self._target_uses_gemma4_vocab(): + self._set_vocab_gemma4( + dir_model=target_model_dir, + vocab_size=target_hparams.get("vocab_size"), + ) + return self._set_vocab_gpt2( - dir_model=self._require_target_model_dir(), + dir_model=target_model_dir, vocab_size=target_hparams.get("vocab_size"), ) @@ -2313,6 +2387,29 @@ class DFlashDraftModel(Qwen3Model): self.gguf_writer.add_causal_attention(False) self.gguf_writer.add_rope_dimension_count(self.hparams.get("head_dim", 128)) + rope_scaling = self.hparams.get("rope_scaling") + if isinstance(rope_scaling, dict): + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type")) + rope_factor = rope_scaling.get("factor") + + if rope_type == "linear" and rope_factor is not None: + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) + self.gguf_writer.add_rope_scaling_factor(rope_factor) + elif rope_type == "yarn" and rope_factor is not None: + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) + self.gguf_writer.add_rope_scaling_factor(rope_factor) + + if (orig_ctx_len := rope_scaling.get("original_max_position_embeddings")) is not None: + self.gguf_writer.add_rope_scaling_orig_ctx_len(orig_ctx_len) + if (yarn_ext_factor := rope_scaling.get("extrapolation_factor")) is not None: + self.gguf_writer.add_rope_scaling_yarn_ext_factor(yarn_ext_factor) + if (yarn_attn_factor := rope_scaling.get("attention_factor", rope_scaling.get("attn_factor"))) is not None: + self.gguf_writer.add_rope_scaling_yarn_attn_factor(yarn_attn_factor) + if (yarn_beta_fast := rope_scaling.get("beta_fast")) is not None: + self.gguf_writer.add_rope_scaling_yarn_beta_fast(yarn_beta_fast) + if (yarn_beta_slow := rope_scaling.get("beta_slow")) is not None: + self.gguf_writer.add_rope_scaling_yarn_beta_slow(yarn_beta_slow) + arch = self.gguf_writer.arch dflash_cfg = self.hparams.get("dflash_config") dflash_cfg = dflash_cfg if isinstance(dflash_cfg, dict) else {} @@ -2340,8 +2437,7 @@ class DFlashDraftModel(Qwen3Model): elif "n_target_features" in self.hparams: n_target_features = int(self.hparams["n_target_features"]) else: - target_hparams = self._get_target_hparams() - target_hidden_size = target_hparams.get("hidden_size") + target_hidden_size = self._get_target_hidden_size() if target_hidden_size is None: raise ValueError("DFlashDraftModel: target config is missing hidden_size") diff --git a/examples/server/server-context.cpp b/examples/server/server-context.cpp index 81139ac1..25d2e09d 100644 --- a/examples/server/server-context.cpp +++ b/examples/server/server-context.cpp @@ -446,6 +446,18 @@ bool server_context::load_model(const gpt_params& params_) { params_dft.flash_attn = params_base.flash_attn; params_dft.k_cache_hadamard = params_base.k_cache_hadamard; params_dft.v_cache_hadamard = params_base.v_cache_hadamard; + if (server_speculative_has_dflash(params_base.speculative)) { + params_dft.split_mode = params_base.split_mode; + for (size_t i = 0; i < std::size(params_dft.tensor_split); ++i) { + params_dft.tensor_split[i] = params_base.tensor_split[i]; + } + params_dft.attn_max_batch = params_base.attn_max_batch; + params_dft.graph_reuse = params_base.graph_reuse; + params_dft.split_mode_graph_scheduling = params_base.split_mode_graph_scheduling; + params_dft.scheduler_async = params_base.scheduler_async; + params_dft.max_extra_alloc_MiB = params_base.max_extra_alloc_MiB; + params_dft.reduce_type = params_base.reduce_type; + } if (!params_base.speculative.params.empty()) { auto [argc, argv] = parse_command_line("llama-server " + params_base.speculative.params); if (!gpt_params_parse(argc, argv, params_dft)) { From dc43cdf06b0f1761090a579027f4bab352c01046 Mon Sep 17 00:00:00 2001 From: SamuelOliveirads Date: Tue, 2 Jun 2026 10:22:13 -0300 Subject: [PATCH 08/13] move dflash for it own file --- common/speculative-impl.h | 1740 +++++++++++++++++++++++++++ common/speculative.cpp | 1741 +--------------------------- src/CMakeLists.txt | 2 + src/llama-build-context.cpp | 10 +- src/llama-context.h | 227 ++-- src/llama-dflash.cpp | 1240 ++++++++++++++++++++ src/llama-dflash.h | 8 + src/llama-quantize.cpp | 5 +- src/llama-spec-features-dflash.cpp | 1097 ++++++++++++++++++ src/llama-spec-features-dflash.h | 279 +++++ src/llama-spec-features.cpp | 1084 ----------------- src/llama-spec-features.h | 304 +---- src/llama.cpp | 848 +------------- 13 files changed, 4533 insertions(+), 4052 deletions(-) create mode 100644 common/speculative-impl.h create mode 100644 src/llama-dflash.cpp create mode 100644 src/llama-dflash.h create mode 100644 src/llama-spec-features-dflash.cpp create mode 100644 src/llama-spec-features-dflash.h diff --git a/common/speculative-impl.h b/common/speculative-impl.h new file mode 100644 index 00000000..47603461 --- /dev/null +++ b/common/speculative-impl.h @@ -0,0 +1,1740 @@ +// DFlash runtime state and draft path. +struct common_speculative_state_dflash : public common_speculative_state { + llama_context * ctx_tgt; + llama_context * ctx_dft; + + llama_batch batch = {}; + + int32_t block_size = 0; + int32_t mask_token_id = -1; + int32_t n_target_features = 0; + int32_t cross_ctx = 0; + bool ready = false; + + std::vector target_layer_ids; + std::vector target_window; + std::vector target_window_pos; + std::vector target_window_stage; + std::vector target_window_pos_stage; + std::vector target_window_ring; + std::vector target_window_append_features; + int32_t target_window_rows = 0; + int32_t target_window_ring_write_pos = 0; + int32_t target_window_ring_filled = 0; + uint64_t target_window_version = 0; + int32_t target_window_keep_rows = 0; + int32_t target_window_append_rows = 0; + bool target_window_replace = false; + bool target_window_materialized = false; + llama_pos last_target_pos = -1; + size_t n_window_updates = 0; + size_t n_rows_seen = 0; + size_t n_rows_dropped = 0; + size_t n_context_shifts = 0; + size_t n_draft_empty = 0; + size_t n_set_target_fail = 0; + size_t n_decode_fail = 0; + llama_pos last_draft_pos_base = -1; + + uint64_t t_draft_decode_us = 0; + uint64_t t_draft_sample_us = 0; + uint64_t t_warmup_collect_us = 0; + uint64_t t_warmup_append_us = 0; + uint64_t t_accept_output_copy_us = 0; + uint64_t t_accept_commit_us = 0; + uint64_t t_accept_append_us = 0; + uint64_t t_accept_append_filter_us = 0; + uint64_t t_accept_append_window_alloc_us = 0; + uint64_t t_accept_append_replace_us = 0; + uint64_t t_accept_append_keep_old_us = 0; + uint64_t t_accept_append_new_rows_us = 0; + uint64_t t_accept_append_commit_detail_us = 0; + uint64_t t_accept_append_log_us = 0; + size_t n_warmup_collect_calls = 0; + size_t n_warmup_collect_rows = 0; + size_t n_warmup_append_calls = 0; + size_t n_warmup_append_rows = 0; + size_t n_accept_output_copy_calls = 0; + size_t n_accept_output_copy_rows = 0; + size_t n_accept_commit_calls = 0; + size_t n_accept_commit_rows = 0; + size_t n_accept_append_calls = 0; + size_t n_accept_append_rows = 0; + size_t n_accept_append_replace_calls = 0; + size_t n_accept_append_slide_calls = 0; + + common_speculative_state_dflash( + enum common_speculative_type type, + llama_context * ctx_tgt, + llama_context * ctx_dft, + int32_t cross_ctx) + : common_speculative_state(type) + , ctx_tgt(ctx_tgt) + , ctx_dft(ctx_dft) + , cross_ctx(std::max(1, cross_ctx)) + { + const llama_model * model_tgt = llama_get_model(ctx_tgt); + const llama_model * model_dft = llama_get_model(ctx_dft); + + if (!common_speculative_are_dflash_compatible(model_tgt, model_dft)) { + LOG_ERR("%s: DFlash draft model vocab/tokenizer is incompatible with the target model\n", __func__); + return; + } + + block_size = llama_model_dflash_block_size(model_dft); + mask_token_id = llama_model_dflash_mask_token_id(model_dft); + n_target_features = llama_model_dflash_n_target_features(model_dft); + const int32_t n_target_layers = llama_model_dflash_n_target_layers(model_dft); + + if (block_size <= 0 || mask_token_id < 0 || n_target_features <= 0 || n_target_layers <= 0) { + LOG_ERR("%s: invalid DFlash metadata (block_size=%d, mask_token_id=%d, n_target_features=%d, n_target_layers=%d)\n", + __func__, block_size, mask_token_id, n_target_features, n_target_layers); + return; + } + + target_layer_ids.resize((size_t) n_target_layers); + if (llama_model_dflash_target_layer_ids(model_dft, target_layer_ids.data(), n_target_layers) != n_target_layers) { + LOG_ERR("%s: failed to read DFlash target layer ids\n", __func__); + target_layer_ids.clear(); + return; + } + + const auto * vocab_tgt = llama_model_get_vocab(model_tgt); + const auto * vocab_dft = llama_model_get_vocab(model_dft); + const int32_t target_vocab_size = llama_vocab_n_tokens(vocab_tgt); + const int32_t draft_vocab_size = llama_vocab_n_tokens(vocab_dft); + const int32_t target_hidden_size = llama_model_n_embd(model_tgt); + const int32_t draft_hidden_size = llama_model_n_embd(model_dft); + const int32_t target_mask_token_id = llama_model_dflash_target_mask_token_id(model_tgt); + const int32_t expected_n_target_features = target_hidden_size > 0 ? target_hidden_size * n_target_layers : 0; + + if (target_mask_token_id != (int32_t) LLAMA_TOKEN_NULL && mask_token_id != target_mask_token_id) { + LOG_ERR("%s: DFlash mask token mismatch (draft=%d target=%d)\n", + __func__, mask_token_id, target_mask_token_id); + return; + } + + if (target_hidden_size <= 0 || draft_hidden_size <= 0) { + LOG_ERR("%s: invalid DFlash hidden sizes (draft=%d target=%d)\n", + __func__, draft_hidden_size, target_hidden_size); + return; + } + + if (expected_n_target_features <= 0 || n_target_features != expected_n_target_features) { + LOG_ERR("%s: DFlash target feature width mismatch (metadata=%d expected=%d target_hidden=%d target_layers=%d)\n", + __func__, n_target_features, expected_n_target_features, target_hidden_size, n_target_layers); + return; + } + + std::vector sorted_target_layer_ids = target_layer_ids; + std::sort(sorted_target_layer_ids.begin(), sorted_target_layer_ids.end()); + if (std::adjacent_find(sorted_target_layer_ids.begin(), sorted_target_layer_ids.end()) != sorted_target_layer_ids.end()) { + LOG_ERR("%s: duplicate DFlash target layer ids survived into runtime validation\n", __func__); + target_layer_ids.clear(); + return; + } + + const int32_t n_target_model_layers = llama_n_layer(model_tgt); + for (int32_t layer_id : target_layer_ids) { + if (layer_id < 0 || layer_id >= n_target_model_layers) { + LOG_ERR("%s: invalid DFlash target layer id %d for target model with %d layers\n", + __func__, layer_id, n_target_model_layers); + target_layer_ids.clear(); + return; + } + } + + const int32_t io_mode = llama_model_dflash_io_mode(model_dft, model_tgt); + if (io_mode == LLAMA_DFLASH_IO_MODE_INVALID) { + LOG_ERR("%s: DFlash draft is missing required IO tensors after target sharing\n", __func__); + return; + } + + if (io_mode == LLAMA_DFLASH_IO_MODE_MIXED) { + LOG_ERR("%s: DFlash IO contract must be fully shared or fully self-contained, but resolved to mixed mode\n", __func__); + return; + } + + if (io_mode == LLAMA_DFLASH_IO_MODE_SELF_CONTAINED && !llama_model_dflash_io_tensors_match(model_dft, target_hidden_size, target_vocab_size)) { + LOG_ERR("%s: DFlash self-contained IO tensors do not match the target hidden/vocab contract (target_hidden=%d target_vocab=%d)\n", + __func__, + target_hidden_size, + target_vocab_size); + return; + } + + if (!llama_set_dflash_capture_layers(ctx_tgt, target_layer_ids.data(), (int32_t) target_layer_ids.size())) { + LOG_ERR("%s: failed to configure DFlash target capture callback\n", __func__); + return; + } + + batch = llama_batch_init(std::max(1, block_size), 0, 1); + target_window.reserve((size_t) this->cross_ctx * (size_t) n_target_features); + target_window_stage.reserve((size_t) this->cross_ctx * (size_t) n_target_features); + target_window_ring.resize((size_t) this->cross_ctx * (size_t) n_target_features); + target_window_append_features.reserve((size_t) this->cross_ctx * (size_t) n_target_features); + target_window_pos.reserve((size_t) this->cross_ctx); + target_window_pos_stage.reserve((size_t) this->cross_ctx); + ready = true; + + llama_set_dflash_visible_cross_ctx(ctx_dft, this->cross_ctx); + llama_dflash_profile_reset(ctx_tgt); + llama_dflash_profile_reset(ctx_dft); + + std::ostringstream layers_oss; + for (size_t i = 0; i < target_layer_ids.size(); ++i) { + if (i > 0) { + layers_oss << ","; + } + layers_oss << target_layer_ids[i]; + } + + const char * io_mode_name = io_mode == LLAMA_DFLASH_IO_MODE_SHARED ? "shared" : "self-contained"; + LOG_INF("%s: DFlash context ready (n_ctx=%d, block_size=%d, cross_ctx=%d, n_target_features=%d, target_layer_ids=[%s])\n", + __func__, llama_n_ctx(ctx_dft), block_size, this->cross_ctx, n_target_features, layers_oss.str().c_str()); + LOG_INF("%s: DFlash artifact io=%s draft_vocab=%d target_vocab=%d draft_hidden=%d target_hidden=%d mask_token_id=%d target_mask_token_id=%d\n", + __func__, io_mode_name, draft_vocab_size, target_vocab_size, draft_hidden_size, target_hidden_size, mask_token_id, target_mask_token_id); + } + + ~common_speculative_state_dflash() override { + llama_clear_dflash_capture(ctx_tgt); + if (ctx_dft) { + llama_free(ctx_dft); + } + if (batch.token != nullptr) { + llama_batch_free(batch); + } + } + + void begin(const llama_tokens & prompt) override { + GGML_UNUSED(prompt); + llama_kv_cache_clear(ctx_dft); + llama_reset_dflash_kv_cache_state(ctx_dft); + n_window_updates = 0; + n_rows_seen = 0; + n_rows_dropped = 0; + n_context_shifts = 0; + n_draft_empty = 0; + n_set_target_fail = 0; + n_decode_fail = 0; + last_draft_pos_base = -1; + t_draft_decode_us = 0; + t_draft_sample_us = 0; + t_warmup_collect_us = 0; + t_warmup_append_us = 0; + t_accept_output_copy_us = 0; + t_accept_commit_us = 0; + t_accept_append_us = 0; + t_accept_append_filter_us = 0; + t_accept_append_window_alloc_us = 0; + t_accept_append_replace_us = 0; + t_accept_append_keep_old_us = 0; + t_accept_append_new_rows_us = 0; + t_accept_append_commit_detail_us = 0; + t_accept_append_log_us = 0; + n_warmup_collect_calls = 0; + n_warmup_collect_rows = 0; + n_warmup_append_calls = 0; + n_warmup_append_rows = 0; + n_accept_output_copy_calls = 0; + n_accept_output_copy_rows = 0; + n_accept_commit_calls = 0; + n_accept_commit_rows = 0; + n_accept_append_calls = 0; + n_accept_append_rows = 0; + n_accept_append_replace_calls = 0; + n_accept_append_slide_calls = 0; + llama_dflash_profile_reset(ctx_tgt); + llama_dflash_profile_reset(ctx_dft); + } + + void draft( + const common_params_speculative & params, + const llama_tokens & prompt_tgt, + llama_token id_last, + llama_tokens & result) override { + GGML_UNUSED(prompt_tgt); + + result.clear(); + if (!ready || target_window_rows <= 0) { + n_draft_empty++; + return; + } + + const int32_t n_keep = std::min(params.n_max, block_size - 1); + if (n_keep <= 0) { + return; + } + + const bool use_kv_cache = dflash_use_kv_cache_experiment(); + const float * target_features = nullptr; + size_t target_feature_floats = 0; + llama_dflash_window_update window_update = { + target_window_version, + target_window_keep_rows, + target_window_append_rows, + target_window_replace, + target_window_append_features.empty() ? nullptr : target_window_append_features.data(), + target_window_append_features.size(), + }; + const llama_dflash_kv_cache_transition cache_plan = use_kv_cache + ? llama_plan_dflash_kv_cache_transition_for_ctx(ctx_dft, window_update, target_window_rows) + : llama_dflash_kv_cache_transition{}; + + if (!use_kv_cache || cache_plan.rebuild_cache) { + dflash_materialize_target_window_features(*this); + target_features = target_window.data(); + target_feature_floats = target_window.size(); + } + if (use_kv_cache && cache_plan.rebuild_cache) { + window_update.append_features = target_window.data(); + window_update.append_floats = target_window.size(); + window_update.append_rows = target_window_rows; + } + + if (!llama_set_dflash_target_features_view(ctx_dft, target_features, target_feature_floats, target_window_rows, target_window_pos.data(), &window_update)) { + LOG_ERR("%s: failed to set DFlash target features\n", __func__); + n_set_target_fail++; + return; + } + + llama_kv_cache_clear(ctx_dft); + batch.n_tokens = 0; + const int32_t batch_len = n_keep + 1; + const llama_pos draft_pos_base = last_target_pos >= 0 ? last_target_pos + 1 : (llama_pos) target_window_rows; + const llama_pos seed_pos = last_target_pos >= 0 ? last_target_pos : draft_pos_base - 1; + last_draft_pos_base = draft_pos_base; + common_batch_add(batch, id_last, seed_pos, { 0 }, false); + for (int32_t i = 1; i < batch_len; ++i) { + common_batch_add(batch, mask_token_id, draft_pos_base + (i - 1), { 0 }, i <= n_keep); + } + + const int64_t t_decode_us = ggml_time_us(); + if (llama_decode(ctx_dft, batch) != 0) { + LOG_ERR("%s: llama_decode() failed for DFlash draft batch\n", __func__); + n_decode_fail++; + batch.n_tokens = 0; + return; + } + t_draft_decode_us += (uint64_t) (ggml_time_us() - t_decode_us); + + result.reserve((size_t) n_keep); + const int64_t t_sample_us = ggml_time_us(); + for (int32_t i = 0; i < n_keep; ++i) { + result.push_back(common_sampler_sample_speculative(nullptr, ctx_dft, i + 1, nullptr)); + } + t_draft_sample_us += (uint64_t) (ggml_time_us() - t_sample_us); + + batch.n_tokens = 0; + dflash_contract_log_draft(*this, n_keep, result.size()); + } + + void accept(uint16_t n_accepted) override { + GGML_UNUSED(n_accepted); + } +}; + +static void dflash_contract_log_append( + const common_speculative_state_dflash & state, + llama_seq_id seq_id, + const std::vector & new_positions) { + if (!dflash_contract_log_enabled()) { + return; + } + + static std::atomic counter = 0; + const uint64_t ordinal = counter.fetch_add(1, std::memory_order_relaxed); + if (ordinal >= 8) { + return; + } + + const dflash_contract_pos_summary incoming = dflash_contract_summarize_positions(new_positions); + const dflash_contract_pos_summary window = dflash_contract_summarize_positions(state.target_window_pos); + + LOG_INF("dflash contract append[%llu]: seq=%d incoming_rows=%zu incoming_pos=%s pos=[%d..%d] gaps=%d nonmono=%d window_rows=%d window_pos=%s pos=[%d..%d] gaps=%d nonmono=%d last_target_pos=%d\n", + (unsigned long long) (ordinal + 1), + (int) seq_id, + new_positions.size(), + dflash_contract_format_values(new_positions).c_str(), + (int) incoming.first, + (int) incoming.last, + incoming.gap_count, + incoming.nonmono_count, + state.target_window_rows, + dflash_contract_format_values(state.target_window_pos).c_str(), + (int) window.first, + (int) window.last, + window.gap_count, + window.nonmono_count, + (int) state.last_target_pos); +} + +static void dflash_contract_log_draft( + const common_speculative_state_dflash & state, + int32_t n_keep, + size_t result_size) { + if (!dflash_contract_log_enabled()) { + return; + } + + static std::atomic counter = 0; + const uint64_t ordinal = counter.fetch_add(1, std::memory_order_relaxed); + if (ordinal >= 8) { + return; + } + + const dflash_contract_pos_summary window = dflash_contract_summarize_positions(state.target_window_pos); + llama_dflash_profile_stats graph_stats = {}; + llama_dflash_profile_get_stats(state.ctx_dft, &graph_stats); + const int draft_delta = (state.last_target_pos >= 0 && state.last_draft_pos_base >= 0) + ? (int) (state.last_draft_pos_base - state.last_target_pos) + : -1; + const llama_pos seed_pos = state.last_target_pos; + const llama_pos mask_first_pos = state.last_draft_pos_base; + const llama_pos mask_last_pos = state.last_draft_pos_base >= 0 + ? state.last_draft_pos_base + n_keep - 1 + : -1; + + LOG_INF("dflash contract draft[%llu]: window_rows=%d window_pos=%s pos=[%d..%d] gaps=%d nonmono=%d last_target_pos=%d seed_pos=%d mask_pos=[%d..%d] sample_rows=[1..%d] output_rows=[1..%d] draft_pos_base=%d delta=%d n_keep=%d result=%zu set_target(missing/nonmono)=%llu/%llu graph(fallback/nonmono)=%llu/%llu graph_pos=[%d..%d]\n", + (unsigned long long) (ordinal + 1), + state.target_window_rows, + dflash_contract_format_values(state.target_window_pos).c_str(), + (int) window.first, + (int) window.last, + window.gap_count, + window.nonmono_count, + (int) state.last_target_pos, + (int) seed_pos, + (int) mask_first_pos, + (int) mask_last_pos, + n_keep, + n_keep, + (int) state.last_draft_pos_base, + draft_delta, + n_keep, + result_size, + (unsigned long long) graph_stats.set_target_missing_positions, + (unsigned long long) graph_stats.set_target_non_monotonic_positions, + (unsigned long long) graph_stats.graph_pos_fallbacks, + (unsigned long long) graph_stats.graph_pos_non_monotonic, + (int) graph_stats.last_pos_first, + (int) graph_stats.last_pos_last); +} + +struct common_speculative_state_draft : public common_speculative_state { + llama_context * ctx_tgt; // only used for retokenizing from ctx_dft + llama_context * ctx_dft; + + common_sampler * smpl; + + llama_batch batch; + llama_tokens prompt_dft; + + bool vocab_cmpt = true; // whether retokenization is needed + std::unordered_map vocab_map; + + common_speculative_state_draft( + enum common_speculative_type type, + llama_context * ctx_tgt, + llama_context * ctx_dft, + const std::vector> & replacements) + : common_speculative_state(type) + , ctx_tgt(ctx_tgt) + , ctx_dft(ctx_dft) + { + batch = llama_batch_init(llama_n_batch(ctx_dft), 0, 1); + smpl = nullptr; + { + struct common_params_sampling params; + params.top_k = 10; + params.samplers_sequence = { + llama_sampler_type::TOP_K, + llama_sampler_type::DIST, // needed to get probabilities + }; + smpl = common_sampler_init(llama_get_model(ctx_dft), params); + } + + vocab_cmpt = common_speculative_are_compatible(llama_get_model(ctx_tgt), llama_get_model(ctx_dft)); + LOG_DBG("vocab_cmpt = %d\n", vocab_cmpt); + + if (!vocab_cmpt) { + LOG_WRN("the target and draft vocabs are not compatible - tokens will be translated between the two\n"); + + for (const auto & pair : replacements) { + vocab_map[pair.first] = pair.second; + } + } + } + + ~common_speculative_state_draft() override { + llama_free(ctx_dft); + + common_sampler_free(smpl); + + llama_batch_free(batch); + } + + void begin(const llama_tokens & prompt) override { + GGML_UNUSED(prompt); + } + + void draft( + const common_params_speculative & params, + const llama_tokens & prompt_tgt, + llama_token id_last, + llama_tokens & result) override { + auto * spec = this; + + auto & batch = spec->batch; + auto & ctx_tgt = spec->ctx_tgt; + auto & ctx_dft = spec->ctx_dft; + auto & smpl = spec->smpl; + auto & prompt_dft = spec->prompt_dft; + + int reuse_i = 0; + int reuse_n = 0; + + const int n_ctx = llama_n_ctx(ctx_dft) - params.n_max; + + llama_tokens prompt_cnv; + if (!spec->vocab_cmpt) { + // convert id_last to draft vocab. llama_detokenize is called directly to avoid an allocation + const auto * model_tgt = llama_get_model(ctx_tgt); + const auto * vocab_tgt = llama_model_get_vocab(model_tgt); + + std::string text; + + text = common_detokenize(ctx_tgt, prompt_tgt, true); + text = replace_to_dft(text); + + LOG_DBG("%s: main->draft detokenized string: '%s'\n", __func__, text.c_str()); + + prompt_cnv = common_tokenize(ctx_dft, text, false, true); + + + + int32_t n_chars = llama_detokenize(vocab_tgt, &id_last, 1, nullptr, 0, false, false); + GGML_ASSERT(n_chars < 0 && "failed to detokenize id_last"); + + text.resize(-n_chars); + llama_detokenize(vocab_tgt, &id_last, 1, text.data(), text.size(), false, false); + text = replace_to_dft(text); + + LOG_DBG("main->draft detokenized id_last(%d): '%s'\n", id_last, text.c_str()); + id_last = common_tokenize(ctx_dft, text, false, true)[0]; + } + + const llama_tokens & prompt_cur = spec->vocab_cmpt ? prompt_tgt : prompt_cnv; + + const int i_start = std::max(0, (int) prompt_cur.size() - n_ctx); + + // reuse as much as possible from the old draft context + // ideally, the draft context should be as big as the target context and we will always reuse the entire prompt + for (int i = 0; i < (int) prompt_dft.size(); ++i) { + int cur = 0; + while (i_start + cur < (int) prompt_cur.size() && + i + cur < (int) prompt_dft.size() && + prompt_cur[i_start + cur] == prompt_dft[i + cur]) { + cur++; + } + + if ((cur >= 256 || n_ctx >= (int) prompt_cur.size()) && cur > reuse_n) { + reuse_i = i; + reuse_n = cur; + } + } + + LOG_DBG("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt_dft.size()); + + result.clear(); + result.reserve(params.n_max); + + if (reuse_n == 0) { + llama_kv_cache_clear(ctx_dft); + prompt_dft.clear(); + } else { + // this happens when a previous draft has been discarded (for example, due to being too small), but the + // target model agreed with it. in this case, we simply pass back the previous results to save compute + if (reuse_i + reuse_n < (int) prompt_dft.size() && prompt_dft[reuse_i + reuse_n] == id_last) { + for (int i = reuse_i + reuse_n + 1; i < (int) prompt_dft.size(); ++i) { + result.push_back(prompt_dft[i]); + + if (params.n_max <= (int) result.size()) { + break; + } + } + + return; + } + + if (reuse_i > 0) { + llama_kv_cache_seq_rm (ctx_dft, 0, 0, reuse_i); + llama_kv_cache_seq_add(ctx_dft, 0, reuse_i, -1, -reuse_i); + + prompt_dft.erase(prompt_dft.begin(), prompt_dft.begin() + reuse_i); + } + + if (reuse_n < (int) prompt_dft.size()) { + llama_kv_cache_seq_rm (ctx_dft, 0, reuse_n, -1); + prompt_dft.erase(prompt_dft.begin() + reuse_n, prompt_dft.end()); + } + } + + // prepare a batch to evaluate any new tokens in the prompt + common_batch_clear(batch); + + for (size_t i = i_start + reuse_n; i < prompt_cur.size(); ++i) { + //LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_cur[i]); + common_batch_add(batch, prompt_cur[i], i - i_start, { 0 }, false); + + prompt_dft.push_back(prompt_cur[i]); + } + + // we should rarely end-up here during normal decoding + if (batch.n_tokens > 0) { + //LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str()); + + llama_decode(ctx_dft, batch); + } + + const llama_pos n_past = prompt_dft.size(); + + LOG_DBG("%s: n_past = %d\n", __func__, n_past); + + common_batch_clear(batch); + common_batch_add (batch, id_last, n_past, { 0 }, true); + + prompt_dft.push_back(id_last); + + //LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx_dft, prompt_dft).c_str()); + + llama_decode(ctx_dft, batch); + + common_sampler_reset(smpl); + + // sample n_draft tokens from the draft model + for (int i = 0; i < params.n_max; ++i) { + common_batch_clear(batch); + + common_sampler_sample(smpl, ctx_dft, 0, true); + + const auto * cur_p = common_sampler_get_candidates(smpl, true); + + for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) { + LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", + k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str()); + } + + // add drafted token for each sequence + const llama_token id = cur_p->data[0].id; + + common_sampler_accept(smpl, nullptr, id, true); + + // only collect very high-confidence draft tokens + if (cur_p->data[0].p < params.p_min) { + if (i == 0) { + result.push_back(id); + } + break; + } + + result.push_back(id); + + if (params.n_max <= (int) result.size()) { + break; + } + + + common_batch_add(batch, id, n_past + i + 1, { 0 }, true); + + // evaluate the drafted tokens on the draft model + llama_decode(ctx_dft, batch); + + prompt_dft.push_back(id); + } + + if (!spec->vocab_cmpt) { + std::string detokenized = common_detokenize(ctx_dft, result, true); + detokenized = replace_to_tgt(detokenized); + LOG_DBG("draft->main detokenized string: '%s'\n", detokenized.c_str()); + result = common_tokenize(ctx_tgt, detokenized, false, true); + if (result.size() > (size_t)params.n_max) { + result.resize(params.n_max); + } + } + } + + void accept(uint16_t n_accepted) override { + // noop + GGML_UNUSED(n_accepted); + } + + std::string replace_to_dft(const std::string & input) const { + std::string result = input; + + for (const auto & pair : this->vocab_map) { + size_t pos = result.find(pair.first); + while (pos != std::string::npos) { + result.replace(pos, pair.first.length(), pair.second); + pos = result.find(pair.first, pos + pair.second.length()); + } + } + + return result; + } + + std::string replace_to_tgt(const std::string & input) const { + std::string result = input; + + for (const auto & pair : this->vocab_map) { + size_t pos = result.find(pair.second); + while (pos != std::string::npos) { + result.replace(pos, pair.second.length(), pair.first); + pos = result.find(pair.second, pos + pair.first.length()); + } + } + + return result; + } +}; + +struct common_speculative_state_eagle3 : public common_speculative_state { + common_speculative_state_eagle3(enum common_speculative_type type) : common_speculative_state(type) {} + + void begin(const llama_tokens & prompt) override { + GGML_UNUSED(prompt); + } + + void draft( + const common_params_speculative & params, + const llama_tokens & prompt_tgt, + llama_token id_last, + llama_tokens & draft_tokens) override { + // TODO: implement + GGML_UNUSED(params); + GGML_UNUSED(prompt_tgt); + GGML_UNUSED(id_last); + GGML_UNUSED(draft_tokens); + } + + void accept(uint16_t n_accepted) override { + // noop + GGML_UNUSED(n_accepted); + } +}; + +// state of self-speculation (simple implementation, not ngram-map) +struct common_speculative_state_ngram_simple : public common_speculative_state { + common_ngram_simple_config config; + + common_speculative_state_ngram_simple( + enum common_speculative_type type, + common_ngram_simple_config config) + : common_speculative_state(type), config(config) {} + + void begin(const llama_tokens & prompt) override { + GGML_UNUSED(prompt); + } + + void draft( + const common_params_speculative & params, + const llama_tokens & prompt_tgt, + llama_token id_last, + llama_tokens & result) override { + + result = common_ngram_simple_draft(config, prompt_tgt, id_last); + GGML_UNUSED(params); + } + + void accept(uint16_t n_accepted) override { + // noop + GGML_UNUSED(n_accepted); + } +}; + +struct common_speculative_state_ngram_map_k : public common_speculative_state { + // draft ngram map for speculative decoding without draft model + common_ngram_map map; + + common_speculative_state_ngram_map_k( + enum common_speculative_type type, + common_ngram_map map) + : common_speculative_state(type), map(std::move(map)) {} + + void begin(const llama_tokens & prompt) override { + common_ngram_map_begin(map, prompt); + } + + void draft( + const common_params_speculative & params, + const llama_tokens & prompt_tgt, + llama_token id_last, + llama_tokens & result) override { + common_ngram_map_draft(map, prompt_tgt, id_last, result); + GGML_UNUSED(params); + } + + void accept(uint16_t n_accepted) override { + common_ngram_map_accept(map, n_accepted); + } +}; + +struct common_speculative_state_ngram_mod : public common_speculative_state { + common_ngram_mod & mod; + + // the last position in the prompt that was added to the ngram container + size_t i_last = 0; + + // length of the last drafted n‑gram (number of tokens returned by draft) + size_t n_draft_last = 0; + + // consecutive accept rounds with low acceptance fraction (< 0.5) + int n_low = 0; + + // enable trace logging if LLAMA_TRACE is set + const bool verbose; + + common_speculative_state_ngram_mod(enum common_speculative_type type, common_ngram_mod & mod) + : common_speculative_state(type), mod(mod), verbose(std::getenv("LLAMA_TRACE") != nullptr) { + static_assert(sizeof(llama_token) == sizeof(common_ngram_mod::entry_t)); + } + + void begin(const llama_tokens & prompt) override { + i_last = 0; + + n_draft_last = 0; + n_low = 0; + + const size_t n = mod.get_n(); + + if (prompt.size() < n) { + return; + } + + for (size_t i = 0; i < prompt.size() - n; ++i) { + mod.add(prompt.data() + i); + } + + i_last = prompt.size() - n; + + const double f = (double)mod.get_used() / (double)mod.size(); + LOG_INF("%s: ngram_mod occupancy = %zu/%zu (%.2f)\n", __func__, mod.get_used(), mod.size(), f); + + constexpr double f_thold = 0.25; + if (f > f_thold) { + LOG_WRN("%s: ngram_mod occupancy %.2f exceeds threshold (%.2f) - resetting\n", __func__, f, f_thold); + + mod.reset(); + } + } + + void draft( + const common_params_speculative & params, + const llama_tokens & prompt_tgt, + llama_token id_last, + llama_tokens & result) override { + GGML_UNUSED(params); + + n_draft_last = 0; + + const size_t cur_len = prompt_tgt.size(); + if (cur_len < mod.get_n()) { + return; + } + + const size_t n = mod.get_n(); + + // add new ngrams in chunks + if (i_last + 32 < cur_len) { + for (size_t i = i_last; i < cur_len - n; ++i) { + mod.add(prompt_tgt.data() + i); + } + + i_last = cur_len - n; + } + + result.resize(n + params.n_max); + for (size_t i = 0; i < n - 1; ++i) { + result[i] = prompt_tgt[cur_len - n + 1 + i]; + } + result[n - 1] = id_last; + + for (int i = 0; i < params.n_max; ++i) { + const llama_token token = mod.get(result.data() + i); + if (token == common_ngram_mod::EMPTY) { + if (i < params.n_min) { + result.clear(); + return; + } + + result.resize(n + i); + break; + } + result[n + i] = token; + } + + // only return the m tokens that were drafted + for (size_t i = 0; n + i < result.size(); ++i) { + result[i] = result[n + i]; + } + result.resize(result.size() - n); + + // store length of drafted n‑gram for later acceptance analysis + n_draft_last = result.size(); + } + + void accept(uint16_t n_accepted) override { + if (verbose) { + LOG_INF("%s: accepted %d tokens from %zu drafted tokens\n", __func__, n_accepted, n_draft_last); + } + + // compute acceptance fraction if we have a recorded draft length + if (n_draft_last > 0) { + const double f_acc = (double)n_accepted / (double)n_draft_last; + if (f_acc < 0.5) { + n_low++; + if (n_low >= 3) { + LOG_WRN("%s: low acceptance streak (%d) – resetting ngram_mod\n", __func__, n_low); + + mod.reset(); + n_low = 0; + i_last = 0; + } + } else { + n_low = 0; + } + } + } +}; + +struct common_speculative_state_ngram_cache : public common_speculative_state { + uint16_t n_draft; + bool save_dynamic; + bool save_static; + + common_ngram_cache ngram_cache_context; + common_ngram_cache ngram_cache_dynamic; + common_ngram_cache ngram_cache_static; + + size_t cache_size = 0; // number of tokens in n-gram cache + + common_speculative_state_ngram_cache( + const enum common_speculative_type type, + const std::string & path_static, + const std::string & path_dynamic, + uint16_t n_draft, + bool save_dynamic, + bool save_static) + : common_speculative_state(type) + , n_draft(n_draft) + , save_dynamic(save_dynamic) + , save_static(save_static) + { + if (!path_static.empty()) { + try { + ngram_cache_static = common_ngram_cache_load(path_static); + } catch (...) { + LOG_ERR("failed to open static lookup cache: %s", path_static.c_str()); + GGML_ABORT("Couldn't read static lookup cache"); + } + } + + if (!path_dynamic.empty()) { + try { + ngram_cache_dynamic = common_ngram_cache_load(path_dynamic); + } catch (...) { + LOG_ERR("failed to open dynamic lookup cache: %s", path_dynamic.c_str()); + GGML_ABORT("Couldn't read dynamic lookup cache"); + } + } + } + + void begin(const llama_tokens & prompt) override { + GGML_UNUSED(prompt); + } + + void draft( + const common_params_speculative & params, + const llama_tokens & prompt_tgt, + llama_token id_last, + llama_tokens & result) override { + GGML_UNUSED(params); + + if (cache_size < prompt_tgt.size() + 1) { + llama_tokens tokens_new; + tokens_new.reserve(prompt_tgt.size() + 1 - cache_size); + for (size_t j = cache_size; j < prompt_tgt.size(); ++j) { + tokens_new.push_back(prompt_tgt[j]); + } + tokens_new.push_back(id_last); // add the last token + + // Update context ngram cache with new prompt_tgt: + common_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, + tokens_new, tokens_new.size(), false); + cache_size = prompt_tgt.size() + 1; + } + + llama_tokens inp; + inp.reserve(prompt_tgt.size() + 1); + for (size_t j = 0; j < prompt_tgt.size(); ++j) { + inp.push_back(prompt_tgt[j]); + } + inp.push_back(id_last); + + result.push_back(id_last); + + common_ngram_cache_draft(inp, result, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, + ngram_cache_context, + ngram_cache_dynamic, + ngram_cache_static); + + if (result.size() > 0) { + // delete first token in result (which is the id_last token) + result.erase(result.begin()); + } + } + + void accept(uint16_t n_accepted) override { + // TODO: noop + GGML_UNUSED(n_accepted); + } +}; + +struct common_speculative_state_suffix : public common_speculative_state { + common_suffix_tree tree; + common_suffix_tree corpus_tree; + bool has_corpus = false; + size_t cache_size = 0; + + // Acceptance feedback + size_t n_draft_last = 0; + bool had_accept = false; + int n_low = 0; + float base_p_min = 0.1f; + float eff_p_min = 0.1f; + + common_speculative_state_suffix( + enum common_speculative_type type, + int max_depth, + const std::string & corpus_path, + const llama_model * model) + : common_speculative_state(type) + , tree(max_depth) + , corpus_tree(max_depth) + { + if (!corpus_path.empty()) { + std::function(const std::string &)> tokenize_fn; + if (model) { + tokenize_fn = [model](const std::string & text) -> std::vector { + return common_tokenize(model, text, false, true); + }; + } + has_corpus = corpus_tree.load_corpus(corpus_path, tokenize_fn); + } + } + + void begin(const llama_tokens & prompt) override { + cache_size = 0; + n_draft_last = 0; + had_accept = false; + n_low = 0; + GGML_UNUSED(prompt); + } + + void draft( + const common_params_speculative & params, + const llama_tokens & prompt_tgt, + llama_token id_last, + llama_tokens & result) override { + + base_p_min = params.p_min; + if (n_draft_last > 0 && !had_accept) { + if (++n_low >= 3) { + eff_p_min = std::min(eff_p_min + 0.1f, 0.5f); + n_low = 0; + } + } + had_accept = false; + + if (cache_size < prompt_tgt.size() + 1) { + llama_tokens tokens_new; + tokens_new.reserve(prompt_tgt.size() + 1 - cache_size); + for (size_t j = cache_size; j < prompt_tgt.size(); ++j) { + tokens_new.push_back(prompt_tgt[j]); + } + tokens_new.push_back(id_last); + + tree.extend(tokens_new.data(), (int)tokens_new.size()); + cache_size = prompt_tgt.size() + 1; + } + + const int ctx_len = std::min((int)(prompt_tgt.size() + 1), tree.max_depth()); + llama_tokens context; + context.reserve(ctx_len); + const int ctx_start = (int)prompt_tgt.size() + 1 - ctx_len; + for (int j = ctx_start; j < (int)prompt_tgt.size(); ++j) { + context.push_back(prompt_tgt[j]); + } + context.push_back(id_last); + const int min_match_len = std::max(1, params.suffix_min_match_len); + + result = tree.speculate( + context.data(), (int)context.size(), + params.n_max, + eff_p_min, + 1, + min_match_len); + + if (has_corpus) { + auto corpus_result = corpus_tree.speculate( + context.data(), (int)context.size(), + params.n_max, + eff_p_min, + 1, + min_match_len); + if (corpus_result.size() > result.size()) { + result = std::move(corpus_result); + } + } + + n_draft_last = result.size(); + } + + void accept(uint16_t n_accepted) override { + if (n_draft_last == 0) { + return; + } + had_accept = true; + const double f_acc = (double)n_accepted / (double)n_draft_last; + if (f_acc < 0.5) { + if (++n_low >= 3) { + eff_p_min = std::min(eff_p_min + 0.1f, 0.5f); + n_low = 0; + } + } else { + n_low = 0; + if (eff_p_min > base_p_min) { + eff_p_min = std::max(eff_p_min - 0.05f, base_p_min); + } + } + } +}; + +struct common_speculative { + std::vector configs; // resolved stage config for each implementation + std::vector> impls; // list of implementations to use and their states + common_speculative_state * curr_impl = nullptr; // current implementation in use (for stats) + std::unique_ptr tuner; + int last_n_drafted = 0; + int64_t t_step_start_us = 0; +}; + +static bool common_speculative_stage_chain_matches( + const std::vector & stages, + const std::vector & configs) { + if (stages.size() != configs.size()) { + return false; + } + + for (size_t i = 0; i < stages.size(); ++i) { + if (stages[i].type != configs[i].type) { + return false; + } + } + + return true; +} + +static common_params_speculative common_speculative_get_runtime_params( + const common_speculative_config & config, + const common_params_speculative & params, + const common_speculative_stage_params & stage) { + common_params_speculative result = config.params; + + result.type = config.type; + result.n_max = stage.has_n_max_override() ? stage.n_max : params.n_max; + result.n_min = stage.has_n_min_override() ? stage.n_min : params.n_min; + result.p_min = stage.has_p_min_override() ? stage.p_min : params.p_min; + + if (config.type == COMMON_SPECULATIVE_TYPE_SUFFIX) { + result.suffix_min_match_len = stage.has_suffix_min_match_len_override() + ? stage.suffix_min_match_len + : params.suffix_min_match_len; + } + + result.n_max = std::max(result.n_max, 0); + result.n_min = std::max(0, std::min(result.n_min, result.n_max)); + result.stages.clear(); + + return result; +} + +static common_ngram_map get_common_ngram_map(const common_speculative_config & config) { + uint16_t size_key = config.params.ngram_size_n; + uint16_t size_value = config.params.ngram_size_m; + bool key_only = (config.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K); + uint16_t min_hits = config.params.ngram_min_hits; + + return common_ngram_map(size_key, size_value, key_only, min_hits); +} + +static common_speculative_state_ngram_cache create_state_ngram_cache( + const std::string & path_static, const std::string & path_dynamic, + const common_speculative_config & config) { + uint16_t n_draft = 8; // TODO get from config? + + // TODO bool param in common/common.h to set save_static/save_dynamic? + bool save_static = false; + bool save_dynamic = false; + + common_speculative_state_ngram_cache state(config.type, path_static, path_dynamic, n_draft, save_static, save_dynamic); + + return state; +} + +std::string common_speculative_type_name_str() { + std::string result; + for (size_t i = 0; i < common_speculative_types.size(); i++) { + if (i > 0) { + result += ", "; + } + result += common_speculative_type_to_str(common_speculative_types[i]); + } + return result; +} + +std::string common_speculative_type_to_str(enum common_speculative_type type) { + switch (type) { + case COMMON_SPECULATIVE_TYPE_NONE: return "none"; + case COMMON_SPECULATIVE_TYPE_DRAFT: return "draft"; + case COMMON_SPECULATIVE_TYPE_DFLASH: return "dflash"; + case COMMON_SPECULATIVE_TYPE_MTP: return "mtp"; + case COMMON_SPECULATIVE_TYPE_EAGLE3: return "eagle3"; + case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return "ngram_simple"; + case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: return "ngram_map_k"; + case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: return "ngram_map_k4v"; + case COMMON_SPECULATIVE_TYPE_NGRAM_MOD: return "ngram_mod"; + case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: return "ngram_cache"; + case COMMON_SPECULATIVE_TYPE_SUFFIX: return "suffix"; + default: return "unknown"; + } +} + +enum common_speculative_type common_speculative_type_from_name(const std::string & name) { + std::string normalized = name; + std::replace(normalized.begin(), normalized.end(), '-', '_'); + + const auto it = common_speculative_type_from_name_map.find(normalized); + if (it == common_speculative_type_from_name_map.end()) { + return COMMON_SPECULATIVE_TYPE_COUNT; + } + return it->second; +} + +bool common_speculative_is_compat(llama_context * ctx_tgt) { + bool res = true; + + llama_kv_cache_clear(ctx_tgt); + + // eval 2 tokens to check if the context is compatible + std::vector tmp; + tmp.push_back(0); + tmp.push_back(0); + + int ret = llama_decode(ctx_tgt, llama_batch_get_one(tmp.data(), tmp.size(), 0, 0)); + if (ret != 0) { + LOG_ERR("%s: llama_decode() failed: %d\n", __func__, ret); + res = false; + goto done; + } + + // try to remove the last tokens + if (!llama_kv_cache_seq_rm(ctx_tgt, 0, 1, -1)) { + LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__); + res = false; + goto done; + } + +done: + llama_kv_cache_clear(ctx_tgt); + llama_synchronize(ctx_tgt); + + return res; +} + +// initialization of the speculative decoding system +// +common_speculative * common_speculative_init( + common_params_speculative & params, + llama_context * ctx_tgt) { + std::string chain_error; + if (!common_speculative_validate_chain(params, &chain_error)) { + LOG_ERR("%s: invalid speculative stage chain: %s\n", __func__, chain_error.c_str()); + return nullptr; + } + + const auto stages = params.get_resolved_stages(); + if (params.model_dft && llama_model_is_gemma4_mtp_assistant(params.model_dft)) { + const bool has_draft_stage = std::any_of(stages.begin(), stages.end(), [](const common_speculative_stage_params & stage) { + return stage.type == COMMON_SPECULATIVE_TYPE_DRAFT; + }); + + if (has_draft_stage) { + LOG_ERR("%s: Gemma4 assistant models only support MTP stages; omit -md for self-spec-only runs or use -mtp/--spec-stage mtp for assistant-backed MTP\n", __func__); + return nullptr; + } + } + + const bool has_dflash_stage = std::any_of(stages.begin(), stages.end(), [](const common_speculative_stage_params & stage) { + return stage.type == COMMON_SPECULATIVE_TYPE_DFLASH; + }); + + const bool needs_draft_ctx = std::any_of(stages.begin(), stages.end(), [¶ms](const common_speculative_stage_params & stage) { + return stage.type == COMMON_SPECULATIVE_TYPE_DRAFT || + stage.type == COMMON_SPECULATIVE_TYPE_DFLASH || + (stage.type == COMMON_SPECULATIVE_TYPE_MTP && params.model_dft != nullptr); + }); + + llama_context * ctx_dft = nullptr; + if (needs_draft_ctx) { + if (!params.model_dft) { + LOG_ERR("%s: draft speculative stage requires a loaded draft model\n", __func__); + return nullptr; + } + + llama_context_params cparams_dft = params.cparams_dft; + + if (has_dflash_stage) { + if (!llama_model_share_dflash_io_tensors(params.model_dft, llama_get_model(ctx_tgt))) { + LOG_ERR("%s: failed to share target IO tensors with DFlash draft model\n", __func__); + return nullptr; + } + + int32_t max_cross_ctx = 0; + for (const auto & stage : stages) { + if (stage.type != COMMON_SPECULATIVE_TYPE_DFLASH) { + continue; + } + + max_cross_ctx = std::max(max_cross_ctx, params.with_stage_overrides(stage).dflash_cross_ctx); + } + + const int32_t block_size = llama_model_dflash_block_size(params.model_dft); + if (block_size <= 0) { + LOG_ERR("%s: invalid DFlash draft block size\n", __func__); + return nullptr; + } + + const int64_t required_n_ctx = (int64_t) max_cross_ctx + (int64_t) block_size; + if (required_n_ctx > std::numeric_limits::max()) { + LOG_ERR("%s: invalid DFlash draft context size cross_ctx=%d block_size=%d required_n_ctx=%lld\n", + __func__, max_cross_ctx, block_size, (long long) required_n_ctx); + return nullptr; + } + + cparams_dft.n_ctx = (uint32_t) required_n_ctx; + } + + ctx_dft = llama_init_from_model(params.model_dft, cparams_dft); + if (ctx_dft == nullptr) { + LOG_ERR("%s", "failed to create draft context\n"); + return nullptr; + } + } + + // Compute the implementations to use based on the resolved stage chain. + std::vector configs = {}; + configs.reserve(stages.size()); + + for (const auto & stage : stages) { + common_params_speculative stage_params = params.with_stage_overrides(stage); + + if (stage.type == COMMON_SPECULATIVE_TYPE_NGRAM_MOD && !stage_params.ngram_mod) { + stage_params.ngram_mod = std::make_shared(stage_params.ngram_size_n, 4*1024*1024); + + LOG_INF("%s: initialized ngram_mod with n=%d, size=%zu (%.3f MB)\n", __func__, + stage_params.ngram_size_n, stage_params.ngram_mod->size(), + (float)(stage_params.ngram_mod->size_bytes())/1024/1024); + + if (stage_params.ngram_size_n < 16) { + LOG_WRN("%s: ngram_mod n=%d is too small - poor quality is possible, see: https://github.com/ggml-org/llama.cpp/pull/19164\n", __func__, stage_params.ngram_size_n); + } + } + + configs.push_back(common_speculative_config(stage, stage_params)); + } + + if (!configs.empty() && llama_model_has_recurrent(llama_get_model(ctx_tgt))) { + const int ckpt_tokens = std::max(1, params.get_max_stage_n_max() + 1); + const int actual_mode = llama_spec_ckpt_init(ctx_tgt, params.recurrent_ckpt_mode, ckpt_tokens); + if (actual_mode == LLAMA_SPEC_CKPT_NONE) { + LOG_ERR("%s: failed to prepare recurrent checkpoint mode '%s' during speculative init (max_tokens=%d)\n", + __func__, + params.recurrent_ckpt_mode == LLAMA_SPEC_CKPT_PER_STEP ? "per-step" : + params.recurrent_ckpt_mode == LLAMA_SPEC_CKPT_GPU_FALLBACK ? "gpu-fallback" : + params.recurrent_ckpt_mode == LLAMA_SPEC_CKPT_CPU ? "cpu" : "auto", + ckpt_tokens); + if (ctx_dft != nullptr) { + llama_free(ctx_dft); + } + return nullptr; + } + llama_spec_ckpt_discard(ctx_tgt); + params.recurrent_ckpt_mode = actual_mode; + } + + std::vector> impls = {}; + + for (const common_speculative_config & config : configs) { + LOG_DBG("%s: adding implementation %s\n", __func__, common_speculative_type_to_str(config.type).c_str()); + switch (config.type) { + case COMMON_SPECULATIVE_TYPE_NONE: + break; + case COMMON_SPECULATIVE_TYPE_DRAFT: { + impls.push_back(std::make_unique(config.type, + /* .ctx_tgt = */ ctx_tgt, + /* .ctx_dft = */ ctx_dft, + /* .replacements = */ config.params.replacements + )); + break; + } + case COMMON_SPECULATIVE_TYPE_DFLASH: { + auto state = std::make_unique( + config.type, + ctx_tgt, + ctx_dft, + config.params.dflash_cross_ctx); + if (!state->ready) { + LOG_ERR("%s: failed to initialize DFlash speculative state\n", __func__); + return nullptr; + } + impls.push_back(std::move(state)); + ctx_dft = nullptr; + break; + } + case COMMON_SPECULATIVE_TYPE_MTP: { + llama_context * ctx_mtp = ctx_dft; + if (!ctx_mtp) { + const llama_model * model = llama_get_model(ctx_tgt); + ctx_mtp = llama_init_from_model(const_cast(model), config.params.cparams_dft); + if (!ctx_mtp) { + LOG_ERR("%s: failed to create MTP context\n", __func__); + return nullptr; + } + } + ctx_dft = nullptr; + + const bool use_constant_draft_positions = llama_model_is_gemma4_mtp_assistant(llama_get_model(ctx_mtp)); + impls.push_back(std::make_unique( + config.type, ctx_tgt, ctx_mtp, use_constant_draft_positions)); + break; + } + case COMMON_SPECULATIVE_TYPE_EAGLE3: { + impls.push_back(std::make_unique(config.type)); + break; + } + case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: { + common_ngram_map ngram_map = get_common_ngram_map(config); + + uint16_t ngram_size_key = ngram_map.size_key; + uint16_t mgram_size_value = ngram_map.size_value; + + auto config_simple = common_ngram_simple_config { + /* .size_ngram = */ ngram_size_key, + /* .size_mgram = */ mgram_size_value + }; + auto state = std::make_unique( + /* .type = */ config.type, + /* .state = */ config_simple + ); + impls.push_back(std::move(state)); + break; + } + case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: + case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: { + impls.push_back(std::make_unique( + (config.type), + get_common_ngram_map(config) + )); + break; + } + case COMMON_SPECULATIVE_TYPE_NGRAM_MOD: { + GGML_ASSERT(config.params.ngram_mod); + impls.push_back(std::make_unique(config.type, *config.params.ngram_mod)); + break; + } + case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: { + auto state = create_state_ngram_cache( + config.params.lookup_cache_static, config.params.lookup_cache_dynamic, config); + impls.push_back(std::make_unique(state)); + break; + } + case COMMON_SPECULATIVE_TYPE_SUFFIX: { + int depth = config.params.suffix_max_depth > 0 ? config.params.suffix_max_depth : 64; + const llama_model * model = llama_get_model(ctx_tgt); + impls.push_back(std::make_unique( + config.type, depth, config.params.suffix_corpus, model)); + break; + } + default: + break; + } + } + + if (impls.empty()) { + LOG_WRN("%s", "no implementations specified for speculative decoding\n"); + return nullptr; + } + + auto * result = new common_speculative { + /* .configs = */ std::move(configs), + /* .impls = */ std::move(impls) + }; + + // initialize autotune if requested + if (params.autotune && params.has_composite_stage_chain()) { + LOG_WRN("Autotune disabled — explicit speculative stage chains are not supported yet\n"); + } else if (params.autotune && !result->impls.empty()) { + auto actual_type = result->impls[0]->type; + if (actual_type != COMMON_SPECULATIVE_TYPE_NONE && + actual_type != COMMON_SPECULATIVE_TYPE_EAGLE3) { + result->tuner = std::make_unique(); + result->tuner->init(actual_type, params, llama_get_model(ctx_tgt)); + LOG_DBG("Autotune initialized for %s, tuning %zu parameters\n", + common_speculative_type_to_str(actual_type).c_str(), + result->tuner->coords.size()); + } else { + LOG_WRN("Autotune disabled — speculative type %s is not supported for autotuning\n", + common_speculative_type_to_str(actual_type).c_str()); + } + } + + return result; +} + +void common_speculative_free(common_speculative * spec) { + if (spec == nullptr) { + return; + } + + delete spec; +} + +void common_speculative_begin(common_speculative * spec, const llama_tokens & prompt) { + if (spec == nullptr) { + return; + } + + for (auto & impl : spec->impls) { + common_time_meas tm(impl->t_begin_us, !impl->gen_perf); + impl->begin(prompt); + impl->n_call_begin++; + } +} + +llama_tokens common_speculative_draft( + common_speculative * spec, + common_params_speculative & params, + const llama_tokens & prompt_tgt, // specified in target model vocab + llama_token id_last, + llama_pos draft_base_pos, + llama_seq_id draft_seq_id) { + llama_tokens result; + + spec->t_step_start_us = ggml_time_us(); + + // apply autotune proposal if enabled + if (spec->tuner && spec->tuner->enabled) { + spec->tuner->propose(params); + } + + const auto runtime_stages = params.get_resolved_stages(); + const bool use_runtime_stage_overrides = common_speculative_stage_chain_matches(runtime_stages, spec->configs); + + spec->curr_impl = nullptr; // reset current implementation + + for (size_t i = 0; i < spec->impls.size(); ++i) { + auto & impl = spec->impls[i]; + const auto & runtime_stage = use_runtime_stage_overrides ? runtime_stages[i] : spec->configs[i].stage; + common_params_speculative impl_params = common_speculative_get_runtime_params(spec->configs[i], params, runtime_stage); + result.clear(); + + { + common_time_meas tm(impl->t_draft_us, !impl->gen_perf); + impl->draft(impl_params, prompt_tgt, id_last, draft_base_pos, draft_seq_id, result); + impl->n_call_draft++; + } + + if (result.empty()) { + continue; + } + + if (common_speculative_type_is_self_spec(impl->type) && impl_params.n_min > 0 && (int)result.size() < impl_params.n_min) { + LOG_DBG("%s: impl %s drafted %zu tokens, below fallback threshold %d - trying next implementation\n", + __func__, common_speculative_type_to_str(impl->type).c_str(), result.size(), impl_params.n_min); + result.clear(); + continue; + } + LOG_DBG("%s: called impl %s, hist size = %zu, call_count = %zu, gen = %zu\n", __func__, + common_speculative_type_to_str(impl.get()->type).c_str(), prompt_tgt.size(), + impl.get()->n_call_draft, result.size()); + + spec->curr_impl = impl.get(); + impl->n_gen_drafts++; + impl->n_gen_tokens += result.size(); + + break; // We have a draft, so break out of the loop and return it. + } + + // store draft count for tuner feedback + if (spec->tuner && spec->tuner->enabled) { + spec->last_n_drafted = (int)result.size(); + } + + return result; +} + +void common_speculative_accept(common_speculative * spec, uint16_t n_accepted) { + if (spec->tuner && spec->tuner->enabled && spec->t_step_start_us > 0) { + int64_t step_time_us = ggml_time_us() - spec->t_step_start_us; + double step_tps = (step_time_us > 100) + ? (n_accepted + 1.0) * 1e6 / (double)step_time_us + : 0.0; + spec->tuner->accept_feedback(n_accepted, spec->last_n_drafted, step_tps); + spec->t_step_start_us = 0; + } + + common_speculative_state * impl = spec->curr_impl; + + if (!impl) { + return; + } + + { + common_time_meas tm(impl->t_accept_us, !impl->gen_perf); + if (n_accepted > 0) { + impl->n_acc_drafts++; + impl->n_acc_tokens += n_accepted; + } + + impl->accept(n_accepted); + impl->n_call_accept++; + } + + if (impl->type != COMMON_SPECULATIVE_TYPE_MTP) { + if (auto * mtp_state = common_speculative_get_mtp_state(spec); mtp_state != nullptr) { + mtp_invalidate_cached_drafts(*mtp_state); + } + } +} + +static bool common_speculative_has_type(const common_speculative * spec, common_speculative_type type) { + if (spec == nullptr) { + return false; + } + + return std::any_of(spec->configs.begin(), spec->configs.end(), [type](const common_speculative_config & config) { + return config.type == type; + }); +} + +static int common_speculative_ctx_mtp_n_embd(llama_context * ctx) { + return ctx ? (int) llama_mtp_state_n_embd(ctx) : 0; +} + +static bool common_speculative_batch_token_has_seq_id( + const llama_batch & batch, + int token_index, + llama_seq_id seq_id) { + if (batch.n_seq_id == nullptr || batch.seq_id == nullptr || batch.n_seq_id[token_index] <= 0 || batch.seq_id[token_index] == nullptr) { + return false; + } + + for (int i = 0; i < batch.n_seq_id[token_index]; ++i) { + if (batch.seq_id[token_index][i] == seq_id) { + return true; + } + } + + return false; +} + +static bool common_speculative_batch_is_exact_single_seq( + const llama_batch & batch, + llama_seq_id seq_id) { + if (batch.n_tokens <= 0 || batch.n_seq_id == nullptr || batch.seq_id == nullptr) { + return false; + } + + for (int i = 0; i < batch.n_tokens; ++i) { + if (batch.n_seq_id[i] != 1 || batch.seq_id[i] == nullptr || batch.seq_id[i][0] != seq_id) { + return false; + } + } + + return true; +} + +static int common_speculative_copy_seq_batch( + const llama_batch & batch, + llama_seq_id seq_id, + llama_batch & seq_batch) { + if (batch.token == nullptr || batch.pos == nullptr) { + return -1; + } + + if (batch.n_tokens < 1) { + return 0; + } + + std::vector token_indices; + token_indices.reserve(batch.n_tokens); + for (int i = 0; i < batch.n_tokens; ++i) { + if (common_speculative_batch_token_has_seq_id(batch, i, seq_id)) { + token_indices.push_back(i); + } + } + + if (token_indices.empty()) { + return 0; + } + + seq_batch = llama_batch_init((int) token_indices.size(), 0, 1); + for (const int i : token_indices) { + common_batch_add(seq_batch, batch.token[i], batch.pos[i], { seq_id }, batch.logits != nullptr && batch.logits[i]); + } + + return (int) token_indices.size(); +} + +static bool common_speculative_feature_view_copy_batch_rows( + const common_speculative_feature_view & view, + const llama_batch & batch, + llama_seq_id seq_id, + std::vector * hidden_rows) { + if (hidden_rows == nullptr || view.kind != COMMON_SPECULATIVE_FEATURE_HIDDEN_STATE || view.width <= 0 || batch.n_tokens <= 0 || batch.pos == nullptr) { + return false; + } + + std::unordered_map rows_by_pos; + rows_by_pos.reserve(view.rows.size()); + for (const auto & row : view.rows) { + if (row.seq_id == seq_id && row.data != nullptr) { + rows_by_pos[row.pos] = row.data; + } + } + + hidden_rows->clear(); + hidden_rows->reserve((size_t) batch.n_tokens * view.width); + for (int i = 0; i < batch.n_tokens; ++i) { + auto it = rows_by_pos.find(batch.pos[i]); + if (it == rows_by_pos.end()) { + hidden_rows->clear(); + return false; + } + + hidden_rows->insert(hidden_rows->end(), it->second, it->second + view.width); + } + + return hidden_rows->size() == (size_t) batch.n_tokens * view.width; +} diff --git a/common/speculative.cpp b/common/speculative.cpp index 016aeaa3..e7e6559f 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -477,1745 +477,7 @@ struct common_speculative_state_mtp : public common_speculative_state { } }; -struct common_speculative_state_dflash : public common_speculative_state { - llama_context * ctx_tgt; - llama_context * ctx_dft; - - llama_batch batch = {}; - - int32_t block_size = 0; - int32_t mask_token_id = -1; - int32_t n_target_features = 0; - int32_t cross_ctx = 0; - bool ready = false; - - std::vector target_layer_ids; - std::vector target_window; - std::vector target_window_pos; - std::vector target_window_stage; - std::vector target_window_pos_stage; - std::vector target_window_ring; - std::vector target_window_append_features; - int32_t target_window_rows = 0; - int32_t target_window_ring_write_pos = 0; - int32_t target_window_ring_filled = 0; - uint64_t target_window_version = 0; - int32_t target_window_keep_rows = 0; - int32_t target_window_append_rows = 0; - bool target_window_replace = false; - bool target_window_materialized = false; - llama_pos last_target_pos = -1; - size_t n_window_updates = 0; - size_t n_rows_seen = 0; - size_t n_rows_dropped = 0; - size_t n_context_shifts = 0; - size_t n_draft_empty = 0; - size_t n_set_target_fail = 0; - size_t n_decode_fail = 0; - llama_pos last_draft_pos_base = -1; - - uint64_t t_draft_decode_us = 0; - uint64_t t_draft_sample_us = 0; - uint64_t t_warmup_collect_us = 0; - uint64_t t_warmup_append_us = 0; - uint64_t t_accept_output_copy_us = 0; - uint64_t t_accept_commit_us = 0; - uint64_t t_accept_append_us = 0; - uint64_t t_accept_append_filter_us = 0; - uint64_t t_accept_append_window_alloc_us = 0; - uint64_t t_accept_append_replace_us = 0; - uint64_t t_accept_append_keep_old_us = 0; - uint64_t t_accept_append_new_rows_us = 0; - uint64_t t_accept_append_commit_detail_us = 0; - uint64_t t_accept_append_log_us = 0; - size_t n_warmup_collect_calls = 0; - size_t n_warmup_collect_rows = 0; - size_t n_warmup_append_calls = 0; - size_t n_warmup_append_rows = 0; - size_t n_accept_output_copy_calls = 0; - size_t n_accept_output_copy_rows = 0; - size_t n_accept_commit_calls = 0; - size_t n_accept_commit_rows = 0; - size_t n_accept_append_calls = 0; - size_t n_accept_append_rows = 0; - size_t n_accept_append_replace_calls = 0; - size_t n_accept_append_slide_calls = 0; - - common_speculative_state_dflash( - enum common_speculative_type type, - llama_context * ctx_tgt, - llama_context * ctx_dft, - int32_t cross_ctx) - : common_speculative_state(type) - , ctx_tgt(ctx_tgt) - , ctx_dft(ctx_dft) - , cross_ctx(std::max(1, cross_ctx)) - { - const llama_model * model_tgt = llama_get_model(ctx_tgt); - const llama_model * model_dft = llama_get_model(ctx_dft); - - if (!common_speculative_are_dflash_compatible(model_tgt, model_dft)) { - LOG_ERR("%s: DFlash draft model vocab/tokenizer is incompatible with the target model\n", __func__); - return; - } - - block_size = llama_model_dflash_block_size(model_dft); - mask_token_id = llama_model_dflash_mask_token_id(model_dft); - n_target_features = llama_model_dflash_n_target_features(model_dft); - const int32_t n_target_layers = llama_model_dflash_n_target_layers(model_dft); - - if (block_size <= 0 || mask_token_id < 0 || n_target_features <= 0 || n_target_layers <= 0) { - LOG_ERR("%s: invalid DFlash metadata (block_size=%d, mask_token_id=%d, n_target_features=%d, n_target_layers=%d)\n", - __func__, block_size, mask_token_id, n_target_features, n_target_layers); - return; - } - - target_layer_ids.resize((size_t) n_target_layers); - if (llama_model_dflash_target_layer_ids(model_dft, target_layer_ids.data(), n_target_layers) != n_target_layers) { - LOG_ERR("%s: failed to read DFlash target layer ids\n", __func__); - target_layer_ids.clear(); - return; - } - - const auto * vocab_tgt = llama_model_get_vocab(model_tgt); - const auto * vocab_dft = llama_model_get_vocab(model_dft); - const int32_t target_vocab_size = llama_vocab_n_tokens(vocab_tgt); - const int32_t draft_vocab_size = llama_vocab_n_tokens(vocab_dft); - const int32_t target_hidden_size = llama_model_n_embd(model_tgt); - const int32_t draft_hidden_size = llama_model_n_embd(model_dft); - const int32_t target_mask_token_id = llama_model_dflash_target_mask_token_id(model_tgt); - const int32_t expected_n_target_features = target_hidden_size > 0 ? target_hidden_size * n_target_layers : 0; - - if (target_mask_token_id != (int32_t) LLAMA_TOKEN_NULL && mask_token_id != target_mask_token_id) { - LOG_ERR("%s: DFlash mask token mismatch (draft=%d target=%d)\n", - __func__, mask_token_id, target_mask_token_id); - return; - } - - if (target_hidden_size <= 0 || draft_hidden_size <= 0) { - LOG_ERR("%s: invalid DFlash hidden sizes (draft=%d target=%d)\n", - __func__, draft_hidden_size, target_hidden_size); - return; - } - - if (expected_n_target_features <= 0 || n_target_features != expected_n_target_features) { - LOG_ERR("%s: DFlash target feature width mismatch (metadata=%d expected=%d target_hidden=%d target_layers=%d)\n", - __func__, n_target_features, expected_n_target_features, target_hidden_size, n_target_layers); - return; - } - - std::vector sorted_target_layer_ids = target_layer_ids; - std::sort(sorted_target_layer_ids.begin(), sorted_target_layer_ids.end()); - if (std::adjacent_find(sorted_target_layer_ids.begin(), sorted_target_layer_ids.end()) != sorted_target_layer_ids.end()) { - LOG_ERR("%s: duplicate DFlash target layer ids survived into runtime validation\n", __func__); - target_layer_ids.clear(); - return; - } - - const int32_t n_target_model_layers = llama_n_layer(model_tgt); - for (int32_t layer_id : target_layer_ids) { - if (layer_id < 0 || layer_id >= n_target_model_layers) { - LOG_ERR("%s: invalid DFlash target layer id %d for target model with %d layers\n", - __func__, layer_id, n_target_model_layers); - target_layer_ids.clear(); - return; - } - } - - const int32_t io_mode = llama_model_dflash_io_mode(model_dft, model_tgt); - if (io_mode == LLAMA_DFLASH_IO_MODE_INVALID) { - LOG_ERR("%s: DFlash draft is missing required IO tensors after target sharing\n", __func__); - return; - } - - if (io_mode == LLAMA_DFLASH_IO_MODE_MIXED) { - LOG_ERR("%s: DFlash IO contract must be fully shared or fully self-contained, but resolved to mixed mode\n", __func__); - return; - } - - if (io_mode == LLAMA_DFLASH_IO_MODE_SELF_CONTAINED && !llama_model_dflash_io_tensors_match(model_dft, target_hidden_size, target_vocab_size)) { - LOG_ERR("%s: DFlash self-contained IO tensors do not match the target hidden/vocab contract (target_hidden=%d target_vocab=%d)\n", - __func__, - target_hidden_size, - target_vocab_size); - return; - } - - if (!llama_set_dflash_capture_layers(ctx_tgt, target_layer_ids.data(), (int32_t) target_layer_ids.size())) { - LOG_ERR("%s: failed to configure DFlash target capture callback\n", __func__); - return; - } - - batch = llama_batch_init(std::max(1, block_size), 0, 1); - target_window.reserve((size_t) this->cross_ctx * (size_t) n_target_features); - target_window_stage.reserve((size_t) this->cross_ctx * (size_t) n_target_features); - target_window_ring.resize((size_t) this->cross_ctx * (size_t) n_target_features); - target_window_append_features.reserve((size_t) this->cross_ctx * (size_t) n_target_features); - target_window_pos.reserve((size_t) this->cross_ctx); - target_window_pos_stage.reserve((size_t) this->cross_ctx); - ready = true; - - llama_set_dflash_visible_cross_ctx(ctx_dft, this->cross_ctx); - llama_dflash_profile_reset(ctx_tgt); - llama_dflash_profile_reset(ctx_dft); - - std::ostringstream layers_oss; - for (size_t i = 0; i < target_layer_ids.size(); ++i) { - if (i > 0) { - layers_oss << ","; - } - layers_oss << target_layer_ids[i]; - } - - const char * io_mode_name = io_mode == LLAMA_DFLASH_IO_MODE_SHARED ? "shared" : "self-contained"; - LOG_INF("%s: DFlash context ready (n_ctx=%d, block_size=%d, cross_ctx=%d, n_target_features=%d, target_layer_ids=[%s])\n", - __func__, llama_n_ctx(ctx_dft), block_size, this->cross_ctx, n_target_features, layers_oss.str().c_str()); - LOG_INF("%s: DFlash artifact io=%s draft_vocab=%d target_vocab=%d draft_hidden=%d target_hidden=%d mask_token_id=%d target_mask_token_id=%d\n", - __func__, io_mode_name, draft_vocab_size, target_vocab_size, draft_hidden_size, target_hidden_size, mask_token_id, target_mask_token_id); - } - - ~common_speculative_state_dflash() override { - llama_clear_dflash_capture(ctx_tgt); - if (ctx_dft) { - llama_free(ctx_dft); - } - if (batch.token != nullptr) { - llama_batch_free(batch); - } - } - - void begin(const llama_tokens & prompt) override { - GGML_UNUSED(prompt); - llama_kv_cache_clear(ctx_dft); - llama_reset_dflash_kv_cache_state(ctx_dft); - n_window_updates = 0; - n_rows_seen = 0; - n_rows_dropped = 0; - n_context_shifts = 0; - n_draft_empty = 0; - n_set_target_fail = 0; - n_decode_fail = 0; - last_draft_pos_base = -1; - t_draft_decode_us = 0; - t_draft_sample_us = 0; - t_warmup_collect_us = 0; - t_warmup_append_us = 0; - t_accept_output_copy_us = 0; - t_accept_commit_us = 0; - t_accept_append_us = 0; - t_accept_append_filter_us = 0; - t_accept_append_window_alloc_us = 0; - t_accept_append_replace_us = 0; - t_accept_append_keep_old_us = 0; - t_accept_append_new_rows_us = 0; - t_accept_append_commit_detail_us = 0; - t_accept_append_log_us = 0; - n_warmup_collect_calls = 0; - n_warmup_collect_rows = 0; - n_warmup_append_calls = 0; - n_warmup_append_rows = 0; - n_accept_output_copy_calls = 0; - n_accept_output_copy_rows = 0; - n_accept_commit_calls = 0; - n_accept_commit_rows = 0; - n_accept_append_calls = 0; - n_accept_append_rows = 0; - n_accept_append_replace_calls = 0; - n_accept_append_slide_calls = 0; - llama_dflash_profile_reset(ctx_tgt); - llama_dflash_profile_reset(ctx_dft); - } - - void draft( - const common_params_speculative & params, - const llama_tokens & prompt_tgt, - llama_token id_last, - llama_tokens & result) override { - GGML_UNUSED(prompt_tgt); - - result.clear(); - if (!ready || target_window_rows <= 0) { - n_draft_empty++; - return; - } - - const int32_t n_keep = std::min(params.n_max, block_size - 1); - if (n_keep <= 0) { - return; - } - - const bool use_kv_cache = dflash_use_kv_cache_experiment(); - const float * target_features = nullptr; - size_t target_feature_floats = 0; - llama_dflash_window_update window_update = { - target_window_version, - target_window_keep_rows, - target_window_append_rows, - target_window_replace, - target_window_append_features.empty() ? nullptr : target_window_append_features.data(), - target_window_append_features.size(), - }; - const llama_dflash_kv_cache_transition cache_plan = use_kv_cache - ? llama_plan_dflash_kv_cache_transition_for_ctx(ctx_dft, window_update, target_window_rows) - : llama_dflash_kv_cache_transition{}; - - if (!use_kv_cache || cache_plan.rebuild_cache) { - dflash_materialize_target_window_features(*this); - target_features = target_window.data(); - target_feature_floats = target_window.size(); - } - if (use_kv_cache && cache_plan.rebuild_cache) { - window_update.append_features = target_window.data(); - window_update.append_floats = target_window.size(); - window_update.append_rows = target_window_rows; - } - - if (!llama_set_dflash_target_features_view(ctx_dft, target_features, target_feature_floats, target_window_rows, target_window_pos.data(), &window_update)) { - LOG_ERR("%s: failed to set DFlash target features\n", __func__); - n_set_target_fail++; - return; - } - - llama_kv_cache_clear(ctx_dft); - batch.n_tokens = 0; - const int32_t batch_len = n_keep + 1; - const llama_pos draft_pos_base = last_target_pos >= 0 ? last_target_pos + 1 : (llama_pos) target_window_rows; - const llama_pos seed_pos = last_target_pos >= 0 ? last_target_pos : draft_pos_base - 1; - last_draft_pos_base = draft_pos_base; - common_batch_add(batch, id_last, seed_pos, { 0 }, false); - for (int32_t i = 1; i < batch_len; ++i) { - common_batch_add(batch, mask_token_id, draft_pos_base + (i - 1), { 0 }, i <= n_keep); - } - - const int64_t t_decode_us = ggml_time_us(); - if (llama_decode(ctx_dft, batch) != 0) { - LOG_ERR("%s: llama_decode() failed for DFlash draft batch\n", __func__); - n_decode_fail++; - batch.n_tokens = 0; - return; - } - t_draft_decode_us += (uint64_t) (ggml_time_us() - t_decode_us); - - result.reserve((size_t) n_keep); - const int64_t t_sample_us = ggml_time_us(); - for (int32_t i = 0; i < n_keep; ++i) { - result.push_back(common_sampler_sample_speculative(nullptr, ctx_dft, i + 1, nullptr)); - } - t_draft_sample_us += (uint64_t) (ggml_time_us() - t_sample_us); - - batch.n_tokens = 0; - dflash_contract_log_draft(*this, n_keep, result.size()); - } - - void accept(uint16_t n_accepted) override { - GGML_UNUSED(n_accepted); - } -}; - -static void dflash_contract_log_append( - const common_speculative_state_dflash & state, - llama_seq_id seq_id, - const std::vector & new_positions) { - if (!dflash_contract_log_enabled()) { - return; - } - - static std::atomic counter = 0; - const uint64_t ordinal = counter.fetch_add(1, std::memory_order_relaxed); - if (ordinal >= 8) { - return; - } - - const dflash_contract_pos_summary incoming = dflash_contract_summarize_positions(new_positions); - const dflash_contract_pos_summary window = dflash_contract_summarize_positions(state.target_window_pos); - - LOG_INF("dflash contract append[%llu]: seq=%d incoming_rows=%zu incoming_pos=%s pos=[%d..%d] gaps=%d nonmono=%d window_rows=%d window_pos=%s pos=[%d..%d] gaps=%d nonmono=%d last_target_pos=%d\n", - (unsigned long long) (ordinal + 1), - (int) seq_id, - new_positions.size(), - dflash_contract_format_values(new_positions).c_str(), - (int) incoming.first, - (int) incoming.last, - incoming.gap_count, - incoming.nonmono_count, - state.target_window_rows, - dflash_contract_format_values(state.target_window_pos).c_str(), - (int) window.first, - (int) window.last, - window.gap_count, - window.nonmono_count, - (int) state.last_target_pos); -} - -static void dflash_contract_log_draft( - const common_speculative_state_dflash & state, - int32_t n_keep, - size_t result_size) { - if (!dflash_contract_log_enabled()) { - return; - } - - static std::atomic counter = 0; - const uint64_t ordinal = counter.fetch_add(1, std::memory_order_relaxed); - if (ordinal >= 8) { - return; - } - - const dflash_contract_pos_summary window = dflash_contract_summarize_positions(state.target_window_pos); - llama_dflash_profile_stats graph_stats = {}; - llama_dflash_profile_get_stats(state.ctx_dft, &graph_stats); - const int draft_delta = (state.last_target_pos >= 0 && state.last_draft_pos_base >= 0) - ? (int) (state.last_draft_pos_base - state.last_target_pos) - : -1; - const llama_pos seed_pos = state.last_target_pos; - const llama_pos mask_first_pos = state.last_draft_pos_base; - const llama_pos mask_last_pos = state.last_draft_pos_base >= 0 - ? state.last_draft_pos_base + n_keep - 1 - : -1; - - LOG_INF("dflash contract draft[%llu]: window_rows=%d window_pos=%s pos=[%d..%d] gaps=%d nonmono=%d last_target_pos=%d seed_pos=%d mask_pos=[%d..%d] sample_rows=[1..%d] output_rows=[1..%d] draft_pos_base=%d delta=%d n_keep=%d result=%zu set_target(missing/nonmono)=%llu/%llu graph(fallback/nonmono)=%llu/%llu graph_pos=[%d..%d]\n", - (unsigned long long) (ordinal + 1), - state.target_window_rows, - dflash_contract_format_values(state.target_window_pos).c_str(), - (int) window.first, - (int) window.last, - window.gap_count, - window.nonmono_count, - (int) state.last_target_pos, - (int) seed_pos, - (int) mask_first_pos, - (int) mask_last_pos, - n_keep, - n_keep, - (int) state.last_draft_pos_base, - draft_delta, - n_keep, - result_size, - (unsigned long long) graph_stats.set_target_missing_positions, - (unsigned long long) graph_stats.set_target_non_monotonic_positions, - (unsigned long long) graph_stats.graph_pos_fallbacks, - (unsigned long long) graph_stats.graph_pos_non_monotonic, - (int) graph_stats.last_pos_first, - (int) graph_stats.last_pos_last); -} - -struct common_speculative_state_draft : public common_speculative_state { - llama_context * ctx_tgt; // only used for retokenizing from ctx_dft - llama_context * ctx_dft; - - common_sampler * smpl; - - llama_batch batch; - llama_tokens prompt_dft; - - bool vocab_cmpt = true; // whether retokenization is needed - std::unordered_map vocab_map; - - common_speculative_state_draft( - enum common_speculative_type type, - llama_context * ctx_tgt, - llama_context * ctx_dft, - const std::vector> & replacements) - : common_speculative_state(type) - , ctx_tgt(ctx_tgt) - , ctx_dft(ctx_dft) - { - batch = llama_batch_init(llama_n_batch(ctx_dft), 0, 1); - smpl = nullptr; - { - struct common_params_sampling params; - params.top_k = 10; - params.samplers_sequence = { - llama_sampler_type::TOP_K, - llama_sampler_type::DIST, // needed to get probabilities - }; - smpl = common_sampler_init(llama_get_model(ctx_dft), params); - } - - vocab_cmpt = common_speculative_are_compatible(llama_get_model(ctx_tgt), llama_get_model(ctx_dft)); - LOG_DBG("vocab_cmpt = %d\n", vocab_cmpt); - - if (!vocab_cmpt) { - LOG_WRN("the target and draft vocabs are not compatible - tokens will be translated between the two\n"); - - for (const auto & pair : replacements) { - vocab_map[pair.first] = pair.second; - } - } - } - - ~common_speculative_state_draft() override { - llama_free(ctx_dft); - - common_sampler_free(smpl); - - llama_batch_free(batch); - } - - void begin(const llama_tokens & prompt) override { - GGML_UNUSED(prompt); - } - - void draft( - const common_params_speculative & params, - const llama_tokens & prompt_tgt, - llama_token id_last, - llama_tokens & result) override { - auto * spec = this; - - auto & batch = spec->batch; - auto & ctx_tgt = spec->ctx_tgt; - auto & ctx_dft = spec->ctx_dft; - auto & smpl = spec->smpl; - auto & prompt_dft = spec->prompt_dft; - - int reuse_i = 0; - int reuse_n = 0; - - const int n_ctx = llama_n_ctx(ctx_dft) - params.n_max; - - llama_tokens prompt_cnv; - if (!spec->vocab_cmpt) { - // convert id_last to draft vocab. llama_detokenize is called directly to avoid an allocation - const auto * model_tgt = llama_get_model(ctx_tgt); - const auto * vocab_tgt = llama_model_get_vocab(model_tgt); - - std::string text; - - text = common_detokenize(ctx_tgt, prompt_tgt, true); - text = replace_to_dft(text); - - LOG_DBG("%s: main->draft detokenized string: '%s'\n", __func__, text.c_str()); - - prompt_cnv = common_tokenize(ctx_dft, text, false, true); - - - - int32_t n_chars = llama_detokenize(vocab_tgt, &id_last, 1, nullptr, 0, false, false); - GGML_ASSERT(n_chars < 0 && "failed to detokenize id_last"); - - text.resize(-n_chars); - llama_detokenize(vocab_tgt, &id_last, 1, text.data(), text.size(), false, false); - text = replace_to_dft(text); - - LOG_DBG("main->draft detokenized id_last(%d): '%s'\n", id_last, text.c_str()); - id_last = common_tokenize(ctx_dft, text, false, true)[0]; - } - - const llama_tokens & prompt_cur = spec->vocab_cmpt ? prompt_tgt : prompt_cnv; - - const int i_start = std::max(0, (int) prompt_cur.size() - n_ctx); - - // reuse as much as possible from the old draft context - // ideally, the draft context should be as big as the target context and we will always reuse the entire prompt - for (int i = 0; i < (int) prompt_dft.size(); ++i) { - int cur = 0; - while (i_start + cur < (int) prompt_cur.size() && - i + cur < (int) prompt_dft.size() && - prompt_cur[i_start + cur] == prompt_dft[i + cur]) { - cur++; - } - - if ((cur >= 256 || n_ctx >= (int) prompt_cur.size()) && cur > reuse_n) { - reuse_i = i; - reuse_n = cur; - } - } - - LOG_DBG("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt_dft.size()); - - result.clear(); - result.reserve(params.n_max); - - if (reuse_n == 0) { - llama_kv_cache_clear(ctx_dft); - prompt_dft.clear(); - } else { - // this happens when a previous draft has been discarded (for example, due to being too small), but the - // target model agreed with it. in this case, we simply pass back the previous results to save compute - if (reuse_i + reuse_n < (int) prompt_dft.size() && prompt_dft[reuse_i + reuse_n] == id_last) { - for (int i = reuse_i + reuse_n + 1; i < (int) prompt_dft.size(); ++i) { - result.push_back(prompt_dft[i]); - - if (params.n_max <= (int) result.size()) { - break; - } - } - - return; - } - - if (reuse_i > 0) { - llama_kv_cache_seq_rm (ctx_dft, 0, 0, reuse_i); - llama_kv_cache_seq_add(ctx_dft, 0, reuse_i, -1, -reuse_i); - - prompt_dft.erase(prompt_dft.begin(), prompt_dft.begin() + reuse_i); - } - - if (reuse_n < (int) prompt_dft.size()) { - llama_kv_cache_seq_rm (ctx_dft, 0, reuse_n, -1); - prompt_dft.erase(prompt_dft.begin() + reuse_n, prompt_dft.end()); - } - } - - // prepare a batch to evaluate any new tokens in the prompt - common_batch_clear(batch); - - for (size_t i = i_start + reuse_n; i < prompt_cur.size(); ++i) { - //LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_cur[i]); - common_batch_add(batch, prompt_cur[i], i - i_start, { 0 }, false); - - prompt_dft.push_back(prompt_cur[i]); - } - - // we should rarely end-up here during normal decoding - if (batch.n_tokens > 0) { - //LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str()); - - llama_decode(ctx_dft, batch); - } - - const llama_pos n_past = prompt_dft.size(); - - LOG_DBG("%s: n_past = %d\n", __func__, n_past); - - common_batch_clear(batch); - common_batch_add (batch, id_last, n_past, { 0 }, true); - - prompt_dft.push_back(id_last); - - //LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx_dft, prompt_dft).c_str()); - - llama_decode(ctx_dft, batch); - - common_sampler_reset(smpl); - - // sample n_draft tokens from the draft model - for (int i = 0; i < params.n_max; ++i) { - common_batch_clear(batch); - - common_sampler_sample(smpl, ctx_dft, 0, true); - - const auto * cur_p = common_sampler_get_candidates(smpl, true); - - for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) { - LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", - k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str()); - } - - // add drafted token for each sequence - const llama_token id = cur_p->data[0].id; - - common_sampler_accept(smpl, nullptr, id, true); - - // only collect very high-confidence draft tokens - if (cur_p->data[0].p < params.p_min) { - if (i == 0) { - result.push_back(id); - } - break; - } - - result.push_back(id); - - if (params.n_max <= (int) result.size()) { - break; - } - - - common_batch_add(batch, id, n_past + i + 1, { 0 }, true); - - // evaluate the drafted tokens on the draft model - llama_decode(ctx_dft, batch); - - prompt_dft.push_back(id); - } - - if (!spec->vocab_cmpt) { - std::string detokenized = common_detokenize(ctx_dft, result, true); - detokenized = replace_to_tgt(detokenized); - LOG_DBG("draft->main detokenized string: '%s'\n", detokenized.c_str()); - result = common_tokenize(ctx_tgt, detokenized, false, true); - if (result.size() > (size_t)params.n_max) { - result.resize(params.n_max); - } - } - } - - void accept(uint16_t n_accepted) override { - // noop - GGML_UNUSED(n_accepted); - } - - std::string replace_to_dft(const std::string & input) const { - std::string result = input; - - for (const auto & pair : this->vocab_map) { - size_t pos = result.find(pair.first); - while (pos != std::string::npos) { - result.replace(pos, pair.first.length(), pair.second); - pos = result.find(pair.first, pos + pair.second.length()); - } - } - - return result; - } - - std::string replace_to_tgt(const std::string & input) const { - std::string result = input; - - for (const auto & pair : this->vocab_map) { - size_t pos = result.find(pair.second); - while (pos != std::string::npos) { - result.replace(pos, pair.second.length(), pair.first); - pos = result.find(pair.second, pos + pair.first.length()); - } - } - - return result; - } -}; - -struct common_speculative_state_eagle3 : public common_speculative_state { - common_speculative_state_eagle3(enum common_speculative_type type) : common_speculative_state(type) {} - - void begin(const llama_tokens & prompt) override { - GGML_UNUSED(prompt); - } - - void draft( - const common_params_speculative & params, - const llama_tokens & prompt_tgt, - llama_token id_last, - llama_tokens & draft_tokens) override { - // TODO: implement - GGML_UNUSED(params); - GGML_UNUSED(prompt_tgt); - GGML_UNUSED(id_last); - GGML_UNUSED(draft_tokens); - } - - void accept(uint16_t n_accepted) override { - // noop - GGML_UNUSED(n_accepted); - } -}; - -// state of self-speculation (simple implementation, not ngram-map) -struct common_speculative_state_ngram_simple : public common_speculative_state { - common_ngram_simple_config config; - - common_speculative_state_ngram_simple( - enum common_speculative_type type, - common_ngram_simple_config config) - : common_speculative_state(type), config(config) {} - - void begin(const llama_tokens & prompt) override { - GGML_UNUSED(prompt); - } - - void draft( - const common_params_speculative & params, - const llama_tokens & prompt_tgt, - llama_token id_last, - llama_tokens & result) override { - - result = common_ngram_simple_draft(config, prompt_tgt, id_last); - GGML_UNUSED(params); - } - - void accept(uint16_t n_accepted) override { - // noop - GGML_UNUSED(n_accepted); - } -}; - -struct common_speculative_state_ngram_map_k : public common_speculative_state { - // draft ngram map for speculative decoding without draft model - common_ngram_map map; - - common_speculative_state_ngram_map_k( - enum common_speculative_type type, - common_ngram_map map) - : common_speculative_state(type), map(std::move(map)) {} - - void begin(const llama_tokens & prompt) override { - common_ngram_map_begin(map, prompt); - } - - void draft( - const common_params_speculative & params, - const llama_tokens & prompt_tgt, - llama_token id_last, - llama_tokens & result) override { - common_ngram_map_draft(map, prompt_tgt, id_last, result); - GGML_UNUSED(params); - } - - void accept(uint16_t n_accepted) override { - common_ngram_map_accept(map, n_accepted); - } -}; - -struct common_speculative_state_ngram_mod : public common_speculative_state { - common_ngram_mod & mod; - - // the last position in the prompt that was added to the ngram container - size_t i_last = 0; - - // length of the last drafted n‑gram (number of tokens returned by draft) - size_t n_draft_last = 0; - - // consecutive accept rounds with low acceptance fraction (< 0.5) - int n_low = 0; - - // enable trace logging if LLAMA_TRACE is set - const bool verbose; - - common_speculative_state_ngram_mod(enum common_speculative_type type, common_ngram_mod & mod) - : common_speculative_state(type), mod(mod), verbose(std::getenv("LLAMA_TRACE") != nullptr) { - static_assert(sizeof(llama_token) == sizeof(common_ngram_mod::entry_t)); - } - - void begin(const llama_tokens & prompt) override { - i_last = 0; - - n_draft_last = 0; - n_low = 0; - - const size_t n = mod.get_n(); - - if (prompt.size() < n) { - return; - } - - for (size_t i = 0; i < prompt.size() - n; ++i) { - mod.add(prompt.data() + i); - } - - i_last = prompt.size() - n; - - const double f = (double)mod.get_used() / (double)mod.size(); - LOG_INF("%s: ngram_mod occupancy = %zu/%zu (%.2f)\n", __func__, mod.get_used(), mod.size(), f); - - constexpr double f_thold = 0.25; - if (f > f_thold) { - LOG_WRN("%s: ngram_mod occupancy %.2f exceeds threshold (%.2f) - resetting\n", __func__, f, f_thold); - - mod.reset(); - } - } - - void draft( - const common_params_speculative & params, - const llama_tokens & prompt_tgt, - llama_token id_last, - llama_tokens & result) override { - GGML_UNUSED(params); - - n_draft_last = 0; - - const size_t cur_len = prompt_tgt.size(); - if (cur_len < mod.get_n()) { - return; - } - - const size_t n = mod.get_n(); - - // add new ngrams in chunks - if (i_last + 32 < cur_len) { - for (size_t i = i_last; i < cur_len - n; ++i) { - mod.add(prompt_tgt.data() + i); - } - - i_last = cur_len - n; - } - - result.resize(n + params.n_max); - for (size_t i = 0; i < n - 1; ++i) { - result[i] = prompt_tgt[cur_len - n + 1 + i]; - } - result[n - 1] = id_last; - - for (int i = 0; i < params.n_max; ++i) { - const llama_token token = mod.get(result.data() + i); - if (token == common_ngram_mod::EMPTY) { - if (i < params.n_min) { - result.clear(); - return; - } - - result.resize(n + i); - break; - } - result[n + i] = token; - } - - // only return the m tokens that were drafted - for (size_t i = 0; n + i < result.size(); ++i) { - result[i] = result[n + i]; - } - result.resize(result.size() - n); - - // store length of drafted n‑gram for later acceptance analysis - n_draft_last = result.size(); - } - - void accept(uint16_t n_accepted) override { - if (verbose) { - LOG_INF("%s: accepted %d tokens from %zu drafted tokens\n", __func__, n_accepted, n_draft_last); - } - - // compute acceptance fraction if we have a recorded draft length - if (n_draft_last > 0) { - const double f_acc = (double)n_accepted / (double)n_draft_last; - if (f_acc < 0.5) { - n_low++; - if (n_low >= 3) { - LOG_WRN("%s: low acceptance streak (%d) – resetting ngram_mod\n", __func__, n_low); - - mod.reset(); - n_low = 0; - i_last = 0; - } - } else { - n_low = 0; - } - } - } -}; - -struct common_speculative_state_ngram_cache : public common_speculative_state { - uint16_t n_draft; - bool save_dynamic; - bool save_static; - - common_ngram_cache ngram_cache_context; - common_ngram_cache ngram_cache_dynamic; - common_ngram_cache ngram_cache_static; - - size_t cache_size = 0; // number of tokens in n-gram cache - - common_speculative_state_ngram_cache( - const enum common_speculative_type type, - const std::string & path_static, - const std::string & path_dynamic, - uint16_t n_draft, - bool save_dynamic, - bool save_static) - : common_speculative_state(type) - , n_draft(n_draft) - , save_dynamic(save_dynamic) - , save_static(save_static) - { - if (!path_static.empty()) { - try { - ngram_cache_static = common_ngram_cache_load(path_static); - } catch (...) { - LOG_ERR("failed to open static lookup cache: %s", path_static.c_str()); - GGML_ABORT("Couldn't read static lookup cache"); - } - } - - if (!path_dynamic.empty()) { - try { - ngram_cache_dynamic = common_ngram_cache_load(path_dynamic); - } catch (...) { - LOG_ERR("failed to open dynamic lookup cache: %s", path_dynamic.c_str()); - GGML_ABORT("Couldn't read dynamic lookup cache"); - } - } - } - - void begin(const llama_tokens & prompt) override { - GGML_UNUSED(prompt); - } - - void draft( - const common_params_speculative & params, - const llama_tokens & prompt_tgt, - llama_token id_last, - llama_tokens & result) override { - GGML_UNUSED(params); - - if (cache_size < prompt_tgt.size() + 1) { - llama_tokens tokens_new; - tokens_new.reserve(prompt_tgt.size() + 1 - cache_size); - for (size_t j = cache_size; j < prompt_tgt.size(); ++j) { - tokens_new.push_back(prompt_tgt[j]); - } - tokens_new.push_back(id_last); // add the last token - - // Update context ngram cache with new prompt_tgt: - common_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, - tokens_new, tokens_new.size(), false); - cache_size = prompt_tgt.size() + 1; - } - - llama_tokens inp; - inp.reserve(prompt_tgt.size() + 1); - for (size_t j = 0; j < prompt_tgt.size(); ++j) { - inp.push_back(prompt_tgt[j]); - } - inp.push_back(id_last); - - result.push_back(id_last); - - common_ngram_cache_draft(inp, result, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, - ngram_cache_context, - ngram_cache_dynamic, - ngram_cache_static); - - if (result.size() > 0) { - // delete first token in result (which is the id_last token) - result.erase(result.begin()); - } - } - - void accept(uint16_t n_accepted) override { - // TODO: noop - GGML_UNUSED(n_accepted); - } -}; - -struct common_speculative_state_suffix : public common_speculative_state { - common_suffix_tree tree; - common_suffix_tree corpus_tree; - bool has_corpus = false; - size_t cache_size = 0; - - // Acceptance feedback - size_t n_draft_last = 0; - bool had_accept = false; - int n_low = 0; - float base_p_min = 0.1f; - float eff_p_min = 0.1f; - - common_speculative_state_suffix( - enum common_speculative_type type, - int max_depth, - const std::string & corpus_path, - const llama_model * model) - : common_speculative_state(type) - , tree(max_depth) - , corpus_tree(max_depth) - { - if (!corpus_path.empty()) { - std::function(const std::string &)> tokenize_fn; - if (model) { - tokenize_fn = [model](const std::string & text) -> std::vector { - return common_tokenize(model, text, false, true); - }; - } - has_corpus = corpus_tree.load_corpus(corpus_path, tokenize_fn); - } - } - - void begin(const llama_tokens & prompt) override { - cache_size = 0; - n_draft_last = 0; - had_accept = false; - n_low = 0; - GGML_UNUSED(prompt); - } - - void draft( - const common_params_speculative & params, - const llama_tokens & prompt_tgt, - llama_token id_last, - llama_tokens & result) override { - - base_p_min = params.p_min; - if (n_draft_last > 0 && !had_accept) { - if (++n_low >= 3) { - eff_p_min = std::min(eff_p_min + 0.1f, 0.5f); - n_low = 0; - } - } - had_accept = false; - - if (cache_size < prompt_tgt.size() + 1) { - llama_tokens tokens_new; - tokens_new.reserve(prompt_tgt.size() + 1 - cache_size); - for (size_t j = cache_size; j < prompt_tgt.size(); ++j) { - tokens_new.push_back(prompt_tgt[j]); - } - tokens_new.push_back(id_last); - - tree.extend(tokens_new.data(), (int)tokens_new.size()); - cache_size = prompt_tgt.size() + 1; - } - - const int ctx_len = std::min((int)(prompt_tgt.size() + 1), tree.max_depth()); - llama_tokens context; - context.reserve(ctx_len); - const int ctx_start = (int)prompt_tgt.size() + 1 - ctx_len; - for (int j = ctx_start; j < (int)prompt_tgt.size(); ++j) { - context.push_back(prompt_tgt[j]); - } - context.push_back(id_last); - const int min_match_len = std::max(1, params.suffix_min_match_len); - - result = tree.speculate( - context.data(), (int)context.size(), - params.n_max, - eff_p_min, - 1, - min_match_len); - - if (has_corpus) { - auto corpus_result = corpus_tree.speculate( - context.data(), (int)context.size(), - params.n_max, - eff_p_min, - 1, - min_match_len); - if (corpus_result.size() > result.size()) { - result = std::move(corpus_result); - } - } - - n_draft_last = result.size(); - } - - void accept(uint16_t n_accepted) override { - if (n_draft_last == 0) { - return; - } - had_accept = true; - const double f_acc = (double)n_accepted / (double)n_draft_last; - if (f_acc < 0.5) { - if (++n_low >= 3) { - eff_p_min = std::min(eff_p_min + 0.1f, 0.5f); - n_low = 0; - } - } else { - n_low = 0; - if (eff_p_min > base_p_min) { - eff_p_min = std::max(eff_p_min - 0.05f, base_p_min); - } - } - } -}; - -struct common_speculative { - std::vector configs; // resolved stage config for each implementation - std::vector> impls; // list of implementations to use and their states - common_speculative_state * curr_impl = nullptr; // current implementation in use (for stats) - std::unique_ptr tuner; - int last_n_drafted = 0; - int64_t t_step_start_us = 0; -}; - -static bool common_speculative_stage_chain_matches( - const std::vector & stages, - const std::vector & configs) { - if (stages.size() != configs.size()) { - return false; - } - - for (size_t i = 0; i < stages.size(); ++i) { - if (stages[i].type != configs[i].type) { - return false; - } - } - - return true; -} - -static common_params_speculative common_speculative_get_runtime_params( - const common_speculative_config & config, - const common_params_speculative & params, - const common_speculative_stage_params & stage) { - common_params_speculative result = config.params; - - result.type = config.type; - result.n_max = stage.has_n_max_override() ? stage.n_max : params.n_max; - result.n_min = stage.has_n_min_override() ? stage.n_min : params.n_min; - result.p_min = stage.has_p_min_override() ? stage.p_min : params.p_min; - - if (config.type == COMMON_SPECULATIVE_TYPE_SUFFIX) { - result.suffix_min_match_len = stage.has_suffix_min_match_len_override() - ? stage.suffix_min_match_len - : params.suffix_min_match_len; - } - - result.n_max = std::max(result.n_max, 0); - result.n_min = std::max(0, std::min(result.n_min, result.n_max)); - result.stages.clear(); - - return result; -} - -static common_ngram_map get_common_ngram_map(const common_speculative_config & config) { - uint16_t size_key = config.params.ngram_size_n; - uint16_t size_value = config.params.ngram_size_m; - bool key_only = (config.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K); - uint16_t min_hits = config.params.ngram_min_hits; - - return common_ngram_map(size_key, size_value, key_only, min_hits); -} - -static common_speculative_state_ngram_cache create_state_ngram_cache( - const std::string & path_static, const std::string & path_dynamic, - const common_speculative_config & config) { - uint16_t n_draft = 8; // TODO get from config? - - // TODO bool param in common/common.h to set save_static/save_dynamic? - bool save_static = false; - bool save_dynamic = false; - - common_speculative_state_ngram_cache state(config.type, path_static, path_dynamic, n_draft, save_static, save_dynamic); - - return state; -} - -std::string common_speculative_type_name_str() { - std::string result; - for (size_t i = 0; i < common_speculative_types.size(); i++) { - if (i > 0) { - result += ", "; - } - result += common_speculative_type_to_str(common_speculative_types[i]); - } - return result; -} - -std::string common_speculative_type_to_str(enum common_speculative_type type) { - switch (type) { - case COMMON_SPECULATIVE_TYPE_NONE: return "none"; - case COMMON_SPECULATIVE_TYPE_DRAFT: return "draft"; - case COMMON_SPECULATIVE_TYPE_DFLASH: return "dflash"; - case COMMON_SPECULATIVE_TYPE_MTP: return "mtp"; - case COMMON_SPECULATIVE_TYPE_EAGLE3: return "eagle3"; - case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return "ngram_simple"; - case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: return "ngram_map_k"; - case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: return "ngram_map_k4v"; - case COMMON_SPECULATIVE_TYPE_NGRAM_MOD: return "ngram_mod"; - case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: return "ngram_cache"; - case COMMON_SPECULATIVE_TYPE_SUFFIX: return "suffix"; - default: return "unknown"; - } -} - -enum common_speculative_type common_speculative_type_from_name(const std::string & name) { - std::string normalized = name; - std::replace(normalized.begin(), normalized.end(), '-', '_'); - - const auto it = common_speculative_type_from_name_map.find(normalized); - if (it == common_speculative_type_from_name_map.end()) { - return COMMON_SPECULATIVE_TYPE_COUNT; - } - return it->second; -} - -bool common_speculative_is_compat(llama_context * ctx_tgt) { - bool res = true; - - llama_kv_cache_clear(ctx_tgt); - - // eval 2 tokens to check if the context is compatible - std::vector tmp; - tmp.push_back(0); - tmp.push_back(0); - - int ret = llama_decode(ctx_tgt, llama_batch_get_one(tmp.data(), tmp.size(), 0, 0)); - if (ret != 0) { - LOG_ERR("%s: llama_decode() failed: %d\n", __func__, ret); - res = false; - goto done; - } - - // try to remove the last tokens - if (!llama_kv_cache_seq_rm(ctx_tgt, 0, 1, -1)) { - LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__); - res = false; - goto done; - } - -done: - llama_kv_cache_clear(ctx_tgt); - llama_synchronize(ctx_tgt); - - return res; -} - -// initialization of the speculative decoding system -// -common_speculative * common_speculative_init( - common_params_speculative & params, - llama_context * ctx_tgt) { - std::string chain_error; - if (!common_speculative_validate_chain(params, &chain_error)) { - LOG_ERR("%s: invalid speculative stage chain: %s\n", __func__, chain_error.c_str()); - return nullptr; - } - - const auto stages = params.get_resolved_stages(); - if (params.model_dft && llama_model_is_gemma4_mtp_assistant(params.model_dft)) { - const bool has_draft_stage = std::any_of(stages.begin(), stages.end(), [](const common_speculative_stage_params & stage) { - return stage.type == COMMON_SPECULATIVE_TYPE_DRAFT; - }); - - if (has_draft_stage) { - LOG_ERR("%s: Gemma4 assistant models only support MTP stages; omit -md for self-spec-only runs or use -mtp/--spec-stage mtp for assistant-backed MTP\n", __func__); - return nullptr; - } - } - - const bool has_dflash_stage = std::any_of(stages.begin(), stages.end(), [](const common_speculative_stage_params & stage) { - return stage.type == COMMON_SPECULATIVE_TYPE_DFLASH; - }); - - const bool needs_draft_ctx = std::any_of(stages.begin(), stages.end(), [¶ms](const common_speculative_stage_params & stage) { - return stage.type == COMMON_SPECULATIVE_TYPE_DRAFT || - stage.type == COMMON_SPECULATIVE_TYPE_DFLASH || - (stage.type == COMMON_SPECULATIVE_TYPE_MTP && params.model_dft != nullptr); - }); - - llama_context * ctx_dft = nullptr; - if (needs_draft_ctx) { - if (!params.model_dft) { - LOG_ERR("%s: draft speculative stage requires a loaded draft model\n", __func__); - return nullptr; - } - - llama_context_params cparams_dft = params.cparams_dft; - - if (has_dflash_stage) { - if (!llama_model_share_dflash_io_tensors(params.model_dft, llama_get_model(ctx_tgt))) { - LOG_ERR("%s: failed to share target IO tensors with DFlash draft model\n", __func__); - return nullptr; - } - - int32_t max_cross_ctx = 0; - for (const auto & stage : stages) { - if (stage.type != COMMON_SPECULATIVE_TYPE_DFLASH) { - continue; - } - - max_cross_ctx = std::max(max_cross_ctx, params.with_stage_overrides(stage).dflash_cross_ctx); - } - - const int32_t block_size = llama_model_dflash_block_size(params.model_dft); - if (block_size <= 0) { - LOG_ERR("%s: invalid DFlash draft block size\n", __func__); - return nullptr; - } - - const int64_t required_n_ctx = (int64_t) max_cross_ctx + (int64_t) block_size; - if (required_n_ctx > std::numeric_limits::max()) { - LOG_ERR("%s: invalid DFlash draft context size cross_ctx=%d block_size=%d required_n_ctx=%lld\n", - __func__, max_cross_ctx, block_size, (long long) required_n_ctx); - return nullptr; - } - - cparams_dft.n_ctx = (uint32_t) required_n_ctx; - } - - ctx_dft = llama_init_from_model(params.model_dft, cparams_dft); - if (ctx_dft == nullptr) { - LOG_ERR("%s", "failed to create draft context\n"); - return nullptr; - } - } - - // Compute the implementations to use based on the resolved stage chain. - std::vector configs = {}; - configs.reserve(stages.size()); - - for (const auto & stage : stages) { - common_params_speculative stage_params = params.with_stage_overrides(stage); - - if (stage.type == COMMON_SPECULATIVE_TYPE_NGRAM_MOD && !stage_params.ngram_mod) { - stage_params.ngram_mod = std::make_shared(stage_params.ngram_size_n, 4*1024*1024); - - LOG_INF("%s: initialized ngram_mod with n=%d, size=%zu (%.3f MB)\n", __func__, - stage_params.ngram_size_n, stage_params.ngram_mod->size(), - (float)(stage_params.ngram_mod->size_bytes())/1024/1024); - - if (stage_params.ngram_size_n < 16) { - LOG_WRN("%s: ngram_mod n=%d is too small - poor quality is possible, see: https://github.com/ggml-org/llama.cpp/pull/19164\n", __func__, stage_params.ngram_size_n); - } - } - - configs.push_back(common_speculative_config(stage, stage_params)); - } - - if (!configs.empty() && llama_model_has_recurrent(llama_get_model(ctx_tgt))) { - const int ckpt_tokens = std::max(1, params.get_max_stage_n_max() + 1); - const int actual_mode = llama_spec_ckpt_init(ctx_tgt, params.recurrent_ckpt_mode, ckpt_tokens); - if (actual_mode == LLAMA_SPEC_CKPT_NONE) { - LOG_ERR("%s: failed to prepare recurrent checkpoint mode '%s' during speculative init (max_tokens=%d)\n", - __func__, - params.recurrent_ckpt_mode == LLAMA_SPEC_CKPT_PER_STEP ? "per-step" : - params.recurrent_ckpt_mode == LLAMA_SPEC_CKPT_GPU_FALLBACK ? "gpu-fallback" : - params.recurrent_ckpt_mode == LLAMA_SPEC_CKPT_CPU ? "cpu" : "auto", - ckpt_tokens); - if (ctx_dft != nullptr) { - llama_free(ctx_dft); - } - return nullptr; - } - llama_spec_ckpt_discard(ctx_tgt); - params.recurrent_ckpt_mode = actual_mode; - } - - std::vector> impls = {}; - - for (const common_speculative_config & config : configs) { - LOG_DBG("%s: adding implementation %s\n", __func__, common_speculative_type_to_str(config.type).c_str()); - switch (config.type) { - case COMMON_SPECULATIVE_TYPE_NONE: - break; - case COMMON_SPECULATIVE_TYPE_DRAFT: { - impls.push_back(std::make_unique(config.type, - /* .ctx_tgt = */ ctx_tgt, - /* .ctx_dft = */ ctx_dft, - /* .replacements = */ config.params.replacements - )); - break; - } - case COMMON_SPECULATIVE_TYPE_DFLASH: { - auto state = std::make_unique( - config.type, - ctx_tgt, - ctx_dft, - config.params.dflash_cross_ctx); - if (!state->ready) { - LOG_ERR("%s: failed to initialize DFlash speculative state\n", __func__); - return nullptr; - } - impls.push_back(std::move(state)); - ctx_dft = nullptr; - break; - } - case COMMON_SPECULATIVE_TYPE_MTP: { - llama_context * ctx_mtp = ctx_dft; - if (!ctx_mtp) { - const llama_model * model = llama_get_model(ctx_tgt); - ctx_mtp = llama_init_from_model(const_cast(model), config.params.cparams_dft); - if (!ctx_mtp) { - LOG_ERR("%s: failed to create MTP context\n", __func__); - return nullptr; - } - } - ctx_dft = nullptr; - - const bool use_constant_draft_positions = llama_model_is_gemma4_mtp_assistant(llama_get_model(ctx_mtp)); - impls.push_back(std::make_unique( - config.type, ctx_tgt, ctx_mtp, use_constant_draft_positions)); - break; - } - case COMMON_SPECULATIVE_TYPE_EAGLE3: { - impls.push_back(std::make_unique(config.type)); - break; - } - case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: { - common_ngram_map ngram_map = get_common_ngram_map(config); - - uint16_t ngram_size_key = ngram_map.size_key; - uint16_t mgram_size_value = ngram_map.size_value; - - auto config_simple = common_ngram_simple_config { - /* .size_ngram = */ ngram_size_key, - /* .size_mgram = */ mgram_size_value - }; - auto state = std::make_unique( - /* .type = */ config.type, - /* .state = */ config_simple - ); - impls.push_back(std::move(state)); - break; - } - case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: - case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: { - impls.push_back(std::make_unique( - (config.type), - get_common_ngram_map(config) - )); - break; - } - case COMMON_SPECULATIVE_TYPE_NGRAM_MOD: { - GGML_ASSERT(config.params.ngram_mod); - impls.push_back(std::make_unique(config.type, *config.params.ngram_mod)); - break; - } - case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: { - auto state = create_state_ngram_cache( - config.params.lookup_cache_static, config.params.lookup_cache_dynamic, config); - impls.push_back(std::make_unique(state)); - break; - } - case COMMON_SPECULATIVE_TYPE_SUFFIX: { - int depth = config.params.suffix_max_depth > 0 ? config.params.suffix_max_depth : 64; - const llama_model * model = llama_get_model(ctx_tgt); - impls.push_back(std::make_unique( - config.type, depth, config.params.suffix_corpus, model)); - break; - } - default: - break; - } - } - - if (impls.empty()) { - LOG_WRN("%s", "no implementations specified for speculative decoding\n"); - return nullptr; - } - - auto * result = new common_speculative { - /* .configs = */ std::move(configs), - /* .impls = */ std::move(impls) - }; - - // initialize autotune if requested - if (params.autotune && params.has_composite_stage_chain()) { - LOG_WRN("Autotune disabled — explicit speculative stage chains are not supported yet\n"); - } else if (params.autotune && !result->impls.empty()) { - auto actual_type = result->impls[0]->type; - if (actual_type != COMMON_SPECULATIVE_TYPE_NONE && - actual_type != COMMON_SPECULATIVE_TYPE_EAGLE3) { - result->tuner = std::make_unique(); - result->tuner->init(actual_type, params, llama_get_model(ctx_tgt)); - LOG_DBG("Autotune initialized for %s, tuning %zu parameters\n", - common_speculative_type_to_str(actual_type).c_str(), - result->tuner->coords.size()); - } else { - LOG_WRN("Autotune disabled — speculative type %s is not supported for autotuning\n", - common_speculative_type_to_str(actual_type).c_str()); - } - } - - return result; -} - -void common_speculative_free(common_speculative * spec) { - if (spec == nullptr) { - return; - } - - delete spec; -} - -void common_speculative_begin(common_speculative * spec, const llama_tokens & prompt) { - if (spec == nullptr) { - return; - } - - for (auto & impl : spec->impls) { - common_time_meas tm(impl->t_begin_us, !impl->gen_perf); - impl->begin(prompt); - impl->n_call_begin++; - } -} - -llama_tokens common_speculative_draft( - common_speculative * spec, - common_params_speculative & params, - const llama_tokens & prompt_tgt, // specified in target model vocab - llama_token id_last, - llama_pos draft_base_pos, - llama_seq_id draft_seq_id) { - llama_tokens result; - - spec->t_step_start_us = ggml_time_us(); - - // apply autotune proposal if enabled - if (spec->tuner && spec->tuner->enabled) { - spec->tuner->propose(params); - } - - const auto runtime_stages = params.get_resolved_stages(); - const bool use_runtime_stage_overrides = common_speculative_stage_chain_matches(runtime_stages, spec->configs); - - spec->curr_impl = nullptr; // reset current implementation - - for (size_t i = 0; i < spec->impls.size(); ++i) { - auto & impl = spec->impls[i]; - const auto & runtime_stage = use_runtime_stage_overrides ? runtime_stages[i] : spec->configs[i].stage; - common_params_speculative impl_params = common_speculative_get_runtime_params(spec->configs[i], params, runtime_stage); - result.clear(); - - { - common_time_meas tm(impl->t_draft_us, !impl->gen_perf); - impl->draft(impl_params, prompt_tgt, id_last, draft_base_pos, draft_seq_id, result); - impl->n_call_draft++; - } - - if (result.empty()) { - continue; - } - - if (common_speculative_type_is_self_spec(impl->type) && impl_params.n_min > 0 && (int)result.size() < impl_params.n_min) { - LOG_DBG("%s: impl %s drafted %zu tokens, below fallback threshold %d - trying next implementation\n", - __func__, common_speculative_type_to_str(impl->type).c_str(), result.size(), impl_params.n_min); - result.clear(); - continue; - } - LOG_DBG("%s: called impl %s, hist size = %zu, call_count = %zu, gen = %zu\n", __func__, - common_speculative_type_to_str(impl.get()->type).c_str(), prompt_tgt.size(), - impl.get()->n_call_draft, result.size()); - - spec->curr_impl = impl.get(); - impl->n_gen_drafts++; - impl->n_gen_tokens += result.size(); - - break; // We have a draft, so break out of the loop and return it. - } - - // store draft count for tuner feedback - if (spec->tuner && spec->tuner->enabled) { - spec->last_n_drafted = (int)result.size(); - } - - return result; -} - -void common_speculative_accept(common_speculative * spec, uint16_t n_accepted) { - if (spec->tuner && spec->tuner->enabled && spec->t_step_start_us > 0) { - int64_t step_time_us = ggml_time_us() - spec->t_step_start_us; - double step_tps = (step_time_us > 100) - ? (n_accepted + 1.0) * 1e6 / (double)step_time_us - : 0.0; - spec->tuner->accept_feedback(n_accepted, spec->last_n_drafted, step_tps); - spec->t_step_start_us = 0; - } - - common_speculative_state * impl = spec->curr_impl; - - if (!impl) { - return; - } - - { - common_time_meas tm(impl->t_accept_us, !impl->gen_perf); - if (n_accepted > 0) { - impl->n_acc_drafts++; - impl->n_acc_tokens += n_accepted; - } - - impl->accept(n_accepted); - impl->n_call_accept++; - } - - if (impl->type != COMMON_SPECULATIVE_TYPE_MTP) { - if (auto * mtp_state = common_speculative_get_mtp_state(spec); mtp_state != nullptr) { - mtp_invalidate_cached_drafts(*mtp_state); - } - } -} - -static bool common_speculative_has_type(const common_speculative * spec, common_speculative_type type) { - if (spec == nullptr) { - return false; - } - - return std::any_of(spec->configs.begin(), spec->configs.end(), [type](const common_speculative_config & config) { - return config.type == type; - }); -} - -static int common_speculative_ctx_mtp_n_embd(llama_context * ctx) { - return ctx ? (int) llama_mtp_state_n_embd(ctx) : 0; -} - -static bool common_speculative_batch_token_has_seq_id( - const llama_batch & batch, - int token_index, - llama_seq_id seq_id) { - if (batch.n_seq_id == nullptr || batch.seq_id == nullptr || batch.n_seq_id[token_index] <= 0 || batch.seq_id[token_index] == nullptr) { - return false; - } - - for (int i = 0; i < batch.n_seq_id[token_index]; ++i) { - if (batch.seq_id[token_index][i] == seq_id) { - return true; - } - } - - return false; -} - -static bool common_speculative_batch_is_exact_single_seq( - const llama_batch & batch, - llama_seq_id seq_id) { - if (batch.n_tokens <= 0 || batch.n_seq_id == nullptr || batch.seq_id == nullptr) { - return false; - } - - for (int i = 0; i < batch.n_tokens; ++i) { - if (batch.n_seq_id[i] != 1 || batch.seq_id[i] == nullptr || batch.seq_id[i][0] != seq_id) { - return false; - } - } - - return true; -} - -static int common_speculative_copy_seq_batch( - const llama_batch & batch, - llama_seq_id seq_id, - llama_batch & seq_batch) { - if (batch.token == nullptr || batch.pos == nullptr) { - return -1; - } - - if (batch.n_tokens < 1) { - return 0; - } - - std::vector token_indices; - token_indices.reserve(batch.n_tokens); - for (int i = 0; i < batch.n_tokens; ++i) { - if (common_speculative_batch_token_has_seq_id(batch, i, seq_id)) { - token_indices.push_back(i); - } - } - - if (token_indices.empty()) { - return 0; - } - - seq_batch = llama_batch_init((int) token_indices.size(), 0, 1); - for (const int i : token_indices) { - common_batch_add(seq_batch, batch.token[i], batch.pos[i], { seq_id }, batch.logits != nullptr && batch.logits[i]); - } - - return (int) token_indices.size(); -} - -static bool common_speculative_feature_view_copy_batch_rows( - const common_speculative_feature_view & view, - const llama_batch & batch, - llama_seq_id seq_id, - std::vector * hidden_rows) { - if (hidden_rows == nullptr || view.kind != COMMON_SPECULATIVE_FEATURE_HIDDEN_STATE || view.width <= 0 || batch.n_tokens <= 0 || batch.pos == nullptr) { - return false; - } - - std::unordered_map rows_by_pos; - rows_by_pos.reserve(view.rows.size()); - for (const auto & row : view.rows) { - if (row.seq_id == seq_id && row.data != nullptr) { - rows_by_pos[row.pos] = row.data; - } - } - - hidden_rows->clear(); - hidden_rows->reserve((size_t) batch.n_tokens * view.width); - for (int i = 0; i < batch.n_tokens; ++i) { - auto it = rows_by_pos.find(batch.pos[i]); - if (it == rows_by_pos.end()) { - hidden_rows->clear(); - return false; - } - - hidden_rows->insert(hidden_rows->end(), it->second, it->second + view.width); - } - - return hidden_rows->size() == (size_t) batch.n_tokens * view.width; -} +#include "speculative-impl.h" static bool common_speculative_capture_target_features( common_speculative * spec, @@ -3009,6 +1271,7 @@ static void mtp_clear_target_hidden(common_speculative_state_mtp & state, llama_ state.draft_cache_by_seq.erase(seq_id); } +// DFlash target-window replay and maintenance helpers. struct dflash_append_breakdown { uint64_t filter_us = 0; uint64_t window_alloc_us = 0; diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 035dd8e6..87f375bd 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -41,6 +41,8 @@ add_library(llama ../include/llama.h llama.cpp llama-spec-features.cpp + llama-spec-features-dflash.cpp + llama-dflash.cpp llama-vocab.cpp llama-grammar.cpp llama-sampling.cpp diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index eff7d675..7f3e4f33 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -116,9 +116,9 @@ void llm_build_context::init() { lctx.inp_pos_bucket = nullptr; lctx.inp_embd_enc = nullptr; lctx.inp_KQ_mask_cross = nullptr; - lctx.inp_dflash_target_features = nullptr; - lctx.inp_dflash_pos_ctx = nullptr; - lctx.inp_dflash_kq_mask = nullptr; + lctx.dflash.inputs.target_features = nullptr; + lctx.dflash.inputs.pos_ctx = nullptr; + lctx.dflash.inputs.kq_mask = nullptr; } } @@ -2195,7 +2195,7 @@ struct ggml_cgraph * llm_build_context::llama_build_graph_dflash_kv_cache(llama_ } }; - struct llm_build_context llm(lctx, dummy, cb, false, false, 0, false, &lctx.dflash_buf_compute_meta); + struct llm_build_context llm(lctx, dummy, cb, false, false, 0, false, &lctx.dflash.kv.cache_compute_meta); llm.init(); @@ -2232,7 +2232,7 @@ struct ggml_cgraph * llm_build_context::llama_build_graph_dflash_kv_workspace(ll } }; - struct llm_build_context llm(lctx, dummy, cb, false, false, 0, false, &lctx.dflash_workspace_buf_compute_meta); + struct llm_build_context llm(lctx, dummy, cb, false, false, 0, false, &lctx.dflash.kv.workspace_compute_meta); llm.init(); diff --git a/src/llama-context.h b/src/llama-context.h index 8ad9d74b..ebd4ded3 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -278,77 +278,162 @@ struct llama_context { size_t draft_input_hidden_state_n_floats = 0; std::vector draft_input_hidden_state_owned; - const float * dflash_target_features = nullptr; - size_t dflash_target_features_n_floats = 0; - int32_t dflash_target_features_n_rows = 0; - const float * dflash_target_append_features = nullptr; - size_t dflash_target_append_features_n_floats = 0; - int32_t dflash_target_append_features_n_rows = 0; - const llama_pos * dflash_target_positions = nullptr; - size_t dflash_target_positions_n = 0; - uint64_t dflash_target_window_version = 0; - int32_t dflash_target_window_keep_rows = 0; - int32_t dflash_target_window_append_rows = 0; - bool dflash_target_window_replace = false; - std::vector dflash_target_features_owned; - std::vector dflash_target_append_features_owned; - std::vector dflash_target_positions_owned; - std::vector dflash_target_features_padded; - std::vector dflash_feature_view_buffer; - std::vector dflash_pos_ctx_data; - std::vector dflash_kq_mask_data; - std::vector dflash_kq_mask_swa_data; - int32_t dflash_visible_cross_ctx = 0; - std::vector dflash_k_ctx_cache; - std::vector dflash_v_ctx_cache; - std::vector dflash_k_ctx_workspace; - std::vector dflash_v_ctx_workspace; - struct ggml_context * dflash_cache_ctx = nullptr; - std::vector dflash_cache_bufs; - int32_t dflash_kv_cache_write_pos = 0; - int32_t dflash_kv_cache_n_filled = 0; - int32_t dflash_kv_cache_update_rows = 0; - int32_t dflash_kv_cache_reserved_rows = 0; - int32_t dflash_kv_cache_view_write_pos = 0; - int32_t dflash_kv_cache_view_n_filled = 0; - uint64_t dflash_kv_cache_applied_window_version = 0; - bool dflash_kv_cache_valid = false; - bool dflash_kv_cache_view_valid = false; - int32_t dflash_kv_workspace_write_pos = 0; - int32_t dflash_kv_workspace_n_filled = 0; - int32_t dflash_kv_workspace_reserved_rows = 0; - int32_t dflash_kv_workspace_token_capacity = 0; - int32_t dflash_kv_workspace_n_kv_total = 0; - uint64_t dflash_kv_workspace_applied_window_version = 0; - bool dflash_kv_workspace_valid = false; - bool dflash_kv_workspace_sync_pending = false; - std::vector dflash_buf_compute_meta; - std::vector dflash_workspace_buf_compute_meta; - ggml_backend_sched_t dflash_sched = nullptr; - ggml_backend_sched_t dflash_workspace_sched = nullptr; - ggml_cgraph * dflash_kv_graph = nullptr; - ggml_cgraph * dflash_kv_workspace_graph = nullptr; - int32_t dflash_kv_graph_rows = 0; - int32_t dflash_kv_graph_write_pos = 0; - int32_t dflash_kv_workspace_graph_rows = 0; - int32_t dflash_kv_workspace_graph_write_pos = 0; - struct ggml_tensor * dflash_kv_input_target_features = nullptr; - struct ggml_tensor * dflash_kv_input_pos_ctx = nullptr; - struct ggml_tensor * dflash_kq_mask_tensor = nullptr; - struct ggml_tensor * dflash_kq_mask_swa_tensor = nullptr; + struct dflash_runtime { + struct target_window_state { + const float * features = nullptr; + size_t features_n_floats = 0; + int32_t features_n_rows = 0; + const float * append_features = nullptr; + size_t append_features_n_floats = 0; + int32_t append_features_n_rows = 0; + const llama_pos * positions = nullptr; + size_t positions_n = 0; + uint64_t version = 0; + int32_t keep_rows = 0; + int32_t append_rows = 0; + bool replace = false; + std::vector features_owned; + std::vector append_features_owned; + std::vector positions_owned; + std::vector features_padded; + std::vector pos_ctx_data; + std::vector kq_mask_data; + std::vector kq_mask_swa_data; + }; - struct dflash_capture_state { - std::vector layer_ids; - std::vector> layer_rows; - int32_t row_count = 0; - int32_t row_width = 0; - uint64_t capture_batch_id = 0; - std::vector layer_seen_batch_id; - ggml_backend_sched_eval_callback prev_cb_eval = nullptr; - void * prev_cb_eval_user_data = nullptr; + struct kv_runtime_state { + std::vector k_ctx_cache; + std::vector v_ctx_cache; + std::vector k_ctx_workspace; + std::vector v_ctx_workspace; + struct ggml_context * cache_ctx = nullptr; + std::vector cache_bufs; + int32_t cache_write_pos = 0; + int32_t cache_n_filled = 0; + int32_t cache_update_rows = 0; + int32_t cache_reserved_rows = 0; + int32_t cache_view_write_pos = 0; + int32_t cache_view_n_filled = 0; + uint64_t cache_applied_window_version = 0; + bool cache_valid = false; + bool cache_view_valid = false; + int32_t workspace_write_pos = 0; + int32_t workspace_n_filled = 0; + int32_t workspace_reserved_rows = 0; + int32_t workspace_token_capacity = 0; + int32_t workspace_n_kv_total = 0; + uint64_t workspace_applied_window_version = 0; + bool workspace_valid = false; + bool workspace_sync_pending = false; + std::vector cache_compute_meta; + std::vector workspace_compute_meta; + ggml_backend_sched_t cache_sched = nullptr; + ggml_backend_sched_t workspace_sched = nullptr; + ggml_cgraph * cache_graph = nullptr; + ggml_cgraph * workspace_graph = nullptr; + int32_t cache_graph_rows = 0; + int32_t cache_graph_write_pos = 0; + int32_t workspace_graph_rows = 0; + int32_t workspace_graph_write_pos = 0; + struct ggml_tensor * cache_input_target_features = nullptr; + struct ggml_tensor * cache_input_pos_ctx = nullptr; + struct ggml_tensor * kq_mask_tensor = nullptr; + struct ggml_tensor * kq_mask_swa_tensor = nullptr; + }; + + struct capture_state { + std::vector layer_ids; + std::vector> layer_rows; + int32_t row_count = 0; + int32_t row_width = 0; + uint64_t capture_batch_id = 0; + std::vector layer_seen_batch_id; + ggml_backend_sched_eval_callback prev_cb_eval = nullptr; + void * prev_cb_eval_user_data = nullptr; + }; + + struct input_state { + struct ggml_tensor * target_features = nullptr; // F32 [n_target_features, cross_ctx] + struct ggml_tensor * pos_ctx = nullptr; // I32 [cross_ctx] + struct ggml_tensor * kq_mask = nullptr; // F32 [cross_ctx + n_batch, GGML_PAD(n_batch)] + struct ggml_tensor * kq_mask_swa = nullptr; // F32 [cross_ctx + n_batch, GGML_PAD(n_batch)] + }; + + target_window_state target; + kv_runtime_state kv; + std::unique_ptr capture; + std::vector feature_view_buffer; + input_state inputs; + int32_t visible_cross_ctx = 0; + llama_dflash_profile_stats profile; }; - std::unique_ptr dflash_capture; - llama_dflash_profile_stats dflash_profile; + dflash_runtime dflash; + using dflash_capture_state = dflash_runtime::capture_state; + + const float * & dflash_target_features = dflash.target.features; + size_t & dflash_target_features_n_floats = dflash.target.features_n_floats; + int32_t & dflash_target_features_n_rows = dflash.target.features_n_rows; + const float * & dflash_target_append_features = dflash.target.append_features; + size_t & dflash_target_append_features_n_floats = dflash.target.append_features_n_floats; + int32_t & dflash_target_append_features_n_rows = dflash.target.append_features_n_rows; + const llama_pos * & dflash_target_positions = dflash.target.positions; + size_t & dflash_target_positions_n = dflash.target.positions_n; + uint64_t & dflash_target_window_version = dflash.target.version; + int32_t & dflash_target_window_keep_rows = dflash.target.keep_rows; + int32_t & dflash_target_window_append_rows = dflash.target.append_rows; + bool & dflash_target_window_replace = dflash.target.replace; + std::vector & dflash_target_features_owned = dflash.target.features_owned; + std::vector & dflash_target_append_features_owned = dflash.target.append_features_owned; + std::vector & dflash_target_positions_owned = dflash.target.positions_owned; + std::vector & dflash_target_features_padded = dflash.target.features_padded; + std::vector & dflash_feature_view_buffer = dflash.feature_view_buffer; + std::vector & dflash_pos_ctx_data = dflash.target.pos_ctx_data; + std::vector & dflash_kq_mask_data = dflash.target.kq_mask_data; + std::vector & dflash_kq_mask_swa_data = dflash.target.kq_mask_swa_data; + int32_t & dflash_visible_cross_ctx = dflash.visible_cross_ctx; + std::vector & dflash_k_ctx_cache = dflash.kv.k_ctx_cache; + std::vector & dflash_v_ctx_cache = dflash.kv.v_ctx_cache; + std::vector & dflash_k_ctx_workspace = dflash.kv.k_ctx_workspace; + std::vector & dflash_v_ctx_workspace = dflash.kv.v_ctx_workspace; + struct ggml_context * & dflash_cache_ctx = dflash.kv.cache_ctx; + std::vector & dflash_cache_bufs = dflash.kv.cache_bufs; + int32_t & dflash_kv_cache_write_pos = dflash.kv.cache_write_pos; + int32_t & dflash_kv_cache_n_filled = dflash.kv.cache_n_filled; + int32_t & dflash_kv_cache_update_rows = dflash.kv.cache_update_rows; + int32_t & dflash_kv_cache_reserved_rows = dflash.kv.cache_reserved_rows; + int32_t & dflash_kv_cache_view_write_pos = dflash.kv.cache_view_write_pos; + int32_t & dflash_kv_cache_view_n_filled = dflash.kv.cache_view_n_filled; + uint64_t & dflash_kv_cache_applied_window_version = dflash.kv.cache_applied_window_version; + bool & dflash_kv_cache_valid = dflash.kv.cache_valid; + bool & dflash_kv_cache_view_valid = dflash.kv.cache_view_valid; + int32_t & dflash_kv_workspace_write_pos = dflash.kv.workspace_write_pos; + int32_t & dflash_kv_workspace_n_filled = dflash.kv.workspace_n_filled; + int32_t & dflash_kv_workspace_reserved_rows = dflash.kv.workspace_reserved_rows; + int32_t & dflash_kv_workspace_token_capacity = dflash.kv.workspace_token_capacity; + int32_t & dflash_kv_workspace_n_kv_total = dflash.kv.workspace_n_kv_total; + uint64_t & dflash_kv_workspace_applied_window_version = dflash.kv.workspace_applied_window_version; + bool & dflash_kv_workspace_valid = dflash.kv.workspace_valid; + bool & dflash_kv_workspace_sync_pending = dflash.kv.workspace_sync_pending; + std::vector & dflash_buf_compute_meta = dflash.kv.cache_compute_meta; + std::vector & dflash_workspace_buf_compute_meta = dflash.kv.workspace_compute_meta; + ggml_backend_sched_t & dflash_sched = dflash.kv.cache_sched; + ggml_backend_sched_t & dflash_workspace_sched = dflash.kv.workspace_sched; + ggml_cgraph * & dflash_kv_graph = dflash.kv.cache_graph; + ggml_cgraph * & dflash_kv_workspace_graph = dflash.kv.workspace_graph; + int32_t & dflash_kv_graph_rows = dflash.kv.cache_graph_rows; + int32_t & dflash_kv_graph_write_pos = dflash.kv.cache_graph_write_pos; + int32_t & dflash_kv_workspace_graph_rows = dflash.kv.workspace_graph_rows; + int32_t & dflash_kv_workspace_graph_write_pos = dflash.kv.workspace_graph_write_pos; + struct ggml_tensor * & dflash_kv_input_target_features = dflash.kv.cache_input_target_features; + struct ggml_tensor * & dflash_kv_input_pos_ctx = dflash.kv.cache_input_pos_ctx; + struct ggml_tensor * & dflash_kq_mask_tensor = dflash.kv.kq_mask_tensor; + struct ggml_tensor * & dflash_kq_mask_swa_tensor = dflash.kv.kq_mask_swa_tensor; + std::unique_ptr & dflash_capture = dflash.capture; + llama_dflash_profile_stats & dflash_profile = dflash.profile; + struct ggml_tensor * & inp_dflash_target_features = dflash.inputs.target_features; + struct ggml_tensor * & inp_dflash_pos_ctx = dflash.inputs.pos_ctx; + struct ggml_tensor * & inp_dflash_kq_mask = dflash.inputs.kq_mask; + struct ggml_tensor * & inp_dflash_kq_mask_swa = dflash.inputs.kq_mask_swa; // input tensors struct ggml_tensor * inp_tokens; // I32 [n_batch] @@ -369,10 +454,6 @@ struct llama_context { struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch] struct ggml_tensor * inp_scale = nullptr; // F32 [n_tokens] struct ggml_tensor * inp_mtp_states = nullptr; - struct ggml_tensor * inp_dflash_target_features = nullptr; // F32 [n_target_features, cross_ctx] - struct ggml_tensor * inp_dflash_pos_ctx = nullptr; // I32 [cross_ctx] - struct ggml_tensor * inp_dflash_kq_mask = nullptr; // F32 [cross_ctx + n_batch, GGML_PAD(n_batch)] - struct ggml_tensor * inp_dflash_kq_mask_swa = nullptr; // F32 [cross_ctx + n_batch, GGML_PAD(n_batch)] ggml_backend_t ggml_backend_by_name(const char * name); diff --git a/src/llama-dflash.cpp b/src/llama-dflash.cpp new file mode 100644 index 00000000..aed84a25 --- /dev/null +++ b/src/llama-dflash.cpp @@ -0,0 +1,1240 @@ +#include "llama-dflash.h" + +#include "llama-impl.h" +#include "llama-build-context.h" +#include "llama-context.h" +#include "llama-model.h" +#include "llama-spec-features.h" + +#include "ggml.h" +#include "ggml-backend.h" + +#include +#include +#include +#include +#include + +static bool llama_env_flag_enabled_local(const char * name) { + const char * env = std::getenv(name); + return env != nullptr && *env != '\0' && + std::strcmp(env, "0") != 0 && + std::strcmp(env, "false") != 0 && + std::strcmp(env, "off") != 0; +} + +enum llama_dflash_kv_node_kind { + LLAMA_DFLASH_KV_NODE_NONE = 0, + LLAMA_DFLASH_KV_NODE_FUSED_TARGET, + LLAMA_DFLASH_KV_NODE_K_PROJ, + LLAMA_DFLASH_KV_NODE_K_NORM, + LLAMA_DFLASH_KV_NODE_K_ROPE, + LLAMA_DFLASH_KV_NODE_V_PROJ, + LLAMA_DFLASH_KV_NODE_K_STORE, + LLAMA_DFLASH_KV_NODE_V_STORE, +}; + +enum llama_dflash_main_node_kind { + LLAMA_DFLASH_MAIN_NODE_NONE = 0, + LLAMA_DFLASH_MAIN_NODE_QCUR, + LLAMA_DFLASH_MAIN_NODE_K_DRAFT, + LLAMA_DFLASH_MAIN_NODE_V_DRAFT, + LLAMA_DFLASH_MAIN_NODE_K_CTX_VIEW, + LLAMA_DFLASH_MAIN_NODE_V_CTX_VIEW, + LLAMA_DFLASH_MAIN_NODE_K_CONCAT, + LLAMA_DFLASH_MAIN_NODE_V_CONCAT, + LLAMA_DFLASH_MAIN_NODE_K_PAD, + LLAMA_DFLASH_MAIN_NODE_V_PAD, + LLAMA_DFLASH_MAIN_NODE_K_PERM_CONT, + LLAMA_DFLASH_MAIN_NODE_V_PERM_CONT, + LLAMA_DFLASH_MAIN_NODE_FLASH_ATTN, + LLAMA_DFLASH_MAIN_NODE_ATTN_OUT, + LLAMA_DFLASH_MAIN_NODE_FFN, + LLAMA_DFLASH_MAIN_NODE_RESULT_ROWS, + LLAMA_DFLASH_MAIN_NODE_RESULT_NORM, + LLAMA_DFLASH_MAIN_NODE_RESULT, +}; + +struct llama_dflash_kv_node_profiler { + llama_dflash_profile_stats * profile = nullptr; + int64_t t_start_us = 0; + llama_dflash_kv_node_kind active_kind = LLAMA_DFLASH_KV_NODE_NONE; +}; + +struct llama_dflash_main_node_profiler { + llama_dflash_profile_stats * profile = nullptr; + ggml_backend_sched_eval_callback prev_callback = nullptr; + void * prev_user_data = nullptr; + bool prev_active = false; + int64_t t_start_us = 0; + llama_dflash_main_node_kind active_kind = LLAMA_DFLASH_MAIN_NODE_NONE; +}; + +static bool llama_dflash_tensor_name_has_prefix(const struct ggml_tensor * tensor, const char * prefix) { + if (tensor == nullptr || prefix == nullptr || prefix[0] == '\0') { + return false; + } + + return std::strncmp(tensor->name, prefix, std::strlen(prefix)) == 0; +} + +static bool llama_dflash_tensor_name_matches_label(const struct ggml_tensor * tensor, const char * label) { + if (!llama_dflash_tensor_name_has_prefix(tensor, label)) { + return false; + } + + const size_t label_len = std::strlen(label); + const char next = tensor->name[label_len]; + return next == '\0' || next == '-'; +} + +static llama_dflash_kv_node_kind llama_dflash_kv_node_kind_from_tensor(const struct ggml_tensor * tensor) { + if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_fused_target")) { + return LLAMA_DFLASH_KV_NODE_FUSED_TARGET; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_proj")) { + return LLAMA_DFLASH_KV_NODE_K_PROJ; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_norm")) { + return LLAMA_DFLASH_KV_NODE_K_NORM; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_rope")) { + return LLAMA_DFLASH_KV_NODE_K_ROPE; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_v_proj")) { + return LLAMA_DFLASH_KV_NODE_V_PROJ; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_store")) { + return LLAMA_DFLASH_KV_NODE_K_STORE; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_v_store")) { + return LLAMA_DFLASH_KV_NODE_V_STORE; + } + + return LLAMA_DFLASH_KV_NODE_NONE; +} + +static void llama_dflash_kv_node_profile_add( + llama_dflash_profile_stats & profile, + llama_dflash_kv_node_kind kind, + uint64_t elapsed_us) { + switch (kind) { + case LLAMA_DFLASH_KV_NODE_FUSED_TARGET: + profile.graph_kv_node_fused_target_calls++; + profile.graph_kv_node_fused_target_us += elapsed_us; + break; + case LLAMA_DFLASH_KV_NODE_K_PROJ: + profile.graph_kv_node_k_proj_calls++; + profile.graph_kv_node_k_proj_us += elapsed_us; + break; + case LLAMA_DFLASH_KV_NODE_K_NORM: + profile.graph_kv_node_k_norm_calls++; + profile.graph_kv_node_k_norm_us += elapsed_us; + break; + case LLAMA_DFLASH_KV_NODE_K_ROPE: + profile.graph_kv_node_k_rope_calls++; + profile.graph_kv_node_k_rope_us += elapsed_us; + break; + case LLAMA_DFLASH_KV_NODE_V_PROJ: + profile.graph_kv_node_v_proj_calls++; + profile.graph_kv_node_v_proj_us += elapsed_us; + break; + case LLAMA_DFLASH_KV_NODE_K_STORE: + profile.graph_kv_node_k_store_calls++; + profile.graph_kv_node_k_store_us += elapsed_us; + break; + case LLAMA_DFLASH_KV_NODE_V_STORE: + profile.graph_kv_node_v_store_calls++; + profile.graph_kv_node_v_store_us += elapsed_us; + break; + case LLAMA_DFLASH_KV_NODE_NONE: + break; + } +} + +static llama_dflash_main_node_kind llama_dflash_main_node_kind_from_tensor(const struct ggml_tensor * tensor) { + if (llama_dflash_tensor_name_has_prefix(tensor, "Qcur")) { + return LLAMA_DFLASH_MAIN_NODE_QCUR; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "Kcur_noise")) { + return LLAMA_DFLASH_MAIN_NODE_K_DRAFT; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "Vcur_noise")) { + return LLAMA_DFLASH_MAIN_NODE_V_DRAFT; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "Kcur_ctx_cache")) { + return LLAMA_DFLASH_MAIN_NODE_K_CTX_VIEW; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "Vcur_ctx_cache")) { + return LLAMA_DFLASH_MAIN_NODE_V_CTX_VIEW; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_k_concat")) { + return LLAMA_DFLASH_MAIN_NODE_K_CONCAT; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_v_concat")) { + return LLAMA_DFLASH_MAIN_NODE_V_CONCAT; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_k_pad")) { + return LLAMA_DFLASH_MAIN_NODE_K_PAD; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_v_pad")) { + return LLAMA_DFLASH_MAIN_NODE_V_PAD; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_k_perm_cont")) { + return LLAMA_DFLASH_MAIN_NODE_K_PERM_CONT; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_v_perm_cont")) { + return LLAMA_DFLASH_MAIN_NODE_V_PERM_CONT; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "flash_attn_reshaped")) { + return LLAMA_DFLASH_MAIN_NODE_NONE; + } + if (llama_dflash_tensor_name_matches_label(tensor, "flash_attn")) { + return LLAMA_DFLASH_MAIN_NODE_FLASH_ATTN; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "kqv_out")) { + return LLAMA_DFLASH_MAIN_NODE_ATTN_OUT; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "ffn_out")) { + return LLAMA_DFLASH_MAIN_NODE_FFN; + } + if (llama_dflash_tensor_name_matches_label(tensor, "result_output_rows")) { + return LLAMA_DFLASH_MAIN_NODE_RESULT_ROWS; + } + if (llama_dflash_tensor_name_matches_label(tensor, "result_norm")) { + return LLAMA_DFLASH_MAIN_NODE_RESULT_NORM; + } + if (llama_dflash_tensor_name_matches_label(tensor, "output")) { + return LLAMA_DFLASH_MAIN_NODE_RESULT; + } + if (llama_dflash_tensor_name_matches_label(tensor, "result_output")) { + return LLAMA_DFLASH_MAIN_NODE_RESULT; + } + + return LLAMA_DFLASH_MAIN_NODE_NONE; +} + +static void llama_dflash_main_node_profile_add( + llama_dflash_profile_stats & profile, + llama_dflash_main_node_kind kind, + uint64_t elapsed_us) { + switch (kind) { + case LLAMA_DFLASH_MAIN_NODE_QCUR: + profile.graph_main_node_qcur_calls++; + profile.graph_main_node_qcur_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_K_DRAFT: + profile.graph_main_node_k_draft_calls++; + profile.graph_main_node_k_draft_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_V_DRAFT: + profile.graph_main_node_v_draft_calls++; + profile.graph_main_node_v_draft_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_K_CTX_VIEW: + profile.graph_main_node_k_ctx_view_calls++; + profile.graph_main_node_k_ctx_view_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_V_CTX_VIEW: + profile.graph_main_node_v_ctx_view_calls++; + profile.graph_main_node_v_ctx_view_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_K_CONCAT: + profile.graph_main_node_k_concat_calls++; + profile.graph_main_node_k_concat_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_V_CONCAT: + profile.graph_main_node_v_concat_calls++; + profile.graph_main_node_v_concat_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_K_PAD: + profile.graph_main_node_k_pad_calls++; + profile.graph_main_node_k_pad_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_V_PAD: + profile.graph_main_node_v_pad_calls++; + profile.graph_main_node_v_pad_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_K_PERM_CONT: + profile.graph_main_node_k_perm_cont_calls++; + profile.graph_main_node_k_perm_cont_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_V_PERM_CONT: + profile.graph_main_node_v_perm_cont_calls++; + profile.graph_main_node_v_perm_cont_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_FLASH_ATTN: + profile.graph_main_node_flash_attn_calls++; + profile.graph_main_node_flash_attn_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_ATTN_OUT: + profile.graph_main_node_attn_out_calls++; + profile.graph_main_node_attn_out_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_FFN: + profile.graph_main_node_ffn_calls++; + profile.graph_main_node_ffn_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_RESULT_ROWS: + profile.graph_main_node_result_rows_calls++; + profile.graph_main_node_result_rows_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_RESULT_NORM: + profile.graph_main_node_result_norm_calls++; + profile.graph_main_node_result_norm_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_RESULT: + profile.graph_main_node_result_calls++; + profile.graph_main_node_result_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_NONE: + break; + } +} + +static bool llama_dflash_kv_node_eval_callback(struct ggml_tensor * tensor, bool ask, void * user_data) { + auto * profiler = static_cast(user_data); + if (profiler == nullptr || profiler->profile == nullptr) { + return false; + } + + const llama_dflash_kv_node_kind kind = llama_dflash_kv_node_kind_from_tensor(tensor); + if (ask) { + if (kind == LLAMA_DFLASH_KV_NODE_NONE) { + return false; + } + + profiler->active_kind = kind; + profiler->t_start_us = ggml_time_us(); + return true; + } + + if (kind != LLAMA_DFLASH_KV_NODE_NONE && profiler->active_kind == kind && profiler->t_start_us > 0) { + llama_dflash_kv_node_profile_add(*profiler->profile, kind, (uint64_t) (ggml_time_us() - profiler->t_start_us)); + } + + profiler->active_kind = LLAMA_DFLASH_KV_NODE_NONE; + profiler->t_start_us = 0; + return true; +} + +static bool llama_dflash_main_node_eval_callback(struct ggml_tensor * tensor, bool ask, void * user_data) { + auto * profiler = static_cast(user_data); + if (profiler == nullptr || profiler->profile == nullptr) { + return false; + } + + const llama_dflash_main_node_kind kind = llama_dflash_main_node_kind_from_tensor(tensor); + if (ask) { + profiler->prev_active = profiler->prev_callback != nullptr + ? profiler->prev_callback(tensor, ask, profiler->prev_user_data) + : false; + + if (kind == LLAMA_DFLASH_MAIN_NODE_NONE) { + profiler->active_kind = LLAMA_DFLASH_MAIN_NODE_NONE; + profiler->t_start_us = 0; + return profiler->prev_active; + } + + profiler->active_kind = kind; + profiler->t_start_us = ggml_time_us(); + return true; + } + + bool prev_result = false; + if (profiler->prev_active && profiler->prev_callback != nullptr) { + prev_result = profiler->prev_callback(tensor, ask, profiler->prev_user_data); + } + + const bool tracked = kind != LLAMA_DFLASH_MAIN_NODE_NONE && + profiler->active_kind == kind && + profiler->t_start_us > 0; + if (tracked) { + llama_dflash_main_node_profile_add(*profiler->profile, kind, (uint64_t) (ggml_time_us() - profiler->t_start_us)); + } + + profiler->prev_active = false; + profiler->active_kind = LLAMA_DFLASH_MAIN_NODE_NONE; + profiler->t_start_us = 0; + return prev_result || tracked; +} + +static bool llama_dflash_use_kv_workspace_experiment() { + return llama_env_flag_enabled_local("IK_DFLASH_KV_WORKSPACE"); +} + +void llama_sync_dflash_workspace_if_pending(struct llama_context & lctx) { + if (!lctx.dflash_kv_workspace_sync_pending || lctx.dflash_workspace_sched == nullptr) { + return; + } + + const int64_t t_workspace_sync_us = ggml_time_us(); + ggml_backend_sched_synchronize(lctx.dflash_workspace_sched); + lctx.dflash_profile.graph_kv_workspace_sync_us += (uint64_t) (ggml_time_us() - t_workspace_sync_us); + lctx.dflash_kv_workspace_sync_pending = false; +} + +static ggml_backend_buffer_type_t llama_dflash_kv_cache_layer_buft(const llama_context & lctx, int32_t il) { + if (il >= 0 && (size_t) il < lctx.model.buft_layer.size() && lctx.model.buft_layer[(size_t) il].buft != nullptr) { + return lctx.model.buft_layer[(size_t) il].buft; + } + + if (il >= 0 && (size_t) il < lctx.model.layers.size()) { + const ggml_tensor * wk = lctx.model.layers[(size_t) il].wk; + if (wk != nullptr && wk->buffer != nullptr) { + return ggml_backend_buffer_get_type(wk->buffer); + } + } + + return llama_default_buffer_type_cpu(true); +} + +static ggml_backend_t llama_backend_for_tensor(const llama_context & lctx, const ggml_tensor * tensor) { + if (tensor == nullptr) { + return nullptr; + } + + ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; + if (buf == nullptr) { + return nullptr; + } + + ggml_backend_buffer_type_t buft = ggml_backend_buffer_get_type(buf); + for (ggml_backend_t backend : lctx.backends) { + ggml_backend_buffer_type_t backend_buft = ggml_backend_is_cpu(backend) + ? llama_default_buffer_type_cpu(true) + : ggml_backend_get_default_buffer_type(backend); + if (backend_buft == buft) { + return backend; + } + } + + return nullptr; +} + +bool llama_context::ensure_dflash_kv_cache_tensors(int32_t cross_ctx) { + const bool use_kv_workspace = llama_env_flag_enabled_local("IK_DFLASH_KV_WORKSPACE"); + const int32_t target_cross_ctx = std::max(1, cross_ctx); + const int32_t target_token_capacity = std::max(1, (int32_t) model.hparams.dflash_block_size); + const int32_t target_workspace_n_kv_total = GGML_PAD(target_cross_ctx + target_token_capacity, cparams.flash_attn ? 256 : 32); + const int32_t n_layer = model.hparams.n_layer; + const int64_t n_embd_head_k = model.hparams.n_embd_head_k(0); + const int64_t n_embd_head_v = model.hparams.n_embd_head_v(0); + const int64_t n_head_kv = model.hparams.n_head_kv(); + + if (dflash_cache_ctx != nullptr && !dflash_k_ctx_cache.empty()) { + const bool cache_matches = (int32_t) dflash_k_ctx_cache.size() == n_layer && + dflash_k_ctx_cache.front() != nullptr && + (int32_t) dflash_k_ctx_cache.front()->ne[2] == target_cross_ctx; + const bool workspace_matches = use_kv_workspace + ? ((int32_t) dflash_k_ctx_workspace.size() == n_layer && + dflash_k_ctx_workspace.front() != nullptr && + (int32_t) dflash_k_ctx_workspace.front()->ne[1] == target_workspace_n_kv_total) + : dflash_k_ctx_workspace.empty() && dflash_v_ctx_workspace.empty(); + + if (cache_matches && workspace_matches) { + return true; + } + + free_dflash_kv_cache_tensors(); + if (dflash_sched != nullptr) { + ggml_backend_sched_free(dflash_sched); + dflash_sched = nullptr; + } + if (dflash_workspace_sched != nullptr) { + ggml_backend_sched_free(dflash_workspace_sched); + dflash_workspace_sched = nullptr; + } + dflash_kv_graph = nullptr; + dflash_kv_workspace_graph = nullptr; + dflash_kv_graph_rows = 0; + dflash_kv_graph_write_pos = 0; + dflash_kv_workspace_graph_rows = 0; + dflash_kv_workspace_graph_write_pos = 0; + dflash_kv_workspace_reserved_rows = 0; + dflash_buf_compute_meta.clear(); + dflash_workspace_buf_compute_meta.clear(); + } + + ggml_init_params params = { + /*.mem_size =*/ (size_t) ((use_kv_workspace ? 4 : 2) * std::max(1, n_layer)) * ggml_tensor_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + + dflash_cache_ctx = ggml_init(params); + if (dflash_cache_ctx == nullptr) { + return false; + } + + dflash_k_ctx_cache.resize((size_t) n_layer); + dflash_v_ctx_cache.resize((size_t) n_layer); + dflash_k_ctx_workspace.clear(); + dflash_v_ctx_workspace.clear(); + if (use_kv_workspace) { + dflash_k_ctx_workspace.resize((size_t) n_layer); + dflash_v_ctx_workspace.resize((size_t) n_layer); + } + dflash_cache_bufs.clear(); + dflash_cache_bufs.reserve((size_t) std::max(1, n_layer) * (use_kv_workspace ? 4 : 2)); + int32_t host_layers = 0; + const char * first_buft_name = nullptr; + const char * last_buft_name = nullptr; + for (int32_t il = 0; il < n_layer; ++il) { + ggml_backend_buffer_type_t layer_buft = llama_dflash_kv_cache_layer_buft(*this, il); + if (ggml_backend_buft_is_host(layer_buft)) { + host_layers++; + } + if (first_buft_name == nullptr) { + first_buft_name = ggml_backend_buft_name(layer_buft); + } + last_buft_name = ggml_backend_buft_name(layer_buft); + + dflash_k_ctx_cache[(size_t) il] = ggml_new_tensor_3d(dflash_cache_ctx, GGML_TYPE_F32, n_embd_head_k, n_head_kv, target_cross_ctx); + dflash_v_ctx_cache[(size_t) il] = ggml_new_tensor_3d(dflash_cache_ctx, GGML_TYPE_F32, n_embd_head_v, n_head_kv, target_cross_ctx); + if (dflash_k_ctx_cache[(size_t) il] == nullptr || dflash_v_ctx_cache[(size_t) il] == nullptr) { + free_dflash_kv_cache_tensors(); + return false; + } + + ggml_set_input(dflash_k_ctx_cache[(size_t) il]); + ggml_set_input(dflash_v_ctx_cache[(size_t) il]); + ggml_format_name(dflash_k_ctx_cache[(size_t) il], "dflash_k_ctx_cache_%d", il); + ggml_format_name(dflash_v_ctx_cache[(size_t) il], "dflash_v_ctx_cache_%d", il); + + const size_t k_bytes = ggml_backend_buft_get_alloc_size(layer_buft, dflash_k_ctx_cache[(size_t) il]); + ggml_backend_buffer_t k_buf = ggml_backend_buft_alloc_buffer(layer_buft, k_bytes); + if (k_buf == nullptr) { + free_dflash_kv_cache_tensors(); + return false; + } + ggml_backend_buffer_set_usage(k_buf, GGML_BACKEND_BUFFER_USAGE_COMPUTE); + ggml_backend_tensor_alloc(k_buf, dflash_k_ctx_cache[(size_t) il], ggml_backend_buffer_get_base(k_buf)); + ggml_backend_buffer_clear(k_buf, 0); + dflash_cache_bufs.push_back(k_buf); + + const size_t v_bytes = ggml_backend_buft_get_alloc_size(layer_buft, dflash_v_ctx_cache[(size_t) il]); + ggml_backend_buffer_t v_buf = ggml_backend_buft_alloc_buffer(layer_buft, v_bytes); + if (v_buf == nullptr) { + free_dflash_kv_cache_tensors(); + return false; + } + ggml_backend_buffer_set_usage(v_buf, GGML_BACKEND_BUFFER_USAGE_COMPUTE); + ggml_backend_tensor_alloc(v_buf, dflash_v_ctx_cache[(size_t) il], ggml_backend_buffer_get_base(v_buf)); + ggml_backend_buffer_clear(v_buf, 0); + dflash_cache_bufs.push_back(v_buf); + + if (use_kv_workspace) { + dflash_k_ctx_workspace[(size_t) il] = ggml_new_tensor_3d(dflash_cache_ctx, GGML_TYPE_F32, n_embd_head_k, target_workspace_n_kv_total, n_head_kv); + dflash_v_ctx_workspace[(size_t) il] = ggml_new_tensor_3d(dflash_cache_ctx, GGML_TYPE_F32, n_embd_head_v, target_workspace_n_kv_total, n_head_kv); + if (dflash_k_ctx_workspace[(size_t) il] == nullptr || dflash_v_ctx_workspace[(size_t) il] == nullptr) { + free_dflash_kv_cache_tensors(); + return false; + } + + ggml_set_input(dflash_k_ctx_workspace[(size_t) il]); + ggml_set_input(dflash_v_ctx_workspace[(size_t) il]); + ggml_format_name(dflash_k_ctx_workspace[(size_t) il], "dflash_k_ctx_workspace_%d", il); + ggml_format_name(dflash_v_ctx_workspace[(size_t) il], "dflash_v_ctx_workspace_%d", il); + + const size_t k_workspace_bytes = ggml_backend_buft_get_alloc_size(layer_buft, dflash_k_ctx_workspace[(size_t) il]); + ggml_backend_buffer_t k_workspace_buf = ggml_backend_buft_alloc_buffer(layer_buft, k_workspace_bytes); + if (k_workspace_buf == nullptr) { + free_dflash_kv_cache_tensors(); + return false; + } + ggml_backend_buffer_set_usage(k_workspace_buf, GGML_BACKEND_BUFFER_USAGE_COMPUTE); + ggml_backend_tensor_alloc(k_workspace_buf, dflash_k_ctx_workspace[(size_t) il], ggml_backend_buffer_get_base(k_workspace_buf)); + ggml_backend_buffer_clear(k_workspace_buf, 0); + dflash_cache_bufs.push_back(k_workspace_buf); + + const size_t v_workspace_bytes = ggml_backend_buft_get_alloc_size(layer_buft, dflash_v_ctx_workspace[(size_t) il]); + ggml_backend_buffer_t v_workspace_buf = ggml_backend_buft_alloc_buffer(layer_buft, v_workspace_bytes); + if (v_workspace_buf == nullptr) { + free_dflash_kv_cache_tensors(); + return false; + } + ggml_backend_buffer_set_usage(v_workspace_buf, GGML_BACKEND_BUFFER_USAGE_COMPUTE); + ggml_backend_tensor_alloc(v_workspace_buf, dflash_v_ctx_workspace[(size_t) il], ggml_backend_buffer_get_base(v_workspace_buf)); + ggml_backend_buffer_clear(v_workspace_buf, 0); + dflash_cache_bufs.push_back(v_workspace_buf); + } + } + + dflash_profile.last_kv_cache_host_layers = host_layers; + dflash_kv_workspace_token_capacity = use_kv_workspace ? target_token_capacity : 0; + dflash_kv_workspace_n_kv_total = use_kv_workspace ? target_workspace_n_kv_total : 0; + llama_reset_dflash_kv_cache_state(this); + LLAMA_LOG_INFO("%s: DFlash K/V cache placement cross_ctx=%d host_layers=%d/%d first=%s last=%s\n", + __func__, + target_cross_ctx, + host_layers, + n_layer, + first_buft_name != nullptr ? first_buft_name : "(none)", + last_buft_name != nullptr ? last_buft_name : "(none)"); + + return true; +} + +void llama_context::free_dflash_kv_cache_tensors() { + dflash_k_ctx_cache.clear(); + dflash_v_ctx_cache.clear(); + dflash_k_ctx_workspace.clear(); + dflash_v_ctx_workspace.clear(); + dflash_kv_cache_write_pos = 0; + dflash_kv_cache_n_filled = 0; + dflash_kv_cache_update_rows = 0; + dflash_kv_cache_reserved_rows = 0; + dflash_kv_cache_view_write_pos = 0; + dflash_kv_cache_view_n_filled = 0; + dflash_kv_cache_applied_window_version = 0; + dflash_kv_cache_valid = false; + dflash_kv_cache_view_valid = false; + dflash_kv_workspace_write_pos = 0; + dflash_kv_workspace_n_filled = 0; + dflash_kv_workspace_reserved_rows = 0; + dflash_kv_workspace_token_capacity = 0; + dflash_kv_workspace_n_kv_total = 0; + dflash_kv_workspace_applied_window_version = 0; + dflash_kv_workspace_valid = false; + dflash_kv_workspace_sync_pending = false; + dflash_kv_graph = nullptr; + dflash_kv_workspace_graph = nullptr; + dflash_kv_graph_rows = 0; + dflash_kv_graph_write_pos = 0; + dflash_kv_workspace_graph_rows = 0; + dflash_kv_workspace_graph_write_pos = 0; + dflash_kv_input_target_features = nullptr; + dflash_kv_input_pos_ctx = nullptr; + dflash_kq_mask_tensor = nullptr; + dflash_kq_mask_swa_tensor = nullptr; + + if (dflash_workspace_sched != nullptr) { + ggml_backend_sched_synchronize(dflash_workspace_sched); + ggml_backend_sched_free(dflash_workspace_sched); + dflash_workspace_sched = nullptr; + } + + for (ggml_backend_buffer_t buf : dflash_cache_bufs) { + if (buf != nullptr) { + ggml_backend_buffer_free(buf); + } + } + dflash_cache_bufs.clear(); + if (dflash_cache_ctx != nullptr) { + ggml_free(dflash_cache_ctx); + dflash_cache_ctx = nullptr; + } +} + +static void llama_graph_compute_sched( + llama_context & lctx, + ggml_backend_sched_t sched, + ggml_cgraph * gf, + int n_threads) { +#ifdef GGML_USE_METAL + if (ggml_backend_is_metal(lctx.backend_metal)) { + ggml_backend_metal_set_n_cb(lctx.backend_metal, n_threads); + } +#endif + + if (lctx.backend_cpu != nullptr) { + ggml_backend_cpu_set_n_threads(lctx.backend_cpu, n_threads); + ggml_backend_cpu_set_abort_callback(lctx.backend_cpu, lctx.abort_callback, lctx.abort_callback_data); + } +#ifdef GGML_USE_BLAS + if (lctx.backend_blas != nullptr) { + ggml_backend_blas_set_n_threads(lctx.backend_blas, n_threads); + } +#endif + + ggml_backend_sched_graph_compute_async(sched, gf); +} + +static bool dflash_layer_has_attention_bias(const llama_layer & layer) { + return layer.bq != nullptr || + layer.bk != nullptr || + layer.bv != nullptr || + layer.bo != nullptr || + layer.bqkv != nullptr || + layer.bqk != nullptr || + layer.bkv != nullptr; +} + +static bool validate_dflash_graph_contract(const llama_context & lctx) { + const auto & model = lctx.model; + const auto & hparams = model.hparams; + + auto rope_dim_for_layer = [&hparams](int32_t il) -> uint32_t { + if (hparams.rope_dim_per_layer[(size_t) il] != 0) { + return hparams.rope_dim_per_layer[(size_t) il]; + } + + return hparams.swa_layers[(size_t) il] ? hparams.n_rot_swa : hparams.n_rot; + }; + + auto rope_base_for_layer = [&hparams](int32_t il) -> float { + if (hparams.has_rope_freq_base_per_layer) { + return hparams.rope_freq_base_per_layer[(size_t) il]; + } + + return hparams.swa_layers[(size_t) il] ? hparams.rope_freq_base_train_swa : hparams.rope_freq_base_train; + }; + + auto rope_scale_for_layer = [&hparams](int32_t il) -> float { + return hparams.swa_layers[(size_t) il] ? hparams.rope_freq_scale_train_swa : hparams.rope_freq_scale_train; + }; + + const uint32_t ref_n_head = hparams.n_head(0); + const uint32_t ref_n_head_kv = hparams.n_head_kv(0); + const uint32_t ref_n_embd_head_k = hparams.n_embd_head_k(0); + const uint32_t ref_n_embd_head_v = hparams.n_embd_head_v(0); + const uint32_t ref_rope_dim = rope_dim_for_layer(0); + const float ref_rope_base = rope_base_for_layer(0); + const float ref_rope_scale = rope_scale_for_layer(0); + + for (int32_t il = 0; il < (int32_t) hparams.n_layer; ++il) { + if (hparams.n_head((uint32_t) il) != ref_n_head || + hparams.n_head_kv((uint32_t) il) != ref_n_head_kv || + hparams.n_embd_head_k(il) != ref_n_embd_head_k || + hparams.n_embd_head_v(il) != ref_n_embd_head_v) { + LLAMA_LOG_ERROR("%s: DFlash graph assumes layer-invariant head config, but layer %d differs (n_head=%u/%u n_head_kv=%u/%u head_k=%u/%u head_v=%u/%u)\n", + __func__, + il, + hparams.n_head((uint32_t) il), ref_n_head, + hparams.n_head_kv((uint32_t) il), ref_n_head_kv, + hparams.n_embd_head_k(il), ref_n_embd_head_k, + hparams.n_embd_head_v(il), ref_n_embd_head_v); + return false; + } + + const uint32_t rope_dim = rope_dim_for_layer(il); + const float rope_base = rope_base_for_layer(il); + const float rope_scale = rope_scale_for_layer(il); + if (rope_dim != ref_rope_dim || std::fabs(rope_base - ref_rope_base) > 1e-6f || std::fabs(rope_scale - ref_rope_scale) > 1e-6f) { + LLAMA_LOG_ERROR("%s: DFlash graph assumes layer-invariant RoPE config, but layer %d differs (dim=%u/%u base=%g/%g scale=%g/%g)\n", + __func__, + il, + rope_dim, ref_rope_dim, + (double) rope_base, (double) ref_rope_base, + (double) rope_scale, (double) ref_rope_scale); + return false; + } + + if (model.layers[(size_t) il].attn_norm == nullptr || + model.layers[(size_t) il].attn_q_norm == nullptr || + model.layers[(size_t) il].attn_k_norm == nullptr) { + LLAMA_LOG_ERROR("%s: DFlash graph requires attn_norm, attn_q_norm, and attn_k_norm weights, but layer %d is missing one or more of them\n", + __func__, il); + return false; + } + + const bool has_q_norm = model.layers[(size_t) il].attn_q_norm != nullptr; + const bool has_k_norm = model.layers[(size_t) il].attn_k_norm != nullptr; + if (has_q_norm != has_k_norm) { + LLAMA_LOG_ERROR("%s: DFlash graph requires symmetric Q/K norm presence, but layer %d has q_norm=%d k_norm=%d\n", + __func__, il, (int) has_q_norm, (int) has_k_norm); + return false; + } + + if (model.layers[(size_t) il].attn_norm_b != nullptr || + model.layers[(size_t) il].attn_q_norm_b != nullptr || + model.layers[(size_t) il].attn_k_norm_b != nullptr) { + LLAMA_LOG_ERROR("%s: DFlash graph does not implement norm-bias tensors, but layer %d requires attn_norm_b/q_norm_b/k_norm_b\n", + __func__, il); + return false; + } + + if (dflash_layer_has_attention_bias(model.layers[(size_t) il])) { + LLAMA_LOG_ERROR("%s: DFlash graph does not implement attention bias tensors, but layer %d requires them\n", + __func__, il); + return false; + } + } + + return true; +} + +bool llama_prepare_dflash_graph_inputs( + struct llama_context & lctx, + uint32_t n_tokens) { + const bool use_kv_cache = llama_env_flag_enabled_local("IK_DFLASH_KV_CACHE"); + const bool use_kv_workspace = use_kv_cache && llama_dflash_use_kv_workspace_experiment(); + const bool kv_node_timing = llama_env_flag_enabled_local("IK_DFLASH_KV_NODE_TIMING"); + auto & profile = lctx.dflash_profile; + const int32_t cross_ctx = lctx.dflash_visible_cross_ctx > 0 + ? lctx.dflash_visible_cross_ctx + : std::max(1, (int32_t) lctx.cparams.n_ctx - (int32_t) lctx.model.hparams.dflash_block_size); + ggml_tensor * kq_mask = lctx.dflash_kq_mask_tensor; + ggml_tensor * kq_mask_swa = lctx.dflash_kq_mask_swa_tensor; + + if (kq_mask == nullptr) { + LLAMA_LOG_ERROR("%s: DFlash graph inputs are not initialized\n", __func__); + return false; + } + + if (!validate_dflash_graph_contract(lctx)) { + profile.graph_shape_failures++; + return false; + } + + if (use_kv_cache) { + if (!lctx.ensure_dflash_kv_cache_tensors(cross_ctx) || lctx.dflash_k_ctx_cache.empty() || lctx.dflash_v_ctx_cache.empty()) { + LLAMA_LOG_ERROR("%s: DFlash K/V cache inputs are not initialized\n", __func__); + return false; + } + } else if (lctx.inp_dflash_target_features == nullptr || lctx.inp_dflash_pos_ctx == nullptr) { + LLAMA_LOG_ERROR("%s: DFlash inline inputs are not initialized\n", __func__); + return false; + } + + const float * src = lctx.dflash_target_features; + const float * append_src = lctx.dflash_target_append_features; + const llama_pos * src_pos = lctx.dflash_target_positions; + const size_t total_floats = lctx.dflash_target_features_n_floats; + const size_t append_floats = lctx.dflash_target_append_features_n_floats; + const size_t total_positions = lctx.dflash_target_positions_n; + const int32_t n_rows = lctx.dflash_target_features_n_rows; + const int32_t append_rows_available = lctx.dflash_target_append_features_n_rows; + const int32_t width = (int32_t) lctx.model.hparams.dflash_n_target_features; + const int32_t graph_cross_ctx = use_kv_cache + ? (lctx.dflash_k_ctx_cache.front() != nullptr ? (int32_t) lctx.dflash_k_ctx_cache.front()->ne[2] : 0) + : (lctx.inp_dflash_target_features != nullptr ? (int32_t) lctx.inp_dflash_target_features->ne[1] : 0); + const int32_t n_mask_tokens = (int32_t) kq_mask->ne[1]; + const int32_t n_kv_total = (int32_t) kq_mask->ne[0]; + const int64_t t_total_us = ggml_time_us(); + + profile.graph_prepare_calls++; + profile.last_n_rows = n_rows; + profile.last_width = width; + profile.last_cross_ctx = cross_ctx; + profile.last_n_tokens = (int32_t) n_tokens; + profile.last_n_kv_total = n_kv_total; + + if (use_kv_workspace) { + llama_sync_dflash_workspace_if_pending(lctx); + } + + if (graph_cross_ctx != cross_ctx) { + profile.graph_shape_failures++; + + LLAMA_LOG_ERROR("%s: DFlash graph cross_ctx drift (graph=%d configured=%d)\n", + __func__, graph_cross_ctx, cross_ctx); + return false; + } + if (n_rows <= 0) { + profile.graph_shape_failures++; + LLAMA_LOG_ERROR("%s: missing DFlash target feature rows\n", __func__); + return false; + } + + const bool have_full_src = src != nullptr && total_floats == (size_t) n_rows * (size_t) width; + if (n_rows > cross_ctx || (src != nullptr && !have_full_src)) { + profile.graph_shape_failures++; + LLAMA_LOG_ERROR("%s: invalid DFlash target feature shape (rows=%d width=%d floats=%zu cross_ctx=%d)\n", + __func__, n_rows, width, total_floats, cross_ctx); + return false; + } + + if (!use_kv_cache && !have_full_src) { + profile.graph_shape_failures++; + LLAMA_LOG_ERROR("%s: missing contiguous DFlash target features for inline path\n", __func__); + return false; + } + + if (n_kv_total < cross_ctx + (int32_t) n_tokens) { + profile.graph_mask_overflow++; + LLAMA_LOG_ERROR("%s: invalid DFlash mask shape (n_kv_total=%d < cross_ctx+n_tokens=%d)\n", + __func__, n_kv_total, cross_ctx + (int32_t) n_tokens); + return false; + } + + const int32_t left_pad = cross_ctx - n_rows; + profile.last_left_pad = left_pad; + if (!use_kv_cache) { + const size_t padded_floats = (size_t) cross_ctx * (size_t) width; + const size_t dst_offset = (size_t) left_pad * (size_t) width; + const int64_t t_feature_us = ggml_time_us(); + if (lctx.dflash_target_features_padded.size() != padded_floats) { + lctx.dflash_target_features_padded.resize(padded_floats); + } + if (left_pad == 0 && total_floats == padded_floats) { + std::copy(src, src + total_floats, lctx.dflash_target_features_padded.begin()); + } else { + if (dst_offset > 0) { + std::fill(lctx.dflash_target_features_padded.begin(), + lctx.dflash_target_features_padded.begin() + (ptrdiff_t) dst_offset, 0.0f); + } + std::copy(src, src + total_floats, lctx.dflash_target_features_padded.begin() + (ptrdiff_t) dst_offset); + } + profile.graph_feature_copy_us += (uint64_t) (ggml_time_us() - t_feature_us); + profile.graph_feature_bytes += padded_floats * sizeof(float); + } + + const int64_t t_pos_us = ggml_time_us(); + lctx.dflash_pos_ctx_data.resize((size_t) cross_ctx); + std::fill(lctx.dflash_pos_ctx_data.begin(), lctx.dflash_pos_ctx_data.end(), 0); + if (src_pos == nullptr || total_positions != (size_t) n_rows) { + profile.graph_pos_fallbacks++; + profile.graph_shape_failures++; + profile.last_pos_first = -1; + profile.last_pos_last = -1; + if (profile.graph_pos_fallbacks <= 3) { + LLAMA_LOG_ERROR("%s: missing DFlash target positions (rows=%d positions=%zu cross_ctx=%d)\n", + __func__, n_rows, total_positions, cross_ctx); + } + return false; + } + + profile.last_pos_first = src_pos[0]; + profile.last_pos_last = src_pos[n_rows - 1]; + for (int32_t i = 1; i < n_rows; ++i) { + if (src_pos[i] <= src_pos[i - 1]) { + profile.graph_pos_non_monotonic++; + profile.graph_shape_failures++; + if (profile.graph_pos_non_monotonic <= 3) { + LLAMA_LOG_ERROR("%s: DFlash target positions are not strictly increasing (rows=%d first=%d last=%d)\n", + __func__, n_rows, (int) src_pos[0], (int) src_pos[n_rows - 1]); + } + return false; + } + } + std::copy(src_pos, src_pos + n_rows, lctx.dflash_pos_ctx_data.begin() + (ptrdiff_t) left_pad); + profile.graph_pos_copy_us += (uint64_t) (ggml_time_us() - t_pos_us); + profile.graph_pos_bytes += lctx.dflash_pos_ctx_data.size() * sizeof(llama_pos); + + if (use_kv_cache) { + const llama_dflash_kv_cache_transition cache_plan = llama_plan_dflash_kv_cache_transition( + cross_ctx, + lctx.dflash_kv_cache_n_filled, + lctx.dflash_kv_cache_write_pos, + lctx.dflash_kv_cache_valid, + lctx.dflash_kv_cache_applied_window_version, + lctx.dflash_target_window_version, + lctx.dflash_target_window_keep_rows, + lctx.dflash_target_window_append_rows, + lctx.dflash_target_window_replace, + n_rows); + + const bool have_append_src = append_src != nullptr && + append_rows_available == cache_plan.append_rows && + append_floats == (size_t) cache_plan.append_rows * (size_t) width; + + const int32_t update_rows = cache_plan.cache_up_to_date + ? 0 + : (cache_plan.rebuild_cache ? n_rows : cache_plan.append_rows); + const size_t max_nodes = lctx.model.max_nodes((int) std::max(1, cross_ctx)) + 24 * lctx.model.hparams.n_layer; + const size_t meta_size = ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false); + if (lctx.dflash_buf_compute_meta.size() != meta_size) { + lctx.dflash_buf_compute_meta.resize(meta_size); + } + + if (lctx.dflash_sched == nullptr || lctx.dflash_kv_cache_reserved_rows != cross_ctx) { + std::vector backend_buft; + backend_buft.reserve(lctx.backends.size()); + for (auto * backend : lctx.backends) { + if (ggml_backend_is_cpu(backend)) { + backend_buft.push_back(llama_default_buffer_type_cpu(true)); + } else { + backend_buft.push_back(ggml_backend_get_default_buffer_type(backend)); + } + } + + if (lctx.dflash_sched != nullptr) { + ggml_backend_sched_free(lctx.dflash_sched); + lctx.dflash_sched = nullptr; + } + lctx.dflash_kv_graph = nullptr; + lctx.dflash_kv_graph_rows = 0; + lctx.dflash_kv_graph_write_pos = 0; + + const int32_t saved_update_rows = lctx.dflash_kv_cache_update_rows; + lctx.dflash_kv_cache_update_rows = cross_ctx; + const int64_t t_build_us = ggml_time_us(); + ggml_cgraph * gf_reserve = llm_build_context::llama_build_graph_dflash_kv_cache(lctx); + profile.graph_kv_cache_build_us += (uint64_t) (ggml_time_us() - t_build_us); + lctx.dflash_kv_cache_update_rows = saved_update_rows; + if (gf_reserve == nullptr) { + profile.graph_shape_failures++; + LLAMA_LOG_ERROR("%s: failed to build DFlash K/V cache reserve graph\n", __func__); + return false; + } + + const int64_t t_reserve_us = ggml_time_us(); + lctx.dflash_sched = ggml_backend_sched_new(lctx.backends.data(), backend_buft.data(), lctx.backends.size(), max_nodes, false); + const bool reserved = lctx.dflash_sched != nullptr && ggml_backend_sched_reserve(lctx.dflash_sched, gf_reserve); + profile.graph_kv_cache_reserve_us += (uint64_t) (ggml_time_us() - t_reserve_us); + if (!reserved) { + profile.graph_shape_failures++; + LLAMA_LOG_ERROR("%s: failed to initialize DFlash K/V scheduler\n", __func__); + return false; + } + lctx.dflash_kv_cache_reserved_rows = cross_ctx; + } + + if (update_rows > 0) { + const float * update_src = nullptr; + if (have_append_src && update_rows == cache_plan.append_rows) { + update_src = append_src; + } else if (have_full_src) { + update_src = src + (size_t) (n_rows - update_rows) * (size_t) width; + } + const llama_pos * update_pos = src_pos + (n_rows - update_rows); + + if (update_src == nullptr) { + profile.graph_shape_failures++; + LLAMA_LOG_ERROR("%s: missing DFlash appended target features for cached update (rows=%d append_rows=%d floats=%zu)\n", + __func__, n_rows, update_rows, append_floats); + return false; + } + + if (cache_plan.rebuild_cache) { + llama_reset_dflash_kv_cache_state(&lctx); + } + + lctx.dflash_kv_cache_update_rows = update_rows; + ggml_cgraph * gf_kv = nullptr; + const bool can_reuse_kv_graph = lctx.dflash_kv_graph != nullptr && + lctx.dflash_kv_graph_rows == update_rows && + lctx.dflash_kv_graph_write_pos == lctx.dflash_kv_cache_write_pos; + if (can_reuse_kv_graph) { + gf_kv = lctx.dflash_kv_graph; + } else { + const int64_t t_build_us = ggml_time_us(); + gf_kv = llm_build_context::llama_build_graph_dflash_kv_cache(lctx); + profile.graph_kv_cache_build_us += (uint64_t) (ggml_time_us() - t_build_us); + if (gf_kv == nullptr || lctx.dflash_kv_input_target_features == nullptr || lctx.dflash_kv_input_pos_ctx == nullptr) { + profile.graph_shape_failures++; + LLAMA_LOG_ERROR("%s: failed to build DFlash K/V cache graph\n", __func__); + return false; + } + + const int64_t t_reset_us = ggml_time_us(); + ggml_backend_sched_reset(lctx.dflash_sched); + profile.graph_kv_cache_reset_us += (uint64_t) (ggml_time_us() - t_reset_us); + + const int64_t t_alloc_us = ggml_time_us(); + ggml_backend_sched_alloc_graph(lctx.dflash_sched, gf_kv); + profile.graph_kv_cache_alloc_us += (uint64_t) (ggml_time_us() - t_alloc_us); + + lctx.dflash_kv_graph = gf_kv; + lctx.dflash_kv_graph_rows = update_rows; + lctx.dflash_kv_graph_write_pos = lctx.dflash_kv_cache_write_pos; + } + + ggml_backend_t kv_feature_backend = llama_backend_for_tensor(lctx, lctx.dflash_kv_input_target_features); + const int64_t t_feature_upload_us = ggml_time_us(); + if (kv_feature_backend != nullptr) { + ggml_backend_tensor_set_async(kv_feature_backend, lctx.dflash_kv_input_target_features, update_src, 0, ggml_nbytes(lctx.dflash_kv_input_target_features)); + } else { + ggml_backend_tensor_set(lctx.dflash_kv_input_target_features, update_src, 0, ggml_nbytes(lctx.dflash_kv_input_target_features)); + } + profile.graph_kv_cache_feature_upload_us += (uint64_t) (ggml_time_us() - t_feature_upload_us); + profile.graph_feature_bytes += (size_t) update_rows * (size_t) width * sizeof(float); + + ggml_backend_t kv_pos_backend = llama_backend_for_tensor(lctx, lctx.dflash_kv_input_pos_ctx); + const int64_t t_pos_upload_us = ggml_time_us(); + if (kv_pos_backend != nullptr) { + ggml_backend_tensor_set_async(kv_pos_backend, lctx.dflash_kv_input_pos_ctx, update_pos, 0, ggml_nbytes(lctx.dflash_kv_input_pos_ctx)); + } else { + ggml_backend_tensor_set(lctx.dflash_kv_input_pos_ctx, update_pos, 0, ggml_nbytes(lctx.dflash_kv_input_pos_ctx)); + } + profile.graph_kv_cache_pos_upload_us += (uint64_t) (ggml_time_us() - t_pos_upload_us); + + const int64_t t_kv_cache_us = ggml_time_us(); + llama_dflash_kv_node_profiler kv_node_profiler; + if (kv_node_timing) { + kv_node_profiler.profile = &profile; + ggml_backend_sched_set_eval_callback(lctx.dflash_sched, llama_dflash_kv_node_eval_callback, &kv_node_profiler); + } + llama_graph_compute_sched(lctx, lctx.dflash_sched, gf_kv, lctx.cparams.n_threads); + if (kv_node_timing) { + ggml_backend_sched_set_eval_callback(lctx.dflash_sched, nullptr, nullptr); + } + profile.graph_kv_cache_compute_us += (uint64_t) (ggml_time_us() - t_kv_cache_us); + + const int64_t t_sync_us = ggml_time_us(); + ggml_backend_sched_synchronize(lctx.dflash_sched); + profile.graph_kv_cache_sync_us += (uint64_t) (ggml_time_us() - t_sync_us); + profile.graph_kv_cache_calls++; + + lctx.dflash_kv_cache_n_filled = std::min(cross_ctx, lctx.dflash_kv_cache_n_filled + update_rows); + lctx.dflash_kv_cache_write_pos = (lctx.dflash_kv_cache_write_pos + update_rows) % cross_ctx; + lctx.dflash_kv_cache_applied_window_version = lctx.dflash_target_window_version; + lctx.dflash_kv_cache_valid = true; + lctx.dflash_kv_cache_view_n_filled = lctx.dflash_kv_cache_n_filled; + lctx.dflash_kv_cache_view_write_pos = lctx.dflash_kv_cache_write_pos; + lctx.dflash_kv_cache_view_valid = true; + } + + if (use_kv_workspace && lctx.dflash_kv_cache_view_valid && + !lctx.dflash_k_ctx_workspace.empty() && !lctx.dflash_v_ctx_workspace.empty()) { + const bool need_workspace_refresh = !lctx.dflash_kv_workspace_valid || + lctx.dflash_kv_workspace_n_filled != lctx.dflash_kv_cache_view_n_filled || + lctx.dflash_kv_workspace_write_pos != lctx.dflash_kv_cache_view_write_pos || + lctx.dflash_kv_workspace_applied_window_version != lctx.dflash_kv_cache_applied_window_version; + + if (need_workspace_refresh) { + const size_t max_nodes = lctx.model.max_nodes((int) std::max(1, cross_ctx)) + 16 * lctx.model.hparams.n_layer; + const size_t meta_size = ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false); + if (lctx.dflash_workspace_buf_compute_meta.size() != meta_size) { + lctx.dflash_workspace_buf_compute_meta.resize(meta_size); + } + + ggml_cgraph * gf_workspace = nullptr; + const bool can_reuse_workspace_graph = lctx.dflash_kv_workspace_graph != nullptr && + lctx.dflash_kv_workspace_graph_rows == lctx.dflash_kv_cache_view_n_filled && + lctx.dflash_kv_workspace_graph_write_pos == lctx.dflash_kv_cache_view_write_pos; + + if (can_reuse_workspace_graph) { + gf_workspace = lctx.dflash_kv_workspace_graph; + } else { + const int64_t t_build_us = ggml_time_us(); + gf_workspace = llm_build_context::llama_build_graph_dflash_kv_workspace(lctx); + profile.graph_kv_workspace_build_us += (uint64_t) (ggml_time_us() - t_build_us); + if (gf_workspace == nullptr) { + profile.graph_shape_failures++; + LLAMA_LOG_ERROR("%s: failed to build DFlash K/V workspace graph\n", __func__); + return false; + } + + std::vector backend_buft; + backend_buft.reserve(lctx.backends.size()); + for (auto * backend : lctx.backends) { + if (ggml_backend_is_cpu(backend)) { + backend_buft.push_back(llama_default_buffer_type_cpu(true)); + } else { + backend_buft.push_back(ggml_backend_get_default_buffer_type(backend)); + } + } + + if (lctx.dflash_workspace_sched == nullptr) { + lctx.dflash_workspace_sched = ggml_backend_sched_new(lctx.backends.data(), backend_buft.data(), lctx.backends.size(), max_nodes, false); + } + + if (lctx.dflash_kv_workspace_reserved_rows != cross_ctx) { + const bool saved_view_valid = lctx.dflash_kv_cache_view_valid; + const int32_t saved_view_rows = lctx.dflash_kv_cache_view_n_filled; + const int32_t saved_view_write_pos = lctx.dflash_kv_cache_view_write_pos; + + lctx.dflash_kv_cache_view_valid = true; + lctx.dflash_kv_cache_view_n_filled = cross_ctx; + lctx.dflash_kv_cache_view_write_pos = cross_ctx > 1 ? 1 : 0; + + const int64_t t_reserve_build_us = ggml_time_us(); + ggml_cgraph * gf_workspace_reserve = llm_build_context::llama_build_graph_dflash_kv_workspace(lctx); + profile.graph_kv_workspace_build_us += (uint64_t) (ggml_time_us() - t_reserve_build_us); + + lctx.dflash_kv_cache_view_valid = saved_view_valid; + lctx.dflash_kv_cache_view_n_filled = saved_view_rows; + lctx.dflash_kv_cache_view_write_pos = saved_view_write_pos; + + const int64_t t_reserve_us = ggml_time_us(); + const bool reserved = lctx.dflash_workspace_sched != nullptr && + gf_workspace_reserve != nullptr && + ggml_backend_sched_reserve(lctx.dflash_workspace_sched, gf_workspace_reserve); + profile.graph_kv_workspace_reserve_us += (uint64_t) (ggml_time_us() - t_reserve_us); + if (!reserved) { + profile.graph_shape_failures++; + LLAMA_LOG_ERROR("%s: failed to initialize DFlash K/V workspace scheduler\n", __func__); + return false; + } + + lctx.dflash_kv_workspace_reserved_rows = cross_ctx; + } + + const int64_t t_reset_us = ggml_time_us(); + ggml_backend_sched_reset(lctx.dflash_workspace_sched); + profile.graph_kv_workspace_reset_us += (uint64_t) (ggml_time_us() - t_reset_us); + + const int64_t t_alloc_us = ggml_time_us(); + ggml_backend_sched_alloc_graph(lctx.dflash_workspace_sched, gf_workspace); + profile.graph_kv_workspace_alloc_us += (uint64_t) (ggml_time_us() - t_alloc_us); + + lctx.dflash_kv_workspace_graph = gf_workspace; + lctx.dflash_kv_workspace_graph_rows = lctx.dflash_kv_cache_view_n_filled; + lctx.dflash_kv_workspace_graph_write_pos = lctx.dflash_kv_cache_view_write_pos; + } + + const int64_t t_workspace_us = ggml_time_us(); + llama_graph_compute_sched(lctx, lctx.dflash_workspace_sched, gf_workspace, lctx.cparams.n_threads); + profile.graph_kv_workspace_compute_us += (uint64_t) (ggml_time_us() - t_workspace_us); + lctx.dflash_kv_workspace_sync_pending = true; + profile.graph_kv_workspace_calls++; + + lctx.dflash_kv_workspace_n_filled = lctx.dflash_kv_cache_view_n_filled; + lctx.dflash_kv_workspace_write_pos = lctx.dflash_kv_cache_view_write_pos; + lctx.dflash_kv_workspace_applied_window_version = lctx.dflash_kv_cache_applied_window_version; + lctx.dflash_kv_workspace_valid = true; + } + } + } else { + ggml_backend_tensor_set(lctx.inp_dflash_target_features, lctx.dflash_target_features_padded.data(), 0, ggml_nbytes(lctx.inp_dflash_target_features)); + ggml_backend_tensor_set(lctx.inp_dflash_pos_ctx, lctx.dflash_pos_ctx_data.data(), 0, ggml_nbytes(lctx.inp_dflash_pos_ctx)); + } + + const int64_t t_mask_us = ggml_time_us(); + const int32_t full_visible_first = left_pad; + const int32_t full_visible_last = cross_ctx + (int32_t) n_tokens - 1; + lctx.dflash_kq_mask_data.assign((size_t) n_kv_total * (size_t) n_mask_tokens, -INFINITY); + int32_t visible_kv_max = 0; + for (uint32_t j = 0; j < n_tokens; ++j) { + float * row = lctx.dflash_kq_mask_data.data() + (size_t) j * (size_t) n_kv_total; + const int32_t visible_kv = cross_ctx + (int32_t) n_tokens; + visible_kv_max = std::max(visible_kv_max, visible_kv); + profile.graph_visible_kv_sum += (uint64_t) visible_kv; + for (int32_t i = full_visible_first; i <= full_visible_last; ++i) { + row[i] = 0.0f; + } + } + ggml_backend_tensor_set(kq_mask, lctx.dflash_kq_mask_data.data(), 0, ggml_nbytes(kq_mask)); + profile.graph_mask_build_us += (uint64_t) (ggml_time_us() - t_mask_us); + profile.graph_mask_bytes += ggml_nbytes(kq_mask); + + if (kq_mask_swa != nullptr) { + lctx.dflash_kq_mask_swa_data.assign((size_t) n_kv_total * (size_t) n_mask_tokens, -INFINITY); + const int32_t swa_window = (int32_t) lctx.model.hparams.n_swa; + const int32_t draft_pos_base = (int32_t) profile.last_pos_last; + for (uint32_t j = 0; j < n_tokens; ++j) { + float * row = lctx.dflash_kq_mask_swa_data.data() + (size_t) j * (size_t) n_kv_total; + const int32_t q_pos = draft_pos_base + (int32_t) j; + + for (int32_t k = left_pad; k < cross_ctx; ++k) { + const int32_t k_pos = (int32_t) lctx.dflash_pos_ctx_data[(size_t) k]; + if (q_pos - k_pos < swa_window) { + row[k] = 0.0f; + } + } + + for (int32_t k = cross_ctx; k < cross_ctx + (int32_t) n_tokens; ++k) { + const int32_t block_k = k - cross_ctx; + if (block_k <= (int32_t) j) { + row[k] = 0.0f; + } + } + } + + ggml_backend_tensor_set(kq_mask_swa, lctx.dflash_kq_mask_swa_data.data(), 0, ggml_nbytes(kq_mask_swa)); + profile.graph_mask_bytes += ggml_nbytes(kq_mask_swa); + } + + profile.graph_visible_kv_max = std::max(profile.graph_visible_kv_max, (uint64_t) visible_kv_max); + profile.graph_prepare_total_us += (uint64_t) (ggml_time_us() - t_total_us); + + if (profile.graph_prepare_calls == 1) { + int32_t n_swa_layers = 0; + for (int32_t il = 0; il < lctx.model.hparams.n_layer; ++il) { + n_swa_layers += lctx.model.hparams.swa_layers[(size_t) il] ? 1 : 0; + } + + LLAMA_LOG_INFO("%s: DFlash graph contract rows=%d width=%d cross_ctx=%d n_tokens=%u left_pad=%d n_kv_total=%d draft_n_ctx=%u pos=%s [%d..%d] full_mask=[%d..%d] swa_window=%u swa_layers=%d\n", + __func__, n_rows, width, cross_ctx, n_tokens, left_pad, n_kv_total, lctx.cparams.n_ctx, + (src_pos != nullptr && total_positions == (size_t) n_rows) ? "target" : "synthetic", + (int) profile.last_pos_first, (int) profile.last_pos_last, + full_visible_first, full_visible_last, + lctx.model.hparams.n_swa, + n_swa_layers); + } + + return true; +} diff --git a/src/llama-dflash.h b/src/llama-dflash.h new file mode 100644 index 00000000..8280c6ca --- /dev/null +++ b/src/llama-dflash.h @@ -0,0 +1,8 @@ +#pragma once + +#include + +struct llama_context; + +bool llama_prepare_dflash_graph_inputs(llama_context & lctx, uint32_t n_tokens); +void llama_sync_dflash_workspace_if_pending(llama_context & lctx); diff --git a/src/llama-quantize.cpp b/src/llama-quantize.cpp index 1f538882..367b7225 100644 --- a/src/llama-quantize.cpp +++ b/src/llama-quantize.cpp @@ -616,7 +616,9 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n if (qs.model.hparams.n_vocab >= 127999 && (qs.model.type == MODEL_8B || qs.model.type == MODEL_70B)) new_type = GGML_TYPE_IQ6_K; } - else if (qs.model.hparams.n_gqa() >= 4) { + else if (qs.model.hparams.n_gqa() >= 4 && + !(arch == LLM_ARCH_DFLASH_DRAFT && + (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M))) { if (new_type == GGML_TYPE_Q2_K || new_type == GGML_TYPE_IQ3_XXS) new_type = GGML_TYPE_IQ3_S; else if (new_type == GGML_TYPE_Q2_K_R4 || new_type == GGML_TYPE_IQ3_XXS_R4) new_type = GGML_TYPE_IQ3_K_R4; else if (new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_IQ3_S) new_type = GGML_TYPE_Q4_K; @@ -1778,4 +1780,3 @@ uint32_t llama_model_quantize( return 1; } } - diff --git a/src/llama-spec-features-dflash.cpp b/src/llama-spec-features-dflash.cpp new file mode 100644 index 00000000..088f6b2d --- /dev/null +++ b/src/llama-spec-features-dflash.cpp @@ -0,0 +1,1097 @@ +#include "llama-spec-features.h" + +#include +#include +#include +#include +#include +#include + +#include "llama-model.h" +#include "llama-context.h" + +static bool llama_dflash_positions_strictly_increasing( + const llama_pos * positions, + int32_t n_rows, + llama_pos & first_pos, + llama_pos & last_pos) { + first_pos = -1; + last_pos = -1; + + if (positions == nullptr || n_rows <= 0) { + return false; + } + + first_pos = positions[0]; + last_pos = positions[n_rows - 1]; + + for (int32_t i = 1; i < n_rows; ++i) { + if (positions[i] <= positions[i - 1]) { + return false; + } + } + + return true; +} + +void llama_dflash_profile_reset(struct llama_context * ctx) { + if (ctx == nullptr) { + return; + } + + ctx->dflash.profile = {}; +} + +void llama_reset_dflash_kv_cache_state(struct llama_context * ctx) { + if (ctx == nullptr) { + return; + } + + ctx->dflash.kv.cache_write_pos = 0; + ctx->dflash.kv.cache_n_filled = 0; + ctx->dflash.kv.cache_update_rows = 0; + ctx->dflash.kv.cache_view_write_pos = 0; + ctx->dflash.kv.cache_view_n_filled = 0; + ctx->dflash.kv.cache_applied_window_version = 0; + ctx->dflash.kv.cache_valid = false; + ctx->dflash.kv.cache_view_valid = false; + ctx->dflash.kv.workspace_write_pos = 0; + ctx->dflash.kv.workspace_n_filled = 0; + ctx->dflash.kv.workspace_applied_window_version = 0; + ctx->dflash.kv.workspace_valid = false; + ctx->dflash.kv.workspace_sync_pending = false; + + for (ggml_backend_buffer_t buf : ctx->dflash.kv.cache_bufs) { + if (buf != nullptr) { + ggml_backend_buffer_clear(buf, 0); + } + } +} + +llama_dflash_kv_cache_transition llama_plan_dflash_kv_cache_transition_for_ctx( + const struct llama_context * ctx, + const llama_dflash_window_update & window_update, + int32_t n_rows) { + if (ctx == nullptr) { + llama_dflash_kv_cache_transition plan; + plan.rebuild_cache = true; + plan.append_rows = std::clamp(window_update.append_rows, 0, n_rows); + plan.next_n_filled = n_rows; + return plan; + } + + const int32_t cross_ctx = ctx->dflash.visible_cross_ctx > 0 + ? ctx->dflash.visible_cross_ctx + : std::max(1, (int32_t) ctx->cparams.n_ctx - (int32_t) ctx->model.hparams.dflash_block_size); + + return llama_plan_dflash_kv_cache_transition( + cross_ctx, + ctx->dflash.kv.cache_n_filled, + ctx->dflash.kv.cache_write_pos, + ctx->dflash.kv.cache_valid, + ctx->dflash.kv.cache_applied_window_version, + window_update.version, + window_update.keep_rows, + window_update.append_rows, + window_update.replace, + n_rows); +} + +void llama_set_dflash_visible_cross_ctx( + struct llama_context * ctx, + int32_t cross_ctx) { + if (ctx == nullptr) { + return; + } + + ctx->dflash.visible_cross_ctx = std::max(0, cross_ctx); +} + +int32_t llama_get_dflash_visible_cross_ctx( + const struct llama_context * ctx) { + return ctx != nullptr ? ctx->dflash.visible_cross_ctx : 0; +} + +bool llama_dflash_profile_get_stats( + const struct llama_context * ctx, + llama_dflash_profile_stats * stats) { + if (ctx == nullptr || stats == nullptr) { + return false; + } + + *stats = ctx->dflash.profile; + return true; +} + +int32_t llama_model_dflash_block_size(const struct llama_model * model) { + return model ? (int32_t) model->hparams.dflash_block_size : 0; +} + +int32_t llama_model_dflash_mask_token_id(const struct llama_model * model) { + return model ? (int32_t) model->hparams.dflash_mask_token_id : -1; +} + +int32_t llama_model_dflash_n_target_layers(const struct llama_model * model) { + return model ? (int32_t) model->hparams.dflash_n_target_layers : 0; +} + +int32_t llama_model_dflash_n_target_features(const struct llama_model * model) { + return model ? (int32_t) model->hparams.dflash_n_target_features : 0; +} + +int32_t llama_model_dflash_target_layer_ids( + const struct llama_model * model, + int32_t * layer_ids, + int32_t capacity) { + if (model == nullptr || layer_ids == nullptr || capacity <= 0) { + return 0; + } + + const int32_t n_layers = std::min((int32_t) model->hparams.dflash_n_target_layers, capacity); + for (int32_t i = 0; i < n_layers; ++i) { + layer_ids[i] = (int32_t) model->hparams.dflash_target_layer_ids[i]; + } + + return n_layers; +} + +int32_t llama_model_dflash_target_mask_token_id(const struct llama_model * model) { + if (model == nullptr) { + return (int32_t) LLAMA_TOKEN_NULL; + } + + return (int32_t) model->vocab.token_mask(); +} + +const struct ggml_tensor * llama_model_dflash_output_tensor( + const struct llama_model * model) { + if (model == nullptr) { + return nullptr; + } + + if (model->output_mtp != nullptr) { + return model->output_mtp; + } + + if (model->output != nullptr) { + return model->output; + } + + return model->tok_embd; +} + +static const char * llama_dflash_io_mode_name(int32_t io_mode) { + switch (io_mode) { + case LLAMA_DFLASH_IO_MODE_SHARED: + return "shared"; + case LLAMA_DFLASH_IO_MODE_SELF_CONTAINED: + return "self-contained"; + case LLAMA_DFLASH_IO_MODE_MIXED: + return "mixed"; + default: + return "invalid"; + } +} + +static const char * llama_dflash_output_head_kind( + const struct llama_model * draft_model, + const struct llama_model * target_model) { + const struct ggml_tensor * output = llama_model_dflash_output_tensor(draft_model); + if (output == nullptr) { + return "missing"; + } + + if (output == draft_model->tok_embd) { + return draft_model->tok_embd == (target_model ? target_model->tok_embd : nullptr) + ? "shared_token_embedding" + : "token_embedding"; + } + + if (draft_model->output_mtp != nullptr && output == draft_model->output_mtp) { + if (target_model != nullptr && target_model->output_mtp != nullptr && output == target_model->output_mtp) { + return "output_mtp"; + } + + if (std::strcmp(output->name, "output_extra.weight") == 0) { + return "output_extra"; + } + + return "output_mtp"; + } + + return "output"; +} + +int32_t llama_model_dflash_io_mode( + const struct llama_model * draft_model, + const struct llama_model * target_model) { + if (draft_model == nullptr || target_model == nullptr || draft_model->arch != LLM_ARCH_DFLASH_DRAFT) { + return LLAMA_DFLASH_IO_MODE_INVALID; + } + + const ggml_tensor * draft_output = llama_model_dflash_output_tensor(draft_model); + const ggml_tensor * target_output = llama_model_dflash_output_tensor(target_model); + if (draft_model->tok_embd == nullptr || draft_output == nullptr || target_model->tok_embd == nullptr || target_output == nullptr) { + return LLAMA_DFLASH_IO_MODE_INVALID; + } + + const bool shared_tok = draft_model->tok_embd == target_model->tok_embd; + const bool shared_output = draft_output == target_output; + if (shared_tok && shared_output) { + return LLAMA_DFLASH_IO_MODE_SHARED; + } + + if (!shared_tok && !shared_output) { + return LLAMA_DFLASH_IO_MODE_SELF_CONTAINED; + } + + return LLAMA_DFLASH_IO_MODE_MIXED; +} + +bool llama_model_dflash_io_tensors_match( + const struct llama_model * draft_model, + int32_t n_embd, + int32_t n_vocab) { + const ggml_tensor * output = llama_model_dflash_output_tensor(draft_model); + if (draft_model == nullptr || draft_model->tok_embd == nullptr || output == nullptr || n_embd <= 0 || n_vocab <= 0) { + return false; + } + + return (int32_t) draft_model->tok_embd->ne[0] == n_embd && + (int32_t) draft_model->tok_embd->ne[1] == n_vocab && + (int32_t) output->ne[0] == n_embd && + (int32_t) output->ne[1] == n_vocab; +} + +bool llama_model_share_dflash_io_tensors( + struct llama_model * draft_model, + const struct llama_model * target_model) { + if (draft_model == nullptr || target_model == nullptr) { + return false; + } + + if (draft_model->arch != LLM_ARCH_DFLASH_DRAFT) { + return true; + } + + if (draft_model->tok_embd == nullptr) { + draft_model->tok_embd = target_model->tok_embd; + } + + if (draft_model->output == nullptr) { + draft_model->output = target_model->output ? target_model->output : target_model->tok_embd; + if (draft_model->output == nullptr) { + draft_model->output = draft_model->tok_embd; + } + } + + const bool uses_shared_tok = draft_model->tok_embd == target_model->tok_embd; + const bool uses_shared_output = draft_model->output == target_model->output || + draft_model->output == target_model->tok_embd; + + if (draft_model->output_mtp == nullptr && target_model->output_mtp != nullptr && uses_shared_tok && uses_shared_output) { + draft_model->output_mtp = target_model->output_mtp; + } + + const struct ggml_tensor * output = llama_model_dflash_output_tensor(draft_model); + if (draft_model->tok_embd != nullptr && output != nullptr) { + LLAMA_LOG_INFO("%s: DFlash IO mode=%s output_head=%s tensor=%s type=%s\n", + __func__, + llama_dflash_io_mode_name(llama_model_dflash_io_mode(draft_model, target_model)), + llama_dflash_output_head_kind(draft_model, target_model), + output->name[0] != '\0' ? output->name : "(unnamed)", + ggml_type_name(output->type)); + } + + return draft_model->tok_embd != nullptr && output != nullptr; +} + +static bool llama_set_dflash_target_features_impl( + struct llama_context * ctx, + const float * target_features, + size_t n_floats, + int32_t n_rows, + const llama_pos * target_positions, + bool copy_data, + const llama_dflash_window_update * window_update) { + const bool have_full_features = target_features != nullptr && n_floats > 0; + const bool have_append_features = window_update != nullptr && + window_update->append_features != nullptr && + window_update->append_floats > 0 && + window_update->append_rows > 0; + + if (ctx == nullptr || n_rows <= 0 || (!have_full_features && !have_append_features)) { + return false; + } + + auto & profile = ctx->dflash.profile; + const int64_t t_start_us = ggml_time_us(); + const int32_t row_width = have_full_features + ? (n_rows > 0 ? (int32_t) (n_floats / (size_t) n_rows) : 0) + : (window_update->append_rows > 0 ? (int32_t) (window_update->append_floats / (size_t) window_update->append_rows) : 0); + llama_pos first_pos = -1; + llama_pos last_pos = -1; + + if (have_full_features && copy_data) { + ctx->dflash.target.features_owned.assign(target_features, target_features + n_floats); + ctx->dflash.target.features = ctx->dflash.target.features_owned.data(); + } else if (have_full_features) { + ctx->dflash.target.features_owned.clear(); + ctx->dflash.target.features = target_features; + } else { + ctx->dflash.target.features_owned.clear(); + ctx->dflash.target.features = nullptr; + } + ctx->dflash.target.features_n_floats = have_full_features ? n_floats : 0; + ctx->dflash.target.features_n_rows = n_rows; + if (have_append_features && copy_data) { + ctx->dflash.target.append_features_owned.assign( + window_update->append_features, + window_update->append_features + window_update->append_floats); + ctx->dflash.target.append_features = ctx->dflash.target.append_features_owned.data(); + } else if (have_append_features) { + ctx->dflash.target.append_features_owned.clear(); + ctx->dflash.target.append_features = window_update->append_features; + } else { + ctx->dflash.target.append_features_owned.clear(); + ctx->dflash.target.append_features = nullptr; + } + ctx->dflash.target.append_features_n_floats = have_append_features ? window_update->append_floats : 0; + ctx->dflash.target.append_features_n_rows = have_append_features ? window_update->append_rows : 0; + ctx->dflash.target.version = window_update != nullptr && window_update->version > 0 + ? window_update->version + : ctx->dflash.target.version + 1; + ctx->dflash.target.keep_rows = window_update != nullptr + ? std::max(0, std::min(n_rows, window_update->keep_rows)) + : 0; + ctx->dflash.target.append_rows = window_update != nullptr + ? std::max(0, std::min(n_rows, window_update->append_rows)) + : n_rows; + ctx->dflash.target.replace = window_update != nullptr + ? window_update->replace + : true; + if (ctx->dflash.target.keep_rows + ctx->dflash.target.append_rows > n_rows) { + ctx->dflash.target.keep_rows = std::max(0, n_rows - ctx->dflash.target.append_rows); + } + + const int32_t cross_ctx = ctx->dflash.visible_cross_ctx > 0 + ? ctx->dflash.visible_cross_ctx + : std::max(1, (int32_t) ctx->cparams.n_ctx - (int32_t) ctx->model.hparams.dflash_block_size); + const llama_dflash_window_update cache_window_update = { + ctx->dflash.target.version, + ctx->dflash.target.keep_rows, + ctx->dflash.target.append_rows, + ctx->dflash.target.replace, + ctx->dflash.target.append_features, + ctx->dflash.target.append_features_n_floats, + }; + const llama_dflash_kv_cache_transition cache_plan = llama_plan_dflash_kv_cache_transition_for_ctx(ctx, cache_window_update, n_rows); + + if (cache_plan.cache_up_to_date) { + ctx->dflash.kv.cache_view_n_filled = ctx->dflash.kv.cache_n_filled; + ctx->dflash.kv.cache_view_write_pos = ctx->dflash.kv.cache_write_pos; + ctx->dflash.kv.cache_view_valid = ctx->dflash.kv.cache_valid; + } else if (cross_ctx > 0) { + ctx->dflash.kv.cache_view_n_filled = cache_plan.next_n_filled; + ctx->dflash.kv.cache_view_write_pos = cache_plan.next_write_pos; + ctx->dflash.kv.cache_view_valid = cache_plan.next_n_filled > 0; + } + + if (target_positions != nullptr) { + if (copy_data) { + ctx->dflash.target.positions_owned.assign(target_positions, target_positions + n_rows); + ctx->dflash.target.positions = ctx->dflash.target.positions_owned.data(); + } else { + ctx->dflash.target.positions_owned.clear(); + ctx->dflash.target.positions = target_positions; + } + ctx->dflash.target.positions_n = (size_t) n_rows; + } else { + ctx->dflash.target.positions_owned.clear(); + ctx->dflash.target.positions = nullptr; + ctx->dflash.target.positions_n = 0; + } + + profile.set_target_copy_calls++; + profile.set_target_copy_us += (uint64_t) (ggml_time_us() - t_start_us); + profile.set_target_rows += (uint64_t) n_rows; + profile.set_target_copy_bytes += + (have_full_features ? n_floats : 0) * sizeof(float) + + (have_append_features ? window_update->append_floats : 0) * sizeof(float) + + (target_positions ? (size_t) n_rows * sizeof(llama_pos) : 0); + profile.last_n_rows = n_rows; + profile.last_width = row_width; + + if (target_positions == nullptr) { + profile.set_target_missing_positions++; + profile.last_pos_first = -1; + profile.last_pos_last = -1; + } else { + if (!llama_dflash_positions_strictly_increasing(target_positions, n_rows, first_pos, last_pos)) { + profile.set_target_non_monotonic_positions++; + } + profile.last_pos_first = first_pos; + profile.last_pos_last = last_pos; + } + + return true; +} + +bool llama_set_dflash_target_features_copy( + struct llama_context * ctx, + const float * target_features, + size_t n_floats, + int32_t n_rows, + const llama_pos * target_positions, + const llama_dflash_window_update * window_update) { + return llama_set_dflash_target_features_impl(ctx, target_features, n_floats, n_rows, target_positions, true, window_update); +} + +bool llama_set_dflash_target_features_view( + struct llama_context * ctx, + const float * target_features, + size_t n_floats, + int32_t n_rows, + const llama_pos * target_positions, + const llama_dflash_window_update * window_update) { + return llama_set_dflash_target_features_impl(ctx, target_features, n_floats, n_rows, target_positions, false, window_update); +} + +static void llama_record_dflash_capture_phase( + struct llama_context * ctx, + bool is_prompt_warmup, + int32_t row_count, + int32_t row_width) { + if (ctx == nullptr || row_count <= 0 || row_width <= 0) { + return; + } + + auto & profile = ctx->dflash.profile; + if (is_prompt_warmup) { + profile.capture_prompt_batches++; + if (profile.capture_prompt_last_rows > 0 && profile.capture_prompt_last_width > 0 && + (profile.capture_prompt_last_rows != row_count || profile.capture_prompt_last_width != row_width)) { + profile.capture_prompt_shape_changes++; + } + profile.capture_prompt_last_rows = row_count; + profile.capture_prompt_last_width = row_width; + } else { + profile.capture_verify_batches++; + if (profile.capture_verify_last_rows > 0 && profile.capture_verify_last_width > 0 && + (profile.capture_verify_last_rows != row_count || profile.capture_verify_last_width != row_width)) { + profile.capture_verify_shape_changes++; + } + profile.capture_verify_last_rows = row_count; + profile.capture_verify_last_width = row_width; + } +} + +static bool llama_dflash_parse_layer_id(const struct ggml_tensor * tensor, int32_t & layer_id) { + if (tensor == nullptr) { + return false; + } + + static constexpr const char * prefix = "l_out-"; + if (std::strncmp(tensor->name, prefix, std::strlen(prefix)) != 0) { + return false; + } + + char * end = nullptr; + const long raw = std::strtol(tensor->name + std::strlen(prefix), &end, 10); + if (end == tensor->name + std::strlen(prefix) || *end != '\0') { + return false; + } + + layer_id = (int32_t) raw; + if (layer_id >= 1000) { + layer_id %= 1000; + } + + return layer_id >= 0; +} + +static int32_t llama_dflash_find_layer_index(const struct llama_context * ctx, int32_t layer_id) { + if (ctx == nullptr || !ctx->dflash.capture) { + return -1; + } + + const auto & layer_ids = ctx->dflash.capture->layer_ids; + const auto it = std::find(layer_ids.begin(), layer_ids.end(), layer_id); + return it == layer_ids.end() ? -1 : (int32_t) std::distance(layer_ids.begin(), it); +} + +static bool llama_dflash_capture_eval_callback(struct ggml_tensor * tensor, bool ask, void * user_data) { + auto * ctx = static_cast(user_data); + if (ctx == nullptr || !ctx->dflash.capture) { + return false; + } + + int32_t layer_id = -1; + if (!llama_dflash_parse_layer_id(tensor, layer_id)) { + return false; + } + + const int32_t layer_idx = llama_dflash_find_layer_index(ctx, layer_id); + if (layer_idx < 0) { + return false; + } + + if (ask) { + return true; + } + + const int32_t row_width = (int32_t) tensor->ne[0]; + const int32_t row_count = row_width > 0 ? (int32_t) (ggml_nelements(tensor) / (int64_t) row_width) : 0; + if (row_width <= 0 || row_count <= 0) { + return false; + } + + auto & capture = *ctx->dflash.capture; + if (capture.capture_batch_id == 0) { + capture.capture_batch_id = 1; + } + if (capture.layer_seen_batch_id.size() != capture.layer_ids.size()) { + capture.layer_seen_batch_id.assign(capture.layer_ids.size(), 0); + } + + auto & rows = capture.layer_rows[(size_t) layer_idx]; + rows.resize((size_t) row_count * (size_t) row_width); + ggml_backend_tensor_get(tensor, rows.data(), 0, ggml_nbytes(tensor)); + capture.row_width = row_width; + capture.row_count = row_count; + capture.layer_seen_batch_id[(size_t) layer_idx] = capture.capture_batch_id; + return true; +} + +bool llama_set_dflash_capture_layers( + struct llama_context * ctx, + const int32_t * layer_ids, + int32_t n_layers) { + if (ctx == nullptr || layer_ids == nullptr || n_layers <= 0) { + return false; + } + + auto capture = std::make_unique(); + capture->layer_ids.assign(layer_ids, layer_ids + n_layers); + capture->layer_rows.resize((size_t) n_layers); + capture->layer_seen_batch_id.assign((size_t) n_layers, 0); + capture->prev_cb_eval = ctx->cparams.cb_eval; + capture->prev_cb_eval_user_data = ctx->cparams.cb_eval_user_data; + ctx->dflash.capture = std::move(capture); + ctx->dflash.feature_view_buffer.clear(); + + ctx->cparams.cb_eval = llama_dflash_capture_eval_callback; + ctx->cparams.cb_eval_user_data = ctx; + if (ctx->sched != nullptr) { + ggml_backend_sched_set_eval_callback(ctx->sched, ctx->cparams.cb_eval, ctx->cparams.cb_eval_user_data); + } + + return true; +} + +void llama_clear_dflash_capture(struct llama_context * ctx) { + if (ctx == nullptr) { + return; + } + + ggml_backend_sched_eval_callback prev_cb_eval = nullptr; + void * prev_cb_eval_user_data = nullptr; + if (ctx->dflash.capture) { + prev_cb_eval = ctx->dflash.capture->prev_cb_eval; + prev_cb_eval_user_data = ctx->dflash.capture->prev_cb_eval_user_data; + } + + ctx->dflash.capture.reset(); + ctx->dflash.feature_view_buffer.clear(); + + if (ctx->cparams.cb_eval == llama_dflash_capture_eval_callback && ctx->cparams.cb_eval_user_data == ctx) { + ctx->cparams.cb_eval = prev_cb_eval; + ctx->cparams.cb_eval_user_data = prev_cb_eval_user_data; + if (ctx->sched != nullptr) { + ggml_backend_sched_set_eval_callback(ctx->sched, prev_cb_eval, prev_cb_eval_user_data); + } + } +} + +void llama_begin_dflash_capture_batch(struct llama_context * ctx) { + if (ctx == nullptr || !ctx->dflash.capture) { + return; + } + + auto & capture = *ctx->dflash.capture; + capture.capture_batch_id++; + capture.row_count = 0; + capture.row_width = 0; + std::fill(capture.layer_seen_batch_id.begin(), capture.layer_seen_batch_id.end(), 0); +} + +void llama_finish_dflash_capture_batch( + struct llama_context * ctx, + bool is_prompt_warmup) { + if (ctx == nullptr || !ctx->dflash.capture) { + return; + } + + auto & capture = *ctx->dflash.capture; + llama_record_dflash_capture_phase(ctx, is_prompt_warmup, capture.row_count, capture.row_width); + + // Reset the batch-local reference shape so the next decode only compares layers within + // the same batch, not against the previous prompt/verify batch. + capture.row_count = 0; + capture.row_width = 0; +} + +static bool llama_spec_prepare_dflash_capture( + struct llama_context * ctx, + int32_t & row_count, + int32_t & row_width, + int32_t & n_layers) { + if (ctx == nullptr || !ctx->dflash.capture) { + return false; + } + + auto & profile = ctx->dflash.profile; + profile.capture_prepare_calls++; + const int64_t t_sync_us = ggml_time_us(); + llama_synchronize(ctx); + profile.capture_prepare_sync_us += (uint64_t) (ggml_time_us() - t_sync_us); + + auto & capture = *ctx->dflash.capture; + row_count = capture.row_count; + row_width = capture.row_width; + n_layers = (int32_t) capture.layer_ids.size(); + if (row_count <= 0 || row_width <= 0 || n_layers <= 0 || capture.layer_rows.size() != (size_t) n_layers) { + profile.capture_prepare_failures++; + return false; + } + + if (capture.capture_batch_id == 0 || capture.layer_seen_batch_id.size() != (size_t) n_layers) { + profile.capture_prepare_failures++; + profile.capture_layer_batch_mismatch++; + if (profile.capture_layer_batch_mismatch <= 3) { + LLAMA_LOG_WARN("%s: DFlash capture batch markers are not initialized (batch_id=%llu layers=%zu expected=%d)\n", + __func__, + (unsigned long long) capture.capture_batch_id, + capture.layer_seen_batch_id.size(), + n_layers); + } + return false; + } + + for (int32_t layer_idx = 0; layer_idx < n_layers; ++layer_idx) { + if (capture.layer_seen_batch_id[(size_t) layer_idx] != capture.capture_batch_id) { + profile.capture_prepare_failures++; + profile.capture_layer_batch_mismatch++; + if (profile.capture_layer_batch_mismatch <= 3) { + LLAMA_LOG_WARN("%s: DFlash capture is stale for layer %d (seen_batch=%llu current_batch=%llu rows=%d width=%d)\n", + __func__, + capture.layer_ids[(size_t) layer_idx], + (unsigned long long) capture.layer_seen_batch_id[(size_t) layer_idx], + (unsigned long long) capture.capture_batch_id, + row_count, + row_width); + } + return false; + } + + const auto & rows = capture.layer_rows[(size_t) layer_idx]; + if (rows.size() != (size_t) row_count * (size_t) row_width) { + profile.capture_prepare_failures++; + profile.capture_layer_shape_mismatch++; + if (profile.capture_layer_shape_mismatch <= 3) { + LLAMA_LOG_WARN("%s: DFlash capture rows mismatch for layer %d: got=%zu expected=%zu (rows=%d width=%d)\n", + __func__, capture.layer_ids[(size_t) layer_idx], rows.size(), + (size_t) row_count * (size_t) row_width, row_count, row_width); + } + return false; + } + } + + return true; +} + +static bool llama_dflash_contract_log_enabled() { + const char * env = std::getenv("IK_DFLASH_CONTRACT_LOG"); + if (env == nullptr || *env == '\0') { + return false; + } + + return std::strcmp(env, "0") != 0 && + std::strcmp(env, "false") != 0 && + std::strcmp(env, "off") != 0; +} + +template +static std::string llama_dflash_contract_format_values( + const std::vector & values, + size_t edge_count = 4) { + std::ostringstream oss; + oss << '['; + if (values.empty()) { + oss << ']'; + return oss.str(); + } + + const size_t head = std::min(edge_count, values.size()); + for (size_t i = 0; i < head; ++i) { + if (i > 0) { + oss << ','; + } + oss << values[i]; + } + + if (values.size() > edge_count * 2) { + oss << ",...,"; + for (size_t i = values.size() - edge_count; i < values.size(); ++i) { + if (i > values.size() - edge_count) { + oss << ','; + } + oss << values[i]; + } + } else { + for (size_t i = head; i < values.size(); ++i) { + oss << ',' << values[i]; + } + } + + oss << ']'; + return oss.str(); +} + +static std::vector llama_dflash_contract_collect_batch_positions( + const llama_batch & batch, + const std::vector & batch_indices) { + std::vector positions; + positions.reserve(batch_indices.size()); + for (int32_t batch_index : batch_indices) { + positions.push_back(batch.pos[batch_index]); + } + return positions; +} + +static void llama_dflash_contract_summarize_positions( + const std::vector & positions, + llama_pos & first_pos, + llama_pos & last_pos, + int32_t & gap_count, + int32_t & nonmono_count) { + first_pos = -1; + last_pos = -1; + gap_count = 0; + nonmono_count = 0; + if (positions.empty()) { + return; + } + + first_pos = positions.front(); + last_pos = positions.back(); + for (size_t i = 1; i < positions.size(); ++i) { + if (positions[i] <= positions[i - 1]) { + nonmono_count++; + } else if (positions[i] != positions[i - 1] + 1) { + gap_count++; + } + } +} + +static void llama_dflash_contract_log_feature_view( + const char * kind, + llama_seq_id seq_id, + const llama_batch & batch, + int32_t row_count, + int32_t row_width, + int32_t n_layers, + int32_t batch_row_offset, + const std::vector & row_indices, + const std::vector & batch_indices) { + if (!llama_dflash_contract_log_enabled()) { + return; + } + + static std::atomic counter = 0; + const uint64_t ordinal = counter.fetch_add(1, std::memory_order_relaxed); + if (ordinal >= 8) { + return; + } + + const std::vector positions = llama_dflash_contract_collect_batch_positions(batch, batch_indices); + llama_pos first_pos = -1; + llama_pos last_pos = -1; + int32_t gap_count = 0; + int32_t nonmono_count = 0; + llama_dflash_contract_summarize_positions(positions, first_pos, last_pos, gap_count, nonmono_count); + + LLAMA_LOG_INFO("%s[%llu]: kind=%s seq=%d batch_tokens=%d capture_rows=%d row_width=%d layers=%d batch_row_offset=%d row_indices=%s batch_indices=%s batch_pos=%s pos=[%d..%d] gaps=%d nonmono=%d\n", + __func__, + (unsigned long long) (ordinal + 1), + kind, + (int) seq_id, + batch.n_tokens, + row_count, + row_width, + n_layers, + batch_row_offset, + llama_dflash_contract_format_values(row_indices).c_str(), + llama_dflash_contract_format_values(batch_indices).c_str(), + llama_dflash_contract_format_values(positions).c_str(), + (int) first_pos, + (int) last_pos, + gap_count, + nonmono_count); +} + +static void llama_dflash_contract_log_output_indices( + struct llama_context * ctx, + const std::vector & output_indices) { + if (!llama_dflash_contract_log_enabled()) { + return; + } + + static std::atomic counter = 0; + const uint64_t ordinal = counter.fetch_add(1, std::memory_order_relaxed); + if (ordinal >= 8) { + return; + } + + int32_t row_count = 0; + int32_t row_width = 0; + int32_t n_layers = 0; + const bool have_capture = llama_spec_prepare_dflash_capture(ctx, row_count, row_width, n_layers); + + LLAMA_LOG_INFO("%s[%llu]: output_indices=%s capture_rows=%d row_width=%d layers=%d have_capture=%s\n", + __func__, + (unsigned long long) (ordinal + 1), + llama_dflash_contract_format_values(output_indices).c_str(), + row_count, + row_width, + n_layers, + have_capture ? "true" : "false"); +} + + static bool llama_spec_materialize_dflash_rows_prepared( + struct llama_context * ctx, + int32_t row_count, + int32_t row_width, + int32_t n_layers, + const std::vector & row_indices, + std::vector & rows_out, + int32_t & combined_width); + +static bool llama_spec_materialize_dflash_rows( + struct llama_context * ctx, + const std::vector & row_indices, + std::vector & rows_out, + int32_t & combined_width) { + int32_t row_count = 0; + int32_t row_width = 0; + int32_t n_layers = 0; + if (!llama_spec_prepare_dflash_capture(ctx, row_count, row_width, n_layers)) { + if (ctx != nullptr) { + ctx->dflash.profile.capture_materialize_failures++; + } + return false; + } + + return llama_spec_materialize_dflash_rows_prepared(ctx, row_count, row_width, n_layers, row_indices, rows_out, combined_width); +} + +static bool llama_spec_materialize_dflash_rows_prepared( + struct llama_context * ctx, + int32_t row_count, + int32_t row_width, + int32_t n_layers, + const std::vector & row_indices, + std::vector & rows_out, + int32_t & combined_width) { + rows_out.clear(); + combined_width = 0; + if (ctx == nullptr || row_indices.empty()) { + return false; + } + + auto & profile = ctx->dflash.profile; + profile.capture_materialize_calls++; + const int64_t t_start_us = ggml_time_us(); + + if (row_count <= 0 || row_width <= 0 || n_layers <= 0 || ctx->dflash.capture == nullptr) { + profile.capture_materialize_failures++; + return false; + } + + combined_width = row_width * n_layers; + rows_out.resize((size_t) row_indices.size() * (size_t) combined_width); + + const auto & layer_rows = ctx->dflash.capture->layer_rows; + for (size_t out_row = 0; out_row < row_indices.size(); ++out_row) { + int32_t row_index = row_indices[out_row]; + if (row_index < 0) { + row_index += row_count; + } + if (row_index < 0 || row_index >= row_count) { + rows_out.clear(); + combined_width = 0; + profile.capture_materialize_failures++; + return false; + } + + float * dst = rows_out.data() + out_row * (size_t) combined_width; + for (int32_t layer_idx = 0; layer_idx < n_layers; ++layer_idx) { + const float * src = layer_rows[(size_t) layer_idx].data() + (size_t) row_index * (size_t) row_width; + std::memcpy(dst + (size_t) layer_idx * (size_t) row_width, src, (size_t) row_width * sizeof(float)); + } + } + + profile.capture_materialize_us += (uint64_t) (ggml_time_us() - t_start_us); + profile.capture_materialize_rows += (uint64_t) row_indices.size(); + profile.capture_materialize_bytes += rows_out.size() * sizeof(float); + + return true; +} + + +bool llama_spec_get_dflash_feature_view( + struct llama_context * ctx, + const llama_batch & batch, + llama_spec_feature_view & view) { + if (ctx == nullptr || batch.n_tokens <= 0 || batch.pos == nullptr || batch.n_seq_id == nullptr || batch.seq_id == nullptr) { + return false; + } + + int32_t row_count = 0; + int32_t row_width = 0; + int32_t n_layers = 0; + if (!llama_spec_prepare_dflash_capture(ctx, row_count, row_width, n_layers)) { + return false; + } + + const int32_t batch_row_offset = std::max(0, batch.n_tokens - row_count); + std::vector row_indices; + std::vector batch_indices; + row_indices.reserve((size_t) (batch.n_tokens - batch_row_offset)); + batch_indices.reserve((size_t) (batch.n_tokens - batch_row_offset)); + for (int32_t i = batch_row_offset; i < batch.n_tokens; ++i) { + row_indices.push_back(i - batch_row_offset); + batch_indices.push_back(i); + } + + if (row_indices.empty()) { + return false; + } + + view = {}; + view.kind = LLAMA_SPEC_FEATURE_HIDDEN_STATE; + if (!llama_spec_materialize_dflash_rows_prepared(ctx, row_count, row_width, n_layers, row_indices, ctx->dflash.feature_view_buffer, view.width)) { + return false; + } + + view.rows.reserve(batch_indices.size()); + for (int32_t batch_index : batch_indices) { + if (batch.n_seq_id[batch_index] <= 0 || batch.seq_id[batch_index] == nullptr) { + view.rows.clear(); + return false; + } + + view.rows.push_back({ + /* .seq_id = */ batch.seq_id[batch_index][0], + /* .pos = */ batch.pos[batch_index], + /* .data = */ ctx->dflash.feature_view_buffer.data() + view.rows.size() * (size_t) view.width, + }); + } + + llama_dflash_contract_log_feature_view( + "batch", + view.rows.empty() ? -1 : view.rows.front().seq_id, + batch, + row_count, + row_width, + n_layers, + batch_row_offset, + row_indices, + batch_indices); + + return true; +} + +bool llama_spec_get_dflash_feature_view_for_seq( + struct llama_context * ctx, + const llama_batch & batch, + llama_seq_id seq_id, + llama_spec_feature_view & view) { + if (ctx == nullptr || batch.n_tokens <= 0 || batch.pos == nullptr || batch.n_seq_id == nullptr || batch.seq_id == nullptr) { + return false; + } + + int32_t row_count = 0; + int32_t row_width = 0; + int32_t n_layers = 0; + if (!llama_spec_prepare_dflash_capture(ctx, row_count, row_width, n_layers)) { + return false; + } + + const int32_t batch_row_offset = std::max(0, batch.n_tokens - row_count); + std::vector row_indices; + row_indices.reserve((size_t) batch.n_tokens); + std::vector batch_indices; + batch_indices.reserve((size_t) batch.n_tokens); + for (int32_t i = batch_row_offset; i < batch.n_tokens; ++i) { + if (batch.n_seq_id[i] <= 0 || batch.seq_id[i] == nullptr) { + return false; + } + + for (int32_t j = 0; j < batch.n_seq_id[i]; ++j) { + if (batch.seq_id[i][j] == seq_id) { + row_indices.push_back(i - batch_row_offset); + batch_indices.push_back(i); + break; + } + } + } + + if (row_indices.empty()) { + return false; + } + + view = {}; + view.kind = LLAMA_SPEC_FEATURE_HIDDEN_STATE; + if (!llama_spec_materialize_dflash_rows_prepared(ctx, row_count, row_width, n_layers, row_indices, ctx->dflash.feature_view_buffer, view.width)) { + return false; + } + + view.rows.reserve(row_indices.size()); + for (size_t i = 0; i < batch_indices.size(); ++i) { + const int32_t batch_index = batch_indices[i]; + view.rows.push_back({ + /* .seq_id = */ seq_id, + /* .pos = */ batch.pos[batch_index], + /* .data = */ ctx->dflash.feature_view_buffer.data() + i * (size_t) view.width, + }); + } + + llama_dflash_contract_log_feature_view( + "seq", + seq_id, + batch, + row_count, + row_width, + n_layers, + batch_row_offset, + row_indices, + batch_indices); + + return true; +} + +bool llama_spec_copy_dflash_rows_from_output_indices( + struct llama_context * ctx, + const std::vector & output_indices, + std::vector & hidden_rows) { + int32_t combined_width = 0; + if (!llama_spec_materialize_dflash_rows(ctx, output_indices, hidden_rows, combined_width)) { + hidden_rows.clear(); + return false; + } + + llama_dflash_contract_log_output_indices(ctx, output_indices); + + return hidden_rows.size() == (size_t) output_indices.size() * (size_t) combined_width; +} diff --git a/src/llama-spec-features-dflash.h b/src/llama-spec-features-dflash.h new file mode 100644 index 00000000..02e709d1 --- /dev/null +++ b/src/llama-spec-features-dflash.h @@ -0,0 +1,279 @@ +#pragma once + +#include "llama.h" + +#include +#include +#include + +struct llama_context; +struct llama_model; +struct ggml_tensor; +struct llama_spec_feature_view; + +struct llama_dflash_profile_stats { + uint64_t decode_internal_chunks = 0; + uint64_t decode_graph_rebuilds = 0; + uint64_t decode_sync_profile_points = 0; + uint64_t decode_prelude_us = 0; + uint64_t decode_sched_reset_us = 0; + uint64_t decode_build_graph_us = 0; + uint64_t decode_sched_alloc_graph_us = 0; + uint64_t decode_set_inputs_us = 0; + uint64_t decode_graph_compute_us = 0; + uint64_t decode_result_us = 0; + uint64_t decode_embedding_us = 0; + uint64_t decode_final_sched_reset_us = 0; + + uint64_t decode_output_reserve_calls = 0; + uint64_t decode_output_reserve_us = 0; + uint64_t decode_output_reserve_reallocs = 0; + uint64_t decode_output_reserve_realloc_bytes = 0; + uint64_t decode_prepare_calls = 0; + uint64_t decode_prepare_us = 0; + uint64_t decode_prepare_failures = 0; + + uint64_t set_target_copy_calls = 0; + uint64_t set_target_copy_us = 0; + uint64_t set_target_rows = 0; + uint64_t set_target_copy_bytes = 0; + uint64_t set_target_missing_positions = 0; + uint64_t set_target_non_monotonic_positions = 0; + + uint64_t capture_prepare_calls = 0; + uint64_t capture_prepare_sync_us = 0; + uint64_t capture_prepare_failures = 0; + uint64_t capture_layer_shape_mismatch = 0; + uint64_t capture_layer_batch_mismatch = 0; + uint64_t capture_prompt_batches = 0; + uint64_t capture_prompt_shape_changes = 0; + uint64_t capture_verify_batches = 0; + uint64_t capture_verify_shape_changes = 0; + uint64_t capture_materialize_calls = 0; + uint64_t capture_materialize_rows = 0; + uint64_t capture_materialize_bytes = 0; + uint64_t capture_materialize_us = 0; + uint64_t capture_materialize_failures = 0; + + uint64_t graph_prepare_calls = 0; + uint64_t graph_prepare_total_us = 0; + uint64_t graph_feature_copy_us = 0; + uint64_t graph_pos_copy_us = 0; + uint64_t graph_mask_build_us = 0; + uint64_t graph_kv_cache_build_us = 0; + uint64_t graph_kv_cache_reserve_us = 0; + uint64_t graph_kv_cache_reset_us = 0; + uint64_t graph_kv_cache_alloc_us = 0; + uint64_t graph_kv_cache_feature_upload_us = 0; + uint64_t graph_kv_cache_pos_upload_us = 0; + uint64_t graph_kv_cache_compute_us = 0; + uint64_t graph_kv_cache_sync_us = 0; + uint64_t graph_kv_cache_read_concat_pad_us = 0; + uint64_t graph_kv_cache_read_concat_pad_calls = 0; + uint64_t graph_kv_cache_cached_bytes = 0; + uint64_t graph_kv_cache_calls = 0; + uint64_t graph_kv_workspace_build_us = 0; + uint64_t graph_kv_workspace_reserve_us = 0; + uint64_t graph_kv_workspace_reset_us = 0; + uint64_t graph_kv_workspace_alloc_us = 0; + uint64_t graph_kv_workspace_compute_us = 0; + uint64_t graph_kv_workspace_sync_us = 0; + uint64_t graph_kv_workspace_calls = 0; + uint64_t graph_kv_node_fused_target_calls = 0; + uint64_t graph_kv_node_fused_target_us = 0; + uint64_t graph_kv_node_k_proj_calls = 0; + uint64_t graph_kv_node_k_proj_us = 0; + uint64_t graph_kv_node_k_norm_calls = 0; + uint64_t graph_kv_node_k_norm_us = 0; + uint64_t graph_kv_node_k_rope_calls = 0; + uint64_t graph_kv_node_k_rope_us = 0; + uint64_t graph_kv_node_v_proj_calls = 0; + uint64_t graph_kv_node_v_proj_us = 0; + uint64_t graph_kv_node_k_store_calls = 0; + uint64_t graph_kv_node_k_store_us = 0; + uint64_t graph_kv_node_v_store_calls = 0; + uint64_t graph_kv_node_v_store_us = 0; + uint64_t graph_main_node_qcur_calls = 0; + uint64_t graph_main_node_qcur_us = 0; + uint64_t graph_main_node_k_draft_calls = 0; + uint64_t graph_main_node_k_draft_us = 0; + uint64_t graph_main_node_v_draft_calls = 0; + uint64_t graph_main_node_v_draft_us = 0; + uint64_t graph_main_node_k_ctx_view_calls = 0; + uint64_t graph_main_node_k_ctx_view_us = 0; + uint64_t graph_main_node_v_ctx_view_calls = 0; + uint64_t graph_main_node_v_ctx_view_us = 0; + uint64_t graph_main_node_k_concat_calls = 0; + uint64_t graph_main_node_k_concat_us = 0; + uint64_t graph_main_node_v_concat_calls = 0; + uint64_t graph_main_node_v_concat_us = 0; + uint64_t graph_main_node_k_pad_calls = 0; + uint64_t graph_main_node_k_pad_us = 0; + uint64_t graph_main_node_v_pad_calls = 0; + uint64_t graph_main_node_v_pad_us = 0; + uint64_t graph_main_node_k_perm_cont_calls = 0; + uint64_t graph_main_node_k_perm_cont_us = 0; + uint64_t graph_main_node_v_perm_cont_calls = 0; + uint64_t graph_main_node_v_perm_cont_us = 0; + uint64_t graph_main_node_flash_attn_calls = 0; + uint64_t graph_main_node_flash_attn_us = 0; + uint64_t graph_main_node_attn_out_calls = 0; + uint64_t graph_main_node_attn_out_us = 0; + uint64_t graph_main_node_ffn_calls = 0; + uint64_t graph_main_node_ffn_us = 0; + uint64_t graph_main_node_result_rows_calls = 0; + uint64_t graph_main_node_result_rows_us = 0; + uint64_t graph_main_node_result_norm_calls = 0; + uint64_t graph_main_node_result_norm_us = 0; + uint64_t graph_main_node_result_calls = 0; + uint64_t graph_main_node_result_us = 0; + uint64_t graph_feature_bytes = 0; + uint64_t graph_pos_bytes = 0; + uint64_t graph_mask_bytes = 0; + uint64_t graph_visible_kv_sum = 0; + uint64_t graph_visible_kv_max = 0; + uint64_t graph_pos_fallbacks = 0; + uint64_t graph_pos_non_monotonic = 0; + uint64_t graph_shape_failures = 0; + uint64_t graph_mask_overflow = 0; + + int32_t last_n_rows = 0; + int32_t last_width = 0; + int32_t last_cross_ctx = 0; + int32_t last_left_pad = 0; + int32_t last_n_tokens = 0; + int32_t last_n_kv_total = 0; + int32_t last_kv_cache_host_layers = 0; + int32_t capture_prompt_last_rows = 0; + int32_t capture_prompt_last_width = 0; + int32_t capture_verify_last_rows = 0; + int32_t capture_verify_last_width = 0; + llama_pos last_pos_first = -1; + llama_pos last_pos_last = -1; +}; + +struct llama_dflash_window_update { + uint64_t version = 0; + int32_t keep_rows = 0; + int32_t append_rows = 0; + bool replace = false; + const float * append_features = nullptr; + size_t append_floats = 0; +}; + +struct llama_dflash_kv_cache_transition { + bool cache_up_to_date = false; + bool rebuild_cache = false; + int32_t append_rows = 0; + int32_t next_n_filled = 0; + int32_t next_write_pos = 0; +}; + +static inline llama_dflash_kv_cache_transition llama_plan_dflash_kv_cache_transition( + int32_t cross_ctx, + int32_t current_n_filled, + int32_t current_write_pos, + bool cache_valid, + uint64_t applied_window_version, + uint64_t target_window_version, + int32_t keep_rows, + int32_t append_rows, + bool replace, + int32_t n_rows) { + llama_dflash_kv_cache_transition plan; + + const int32_t safe_cross_ctx = std::max(1, cross_ctx); + const int32_t bounded_n_filled = std::clamp(current_n_filled, 0, safe_cross_ctx); + const int32_t bounded_append_rows = std::clamp(append_rows, 0, n_rows); + const int32_t bounded_keep_rows = std::clamp(keep_rows, 0, n_rows); + const int32_t expected_keep_rows = std::min(bounded_n_filled, std::max(0, safe_cross_ctx - bounded_append_rows)); + + plan.cache_up_to_date = cache_valid && applied_window_version == target_window_version; + plan.rebuild_cache = !cache_valid || replace || bounded_append_rows <= 0 || bounded_append_rows > n_rows; + if (!plan.rebuild_cache && bounded_keep_rows != expected_keep_rows) { + plan.rebuild_cache = true; + } + + plan.append_rows = bounded_append_rows; + if (plan.cache_up_to_date) { + plan.next_n_filled = bounded_n_filled; + plan.next_write_pos = safe_cross_ctx > 0 + ? ((current_write_pos % safe_cross_ctx) + safe_cross_ctx) % safe_cross_ctx + : 0; + } else if (plan.rebuild_cache) { + plan.next_n_filled = std::min(safe_cross_ctx, n_rows); + plan.next_write_pos = plan.next_n_filled % safe_cross_ctx; + } else { + plan.next_n_filled = std::min(safe_cross_ctx, bounded_n_filled + bounded_append_rows); + plan.next_write_pos = (current_write_pos + bounded_append_rows) % safe_cross_ctx; + } + + return plan; +} + +llama_dflash_kv_cache_transition llama_plan_dflash_kv_cache_transition_for_ctx( + const struct llama_context * ctx, + const llama_dflash_window_update & window_update, + int32_t n_rows); + +void llama_dflash_profile_reset(struct llama_context * ctx); +void llama_reset_dflash_kv_cache_state(struct llama_context * ctx); +void llama_set_dflash_visible_cross_ctx(struct llama_context * ctx, int32_t cross_ctx); +int32_t llama_get_dflash_visible_cross_ctx(const struct llama_context * ctx); +bool llama_dflash_profile_get_stats(const struct llama_context * ctx, llama_dflash_profile_stats * stats); + +int32_t llama_model_dflash_block_size(const struct llama_model * model); +int32_t llama_model_dflash_mask_token_id(const struct llama_model * model); +int32_t llama_model_dflash_n_target_layers(const struct llama_model * model); +int32_t llama_model_dflash_n_target_features(const struct llama_model * model); +int32_t llama_model_dflash_target_layer_ids(const struct llama_model * model, int32_t * layer_ids, int32_t capacity); +int32_t llama_model_dflash_target_mask_token_id(const struct llama_model * model); +const struct ggml_tensor * llama_model_dflash_output_tensor(const struct llama_model * model); + +enum llama_dflash_io_mode { + LLAMA_DFLASH_IO_MODE_INVALID = 0, + LLAMA_DFLASH_IO_MODE_SHARED, + LLAMA_DFLASH_IO_MODE_SELF_CONTAINED, + LLAMA_DFLASH_IO_MODE_MIXED, +}; + +int32_t llama_model_dflash_io_mode(const struct llama_model * draft_model, const struct llama_model * target_model); +bool llama_model_dflash_io_tensors_match(const struct llama_model * draft_model, int32_t n_embd, int32_t n_vocab); +bool llama_model_share_dflash_io_tensors(struct llama_model * draft_model, const struct llama_model * target_model); + +bool llama_set_dflash_target_features_copy( + struct llama_context * ctx, + const float * target_features, + size_t n_floats, + int32_t n_rows, + const llama_pos * target_positions, + const llama_dflash_window_update * window_update = nullptr); + +bool llama_set_dflash_target_features_view( + struct llama_context * ctx, + const float * target_features, + size_t n_floats, + int32_t n_rows, + const llama_pos * target_positions, + const llama_dflash_window_update * window_update = nullptr); + +bool llama_set_dflash_capture_layers(struct llama_context * ctx, const int32_t * layer_ids, int32_t n_layers); +void llama_clear_dflash_capture(struct llama_context * ctx); +void llama_begin_dflash_capture_batch(struct llama_context * ctx); +void llama_finish_dflash_capture_batch(struct llama_context * ctx, bool is_prompt_warmup); + +bool llama_spec_get_dflash_feature_view( + struct llama_context * ctx, + const llama_batch & batch, + llama_spec_feature_view & view); + +bool llama_spec_get_dflash_feature_view_for_seq( + struct llama_context * ctx, + const llama_batch & batch, + llama_seq_id seq_id, + llama_spec_feature_view & view); + +bool llama_spec_copy_dflash_rows_from_output_indices( + struct llama_context * ctx, + const std::vector & output_indices, + std::vector & hidden_rows); diff --git a/src/llama-spec-features.cpp b/src/llama-spec-features.cpp index 00c4b6e2..933f3d15 100644 --- a/src/llama-spec-features.cpp +++ b/src/llama-spec-features.cpp @@ -10,30 +10,6 @@ #include "llama-model.h" #include "llama-context.h" -static bool llama_dflash_positions_strictly_increasing( - const llama_pos * positions, - int32_t n_rows, - llama_pos & first_pos, - llama_pos & last_pos) { - first_pos = -1; - last_pos = -1; - - if (positions == nullptr || n_rows <= 0) { - return false; - } - - first_pos = positions[0]; - last_pos = positions[n_rows - 1]; - - for (int32_t i = 1; i < n_rows; ++i) { - if (positions[i] <= positions[i - 1]) { - return false; - } - } - - return true; -} - uint32_t llama_mtp_state_n_embd(const struct llama_context * ctx) { if (ctx == nullptr) { return 0; @@ -47,278 +23,6 @@ uint32_t llama_mtp_state_n_embd(const struct llama_context * ctx) { return hparams.n_embd; } -void llama_dflash_profile_reset(struct llama_context * ctx) { - if (ctx == nullptr) { - return; - } - - ctx->dflash_profile = {}; -} - -void llama_reset_dflash_kv_cache_state(struct llama_context * ctx) { - if (ctx == nullptr) { - return; - } - - ctx->dflash_kv_cache_write_pos = 0; - ctx->dflash_kv_cache_n_filled = 0; - ctx->dflash_kv_cache_update_rows = 0; - ctx->dflash_kv_cache_view_write_pos = 0; - ctx->dflash_kv_cache_view_n_filled = 0; - ctx->dflash_kv_cache_applied_window_version = 0; - ctx->dflash_kv_cache_valid = false; - ctx->dflash_kv_cache_view_valid = false; - ctx->dflash_kv_workspace_write_pos = 0; - ctx->dflash_kv_workspace_n_filled = 0; - ctx->dflash_kv_workspace_applied_window_version = 0; - ctx->dflash_kv_workspace_valid = false; - ctx->dflash_kv_workspace_sync_pending = false; - - for (ggml_backend_buffer_t buf : ctx->dflash_cache_bufs) { - if (buf != nullptr) { - ggml_backend_buffer_clear(buf, 0); - } - } -} - -llama_dflash_kv_cache_transition llama_plan_dflash_kv_cache_transition_for_ctx( - const struct llama_context * ctx, - const llama_dflash_window_update & window_update, - int32_t n_rows) { - if (ctx == nullptr) { - llama_dflash_kv_cache_transition plan; - plan.rebuild_cache = true; - plan.append_rows = std::clamp(window_update.append_rows, 0, n_rows); - plan.next_n_filled = n_rows; - return plan; - } - - const int32_t cross_ctx = ctx->dflash_visible_cross_ctx > 0 - ? ctx->dflash_visible_cross_ctx - : std::max(1, (int32_t) ctx->cparams.n_ctx - (int32_t) ctx->model.hparams.dflash_block_size); - - return llama_plan_dflash_kv_cache_transition( - cross_ctx, - ctx->dflash_kv_cache_n_filled, - ctx->dflash_kv_cache_write_pos, - ctx->dflash_kv_cache_valid, - ctx->dflash_kv_cache_applied_window_version, - window_update.version, - window_update.keep_rows, - window_update.append_rows, - window_update.replace, - n_rows); -} - -void llama_set_dflash_visible_cross_ctx( - struct llama_context * ctx, - int32_t cross_ctx) { - if (ctx == nullptr) { - return; - } - - ctx->dflash_visible_cross_ctx = std::max(0, cross_ctx); -} - -int32_t llama_get_dflash_visible_cross_ctx( - const struct llama_context * ctx) { - return ctx != nullptr ? ctx->dflash_visible_cross_ctx : 0; -} - -bool llama_dflash_profile_get_stats( - const struct llama_context * ctx, - llama_dflash_profile_stats * stats) { - if (ctx == nullptr || stats == nullptr) { - return false; - } - - *stats = ctx->dflash_profile; - return true; -} - -int32_t llama_model_dflash_block_size(const struct llama_model * model) { - return model ? (int32_t) model->hparams.dflash_block_size : 0; -} - -int32_t llama_model_dflash_mask_token_id(const struct llama_model * model) { - return model ? (int32_t) model->hparams.dflash_mask_token_id : -1; -} - -int32_t llama_model_dflash_n_target_layers(const struct llama_model * model) { - return model ? (int32_t) model->hparams.dflash_n_target_layers : 0; -} - -int32_t llama_model_dflash_n_target_features(const struct llama_model * model) { - return model ? (int32_t) model->hparams.dflash_n_target_features : 0; -} - -int32_t llama_model_dflash_target_layer_ids( - const struct llama_model * model, - int32_t * layer_ids, - int32_t capacity) { - if (model == nullptr || layer_ids == nullptr || capacity <= 0) { - return 0; - } - - const int32_t n_layers = std::min((int32_t) model->hparams.dflash_n_target_layers, capacity); - for (int32_t i = 0; i < n_layers; ++i) { - layer_ids[i] = (int32_t) model->hparams.dflash_target_layer_ids[i]; - } - - return n_layers; -} - -int32_t llama_model_dflash_target_mask_token_id(const struct llama_model * model) { - if (model == nullptr) { - return (int32_t) LLAMA_TOKEN_NULL; - } - - return (int32_t) model->vocab.token_mask(); -} - -const struct ggml_tensor * llama_model_dflash_output_tensor( - const struct llama_model * model) { - if (model == nullptr) { - return nullptr; - } - - if (model->output_mtp != nullptr) { - return model->output_mtp; - } - - if (model->output != nullptr) { - return model->output; - } - - return model->tok_embd; -} - -static const char * llama_dflash_io_mode_name(int32_t io_mode) { - switch (io_mode) { - case LLAMA_DFLASH_IO_MODE_SHARED: - return "shared"; - case LLAMA_DFLASH_IO_MODE_SELF_CONTAINED: - return "self-contained"; - case LLAMA_DFLASH_IO_MODE_MIXED: - return "mixed"; - default: - return "invalid"; - } -} - -static const char * llama_dflash_output_head_kind( - const struct llama_model * draft_model, - const struct llama_model * target_model) { - const struct ggml_tensor * output = llama_model_dflash_output_tensor(draft_model); - if (output == nullptr) { - return "missing"; - } - - if (output == draft_model->tok_embd) { - return draft_model->tok_embd == (target_model ? target_model->tok_embd : nullptr) - ? "shared_token_embedding" - : "token_embedding"; - } - - if (draft_model->output_mtp != nullptr && output == draft_model->output_mtp) { - if (target_model != nullptr && target_model->output_mtp != nullptr && output == target_model->output_mtp) { - return "output_mtp"; - } - - if (std::strcmp(output->name, "output_extra.weight") == 0) { - return "output_extra"; - } - - return "output_mtp"; - } - - return "output"; -} - -int32_t llama_model_dflash_io_mode( - const struct llama_model * draft_model, - const struct llama_model * target_model) { - if (draft_model == nullptr || target_model == nullptr || draft_model->arch != LLM_ARCH_DFLASH_DRAFT) { - return LLAMA_DFLASH_IO_MODE_INVALID; - } - - const ggml_tensor * draft_output = llama_model_dflash_output_tensor(draft_model); - const ggml_tensor * target_output = llama_model_dflash_output_tensor(target_model); - if (draft_model->tok_embd == nullptr || draft_output == nullptr || target_model->tok_embd == nullptr || target_output == nullptr) { - return LLAMA_DFLASH_IO_MODE_INVALID; - } - - const bool shared_tok = draft_model->tok_embd == target_model->tok_embd; - const bool shared_output = draft_output == target_output; - if (shared_tok && shared_output) { - return LLAMA_DFLASH_IO_MODE_SHARED; - } - - if (!shared_tok && !shared_output) { - return LLAMA_DFLASH_IO_MODE_SELF_CONTAINED; - } - - return LLAMA_DFLASH_IO_MODE_MIXED; -} - -bool llama_model_dflash_io_tensors_match( - const struct llama_model * draft_model, - int32_t n_embd, - int32_t n_vocab) { - const ggml_tensor * output = llama_model_dflash_output_tensor(draft_model); - if (draft_model == nullptr || draft_model->tok_embd == nullptr || output == nullptr || n_embd <= 0 || n_vocab <= 0) { - return false; - } - - return (int32_t) draft_model->tok_embd->ne[0] == n_embd && - (int32_t) draft_model->tok_embd->ne[1] == n_vocab && - (int32_t) output->ne[0] == n_embd && - (int32_t) output->ne[1] == n_vocab; -} - -bool llama_model_share_dflash_io_tensors( - struct llama_model * draft_model, - const struct llama_model * target_model) { - if (draft_model == nullptr || target_model == nullptr) { - return false; - } - - if (draft_model->arch != LLM_ARCH_DFLASH_DRAFT) { - return true; - } - - if (draft_model->tok_embd == nullptr) { - draft_model->tok_embd = target_model->tok_embd; - } - - if (draft_model->output == nullptr) { - draft_model->output = target_model->output ? target_model->output : target_model->tok_embd; - if (draft_model->output == nullptr) { - draft_model->output = draft_model->tok_embd; - } - } - - const bool uses_shared_tok = draft_model->tok_embd == target_model->tok_embd; - const bool uses_shared_output = draft_model->output == target_model->output || - draft_model->output == target_model->tok_embd; - - if (draft_model->output_mtp == nullptr && target_model->output_mtp != nullptr && uses_shared_tok && uses_shared_output) { - draft_model->output_mtp = target_model->output_mtp; - } - - const struct ggml_tensor * output = llama_model_dflash_output_tensor(draft_model); - if (draft_model->tok_embd != nullptr && output != nullptr) { - LLAMA_LOG_INFO("%s: DFlash IO mode=%s output_head=%s tensor=%s type=%s\n", - __func__, - llama_dflash_io_mode_name(llama_model_dflash_io_mode(draft_model, target_model)), - llama_dflash_output_head_kind(draft_model, target_model), - output->name[0] != '\0' ? output->name : "(unnamed)", - ggml_type_name(output->type)); - } - - return draft_model->tok_embd != nullptr && output != nullptr; -} - bool llama_set_draft_input_hidden_state_copy( struct llama_context * ctx, const float * hidden_state, @@ -333,648 +37,6 @@ bool llama_set_draft_input_hidden_state_copy( return true; } -static bool llama_set_dflash_target_features_impl( - struct llama_context * ctx, - const float * target_features, - size_t n_floats, - int32_t n_rows, - const llama_pos * target_positions, - bool copy_data, - const llama_dflash_window_update * window_update) { - const bool have_full_features = target_features != nullptr && n_floats > 0; - const bool have_append_features = window_update != nullptr && - window_update->append_features != nullptr && - window_update->append_floats > 0 && - window_update->append_rows > 0; - - if (ctx == nullptr || n_rows <= 0 || (!have_full_features && !have_append_features)) { - return false; - } - - auto & profile = ctx->dflash_profile; - const int64_t t_start_us = ggml_time_us(); - const int32_t row_width = have_full_features - ? (n_rows > 0 ? (int32_t) (n_floats / (size_t) n_rows) : 0) - : (window_update->append_rows > 0 ? (int32_t) (window_update->append_floats / (size_t) window_update->append_rows) : 0); - llama_pos first_pos = -1; - llama_pos last_pos = -1; - - if (have_full_features && copy_data) { - ctx->dflash_target_features_owned.assign(target_features, target_features + n_floats); - ctx->dflash_target_features = ctx->dflash_target_features_owned.data(); - } else if (have_full_features) { - ctx->dflash_target_features_owned.clear(); - ctx->dflash_target_features = target_features; - } else { - ctx->dflash_target_features_owned.clear(); - ctx->dflash_target_features = nullptr; - } - ctx->dflash_target_features_n_floats = have_full_features ? n_floats : 0; - ctx->dflash_target_features_n_rows = n_rows; - if (have_append_features && copy_data) { - ctx->dflash_target_append_features_owned.assign( - window_update->append_features, - window_update->append_features + window_update->append_floats); - ctx->dflash_target_append_features = ctx->dflash_target_append_features_owned.data(); - } else if (have_append_features) { - ctx->dflash_target_append_features_owned.clear(); - ctx->dflash_target_append_features = window_update->append_features; - } else { - ctx->dflash_target_append_features_owned.clear(); - ctx->dflash_target_append_features = nullptr; - } - ctx->dflash_target_append_features_n_floats = have_append_features ? window_update->append_floats : 0; - ctx->dflash_target_append_features_n_rows = have_append_features ? window_update->append_rows : 0; - ctx->dflash_target_window_version = window_update != nullptr && window_update->version > 0 - ? window_update->version - : ctx->dflash_target_window_version + 1; - ctx->dflash_target_window_keep_rows = window_update != nullptr - ? std::max(0, std::min(n_rows, window_update->keep_rows)) - : 0; - ctx->dflash_target_window_append_rows = window_update != nullptr - ? std::max(0, std::min(n_rows, window_update->append_rows)) - : n_rows; - ctx->dflash_target_window_replace = window_update != nullptr - ? window_update->replace - : true; - if (ctx->dflash_target_window_keep_rows + ctx->dflash_target_window_append_rows > n_rows) { - ctx->dflash_target_window_keep_rows = std::max(0, n_rows - ctx->dflash_target_window_append_rows); - } - - const int32_t cross_ctx = ctx->dflash_visible_cross_ctx > 0 - ? ctx->dflash_visible_cross_ctx - : std::max(1, (int32_t) ctx->cparams.n_ctx - (int32_t) ctx->model.hparams.dflash_block_size); - const llama_dflash_window_update cache_window_update = { - ctx->dflash_target_window_version, - ctx->dflash_target_window_keep_rows, - ctx->dflash_target_window_append_rows, - ctx->dflash_target_window_replace, - ctx->dflash_target_append_features, - ctx->dflash_target_append_features_n_floats, - }; - const llama_dflash_kv_cache_transition cache_plan = llama_plan_dflash_kv_cache_transition_for_ctx(ctx, cache_window_update, n_rows); - - if (cache_plan.cache_up_to_date) { - ctx->dflash_kv_cache_view_n_filled = ctx->dflash_kv_cache_n_filled; - ctx->dflash_kv_cache_view_write_pos = ctx->dflash_kv_cache_write_pos; - ctx->dflash_kv_cache_view_valid = ctx->dflash_kv_cache_valid; - } else if (cross_ctx > 0) { - ctx->dflash_kv_cache_view_n_filled = cache_plan.next_n_filled; - ctx->dflash_kv_cache_view_write_pos = cache_plan.next_write_pos; - ctx->dflash_kv_cache_view_valid = cache_plan.next_n_filled > 0; - } - - if (target_positions != nullptr) { - if (copy_data) { - ctx->dflash_target_positions_owned.assign(target_positions, target_positions + n_rows); - ctx->dflash_target_positions = ctx->dflash_target_positions_owned.data(); - } else { - ctx->dflash_target_positions_owned.clear(); - ctx->dflash_target_positions = target_positions; - } - ctx->dflash_target_positions_n = (size_t) n_rows; - } else { - ctx->dflash_target_positions_owned.clear(); - ctx->dflash_target_positions = nullptr; - ctx->dflash_target_positions_n = 0; - } - - profile.set_target_copy_calls++; - profile.set_target_copy_us += (uint64_t) (ggml_time_us() - t_start_us); - profile.set_target_rows += (uint64_t) n_rows; - profile.set_target_copy_bytes += - (have_full_features ? n_floats : 0) * sizeof(float) + - (have_append_features ? window_update->append_floats : 0) * sizeof(float) + - (target_positions ? (size_t) n_rows * sizeof(llama_pos) : 0); - profile.last_n_rows = n_rows; - profile.last_width = row_width; - - if (target_positions == nullptr) { - profile.set_target_missing_positions++; - profile.last_pos_first = -1; - profile.last_pos_last = -1; - } else { - if (!llama_dflash_positions_strictly_increasing(target_positions, n_rows, first_pos, last_pos)) { - profile.set_target_non_monotonic_positions++; - } - profile.last_pos_first = first_pos; - profile.last_pos_last = last_pos; - } - - return true; -} - -bool llama_set_dflash_target_features_copy( - struct llama_context * ctx, - const float * target_features, - size_t n_floats, - int32_t n_rows, - const llama_pos * target_positions, - const llama_dflash_window_update * window_update) { - return llama_set_dflash_target_features_impl(ctx, target_features, n_floats, n_rows, target_positions, true, window_update); -} - -bool llama_set_dflash_target_features_view( - struct llama_context * ctx, - const float * target_features, - size_t n_floats, - int32_t n_rows, - const llama_pos * target_positions, - const llama_dflash_window_update * window_update) { - return llama_set_dflash_target_features_impl(ctx, target_features, n_floats, n_rows, target_positions, false, window_update); -} - -static void llama_record_dflash_capture_phase( - struct llama_context * ctx, - bool is_prompt_warmup, - int32_t row_count, - int32_t row_width) { - if (ctx == nullptr || row_count <= 0 || row_width <= 0) { - return; - } - - auto & profile = ctx->dflash_profile; - if (is_prompt_warmup) { - profile.capture_prompt_batches++; - if (profile.capture_prompt_last_rows > 0 && profile.capture_prompt_last_width > 0 && - (profile.capture_prompt_last_rows != row_count || profile.capture_prompt_last_width != row_width)) { - profile.capture_prompt_shape_changes++; - } - profile.capture_prompt_last_rows = row_count; - profile.capture_prompt_last_width = row_width; - } else { - profile.capture_verify_batches++; - if (profile.capture_verify_last_rows > 0 && profile.capture_verify_last_width > 0 && - (profile.capture_verify_last_rows != row_count || profile.capture_verify_last_width != row_width)) { - profile.capture_verify_shape_changes++; - } - profile.capture_verify_last_rows = row_count; - profile.capture_verify_last_width = row_width; - } -} - -static bool llama_dflash_parse_layer_id(const struct ggml_tensor * tensor, int32_t & layer_id) { - if (tensor == nullptr) { - return false; - } - - static constexpr const char * prefix = "l_out-"; - if (std::strncmp(tensor->name, prefix, std::strlen(prefix)) != 0) { - return false; - } - - char * end = nullptr; - const long raw = std::strtol(tensor->name + std::strlen(prefix), &end, 10); - if (end == tensor->name + std::strlen(prefix) || *end != '\0') { - return false; - } - - layer_id = (int32_t) raw; - if (layer_id >= 1000) { - layer_id %= 1000; - } - - return layer_id >= 0; -} - -static int32_t llama_dflash_find_layer_index(const struct llama_context * ctx, int32_t layer_id) { - if (ctx == nullptr || !ctx->dflash_capture) { - return -1; - } - - const auto & layer_ids = ctx->dflash_capture->layer_ids; - const auto it = std::find(layer_ids.begin(), layer_ids.end(), layer_id); - return it == layer_ids.end() ? -1 : (int32_t) std::distance(layer_ids.begin(), it); -} - -static bool llama_dflash_capture_eval_callback(struct ggml_tensor * tensor, bool ask, void * user_data) { - auto * ctx = static_cast(user_data); - if (ctx == nullptr || !ctx->dflash_capture) { - return false; - } - - int32_t layer_id = -1; - if (!llama_dflash_parse_layer_id(tensor, layer_id)) { - return false; - } - - const int32_t layer_idx = llama_dflash_find_layer_index(ctx, layer_id); - if (layer_idx < 0) { - return false; - } - - if (ask) { - return true; - } - - const int32_t row_width = (int32_t) tensor->ne[0]; - const int32_t row_count = row_width > 0 ? (int32_t) (ggml_nelements(tensor) / (int64_t) row_width) : 0; - if (row_width <= 0 || row_count <= 0) { - return false; - } - - auto & capture = *ctx->dflash_capture; - if (capture.capture_batch_id == 0) { - capture.capture_batch_id = 1; - } - if (capture.layer_seen_batch_id.size() != capture.layer_ids.size()) { - capture.layer_seen_batch_id.assign(capture.layer_ids.size(), 0); - } - - auto & rows = capture.layer_rows[(size_t) layer_idx]; - rows.resize((size_t) row_count * (size_t) row_width); - ggml_backend_tensor_get(tensor, rows.data(), 0, ggml_nbytes(tensor)); - capture.row_width = row_width; - capture.row_count = row_count; - capture.layer_seen_batch_id[(size_t) layer_idx] = capture.capture_batch_id; - return true; -} - -bool llama_set_dflash_capture_layers( - struct llama_context * ctx, - const int32_t * layer_ids, - int32_t n_layers) { - if (ctx == nullptr || layer_ids == nullptr || n_layers <= 0) { - return false; - } - - auto capture = std::make_unique(); - capture->layer_ids.assign(layer_ids, layer_ids + n_layers); - capture->layer_rows.resize((size_t) n_layers); - capture->layer_seen_batch_id.assign((size_t) n_layers, 0); - capture->prev_cb_eval = ctx->cparams.cb_eval; - capture->prev_cb_eval_user_data = ctx->cparams.cb_eval_user_data; - ctx->dflash_capture = std::move(capture); - ctx->dflash_feature_view_buffer.clear(); - - ctx->cparams.cb_eval = llama_dflash_capture_eval_callback; - ctx->cparams.cb_eval_user_data = ctx; - if (ctx->sched != nullptr) { - ggml_backend_sched_set_eval_callback(ctx->sched, ctx->cparams.cb_eval, ctx->cparams.cb_eval_user_data); - } - - return true; -} - -void llama_clear_dflash_capture(struct llama_context * ctx) { - if (ctx == nullptr) { - return; - } - - ggml_backend_sched_eval_callback prev_cb_eval = nullptr; - void * prev_cb_eval_user_data = nullptr; - if (ctx->dflash_capture) { - prev_cb_eval = ctx->dflash_capture->prev_cb_eval; - prev_cb_eval_user_data = ctx->dflash_capture->prev_cb_eval_user_data; - } - - ctx->dflash_capture.reset(); - ctx->dflash_feature_view_buffer.clear(); - - if (ctx->cparams.cb_eval == llama_dflash_capture_eval_callback && ctx->cparams.cb_eval_user_data == ctx) { - ctx->cparams.cb_eval = prev_cb_eval; - ctx->cparams.cb_eval_user_data = prev_cb_eval_user_data; - if (ctx->sched != nullptr) { - ggml_backend_sched_set_eval_callback(ctx->sched, prev_cb_eval, prev_cb_eval_user_data); - } - } -} - -void llama_begin_dflash_capture_batch(struct llama_context * ctx) { - if (ctx == nullptr || !ctx->dflash_capture) { - return; - } - - auto & capture = *ctx->dflash_capture; - capture.capture_batch_id++; - capture.row_count = 0; - capture.row_width = 0; - std::fill(capture.layer_seen_batch_id.begin(), capture.layer_seen_batch_id.end(), 0); -} - -void llama_finish_dflash_capture_batch( - struct llama_context * ctx, - bool is_prompt_warmup) { - if (ctx == nullptr || !ctx->dflash_capture) { - return; - } - - auto & capture = *ctx->dflash_capture; - llama_record_dflash_capture_phase(ctx, is_prompt_warmup, capture.row_count, capture.row_width); - - // Reset the batch-local reference shape so the next decode only compares layers within - // the same batch, not against the previous prompt/verify batch. - capture.row_count = 0; - capture.row_width = 0; -} - -static bool llama_spec_prepare_dflash_capture( - struct llama_context * ctx, - int32_t & row_count, - int32_t & row_width, - int32_t & n_layers) { - if (ctx == nullptr || !ctx->dflash_capture) { - return false; - } - - auto & profile = ctx->dflash_profile; - profile.capture_prepare_calls++; - const int64_t t_sync_us = ggml_time_us(); - llama_synchronize(ctx); - profile.capture_prepare_sync_us += (uint64_t) (ggml_time_us() - t_sync_us); - - auto & capture = *ctx->dflash_capture; - row_count = capture.row_count; - row_width = capture.row_width; - n_layers = (int32_t) capture.layer_ids.size(); - if (row_count <= 0 || row_width <= 0 || n_layers <= 0 || capture.layer_rows.size() != (size_t) n_layers) { - profile.capture_prepare_failures++; - return false; - } - - if (capture.capture_batch_id == 0 || capture.layer_seen_batch_id.size() != (size_t) n_layers) { - profile.capture_prepare_failures++; - profile.capture_layer_batch_mismatch++; - if (profile.capture_layer_batch_mismatch <= 3) { - LLAMA_LOG_WARN("%s: DFlash capture batch markers are not initialized (batch_id=%llu layers=%zu expected=%d)\n", - __func__, - (unsigned long long) capture.capture_batch_id, - capture.layer_seen_batch_id.size(), - n_layers); - } - return false; - } - - for (int32_t layer_idx = 0; layer_idx < n_layers; ++layer_idx) { - if (capture.layer_seen_batch_id[(size_t) layer_idx] != capture.capture_batch_id) { - profile.capture_prepare_failures++; - profile.capture_layer_batch_mismatch++; - if (profile.capture_layer_batch_mismatch <= 3) { - LLAMA_LOG_WARN("%s: DFlash capture is stale for layer %d (seen_batch=%llu current_batch=%llu rows=%d width=%d)\n", - __func__, - capture.layer_ids[(size_t) layer_idx], - (unsigned long long) capture.layer_seen_batch_id[(size_t) layer_idx], - (unsigned long long) capture.capture_batch_id, - row_count, - row_width); - } - return false; - } - - const auto & rows = capture.layer_rows[(size_t) layer_idx]; - if (rows.size() != (size_t) row_count * (size_t) row_width) { - profile.capture_prepare_failures++; - profile.capture_layer_shape_mismatch++; - if (profile.capture_layer_shape_mismatch <= 3) { - LLAMA_LOG_WARN("%s: DFlash capture rows mismatch for layer %d: got=%zu expected=%zu (rows=%d width=%d)\n", - __func__, capture.layer_ids[(size_t) layer_idx], rows.size(), - (size_t) row_count * (size_t) row_width, row_count, row_width); - } - return false; - } - } - - return true; -} - -static bool llama_dflash_contract_log_enabled() { - const char * env = std::getenv("IK_DFLASH_CONTRACT_LOG"); - if (env == nullptr || *env == '\0') { - return false; - } - - return std::strcmp(env, "0") != 0 && - std::strcmp(env, "false") != 0 && - std::strcmp(env, "off") != 0; -} - -template -static std::string llama_dflash_contract_format_values( - const std::vector & values, - size_t edge_count = 4) { - std::ostringstream oss; - oss << '['; - if (values.empty()) { - oss << ']'; - return oss.str(); - } - - const size_t head = std::min(edge_count, values.size()); - for (size_t i = 0; i < head; ++i) { - if (i > 0) { - oss << ','; - } - oss << values[i]; - } - - if (values.size() > edge_count * 2) { - oss << ",...,"; - for (size_t i = values.size() - edge_count; i < values.size(); ++i) { - if (i > values.size() - edge_count) { - oss << ','; - } - oss << values[i]; - } - } else { - for (size_t i = head; i < values.size(); ++i) { - oss << ',' << values[i]; - } - } - - oss << ']'; - return oss.str(); -} - -static std::vector llama_dflash_contract_collect_batch_positions( - const llama_batch & batch, - const std::vector & batch_indices) { - std::vector positions; - positions.reserve(batch_indices.size()); - for (int32_t batch_index : batch_indices) { - positions.push_back(batch.pos[batch_index]); - } - return positions; -} - -static void llama_dflash_contract_summarize_positions( - const std::vector & positions, - llama_pos & first_pos, - llama_pos & last_pos, - int32_t & gap_count, - int32_t & nonmono_count) { - first_pos = -1; - last_pos = -1; - gap_count = 0; - nonmono_count = 0; - if (positions.empty()) { - return; - } - - first_pos = positions.front(); - last_pos = positions.back(); - for (size_t i = 1; i < positions.size(); ++i) { - if (positions[i] <= positions[i - 1]) { - nonmono_count++; - } else if (positions[i] != positions[i - 1] + 1) { - gap_count++; - } - } -} - -static void llama_dflash_contract_log_feature_view( - const char * kind, - llama_seq_id seq_id, - const llama_batch & batch, - int32_t row_count, - int32_t row_width, - int32_t n_layers, - int32_t batch_row_offset, - const std::vector & row_indices, - const std::vector & batch_indices) { - if (!llama_dflash_contract_log_enabled()) { - return; - } - - static std::atomic counter = 0; - const uint64_t ordinal = counter.fetch_add(1, std::memory_order_relaxed); - if (ordinal >= 8) { - return; - } - - const std::vector positions = llama_dflash_contract_collect_batch_positions(batch, batch_indices); - llama_pos first_pos = -1; - llama_pos last_pos = -1; - int32_t gap_count = 0; - int32_t nonmono_count = 0; - llama_dflash_contract_summarize_positions(positions, first_pos, last_pos, gap_count, nonmono_count); - - LLAMA_LOG_INFO("%s[%llu]: kind=%s seq=%d batch_tokens=%d capture_rows=%d row_width=%d layers=%d batch_row_offset=%d row_indices=%s batch_indices=%s batch_pos=%s pos=[%d..%d] gaps=%d nonmono=%d\n", - __func__, - (unsigned long long) (ordinal + 1), - kind, - (int) seq_id, - batch.n_tokens, - row_count, - row_width, - n_layers, - batch_row_offset, - llama_dflash_contract_format_values(row_indices).c_str(), - llama_dflash_contract_format_values(batch_indices).c_str(), - llama_dflash_contract_format_values(positions).c_str(), - (int) first_pos, - (int) last_pos, - gap_count, - nonmono_count); -} - -static void llama_dflash_contract_log_output_indices( - struct llama_context * ctx, - const std::vector & output_indices) { - if (!llama_dflash_contract_log_enabled()) { - return; - } - - static std::atomic counter = 0; - const uint64_t ordinal = counter.fetch_add(1, std::memory_order_relaxed); - if (ordinal >= 8) { - return; - } - - int32_t row_count = 0; - int32_t row_width = 0; - int32_t n_layers = 0; - const bool have_capture = llama_spec_prepare_dflash_capture(ctx, row_count, row_width, n_layers); - - LLAMA_LOG_INFO("%s[%llu]: output_indices=%s capture_rows=%d row_width=%d layers=%d have_capture=%s\n", - __func__, - (unsigned long long) (ordinal + 1), - llama_dflash_contract_format_values(output_indices).c_str(), - row_count, - row_width, - n_layers, - have_capture ? "true" : "false"); -} - - static bool llama_spec_materialize_dflash_rows_prepared( - struct llama_context * ctx, - int32_t row_count, - int32_t row_width, - int32_t n_layers, - const std::vector & row_indices, - std::vector & rows_out, - int32_t & combined_width); - -static bool llama_spec_materialize_dflash_rows( - struct llama_context * ctx, - const std::vector & row_indices, - std::vector & rows_out, - int32_t & combined_width) { - int32_t row_count = 0; - int32_t row_width = 0; - int32_t n_layers = 0; - if (!llama_spec_prepare_dflash_capture(ctx, row_count, row_width, n_layers)) { - if (ctx != nullptr) { - ctx->dflash_profile.capture_materialize_failures++; - } - return false; - } - - return llama_spec_materialize_dflash_rows_prepared(ctx, row_count, row_width, n_layers, row_indices, rows_out, combined_width); -} - -static bool llama_spec_materialize_dflash_rows_prepared( - struct llama_context * ctx, - int32_t row_count, - int32_t row_width, - int32_t n_layers, - const std::vector & row_indices, - std::vector & rows_out, - int32_t & combined_width) { - rows_out.clear(); - combined_width = 0; - if (ctx == nullptr || row_indices.empty()) { - return false; - } - - auto & profile = ctx->dflash_profile; - profile.capture_materialize_calls++; - const int64_t t_start_us = ggml_time_us(); - - if (row_count <= 0 || row_width <= 0 || n_layers <= 0 || ctx->dflash_capture == nullptr) { - profile.capture_materialize_failures++; - return false; - } - - combined_width = row_width * n_layers; - rows_out.resize((size_t) row_indices.size() * (size_t) combined_width); - - const auto & layer_rows = ctx->dflash_capture->layer_rows; - for (size_t out_row = 0; out_row < row_indices.size(); ++out_row) { - int32_t row_index = row_indices[out_row]; - if (row_index < 0) { - row_index += row_count; - } - if (row_index < 0 || row_index >= row_count) { - rows_out.clear(); - combined_width = 0; - profile.capture_materialize_failures++; - return false; - } - - float * dst = rows_out.data() + out_row * (size_t) combined_width; - for (int32_t layer_idx = 0; layer_idx < n_layers; ++layer_idx) { - const float * src = layer_rows[(size_t) layer_idx].data() + (size_t) row_index * (size_t) row_width; - std::memcpy(dst + (size_t) layer_idx * (size_t) row_width, src, (size_t) row_width * sizeof(float)); - } - } - - profile.capture_materialize_us += (uint64_t) (ggml_time_us() - t_start_us); - profile.capture_materialize_rows += (uint64_t) row_indices.size(); - profile.capture_materialize_bytes += rows_out.size() * sizeof(float); - - return true; -} - static bool llama_spec_prepare_hidden_feature_view( struct llama_context * ctx, int32_t n_rows, @@ -1031,137 +93,6 @@ bool llama_spec_get_hidden_feature_view( return true; } -bool llama_spec_get_dflash_feature_view( - struct llama_context * ctx, - const llama_batch & batch, - llama_spec_feature_view & view) { - if (ctx == nullptr || batch.n_tokens <= 0 || batch.pos == nullptr || batch.n_seq_id == nullptr || batch.seq_id == nullptr) { - return false; - } - - int32_t row_count = 0; - int32_t row_width = 0; - int32_t n_layers = 0; - if (!llama_spec_prepare_dflash_capture(ctx, row_count, row_width, n_layers)) { - return false; - } - - const int32_t batch_row_offset = std::max(0, batch.n_tokens - row_count); - std::vector row_indices; - std::vector batch_indices; - row_indices.reserve((size_t) (batch.n_tokens - batch_row_offset)); - batch_indices.reserve((size_t) (batch.n_tokens - batch_row_offset)); - for (int32_t i = batch_row_offset; i < batch.n_tokens; ++i) { - row_indices.push_back(i - batch_row_offset); - batch_indices.push_back(i); - } - - if (row_indices.empty()) { - return false; - } - - view = {}; - view.kind = LLAMA_SPEC_FEATURE_HIDDEN_STATE; - if (!llama_spec_materialize_dflash_rows_prepared(ctx, row_count, row_width, n_layers, row_indices, ctx->dflash_feature_view_buffer, view.width)) { - return false; - } - - view.rows.reserve(batch_indices.size()); - for (int32_t batch_index : batch_indices) { - if (batch.n_seq_id[batch_index] <= 0 || batch.seq_id[batch_index] == nullptr) { - view.rows.clear(); - return false; - } - - view.rows.push_back({ - /* .seq_id = */ batch.seq_id[batch_index][0], - /* .pos = */ batch.pos[batch_index], - /* .data = */ ctx->dflash_feature_view_buffer.data() + view.rows.size() * (size_t) view.width, - }); - } - - llama_dflash_contract_log_feature_view( - "batch", - view.rows.empty() ? -1 : view.rows.front().seq_id, - batch, - row_count, - row_width, - n_layers, - batch_row_offset, - row_indices, - batch_indices); - - return true; -} - -bool llama_spec_get_dflash_feature_view_for_seq( - struct llama_context * ctx, - const llama_batch & batch, - llama_seq_id seq_id, - llama_spec_feature_view & view) { - if (ctx == nullptr || batch.n_tokens <= 0 || batch.pos == nullptr || batch.n_seq_id == nullptr || batch.seq_id == nullptr) { - return false; - } - - int32_t row_count = 0; - int32_t row_width = 0; - int32_t n_layers = 0; - if (!llama_spec_prepare_dflash_capture(ctx, row_count, row_width, n_layers)) { - return false; - } - - const int32_t batch_row_offset = std::max(0, batch.n_tokens - row_count); - std::vector row_indices; - row_indices.reserve((size_t) batch.n_tokens); - std::vector batch_indices; - batch_indices.reserve((size_t) batch.n_tokens); - for (int32_t i = batch_row_offset; i < batch.n_tokens; ++i) { - if (batch.n_seq_id[i] <= 0 || batch.seq_id[i] == nullptr) { - return false; - } - - for (int32_t j = 0; j < batch.n_seq_id[i]; ++j) { - if (batch.seq_id[i][j] == seq_id) { - row_indices.push_back(i - batch_row_offset); - batch_indices.push_back(i); - break; - } - } - } - - if (row_indices.empty()) { - return false; - } - - view = {}; - view.kind = LLAMA_SPEC_FEATURE_HIDDEN_STATE; - if (!llama_spec_materialize_dflash_rows_prepared(ctx, row_count, row_width, n_layers, row_indices, ctx->dflash_feature_view_buffer, view.width)) { - return false; - } - - view.rows.reserve(row_indices.size()); - for (size_t i = 0; i < batch_indices.size(); ++i) { - const int32_t batch_index = batch_indices[i]; - view.rows.push_back({ - /* .seq_id = */ seq_id, - /* .pos = */ batch.pos[batch_index], - /* .data = */ ctx->dflash_feature_view_buffer.data() + i * (size_t) view.width, - }); - } - - llama_dflash_contract_log_feature_view( - "seq", - seq_id, - batch, - row_count, - row_width, - n_layers, - batch_row_offset, - row_indices, - batch_indices); - - return true; -} bool llama_spec_get_hidden_feature_view_for_seq( struct llama_context * ctx, @@ -1255,18 +186,3 @@ bool llama_spec_copy_hidden_rows_from_output_indices( return hidden_rows.size() == (size_t) output_indices.size() * view.width; } - -bool llama_spec_copy_dflash_rows_from_output_indices( - struct llama_context * ctx, - const std::vector & output_indices, - std::vector & hidden_rows) { - int32_t combined_width = 0; - if (!llama_spec_materialize_dflash_rows(ctx, output_indices, hidden_rows, combined_width)) { - hidden_rows.clear(); - return false; - } - - llama_dflash_contract_log_output_indices(ctx, output_indices); - - return hidden_rows.size() == (size_t) output_indices.size() * (size_t) combined_width; -} diff --git a/src/llama-spec-features.h b/src/llama-spec-features.h index 1c327049..b1342fed 100644 --- a/src/llama-spec-features.h +++ b/src/llama-spec-features.h @@ -2,7 +2,6 @@ #include "llama.h" -#include #include #include @@ -25,316 +24,20 @@ struct llama_spec_feature_view { std::vector rows; }; -struct llama_dflash_profile_stats { - uint64_t decode_internal_chunks = 0; - uint64_t decode_graph_rebuilds = 0; - uint64_t decode_sync_profile_points = 0; - uint64_t decode_prelude_us = 0; - uint64_t decode_sched_reset_us = 0; - uint64_t decode_build_graph_us = 0; - uint64_t decode_sched_alloc_graph_us = 0; - uint64_t decode_set_inputs_us = 0; - uint64_t decode_graph_compute_us = 0; - uint64_t decode_result_us = 0; - uint64_t decode_embedding_us = 0; - uint64_t decode_final_sched_reset_us = 0; - - uint64_t decode_output_reserve_calls = 0; - uint64_t decode_output_reserve_us = 0; - uint64_t decode_output_reserve_reallocs = 0; - uint64_t decode_output_reserve_realloc_bytes = 0; - uint64_t decode_prepare_calls = 0; - uint64_t decode_prepare_us = 0; - uint64_t decode_prepare_failures = 0; - - uint64_t set_target_copy_calls = 0; - uint64_t set_target_copy_us = 0; - uint64_t set_target_rows = 0; - uint64_t set_target_copy_bytes = 0; - uint64_t set_target_missing_positions = 0; - uint64_t set_target_non_monotonic_positions = 0; - - uint64_t capture_prepare_calls = 0; - uint64_t capture_prepare_sync_us = 0; - uint64_t capture_prepare_failures = 0; - uint64_t capture_layer_shape_mismatch = 0; - uint64_t capture_layer_batch_mismatch = 0; - uint64_t capture_prompt_batches = 0; - uint64_t capture_prompt_shape_changes = 0; - uint64_t capture_verify_batches = 0; - uint64_t capture_verify_shape_changes = 0; - uint64_t capture_materialize_calls = 0; - uint64_t capture_materialize_rows = 0; - uint64_t capture_materialize_bytes = 0; - uint64_t capture_materialize_us = 0; - uint64_t capture_materialize_failures = 0; - - uint64_t graph_prepare_calls = 0; - uint64_t graph_prepare_total_us = 0; - uint64_t graph_feature_copy_us = 0; - uint64_t graph_pos_copy_us = 0; - uint64_t graph_mask_build_us = 0; - uint64_t graph_kv_cache_build_us = 0; - uint64_t graph_kv_cache_reserve_us = 0; - uint64_t graph_kv_cache_reset_us = 0; - uint64_t graph_kv_cache_alloc_us = 0; - uint64_t graph_kv_cache_feature_upload_us = 0; - uint64_t graph_kv_cache_pos_upload_us = 0; - uint64_t graph_kv_cache_compute_us = 0; - uint64_t graph_kv_cache_sync_us = 0; - uint64_t graph_kv_cache_read_concat_pad_us = 0; - uint64_t graph_kv_cache_read_concat_pad_calls = 0; - uint64_t graph_kv_cache_cached_bytes = 0; - uint64_t graph_kv_cache_calls = 0; - uint64_t graph_kv_workspace_build_us = 0; - uint64_t graph_kv_workspace_reserve_us = 0; - uint64_t graph_kv_workspace_reset_us = 0; - uint64_t graph_kv_workspace_alloc_us = 0; - uint64_t graph_kv_workspace_compute_us = 0; - uint64_t graph_kv_workspace_sync_us = 0; - uint64_t graph_kv_workspace_calls = 0; - uint64_t graph_kv_node_fused_target_calls = 0; - uint64_t graph_kv_node_fused_target_us = 0; - uint64_t graph_kv_node_k_proj_calls = 0; - uint64_t graph_kv_node_k_proj_us = 0; - uint64_t graph_kv_node_k_norm_calls = 0; - uint64_t graph_kv_node_k_norm_us = 0; - uint64_t graph_kv_node_k_rope_calls = 0; - uint64_t graph_kv_node_k_rope_us = 0; - uint64_t graph_kv_node_v_proj_calls = 0; - uint64_t graph_kv_node_v_proj_us = 0; - uint64_t graph_kv_node_k_store_calls = 0; - uint64_t graph_kv_node_k_store_us = 0; - uint64_t graph_kv_node_v_store_calls = 0; - uint64_t graph_kv_node_v_store_us = 0; - uint64_t graph_main_node_qcur_calls = 0; - uint64_t graph_main_node_qcur_us = 0; - uint64_t graph_main_node_k_draft_calls = 0; - uint64_t graph_main_node_k_draft_us = 0; - uint64_t graph_main_node_v_draft_calls = 0; - uint64_t graph_main_node_v_draft_us = 0; - uint64_t graph_main_node_k_ctx_view_calls = 0; - uint64_t graph_main_node_k_ctx_view_us = 0; - uint64_t graph_main_node_v_ctx_view_calls = 0; - uint64_t graph_main_node_v_ctx_view_us = 0; - uint64_t graph_main_node_k_concat_calls = 0; - uint64_t graph_main_node_k_concat_us = 0; - uint64_t graph_main_node_v_concat_calls = 0; - uint64_t graph_main_node_v_concat_us = 0; - uint64_t graph_main_node_k_pad_calls = 0; - uint64_t graph_main_node_k_pad_us = 0; - uint64_t graph_main_node_v_pad_calls = 0; - uint64_t graph_main_node_v_pad_us = 0; - uint64_t graph_main_node_k_perm_cont_calls = 0; - uint64_t graph_main_node_k_perm_cont_us = 0; - uint64_t graph_main_node_v_perm_cont_calls = 0; - uint64_t graph_main_node_v_perm_cont_us = 0; - uint64_t graph_main_node_flash_attn_calls = 0; - uint64_t graph_main_node_flash_attn_us = 0; - uint64_t graph_main_node_attn_out_calls = 0; - uint64_t graph_main_node_attn_out_us = 0; - uint64_t graph_main_node_ffn_calls = 0; - uint64_t graph_main_node_ffn_us = 0; - uint64_t graph_main_node_result_rows_calls = 0; - uint64_t graph_main_node_result_rows_us = 0; - uint64_t graph_main_node_result_norm_calls = 0; - uint64_t graph_main_node_result_norm_us = 0; - uint64_t graph_main_node_result_calls = 0; - uint64_t graph_main_node_result_us = 0; - uint64_t graph_feature_bytes = 0; - uint64_t graph_pos_bytes = 0; - uint64_t graph_mask_bytes = 0; - uint64_t graph_visible_kv_sum = 0; - uint64_t graph_visible_kv_max = 0; - uint64_t graph_pos_fallbacks = 0; - uint64_t graph_pos_non_monotonic = 0; - uint64_t graph_shape_failures = 0; - uint64_t graph_mask_overflow = 0; - - int32_t last_n_rows = 0; - int32_t last_width = 0; - int32_t last_cross_ctx = 0; - int32_t last_left_pad = 0; - int32_t last_n_tokens = 0; - int32_t last_n_kv_total = 0; - int32_t last_kv_cache_host_layers = 0; - int32_t capture_prompt_last_rows = 0; - int32_t capture_prompt_last_width = 0; - int32_t capture_verify_last_rows = 0; - int32_t capture_verify_last_width = 0; - llama_pos last_pos_first = -1; - llama_pos last_pos_last = -1; -}; - -struct llama_dflash_window_update { - uint64_t version = 0; - int32_t keep_rows = 0; - int32_t append_rows = 0; - bool replace = false; - const float * append_features = nullptr; - size_t append_floats = 0; -}; - -struct llama_dflash_kv_cache_transition { - bool cache_up_to_date = false; - bool rebuild_cache = false; - int32_t append_rows = 0; - int32_t next_n_filled = 0; - int32_t next_write_pos = 0; -}; - -static inline llama_dflash_kv_cache_transition llama_plan_dflash_kv_cache_transition( - int32_t cross_ctx, - int32_t current_n_filled, - int32_t current_write_pos, - bool cache_valid, - uint64_t applied_window_version, - uint64_t target_window_version, - int32_t keep_rows, - int32_t append_rows, - bool replace, - int32_t n_rows) { - llama_dflash_kv_cache_transition plan; - - const int32_t safe_cross_ctx = std::max(1, cross_ctx); - const int32_t bounded_n_filled = std::clamp(current_n_filled, 0, safe_cross_ctx); - const int32_t bounded_append_rows = std::clamp(append_rows, 0, n_rows); - const int32_t bounded_keep_rows = std::clamp(keep_rows, 0, n_rows); - const int32_t expected_keep_rows = std::min(bounded_n_filled, std::max(0, safe_cross_ctx - bounded_append_rows)); - - plan.cache_up_to_date = cache_valid && applied_window_version == target_window_version; - plan.rebuild_cache = !cache_valid || replace || bounded_append_rows <= 0 || bounded_append_rows > n_rows; - if (!plan.rebuild_cache && bounded_keep_rows != expected_keep_rows) { - plan.rebuild_cache = true; - } - - plan.append_rows = bounded_append_rows; - if (plan.cache_up_to_date) { - plan.next_n_filled = bounded_n_filled; - plan.next_write_pos = safe_cross_ctx > 0 - ? ((current_write_pos % safe_cross_ctx) + safe_cross_ctx) % safe_cross_ctx - : 0; - } else if (plan.rebuild_cache) { - plan.next_n_filled = std::min(safe_cross_ctx, n_rows); - plan.next_write_pos = plan.next_n_filled % safe_cross_ctx; - } else { - plan.next_n_filled = std::min(safe_cross_ctx, bounded_n_filled + bounded_append_rows); - plan.next_write_pos = (current_write_pos + bounded_append_rows) % safe_cross_ctx; - } - - return plan; -} - -llama_dflash_kv_cache_transition llama_plan_dflash_kv_cache_transition_for_ctx( - const struct llama_context * ctx, - const llama_dflash_window_update & window_update, - int32_t n_rows); +#include "llama-spec-features-dflash.h" uint32_t llama_mtp_state_n_embd(const struct llama_context * ctx); -void llama_dflash_profile_reset(struct llama_context * ctx); - -void llama_reset_dflash_kv_cache_state(struct llama_context * ctx); - -void llama_set_dflash_visible_cross_ctx( - struct llama_context * ctx, - int32_t cross_ctx); - -int32_t llama_get_dflash_visible_cross_ctx( - const struct llama_context * ctx); - -bool llama_dflash_profile_get_stats( - const struct llama_context * ctx, - llama_dflash_profile_stats * stats); - -int32_t llama_model_dflash_block_size(const struct llama_model * model); - -int32_t llama_model_dflash_mask_token_id(const struct llama_model * model); - -int32_t llama_model_dflash_n_target_layers(const struct llama_model * model); - -int32_t llama_model_dflash_n_target_features(const struct llama_model * model); - -int32_t llama_model_dflash_target_layer_ids( - const struct llama_model * model, - int32_t * layer_ids, - int32_t capacity); - -enum llama_dflash_io_mode { - LLAMA_DFLASH_IO_MODE_INVALID = 0, - LLAMA_DFLASH_IO_MODE_SHARED, - LLAMA_DFLASH_IO_MODE_SELF_CONTAINED, - LLAMA_DFLASH_IO_MODE_MIXED, -}; - -int32_t llama_model_dflash_target_mask_token_id(const struct llama_model * model); - -int32_t llama_model_dflash_io_mode( - const struct llama_model * draft_model, - const struct llama_model * target_model); - -const struct ggml_tensor * llama_model_dflash_output_tensor( - const struct llama_model * model); - -bool llama_model_dflash_io_tensors_match( - const struct llama_model * draft_model, - int32_t n_embd, - int32_t n_vocab); - -bool llama_model_share_dflash_io_tensors( - struct llama_model * draft_model, - const struct llama_model * target_model); - bool llama_set_draft_input_hidden_state_copy( struct llama_context * ctx, const float * hidden_state, size_t n_floats); -bool llama_set_dflash_target_features_copy( - struct llama_context * ctx, - const float * target_features, - size_t n_floats, - int32_t n_rows, - const llama_pos * target_positions, - const llama_dflash_window_update * window_update = nullptr); - -bool llama_set_dflash_target_features_view( - struct llama_context * ctx, - const float * target_features, - size_t n_floats, - int32_t n_rows, - const llama_pos * target_positions, - const llama_dflash_window_update * window_update = nullptr); - -bool llama_set_dflash_capture_layers( - struct llama_context * ctx, - const int32_t * layer_ids, - int32_t n_layers); - -void llama_clear_dflash_capture(struct llama_context * ctx); - -void llama_begin_dflash_capture_batch(struct llama_context * ctx); - -void llama_finish_dflash_capture_batch( - struct llama_context * ctx, - bool is_prompt_warmup); - bool llama_spec_get_hidden_feature_view( struct llama_context * ctx, const llama_batch & batch, llama_spec_feature_view & view); -bool llama_spec_get_dflash_feature_view( - struct llama_context * ctx, - const llama_batch & batch, - llama_spec_feature_view & view); - -bool llama_spec_get_dflash_feature_view_for_seq( - struct llama_context * ctx, - const llama_batch & batch, - llama_seq_id seq_id, - llama_spec_feature_view & view); - bool llama_spec_get_hidden_feature_view_for_seq( struct llama_context * ctx, const llama_batch & batch, @@ -352,8 +55,3 @@ bool llama_spec_copy_hidden_rows_from_output_indices( struct llama_context * ctx, const std::vector & output_indices, std::vector & hidden_rows); - -bool llama_spec_copy_dflash_rows_from_output_indices( - struct llama_context * ctx, - const std::vector & output_indices, - std::vector & hidden_rows); diff --git a/src/llama.cpp b/src/llama.cpp index a1b63a73..75482d80 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -18,6 +18,7 @@ #include "llama-hparams.h" #include "llama-context.h" #include "llama-spec-features.h" +#include "llama-dflash.h" #include "llama-quantize.h" #include "unicode.h" @@ -515,20 +516,6 @@ static bool llama_dflash_main_node_eval_callback(struct ggml_tensor * tensor, bo return prev_result || tracked; } -static bool llama_dflash_use_kv_workspace_experiment() { - return llama_env_flag_enabled("IK_DFLASH_KV_WORKSPACE"); -} - -static void llama_sync_dflash_workspace_if_pending(struct llama_context & lctx) { - if (!lctx.dflash_kv_workspace_sync_pending || lctx.dflash_workspace_sched == nullptr) { - return; - } - - const int64_t t_workspace_sync_us = ggml_time_us(); - ggml_backend_sched_synchronize(lctx.dflash_workspace_sched); - lctx.dflash_profile.graph_kv_workspace_sync_us += (uint64_t) (ggml_time_us() - t_workspace_sync_us); - lctx.dflash_kv_workspace_sync_pending = false; -} // extract ip and port from RPC[ip:port] for rpc and keep other device names static std::vector extract_device_from_rpc_device(std::vector devices) { @@ -924,259 +911,6 @@ void llama_context::reset_scheduler() { prev_mtp.reset(); } -static ggml_backend_buffer_type_t llama_dflash_kv_cache_layer_buft(const llama_context & lctx, int32_t il) { - if (il >= 0 && (size_t) il < lctx.model.buft_layer.size() && lctx.model.buft_layer[(size_t) il].buft != nullptr) { - return lctx.model.buft_layer[(size_t) il].buft; - } - - if (il >= 0 && (size_t) il < lctx.model.layers.size()) { - const ggml_tensor * wk = lctx.model.layers[(size_t) il].wk; - if (wk != nullptr && wk->buffer != nullptr) { - return ggml_backend_buffer_get_type(wk->buffer); - } - } - - return llama_default_buffer_type_cpu(true); -} - -static ggml_backend_t llama_backend_for_tensor(const llama_context & lctx, const ggml_tensor * tensor) { - if (tensor == nullptr) { - return nullptr; - } - - ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; - if (buf == nullptr) { - return nullptr; - } - - ggml_backend_buffer_type_t buft = ggml_backend_buffer_get_type(buf); - for (ggml_backend_t backend : lctx.backends) { - ggml_backend_buffer_type_t backend_buft = ggml_backend_is_cpu(backend) - ? llama_default_buffer_type_cpu(true) - : ggml_backend_get_default_buffer_type(backend); - if (backend_buft == buft) { - return backend; - } - } - - return nullptr; -} - -bool llama_context::ensure_dflash_kv_cache_tensors(int32_t cross_ctx) { - const bool use_kv_workspace = llama_env_flag_enabled("IK_DFLASH_KV_WORKSPACE"); - const int32_t target_cross_ctx = std::max(1, cross_ctx); - const int32_t target_token_capacity = std::max(1, (int32_t) model.hparams.dflash_block_size); - const int32_t target_workspace_n_kv_total = GGML_PAD(target_cross_ctx + target_token_capacity, cparams.flash_attn ? 256 : 32); - const int32_t n_layer = model.hparams.n_layer; - const int64_t n_embd_head_k = model.hparams.n_embd_head_k(0); - const int64_t n_embd_head_v = model.hparams.n_embd_head_v(0); - const int64_t n_head_kv = model.hparams.n_head_kv(); - - if (dflash_cache_ctx != nullptr && !dflash_k_ctx_cache.empty()) { - const bool cache_matches = (int32_t) dflash_k_ctx_cache.size() == n_layer && - dflash_k_ctx_cache.front() != nullptr && - (int32_t) dflash_k_ctx_cache.front()->ne[2] == target_cross_ctx; - const bool workspace_matches = use_kv_workspace - ? ((int32_t) dflash_k_ctx_workspace.size() == n_layer && - dflash_k_ctx_workspace.front() != nullptr && - (int32_t) dflash_k_ctx_workspace.front()->ne[1] == target_workspace_n_kv_total) - : dflash_k_ctx_workspace.empty() && dflash_v_ctx_workspace.empty(); - - if (cache_matches && workspace_matches) { - return true; - } - - free_dflash_kv_cache_tensors(); - if (dflash_sched != nullptr) { - ggml_backend_sched_free(dflash_sched); - dflash_sched = nullptr; - } - if (dflash_workspace_sched != nullptr) { - ggml_backend_sched_free(dflash_workspace_sched); - dflash_workspace_sched = nullptr; - } - dflash_kv_graph = nullptr; - dflash_kv_workspace_graph = nullptr; - dflash_kv_graph_rows = 0; - dflash_kv_graph_write_pos = 0; - dflash_kv_workspace_graph_rows = 0; - dflash_kv_workspace_graph_write_pos = 0; - dflash_kv_workspace_reserved_rows = 0; - dflash_buf_compute_meta.clear(); - dflash_workspace_buf_compute_meta.clear(); - } - - ggml_init_params params = { - /*.mem_size =*/ (size_t) ((use_kv_workspace ? 4 : 2) * std::max(1, n_layer)) * ggml_tensor_overhead(), - /*.mem_buffer =*/ nullptr, - /*.no_alloc =*/ true, - }; - - dflash_cache_ctx = ggml_init(params); - if (dflash_cache_ctx == nullptr) { - return false; - } - - dflash_k_ctx_cache.resize((size_t) n_layer); - dflash_v_ctx_cache.resize((size_t) n_layer); - dflash_k_ctx_workspace.clear(); - dflash_v_ctx_workspace.clear(); - if (use_kv_workspace) { - dflash_k_ctx_workspace.resize((size_t) n_layer); - dflash_v_ctx_workspace.resize((size_t) n_layer); - } - dflash_cache_bufs.clear(); - dflash_cache_bufs.reserve((size_t) std::max(1, n_layer) * (use_kv_workspace ? 4 : 2)); - int32_t host_layers = 0; - const char * first_buft_name = nullptr; - const char * last_buft_name = nullptr; - for (int32_t il = 0; il < n_layer; ++il) { - ggml_backend_buffer_type_t layer_buft = llama_dflash_kv_cache_layer_buft(*this, il); - if (ggml_backend_buft_is_host(layer_buft)) { - host_layers++; - } - if (first_buft_name == nullptr) { - first_buft_name = ggml_backend_buft_name(layer_buft); - } - last_buft_name = ggml_backend_buft_name(layer_buft); - - dflash_k_ctx_cache[(size_t) il] = ggml_new_tensor_3d(dflash_cache_ctx, GGML_TYPE_F32, n_embd_head_k, n_head_kv, target_cross_ctx); - dflash_v_ctx_cache[(size_t) il] = ggml_new_tensor_3d(dflash_cache_ctx, GGML_TYPE_F32, n_embd_head_v, n_head_kv, target_cross_ctx); - if (dflash_k_ctx_cache[(size_t) il] == nullptr || dflash_v_ctx_cache[(size_t) il] == nullptr) { - free_dflash_kv_cache_tensors(); - return false; - } - - ggml_set_input(dflash_k_ctx_cache[(size_t) il]); - ggml_set_input(dflash_v_ctx_cache[(size_t) il]); - ggml_format_name(dflash_k_ctx_cache[(size_t) il], "dflash_k_ctx_cache_%d", il); - ggml_format_name(dflash_v_ctx_cache[(size_t) il], "dflash_v_ctx_cache_%d", il); - - const size_t k_bytes = ggml_backend_buft_get_alloc_size(layer_buft, dflash_k_ctx_cache[(size_t) il]); - ggml_backend_buffer_t k_buf = ggml_backend_buft_alloc_buffer(layer_buft, k_bytes); - if (k_buf == nullptr) { - free_dflash_kv_cache_tensors(); - return false; - } - ggml_backend_buffer_set_usage(k_buf, GGML_BACKEND_BUFFER_USAGE_COMPUTE); - ggml_backend_tensor_alloc(k_buf, dflash_k_ctx_cache[(size_t) il], ggml_backend_buffer_get_base(k_buf)); - ggml_backend_buffer_clear(k_buf, 0); - dflash_cache_bufs.push_back(k_buf); - - const size_t v_bytes = ggml_backend_buft_get_alloc_size(layer_buft, dflash_v_ctx_cache[(size_t) il]); - ggml_backend_buffer_t v_buf = ggml_backend_buft_alloc_buffer(layer_buft, v_bytes); - if (v_buf == nullptr) { - free_dflash_kv_cache_tensors(); - return false; - } - ggml_backend_buffer_set_usage(v_buf, GGML_BACKEND_BUFFER_USAGE_COMPUTE); - ggml_backend_tensor_alloc(v_buf, dflash_v_ctx_cache[(size_t) il], ggml_backend_buffer_get_base(v_buf)); - ggml_backend_buffer_clear(v_buf, 0); - dflash_cache_bufs.push_back(v_buf); - - if (use_kv_workspace) { - dflash_k_ctx_workspace[(size_t) il] = ggml_new_tensor_3d(dflash_cache_ctx, GGML_TYPE_F32, n_embd_head_k, target_workspace_n_kv_total, n_head_kv); - dflash_v_ctx_workspace[(size_t) il] = ggml_new_tensor_3d(dflash_cache_ctx, GGML_TYPE_F32, n_embd_head_v, target_workspace_n_kv_total, n_head_kv); - if (dflash_k_ctx_workspace[(size_t) il] == nullptr || dflash_v_ctx_workspace[(size_t) il] == nullptr) { - free_dflash_kv_cache_tensors(); - return false; - } - - ggml_set_input(dflash_k_ctx_workspace[(size_t) il]); - ggml_set_input(dflash_v_ctx_workspace[(size_t) il]); - ggml_format_name(dflash_k_ctx_workspace[(size_t) il], "dflash_k_ctx_workspace_%d", il); - ggml_format_name(dflash_v_ctx_workspace[(size_t) il], "dflash_v_ctx_workspace_%d", il); - - const size_t k_workspace_bytes = ggml_backend_buft_get_alloc_size(layer_buft, dflash_k_ctx_workspace[(size_t) il]); - ggml_backend_buffer_t k_workspace_buf = ggml_backend_buft_alloc_buffer(layer_buft, k_workspace_bytes); - if (k_workspace_buf == nullptr) { - free_dflash_kv_cache_tensors(); - return false; - } - ggml_backend_buffer_set_usage(k_workspace_buf, GGML_BACKEND_BUFFER_USAGE_COMPUTE); - ggml_backend_tensor_alloc(k_workspace_buf, dflash_k_ctx_workspace[(size_t) il], ggml_backend_buffer_get_base(k_workspace_buf)); - ggml_backend_buffer_clear(k_workspace_buf, 0); - dflash_cache_bufs.push_back(k_workspace_buf); - - const size_t v_workspace_bytes = ggml_backend_buft_get_alloc_size(layer_buft, dflash_v_ctx_workspace[(size_t) il]); - ggml_backend_buffer_t v_workspace_buf = ggml_backend_buft_alloc_buffer(layer_buft, v_workspace_bytes); - if (v_workspace_buf == nullptr) { - free_dflash_kv_cache_tensors(); - return false; - } - ggml_backend_buffer_set_usage(v_workspace_buf, GGML_BACKEND_BUFFER_USAGE_COMPUTE); - ggml_backend_tensor_alloc(v_workspace_buf, dflash_v_ctx_workspace[(size_t) il], ggml_backend_buffer_get_base(v_workspace_buf)); - ggml_backend_buffer_clear(v_workspace_buf, 0); - dflash_cache_bufs.push_back(v_workspace_buf); - } - } - - dflash_profile.last_kv_cache_host_layers = host_layers; - dflash_kv_workspace_token_capacity = use_kv_workspace ? target_token_capacity : 0; - dflash_kv_workspace_n_kv_total = use_kv_workspace ? target_workspace_n_kv_total : 0; - llama_reset_dflash_kv_cache_state(this); - LLAMA_LOG_INFO("%s: DFlash K/V cache placement cross_ctx=%d host_layers=%d/%d first=%s last=%s\n", - __func__, - target_cross_ctx, - host_layers, - n_layer, - first_buft_name != nullptr ? first_buft_name : "(none)", - last_buft_name != nullptr ? last_buft_name : "(none)"); - - return true; -} - -void llama_context::free_dflash_kv_cache_tensors() { - dflash_k_ctx_cache.clear(); - dflash_v_ctx_cache.clear(); - dflash_k_ctx_workspace.clear(); - dflash_v_ctx_workspace.clear(); - dflash_kv_cache_write_pos = 0; - dflash_kv_cache_n_filled = 0; - dflash_kv_cache_update_rows = 0; - dflash_kv_cache_reserved_rows = 0; - dflash_kv_cache_view_write_pos = 0; - dflash_kv_cache_view_n_filled = 0; - dflash_kv_cache_applied_window_version = 0; - dflash_kv_cache_valid = false; - dflash_kv_cache_view_valid = false; - dflash_kv_workspace_write_pos = 0; - dflash_kv_workspace_n_filled = 0; - dflash_kv_workspace_reserved_rows = 0; - dflash_kv_workspace_token_capacity = 0; - dflash_kv_workspace_n_kv_total = 0; - dflash_kv_workspace_applied_window_version = 0; - dflash_kv_workspace_valid = false; - dflash_kv_workspace_sync_pending = false; - dflash_kv_graph = nullptr; - dflash_kv_workspace_graph = nullptr; - dflash_kv_graph_rows = 0; - dflash_kv_graph_write_pos = 0; - dflash_kv_workspace_graph_rows = 0; - dflash_kv_workspace_graph_write_pos = 0; - dflash_kv_input_target_features = nullptr; - dflash_kv_input_pos_ctx = nullptr; - dflash_kq_mask_tensor = nullptr; - dflash_kq_mask_swa_tensor = nullptr; - - if (dflash_workspace_sched != nullptr) { - ggml_backend_sched_synchronize(dflash_workspace_sched); - ggml_backend_sched_free(dflash_workspace_sched); - dflash_workspace_sched = nullptr; - } - - for (ggml_backend_buffer_t buf : dflash_cache_bufs) { - if (buf != nullptr) { - ggml_backend_buffer_free(buf); - } - } - dflash_cache_bufs.clear(); - if (dflash_cache_ctx != nullptr) { - ggml_free(dflash_cache_ctx); - dflash_cache_ctx = nullptr; - } -} - bool llama_context::can_reuse_graph(const llama_batch & u_batch) { if (!cparams.graph_reuse) return false; //if (kv_self.save_per_step_ssm) return false; @@ -5631,584 +5365,6 @@ static bool dflash_layer_has_attention_bias(const llama_layer & layer) { layer.bkv != nullptr; } -static bool validate_dflash_graph_contract(const llama_context & lctx) { - const auto & model = lctx.model; - const auto & hparams = model.hparams; - - auto rope_dim_for_layer = [&hparams](int32_t il) -> uint32_t { - if (hparams.rope_dim_per_layer[(size_t) il] != 0) { - return hparams.rope_dim_per_layer[(size_t) il]; - } - - return hparams.swa_layers[(size_t) il] ? hparams.n_rot_swa : hparams.n_rot; - }; - - auto rope_base_for_layer = [&hparams](int32_t il) -> float { - if (hparams.has_rope_freq_base_per_layer) { - return hparams.rope_freq_base_per_layer[(size_t) il]; - } - - return hparams.swa_layers[(size_t) il] ? hparams.rope_freq_base_train_swa : hparams.rope_freq_base_train; - }; - - auto rope_scale_for_layer = [&hparams](int32_t il) -> float { - return hparams.swa_layers[(size_t) il] ? hparams.rope_freq_scale_train_swa : hparams.rope_freq_scale_train; - }; - - const uint32_t ref_n_head = hparams.n_head(0); - const uint32_t ref_n_head_kv = hparams.n_head_kv(0); - const uint32_t ref_n_embd_head_k = hparams.n_embd_head_k(0); - const uint32_t ref_n_embd_head_v = hparams.n_embd_head_v(0); - const uint32_t ref_rope_dim = rope_dim_for_layer(0); - const float ref_rope_base = rope_base_for_layer(0); - const float ref_rope_scale = rope_scale_for_layer(0); - - for (int32_t il = 0; il < (int32_t) hparams.n_layer; ++il) { - if (hparams.n_head((uint32_t) il) != ref_n_head || - hparams.n_head_kv((uint32_t) il) != ref_n_head_kv || - hparams.n_embd_head_k(il) != ref_n_embd_head_k || - hparams.n_embd_head_v(il) != ref_n_embd_head_v) { - LLAMA_LOG_ERROR("%s: DFlash graph assumes layer-invariant head config, but layer %d differs (n_head=%u/%u n_head_kv=%u/%u head_k=%u/%u head_v=%u/%u)\n", - __func__, - il, - hparams.n_head((uint32_t) il), ref_n_head, - hparams.n_head_kv((uint32_t) il), ref_n_head_kv, - hparams.n_embd_head_k(il), ref_n_embd_head_k, - hparams.n_embd_head_v(il), ref_n_embd_head_v); - return false; - } - - const uint32_t rope_dim = rope_dim_for_layer(il); - const float rope_base = rope_base_for_layer(il); - const float rope_scale = rope_scale_for_layer(il); - if (rope_dim != ref_rope_dim || std::fabs(rope_base - ref_rope_base) > 1e-6f || std::fabs(rope_scale - ref_rope_scale) > 1e-6f) { - LLAMA_LOG_ERROR("%s: DFlash graph assumes layer-invariant RoPE config, but layer %d differs (dim=%u/%u base=%g/%g scale=%g/%g)\n", - __func__, - il, - rope_dim, ref_rope_dim, - (double) rope_base, (double) ref_rope_base, - (double) rope_scale, (double) ref_rope_scale); - return false; - } - - if (model.layers[(size_t) il].attn_norm == nullptr || - model.layers[(size_t) il].attn_q_norm == nullptr || - model.layers[(size_t) il].attn_k_norm == nullptr) { - LLAMA_LOG_ERROR("%s: DFlash graph requires attn_norm, attn_q_norm, and attn_k_norm weights, but layer %d is missing one or more of them\n", - __func__, il); - return false; - } - - const bool has_q_norm = model.layers[(size_t) il].attn_q_norm != nullptr; - const bool has_k_norm = model.layers[(size_t) il].attn_k_norm != nullptr; - if (has_q_norm != has_k_norm) { - LLAMA_LOG_ERROR("%s: DFlash graph requires symmetric Q/K norm presence, but layer %d has q_norm=%d k_norm=%d\n", - __func__, il, (int) has_q_norm, (int) has_k_norm); - return false; - } - - if (model.layers[(size_t) il].attn_norm_b != nullptr || - model.layers[(size_t) il].attn_q_norm_b != nullptr || - model.layers[(size_t) il].attn_k_norm_b != nullptr) { - LLAMA_LOG_ERROR("%s: DFlash graph does not implement norm-bias tensors, but layer %d requires attn_norm_b/q_norm_b/k_norm_b\n", - __func__, il); - return false; - } - - if (dflash_layer_has_attention_bias(model.layers[(size_t) il])) { - LLAMA_LOG_ERROR("%s: DFlash graph does not implement attention bias tensors, but layer %d requires them\n", - __func__, il); - return false; - } - } - - return true; -} - -static bool prepare_dflash_graph_inputs( - struct llama_context & lctx, - uint32_t n_tokens) { - const bool use_kv_cache = llama_env_flag_enabled("IK_DFLASH_KV_CACHE"); - const bool use_kv_workspace = use_kv_cache && llama_dflash_use_kv_workspace_experiment(); - const bool kv_node_timing = llama_env_flag_enabled("IK_DFLASH_KV_NODE_TIMING"); - auto & profile = lctx.dflash_profile; - const int32_t cross_ctx = lctx.dflash_visible_cross_ctx > 0 - ? lctx.dflash_visible_cross_ctx - : std::max(1, (int32_t) lctx.cparams.n_ctx - (int32_t) lctx.model.hparams.dflash_block_size); - ggml_tensor * kq_mask = lctx.dflash_kq_mask_tensor; - ggml_tensor * kq_mask_swa = lctx.dflash_kq_mask_swa_tensor; - - if (kq_mask == nullptr) { - LLAMA_LOG_ERROR("%s: DFlash graph inputs are not initialized\n", __func__); - return false; - } - - if (!validate_dflash_graph_contract(lctx)) { - profile.graph_shape_failures++; - return false; - } - - if (use_kv_cache) { - if (!lctx.ensure_dflash_kv_cache_tensors(cross_ctx) || lctx.dflash_k_ctx_cache.empty() || lctx.dflash_v_ctx_cache.empty()) { - LLAMA_LOG_ERROR("%s: DFlash K/V cache inputs are not initialized\n", __func__); - return false; - } - } else if (lctx.inp_dflash_target_features == nullptr || lctx.inp_dflash_pos_ctx == nullptr) { - LLAMA_LOG_ERROR("%s: DFlash inline inputs are not initialized\n", __func__); - return false; - } - - const float * src = lctx.dflash_target_features; - const float * append_src = lctx.dflash_target_append_features; - const llama_pos * src_pos = lctx.dflash_target_positions; - const size_t total_floats = lctx.dflash_target_features_n_floats; - const size_t append_floats = lctx.dflash_target_append_features_n_floats; - const size_t total_positions = lctx.dflash_target_positions_n; - const int32_t n_rows = lctx.dflash_target_features_n_rows; - const int32_t append_rows_available = lctx.dflash_target_append_features_n_rows; - const int32_t width = (int32_t) lctx.model.hparams.dflash_n_target_features; - const int32_t graph_cross_ctx = use_kv_cache - ? (lctx.dflash_k_ctx_cache.front() != nullptr ? (int32_t) lctx.dflash_k_ctx_cache.front()->ne[2] : 0) - : (lctx.inp_dflash_target_features != nullptr ? (int32_t) lctx.inp_dflash_target_features->ne[1] : 0); - const int32_t n_mask_tokens = (int32_t) kq_mask->ne[1]; - const int32_t n_kv_total = (int32_t) kq_mask->ne[0]; - const int64_t t_total_us = ggml_time_us(); - - profile.graph_prepare_calls++; - profile.last_n_rows = n_rows; - profile.last_width = width; - profile.last_cross_ctx = cross_ctx; - profile.last_n_tokens = (int32_t) n_tokens; - profile.last_n_kv_total = n_kv_total; - - if (use_kv_workspace) { - llama_sync_dflash_workspace_if_pending(lctx); - } - - if (graph_cross_ctx != cross_ctx) { - profile.graph_shape_failures++; - - LLAMA_LOG_ERROR("%s: DFlash graph cross_ctx drift (graph=%d configured=%d)\n", - __func__, graph_cross_ctx, cross_ctx); - return false; - } - if (n_rows <= 0) { - profile.graph_shape_failures++; - LLAMA_LOG_ERROR("%s: missing DFlash target feature rows\n", __func__); - return false; - } - - const bool have_full_src = src != nullptr && total_floats == (size_t) n_rows * (size_t) width; - if (n_rows > cross_ctx || (src != nullptr && !have_full_src)) { - profile.graph_shape_failures++; - LLAMA_LOG_ERROR("%s: invalid DFlash target feature shape (rows=%d width=%d floats=%zu cross_ctx=%d)\n", - __func__, n_rows, width, total_floats, cross_ctx); - return false; - } - - if (!use_kv_cache && !have_full_src) { - profile.graph_shape_failures++; - LLAMA_LOG_ERROR("%s: missing contiguous DFlash target features for inline path\n", __func__); - return false; - } - - if (n_kv_total < cross_ctx + (int32_t) n_tokens) { - profile.graph_mask_overflow++; - LLAMA_LOG_ERROR("%s: invalid DFlash mask shape (n_kv_total=%d < cross_ctx+n_tokens=%d)\n", - __func__, n_kv_total, cross_ctx + (int32_t) n_tokens); - return false; - } - - const int32_t left_pad = cross_ctx - n_rows; - profile.last_left_pad = left_pad; - if (!use_kv_cache) { - const size_t padded_floats = (size_t) cross_ctx * (size_t) width; - const size_t dst_offset = (size_t) left_pad * (size_t) width; - const int64_t t_feature_us = ggml_time_us(); - if (lctx.dflash_target_features_padded.size() != padded_floats) { - lctx.dflash_target_features_padded.resize(padded_floats); - } - if (left_pad == 0 && total_floats == padded_floats) { - std::copy(src, src + total_floats, lctx.dflash_target_features_padded.begin()); - } else { - if (dst_offset > 0) { - std::fill(lctx.dflash_target_features_padded.begin(), - lctx.dflash_target_features_padded.begin() + (ptrdiff_t) dst_offset, 0.0f); - } - std::copy(src, src + total_floats, lctx.dflash_target_features_padded.begin() + (ptrdiff_t) dst_offset); - } - profile.graph_feature_copy_us += (uint64_t) (ggml_time_us() - t_feature_us); - profile.graph_feature_bytes += padded_floats * sizeof(float); - } - - const int64_t t_pos_us = ggml_time_us(); - lctx.dflash_pos_ctx_data.resize((size_t) cross_ctx); - std::fill(lctx.dflash_pos_ctx_data.begin(), lctx.dflash_pos_ctx_data.end(), 0); - if (src_pos == nullptr || total_positions != (size_t) n_rows) { - profile.graph_pos_fallbacks++; - profile.graph_shape_failures++; - profile.last_pos_first = -1; - profile.last_pos_last = -1; - if (profile.graph_pos_fallbacks <= 3) { - LLAMA_LOG_ERROR("%s: missing DFlash target positions (rows=%d positions=%zu cross_ctx=%d)\n", - __func__, n_rows, total_positions, cross_ctx); - } - return false; - } - - profile.last_pos_first = src_pos[0]; - profile.last_pos_last = src_pos[n_rows - 1]; - for (int32_t i = 1; i < n_rows; ++i) { - if (src_pos[i] <= src_pos[i - 1]) { - profile.graph_pos_non_monotonic++; - profile.graph_shape_failures++; - if (profile.graph_pos_non_monotonic <= 3) { - LLAMA_LOG_ERROR("%s: DFlash target positions are not strictly increasing (rows=%d first=%d last=%d)\n", - __func__, n_rows, (int) src_pos[0], (int) src_pos[n_rows - 1]); - } - return false; - } - } - std::copy(src_pos, src_pos + n_rows, lctx.dflash_pos_ctx_data.begin() + (ptrdiff_t) left_pad); - profile.graph_pos_copy_us += (uint64_t) (ggml_time_us() - t_pos_us); - profile.graph_pos_bytes += lctx.dflash_pos_ctx_data.size() * sizeof(llama_pos); - - if (use_kv_cache) { - const llama_dflash_kv_cache_transition cache_plan = llama_plan_dflash_kv_cache_transition( - cross_ctx, - lctx.dflash_kv_cache_n_filled, - lctx.dflash_kv_cache_write_pos, - lctx.dflash_kv_cache_valid, - lctx.dflash_kv_cache_applied_window_version, - lctx.dflash_target_window_version, - lctx.dflash_target_window_keep_rows, - lctx.dflash_target_window_append_rows, - lctx.dflash_target_window_replace, - n_rows); - - const bool have_append_src = append_src != nullptr && - append_rows_available == cache_plan.append_rows && - append_floats == (size_t) cache_plan.append_rows * (size_t) width; - - const int32_t update_rows = cache_plan.cache_up_to_date - ? 0 - : (cache_plan.rebuild_cache ? n_rows : cache_plan.append_rows); - const size_t max_nodes = lctx.model.max_nodes((int) std::max(1, cross_ctx)) + 24 * lctx.model.hparams.n_layer; - const size_t meta_size = ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false); - if (lctx.dflash_buf_compute_meta.size() != meta_size) { - lctx.dflash_buf_compute_meta.resize(meta_size); - } - - if (lctx.dflash_sched == nullptr || lctx.dflash_kv_cache_reserved_rows != cross_ctx) { - std::vector backend_buft; - backend_buft.reserve(lctx.backends.size()); - for (auto * backend : lctx.backends) { - if (ggml_backend_is_cpu(backend)) { - backend_buft.push_back(llama_default_buffer_type_cpu(true)); - } else { - backend_buft.push_back(ggml_backend_get_default_buffer_type(backend)); - } - } - - if (lctx.dflash_sched != nullptr) { - ggml_backend_sched_free(lctx.dflash_sched); - lctx.dflash_sched = nullptr; - } - lctx.dflash_kv_graph = nullptr; - lctx.dflash_kv_graph_rows = 0; - lctx.dflash_kv_graph_write_pos = 0; - - const int32_t saved_update_rows = lctx.dflash_kv_cache_update_rows; - lctx.dflash_kv_cache_update_rows = cross_ctx; - const int64_t t_build_us = ggml_time_us(); - ggml_cgraph * gf_reserve = llm_build_context::llama_build_graph_dflash_kv_cache(lctx); - profile.graph_kv_cache_build_us += (uint64_t) (ggml_time_us() - t_build_us); - lctx.dflash_kv_cache_update_rows = saved_update_rows; - if (gf_reserve == nullptr) { - profile.graph_shape_failures++; - LLAMA_LOG_ERROR("%s: failed to build DFlash K/V cache reserve graph\n", __func__); - return false; - } - - const int64_t t_reserve_us = ggml_time_us(); - lctx.dflash_sched = ggml_backend_sched_new(lctx.backends.data(), backend_buft.data(), lctx.backends.size(), max_nodes, false); - const bool reserved = lctx.dflash_sched != nullptr && ggml_backend_sched_reserve(lctx.dflash_sched, gf_reserve); - profile.graph_kv_cache_reserve_us += (uint64_t) (ggml_time_us() - t_reserve_us); - if (!reserved) { - profile.graph_shape_failures++; - LLAMA_LOG_ERROR("%s: failed to initialize DFlash K/V scheduler\n", __func__); - return false; - } - lctx.dflash_kv_cache_reserved_rows = cross_ctx; - } - - if (update_rows > 0) { - const float * update_src = nullptr; - if (have_append_src && update_rows == cache_plan.append_rows) { - update_src = append_src; - } else if (have_full_src) { - update_src = src + (size_t) (n_rows - update_rows) * (size_t) width; - } - const llama_pos * update_pos = src_pos + (n_rows - update_rows); - - if (update_src == nullptr) { - profile.graph_shape_failures++; - LLAMA_LOG_ERROR("%s: missing DFlash appended target features for cached update (rows=%d append_rows=%d floats=%zu)\n", - __func__, n_rows, update_rows, append_floats); - return false; - } - - if (cache_plan.rebuild_cache) { - llama_reset_dflash_kv_cache_state(&lctx); - } - - lctx.dflash_kv_cache_update_rows = update_rows; - ggml_cgraph * gf_kv = nullptr; - const bool can_reuse_kv_graph = lctx.dflash_kv_graph != nullptr && - lctx.dflash_kv_graph_rows == update_rows && - lctx.dflash_kv_graph_write_pos == lctx.dflash_kv_cache_write_pos; - if (can_reuse_kv_graph) { - gf_kv = lctx.dflash_kv_graph; - } else { - const int64_t t_build_us = ggml_time_us(); - gf_kv = llm_build_context::llama_build_graph_dflash_kv_cache(lctx); - profile.graph_kv_cache_build_us += (uint64_t) (ggml_time_us() - t_build_us); - if (gf_kv == nullptr || lctx.dflash_kv_input_target_features == nullptr || lctx.dflash_kv_input_pos_ctx == nullptr) { - profile.graph_shape_failures++; - LLAMA_LOG_ERROR("%s: failed to build DFlash K/V cache graph\n", __func__); - return false; - } - - const int64_t t_reset_us = ggml_time_us(); - ggml_backend_sched_reset(lctx.dflash_sched); - profile.graph_kv_cache_reset_us += (uint64_t) (ggml_time_us() - t_reset_us); - - const int64_t t_alloc_us = ggml_time_us(); - ggml_backend_sched_alloc_graph(lctx.dflash_sched, gf_kv); - profile.graph_kv_cache_alloc_us += (uint64_t) (ggml_time_us() - t_alloc_us); - - lctx.dflash_kv_graph = gf_kv; - lctx.dflash_kv_graph_rows = update_rows; - lctx.dflash_kv_graph_write_pos = lctx.dflash_kv_cache_write_pos; - } - - ggml_backend_t kv_feature_backend = llama_backend_for_tensor(lctx, lctx.dflash_kv_input_target_features); - const int64_t t_feature_upload_us = ggml_time_us(); - if (kv_feature_backend != nullptr) { - ggml_backend_tensor_set_async(kv_feature_backend, lctx.dflash_kv_input_target_features, update_src, 0, ggml_nbytes(lctx.dflash_kv_input_target_features)); - } else { - ggml_backend_tensor_set(lctx.dflash_kv_input_target_features, update_src, 0, ggml_nbytes(lctx.dflash_kv_input_target_features)); - } - profile.graph_kv_cache_feature_upload_us += (uint64_t) (ggml_time_us() - t_feature_upload_us); - profile.graph_feature_bytes += (size_t) update_rows * (size_t) width * sizeof(float); - - ggml_backend_t kv_pos_backend = llama_backend_for_tensor(lctx, lctx.dflash_kv_input_pos_ctx); - const int64_t t_pos_upload_us = ggml_time_us(); - if (kv_pos_backend != nullptr) { - ggml_backend_tensor_set_async(kv_pos_backend, lctx.dflash_kv_input_pos_ctx, update_pos, 0, ggml_nbytes(lctx.dflash_kv_input_pos_ctx)); - } else { - ggml_backend_tensor_set(lctx.dflash_kv_input_pos_ctx, update_pos, 0, ggml_nbytes(lctx.dflash_kv_input_pos_ctx)); - } - profile.graph_kv_cache_pos_upload_us += (uint64_t) (ggml_time_us() - t_pos_upload_us); - - const int64_t t_kv_cache_us = ggml_time_us(); - llama_dflash_kv_node_profiler kv_node_profiler; - if (kv_node_timing) { - kv_node_profiler.profile = &profile; - ggml_backend_sched_set_eval_callback(lctx.dflash_sched, llama_dflash_kv_node_eval_callback, &kv_node_profiler); - } - llama_graph_compute_sched(lctx, lctx.dflash_sched, gf_kv, lctx.cparams.n_threads); - if (kv_node_timing) { - ggml_backend_sched_set_eval_callback(lctx.dflash_sched, nullptr, nullptr); - } - profile.graph_kv_cache_compute_us += (uint64_t) (ggml_time_us() - t_kv_cache_us); - - const int64_t t_sync_us = ggml_time_us(); - ggml_backend_sched_synchronize(lctx.dflash_sched); - profile.graph_kv_cache_sync_us += (uint64_t) (ggml_time_us() - t_sync_us); - profile.graph_kv_cache_calls++; - - lctx.dflash_kv_cache_n_filled = std::min(cross_ctx, lctx.dflash_kv_cache_n_filled + update_rows); - lctx.dflash_kv_cache_write_pos = (lctx.dflash_kv_cache_write_pos + update_rows) % cross_ctx; - lctx.dflash_kv_cache_applied_window_version = lctx.dflash_target_window_version; - lctx.dflash_kv_cache_valid = true; - lctx.dflash_kv_cache_view_n_filled = lctx.dflash_kv_cache_n_filled; - lctx.dflash_kv_cache_view_write_pos = lctx.dflash_kv_cache_write_pos; - lctx.dflash_kv_cache_view_valid = true; - } - - if (use_kv_workspace && lctx.dflash_kv_cache_view_valid && - !lctx.dflash_k_ctx_workspace.empty() && !lctx.dflash_v_ctx_workspace.empty()) { - const bool need_workspace_refresh = !lctx.dflash_kv_workspace_valid || - lctx.dflash_kv_workspace_n_filled != lctx.dflash_kv_cache_view_n_filled || - lctx.dflash_kv_workspace_write_pos != lctx.dflash_kv_cache_view_write_pos || - lctx.dflash_kv_workspace_applied_window_version != lctx.dflash_kv_cache_applied_window_version; - - if (need_workspace_refresh) { - const size_t max_nodes = lctx.model.max_nodes((int) std::max(1, cross_ctx)) + 16 * lctx.model.hparams.n_layer; - const size_t meta_size = ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false); - if (lctx.dflash_workspace_buf_compute_meta.size() != meta_size) { - lctx.dflash_workspace_buf_compute_meta.resize(meta_size); - } - - ggml_cgraph * gf_workspace = nullptr; - const bool can_reuse_workspace_graph = lctx.dflash_kv_workspace_graph != nullptr && - lctx.dflash_kv_workspace_graph_rows == lctx.dflash_kv_cache_view_n_filled && - lctx.dflash_kv_workspace_graph_write_pos == lctx.dflash_kv_cache_view_write_pos; - - if (can_reuse_workspace_graph) { - gf_workspace = lctx.dflash_kv_workspace_graph; - } else { - const int64_t t_build_us = ggml_time_us(); - gf_workspace = llm_build_context::llama_build_graph_dflash_kv_workspace(lctx); - profile.graph_kv_workspace_build_us += (uint64_t) (ggml_time_us() - t_build_us); - if (gf_workspace == nullptr) { - profile.graph_shape_failures++; - LLAMA_LOG_ERROR("%s: failed to build DFlash K/V workspace graph\n", __func__); - return false; - } - - std::vector backend_buft; - backend_buft.reserve(lctx.backends.size()); - for (auto * backend : lctx.backends) { - if (ggml_backend_is_cpu(backend)) { - backend_buft.push_back(llama_default_buffer_type_cpu(true)); - } else { - backend_buft.push_back(ggml_backend_get_default_buffer_type(backend)); - } - } - - if (lctx.dflash_workspace_sched == nullptr) { - lctx.dflash_workspace_sched = ggml_backend_sched_new(lctx.backends.data(), backend_buft.data(), lctx.backends.size(), max_nodes, false); - } - - if (lctx.dflash_kv_workspace_reserved_rows != cross_ctx) { - const bool saved_view_valid = lctx.dflash_kv_cache_view_valid; - const int32_t saved_view_rows = lctx.dflash_kv_cache_view_n_filled; - const int32_t saved_view_write_pos = lctx.dflash_kv_cache_view_write_pos; - - lctx.dflash_kv_cache_view_valid = true; - lctx.dflash_kv_cache_view_n_filled = cross_ctx; - lctx.dflash_kv_cache_view_write_pos = cross_ctx > 1 ? 1 : 0; - - const int64_t t_reserve_build_us = ggml_time_us(); - ggml_cgraph * gf_workspace_reserve = llm_build_context::llama_build_graph_dflash_kv_workspace(lctx); - profile.graph_kv_workspace_build_us += (uint64_t) (ggml_time_us() - t_reserve_build_us); - - lctx.dflash_kv_cache_view_valid = saved_view_valid; - lctx.dflash_kv_cache_view_n_filled = saved_view_rows; - lctx.dflash_kv_cache_view_write_pos = saved_view_write_pos; - - const int64_t t_reserve_us = ggml_time_us(); - const bool reserved = lctx.dflash_workspace_sched != nullptr && - gf_workspace_reserve != nullptr && - ggml_backend_sched_reserve(lctx.dflash_workspace_sched, gf_workspace_reserve); - profile.graph_kv_workspace_reserve_us += (uint64_t) (ggml_time_us() - t_reserve_us); - if (!reserved) { - profile.graph_shape_failures++; - LLAMA_LOG_ERROR("%s: failed to initialize DFlash K/V workspace scheduler\n", __func__); - return false; - } - - lctx.dflash_kv_workspace_reserved_rows = cross_ctx; - } - - const int64_t t_reset_us = ggml_time_us(); - ggml_backend_sched_reset(lctx.dflash_workspace_sched); - profile.graph_kv_workspace_reset_us += (uint64_t) (ggml_time_us() - t_reset_us); - - const int64_t t_alloc_us = ggml_time_us(); - ggml_backend_sched_alloc_graph(lctx.dflash_workspace_sched, gf_workspace); - profile.graph_kv_workspace_alloc_us += (uint64_t) (ggml_time_us() - t_alloc_us); - - lctx.dflash_kv_workspace_graph = gf_workspace; - lctx.dflash_kv_workspace_graph_rows = lctx.dflash_kv_cache_view_n_filled; - lctx.dflash_kv_workspace_graph_write_pos = lctx.dflash_kv_cache_view_write_pos; - } - - const int64_t t_workspace_us = ggml_time_us(); - llama_graph_compute_sched(lctx, lctx.dflash_workspace_sched, gf_workspace, lctx.cparams.n_threads); - profile.graph_kv_workspace_compute_us += (uint64_t) (ggml_time_us() - t_workspace_us); - lctx.dflash_kv_workspace_sync_pending = true; - profile.graph_kv_workspace_calls++; - - lctx.dflash_kv_workspace_n_filled = lctx.dflash_kv_cache_view_n_filled; - lctx.dflash_kv_workspace_write_pos = lctx.dflash_kv_cache_view_write_pos; - lctx.dflash_kv_workspace_applied_window_version = lctx.dflash_kv_cache_applied_window_version; - lctx.dflash_kv_workspace_valid = true; - } - } - } else { - ggml_backend_tensor_set(lctx.inp_dflash_target_features, lctx.dflash_target_features_padded.data(), 0, ggml_nbytes(lctx.inp_dflash_target_features)); - ggml_backend_tensor_set(lctx.inp_dflash_pos_ctx, lctx.dflash_pos_ctx_data.data(), 0, ggml_nbytes(lctx.inp_dflash_pos_ctx)); - } - - const int64_t t_mask_us = ggml_time_us(); - const int32_t full_visible_first = left_pad; - const int32_t full_visible_last = cross_ctx + (int32_t) n_tokens - 1; - lctx.dflash_kq_mask_data.assign((size_t) n_kv_total * (size_t) n_mask_tokens, -INFINITY); - int32_t visible_kv_max = 0; - for (uint32_t j = 0; j < n_tokens; ++j) { - float * row = lctx.dflash_kq_mask_data.data() + (size_t) j * (size_t) n_kv_total; - const int32_t visible_kv = cross_ctx + (int32_t) n_tokens; - visible_kv_max = std::max(visible_kv_max, visible_kv); - profile.graph_visible_kv_sum += (uint64_t) visible_kv; - for (int32_t i = full_visible_first; i <= full_visible_last; ++i) { - row[i] = 0.0f; - } - } - ggml_backend_tensor_set(kq_mask, lctx.dflash_kq_mask_data.data(), 0, ggml_nbytes(kq_mask)); - profile.graph_mask_build_us += (uint64_t) (ggml_time_us() - t_mask_us); - profile.graph_mask_bytes += ggml_nbytes(kq_mask); - - if (kq_mask_swa != nullptr) { - lctx.dflash_kq_mask_swa_data.assign((size_t) n_kv_total * (size_t) n_mask_tokens, -INFINITY); - const int32_t swa_window = (int32_t) lctx.model.hparams.n_swa; - const int32_t draft_pos_base = (int32_t) profile.last_pos_last; - for (uint32_t j = 0; j < n_tokens; ++j) { - float * row = lctx.dflash_kq_mask_swa_data.data() + (size_t) j * (size_t) n_kv_total; - const int32_t q_pos = draft_pos_base + (int32_t) j; - - for (int32_t k = left_pad; k < cross_ctx; ++k) { - const int32_t k_pos = (int32_t) lctx.dflash_pos_ctx_data[(size_t) k]; - if (q_pos - k_pos < swa_window) { - row[k] = 0.0f; - } - } - - for (int32_t k = cross_ctx; k < cross_ctx + (int32_t) n_tokens; ++k) { - const int32_t block_k = k - cross_ctx; - if (block_k <= (int32_t) j) { - row[k] = 0.0f; - } - } - } - - ggml_backend_tensor_set(kq_mask_swa, lctx.dflash_kq_mask_swa_data.data(), 0, ggml_nbytes(kq_mask_swa)); - profile.graph_mask_bytes += ggml_nbytes(kq_mask_swa); - } - - profile.graph_visible_kv_max = std::max(profile.graph_visible_kv_max, (uint64_t) visible_kv_max); - profile.graph_prepare_total_us += (uint64_t) (ggml_time_us() - t_total_us); - - if (profile.graph_prepare_calls == 1) { - int32_t n_swa_layers = 0; - for (int32_t il = 0; il < lctx.model.hparams.n_layer; ++il) { - n_swa_layers += lctx.model.hparams.swa_layers[(size_t) il] ? 1 : 0; - } - - LLAMA_LOG_INFO("%s: DFlash graph contract rows=%d width=%d cross_ctx=%d n_tokens=%u left_pad=%d n_kv_total=%d draft_n_ctx=%u pos=%s [%d..%d] full_mask=[%d..%d] swa_window=%u swa_layers=%d\n", - __func__, n_rows, width, cross_ctx, n_tokens, left_pad, n_kv_total, lctx.cparams.n_ctx, - (src_pos != nullptr && total_positions == (size_t) n_rows) ? "target" : "synthetic", - (int) profile.last_pos_first, (int) profile.last_pos_last, - full_visible_first, full_visible_last, - lctx.model.hparams.n_swa, - n_swa_layers); - } - - return true; -} - // decode a batch of tokens by evaluating the transformer // // - lctx: llama context @@ -6548,7 +5704,7 @@ static int llama_decode_internal( if (dflash_profile != nullptr) { dflash_profile->decode_prepare_calls++; const int64_t t_prepare_dflash_us = ggml_time_us(); - if (!prepare_dflash_graph_inputs(lctx, n_tokens)) { + if (!llama_prepare_dflash_graph_inputs(lctx, n_tokens)) { dflash_profile->decode_prepare_failures++; dflash_profile->decode_prepare_us += (uint64_t) (ggml_time_us() - t_prepare_dflash_us); return GGML_STATUS_FAILED; From 08e4590dcb00a17e4b11e17b35a0c46a088287d3 Mon Sep 17 00:00:00 2001 From: SamuelOliveirads Date: Thu, 4 Jun 2026 20:45:12 -0300 Subject: [PATCH 09/13] implement gpu argmax --- common/speculative-impl.h | 7 ++++++- include/llama.h | 6 +++++- src/graphs/build_dflash.cpp | 6 ++++++ src/llama-context.h | 6 ++++++ src/llama.cpp | 39 +++++++++++++++++++++++++++++++++++++ 5 files changed, 62 insertions(+), 2 deletions(-) diff --git a/common/speculative-impl.h b/common/speculative-impl.h index 47603461..ccec2e9e 100644 --- a/common/speculative-impl.h +++ b/common/speculative-impl.h @@ -321,7 +321,12 @@ struct common_speculative_state_dflash : public common_speculative_state { result.reserve((size_t) n_keep); const int64_t t_sample_us = ggml_time_us(); for (int32_t i = 0; i < n_keep; ++i) { - result.push_back(common_sampler_sample_speculative(nullptr, ctx_dft, i + 1, nullptr)); + // Use argmax in GPU when available + llama_token id = llama_get_dflash_draft_token_ith(ctx_dft, i); + if (id == LLAMA_TOKEN_NULL) { + id = common_sampler_sample_speculative(nullptr, ctx_dft, i + 1, nullptr); + } + result.push_back(id); } t_draft_sample_us += (uint64_t) (ggml_time_us() - t_sample_us); diff --git a/include/llama.h b/include/llama.h index f89d82ef..754a0643 100644 --- a/include/llama.h +++ b/include/llama.h @@ -53,7 +53,7 @@ #define LLAMA_STATE_SEQ_VERSION 3 #define LLAMA_SERVER_MAGIC 0x6c6d7376u // 'lmsv' -#define LLAMA_SERVER_VERSION 1 +#define LLAMA_SERVER_VERSION 1 #ifdef __cplusplus extern "C" { @@ -1096,6 +1096,10 @@ extern "C" { // returns NULL for invalid ids. LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i); + // Get the argmax token ID for DFlash draft position i without materializing full logits. + // Returns LLAMA_TOKEN_NULL if argmax is not available (falls back to logits path). + LLAMA_API llama_token llama_get_dflash_draft_token_ith(struct llama_context * ctx, int32_t i); + // Get all output token embeddings. // when pooling_type == LLAMA_POOLING_TYPE_NONE or when using a generative model, // the embeddings for which llama_batch.logits[i] != 0 are stored contiguously diff --git a/src/graphs/build_dflash.cpp b/src/graphs/build_dflash.cpp index 80c45c1e..4cbfc147 100644 --- a/src/graphs/build_dflash.cpp +++ b/src/graphs/build_dflash.cpp @@ -565,5 +565,11 @@ ggml_cgraph * llm_build_context::build_dflash() { cb(result, "result_output", -1); ggml_build_forward_expand(gf, result); + lctx.dflash_draft_tokens_tensor = nullptr; + ggml_tensor * draft_tokens = ggml_argmax(ctx0, result); + ggml_set_name(draft_tokens, "draft_argmax"); + ggml_build_forward_expand(gf, draft_tokens); + lctx.dflash_draft_tokens_tensor = draft_tokens; + return gf; } diff --git a/src/llama-context.h b/src/llama-context.h index ebd4ded3..fec7edbb 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -393,6 +393,12 @@ struct llama_context { int32_t & dflash_visible_cross_ctx = dflash.visible_cross_ctx; std::vector & dflash_k_ctx_cache = dflash.kv.k_ctx_cache; std::vector & dflash_v_ctx_cache = dflash.kv.v_ctx_cache; + + // Argmax token IDs from the DFlash draft graph, computed via GPU argmax. + // Populated in llama_decode_internal after graph compute. + std::vector dflash_draft_tokens; + struct ggml_tensor * dflash_draft_tokens_tensor = nullptr; + std::vector & dflash_k_ctx_workspace = dflash.kv.k_ctx_workspace; std::vector & dflash_v_ctx_workspace = dflash.kv.v_ctx_workspace; struct ggml_context * & dflash_cache_ctx = dflash.kv.cache_ctx; diff --git a/src/llama.cpp b/src/llama.cpp index 75482d80..0e2430e4 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -5716,6 +5716,17 @@ static int llama_decode_internal( struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1]; struct ggml_tensor * embd = nullptr; + // DFlash GPU argmax draft_argmax node + if (lctx.dflash_draft_tokens_tensor != nullptr && + strcmp(res->name, "result_output") != 0) { + for (int i = gf->n_nodes - 2; i >= 0; --i) { + if (strcmp(gf->nodes[i]->name, "result_output") == 0) { + res = gf->nodes[i]; + break; + } + } + } + if (lctx.n_outputs == 0) { // no output res = nullptr; @@ -5813,7 +5824,28 @@ static int llama_decode_internal( // ggml_graph_dump_dot(gf, NULL, "llama.dot"); //} + lctx.dflash_draft_tokens.clear(); + if (lctx.dflash_draft_tokens_tensor != nullptr) { + ggml_backend_t backend_argmax = ggml_backend_sched_get_tensor_backend( + lctx.sched, lctx.dflash_draft_tokens_tensor); + if (backend_argmax != nullptr) { + const int64_t n_tokens_argmax = lctx.dflash_draft_tokens_tensor->ne[0]; + lctx.dflash_draft_tokens.resize((size_t) n_tokens_argmax); + ggml_backend_tensor_get_async(backend_argmax, + lctx.dflash_draft_tokens_tensor, + lctx.dflash_draft_tokens.data(), 0, + (size_t) n_tokens_argmax * sizeof(int32_t)); + } + } + // extract logits + { + const bool dflash_skip_logits = (lctx.model.arch == LLM_ARCH_DFLASH_DRAFT + && !lctx.dflash_draft_tokens.empty()); + if (dflash_skip_logits) { + res = nullptr; + } + } if (res) { #if IK_PRINT_TIMING tim1 = ggml_time_us(); @@ -10068,6 +10100,13 @@ float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) { } } +llama_token llama_get_dflash_draft_token_ith(struct llama_context * ctx, int32_t i) { + if ((size_t) i >= ctx->dflash_draft_tokens.size()) { + return LLAMA_TOKEN_NULL; + } + return ctx->dflash_draft_tokens[(size_t) i]; +} + float * llama_get_embeddings(struct llama_context * ctx) { llama_synchronize(ctx); From 3b1a0f88d5829539cfc23ea27656558f137d9839 Mon Sep 17 00:00:00 2001 From: SamuelOliveirads Date: Sat, 13 Jun 2026 20:14:08 -0300 Subject: [PATCH 10/13] Add logging for DFlash statistics and clean up workspace handling --- common/speculative-impl.h | 10 +- common/speculative.cpp | 6 +- src/graphs/build_dflash.cpp | 215 ++-------- src/llama-dflash.cpp | 620 ++++++++++++++--------------- src/llama-spec-features-dflash.cpp | 20 +- 5 files changed, 352 insertions(+), 519 deletions(-) diff --git a/common/speculative-impl.h b/common/speculative-impl.h index 48d810e7..dbf8cfb1 100644 --- a/common/speculative-impl.h +++ b/common/speculative-impl.h @@ -266,7 +266,6 @@ struct common_speculative_state_dflash : public common_speculative_state { return; } - const bool use_kv_cache = dflash_use_kv_cache_experiment(); const float * target_features = nullptr; size_t target_feature_floats = 0; llama_dflash_window_update window_update = { @@ -277,16 +276,13 @@ struct common_speculative_state_dflash : public common_speculative_state { target_window_append_features.empty() ? nullptr : target_window_append_features.data(), target_window_append_features.size(), }; - const llama_dflash_kv_cache_transition cache_plan = use_kv_cache - ? llama_plan_dflash_kv_cache_transition_for_ctx(ctx_dft, window_update, target_window_rows) - : llama_dflash_kv_cache_transition{}; + const llama_dflash_kv_cache_transition cache_plan = + llama_plan_dflash_kv_cache_transition_for_ctx(ctx_dft, window_update, target_window_rows); - if (!use_kv_cache || cache_plan.rebuild_cache) { + if (cache_plan.rebuild_cache) { dflash_materialize_target_window_features(*this); target_features = target_window.data(); target_feature_floats = target_window.size(); - } - if (use_kv_cache && cache_plan.rebuild_cache) { window_update.append_features = target_window.data(); window_update.append_floats = target_window.size(); window_update.append_rows = target_window_rows; diff --git a/common/speculative.cpp b/common/speculative.cpp index d0825387..b491c244 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -309,8 +309,8 @@ static bool dflash_contract_log_enabled() { std::strcmp(env, "off") != 0; } -static bool dflash_use_kv_cache_experiment() { - const char * env = std::getenv("IK_DFLASH_KV_CACHE"); +static bool dflash_stats_log_enabled() { + const char * env = std::getenv("IK_DFLASH_STATS_LOG"); if (env == nullptr || *env == '\0') { return false; } @@ -1318,7 +1318,7 @@ void common_speculative_print_stats(const common_speculative * spec, double slot if (impl->type == COMMON_SPECULATIVE_TYPE_DFLASH) { const auto * dflash_state = dynamic_cast(impl.get()); - if (dflash_state != nullptr) { + if (dflash_state != nullptr && dflash_stats_log_enabled()) { llama_dflash_profile_stats capture_stats; llama_dflash_profile_stats graph_stats; const bool have_capture = llama_dflash_profile_get_stats(dflash_state->ctx_tgt, &capture_stats); diff --git a/src/graphs/build_dflash.cpp b/src/graphs/build_dflash.cpp index 4cbfc147..cbb03403 100644 --- a/src/graphs/build_dflash.cpp +++ b/src/graphs/build_dflash.cpp @@ -3,29 +3,6 @@ #include "../llama-model.h" #include -#include - -static bool dflash_use_kv_cache_experiment() { - const char * env = std::getenv("IK_DFLASH_KV_CACHE"); - if (env == nullptr || *env == '\0') { - return false; - } - - return std::strcmp(env, "0") != 0 && - std::strcmp(env, "false") != 0 && - std::strcmp(env, "off") != 0; -} - -static bool dflash_use_kv_workspace_experiment() { - const char * env = std::getenv("IK_DFLASH_KV_WORKSPACE"); - if (env == nullptr || *env == '\0') { - return false; - } - - return std::strcmp(env, "0") != 0 && - std::strcmp(env, "false") != 0 && - std::strcmp(env, "off") != 0; -} ggml_cgraph * llm_build_context::build_dflash_kv_workspace() { const int64_t n_embd_head_k = hparams.n_embd_head_k(0); @@ -272,13 +249,10 @@ ggml_cgraph * llm_build_context::build_dflash() { const int64_t n_embd_head_v = hparams.n_embd_head_v(0); const int64_t n_target_features = hparams.dflash_n_target_features; auto & profile = lctx.dflash_profile; - const bool use_kv_cache = dflash_use_kv_cache_experiment(); - const bool use_kv_workspace = use_kv_cache && dflash_use_kv_workspace_experiment(); const int64_t ctx_len = lctx.dflash_visible_cross_ctx > 0 ? (int64_t) lctx.dflash_visible_cross_ctx : std::max(1, (int64_t) cparams.n_ctx - (int64_t) hparams.dflash_block_size); - const int32_t cache_rows = use_kv_cache ? std::clamp(lctx.dflash_kv_cache_view_n_filled, 0, (int32_t) ctx_len) : 0; - const int32_t cache_write_pos = use_kv_cache && ctx_len > 0 + const int32_t cache_write_pos = ctx_len > 0 ? ((lctx.dflash_kv_cache_view_write_pos % (int32_t) ctx_len) + (int32_t) ctx_len) % (int32_t) ctx_len : 0; const int64_t n_kv_total = GGML_PAD(ctx_len + n_tokens, flash_attn ? 256 : 32); @@ -286,8 +260,8 @@ ggml_cgraph * llm_build_context::build_dflash() { GGML_ASSERT(n_embd_head_k == n_embd_head_v); GGML_ASSERT(n_target_features > 0); - GGML_ASSERT(!use_kv_cache || lctx.ensure_dflash_kv_cache_tensors((int32_t) ctx_len)); - GGML_ASSERT(!use_kv_cache || (cache_write_pos >= 0 && cache_write_pos < ctx_len)); + GGML_ASSERT(lctx.ensure_dflash_kv_cache_tensors((int32_t) ctx_len)); + GGML_ASSERT(cache_write_pos >= 0 && cache_write_pos < ctx_len); ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes((int) std::max(n_tokens, ctx_len)) + 32 * n_layer, false); @@ -316,22 +290,6 @@ ggml_cgraph * llm_build_context::build_dflash() { dflash_kq_mask_swa = flash_attn ? ggml_cast(ctx0, lctx.inp_dflash_kq_mask_swa, GGML_TYPE_F16) : lctx.inp_dflash_kq_mask_swa; } - ggml_tensor * fused_target = nullptr; - ggml_tensor * pos_ctx = nullptr; - if (!use_kv_cache) { - lctx.inp_dflash_target_features = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_target_features, ctx_len); - ggml_set_input(lctx.inp_dflash_target_features); - cb(lctx.inp_dflash_target_features, "dflash_target_features", -1); - - lctx.inp_dflash_pos_ctx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ctx_len); - ggml_set_input(lctx.inp_dflash_pos_ctx); - cb(lctx.inp_dflash_pos_ctx, "dflash_pos_ctx", -1); - - fused_target = llm_build_lora_mm(lctx, ctx0, model.dflash_fc, lctx.inp_dflash_target_features); - fused_target = llm_build_norm(ctx0, fused_target, hparams, model.dflash_hidden_norm, nullptr, LLM_NORM_RMS, cb, -1); - pos_ctx = lctx.inp_dflash_pos_ctx; - } - ggml_tensor * tok_embd = model.tok_embd; if (tok_embd == nullptr) { tok_embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_Q4_0, n_embd, hparams.n_vocab); @@ -370,147 +328,58 @@ ggml_cgraph * llm_build_context::build_dflash() { Vcur_noise = ggml_reshape_3d(ctx0, Vcur_noise, n_embd_head_v, n_head_kv, n_tokens); cb(Vcur_noise, "Vcur_noise", il); - const int64_t t_cache_read_us = use_kv_cache ? ggml_time_us() : 0; - ggml_tensor * Kcur_ctx = nullptr; - ggml_tensor * Vcur_ctx = nullptr; - const bool have_workspace_ctx = use_kv_workspace && - (size_t) il < lctx.dflash_k_ctx_workspace.size() && - (size_t) il < lctx.dflash_v_ctx_workspace.size() && - lctx.dflash_k_ctx_workspace[(size_t) il] != nullptr && - lctx.dflash_v_ctx_workspace[(size_t) il] != nullptr; + const int64_t t_cache_read_us = ggml_time_us(); + GGML_ASSERT((size_t) il < lctx.dflash_k_ctx_workspace.size()); + GGML_ASSERT((size_t) il < lctx.dflash_v_ctx_workspace.size()); + GGML_ASSERT(lctx.dflash_k_ctx_workspace[(size_t) il] != nullptr); + GGML_ASSERT(lctx.dflash_v_ctx_workspace[(size_t) il] != nullptr); - if (have_workspace_ctx) { - Kcur_ctx = ggml_view_3d(ctx0, lctx.dflash_k_ctx_workspace[(size_t) il], - lctx.dflash_k_ctx_workspace[(size_t) il]->ne[0], - ctx_len, - lctx.dflash_k_ctx_workspace[(size_t) il]->ne[2], - lctx.dflash_k_ctx_workspace[(size_t) il]->nb[1], - lctx.dflash_k_ctx_workspace[(size_t) il]->nb[2], - 0); - Vcur_ctx = ggml_view_3d(ctx0, lctx.dflash_v_ctx_workspace[(size_t) il], - lctx.dflash_v_ctx_workspace[(size_t) il]->ne[0], - ctx_len, - lctx.dflash_v_ctx_workspace[(size_t) il]->ne[2], - lctx.dflash_v_ctx_workspace[(size_t) il]->nb[1], - lctx.dflash_v_ctx_workspace[(size_t) il]->nb[2], - 0); - cb(Kcur_ctx, "Kcur_ctx_workspace", il); - cb(Vcur_ctx, "Vcur_ctx_workspace", il); - } else if (use_kv_cache) { - auto build_ordered_cache_view = [&](ggml_tensor * cache) -> ggml_tensor * { - if (!lctx.dflash_kv_cache_view_valid || cache_rows <= 0) { - return cache; - } + ggml_tensor * Kcur_ctx = ggml_view_3d(ctx0, lctx.dflash_k_ctx_workspace[(size_t) il], + lctx.dflash_k_ctx_workspace[(size_t) il]->ne[0], + ctx_len, + lctx.dflash_k_ctx_workspace[(size_t) il]->ne[2], + lctx.dflash_k_ctx_workspace[(size_t) il]->nb[1], + lctx.dflash_k_ctx_workspace[(size_t) il]->nb[2], + 0); + ggml_tensor * Vcur_ctx = ggml_view_3d(ctx0, lctx.dflash_v_ctx_workspace[(size_t) il], + lctx.dflash_v_ctx_workspace[(size_t) il]->ne[0], + ctx_len, + lctx.dflash_v_ctx_workspace[(size_t) il]->ne[2], + lctx.dflash_v_ctx_workspace[(size_t) il]->nb[1], + lctx.dflash_v_ctx_workspace[(size_t) il]->nb[2], + 0); + cb(Kcur_ctx, "Kcur_ctx_workspace", il); + cb(Vcur_ctx, "Vcur_ctx_workspace", il); - if (cache_rows < ctx_len) { - ggml_tensor * zero_pad = ggml_view_3d(ctx0, cache, - cache->ne[0], - cache->ne[1], - ctx_len - cache_rows, - cache->nb[1], - cache->nb[2], - (size_t) cache_rows * cache->nb[2]); - ggml_tensor * valid = ggml_view_3d(ctx0, cache, - cache->ne[0], - cache->ne[1], - cache_rows, - cache->nb[1], - cache->nb[2], - 0); - return ggml_concat(ctx0, zero_pad, valid, 2); - } + ggml_tensor * Kcur_draft = ggml_cont(ctx0, ggml_permute(ctx0, Kcur_noise, 0, 2, 1, 3)); + ggml_tensor * Vcur_draft = ggml_cont(ctx0, ggml_permute(ctx0, Vcur_noise, 0, 2, 1, 3)); + cb(Kcur_draft, "dflash_main_k_perm_cont", il); + cb(Vcur_draft, "dflash_main_v_perm_cont", il); - if (cache_write_pos == 0) { - return cache; - } + ggml_tensor * Kcur = ggml_concat(ctx0, Kcur_ctx, Kcur_draft, 1); + ggml_tensor * Vcur = ggml_concat(ctx0, Vcur_ctx, Vcur_draft, 1); + cb(Kcur, "dflash_main_k_concat", il); + cb(Vcur, "dflash_main_v_concat", il); - ggml_tensor * tail = ggml_view_3d(ctx0, cache, - cache->ne[0], - cache->ne[1], - ctx_len - cache_write_pos, - cache->nb[1], - cache->nb[2], - (size_t) cache_write_pos * cache->nb[2]); - ggml_tensor * head = ggml_view_3d(ctx0, cache, - cache->ne[0], - cache->ne[1], - cache_write_pos, - cache->nb[1], - cache->nb[2], - 0); - return ggml_concat(ctx0, tail, head, 2); - }; - - Kcur_ctx = build_ordered_cache_view(lctx.dflash_k_ctx_cache[(size_t) il]); - Vcur_ctx = build_ordered_cache_view(lctx.dflash_v_ctx_cache[(size_t) il]); - cb(Kcur_ctx, "Kcur_ctx_cache", il); - cb(Vcur_ctx, "Vcur_ctx_cache", il); - } else { - Kcur_ctx = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, fused_target); - Kcur_ctx = ggml_reshape_3d(ctx0, Kcur_ctx, n_embd_head_k, n_head_kv, ctx_len); - Kcur_ctx = llm_build_norm(ctx0, Kcur_ctx, hparams, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, cb, il); - Kcur_ctx = ggml_rope_ext(ctx0, Kcur_ctx, pos_ctx, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow); - - Vcur_ctx = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, fused_target); - Vcur_ctx = ggml_reshape_3d(ctx0, Vcur_ctx, n_embd_head_v, n_head_kv, ctx_len); - cb(Kcur_ctx, "Kcur_ctx", il); - cb(Vcur_ctx, "Vcur_ctx", il); + if (n_kv_pad > 0) { + Kcur = ggml_pad(ctx0, Kcur, 0, (int) n_kv_pad, 0, 0); + Vcur = ggml_pad(ctx0, Vcur, 0, (int) n_kv_pad, 0, 0); + cb(Kcur, "dflash_main_k_pad", il); + cb(Vcur, "dflash_main_v_pad", il); } - ggml_tensor * Kcur = nullptr; - ggml_tensor * Vcur = nullptr; - if (have_workspace_ctx) { - ggml_tensor * Kcur_draft = ggml_cont(ctx0, ggml_permute(ctx0, Kcur_noise, 0, 2, 1, 3)); - ggml_tensor * Vcur_draft = ggml_cont(ctx0, ggml_permute(ctx0, Vcur_noise, 0, 2, 1, 3)); - cb(Kcur_draft, "dflash_main_k_perm_cont", il); - cb(Vcur_draft, "dflash_main_v_perm_cont", il); - - Kcur = ggml_concat(ctx0, Kcur_ctx, Kcur_draft, 1); - Vcur = ggml_concat(ctx0, Vcur_ctx, Vcur_draft, 1); - cb(Kcur, "dflash_main_k_concat", il); - cb(Vcur, "dflash_main_v_concat", il); - - if (n_kv_pad > 0) { - Kcur = ggml_pad(ctx0, Kcur, 0, (int) n_kv_pad, 0, 0); - Vcur = ggml_pad(ctx0, Vcur, 0, (int) n_kv_pad, 0, 0); - cb(Kcur, "dflash_main_k_pad", il); - cb(Vcur, "dflash_main_v_pad", il); - } - } else { - ggml_tensor * Kcur_concat = ggml_concat(ctx0, Kcur_ctx, Kcur_noise, 2); - ggml_tensor * Vcur_concat = ggml_concat(ctx0, Vcur_ctx, Vcur_noise, 2); - cb(Kcur_concat, "dflash_main_k_concat", il); - cb(Vcur_concat, "dflash_main_v_concat", il); - - Kcur = Kcur_concat; - Vcur = Vcur_concat; - if (n_kv_pad > 0) { - Kcur = ggml_pad(ctx0, Kcur, 0, 0, (int) n_kv_pad, 0); - Vcur = ggml_pad(ctx0, Vcur, 0, 0, (int) n_kv_pad, 0); - cb(Kcur, "dflash_main_k_pad", il); - cb(Vcur, "dflash_main_v_pad", il); - } - } - if (use_kv_cache) { - profile.graph_kv_cache_read_concat_pad_us += (uint64_t) (ggml_time_us() - t_cache_read_us); - profile.graph_kv_cache_read_concat_pad_calls++; - profile.graph_kv_cache_cached_bytes += ggml_nbytes(lctx.dflash_k_ctx_cache[(size_t) il]) + ggml_nbytes(lctx.dflash_v_ctx_cache[(size_t) il]); - } + profile.graph_kv_cache_read_concat_pad_us += (uint64_t) (ggml_time_us() - t_cache_read_us); + profile.graph_kv_cache_read_concat_pad_calls++; + profile.graph_kv_cache_cached_bytes += ggml_nbytes(lctx.dflash_k_ctx_cache[(size_t) il]) + ggml_nbytes(lctx.dflash_v_ctx_cache[(size_t) il]); cb(Qcur, "Qcur", il); ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); - ggml_tensor * k = have_workspace_ctx ? Kcur : ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3)); - ggml_tensor * v = have_workspace_ctx ? Vcur : ggml_cont(ctx0, ggml_permute(ctx0, Vcur, 0, 2, 1, 3)); + ggml_tensor * k = Kcur; + ggml_tensor * v = Vcur; ggml_tensor * dflash_kq_mask_l = (hparams.swa_layers[il] && dflash_kq_mask_swa != nullptr) ? dflash_kq_mask_swa : dflash_kq_mask_full; cb(q, "q", il); - if (!have_workspace_ctx) { - cb(k, "dflash_main_k_perm_cont", il); - cb(v, "dflash_main_v_perm_cont", il); - } cur = ggml_flash_attn_ext(ctx0, q, k, v, dflash_kq_mask_l, kq_scale, hparams.f_max_alibi_bias, hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f); diff --git a/src/llama-dflash.cpp b/src/llama-dflash.cpp index aed84a25..9230840d 100644 --- a/src/llama-dflash.cpp +++ b/src/llama-dflash.cpp @@ -23,6 +23,10 @@ static bool llama_env_flag_enabled_local(const char * name) { std::strcmp(env, "off") != 0; } +static bool llama_dflash_stats_log_enabled() { + return llama_env_flag_enabled_local("IK_DFLASH_STATS_LOG"); +} + enum llama_dflash_kv_node_kind { LLAMA_DFLASH_KV_NODE_NONE = 0, LLAMA_DFLASH_KV_NODE_FUSED_TARGET, @@ -359,10 +363,6 @@ static bool llama_dflash_main_node_eval_callback(struct ggml_tensor * tensor, bo return prev_result || tracked; } -static bool llama_dflash_use_kv_workspace_experiment() { - return llama_env_flag_enabled_local("IK_DFLASH_KV_WORKSPACE"); -} - void llama_sync_dflash_workspace_if_pending(struct llama_context & lctx) { if (!lctx.dflash_kv_workspace_sync_pending || lctx.dflash_workspace_sched == nullptr) { return; @@ -413,7 +413,6 @@ static ggml_backend_t llama_backend_for_tensor(const llama_context & lctx, const } bool llama_context::ensure_dflash_kv_cache_tensors(int32_t cross_ctx) { - const bool use_kv_workspace = llama_env_flag_enabled_local("IK_DFLASH_KV_WORKSPACE"); const int32_t target_cross_ctx = std::max(1, cross_ctx); const int32_t target_token_capacity = std::max(1, (int32_t) model.hparams.dflash_block_size); const int32_t target_workspace_n_kv_total = GGML_PAD(target_cross_ctx + target_token_capacity, cparams.flash_attn ? 256 : 32); @@ -426,11 +425,9 @@ bool llama_context::ensure_dflash_kv_cache_tensors(int32_t cross_ctx) { const bool cache_matches = (int32_t) dflash_k_ctx_cache.size() == n_layer && dflash_k_ctx_cache.front() != nullptr && (int32_t) dflash_k_ctx_cache.front()->ne[2] == target_cross_ctx; - const bool workspace_matches = use_kv_workspace - ? ((int32_t) dflash_k_ctx_workspace.size() == n_layer && - dflash_k_ctx_workspace.front() != nullptr && - (int32_t) dflash_k_ctx_workspace.front()->ne[1] == target_workspace_n_kv_total) - : dflash_k_ctx_workspace.empty() && dflash_v_ctx_workspace.empty(); + const bool workspace_matches = (int32_t) dflash_k_ctx_workspace.size() == n_layer && + dflash_k_ctx_workspace.front() != nullptr && + (int32_t) dflash_k_ctx_workspace.front()->ne[1] == target_workspace_n_kv_total; if (cache_matches && workspace_matches) { return true; @@ -457,7 +454,7 @@ bool llama_context::ensure_dflash_kv_cache_tensors(int32_t cross_ctx) { } ggml_init_params params = { - /*.mem_size =*/ (size_t) ((use_kv_workspace ? 4 : 2) * std::max(1, n_layer)) * ggml_tensor_overhead(), + /*.mem_size =*/ (size_t) (4 * std::max(1, n_layer)) * ggml_tensor_overhead(), /*.mem_buffer =*/ nullptr, /*.no_alloc =*/ true, }; @@ -471,12 +468,10 @@ bool llama_context::ensure_dflash_kv_cache_tensors(int32_t cross_ctx) { dflash_v_ctx_cache.resize((size_t) n_layer); dflash_k_ctx_workspace.clear(); dflash_v_ctx_workspace.clear(); - if (use_kv_workspace) { - dflash_k_ctx_workspace.resize((size_t) n_layer); - dflash_v_ctx_workspace.resize((size_t) n_layer); - } + dflash_k_ctx_workspace.resize((size_t) n_layer); + dflash_v_ctx_workspace.resize((size_t) n_layer); dflash_cache_bufs.clear(); - dflash_cache_bufs.reserve((size_t) std::max(1, n_layer) * (use_kv_workspace ? 4 : 2)); + dflash_cache_bufs.reserve((size_t) std::max(1, n_layer) * 4); int32_t host_layers = 0; const char * first_buft_name = nullptr; const char * last_buft_name = nullptr; @@ -524,54 +519,54 @@ bool llama_context::ensure_dflash_kv_cache_tensors(int32_t cross_ctx) { ggml_backend_buffer_clear(v_buf, 0); dflash_cache_bufs.push_back(v_buf); - if (use_kv_workspace) { - dflash_k_ctx_workspace[(size_t) il] = ggml_new_tensor_3d(dflash_cache_ctx, GGML_TYPE_F32, n_embd_head_k, target_workspace_n_kv_total, n_head_kv); - dflash_v_ctx_workspace[(size_t) il] = ggml_new_tensor_3d(dflash_cache_ctx, GGML_TYPE_F32, n_embd_head_v, target_workspace_n_kv_total, n_head_kv); - if (dflash_k_ctx_workspace[(size_t) il] == nullptr || dflash_v_ctx_workspace[(size_t) il] == nullptr) { - free_dflash_kv_cache_tensors(); - return false; - } - - ggml_set_input(dflash_k_ctx_workspace[(size_t) il]); - ggml_set_input(dflash_v_ctx_workspace[(size_t) il]); - ggml_format_name(dflash_k_ctx_workspace[(size_t) il], "dflash_k_ctx_workspace_%d", il); - ggml_format_name(dflash_v_ctx_workspace[(size_t) il], "dflash_v_ctx_workspace_%d", il); - - const size_t k_workspace_bytes = ggml_backend_buft_get_alloc_size(layer_buft, dflash_k_ctx_workspace[(size_t) il]); - ggml_backend_buffer_t k_workspace_buf = ggml_backend_buft_alloc_buffer(layer_buft, k_workspace_bytes); - if (k_workspace_buf == nullptr) { - free_dflash_kv_cache_tensors(); - return false; - } - ggml_backend_buffer_set_usage(k_workspace_buf, GGML_BACKEND_BUFFER_USAGE_COMPUTE); - ggml_backend_tensor_alloc(k_workspace_buf, dflash_k_ctx_workspace[(size_t) il], ggml_backend_buffer_get_base(k_workspace_buf)); - ggml_backend_buffer_clear(k_workspace_buf, 0); - dflash_cache_bufs.push_back(k_workspace_buf); - - const size_t v_workspace_bytes = ggml_backend_buft_get_alloc_size(layer_buft, dflash_v_ctx_workspace[(size_t) il]); - ggml_backend_buffer_t v_workspace_buf = ggml_backend_buft_alloc_buffer(layer_buft, v_workspace_bytes); - if (v_workspace_buf == nullptr) { - free_dflash_kv_cache_tensors(); - return false; - } - ggml_backend_buffer_set_usage(v_workspace_buf, GGML_BACKEND_BUFFER_USAGE_COMPUTE); - ggml_backend_tensor_alloc(v_workspace_buf, dflash_v_ctx_workspace[(size_t) il], ggml_backend_buffer_get_base(v_workspace_buf)); - ggml_backend_buffer_clear(v_workspace_buf, 0); - dflash_cache_bufs.push_back(v_workspace_buf); + dflash_k_ctx_workspace[(size_t) il] = ggml_new_tensor_3d(dflash_cache_ctx, GGML_TYPE_F32, n_embd_head_k, target_workspace_n_kv_total, n_head_kv); + dflash_v_ctx_workspace[(size_t) il] = ggml_new_tensor_3d(dflash_cache_ctx, GGML_TYPE_F32, n_embd_head_v, target_workspace_n_kv_total, n_head_kv); + if (dflash_k_ctx_workspace[(size_t) il] == nullptr || dflash_v_ctx_workspace[(size_t) il] == nullptr) { + free_dflash_kv_cache_tensors(); + return false; } + + ggml_set_input(dflash_k_ctx_workspace[(size_t) il]); + ggml_set_input(dflash_v_ctx_workspace[(size_t) il]); + ggml_format_name(dflash_k_ctx_workspace[(size_t) il], "dflash_k_ctx_workspace_%d", il); + ggml_format_name(dflash_v_ctx_workspace[(size_t) il], "dflash_v_ctx_workspace_%d", il); + + const size_t k_workspace_bytes = ggml_backend_buft_get_alloc_size(layer_buft, dflash_k_ctx_workspace[(size_t) il]); + ggml_backend_buffer_t k_workspace_buf = ggml_backend_buft_alloc_buffer(layer_buft, k_workspace_bytes); + if (k_workspace_buf == nullptr) { + free_dflash_kv_cache_tensors(); + return false; + } + ggml_backend_buffer_set_usage(k_workspace_buf, GGML_BACKEND_BUFFER_USAGE_COMPUTE); + ggml_backend_tensor_alloc(k_workspace_buf, dflash_k_ctx_workspace[(size_t) il], ggml_backend_buffer_get_base(k_workspace_buf)); + ggml_backend_buffer_clear(k_workspace_buf, 0); + dflash_cache_bufs.push_back(k_workspace_buf); + + const size_t v_workspace_bytes = ggml_backend_buft_get_alloc_size(layer_buft, dflash_v_ctx_workspace[(size_t) il]); + ggml_backend_buffer_t v_workspace_buf = ggml_backend_buft_alloc_buffer(layer_buft, v_workspace_bytes); + if (v_workspace_buf == nullptr) { + free_dflash_kv_cache_tensors(); + return false; + } + ggml_backend_buffer_set_usage(v_workspace_buf, GGML_BACKEND_BUFFER_USAGE_COMPUTE); + ggml_backend_tensor_alloc(v_workspace_buf, dflash_v_ctx_workspace[(size_t) il], ggml_backend_buffer_get_base(v_workspace_buf)); + ggml_backend_buffer_clear(v_workspace_buf, 0); + dflash_cache_bufs.push_back(v_workspace_buf); } dflash_profile.last_kv_cache_host_layers = host_layers; - dflash_kv_workspace_token_capacity = use_kv_workspace ? target_token_capacity : 0; - dflash_kv_workspace_n_kv_total = use_kv_workspace ? target_workspace_n_kv_total : 0; + dflash_kv_workspace_token_capacity = target_token_capacity; + dflash_kv_workspace_n_kv_total = target_workspace_n_kv_total; llama_reset_dflash_kv_cache_state(this); - LLAMA_LOG_INFO("%s: DFlash K/V cache placement cross_ctx=%d host_layers=%d/%d first=%s last=%s\n", - __func__, - target_cross_ctx, - host_layers, - n_layer, - first_buft_name != nullptr ? first_buft_name : "(none)", - last_buft_name != nullptr ? last_buft_name : "(none)"); + if (llama_dflash_stats_log_enabled()) { + LLAMA_LOG_INFO("%s: DFlash K/V cache placement cross_ctx=%d host_layers=%d/%d first=%s last=%s\n", + __func__, + target_cross_ctx, + host_layers, + n_layer, + first_buft_name != nullptr ? first_buft_name : "(none)", + last_buft_name != nullptr ? last_buft_name : "(none)"); + } return true; } @@ -758,8 +753,6 @@ static bool validate_dflash_graph_contract(const llama_context & lctx) { bool llama_prepare_dflash_graph_inputs( struct llama_context & lctx, uint32_t n_tokens) { - const bool use_kv_cache = llama_env_flag_enabled_local("IK_DFLASH_KV_CACHE"); - const bool use_kv_workspace = use_kv_cache && llama_dflash_use_kv_workspace_experiment(); const bool kv_node_timing = llama_env_flag_enabled_local("IK_DFLASH_KV_NODE_TIMING"); auto & profile = lctx.dflash_profile; const int32_t cross_ctx = lctx.dflash_visible_cross_ctx > 0 @@ -778,13 +771,8 @@ bool llama_prepare_dflash_graph_inputs( return false; } - if (use_kv_cache) { - if (!lctx.ensure_dflash_kv_cache_tensors(cross_ctx) || lctx.dflash_k_ctx_cache.empty() || lctx.dflash_v_ctx_cache.empty()) { - LLAMA_LOG_ERROR("%s: DFlash K/V cache inputs are not initialized\n", __func__); - return false; - } - } else if (lctx.inp_dflash_target_features == nullptr || lctx.inp_dflash_pos_ctx == nullptr) { - LLAMA_LOG_ERROR("%s: DFlash inline inputs are not initialized\n", __func__); + if (!lctx.ensure_dflash_kv_cache_tensors(cross_ctx) || lctx.dflash_k_ctx_cache.empty() || lctx.dflash_v_ctx_cache.empty()) { + LLAMA_LOG_ERROR("%s: DFlash K/V cache inputs are not initialized\n", __func__); return false; } @@ -797,9 +785,9 @@ bool llama_prepare_dflash_graph_inputs( const int32_t n_rows = lctx.dflash_target_features_n_rows; const int32_t append_rows_available = lctx.dflash_target_append_features_n_rows; const int32_t width = (int32_t) lctx.model.hparams.dflash_n_target_features; - const int32_t graph_cross_ctx = use_kv_cache - ? (lctx.dflash_k_ctx_cache.front() != nullptr ? (int32_t) lctx.dflash_k_ctx_cache.front()->ne[2] : 0) - : (lctx.inp_dflash_target_features != nullptr ? (int32_t) lctx.inp_dflash_target_features->ne[1] : 0); + const int32_t graph_cross_ctx = lctx.dflash_k_ctx_cache.front() != nullptr + ? (int32_t) lctx.dflash_k_ctx_cache.front()->ne[2] + : 0; const int32_t n_mask_tokens = (int32_t) kq_mask->ne[1]; const int32_t n_kv_total = (int32_t) kq_mask->ne[0]; const int64_t t_total_us = ggml_time_us(); @@ -811,9 +799,7 @@ bool llama_prepare_dflash_graph_inputs( profile.last_n_tokens = (int32_t) n_tokens; profile.last_n_kv_total = n_kv_total; - if (use_kv_workspace) { - llama_sync_dflash_workspace_if_pending(lctx); - } + llama_sync_dflash_workspace_if_pending(lctx); if (graph_cross_ctx != cross_ctx) { profile.graph_shape_failures++; @@ -836,12 +822,6 @@ bool llama_prepare_dflash_graph_inputs( return false; } - if (!use_kv_cache && !have_full_src) { - profile.graph_shape_failures++; - LLAMA_LOG_ERROR("%s: missing contiguous DFlash target features for inline path\n", __func__); - return false; - } - if (n_kv_total < cross_ctx + (int32_t) n_tokens) { profile.graph_mask_overflow++; LLAMA_LOG_ERROR("%s: invalid DFlash mask shape (n_kv_total=%d < cross_ctx+n_tokens=%d)\n", @@ -851,25 +831,6 @@ bool llama_prepare_dflash_graph_inputs( const int32_t left_pad = cross_ctx - n_rows; profile.last_left_pad = left_pad; - if (!use_kv_cache) { - const size_t padded_floats = (size_t) cross_ctx * (size_t) width; - const size_t dst_offset = (size_t) left_pad * (size_t) width; - const int64_t t_feature_us = ggml_time_us(); - if (lctx.dflash_target_features_padded.size() != padded_floats) { - lctx.dflash_target_features_padded.resize(padded_floats); - } - if (left_pad == 0 && total_floats == padded_floats) { - std::copy(src, src + total_floats, lctx.dflash_target_features_padded.begin()); - } else { - if (dst_offset > 0) { - std::fill(lctx.dflash_target_features_padded.begin(), - lctx.dflash_target_features_padded.begin() + (ptrdiff_t) dst_offset, 0.0f); - } - std::copy(src, src + total_floats, lctx.dflash_target_features_padded.begin() + (ptrdiff_t) dst_offset); - } - profile.graph_feature_copy_us += (uint64_t) (ggml_time_us() - t_feature_us); - profile.graph_feature_bytes += padded_floats * sizeof(float); - } const int64_t t_pos_us = ggml_time_us(); lctx.dflash_pos_ctx_data.resize((size_t) cross_ctx); @@ -903,274 +864,269 @@ bool llama_prepare_dflash_graph_inputs( profile.graph_pos_copy_us += (uint64_t) (ggml_time_us() - t_pos_us); profile.graph_pos_bytes += lctx.dflash_pos_ctx_data.size() * sizeof(llama_pos); - if (use_kv_cache) { - const llama_dflash_kv_cache_transition cache_plan = llama_plan_dflash_kv_cache_transition( - cross_ctx, - lctx.dflash_kv_cache_n_filled, - lctx.dflash_kv_cache_write_pos, - lctx.dflash_kv_cache_valid, - lctx.dflash_kv_cache_applied_window_version, - lctx.dflash_target_window_version, - lctx.dflash_target_window_keep_rows, - lctx.dflash_target_window_append_rows, - lctx.dflash_target_window_replace, - n_rows); + const llama_dflash_kv_cache_transition cache_plan = llama_plan_dflash_kv_cache_transition( + cross_ctx, + lctx.dflash_kv_cache_n_filled, + lctx.dflash_kv_cache_write_pos, + lctx.dflash_kv_cache_valid, + lctx.dflash_kv_cache_applied_window_version, + lctx.dflash_target_window_version, + lctx.dflash_target_window_keep_rows, + lctx.dflash_target_window_append_rows, + lctx.dflash_target_window_replace, + n_rows); - const bool have_append_src = append_src != nullptr && - append_rows_available == cache_plan.append_rows && - append_floats == (size_t) cache_plan.append_rows * (size_t) width; + const bool have_append_src = append_src != nullptr && + append_rows_available == cache_plan.append_rows && + append_floats == (size_t) cache_plan.append_rows * (size_t) width; - const int32_t update_rows = cache_plan.cache_up_to_date - ? 0 - : (cache_plan.rebuild_cache ? n_rows : cache_plan.append_rows); - const size_t max_nodes = lctx.model.max_nodes((int) std::max(1, cross_ctx)) + 24 * lctx.model.hparams.n_layer; - const size_t meta_size = ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false); - if (lctx.dflash_buf_compute_meta.size() != meta_size) { - lctx.dflash_buf_compute_meta.resize(meta_size); + const int32_t update_rows = cache_plan.cache_up_to_date + ? 0 + : (cache_plan.rebuild_cache ? n_rows : cache_plan.append_rows); + const size_t max_nodes = lctx.model.max_nodes((int) std::max(1, cross_ctx)) + 24 * lctx.model.hparams.n_layer; + const size_t meta_size = ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false); + if (lctx.dflash_buf_compute_meta.size() != meta_size) { + lctx.dflash_buf_compute_meta.resize(meta_size); + } + + if (lctx.dflash_sched == nullptr || lctx.dflash_kv_cache_reserved_rows != cross_ctx) { + std::vector backend_buft; + backend_buft.reserve(lctx.backends.size()); + for (auto * backend : lctx.backends) { + if (ggml_backend_is_cpu(backend)) { + backend_buft.push_back(llama_default_buffer_type_cpu(true)); + } else { + backend_buft.push_back(ggml_backend_get_default_buffer_type(backend)); + } } - if (lctx.dflash_sched == nullptr || lctx.dflash_kv_cache_reserved_rows != cross_ctx) { - std::vector backend_buft; - backend_buft.reserve(lctx.backends.size()); - for (auto * backend : lctx.backends) { - if (ggml_backend_is_cpu(backend)) { - backend_buft.push_back(llama_default_buffer_type_cpu(true)); - } else { - backend_buft.push_back(ggml_backend_get_default_buffer_type(backend)); - } - } + if (lctx.dflash_sched != nullptr) { + ggml_backend_sched_free(lctx.dflash_sched); + lctx.dflash_sched = nullptr; + } + lctx.dflash_kv_graph = nullptr; + lctx.dflash_kv_graph_rows = 0; + lctx.dflash_kv_graph_write_pos = 0; - if (lctx.dflash_sched != nullptr) { - ggml_backend_sched_free(lctx.dflash_sched); - lctx.dflash_sched = nullptr; - } - lctx.dflash_kv_graph = nullptr; - lctx.dflash_kv_graph_rows = 0; - lctx.dflash_kv_graph_write_pos = 0; + const int32_t saved_update_rows = lctx.dflash_kv_cache_update_rows; + lctx.dflash_kv_cache_update_rows = cross_ctx; + const int64_t t_build_us = ggml_time_us(); + ggml_cgraph * gf_reserve = llm_build_context::llama_build_graph_dflash_kv_cache(lctx); + profile.graph_kv_cache_build_us += (uint64_t) (ggml_time_us() - t_build_us); + lctx.dflash_kv_cache_update_rows = saved_update_rows; + if (gf_reserve == nullptr) { + profile.graph_shape_failures++; + LLAMA_LOG_ERROR("%s: failed to build DFlash K/V cache reserve graph\n", __func__); + return false; + } - const int32_t saved_update_rows = lctx.dflash_kv_cache_update_rows; - lctx.dflash_kv_cache_update_rows = cross_ctx; + const int64_t t_reserve_us = ggml_time_us(); + lctx.dflash_sched = ggml_backend_sched_new(lctx.backends.data(), backend_buft.data(), lctx.backends.size(), max_nodes, false); + const bool reserved = lctx.dflash_sched != nullptr && ggml_backend_sched_reserve(lctx.dflash_sched, gf_reserve); + profile.graph_kv_cache_reserve_us += (uint64_t) (ggml_time_us() - t_reserve_us); + if (!reserved) { + profile.graph_shape_failures++; + LLAMA_LOG_ERROR("%s: failed to initialize DFlash K/V scheduler\n", __func__); + return false; + } + lctx.dflash_kv_cache_reserved_rows = cross_ctx; + } + + if (update_rows > 0) { + const float * update_src = nullptr; + if (have_append_src && update_rows == cache_plan.append_rows) { + update_src = append_src; + } else if (have_full_src) { + update_src = src + (size_t) (n_rows - update_rows) * (size_t) width; + } + const llama_pos * update_pos = src_pos + (n_rows - update_rows); + + if (update_src == nullptr) { + profile.graph_shape_failures++; + LLAMA_LOG_ERROR("%s: missing DFlash appended target features for cached update (rows=%d append_rows=%d floats=%zu)\n", + __func__, n_rows, update_rows, append_floats); + return false; + } + + if (cache_plan.rebuild_cache) { + llama_reset_dflash_kv_cache_state(&lctx); + } + + lctx.dflash_kv_cache_update_rows = update_rows; + ggml_cgraph * gf_kv = nullptr; + const bool can_reuse_kv_graph = lctx.dflash_kv_graph != nullptr && + lctx.dflash_kv_graph_rows == update_rows && + lctx.dflash_kv_graph_write_pos == lctx.dflash_kv_cache_write_pos; + if (can_reuse_kv_graph) { + gf_kv = lctx.dflash_kv_graph; + } else { const int64_t t_build_us = ggml_time_us(); - ggml_cgraph * gf_reserve = llm_build_context::llama_build_graph_dflash_kv_cache(lctx); + gf_kv = llm_build_context::llama_build_graph_dflash_kv_cache(lctx); profile.graph_kv_cache_build_us += (uint64_t) (ggml_time_us() - t_build_us); - lctx.dflash_kv_cache_update_rows = saved_update_rows; - if (gf_reserve == nullptr) { + if (gf_kv == nullptr || lctx.dflash_kv_input_target_features == nullptr || lctx.dflash_kv_input_pos_ctx == nullptr) { profile.graph_shape_failures++; - LLAMA_LOG_ERROR("%s: failed to build DFlash K/V cache reserve graph\n", __func__); + LLAMA_LOG_ERROR("%s: failed to build DFlash K/V cache graph\n", __func__); return false; } - const int64_t t_reserve_us = ggml_time_us(); - lctx.dflash_sched = ggml_backend_sched_new(lctx.backends.data(), backend_buft.data(), lctx.backends.size(), max_nodes, false); - const bool reserved = lctx.dflash_sched != nullptr && ggml_backend_sched_reserve(lctx.dflash_sched, gf_reserve); - profile.graph_kv_cache_reserve_us += (uint64_t) (ggml_time_us() - t_reserve_us); - if (!reserved) { - profile.graph_shape_failures++; - LLAMA_LOG_ERROR("%s: failed to initialize DFlash K/V scheduler\n", __func__); - return false; - } - lctx.dflash_kv_cache_reserved_rows = cross_ctx; + const int64_t t_reset_us = ggml_time_us(); + ggml_backend_sched_reset(lctx.dflash_sched); + profile.graph_kv_cache_reset_us += (uint64_t) (ggml_time_us() - t_reset_us); + + const int64_t t_alloc_us = ggml_time_us(); + ggml_backend_sched_alloc_graph(lctx.dflash_sched, gf_kv); + profile.graph_kv_cache_alloc_us += (uint64_t) (ggml_time_us() - t_alloc_us); + + lctx.dflash_kv_graph = gf_kv; + lctx.dflash_kv_graph_rows = update_rows; + lctx.dflash_kv_graph_write_pos = lctx.dflash_kv_cache_write_pos; } - if (update_rows > 0) { - const float * update_src = nullptr; - if (have_append_src && update_rows == cache_plan.append_rows) { - update_src = append_src; - } else if (have_full_src) { - update_src = src + (size_t) (n_rows - update_rows) * (size_t) width; - } - const llama_pos * update_pos = src_pos + (n_rows - update_rows); + ggml_backend_t kv_feature_backend = llama_backend_for_tensor(lctx, lctx.dflash_kv_input_target_features); + const int64_t t_feature_upload_us = ggml_time_us(); + if (kv_feature_backend != nullptr) { + ggml_backend_tensor_set_async(kv_feature_backend, lctx.dflash_kv_input_target_features, update_src, 0, ggml_nbytes(lctx.dflash_kv_input_target_features)); + } else { + ggml_backend_tensor_set(lctx.dflash_kv_input_target_features, update_src, 0, ggml_nbytes(lctx.dflash_kv_input_target_features)); + } + profile.graph_kv_cache_feature_upload_us += (uint64_t) (ggml_time_us() - t_feature_upload_us); + profile.graph_feature_bytes += (size_t) update_rows * (size_t) width * sizeof(float); - if (update_src == nullptr) { - profile.graph_shape_failures++; - LLAMA_LOG_ERROR("%s: missing DFlash appended target features for cached update (rows=%d append_rows=%d floats=%zu)\n", - __func__, n_rows, update_rows, append_floats); - return false; + ggml_backend_t kv_pos_backend = llama_backend_for_tensor(lctx, lctx.dflash_kv_input_pos_ctx); + const int64_t t_pos_upload_us = ggml_time_us(); + if (kv_pos_backend != nullptr) { + ggml_backend_tensor_set_async(kv_pos_backend, lctx.dflash_kv_input_pos_ctx, update_pos, 0, ggml_nbytes(lctx.dflash_kv_input_pos_ctx)); + } else { + ggml_backend_tensor_set(lctx.dflash_kv_input_pos_ctx, update_pos, 0, ggml_nbytes(lctx.dflash_kv_input_pos_ctx)); + } + profile.graph_kv_cache_pos_upload_us += (uint64_t) (ggml_time_us() - t_pos_upload_us); + + const int64_t t_kv_cache_us = ggml_time_us(); + llama_dflash_kv_node_profiler kv_node_profiler; + if (kv_node_timing) { + kv_node_profiler.profile = &profile; + ggml_backend_sched_set_eval_callback(lctx.dflash_sched, llama_dflash_kv_node_eval_callback, &kv_node_profiler); + } + llama_graph_compute_sched(lctx, lctx.dflash_sched, gf_kv, lctx.cparams.n_threads); + if (kv_node_timing) { + ggml_backend_sched_set_eval_callback(lctx.dflash_sched, nullptr, nullptr); + } + profile.graph_kv_cache_compute_us += (uint64_t) (ggml_time_us() - t_kv_cache_us); + + const int64_t t_sync_us = ggml_time_us(); + ggml_backend_sched_synchronize(lctx.dflash_sched); + profile.graph_kv_cache_sync_us += (uint64_t) (ggml_time_us() - t_sync_us); + profile.graph_kv_cache_calls++; + + lctx.dflash_kv_cache_n_filled = std::min(cross_ctx, lctx.dflash_kv_cache_n_filled + update_rows); + lctx.dflash_kv_cache_write_pos = (lctx.dflash_kv_cache_write_pos + update_rows) % cross_ctx; + lctx.dflash_kv_cache_applied_window_version = lctx.dflash_target_window_version; + lctx.dflash_kv_cache_valid = true; + lctx.dflash_kv_cache_view_n_filled = lctx.dflash_kv_cache_n_filled; + lctx.dflash_kv_cache_view_write_pos = lctx.dflash_kv_cache_write_pos; + lctx.dflash_kv_cache_view_valid = true; + } + + if (lctx.dflash_kv_cache_view_valid && + !lctx.dflash_k_ctx_workspace.empty() && !lctx.dflash_v_ctx_workspace.empty()) { + const bool need_workspace_refresh = !lctx.dflash_kv_workspace_valid || + lctx.dflash_kv_workspace_n_filled != lctx.dflash_kv_cache_view_n_filled || + lctx.dflash_kv_workspace_write_pos != lctx.dflash_kv_cache_view_write_pos || + lctx.dflash_kv_workspace_applied_window_version != lctx.dflash_kv_cache_applied_window_version; + + if (need_workspace_refresh) { + const size_t max_nodes = lctx.model.max_nodes((int) std::max(1, cross_ctx)) + 16 * lctx.model.hparams.n_layer; + const size_t meta_size = ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false); + if (lctx.dflash_workspace_buf_compute_meta.size() != meta_size) { + lctx.dflash_workspace_buf_compute_meta.resize(meta_size); } - if (cache_plan.rebuild_cache) { - llama_reset_dflash_kv_cache_state(&lctx); - } + ggml_cgraph * gf_workspace = nullptr; + const bool can_reuse_workspace_graph = lctx.dflash_kv_workspace_graph != nullptr && + lctx.dflash_kv_workspace_graph_rows == lctx.dflash_kv_cache_view_n_filled && + lctx.dflash_kv_workspace_graph_write_pos == lctx.dflash_kv_cache_view_write_pos; - lctx.dflash_kv_cache_update_rows = update_rows; - ggml_cgraph * gf_kv = nullptr; - const bool can_reuse_kv_graph = lctx.dflash_kv_graph != nullptr && - lctx.dflash_kv_graph_rows == update_rows && - lctx.dflash_kv_graph_write_pos == lctx.dflash_kv_cache_write_pos; - if (can_reuse_kv_graph) { - gf_kv = lctx.dflash_kv_graph; + if (can_reuse_workspace_graph) { + gf_workspace = lctx.dflash_kv_workspace_graph; } else { const int64_t t_build_us = ggml_time_us(); - gf_kv = llm_build_context::llama_build_graph_dflash_kv_cache(lctx); - profile.graph_kv_cache_build_us += (uint64_t) (ggml_time_us() - t_build_us); - if (gf_kv == nullptr || lctx.dflash_kv_input_target_features == nullptr || lctx.dflash_kv_input_pos_ctx == nullptr) { + gf_workspace = llm_build_context::llama_build_graph_dflash_kv_workspace(lctx); + profile.graph_kv_workspace_build_us += (uint64_t) (ggml_time_us() - t_build_us); + if (gf_workspace == nullptr) { profile.graph_shape_failures++; - LLAMA_LOG_ERROR("%s: failed to build DFlash K/V cache graph\n", __func__); + LLAMA_LOG_ERROR("%s: failed to build DFlash K/V workspace graph\n", __func__); return false; } - const int64_t t_reset_us = ggml_time_us(); - ggml_backend_sched_reset(lctx.dflash_sched); - profile.graph_kv_cache_reset_us += (uint64_t) (ggml_time_us() - t_reset_us); - - const int64_t t_alloc_us = ggml_time_us(); - ggml_backend_sched_alloc_graph(lctx.dflash_sched, gf_kv); - profile.graph_kv_cache_alloc_us += (uint64_t) (ggml_time_us() - t_alloc_us); - - lctx.dflash_kv_graph = gf_kv; - lctx.dflash_kv_graph_rows = update_rows; - lctx.dflash_kv_graph_write_pos = lctx.dflash_kv_cache_write_pos; - } - - ggml_backend_t kv_feature_backend = llama_backend_for_tensor(lctx, lctx.dflash_kv_input_target_features); - const int64_t t_feature_upload_us = ggml_time_us(); - if (kv_feature_backend != nullptr) { - ggml_backend_tensor_set_async(kv_feature_backend, lctx.dflash_kv_input_target_features, update_src, 0, ggml_nbytes(lctx.dflash_kv_input_target_features)); - } else { - ggml_backend_tensor_set(lctx.dflash_kv_input_target_features, update_src, 0, ggml_nbytes(lctx.dflash_kv_input_target_features)); - } - profile.graph_kv_cache_feature_upload_us += (uint64_t) (ggml_time_us() - t_feature_upload_us); - profile.graph_feature_bytes += (size_t) update_rows * (size_t) width * sizeof(float); - - ggml_backend_t kv_pos_backend = llama_backend_for_tensor(lctx, lctx.dflash_kv_input_pos_ctx); - const int64_t t_pos_upload_us = ggml_time_us(); - if (kv_pos_backend != nullptr) { - ggml_backend_tensor_set_async(kv_pos_backend, lctx.dflash_kv_input_pos_ctx, update_pos, 0, ggml_nbytes(lctx.dflash_kv_input_pos_ctx)); - } else { - ggml_backend_tensor_set(lctx.dflash_kv_input_pos_ctx, update_pos, 0, ggml_nbytes(lctx.dflash_kv_input_pos_ctx)); - } - profile.graph_kv_cache_pos_upload_us += (uint64_t) (ggml_time_us() - t_pos_upload_us); - - const int64_t t_kv_cache_us = ggml_time_us(); - llama_dflash_kv_node_profiler kv_node_profiler; - if (kv_node_timing) { - kv_node_profiler.profile = &profile; - ggml_backend_sched_set_eval_callback(lctx.dflash_sched, llama_dflash_kv_node_eval_callback, &kv_node_profiler); - } - llama_graph_compute_sched(lctx, lctx.dflash_sched, gf_kv, lctx.cparams.n_threads); - if (kv_node_timing) { - ggml_backend_sched_set_eval_callback(lctx.dflash_sched, nullptr, nullptr); - } - profile.graph_kv_cache_compute_us += (uint64_t) (ggml_time_us() - t_kv_cache_us); - - const int64_t t_sync_us = ggml_time_us(); - ggml_backend_sched_synchronize(lctx.dflash_sched); - profile.graph_kv_cache_sync_us += (uint64_t) (ggml_time_us() - t_sync_us); - profile.graph_kv_cache_calls++; - - lctx.dflash_kv_cache_n_filled = std::min(cross_ctx, lctx.dflash_kv_cache_n_filled + update_rows); - lctx.dflash_kv_cache_write_pos = (lctx.dflash_kv_cache_write_pos + update_rows) % cross_ctx; - lctx.dflash_kv_cache_applied_window_version = lctx.dflash_target_window_version; - lctx.dflash_kv_cache_valid = true; - lctx.dflash_kv_cache_view_n_filled = lctx.dflash_kv_cache_n_filled; - lctx.dflash_kv_cache_view_write_pos = lctx.dflash_kv_cache_write_pos; - lctx.dflash_kv_cache_view_valid = true; - } - - if (use_kv_workspace && lctx.dflash_kv_cache_view_valid && - !lctx.dflash_k_ctx_workspace.empty() && !lctx.dflash_v_ctx_workspace.empty()) { - const bool need_workspace_refresh = !lctx.dflash_kv_workspace_valid || - lctx.dflash_kv_workspace_n_filled != lctx.dflash_kv_cache_view_n_filled || - lctx.dflash_kv_workspace_write_pos != lctx.dflash_kv_cache_view_write_pos || - lctx.dflash_kv_workspace_applied_window_version != lctx.dflash_kv_cache_applied_window_version; - - if (need_workspace_refresh) { - const size_t max_nodes = lctx.model.max_nodes((int) std::max(1, cross_ctx)) + 16 * lctx.model.hparams.n_layer; - const size_t meta_size = ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false); - if (lctx.dflash_workspace_buf_compute_meta.size() != meta_size) { - lctx.dflash_workspace_buf_compute_meta.resize(meta_size); + std::vector backend_buft; + backend_buft.reserve(lctx.backends.size()); + for (auto * backend : lctx.backends) { + if (ggml_backend_is_cpu(backend)) { + backend_buft.push_back(llama_default_buffer_type_cpu(true)); + } else { + backend_buft.push_back(ggml_backend_get_default_buffer_type(backend)); + } } - ggml_cgraph * gf_workspace = nullptr; - const bool can_reuse_workspace_graph = lctx.dflash_kv_workspace_graph != nullptr && - lctx.dflash_kv_workspace_graph_rows == lctx.dflash_kv_cache_view_n_filled && - lctx.dflash_kv_workspace_graph_write_pos == lctx.dflash_kv_cache_view_write_pos; + if (lctx.dflash_workspace_sched == nullptr) { + lctx.dflash_workspace_sched = ggml_backend_sched_new(lctx.backends.data(), backend_buft.data(), lctx.backends.size(), max_nodes, false); + } - if (can_reuse_workspace_graph) { - gf_workspace = lctx.dflash_kv_workspace_graph; - } else { - const int64_t t_build_us = ggml_time_us(); - gf_workspace = llm_build_context::llama_build_graph_dflash_kv_workspace(lctx); - profile.graph_kv_workspace_build_us += (uint64_t) (ggml_time_us() - t_build_us); - if (gf_workspace == nullptr) { + if (lctx.dflash_kv_workspace_reserved_rows != cross_ctx) { + const bool saved_view_valid = lctx.dflash_kv_cache_view_valid; + const int32_t saved_view_rows = lctx.dflash_kv_cache_view_n_filled; + const int32_t saved_view_write_pos = lctx.dflash_kv_cache_view_write_pos; + + lctx.dflash_kv_cache_view_valid = true; + lctx.dflash_kv_cache_view_n_filled = cross_ctx; + lctx.dflash_kv_cache_view_write_pos = cross_ctx > 1 ? 1 : 0; + + const int64_t t_reserve_build_us = ggml_time_us(); + ggml_cgraph * gf_workspace_reserve = llm_build_context::llama_build_graph_dflash_kv_workspace(lctx); + profile.graph_kv_workspace_build_us += (uint64_t) (ggml_time_us() - t_reserve_build_us); + + lctx.dflash_kv_cache_view_valid = saved_view_valid; + lctx.dflash_kv_cache_view_n_filled = saved_view_rows; + lctx.dflash_kv_cache_view_write_pos = saved_view_write_pos; + + const int64_t t_reserve_us = ggml_time_us(); + const bool reserved = lctx.dflash_workspace_sched != nullptr && + gf_workspace_reserve != nullptr && + ggml_backend_sched_reserve(lctx.dflash_workspace_sched, gf_workspace_reserve); + profile.graph_kv_workspace_reserve_us += (uint64_t) (ggml_time_us() - t_reserve_us); + if (!reserved) { profile.graph_shape_failures++; - LLAMA_LOG_ERROR("%s: failed to build DFlash K/V workspace graph\n", __func__); + LLAMA_LOG_ERROR("%s: failed to initialize DFlash K/V workspace scheduler\n", __func__); return false; } - std::vector backend_buft; - backend_buft.reserve(lctx.backends.size()); - for (auto * backend : lctx.backends) { - if (ggml_backend_is_cpu(backend)) { - backend_buft.push_back(llama_default_buffer_type_cpu(true)); - } else { - backend_buft.push_back(ggml_backend_get_default_buffer_type(backend)); - } - } - - if (lctx.dflash_workspace_sched == nullptr) { - lctx.dflash_workspace_sched = ggml_backend_sched_new(lctx.backends.data(), backend_buft.data(), lctx.backends.size(), max_nodes, false); - } - - if (lctx.dflash_kv_workspace_reserved_rows != cross_ctx) { - const bool saved_view_valid = lctx.dflash_kv_cache_view_valid; - const int32_t saved_view_rows = lctx.dflash_kv_cache_view_n_filled; - const int32_t saved_view_write_pos = lctx.dflash_kv_cache_view_write_pos; - - lctx.dflash_kv_cache_view_valid = true; - lctx.dflash_kv_cache_view_n_filled = cross_ctx; - lctx.dflash_kv_cache_view_write_pos = cross_ctx > 1 ? 1 : 0; - - const int64_t t_reserve_build_us = ggml_time_us(); - ggml_cgraph * gf_workspace_reserve = llm_build_context::llama_build_graph_dflash_kv_workspace(lctx); - profile.graph_kv_workspace_build_us += (uint64_t) (ggml_time_us() - t_reserve_build_us); - - lctx.dflash_kv_cache_view_valid = saved_view_valid; - lctx.dflash_kv_cache_view_n_filled = saved_view_rows; - lctx.dflash_kv_cache_view_write_pos = saved_view_write_pos; - - const int64_t t_reserve_us = ggml_time_us(); - const bool reserved = lctx.dflash_workspace_sched != nullptr && - gf_workspace_reserve != nullptr && - ggml_backend_sched_reserve(lctx.dflash_workspace_sched, gf_workspace_reserve); - profile.graph_kv_workspace_reserve_us += (uint64_t) (ggml_time_us() - t_reserve_us); - if (!reserved) { - profile.graph_shape_failures++; - LLAMA_LOG_ERROR("%s: failed to initialize DFlash K/V workspace scheduler\n", __func__); - return false; - } - - lctx.dflash_kv_workspace_reserved_rows = cross_ctx; - } - - const int64_t t_reset_us = ggml_time_us(); - ggml_backend_sched_reset(lctx.dflash_workspace_sched); - profile.graph_kv_workspace_reset_us += (uint64_t) (ggml_time_us() - t_reset_us); - - const int64_t t_alloc_us = ggml_time_us(); - ggml_backend_sched_alloc_graph(lctx.dflash_workspace_sched, gf_workspace); - profile.graph_kv_workspace_alloc_us += (uint64_t) (ggml_time_us() - t_alloc_us); - - lctx.dflash_kv_workspace_graph = gf_workspace; - lctx.dflash_kv_workspace_graph_rows = lctx.dflash_kv_cache_view_n_filled; - lctx.dflash_kv_workspace_graph_write_pos = lctx.dflash_kv_cache_view_write_pos; + lctx.dflash_kv_workspace_reserved_rows = cross_ctx; } - const int64_t t_workspace_us = ggml_time_us(); - llama_graph_compute_sched(lctx, lctx.dflash_workspace_sched, gf_workspace, lctx.cparams.n_threads); - profile.graph_kv_workspace_compute_us += (uint64_t) (ggml_time_us() - t_workspace_us); - lctx.dflash_kv_workspace_sync_pending = true; - profile.graph_kv_workspace_calls++; + const int64_t t_reset_us = ggml_time_us(); + ggml_backend_sched_reset(lctx.dflash_workspace_sched); + profile.graph_kv_workspace_reset_us += (uint64_t) (ggml_time_us() - t_reset_us); - lctx.dflash_kv_workspace_n_filled = lctx.dflash_kv_cache_view_n_filled; - lctx.dflash_kv_workspace_write_pos = lctx.dflash_kv_cache_view_write_pos; - lctx.dflash_kv_workspace_applied_window_version = lctx.dflash_kv_cache_applied_window_version; - lctx.dflash_kv_workspace_valid = true; + const int64_t t_alloc_us = ggml_time_us(); + ggml_backend_sched_alloc_graph(lctx.dflash_workspace_sched, gf_workspace); + profile.graph_kv_workspace_alloc_us += (uint64_t) (ggml_time_us() - t_alloc_us); + + lctx.dflash_kv_workspace_graph = gf_workspace; + lctx.dflash_kv_workspace_graph_rows = lctx.dflash_kv_cache_view_n_filled; + lctx.dflash_kv_workspace_graph_write_pos = lctx.dflash_kv_cache_view_write_pos; } + + const int64_t t_workspace_us = ggml_time_us(); + llama_graph_compute_sched(lctx, lctx.dflash_workspace_sched, gf_workspace, lctx.cparams.n_threads); + profile.graph_kv_workspace_compute_us += (uint64_t) (ggml_time_us() - t_workspace_us); + lctx.dflash_kv_workspace_sync_pending = true; + profile.graph_kv_workspace_calls++; + + lctx.dflash_kv_workspace_n_filled = lctx.dflash_kv_cache_view_n_filled; + lctx.dflash_kv_workspace_write_pos = lctx.dflash_kv_cache_view_write_pos; + lctx.dflash_kv_workspace_applied_window_version = lctx.dflash_kv_cache_applied_window_version; + lctx.dflash_kv_workspace_valid = true; } - } else { - ggml_backend_tensor_set(lctx.inp_dflash_target_features, lctx.dflash_target_features_padded.data(), 0, ggml_nbytes(lctx.inp_dflash_target_features)); - ggml_backend_tensor_set(lctx.inp_dflash_pos_ctx, lctx.dflash_pos_ctx_data.data(), 0, ggml_nbytes(lctx.inp_dflash_pos_ctx)); } const int64_t t_mask_us = ggml_time_us(); @@ -1221,7 +1177,7 @@ bool llama_prepare_dflash_graph_inputs( profile.graph_visible_kv_max = std::max(profile.graph_visible_kv_max, (uint64_t) visible_kv_max); profile.graph_prepare_total_us += (uint64_t) (ggml_time_us() - t_total_us); - if (profile.graph_prepare_calls == 1) { + if (profile.graph_prepare_calls == 1 && llama_dflash_stats_log_enabled()) { int32_t n_swa_layers = 0; for (int32_t il = 0; il < lctx.model.hparams.n_layer; ++il) { n_swa_layers += lctx.model.hparams.swa_layers[(size_t) il] ? 1 : 0; diff --git a/src/llama-spec-features-dflash.cpp b/src/llama-spec-features-dflash.cpp index 088f6b2d..2df71006 100644 --- a/src/llama-spec-features-dflash.cpp +++ b/src/llama-spec-features-dflash.cpp @@ -10,6 +10,14 @@ #include "llama-model.h" #include "llama-context.h" +static bool llama_dflash_stats_log_enabled() { + const char * env = std::getenv("IK_DFLASH_STATS_LOG"); + return env != nullptr && *env != '\0' && + std::strcmp(env, "0") != 0 && + std::strcmp(env, "false") != 0 && + std::strcmp(env, "off") != 0; +} + static bool llama_dflash_positions_strictly_increasing( const llama_pos * positions, int32_t n_rows, @@ -295,12 +303,16 @@ bool llama_model_share_dflash_io_tensors( const struct ggml_tensor * output = llama_model_dflash_output_tensor(draft_model); if (draft_model->tok_embd != nullptr && output != nullptr) { - LLAMA_LOG_INFO("%s: DFlash IO mode=%s output_head=%s tensor=%s type=%s\n", + LLAMA_LOG_INFO("%s: DFlash ready io=%s output_head=%s\n", __func__, llama_dflash_io_mode_name(llama_model_dflash_io_mode(draft_model, target_model)), - llama_dflash_output_head_kind(draft_model, target_model), - output->name[0] != '\0' ? output->name : "(unnamed)", - ggml_type_name(output->type)); + llama_dflash_output_head_kind(draft_model, target_model)); + if (llama_dflash_stats_log_enabled()) { + LLAMA_LOG_INFO("%s: DFlash IO tensor=%s type=%s\n", + __func__, + output->name[0] != '\0' ? output->name : "(unnamed)", + ggml_type_name(output->type)); + } } return draft_model->tok_embd != nullptr && output != nullptr; From 0d75eee35a892d37d0d196e37d7e2e9b5a090f05 Mon Sep 17 00:00:00 2001 From: SamuelOliveirads Date: Sun, 14 Jun 2026 16:02:02 -0300 Subject: [PATCH 11/13] remove duplicated code and unnecesary refactor --- common/speculative-dflash-impl.h | 872 ++++++++++++++ common/speculative-impl.h | 1743 --------------------------- common/speculative.cpp | 1767 +++++++++++++++++++++------- convert_hf_to_gguf.py | 228 +++- examples/server/server-context.cpp | 111 +- src/llama-dflash-profile.h | 340 ++++++ src/llama-dflash.cpp | 337 +----- src/llama-spec-features-dflash.cpp | 30 + src/llama-spec-features-dflash.h | 10 + src/llama-spec-features.cpp | 5 - src/llama.cpp | 338 +----- 11 files changed, 2818 insertions(+), 2963 deletions(-) create mode 100644 common/speculative-dflash-impl.h delete mode 100644 common/speculative-impl.h create mode 100644 src/llama-dflash-profile.h diff --git a/common/speculative-dflash-impl.h b/common/speculative-dflash-impl.h new file mode 100644 index 00000000..c644ddbb --- /dev/null +++ b/common/speculative-dflash-impl.h @@ -0,0 +1,872 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +static bool common_speculative_are_dflash_compatible( + const llama_model * model_tgt, + const llama_model * model_dft) { + const llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt); + const llama_vocab * vocab_dft = llama_model_get_vocab(model_dft); + + if (llama_vocab_type(vocab_tgt) != llama_vocab_type(vocab_dft)) { + LOG_DBG("%s: DFlash draft model vocab type must match the target model\n", __func__); + return false; + } + + const bool add_bos_tgt = llama_vocab_get_add_bos(vocab_tgt); + const bool add_bos_dft = llama_vocab_get_add_bos(vocab_dft); + const bool add_eos_tgt = llama_vocab_get_add_eos(vocab_tgt); + const bool add_eos_dft = llama_vocab_get_add_eos(vocab_dft); + const llama_token bos_tgt = llama_vocab_bos(vocab_tgt); + const llama_token bos_dft = llama_vocab_bos(vocab_dft); + const llama_token eos_tgt = llama_vocab_eos(vocab_tgt); + const llama_token eos_dft = llama_vocab_eos(vocab_dft); + + if (add_bos_tgt != add_bos_dft || add_eos_tgt != add_eos_dft || + (add_bos_tgt && bos_tgt != bos_dft) || + (add_eos_tgt && eos_tgt != eos_dft)) { + LOG_DBG("%s: DFlash draft special tokens must match the target model (add_bos=%d/%d add_eos=%d/%d bos=%d/%d eos=%d/%d)\n", + __func__, + (int) add_bos_tgt, + (int) add_bos_dft, + (int) add_eos_tgt, + (int) add_eos_dft, + (int) bos_tgt, + (int) bos_dft, + (int) eos_tgt, + (int) eos_dft); + return false; + } + + const int n_vocab_tgt = llama_vocab_n_tokens(vocab_tgt); + const int n_vocab_dft = llama_vocab_n_tokens(vocab_dft); + const int vocab_diff = n_vocab_tgt > n_vocab_dft + ? n_vocab_tgt - n_vocab_dft + : n_vocab_dft - n_vocab_tgt; + + if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) { + LOG_DBG("%s: DFlash draft vocab size differs too much from the target model (%d vs %d)\n", + __func__, n_vocab_dft, n_vocab_tgt); + return false; + } + + for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) { + const char * token_text_tgt = llama_vocab_get_text(vocab_tgt, i); + const char * token_text_dft = llama_vocab_get_text(vocab_dft, i); + + if (std::strcmp(token_text_tgt, token_text_dft) != 0) { + LOG_DBG("%s: DFlash draft token %d differs - target '%s', draft '%s'\n", __func__, i, + common_token_to_piece(vocab_tgt, i).c_str(), + common_token_to_piece(vocab_dft, i).c_str()); + return false; + } + } + + return true; +} + +static bool dflash_contract_log_enabled() { + const char * env = std::getenv("IK_DFLASH_CONTRACT_LOG"); + if (env == nullptr || *env == '\0') { + return false; + } + + return std::strcmp(env, "0") != 0 && + std::strcmp(env, "false") != 0 && + std::strcmp(env, "off") != 0; +} + +static bool dflash_stats_log_enabled() { + const char * env = std::getenv("IK_DFLASH_STATS_LOG"); + if (env == nullptr || *env == '\0') { + return false; + } + + return std::strcmp(env, "0") != 0 && + std::strcmp(env, "false") != 0 && + std::strcmp(env, "off") != 0; +} + +template +static std::string dflash_contract_format_values( + const std::vector & values, + size_t edge_count = 4) { + std::ostringstream oss; + oss << '['; + if (values.empty()) { + oss << ']'; + return oss.str(); + } + + const size_t head = std::min(edge_count, values.size()); + for (size_t i = 0; i < head; ++i) { + if (i > 0) { + oss << ','; + } + oss << values[i]; + } + + if (values.size() > edge_count * 2) { + oss << ",...,"; + for (size_t i = values.size() - edge_count; i < values.size(); ++i) { + if (i > values.size() - edge_count) { + oss << ','; + } + oss << values[i]; + } + } else { + for (size_t i = head; i < values.size(); ++i) { + oss << ',' << values[i]; + } + } + + oss << ']'; + return oss.str(); +} + +struct dflash_contract_pos_summary { + llama_pos first = -1; + llama_pos last = -1; + int32_t gap_count = 0; + int32_t nonmono_count = 0; +}; + +static dflash_contract_pos_summary dflash_contract_summarize_positions( + const std::vector & positions) { + dflash_contract_pos_summary summary; + if (positions.empty()) { + return summary; + } + + summary.first = positions.front(); + summary.last = positions.back(); + for (size_t i = 1; i < positions.size(); ++i) { + if (positions[i] <= positions[i - 1]) { + summary.nonmono_count++; + } else if (positions[i] != positions[i - 1] + 1) { + summary.gap_count++; + } + } + + return summary; +} + +struct common_speculative_state_dflash; + +static void dflash_contract_log_append( + const common_speculative_state_dflash & state, + llama_seq_id seq_id, + const std::vector & new_positions); +static void dflash_contract_log_draft( + const common_speculative_state_dflash & state, + int32_t n_keep, + size_t result_size); +static void dflash_materialize_target_window_features(common_speculative_state_dflash & state); + +// DFlash runtime state and draft path. +struct common_speculative_state_dflash : public common_speculative_state { + llama_context * ctx_tgt; + llama_context * ctx_dft; + + llama_batch batch = {}; + + int32_t block_size = 0; + int32_t mask_token_id = -1; + int32_t n_target_features = 0; + int32_t cross_ctx = 0; + bool ready = false; + + std::vector target_layer_ids; + std::vector target_window; + std::vector target_window_pos; + std::vector target_window_stage; + std::vector target_window_pos_stage; + std::vector target_window_ring; + std::vector target_window_append_features; + int32_t target_window_rows = 0; + int32_t target_window_ring_write_pos = 0; + int32_t target_window_ring_filled = 0; + uint64_t target_window_version = 0; + int32_t target_window_keep_rows = 0; + int32_t target_window_append_rows = 0; + bool target_window_replace = false; + bool target_window_materialized = false; + llama_pos last_target_pos = -1; + size_t n_window_updates = 0; + size_t n_rows_seen = 0; + size_t n_rows_dropped = 0; + size_t n_context_shifts = 0; + size_t n_draft_empty = 0; + size_t n_set_target_fail = 0; + size_t n_decode_fail = 0; + llama_pos last_draft_pos_base = -1; + + uint64_t t_draft_decode_us = 0; + uint64_t t_draft_sample_us = 0; + uint64_t t_warmup_collect_us = 0; + uint64_t t_warmup_append_us = 0; + uint64_t t_accept_output_copy_us = 0; + uint64_t t_accept_commit_us = 0; + uint64_t t_accept_append_us = 0; + uint64_t t_accept_append_filter_us = 0; + uint64_t t_accept_append_window_alloc_us = 0; + uint64_t t_accept_append_replace_us = 0; + uint64_t t_accept_append_keep_old_us = 0; + uint64_t t_accept_append_new_rows_us = 0; + uint64_t t_accept_append_commit_detail_us = 0; + uint64_t t_accept_append_log_us = 0; + size_t n_warmup_collect_calls = 0; + size_t n_warmup_collect_rows = 0; + size_t n_warmup_append_calls = 0; + size_t n_warmup_append_rows = 0; + size_t n_accept_output_copy_calls = 0; + size_t n_accept_output_copy_rows = 0; + size_t n_accept_commit_calls = 0; + size_t n_accept_commit_rows = 0; + size_t n_accept_append_calls = 0; + size_t n_accept_append_rows = 0; + size_t n_accept_append_replace_calls = 0; + size_t n_accept_append_slide_calls = 0; + + common_speculative_state_dflash( + enum common_speculative_type type, + llama_context * ctx_tgt, + llama_context * ctx_dft, + int32_t cross_ctx) + : common_speculative_state(type) + , ctx_tgt(ctx_tgt) + , ctx_dft(ctx_dft) + , cross_ctx(std::max(1, cross_ctx)) + { + const llama_model * model_tgt = llama_get_model(ctx_tgt); + const llama_model * model_dft = llama_get_model(ctx_dft); + + if (!common_speculative_are_dflash_compatible(model_tgt, model_dft)) { + LOG_ERR("%s: DFlash draft model vocab/tokenizer is incompatible with the target model\n", __func__); + return; + } + + block_size = llama_model_dflash_block_size(model_dft); + mask_token_id = llama_model_dflash_mask_token_id(model_dft); + n_target_features = llama_model_dflash_n_target_features(model_dft); + const int32_t n_target_layers = llama_model_dflash_n_target_layers(model_dft); + + if (block_size <= 0 || mask_token_id < 0 || n_target_features <= 0 || n_target_layers <= 0) { + LOG_ERR("%s: invalid DFlash metadata (block_size=%d, mask_token_id=%d, n_target_features=%d, n_target_layers=%d)\n", + __func__, block_size, mask_token_id, n_target_features, n_target_layers); + return; + } + + target_layer_ids.resize((size_t) n_target_layers); + if (llama_model_dflash_target_layer_ids(model_dft, target_layer_ids.data(), n_target_layers) != n_target_layers) { + LOG_ERR("%s: failed to read DFlash target layer ids\n", __func__); + target_layer_ids.clear(); + return; + } + + const auto * vocab_tgt = llama_model_get_vocab(model_tgt); + const auto * vocab_dft = llama_model_get_vocab(model_dft); + const int32_t target_vocab_size = llama_vocab_n_tokens(vocab_tgt); + const int32_t draft_vocab_size = llama_vocab_n_tokens(vocab_dft); + const int32_t target_hidden_size = llama_model_n_embd(model_tgt); + const int32_t draft_hidden_size = llama_model_n_embd(model_dft); + const int32_t target_mask_token_id = llama_model_dflash_target_mask_token_id(model_tgt); + const int32_t expected_n_target_features = target_hidden_size > 0 ? target_hidden_size * n_target_layers : 0; + + if (target_mask_token_id != (int32_t) LLAMA_TOKEN_NULL && mask_token_id != target_mask_token_id) { + LOG_ERR("%s: DFlash mask token mismatch (draft=%d target=%d)\n", + __func__, mask_token_id, target_mask_token_id); + return; + } + + if (target_hidden_size <= 0 || draft_hidden_size <= 0) { + LOG_ERR("%s: invalid DFlash hidden sizes (draft=%d target=%d)\n", + __func__, draft_hidden_size, target_hidden_size); + return; + } + + if (expected_n_target_features <= 0 || n_target_features != expected_n_target_features) { + LOG_ERR("%s: DFlash target feature width mismatch (metadata=%d expected=%d target_hidden=%d target_layers=%d)\n", + __func__, n_target_features, expected_n_target_features, target_hidden_size, n_target_layers); + return; + } + + std::vector sorted_target_layer_ids = target_layer_ids; + std::sort(sorted_target_layer_ids.begin(), sorted_target_layer_ids.end()); + if (std::adjacent_find(sorted_target_layer_ids.begin(), sorted_target_layer_ids.end()) != sorted_target_layer_ids.end()) { + LOG_ERR("%s: duplicate DFlash target layer ids survived into runtime validation\n", __func__); + target_layer_ids.clear(); + return; + } + + const int32_t n_target_model_layers = llama_n_layer(model_tgt); + for (int32_t layer_id : target_layer_ids) { + if (layer_id < 0 || layer_id >= n_target_model_layers) { + LOG_ERR("%s: invalid DFlash target layer id %d for target model with %d layers\n", + __func__, layer_id, n_target_model_layers); + target_layer_ids.clear(); + return; + } + } + + const int32_t io_mode = llama_model_dflash_io_mode(model_dft, model_tgt); + if (io_mode == LLAMA_DFLASH_IO_MODE_INVALID) { + LOG_ERR("%s: DFlash draft is missing required IO tensors after target sharing\n", __func__); + return; + } + + if (io_mode == LLAMA_DFLASH_IO_MODE_MIXED) { + LOG_ERR("%s: DFlash IO contract must be fully shared or fully self-contained, but resolved to mixed mode\n", __func__); + return; + } + + if (io_mode == LLAMA_DFLASH_IO_MODE_SELF_CONTAINED && !llama_model_dflash_io_tensors_match(model_dft, target_hidden_size, target_vocab_size)) { + LOG_ERR("%s: DFlash self-contained IO tensors do not match the target hidden/vocab contract (target_hidden=%d target_vocab=%d)\n", + __func__, + target_hidden_size, + target_vocab_size); + return; + } + + if (!llama_set_dflash_capture_layers(ctx_tgt, target_layer_ids.data(), (int32_t) target_layer_ids.size())) { + LOG_ERR("%s: failed to configure DFlash target capture callback\n", __func__); + return; + } + + batch = llama_batch_init(std::max(1, block_size), 0, 1); + target_window.reserve((size_t) this->cross_ctx * (size_t) n_target_features); + target_window_stage.reserve((size_t) this->cross_ctx * (size_t) n_target_features); + target_window_ring.resize((size_t) this->cross_ctx * (size_t) n_target_features); + target_window_append_features.reserve((size_t) this->cross_ctx * (size_t) n_target_features); + target_window_pos.reserve((size_t) this->cross_ctx); + target_window_pos_stage.reserve((size_t) this->cross_ctx); + ready = true; + + llama_set_dflash_visible_cross_ctx(ctx_dft, this->cross_ctx); + llama_dflash_profile_reset(ctx_tgt); + llama_dflash_profile_reset(ctx_dft); + + std::ostringstream layers_oss; + for (size_t i = 0; i < target_layer_ids.size(); ++i) { + if (i > 0) { + layers_oss << ","; + } + layers_oss << target_layer_ids[i]; + } + + const char * io_mode_name = io_mode == LLAMA_DFLASH_IO_MODE_SHARED ? "shared" : "self-contained"; + LOG_INF("%s: DFlash context ready (n_ctx=%d, block_size=%d, cross_ctx=%d, n_target_features=%d, target_layer_ids=[%s])\n", + __func__, llama_n_ctx(ctx_dft), block_size, this->cross_ctx, n_target_features, layers_oss.str().c_str()); + LOG_INF("%s: DFlash artifact io=%s draft_vocab=%d target_vocab=%d draft_hidden=%d target_hidden=%d mask_token_id=%d target_mask_token_id=%d\n", + __func__, io_mode_name, draft_vocab_size, target_vocab_size, draft_hidden_size, target_hidden_size, mask_token_id, target_mask_token_id); + } + + ~common_speculative_state_dflash() override { + llama_clear_dflash_capture(ctx_tgt); + if (ctx_dft) { + llama_free(ctx_dft); + } + if (batch.token != nullptr) { + llama_batch_free(batch); + } + } + + void begin(const llama_tokens & prompt) override { + GGML_UNUSED(prompt); + llama_kv_cache_clear(ctx_dft); + llama_reset_dflash_kv_cache_state(ctx_dft); + n_window_updates = 0; + n_rows_seen = 0; + n_rows_dropped = 0; + n_context_shifts = 0; + n_draft_empty = 0; + n_set_target_fail = 0; + n_decode_fail = 0; + last_draft_pos_base = -1; + t_draft_decode_us = 0; + t_draft_sample_us = 0; + t_warmup_collect_us = 0; + t_warmup_append_us = 0; + t_accept_output_copy_us = 0; + t_accept_commit_us = 0; + t_accept_append_us = 0; + t_accept_append_filter_us = 0; + t_accept_append_window_alloc_us = 0; + t_accept_append_replace_us = 0; + t_accept_append_keep_old_us = 0; + t_accept_append_new_rows_us = 0; + t_accept_append_commit_detail_us = 0; + t_accept_append_log_us = 0; + n_warmup_collect_calls = 0; + n_warmup_collect_rows = 0; + n_warmup_append_calls = 0; + n_warmup_append_rows = 0; + n_accept_output_copy_calls = 0; + n_accept_output_copy_rows = 0; + n_accept_commit_calls = 0; + n_accept_commit_rows = 0; + n_accept_append_calls = 0; + n_accept_append_rows = 0; + n_accept_append_replace_calls = 0; + n_accept_append_slide_calls = 0; + llama_dflash_profile_reset(ctx_tgt); + llama_dflash_profile_reset(ctx_dft); + } + + void draft( + const common_params_speculative & params, + const llama_tokens & prompt_tgt, + llama_token id_last, + llama_tokens & result) override { + GGML_UNUSED(prompt_tgt); + + result.clear(); + if (!ready || target_window_rows <= 0) { + n_draft_empty++; + return; + } + + const int32_t n_keep = std::min(params.n_max, block_size - 1); + if (n_keep <= 0) { + return; + } + + const float * target_features = nullptr; + size_t target_feature_floats = 0; + llama_dflash_window_update window_update = { + target_window_version, + target_window_keep_rows, + target_window_append_rows, + target_window_replace, + target_window_append_features.empty() ? nullptr : target_window_append_features.data(), + target_window_append_features.size(), + }; + const llama_dflash_kv_cache_transition cache_plan = + llama_plan_dflash_kv_cache_transition_for_ctx(ctx_dft, window_update, target_window_rows); + + if (cache_plan.rebuild_cache) { + dflash_materialize_target_window_features(*this); + target_features = target_window.data(); + target_feature_floats = target_window.size(); + window_update.append_features = target_window.data(); + window_update.append_floats = target_window.size(); + window_update.append_rows = target_window_rows; + } + + if (!llama_set_dflash_target_features_view(ctx_dft, target_features, target_feature_floats, target_window_rows, target_window_pos.data(), &window_update)) { + LOG_ERR("%s: failed to set DFlash target features\n", __func__); + n_set_target_fail++; + return; + } + + llama_kv_cache_clear(ctx_dft); + batch.n_tokens = 0; + const int32_t batch_len = n_keep + 1; + const llama_pos draft_pos_base = last_target_pos >= 0 ? last_target_pos + 1 : (llama_pos) target_window_rows; + const llama_pos seed_pos = last_target_pos >= 0 ? last_target_pos : draft_pos_base - 1; + last_draft_pos_base = draft_pos_base; + common_batch_add(batch, id_last, seed_pos, { 0 }, false); + for (int32_t i = 1; i < batch_len; ++i) { + common_batch_add(batch, mask_token_id, draft_pos_base + (i - 1), { 0 }, i <= n_keep); + } + + const int64_t t_decode_us = ggml_time_us(); + if (llama_decode(ctx_dft, batch) != 0) { + LOG_ERR("%s: llama_decode() failed for DFlash draft batch\n", __func__); + n_decode_fail++; + batch.n_tokens = 0; + return; + } + t_draft_decode_us += (uint64_t) (ggml_time_us() - t_decode_us); + + result.reserve((size_t) n_keep); + const int64_t t_sample_us = ggml_time_us(); + for (int32_t i = 0; i < n_keep; ++i) { + llama_token id = llama_get_dflash_draft_token_ith(ctx_dft, i); + if (id == LLAMA_TOKEN_NULL) { + id = common_sampler_sample_speculative(nullptr, ctx_dft, i + 1, nullptr); + } + result.push_back(id); + } + t_draft_sample_us += (uint64_t) (ggml_time_us() - t_sample_us); + + batch.n_tokens = 0; + dflash_contract_log_draft(*this, n_keep, result.size()); + } + + void accept(uint16_t n_accepted) override { + GGML_UNUSED(n_accepted); + } +}; + +static void dflash_contract_log_append( + const common_speculative_state_dflash & state, + llama_seq_id seq_id, + const std::vector & new_positions) { + if (!dflash_contract_log_enabled()) { + return; + } + + static std::atomic counter = 0; + const uint64_t ordinal = counter.fetch_add(1, std::memory_order_relaxed); + if (ordinal >= 8) { + return; + } + + const dflash_contract_pos_summary incoming = dflash_contract_summarize_positions(new_positions); + const dflash_contract_pos_summary window = dflash_contract_summarize_positions(state.target_window_pos); + + LOG_INF("dflash contract append[%llu]: seq=%d incoming_rows=%zu incoming_pos=%s pos=[%d..%d] gaps=%d nonmono=%d window_rows=%d window_pos=%s pos=[%d..%d] gaps=%d nonmono=%d last_target_pos=%d\n", + (unsigned long long) (ordinal + 1), + (int) seq_id, + new_positions.size(), + dflash_contract_format_values(new_positions).c_str(), + (int) incoming.first, + (int) incoming.last, + incoming.gap_count, + incoming.nonmono_count, + state.target_window_rows, + dflash_contract_format_values(state.target_window_pos).c_str(), + (int) window.first, + (int) window.last, + window.gap_count, + window.nonmono_count, + (int) state.last_target_pos); +} + +static void dflash_contract_log_draft( + const common_speculative_state_dflash & state, + int32_t n_keep, + size_t result_size) { + if (!dflash_contract_log_enabled()) { + return; + } + + static std::atomic counter = 0; + const uint64_t ordinal = counter.fetch_add(1, std::memory_order_relaxed); + if (ordinal >= 8) { + return; + } + + const dflash_contract_pos_summary window = dflash_contract_summarize_positions(state.target_window_pos); + llama_dflash_profile_stats graph_stats = {}; + llama_dflash_profile_get_stats(state.ctx_dft, &graph_stats); + const int draft_delta = (state.last_target_pos >= 0 && state.last_draft_pos_base >= 0) + ? (int) (state.last_draft_pos_base - state.last_target_pos) + : -1; + const llama_pos seed_pos = state.last_target_pos; + const llama_pos mask_first_pos = state.last_draft_pos_base; + const llama_pos mask_last_pos = state.last_draft_pos_base >= 0 + ? state.last_draft_pos_base + n_keep - 1 + : -1; + + LOG_INF("dflash contract draft[%llu]: window_rows=%d window_pos=%s pos=[%d..%d] gaps=%d nonmono=%d last_target_pos=%d seed_pos=%d mask_pos=[%d..%d] sample_rows=[1..%d] output_rows=[1..%d] draft_pos_base=%d delta=%d n_keep=%d result=%zu set_target(missing/nonmono)=%llu/%llu graph(fallback/nonmono)=%llu/%llu graph_pos=[%d..%d]\n", + (unsigned long long) (ordinal + 1), + state.target_window_rows, + dflash_contract_format_values(state.target_window_pos).c_str(), + (int) window.first, + (int) window.last, + window.gap_count, + window.nonmono_count, + (int) state.last_target_pos, + (int) seed_pos, + (int) mask_first_pos, + (int) mask_last_pos, + n_keep, + n_keep, + (int) state.last_draft_pos_base, + draft_delta, + n_keep, + result_size, + (unsigned long long) graph_stats.set_target_missing_positions, + (unsigned long long) graph_stats.set_target_non_monotonic_positions, + (unsigned long long) graph_stats.graph_pos_fallbacks, + (unsigned long long) graph_stats.graph_pos_non_monotonic, + (int) graph_stats.last_pos_first, + (int) graph_stats.last_pos_last); +} + +struct dflash_append_breakdown { + uint64_t filter_us = 0; + uint64_t window_alloc_us = 0; + uint64_t replace_us = 0; + uint64_t keep_old_us = 0; + uint64_t new_rows_us = 0; + uint64_t commit_us = 0; + uint64_t log_us = 0; + bool replace_call = false; +}; + +static void dflash_record_window_update( + common_speculative_state_dflash & state, + int32_t keep_rows, + int32_t append_rows, + bool replace) { + state.target_window_keep_rows = std::max(0, keep_rows); + state.target_window_append_rows = std::max(0, append_rows); + state.target_window_replace = replace; + state.target_window_version++; +} + +static void dflash_ring_reset_rows( + common_speculative_state_dflash & state, + const float * rows, + int32_t n_rows) { + const size_t row_width = (size_t) state.n_target_features; + if (n_rows <= 0 || rows == nullptr) { + state.target_window_ring_write_pos = 0; + state.target_window_ring_filled = 0; + return; + } + + if (state.target_window_ring.size() != (size_t) state.cross_ctx * row_width) { + state.target_window_ring.resize((size_t) state.cross_ctx * row_width); + } + + std::memcpy(state.target_window_ring.data(), rows, (size_t) n_rows * row_width * sizeof(float)); + state.target_window_ring_write_pos = n_rows % state.cross_ctx; + state.target_window_ring_filled = n_rows; + state.target_window_materialized = false; +} + +static void dflash_ring_append_rows( + common_speculative_state_dflash & state, + const float * rows, + int32_t n_rows) { + const size_t row_width = (size_t) state.n_target_features; + if (n_rows <= 0 || rows == nullptr) { + return; + } + + if (state.target_window_ring.size() != (size_t) state.cross_ctx * row_width) { + state.target_window_ring.resize((size_t) state.cross_ctx * row_width); + } + + int32_t write_pos = state.target_window_ring_write_pos; + int32_t remaining = n_rows; + const float * src = rows; + while (remaining > 0) { + const int32_t chunk_rows = std::min(remaining, state.cross_ctx - write_pos); + std::memcpy( + state.target_window_ring.data() + (size_t) write_pos * row_width, + src, + (size_t) chunk_rows * row_width * sizeof(float)); + src += (size_t) chunk_rows * row_width; + remaining -= chunk_rows; + write_pos = (write_pos + chunk_rows) % state.cross_ctx; + } + + state.target_window_ring_write_pos = write_pos; + state.target_window_ring_filled = std::min(state.cross_ctx, state.target_window_ring_filled + n_rows); + state.target_window_materialized = false; +} + +static void dflash_materialize_target_window_features(common_speculative_state_dflash & state) { + if (state.target_window_materialized || state.target_window_rows <= 0) { + return; + } + + const size_t row_width = (size_t) state.n_target_features; + state.target_window.resize((size_t) state.target_window_rows * row_width); + + const int32_t read_start = (state.target_window_ring_write_pos - state.target_window_rows + state.cross_ctx) % state.cross_ctx; + const int32_t first_rows = std::min(state.target_window_rows, state.cross_ctx - read_start); + std::memcpy( + state.target_window.data(), + state.target_window_ring.data() + (size_t) read_start * row_width, + (size_t) first_rows * row_width * sizeof(float)); + + const int32_t second_rows = state.target_window_rows - first_rows; + if (second_rows > 0) { + std::memcpy( + state.target_window.data() + (size_t) first_rows * row_width, + state.target_window_ring.data(), + (size_t) second_rows * row_width * sizeof(float)); + } + + state.target_window_materialized = true; +} + +static bool dflash_append_target_features( + common_speculative_state_dflash & state, + const common_speculative_feature_view & features, + const llama_batch & batch, + llama_seq_id seq_id, + dflash_append_breakdown * breakdown = nullptr) { + GGML_UNUSED(batch); + + if (features.kind != COMMON_SPECULATIVE_FEATURE_HIDDEN_STATE || + features.width != state.n_target_features || + features.rows.empty() || + state.cross_ctx <= 0) { + return false; + } + + const size_t row_width = (size_t) state.n_target_features; + std::vector new_rows; + std::vector new_positions; + new_rows.reserve(features.rows.size() * row_width); + new_positions.reserve(features.rows.size()); + + const int64_t t_filter_us = ggml_time_us(); + for (const auto & row : features.rows) { + if (row.seq_id != seq_id || row.data == nullptr) { + continue; + } + + new_positions.push_back(row.pos); + new_rows.insert(new_rows.end(), row.data, row.data + row_width); + } + if (breakdown != nullptr) { + breakdown->filter_us += (uint64_t) (ggml_time_us() - t_filter_us); + } + + if (new_positions.empty()) { + return false; + } + + const int32_t n_rows = (int32_t) new_positions.size(); + state.n_window_updates++; + state.n_rows_seen += (size_t) n_rows; + if (n_rows >= state.cross_ctx) { + state.n_rows_dropped += (size_t) state.target_window_rows + (size_t) (n_rows - state.cross_ctx); + const int32_t keep_from = n_rows - state.cross_ctx; + const int64_t t_replace_us = ggml_time_us(); + state.target_window_pos.assign(new_positions.begin() + keep_from, new_positions.end()); + state.target_window_append_features.assign( + new_rows.begin() + (ptrdiff_t) keep_from * (ptrdiff_t) row_width, + new_rows.end()); + dflash_ring_reset_rows(state, state.target_window_append_features.data(), state.cross_ctx); + if (breakdown != nullptr) { + breakdown->replace_us += (uint64_t) (ggml_time_us() - t_replace_us); + breakdown->replace_call = true; + } + + const int64_t t_commit_us = ggml_time_us(); + state.target_window_rows = state.cross_ctx; + state.target_window_ring_filled = state.target_window_rows; + state.last_target_pos = state.target_window_pos.empty() ? -1 : state.target_window_pos.back(); + dflash_record_window_update(state, 0, state.target_window_rows, true); + if (breakdown != nullptr) { + breakdown->commit_us += (uint64_t) (ggml_time_us() - t_commit_us); + } + + const int64_t t_log_us = ggml_time_us(); + dflash_contract_log_append(state, seq_id, new_positions); + if (breakdown != nullptr) { + breakdown->log_us += (uint64_t) (ggml_time_us() - t_log_us); + } + return true; + } + + const int32_t keep_old_rows = std::min(state.target_window_rows, state.cross_ctx - n_rows); + state.n_rows_dropped += (size_t) std::max(0, state.target_window_rows - keep_old_rows); + const int64_t t_window_alloc_us = ggml_time_us(); + std::vector & next_window_pos = state.target_window_pos_stage; + next_window_pos.resize((size_t) (keep_old_rows + n_rows)); + if (breakdown != nullptr) { + breakdown->window_alloc_us += (uint64_t) (ggml_time_us() - t_window_alloc_us); + } + + if (keep_old_rows > 0) { + const int64_t t_keep_old_us = ggml_time_us(); + std::copy(state.target_window_pos.end() - keep_old_rows, state.target_window_pos.end(), next_window_pos.begin()); + if (breakdown != nullptr) { + breakdown->keep_old_us += (uint64_t) (ggml_time_us() - t_keep_old_us); + } + } + + const int64_t t_new_rows_us = ggml_time_us(); + state.target_window_append_features.assign(new_rows.begin(), new_rows.end()); + dflash_ring_append_rows(state, state.target_window_append_features.data(), n_rows); + std::copy(new_positions.begin(), new_positions.end(), next_window_pos.begin() + keep_old_rows); + if (breakdown != nullptr) { + breakdown->new_rows_us += (uint64_t) (ggml_time_us() - t_new_rows_us); + } + + const int64_t t_commit_us = ggml_time_us(); + state.target_window_pos.swap(next_window_pos); + next_window_pos.clear(); + state.target_window_rows = keep_old_rows + n_rows; + state.target_window_ring_filled = state.target_window_rows; + state.last_target_pos = state.target_window_pos.empty() ? -1 : state.target_window_pos.back(); + dflash_record_window_update(state, keep_old_rows, n_rows, false); + if (breakdown != nullptr) { + breakdown->commit_us += (uint64_t) (ggml_time_us() - t_commit_us); + } + + const int64_t t_log_us = ggml_time_us(); + dflash_contract_log_append(state, seq_id, new_positions); + if (breakdown != nullptr) { + breakdown->log_us += (uint64_t) (ggml_time_us() - t_log_us); + } + return true; +} + +static void dflash_clear_target_features(common_speculative_state_dflash & state) { + state.target_window.clear(); + state.target_window_pos.clear(); + state.target_window_stage.clear(); + state.target_window_pos_stage.clear(); + state.target_window_append_features.clear(); + state.target_window_rows = 0; + state.target_window_ring_write_pos = 0; + state.target_window_ring_filled = 0; + state.target_window_keep_rows = 0; + state.target_window_append_rows = 0; + state.target_window_replace = false; + state.target_window_materialized = false; + state.last_target_pos = -1; + llama_reset_dflash_kv_cache_state(state.ctx_dft); +} + +static void dflash_context_shift( + common_speculative_state_dflash & state, + llama_pos kv_keep, + llama_pos kv_discard, + llama_pos kv_past) { + if (kv_discard <= 0 || state.target_window_rows <= 0 || state.target_window_pos.empty()) { + return; + } + + dflash_materialize_target_window_features(state); + + const size_t row_width = (size_t) state.n_target_features; + const llama_pos discard_begin = kv_keep; + const llama_pos discard_end = kv_keep + kv_discard; + + std::vector shifted_rows; + std::vector shifted_positions; + shifted_rows.reserve(state.target_window.size()); + shifted_positions.reserve(state.target_window_pos.size()); + + for (int32_t row = 0; row < state.target_window_rows; ++row) { + llama_pos pos = state.target_window_pos[(size_t) row]; + if (pos >= discard_begin && pos < discard_end) { + continue; + } + + if (pos >= discard_end && pos < kv_past) { + pos -= kv_discard; + } + + const float * row_src = state.target_window.data() + (size_t) row * row_width; + shifted_rows.insert(shifted_rows.end(), row_src, row_src + row_width); + shifted_positions.push_back(pos); + } + + state.target_window = std::move(shifted_rows); + state.target_window_pos = std::move(shifted_positions); + state.target_window_rows = (int32_t) state.target_window_pos.size(); + dflash_ring_reset_rows(state, state.target_window.data(), state.target_window_rows); + state.last_target_pos = state.target_window_pos.empty() ? -1 : state.target_window_pos.back(); + dflash_record_window_update(state, 0, state.target_window_rows, true); + llama_reset_dflash_kv_cache_state(state.ctx_dft); + state.n_context_shifts++; +} diff --git a/common/speculative-impl.h b/common/speculative-impl.h deleted file mode 100644 index dbf8cfb1..00000000 --- a/common/speculative-impl.h +++ /dev/null @@ -1,1743 +0,0 @@ -// DFlash runtime state and draft path. -struct common_speculative_state_dflash : public common_speculative_state { - llama_context * ctx_tgt; - llama_context * ctx_dft; - - llama_batch batch = {}; - - int32_t block_size = 0; - int32_t mask_token_id = -1; - int32_t n_target_features = 0; - int32_t cross_ctx = 0; - bool ready = false; - - std::vector target_layer_ids; - std::vector target_window; - std::vector target_window_pos; - std::vector target_window_stage; - std::vector target_window_pos_stage; - std::vector target_window_ring; - std::vector target_window_append_features; - int32_t target_window_rows = 0; - int32_t target_window_ring_write_pos = 0; - int32_t target_window_ring_filled = 0; - uint64_t target_window_version = 0; - int32_t target_window_keep_rows = 0; - int32_t target_window_append_rows = 0; - bool target_window_replace = false; - bool target_window_materialized = false; - llama_pos last_target_pos = -1; - size_t n_window_updates = 0; - size_t n_rows_seen = 0; - size_t n_rows_dropped = 0; - size_t n_context_shifts = 0; - size_t n_draft_empty = 0; - size_t n_set_target_fail = 0; - size_t n_decode_fail = 0; - llama_pos last_draft_pos_base = -1; - - uint64_t t_draft_decode_us = 0; - uint64_t t_draft_sample_us = 0; - uint64_t t_warmup_collect_us = 0; - uint64_t t_warmup_append_us = 0; - uint64_t t_accept_output_copy_us = 0; - uint64_t t_accept_commit_us = 0; - uint64_t t_accept_append_us = 0; - uint64_t t_accept_append_filter_us = 0; - uint64_t t_accept_append_window_alloc_us = 0; - uint64_t t_accept_append_replace_us = 0; - uint64_t t_accept_append_keep_old_us = 0; - uint64_t t_accept_append_new_rows_us = 0; - uint64_t t_accept_append_commit_detail_us = 0; - uint64_t t_accept_append_log_us = 0; - size_t n_warmup_collect_calls = 0; - size_t n_warmup_collect_rows = 0; - size_t n_warmup_append_calls = 0; - size_t n_warmup_append_rows = 0; - size_t n_accept_output_copy_calls = 0; - size_t n_accept_output_copy_rows = 0; - size_t n_accept_commit_calls = 0; - size_t n_accept_commit_rows = 0; - size_t n_accept_append_calls = 0; - size_t n_accept_append_rows = 0; - size_t n_accept_append_replace_calls = 0; - size_t n_accept_append_slide_calls = 0; - - common_speculative_state_dflash( - enum common_speculative_type type, - llama_context * ctx_tgt, - llama_context * ctx_dft, - int32_t cross_ctx) - : common_speculative_state(type) - , ctx_tgt(ctx_tgt) - , ctx_dft(ctx_dft) - , cross_ctx(std::max(1, cross_ctx)) - { - const llama_model * model_tgt = llama_get_model(ctx_tgt); - const llama_model * model_dft = llama_get_model(ctx_dft); - - if (!common_speculative_are_dflash_compatible(model_tgt, model_dft)) { - LOG_ERR("%s: DFlash draft model vocab/tokenizer is incompatible with the target model\n", __func__); - return; - } - - block_size = llama_model_dflash_block_size(model_dft); - mask_token_id = llama_model_dflash_mask_token_id(model_dft); - n_target_features = llama_model_dflash_n_target_features(model_dft); - const int32_t n_target_layers = llama_model_dflash_n_target_layers(model_dft); - - if (block_size <= 0 || mask_token_id < 0 || n_target_features <= 0 || n_target_layers <= 0) { - LOG_ERR("%s: invalid DFlash metadata (block_size=%d, mask_token_id=%d, n_target_features=%d, n_target_layers=%d)\n", - __func__, block_size, mask_token_id, n_target_features, n_target_layers); - return; - } - - target_layer_ids.resize((size_t) n_target_layers); - if (llama_model_dflash_target_layer_ids(model_dft, target_layer_ids.data(), n_target_layers) != n_target_layers) { - LOG_ERR("%s: failed to read DFlash target layer ids\n", __func__); - target_layer_ids.clear(); - return; - } - - const auto * vocab_tgt = llama_model_get_vocab(model_tgt); - const auto * vocab_dft = llama_model_get_vocab(model_dft); - const int32_t target_vocab_size = llama_vocab_n_tokens(vocab_tgt); - const int32_t draft_vocab_size = llama_vocab_n_tokens(vocab_dft); - const int32_t target_hidden_size = llama_model_n_embd(model_tgt); - const int32_t draft_hidden_size = llama_model_n_embd(model_dft); - const int32_t target_mask_token_id = llama_model_dflash_target_mask_token_id(model_tgt); - const int32_t expected_n_target_features = target_hidden_size > 0 ? target_hidden_size * n_target_layers : 0; - - if (target_mask_token_id != (int32_t) LLAMA_TOKEN_NULL && mask_token_id != target_mask_token_id) { - LOG_ERR("%s: DFlash mask token mismatch (draft=%d target=%d)\n", - __func__, mask_token_id, target_mask_token_id); - return; - } - - if (target_hidden_size <= 0 || draft_hidden_size <= 0) { - LOG_ERR("%s: invalid DFlash hidden sizes (draft=%d target=%d)\n", - __func__, draft_hidden_size, target_hidden_size); - return; - } - - if (expected_n_target_features <= 0 || n_target_features != expected_n_target_features) { - LOG_ERR("%s: DFlash target feature width mismatch (metadata=%d expected=%d target_hidden=%d target_layers=%d)\n", - __func__, n_target_features, expected_n_target_features, target_hidden_size, n_target_layers); - return; - } - - std::vector sorted_target_layer_ids = target_layer_ids; - std::sort(sorted_target_layer_ids.begin(), sorted_target_layer_ids.end()); - if (std::adjacent_find(sorted_target_layer_ids.begin(), sorted_target_layer_ids.end()) != sorted_target_layer_ids.end()) { - LOG_ERR("%s: duplicate DFlash target layer ids survived into runtime validation\n", __func__); - target_layer_ids.clear(); - return; - } - - const int32_t n_target_model_layers = llama_n_layer(model_tgt); - for (int32_t layer_id : target_layer_ids) { - if (layer_id < 0 || layer_id >= n_target_model_layers) { - LOG_ERR("%s: invalid DFlash target layer id %d for target model with %d layers\n", - __func__, layer_id, n_target_model_layers); - target_layer_ids.clear(); - return; - } - } - - const int32_t io_mode = llama_model_dflash_io_mode(model_dft, model_tgt); - if (io_mode == LLAMA_DFLASH_IO_MODE_INVALID) { - LOG_ERR("%s: DFlash draft is missing required IO tensors after target sharing\n", __func__); - return; - } - - if (io_mode == LLAMA_DFLASH_IO_MODE_MIXED) { - LOG_ERR("%s: DFlash IO contract must be fully shared or fully self-contained, but resolved to mixed mode\n", __func__); - return; - } - - if (io_mode == LLAMA_DFLASH_IO_MODE_SELF_CONTAINED && !llama_model_dflash_io_tensors_match(model_dft, target_hidden_size, target_vocab_size)) { - LOG_ERR("%s: DFlash self-contained IO tensors do not match the target hidden/vocab contract (target_hidden=%d target_vocab=%d)\n", - __func__, - target_hidden_size, - target_vocab_size); - return; - } - - if (!llama_set_dflash_capture_layers(ctx_tgt, target_layer_ids.data(), (int32_t) target_layer_ids.size())) { - LOG_ERR("%s: failed to configure DFlash target capture callback\n", __func__); - return; - } - - batch = llama_batch_init(std::max(1, block_size), 0, 1); - target_window.reserve((size_t) this->cross_ctx * (size_t) n_target_features); - target_window_stage.reserve((size_t) this->cross_ctx * (size_t) n_target_features); - target_window_ring.resize((size_t) this->cross_ctx * (size_t) n_target_features); - target_window_append_features.reserve((size_t) this->cross_ctx * (size_t) n_target_features); - target_window_pos.reserve((size_t) this->cross_ctx); - target_window_pos_stage.reserve((size_t) this->cross_ctx); - ready = true; - - llama_set_dflash_visible_cross_ctx(ctx_dft, this->cross_ctx); - llama_dflash_profile_reset(ctx_tgt); - llama_dflash_profile_reset(ctx_dft); - - std::ostringstream layers_oss; - for (size_t i = 0; i < target_layer_ids.size(); ++i) { - if (i > 0) { - layers_oss << ","; - } - layers_oss << target_layer_ids[i]; - } - - const char * io_mode_name = io_mode == LLAMA_DFLASH_IO_MODE_SHARED ? "shared" : "self-contained"; - LOG_INF("%s: DFlash context ready (n_ctx=%d, block_size=%d, cross_ctx=%d, n_target_features=%d, target_layer_ids=[%s])\n", - __func__, llama_n_ctx(ctx_dft), block_size, this->cross_ctx, n_target_features, layers_oss.str().c_str()); - LOG_INF("%s: DFlash artifact io=%s draft_vocab=%d target_vocab=%d draft_hidden=%d target_hidden=%d mask_token_id=%d target_mask_token_id=%d\n", - __func__, io_mode_name, draft_vocab_size, target_vocab_size, draft_hidden_size, target_hidden_size, mask_token_id, target_mask_token_id); - } - - ~common_speculative_state_dflash() override { - llama_clear_dflash_capture(ctx_tgt); - if (ctx_dft) { - llama_free(ctx_dft); - } - if (batch.token != nullptr) { - llama_batch_free(batch); - } - } - - void begin(const llama_tokens & prompt) override { - GGML_UNUSED(prompt); - llama_kv_cache_clear(ctx_dft); - llama_reset_dflash_kv_cache_state(ctx_dft); - n_window_updates = 0; - n_rows_seen = 0; - n_rows_dropped = 0; - n_context_shifts = 0; - n_draft_empty = 0; - n_set_target_fail = 0; - n_decode_fail = 0; - last_draft_pos_base = -1; - t_draft_decode_us = 0; - t_draft_sample_us = 0; - t_warmup_collect_us = 0; - t_warmup_append_us = 0; - t_accept_output_copy_us = 0; - t_accept_commit_us = 0; - t_accept_append_us = 0; - t_accept_append_filter_us = 0; - t_accept_append_window_alloc_us = 0; - t_accept_append_replace_us = 0; - t_accept_append_keep_old_us = 0; - t_accept_append_new_rows_us = 0; - t_accept_append_commit_detail_us = 0; - t_accept_append_log_us = 0; - n_warmup_collect_calls = 0; - n_warmup_collect_rows = 0; - n_warmup_append_calls = 0; - n_warmup_append_rows = 0; - n_accept_output_copy_calls = 0; - n_accept_output_copy_rows = 0; - n_accept_commit_calls = 0; - n_accept_commit_rows = 0; - n_accept_append_calls = 0; - n_accept_append_rows = 0; - n_accept_append_replace_calls = 0; - n_accept_append_slide_calls = 0; - llama_dflash_profile_reset(ctx_tgt); - llama_dflash_profile_reset(ctx_dft); - } - - void draft( - const common_params_speculative & params, - const llama_tokens & prompt_tgt, - llama_token id_last, - llama_tokens & result) override { - GGML_UNUSED(prompt_tgt); - - result.clear(); - if (!ready || target_window_rows <= 0) { - n_draft_empty++; - return; - } - - const int32_t n_keep = std::min(params.n_max, block_size - 1); - if (n_keep <= 0) { - return; - } - - const float * target_features = nullptr; - size_t target_feature_floats = 0; - llama_dflash_window_update window_update = { - target_window_version, - target_window_keep_rows, - target_window_append_rows, - target_window_replace, - target_window_append_features.empty() ? nullptr : target_window_append_features.data(), - target_window_append_features.size(), - }; - const llama_dflash_kv_cache_transition cache_plan = - llama_plan_dflash_kv_cache_transition_for_ctx(ctx_dft, window_update, target_window_rows); - - if (cache_plan.rebuild_cache) { - dflash_materialize_target_window_features(*this); - target_features = target_window.data(); - target_feature_floats = target_window.size(); - window_update.append_features = target_window.data(); - window_update.append_floats = target_window.size(); - window_update.append_rows = target_window_rows; - } - - if (!llama_set_dflash_target_features_view(ctx_dft, target_features, target_feature_floats, target_window_rows, target_window_pos.data(), &window_update)) { - LOG_ERR("%s: failed to set DFlash target features\n", __func__); - n_set_target_fail++; - return; - } - - llama_kv_cache_clear(ctx_dft); - batch.n_tokens = 0; - const int32_t batch_len = n_keep + 1; - const llama_pos draft_pos_base = last_target_pos >= 0 ? last_target_pos + 1 : (llama_pos) target_window_rows; - const llama_pos seed_pos = last_target_pos >= 0 ? last_target_pos : draft_pos_base - 1; - last_draft_pos_base = draft_pos_base; - common_batch_add(batch, id_last, seed_pos, { 0 }, false); - for (int32_t i = 1; i < batch_len; ++i) { - common_batch_add(batch, mask_token_id, draft_pos_base + (i - 1), { 0 }, i <= n_keep); - } - - const int64_t t_decode_us = ggml_time_us(); - if (llama_decode(ctx_dft, batch) != 0) { - LOG_ERR("%s: llama_decode() failed for DFlash draft batch\n", __func__); - n_decode_fail++; - batch.n_tokens = 0; - return; - } - t_draft_decode_us += (uint64_t) (ggml_time_us() - t_decode_us); - - result.reserve((size_t) n_keep); - const int64_t t_sample_us = ggml_time_us(); - for (int32_t i = 0; i < n_keep; ++i) { - // Use argmax in GPU when available - llama_token id = llama_get_dflash_draft_token_ith(ctx_dft, i); - if (id == LLAMA_TOKEN_NULL) { - id = common_sampler_sample_speculative(nullptr, ctx_dft, i + 1, nullptr); - } - result.push_back(id); - } - t_draft_sample_us += (uint64_t) (ggml_time_us() - t_sample_us); - - batch.n_tokens = 0; - dflash_contract_log_draft(*this, n_keep, result.size()); - } - - void accept(uint16_t n_accepted) override { - GGML_UNUSED(n_accepted); - } -}; - -static void dflash_contract_log_append( - const common_speculative_state_dflash & state, - llama_seq_id seq_id, - const std::vector & new_positions) { - if (!dflash_contract_log_enabled()) { - return; - } - - static std::atomic counter = 0; - const uint64_t ordinal = counter.fetch_add(1, std::memory_order_relaxed); - if (ordinal >= 8) { - return; - } - - const dflash_contract_pos_summary incoming = dflash_contract_summarize_positions(new_positions); - const dflash_contract_pos_summary window = dflash_contract_summarize_positions(state.target_window_pos); - - LOG_INF("dflash contract append[%llu]: seq=%d incoming_rows=%zu incoming_pos=%s pos=[%d..%d] gaps=%d nonmono=%d window_rows=%d window_pos=%s pos=[%d..%d] gaps=%d nonmono=%d last_target_pos=%d\n", - (unsigned long long) (ordinal + 1), - (int) seq_id, - new_positions.size(), - dflash_contract_format_values(new_positions).c_str(), - (int) incoming.first, - (int) incoming.last, - incoming.gap_count, - incoming.nonmono_count, - state.target_window_rows, - dflash_contract_format_values(state.target_window_pos).c_str(), - (int) window.first, - (int) window.last, - window.gap_count, - window.nonmono_count, - (int) state.last_target_pos); -} - -static void dflash_contract_log_draft( - const common_speculative_state_dflash & state, - int32_t n_keep, - size_t result_size) { - if (!dflash_contract_log_enabled()) { - return; - } - - static std::atomic counter = 0; - const uint64_t ordinal = counter.fetch_add(1, std::memory_order_relaxed); - if (ordinal >= 8) { - return; - } - - const dflash_contract_pos_summary window = dflash_contract_summarize_positions(state.target_window_pos); - llama_dflash_profile_stats graph_stats = {}; - llama_dflash_profile_get_stats(state.ctx_dft, &graph_stats); - const int draft_delta = (state.last_target_pos >= 0 && state.last_draft_pos_base >= 0) - ? (int) (state.last_draft_pos_base - state.last_target_pos) - : -1; - const llama_pos seed_pos = state.last_target_pos; - const llama_pos mask_first_pos = state.last_draft_pos_base; - const llama_pos mask_last_pos = state.last_draft_pos_base >= 0 - ? state.last_draft_pos_base + n_keep - 1 - : -1; - - LOG_INF("dflash contract draft[%llu]: window_rows=%d window_pos=%s pos=[%d..%d] gaps=%d nonmono=%d last_target_pos=%d seed_pos=%d mask_pos=[%d..%d] sample_rows=[1..%d] output_rows=[1..%d] draft_pos_base=%d delta=%d n_keep=%d result=%zu set_target(missing/nonmono)=%llu/%llu graph(fallback/nonmono)=%llu/%llu graph_pos=[%d..%d]\n", - (unsigned long long) (ordinal + 1), - state.target_window_rows, - dflash_contract_format_values(state.target_window_pos).c_str(), - (int) window.first, - (int) window.last, - window.gap_count, - window.nonmono_count, - (int) state.last_target_pos, - (int) seed_pos, - (int) mask_first_pos, - (int) mask_last_pos, - n_keep, - n_keep, - (int) state.last_draft_pos_base, - draft_delta, - n_keep, - result_size, - (unsigned long long) graph_stats.set_target_missing_positions, - (unsigned long long) graph_stats.set_target_non_monotonic_positions, - (unsigned long long) graph_stats.graph_pos_fallbacks, - (unsigned long long) graph_stats.graph_pos_non_monotonic, - (int) graph_stats.last_pos_first, - (int) graph_stats.last_pos_last); -} - -struct common_speculative_state_draft : public common_speculative_state { - llama_context * ctx_tgt; // only used for retokenizing from ctx_dft - llama_context * ctx_dft; - - common_sampler * smpl; - - llama_batch batch; - llama_tokens prompt_dft; - - bool vocab_cmpt = true; // whether retokenization is needed - std::unordered_map vocab_map; - - common_speculative_state_draft( - enum common_speculative_type type, - llama_context * ctx_tgt, - llama_context * ctx_dft, - const std::vector> & replacements) - : common_speculative_state(type) - , ctx_tgt(ctx_tgt) - , ctx_dft(ctx_dft) - { - batch = llama_batch_init(llama_n_batch(ctx_dft), 0, 1); - smpl = nullptr; - { - struct common_params_sampling params; - params.top_k = 10; - params.samplers_sequence = { - llama_sampler_type::TOP_K, - llama_sampler_type::DIST, // needed to get probabilities - }; - smpl = common_sampler_init(llama_get_model(ctx_dft), params); - } - - vocab_cmpt = common_speculative_are_compatible(llama_get_model(ctx_tgt), llama_get_model(ctx_dft)); - LOG_DBG("vocab_cmpt = %d\n", vocab_cmpt); - - if (!vocab_cmpt) { - LOG_WRN("the target and draft vocabs are not compatible - tokens will be translated between the two\n"); - - for (const auto & pair : replacements) { - vocab_map[pair.first] = pair.second; - } - } - } - - ~common_speculative_state_draft() override { - llama_free(ctx_dft); - - common_sampler_free(smpl); - - llama_batch_free(batch); - } - - void begin(const llama_tokens & prompt) override { - GGML_UNUSED(prompt); - } - - void draft( - const common_params_speculative & params, - const llama_tokens & prompt_tgt, - llama_token id_last, - llama_tokens & result) override { - auto * spec = this; - - auto & batch = spec->batch; - auto & ctx_tgt = spec->ctx_tgt; - auto & ctx_dft = spec->ctx_dft; - auto & smpl = spec->smpl; - auto & prompt_dft = spec->prompt_dft; - - int reuse_i = 0; - int reuse_n = 0; - - const int n_ctx = llama_n_ctx(ctx_dft) - params.n_max; - - llama_tokens prompt_cnv; - if (!spec->vocab_cmpt) { - // convert id_last to draft vocab. llama_detokenize is called directly to avoid an allocation - const auto * model_tgt = llama_get_model(ctx_tgt); - const auto * vocab_tgt = llama_model_get_vocab(model_tgt); - - std::string text; - - text = common_detokenize(ctx_tgt, prompt_tgt, true); - text = replace_to_dft(text); - - LOG_DBG("%s: main->draft detokenized string: '%s'\n", __func__, text.c_str()); - - prompt_cnv = common_tokenize(ctx_dft, text, false, true); - - - - int32_t n_chars = llama_detokenize(vocab_tgt, &id_last, 1, nullptr, 0, false, false); - GGML_ASSERT(n_chars < 0 && "failed to detokenize id_last"); - - text.resize(-n_chars); - llama_detokenize(vocab_tgt, &id_last, 1, text.data(), text.size(), false, false); - text = replace_to_dft(text); - - LOG_DBG("main->draft detokenized id_last(%d): '%s'\n", id_last, text.c_str()); - id_last = common_tokenize(ctx_dft, text, false, true)[0]; - } - - const llama_tokens & prompt_cur = spec->vocab_cmpt ? prompt_tgt : prompt_cnv; - - const int i_start = std::max(0, (int) prompt_cur.size() - n_ctx); - - // reuse as much as possible from the old draft context - // ideally, the draft context should be as big as the target context and we will always reuse the entire prompt - for (int i = 0; i < (int) prompt_dft.size(); ++i) { - int cur = 0; - while (i_start + cur < (int) prompt_cur.size() && - i + cur < (int) prompt_dft.size() && - prompt_cur[i_start + cur] == prompt_dft[i + cur]) { - cur++; - } - - if ((cur >= 256 || n_ctx >= (int) prompt_cur.size()) && cur > reuse_n) { - reuse_i = i; - reuse_n = cur; - } - } - - LOG_DBG("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt_dft.size()); - - result.clear(); - result.reserve(params.n_max); - - if (reuse_n == 0) { - llama_kv_cache_clear(ctx_dft); - prompt_dft.clear(); - } else { - // this happens when a previous draft has been discarded (for example, due to being too small), but the - // target model agreed with it. in this case, we simply pass back the previous results to save compute - if (reuse_i + reuse_n < (int) prompt_dft.size() && prompt_dft[reuse_i + reuse_n] == id_last) { - for (int i = reuse_i + reuse_n + 1; i < (int) prompt_dft.size(); ++i) { - result.push_back(prompt_dft[i]); - - if (params.n_max <= (int) result.size()) { - break; - } - } - - return; - } - - if (reuse_i > 0) { - llama_kv_cache_seq_rm (ctx_dft, 0, 0, reuse_i); - llama_kv_cache_seq_add(ctx_dft, 0, reuse_i, -1, -reuse_i); - - prompt_dft.erase(prompt_dft.begin(), prompt_dft.begin() + reuse_i); - } - - if (reuse_n < (int) prompt_dft.size()) { - llama_kv_cache_seq_rm (ctx_dft, 0, reuse_n, -1); - prompt_dft.erase(prompt_dft.begin() + reuse_n, prompt_dft.end()); - } - } - - // prepare a batch to evaluate any new tokens in the prompt - common_batch_clear(batch); - - for (size_t i = i_start + reuse_n; i < prompt_cur.size(); ++i) { - //LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_cur[i]); - common_batch_add(batch, prompt_cur[i], i - i_start, { 0 }, false); - - prompt_dft.push_back(prompt_cur[i]); - } - - // we should rarely end-up here during normal decoding - if (batch.n_tokens > 0) { - //LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str()); - - llama_decode(ctx_dft, batch); - } - - const llama_pos n_past = prompt_dft.size(); - - LOG_DBG("%s: n_past = %d\n", __func__, n_past); - - common_batch_clear(batch); - common_batch_add (batch, id_last, n_past, { 0 }, true); - - prompt_dft.push_back(id_last); - - //LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx_dft, prompt_dft).c_str()); - - llama_decode(ctx_dft, batch); - - common_sampler_reset(smpl); - - // sample n_draft tokens from the draft model - for (int i = 0; i < params.n_max; ++i) { - common_batch_clear(batch); - - common_sampler_sample(smpl, ctx_dft, 0, true); - - const auto * cur_p = common_sampler_get_candidates(smpl, true); - - for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) { - LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", - k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str()); - } - - // add drafted token for each sequence - const llama_token id = cur_p->data[0].id; - - common_sampler_accept(smpl, nullptr, id, true); - - // only collect very high-confidence draft tokens - if (cur_p->data[0].p < params.p_min) { - if (i == 0) { - result.push_back(id); - } - break; - } - - result.push_back(id); - - if (params.n_max <= (int) result.size()) { - break; - } - - - common_batch_add(batch, id, n_past + i + 1, { 0 }, true); - - // evaluate the drafted tokens on the draft model - llama_decode(ctx_dft, batch); - - prompt_dft.push_back(id); - } - - if (!spec->vocab_cmpt) { - std::string detokenized = common_detokenize(ctx_dft, result, true); - detokenized = replace_to_tgt(detokenized); - LOG_DBG("draft->main detokenized string: '%s'\n", detokenized.c_str()); - result = common_tokenize(ctx_tgt, detokenized, false, true); - if (result.size() > (size_t)params.n_max) { - result.resize(params.n_max); - } - } - } - - void accept(uint16_t n_accepted) override { - // noop - GGML_UNUSED(n_accepted); - } - - std::string replace_to_dft(const std::string & input) const { - std::string result = input; - - for (const auto & pair : this->vocab_map) { - size_t pos = result.find(pair.first); - while (pos != std::string::npos) { - result.replace(pos, pair.first.length(), pair.second); - pos = result.find(pair.first, pos + pair.second.length()); - } - } - - return result; - } - - std::string replace_to_tgt(const std::string & input) const { - std::string result = input; - - for (const auto & pair : this->vocab_map) { - size_t pos = result.find(pair.second); - while (pos != std::string::npos) { - result.replace(pos, pair.second.length(), pair.first); - pos = result.find(pair.second, pos + pair.first.length()); - } - } - - return result; - } -}; - -struct common_speculative_state_eagle3 : public common_speculative_state { - common_speculative_state_eagle3(enum common_speculative_type type) : common_speculative_state(type) {} - - void begin(const llama_tokens & prompt) override { - GGML_UNUSED(prompt); - } - - void draft( - const common_params_speculative & params, - const llama_tokens & prompt_tgt, - llama_token id_last, - llama_tokens & draft_tokens) override { - // TODO: implement - GGML_UNUSED(params); - GGML_UNUSED(prompt_tgt); - GGML_UNUSED(id_last); - GGML_UNUSED(draft_tokens); - } - - void accept(uint16_t n_accepted) override { - // noop - GGML_UNUSED(n_accepted); - } -}; - -// state of self-speculation (simple implementation, not ngram-map) -struct common_speculative_state_ngram_simple : public common_speculative_state { - common_ngram_simple_config config; - - common_speculative_state_ngram_simple( - enum common_speculative_type type, - common_ngram_simple_config config) - : common_speculative_state(type), config(config) {} - - void begin(const llama_tokens & prompt) override { - GGML_UNUSED(prompt); - } - - void draft( - const common_params_speculative & params, - const llama_tokens & prompt_tgt, - llama_token id_last, - llama_tokens & result) override { - - result = common_ngram_simple_draft(config, prompt_tgt, id_last); - GGML_UNUSED(params); - } - - void accept(uint16_t n_accepted) override { - // noop - GGML_UNUSED(n_accepted); - } -}; - -struct common_speculative_state_ngram_map_k : public common_speculative_state { - // draft ngram map for speculative decoding without draft model - common_ngram_map map; - - common_speculative_state_ngram_map_k( - enum common_speculative_type type, - common_ngram_map map) - : common_speculative_state(type), map(std::move(map)) {} - - void begin(const llama_tokens & prompt) override { - common_ngram_map_begin(map, prompt); - } - - void draft( - const common_params_speculative & params, - const llama_tokens & prompt_tgt, - llama_token id_last, - llama_tokens & result) override { - common_ngram_map_draft(map, prompt_tgt, id_last, result); - GGML_UNUSED(params); - } - - void accept(uint16_t n_accepted) override { - common_ngram_map_accept(map, n_accepted); - } -}; - -struct common_speculative_state_ngram_mod : public common_speculative_state { - common_ngram_mod & mod; - - // the last position in the prompt that was added to the ngram container - size_t i_last = 0; - - // length of the last drafted n‑gram (number of tokens returned by draft) - size_t n_draft_last = 0; - - // consecutive accept rounds with low acceptance fraction (< 0.5) - int n_low = 0; - - // enable trace logging if LLAMA_TRACE is set - const bool verbose; - - common_speculative_state_ngram_mod(enum common_speculative_type type, common_ngram_mod & mod) - : common_speculative_state(type), mod(mod), verbose(std::getenv("LLAMA_TRACE") != nullptr) { - static_assert(sizeof(llama_token) == sizeof(common_ngram_mod::entry_t)); - } - - void begin(const llama_tokens & prompt) override { - i_last = 0; - - n_draft_last = 0; - n_low = 0; - - const size_t n = mod.get_n(); - - if (prompt.size() < n) { - return; - } - - for (size_t i = 0; i < prompt.size() - n; ++i) { - mod.add(prompt.data() + i); - } - - i_last = prompt.size() - n; - - const double f = (double)mod.get_used() / (double)mod.size(); - LOG_INF("%s: ngram_mod occupancy = %zu/%zu (%.2f)\n", __func__, mod.get_used(), mod.size(), f); - - constexpr double f_thold = 0.25; - if (f > f_thold) { - LOG_WRN("%s: ngram_mod occupancy %.2f exceeds threshold (%.2f) - resetting\n", __func__, f, f_thold); - - mod.reset(); - } - } - - void draft( - const common_params_speculative & params, - const llama_tokens & prompt_tgt, - llama_token id_last, - llama_tokens & result) override { - GGML_UNUSED(params); - - n_draft_last = 0; - - const size_t cur_len = prompt_tgt.size(); - if (cur_len < mod.get_n()) { - return; - } - - const size_t n = mod.get_n(); - - // add new ngrams in chunks - if (i_last + 32 < cur_len) { - for (size_t i = i_last; i < cur_len - n; ++i) { - mod.add(prompt_tgt.data() + i); - } - - i_last = cur_len - n; - } - - result.resize(n + params.n_max); - for (size_t i = 0; i < n - 1; ++i) { - result[i] = prompt_tgt[cur_len - n + 1 + i]; - } - result[n - 1] = id_last; - - for (int i = 0; i < params.n_max; ++i) { - const llama_token token = mod.get(result.data() + i); - if (token == common_ngram_mod::EMPTY) { - if (i < params.n_min) { - result.clear(); - return; - } - - result.resize(n + i); - break; - } - result[n + i] = token; - } - - // only return the m tokens that were drafted - for (size_t i = 0; n + i < result.size(); ++i) { - result[i] = result[n + i]; - } - result.resize(result.size() - n); - - // store length of drafted n‑gram for later acceptance analysis - n_draft_last = result.size(); - } - - void accept(uint16_t n_accepted) override { - if (verbose) { - LOG_INF("%s: accepted %d tokens from %zu drafted tokens\n", __func__, n_accepted, n_draft_last); - } - - // compute acceptance fraction if we have a recorded draft length - if (n_draft_last > 0) { - const double f_acc = (double)n_accepted / (double)n_draft_last; - if (f_acc < 0.5) { - n_low++; - if (n_low >= 3) { - LOG_WRN("%s: low acceptance streak (%d) – resetting ngram_mod\n", __func__, n_low); - - mod.reset(); - n_low = 0; - i_last = 0; - } - } else { - n_low = 0; - } - } - } -}; - -struct common_speculative_state_ngram_cache : public common_speculative_state { - uint16_t n_draft; - bool save_dynamic; - bool save_static; - - common_ngram_cache ngram_cache_context; - common_ngram_cache ngram_cache_dynamic; - common_ngram_cache ngram_cache_static; - - size_t cache_size = 0; // number of tokens in n-gram cache - - common_speculative_state_ngram_cache( - const enum common_speculative_type type, - const std::string & path_static, - const std::string & path_dynamic, - uint16_t n_draft, - bool save_dynamic, - bool save_static) - : common_speculative_state(type) - , n_draft(n_draft) - , save_dynamic(save_dynamic) - , save_static(save_static) - { - if (!path_static.empty()) { - try { - ngram_cache_static = common_ngram_cache_load(path_static); - } catch (...) { - LOG_ERR("failed to open static lookup cache: %s", path_static.c_str()); - GGML_ABORT("Couldn't read static lookup cache"); - } - } - - if (!path_dynamic.empty()) { - try { - ngram_cache_dynamic = common_ngram_cache_load(path_dynamic); - } catch (...) { - LOG_ERR("failed to open dynamic lookup cache: %s", path_dynamic.c_str()); - GGML_ABORT("Couldn't read dynamic lookup cache"); - } - } - } - - void begin(const llama_tokens & prompt) override { - GGML_UNUSED(prompt); - } - - void draft( - const common_params_speculative & params, - const llama_tokens & prompt_tgt, - llama_token id_last, - llama_tokens & result) override { - GGML_UNUSED(params); - - if (cache_size < prompt_tgt.size() + 1) { - llama_tokens tokens_new; - tokens_new.reserve(prompt_tgt.size() + 1 - cache_size); - for (size_t j = cache_size; j < prompt_tgt.size(); ++j) { - tokens_new.push_back(prompt_tgt[j]); - } - tokens_new.push_back(id_last); // add the last token - - // Update context ngram cache with new prompt_tgt: - common_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, - tokens_new, tokens_new.size(), false); - cache_size = prompt_tgt.size() + 1; - } - - llama_tokens inp; - inp.reserve(prompt_tgt.size() + 1); - for (size_t j = 0; j < prompt_tgt.size(); ++j) { - inp.push_back(prompt_tgt[j]); - } - inp.push_back(id_last); - - result.push_back(id_last); - - common_ngram_cache_draft(inp, result, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, - ngram_cache_context, - ngram_cache_dynamic, - ngram_cache_static); - - if (result.size() > 0) { - // delete first token in result (which is the id_last token) - result.erase(result.begin()); - } - } - - void accept(uint16_t n_accepted) override { - // TODO: noop - GGML_UNUSED(n_accepted); - } -}; - -struct common_speculative_state_suffix : public common_speculative_state { - common_suffix_tree tree; - common_suffix_tree corpus_tree; - bool has_corpus = false; - size_t cache_size = 0; - - // Acceptance feedback - size_t n_draft_last = 0; - bool had_accept = false; - int n_low = 0; - float base_p_min = 0.1f; - float eff_p_min = 0.1f; - - common_speculative_state_suffix( - enum common_speculative_type type, - int max_depth, - const std::string & corpus_path, - const llama_model * model) - : common_speculative_state(type) - , tree(max_depth) - , corpus_tree(max_depth) - { - if (!corpus_path.empty()) { - std::function(const std::string &)> tokenize_fn; - if (model) { - tokenize_fn = [model](const std::string & text) -> std::vector { - return common_tokenize(model, text, false, true); - }; - } - has_corpus = corpus_tree.load_corpus(corpus_path, tokenize_fn); - } - } - - void begin(const llama_tokens & prompt) override { - cache_size = 0; - n_draft_last = 0; - had_accept = false; - n_low = 0; - GGML_UNUSED(prompt); - } - - void draft( - const common_params_speculative & params, - const llama_tokens & prompt_tgt, - llama_token id_last, - llama_tokens & result) override { - - base_p_min = params.p_min; - if (n_draft_last > 0 && !had_accept) { - if (++n_low >= 3) { - eff_p_min = std::min(eff_p_min + 0.1f, 0.5f); - n_low = 0; - } - } - had_accept = false; - - if (cache_size < prompt_tgt.size() + 1) { - llama_tokens tokens_new; - tokens_new.reserve(prompt_tgt.size() + 1 - cache_size); - for (size_t j = cache_size; j < prompt_tgt.size(); ++j) { - tokens_new.push_back(prompt_tgt[j]); - } - tokens_new.push_back(id_last); - - tree.extend(tokens_new.data(), (int)tokens_new.size()); - cache_size = prompt_tgt.size() + 1; - } - - const int ctx_len = std::min((int)(prompt_tgt.size() + 1), tree.max_depth()); - llama_tokens context; - context.reserve(ctx_len); - const int ctx_start = (int)prompt_tgt.size() + 1 - ctx_len; - for (int j = ctx_start; j < (int)prompt_tgt.size(); ++j) { - context.push_back(prompt_tgt[j]); - } - context.push_back(id_last); - const int min_match_len = std::max(1, params.suffix_min_match_len); - - result = tree.speculate( - context.data(), (int)context.size(), - params.n_max, - eff_p_min, - 1, - min_match_len); - - if (has_corpus) { - auto corpus_result = corpus_tree.speculate( - context.data(), (int)context.size(), - params.n_max, - eff_p_min, - 1, - min_match_len); - if (corpus_result.size() > result.size()) { - result = std::move(corpus_result); - } - } - - n_draft_last = result.size(); - } - - void accept(uint16_t n_accepted) override { - if (n_draft_last == 0) { - return; - } - had_accept = true; - const double f_acc = (double)n_accepted / (double)n_draft_last; - if (f_acc < 0.5) { - if (++n_low >= 3) { - eff_p_min = std::min(eff_p_min + 0.1f, 0.5f); - n_low = 0; - } - } else { - n_low = 0; - if (eff_p_min > base_p_min) { - eff_p_min = std::max(eff_p_min - 0.05f, base_p_min); - } - } - } -}; - -struct common_speculative { - std::vector configs; // resolved stage config for each implementation - std::vector> impls; // list of implementations to use and their states - common_speculative_checkpoint checkpoint; - common_speculative_state * curr_impl = nullptr; // current implementation in use (for stats) - std::unique_ptr tuner; - int last_n_drafted = 0; - int64_t t_step_start_us = 0; -}; - -static bool common_speculative_stage_chain_matches( - const std::vector & stages, - const std::vector & configs) { - if (stages.size() != configs.size()) { - return false; - } - - for (size_t i = 0; i < stages.size(); ++i) { - if (stages[i].type != configs[i].type) { - return false; - } - } - - return true; -} - -static common_params_speculative common_speculative_get_runtime_params( - const common_speculative_config & config, - const common_params_speculative & params, - const common_speculative_stage_params & stage) { - common_params_speculative result = config.params; - - result.type = config.type; - result.n_max = stage.has_n_max_override() ? stage.n_max : params.n_max; - result.n_min = stage.has_n_min_override() ? stage.n_min : params.n_min; - result.p_min = stage.has_p_min_override() ? stage.p_min : params.p_min; - - if (config.type == COMMON_SPECULATIVE_TYPE_SUFFIX) { - result.suffix_min_match_len = stage.has_suffix_min_match_len_override() - ? stage.suffix_min_match_len - : params.suffix_min_match_len; - } - - result.n_max = std::max(result.n_max, 0); - result.n_min = std::max(0, std::min(result.n_min, result.n_max)); - result.stages.clear(); - - return result; -} - -static common_ngram_map get_common_ngram_map(const common_speculative_config & config) { - uint16_t size_key = config.params.ngram_size_n; - uint16_t size_value = config.params.ngram_size_m; - bool key_only = (config.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K); - uint16_t min_hits = config.params.ngram_min_hits; - - return common_ngram_map(size_key, size_value, key_only, min_hits); -} - -static common_speculative_state_ngram_cache create_state_ngram_cache( - const std::string & path_static, const std::string & path_dynamic, - const common_speculative_config & config) { - uint16_t n_draft = 8; // TODO get from config? - - // TODO bool param in common/common.h to set save_static/save_dynamic? - bool save_static = false; - bool save_dynamic = false; - - common_speculative_state_ngram_cache state(config.type, path_static, path_dynamic, n_draft, save_static, save_dynamic); - - return state; -} - -std::string common_speculative_type_name_str() { - std::string result; - for (size_t i = 0; i < common_speculative_types.size(); i++) { - if (i > 0) { - result += ", "; - } - result += common_speculative_type_to_str(common_speculative_types[i]); - } - return result; -} - -std::string common_speculative_type_to_str(enum common_speculative_type type) { - switch (type) { - case COMMON_SPECULATIVE_TYPE_NONE: return "none"; - case COMMON_SPECULATIVE_TYPE_DRAFT: return "draft"; - case COMMON_SPECULATIVE_TYPE_DFLASH: return "dflash"; - case COMMON_SPECULATIVE_TYPE_MTP: return "mtp"; - case COMMON_SPECULATIVE_TYPE_EAGLE3: return "eagle3"; - case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return "ngram_simple"; - case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: return "ngram_map_k"; - case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: return "ngram_map_k4v"; - case COMMON_SPECULATIVE_TYPE_NGRAM_MOD: return "ngram_mod"; - case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: return "ngram_cache"; - case COMMON_SPECULATIVE_TYPE_SUFFIX: return "suffix"; - default: return "unknown"; - } -} - -enum common_speculative_type common_speculative_type_from_name(const std::string & name) { - std::string normalized = name; - std::replace(normalized.begin(), normalized.end(), '-', '_'); - - const auto it = common_speculative_type_from_name_map.find(normalized); - if (it == common_speculative_type_from_name_map.end()) { - return COMMON_SPECULATIVE_TYPE_COUNT; - } - return it->second; -} - -bool common_speculative_is_compat(llama_context * ctx_tgt) { - bool res = true; - - llama_kv_cache_clear(ctx_tgt); - - // eval 2 tokens to check if the context is compatible - std::vector tmp; - tmp.push_back(0); - tmp.push_back(0); - - int ret = llama_decode(ctx_tgt, llama_batch_get_one(tmp.data(), tmp.size(), 0, 0)); - if (ret != 0) { - LOG_ERR("%s: llama_decode() failed: %d\n", __func__, ret); - res = false; - goto done; - } - - // try to remove the last tokens - if (!llama_kv_cache_seq_rm(ctx_tgt, 0, 1, -1)) { - LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__); - res = false; - goto done; - } - -done: - llama_kv_cache_clear(ctx_tgt); - llama_synchronize(ctx_tgt); - - return res; -} - -// initialization of the speculative decoding system -// -common_speculative * common_speculative_init( - common_params_speculative & params, - llama_context * ctx_tgt) { - std::string chain_error; - if (!common_speculative_validate_chain(params, &chain_error)) { - LOG_ERR("%s: invalid speculative stage chain: %s\n", __func__, chain_error.c_str()); - return nullptr; - } - - const auto stages = params.get_resolved_stages(); - if (params.model_dft && llama_model_is_gemma4_mtp_assistant(params.model_dft)) { - const bool has_draft_stage = std::any_of(stages.begin(), stages.end(), [](const common_speculative_stage_params & stage) { - return stage.type == COMMON_SPECULATIVE_TYPE_DRAFT; - }); - - if (has_draft_stage) { - LOG_ERR("%s: Gemma4 assistant models only support MTP stages; omit -md for self-spec-only runs or use -mtp/--spec-stage mtp for assistant-backed MTP\n", __func__); - return nullptr; - } - } - - const bool has_dflash_stage = std::any_of(stages.begin(), stages.end(), [](const common_speculative_stage_params & stage) { - return stage.type == COMMON_SPECULATIVE_TYPE_DFLASH; - }); - - const bool needs_draft_ctx = std::any_of(stages.begin(), stages.end(), [¶ms](const common_speculative_stage_params & stage) { - return stage.type == COMMON_SPECULATIVE_TYPE_DRAFT || - stage.type == COMMON_SPECULATIVE_TYPE_DFLASH || - (stage.type == COMMON_SPECULATIVE_TYPE_MTP && params.model_dft != nullptr); - }); - - llama_context * ctx_dft = nullptr; - if (needs_draft_ctx) { - if (!params.model_dft) { - LOG_ERR("%s: draft speculative stage requires a loaded draft model\n", __func__); - return nullptr; - } - - llama_context_params cparams_dft = params.cparams_dft; - - if (has_dflash_stage) { - if (!llama_model_share_dflash_io_tensors(params.model_dft, llama_get_model(ctx_tgt))) { - LOG_ERR("%s: failed to share target IO tensors with DFlash draft model\n", __func__); - return nullptr; - } - - int32_t max_cross_ctx = 0; - for (const auto & stage : stages) { - if (stage.type != COMMON_SPECULATIVE_TYPE_DFLASH) { - continue; - } - - max_cross_ctx = std::max(max_cross_ctx, params.with_stage_overrides(stage).dflash_cross_ctx); - } - - const int32_t block_size = llama_model_dflash_block_size(params.model_dft); - if (block_size <= 0) { - LOG_ERR("%s: invalid DFlash draft block size\n", __func__); - return nullptr; - } - - const int64_t required_n_ctx = (int64_t) max_cross_ctx + (int64_t) block_size; - if (required_n_ctx > std::numeric_limits::max()) { - LOG_ERR("%s: invalid DFlash draft context size cross_ctx=%d block_size=%d required_n_ctx=%lld\n", - __func__, max_cross_ctx, block_size, (long long) required_n_ctx); - return nullptr; - } - - cparams_dft.n_ctx = (uint32_t) required_n_ctx; - } - - ctx_dft = llama_init_from_model(params.model_dft, cparams_dft); - if (ctx_dft == nullptr) { - LOG_ERR("%s", "failed to create draft context\n"); - return nullptr; - } - } - - // Compute the implementations to use based on the resolved stage chain. - std::vector configs = {}; - configs.reserve(stages.size()); - - for (const auto & stage : stages) { - common_params_speculative stage_params = params.with_stage_overrides(stage); - - if (stage.type == COMMON_SPECULATIVE_TYPE_NGRAM_MOD && !stage_params.ngram_mod) { - stage_params.ngram_mod = std::make_shared(stage_params.ngram_size_n, 4*1024*1024); - - LOG_INF("%s: initialized ngram_mod with n=%d, size=%zu (%.3f MB)\n", __func__, - stage_params.ngram_size_n, stage_params.ngram_mod->size(), - (float)(stage_params.ngram_mod->size_bytes())/1024/1024); - - if (stage_params.ngram_size_n < 16) { - LOG_WRN("%s: ngram_mod n=%d is too small - poor quality is possible, see: https://github.com/ggml-org/llama.cpp/pull/19164\n", __func__, stage_params.ngram_size_n); - } - } - - configs.push_back(common_speculative_config(stage, stage_params)); - } - - if (!configs.empty() && llama_model_has_recurrent(llama_get_model(ctx_tgt))) { - const int ckpt_tokens = std::max(1, params.get_max_stage_n_max() + 1); - const int actual_mode = llama_spec_ckpt_init(ctx_tgt, params.recurrent_ckpt_mode, ckpt_tokens); - if (actual_mode == LLAMA_SPEC_CKPT_NONE) { - LOG_ERR("%s: failed to prepare recurrent checkpoint mode '%s' during speculative init (max_tokens=%d)\n", - __func__, - params.recurrent_ckpt_mode == LLAMA_SPEC_CKPT_PER_STEP ? "per-step" : - params.recurrent_ckpt_mode == LLAMA_SPEC_CKPT_GPU_FALLBACK ? "gpu-fallback" : - params.recurrent_ckpt_mode == LLAMA_SPEC_CKPT_CPU ? "cpu" : "auto", - ckpt_tokens); - if (ctx_dft != nullptr) { - llama_free(ctx_dft); - } - return nullptr; - } - llama_spec_ckpt_discard(ctx_tgt); - params.recurrent_ckpt_mode = actual_mode; - } - - std::vector> impls = {}; - - for (const common_speculative_config & config : configs) { - LOG_DBG("%s: adding implementation %s\n", __func__, common_speculative_type_to_str(config.type).c_str()); - switch (config.type) { - case COMMON_SPECULATIVE_TYPE_NONE: - break; - case COMMON_SPECULATIVE_TYPE_DRAFT: { - impls.push_back(std::make_unique(config.type, - /* .ctx_tgt = */ ctx_tgt, - /* .ctx_dft = */ ctx_dft, - /* .replacements = */ config.params.replacements - )); - break; - } - case COMMON_SPECULATIVE_TYPE_DFLASH: { - auto state = std::make_unique( - config.type, - ctx_tgt, - ctx_dft, - config.params.dflash_cross_ctx); - if (!state->ready) { - LOG_ERR("%s: failed to initialize DFlash speculative state\n", __func__); - return nullptr; - } - impls.push_back(std::move(state)); - ctx_dft = nullptr; - break; - } - case COMMON_SPECULATIVE_TYPE_MTP: { - llama_context * ctx_mtp = ctx_dft; - if (!ctx_mtp) { - const llama_model * model = llama_get_model(ctx_tgt); - ctx_mtp = llama_init_from_model(const_cast(model), config.params.cparams_dft); - if (!ctx_mtp) { - LOG_ERR("%s: failed to create MTP context\n", __func__); - return nullptr; - } - } - ctx_dft = nullptr; - - const bool use_constant_draft_positions = llama_model_is_gemma4_mtp_assistant(llama_get_model(ctx_mtp)); - impls.push_back(std::make_unique( - config.type, ctx_tgt, ctx_mtp, use_constant_draft_positions)); - break; - } - case COMMON_SPECULATIVE_TYPE_EAGLE3: { - impls.push_back(std::make_unique(config.type)); - break; - } - case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: { - common_ngram_map ngram_map = get_common_ngram_map(config); - - uint16_t ngram_size_key = ngram_map.size_key; - uint16_t mgram_size_value = ngram_map.size_value; - - auto config_simple = common_ngram_simple_config { - /* .size_ngram = */ ngram_size_key, - /* .size_mgram = */ mgram_size_value - }; - auto state = std::make_unique( - /* .type = */ config.type, - /* .state = */ config_simple - ); - impls.push_back(std::move(state)); - break; - } - case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: - case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: { - impls.push_back(std::make_unique( - (config.type), - get_common_ngram_map(config) - )); - break; - } - case COMMON_SPECULATIVE_TYPE_NGRAM_MOD: { - GGML_ASSERT(config.params.ngram_mod); - impls.push_back(std::make_unique(config.type, *config.params.ngram_mod)); - break; - } - case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: { - auto state = create_state_ngram_cache( - config.params.lookup_cache_static, config.params.lookup_cache_dynamic, config); - impls.push_back(std::make_unique(state)); - break; - } - case COMMON_SPECULATIVE_TYPE_SUFFIX: { - int depth = config.params.suffix_max_depth > 0 ? config.params.suffix_max_depth : 64; - const llama_model * model = llama_get_model(ctx_tgt); - impls.push_back(std::make_unique( - config.type, depth, config.params.suffix_corpus, model)); - break; - } - default: - break; - } - } - - if (impls.empty()) { - LOG_WRN("%s", "no implementations specified for speculative decoding\n"); - return nullptr; - } - - auto * result = new common_speculative { - /* .configs = */ std::move(configs), - /* .impls = */ std::move(impls) - }; - - // initialize autotune if requested - if (params.autotune && params.has_composite_stage_chain()) { - LOG_WRN("Autotune disabled — explicit speculative stage chains are not supported yet\n"); - } else if (params.autotune && !result->impls.empty()) { - auto actual_type = result->impls[0]->type; - if (actual_type != COMMON_SPECULATIVE_TYPE_NONE && - actual_type != COMMON_SPECULATIVE_TYPE_EAGLE3) { - result->tuner = std::make_unique(); - result->tuner->init(actual_type, params, llama_get_model(ctx_tgt)); - LOG_DBG("Autotune initialized for %s, tuning %zu parameters\n", - common_speculative_type_to_str(actual_type).c_str(), - result->tuner->coords.size()); - } else { - LOG_WRN("Autotune disabled — speculative type %s is not supported for autotuning\n", - common_speculative_type_to_str(actual_type).c_str()); - } - } - - return result; -} - -void common_speculative_free(common_speculative * spec) { - if (spec == nullptr) { - return; - } - - spec->checkpoint.clear(); - delete spec; -} - -void common_speculative_begin(common_speculative * spec, const llama_tokens & prompt) { - if (spec == nullptr) { - return; - } - - for (auto & impl : spec->impls) { - common_time_meas tm(impl->t_begin_us, !impl->gen_perf); - impl->begin(prompt); - impl->n_call_begin++; - } -} - -llama_tokens common_speculative_draft( - common_speculative * spec, - common_params_speculative & params, - const llama_tokens & prompt_tgt, // specified in target model vocab - llama_token id_last, - llama_pos draft_base_pos, - llama_seq_id draft_seq_id) { - llama_tokens result; - - spec->t_step_start_us = ggml_time_us(); - - // apply autotune proposal if enabled - if (spec->tuner && spec->tuner->enabled) { - spec->tuner->propose(params); - } - - const auto runtime_stages = params.get_resolved_stages(); - const bool use_runtime_stage_overrides = common_speculative_stage_chain_matches(runtime_stages, spec->configs); - - spec->curr_impl = nullptr; // reset current implementation - - for (size_t i = 0; i < spec->impls.size(); ++i) { - auto & impl = spec->impls[i]; - const auto & runtime_stage = use_runtime_stage_overrides ? runtime_stages[i] : spec->configs[i].stage; - common_params_speculative impl_params = common_speculative_get_runtime_params(spec->configs[i], params, runtime_stage); - result.clear(); - - { - common_time_meas tm(impl->t_draft_us, !impl->gen_perf); - impl->draft(impl_params, prompt_tgt, id_last, draft_base_pos, draft_seq_id, result); - impl->n_call_draft++; - } - - if (result.empty()) { - continue; - } - - if (common_speculative_type_is_self_spec(impl->type) && impl_params.n_min > 0 && (int)result.size() < impl_params.n_min) { - LOG_DBG("%s: impl %s drafted %zu tokens, below fallback threshold %d - trying next implementation\n", - __func__, common_speculative_type_to_str(impl->type).c_str(), result.size(), impl_params.n_min); - result.clear(); - continue; - } - LOG_DBG("%s: called impl %s, hist size = %zu, call_count = %zu, gen = %zu\n", __func__, - common_speculative_type_to_str(impl.get()->type).c_str(), prompt_tgt.size(), - impl.get()->n_call_draft, result.size()); - - spec->curr_impl = impl.get(); - impl->n_gen_drafts++; - impl->n_gen_tokens += result.size(); - - break; // We have a draft, so break out of the loop and return it. - } - - // store draft count for tuner feedback - if (spec->tuner && spec->tuner->enabled) { - spec->last_n_drafted = (int)result.size(); - } - - return result; -} - -void common_speculative_accept(common_speculative * spec, uint16_t n_accepted) { - if (spec->tuner && spec->tuner->enabled && spec->t_step_start_us > 0) { - int64_t step_time_us = ggml_time_us() - spec->t_step_start_us; - double step_tps = (step_time_us > 100) - ? (n_accepted + 1.0) * 1e6 / (double)step_time_us - : 0.0; - spec->tuner->accept_feedback(n_accepted, spec->last_n_drafted, step_tps); - spec->t_step_start_us = 0; - } - - common_speculative_state * impl = spec->curr_impl; - - if (!impl) { - return; - } - - { - common_time_meas tm(impl->t_accept_us, !impl->gen_perf); - if (n_accepted > 0) { - impl->n_acc_drafts++; - impl->n_acc_tokens += n_accepted; - } - - impl->accept(n_accepted); - impl->n_call_accept++; - } - - if (impl->type != COMMON_SPECULATIVE_TYPE_MTP) { - if (auto * mtp_state = common_speculative_get_mtp_state(spec); mtp_state != nullptr) { - mtp_invalidate_cached_drafts(*mtp_state); - } - } -} - -static bool common_speculative_has_type(const common_speculative * spec, common_speculative_type type) { - if (spec == nullptr) { - return false; - } - - return std::any_of(spec->configs.begin(), spec->configs.end(), [type](const common_speculative_config & config) { - return config.type == type; - }); -} - -static int common_speculative_ctx_mtp_n_embd(llama_context * ctx) { - return ctx ? (int) llama_mtp_state_n_embd(ctx) : 0; -} - -static bool common_speculative_batch_token_has_seq_id( - const llama_batch & batch, - int token_index, - llama_seq_id seq_id) { - if (batch.n_seq_id == nullptr || batch.seq_id == nullptr || batch.n_seq_id[token_index] <= 0 || batch.seq_id[token_index] == nullptr) { - return false; - } - - for (int i = 0; i < batch.n_seq_id[token_index]; ++i) { - if (batch.seq_id[token_index][i] == seq_id) { - return true; - } - } - - return false; -} - -static bool common_speculative_batch_is_exact_single_seq( - const llama_batch & batch, - llama_seq_id seq_id) { - if (batch.n_tokens <= 0 || batch.n_seq_id == nullptr || batch.seq_id == nullptr) { - return false; - } - - for (int i = 0; i < batch.n_tokens; ++i) { - if (batch.n_seq_id[i] != 1 || batch.seq_id[i] == nullptr || batch.seq_id[i][0] != seq_id) { - return false; - } - } - - return true; -} - -static int common_speculative_copy_seq_batch( - const llama_batch & batch, - llama_seq_id seq_id, - llama_batch & seq_batch) { - if (batch.token == nullptr || batch.pos == nullptr) { - return -1; - } - - if (batch.n_tokens < 1) { - return 0; - } - - std::vector token_indices; - token_indices.reserve(batch.n_tokens); - for (int i = 0; i < batch.n_tokens; ++i) { - if (common_speculative_batch_token_has_seq_id(batch, i, seq_id)) { - token_indices.push_back(i); - } - } - - if (token_indices.empty()) { - return 0; - } - - seq_batch = llama_batch_init((int) token_indices.size(), 0, 1); - for (const int i : token_indices) { - common_batch_add(seq_batch, batch.token[i], batch.pos[i], { seq_id }, batch.logits != nullptr && batch.logits[i]); - } - - return (int) token_indices.size(); -} - -static bool common_speculative_feature_view_copy_batch_rows( - const common_speculative_feature_view & view, - const llama_batch & batch, - llama_seq_id seq_id, - std::vector * hidden_rows) { - if (hidden_rows == nullptr || view.kind != COMMON_SPECULATIVE_FEATURE_HIDDEN_STATE || view.width <= 0 || batch.n_tokens <= 0 || batch.pos == nullptr) { - return false; - } - - std::unordered_map rows_by_pos; - rows_by_pos.reserve(view.rows.size()); - for (const auto & row : view.rows) { - if (row.seq_id == seq_id && row.data != nullptr) { - rows_by_pos[row.pos] = row.data; - } - } - - hidden_rows->clear(); - hidden_rows->reserve((size_t) batch.n_tokens * view.width); - for (int i = 0; i < batch.n_tokens; ++i) { - auto it = rows_by_pos.find(batch.pos[i]); - if (it == rows_by_pos.end()) { - hidden_rows->clear(); - return false; - } - - hidden_rows->insert(hidden_rows->end(), it->second, it->second + view.width); - } - - return hidden_rows->size() == (size_t) batch.n_tokens * view.width; -} diff --git a/common/speculative.cpp b/common/speculative.cpp index b491c244..8f15f8ef 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -135,69 +135,6 @@ static bool common_speculative_are_compatible( return true; } -static bool common_speculative_are_dflash_compatible( - const llama_model * model_tgt, - const llama_model * model_dft) { - const llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt); - const llama_vocab * vocab_dft = llama_model_get_vocab(model_dft); - - if (llama_vocab_type(vocab_tgt) != llama_vocab_type(vocab_dft)) { - LOG_DBG("%s: DFlash draft model vocab type must match the target model\n", __func__); - return false; - } - - const bool add_bos_tgt = llama_vocab_get_add_bos(vocab_tgt); - const bool add_bos_dft = llama_vocab_get_add_bos(vocab_dft); - const bool add_eos_tgt = llama_vocab_get_add_eos(vocab_tgt); - const bool add_eos_dft = llama_vocab_get_add_eos(vocab_dft); - const llama_token bos_tgt = llama_vocab_bos(vocab_tgt); - const llama_token bos_dft = llama_vocab_bos(vocab_dft); - const llama_token eos_tgt = llama_vocab_eos(vocab_tgt); - const llama_token eos_dft = llama_vocab_eos(vocab_dft); - - if (add_bos_tgt != add_bos_dft || add_eos_tgt != add_eos_dft || - (add_bos_tgt && bos_tgt != bos_dft) || - (add_eos_tgt && eos_tgt != eos_dft)) { - LOG_DBG("%s: DFlash draft special tokens must match the target model (add_bos=%d/%d add_eos=%d/%d bos=%d/%d eos=%d/%d)\n", - __func__, - (int) add_bos_tgt, - (int) add_bos_dft, - (int) add_eos_tgt, - (int) add_eos_dft, - (int) bos_tgt, - (int) bos_dft, - (int) eos_tgt, - (int) eos_dft); - return false; - } - - const int n_vocab_tgt = llama_vocab_n_tokens(vocab_tgt); - const int n_vocab_dft = llama_vocab_n_tokens(vocab_dft); - const int vocab_diff = n_vocab_tgt > n_vocab_dft - ? n_vocab_tgt - n_vocab_dft - : n_vocab_dft - n_vocab_tgt; - - if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) { - LOG_DBG("%s: DFlash draft vocab size differs too much from the target model (%d vs %d)\n", - __func__, n_vocab_dft, n_vocab_tgt); - return false; - } - - for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) { - const char * token_text_tgt = llama_vocab_get_text(vocab_tgt, i); - const char * token_text_dft = llama_vocab_get_text(vocab_dft, i); - - if (std::strcmp(token_text_tgt, token_text_dft) != 0) { - LOG_DBG("%s: DFlash draft token %d differs - target '%s', draft '%s'\n", __func__, i, - common_token_to_piece(vocab_tgt, i).c_str(), - common_token_to_piece(vocab_dft, i).c_str()); - return false; - } - } - - return true; -} - // state of an implementation of speculative decoding // // each implementation has a unique type and a state that is implementation-specific @@ -251,27 +188,11 @@ struct common_speculative_state { struct common_speculative_state_mtp; struct common_speculative_state_dflash; -static void dflash_contract_log_append( - const common_speculative_state_dflash & state, - llama_seq_id seq_id, - const std::vector & new_positions); -static void dflash_contract_log_draft( - const common_speculative_state_dflash & state, - int32_t n_keep, - size_t result_size); - static common_speculative_state_mtp * common_speculative_get_mtp_state(common_speculative * spec); static const common_speculative_state_mtp * common_speculative_get_mtp_state(const common_speculative * spec); static common_speculative_state_dflash * common_speculative_get_dflash_state(common_speculative * spec); static const common_speculative_state_dflash * common_speculative_get_dflash_state(const common_speculative * spec); static int32_t common_speculative_feature_width(const common_speculative * spec); -static void dflash_materialize_target_window_features(common_speculative_state_dflash & state); -static void dflash_ring_reset_rows(common_speculative_state_dflash & state, const float * rows, int32_t n_rows); -static void dflash_append_target_features( - common_speculative_state_dflash & state, - const float * feature_rows, - int32_t n_rows); -static void dflash_clear_target_features(common_speculative_state_dflash & state); static void mtp_invalidate_cached_drafts(common_speculative_state_mtp & state); static bool common_speculative_checkpoint_save( common_speculative_checkpoint & ckpt, @@ -298,92 +219,6 @@ static std::vector mtp_speculative_gen_draft( static int32_t mtp_update_kv_cache(struct llama_context * ctx, const llama_batch & batch, bool is_prompt_warmup); -static bool dflash_contract_log_enabled() { - const char * env = std::getenv("IK_DFLASH_CONTRACT_LOG"); - if (env == nullptr || *env == '\0') { - return false; - } - - return std::strcmp(env, "0") != 0 && - std::strcmp(env, "false") != 0 && - std::strcmp(env, "off") != 0; -} - -static bool dflash_stats_log_enabled() { - const char * env = std::getenv("IK_DFLASH_STATS_LOG"); - if (env == nullptr || *env == '\0') { - return false; - } - - return std::strcmp(env, "0") != 0 && - std::strcmp(env, "false") != 0 && - std::strcmp(env, "off") != 0; -} - -template -static std::string dflash_contract_format_values( - const std::vector & values, - size_t edge_count = 4) { - std::ostringstream oss; - oss << '['; - if (values.empty()) { - oss << ']'; - return oss.str(); - } - - const size_t head = std::min(edge_count, values.size()); - for (size_t i = 0; i < head; ++i) { - if (i > 0) { - oss << ','; - } - oss << values[i]; - } - - if (values.size() > edge_count * 2) { - oss << ",...,"; - for (size_t i = values.size() - edge_count; i < values.size(); ++i) { - if (i > values.size() - edge_count) { - oss << ','; - } - oss << values[i]; - } - } else { - for (size_t i = head; i < values.size(); ++i) { - oss << ',' << values[i]; - } - } - - oss << ']'; - return oss.str(); -} - -struct dflash_contract_pos_summary { - llama_pos first = -1; - llama_pos last = -1; - int32_t gap_count = 0; - int32_t nonmono_count = 0; -}; - -static dflash_contract_pos_summary dflash_contract_summarize_positions( - const std::vector & positions) { - dflash_contract_pos_summary summary; - if (positions.empty()) { - return summary; - } - - summary.first = positions.front(); - summary.last = positions.back(); - for (size_t i = 1; i < positions.size(); ++i) { - if (positions[i] <= positions[i - 1]) { - summary.nonmono_count++; - } else if (positions[i] != positions[i - 1] + 1) { - summary.gap_count++; - } - } - - return summary; -} - struct mtp_last_embd { std::vector embd; float prob = 0.0f; @@ -500,7 +335,1327 @@ struct common_speculative_state_mtp : public common_speculative_state { } }; -#include "speculative-impl.h" +#include "speculative-dflash-impl.h" + +struct common_speculative_state_draft : public common_speculative_state { + llama_context * ctx_tgt; // only used for retokenizing from ctx_dft + llama_context * ctx_dft; + + common_sampler * smpl; + + llama_batch batch; + llama_tokens prompt_dft; + + bool vocab_cmpt = true; // whether retokenization is needed + std::unordered_map vocab_map; + + common_speculative_state_draft( + enum common_speculative_type type, + llama_context * ctx_tgt, + llama_context * ctx_dft, + const std::vector> & replacements) + : common_speculative_state(type) + , ctx_tgt(ctx_tgt) + , ctx_dft(ctx_dft) + { + batch = llama_batch_init(llama_n_batch(ctx_dft), 0, 1); + smpl = nullptr; + { + struct common_params_sampling params; + params.top_k = 10; + params.samplers_sequence = { + llama_sampler_type::TOP_K, + llama_sampler_type::DIST, // needed to get probabilities + }; + smpl = common_sampler_init(llama_get_model(ctx_dft), params); + } + + vocab_cmpt = common_speculative_are_compatible(llama_get_model(ctx_tgt), llama_get_model(ctx_dft)); + LOG_DBG("vocab_cmpt = %d\n", vocab_cmpt); + + if (!vocab_cmpt) { + LOG_WRN("the target and draft vocabs are not compatible - tokens will be translated between the two\n"); + + for (const auto & pair : replacements) { + vocab_map[pair.first] = pair.second; + } + } + } + + ~common_speculative_state_draft() override { + llama_free(ctx_dft); + + common_sampler_free(smpl); + + llama_batch_free(batch); + } + + void begin(const llama_tokens & prompt) override { + GGML_UNUSED(prompt); + } + + void draft( + const common_params_speculative & params, + const llama_tokens & prompt_tgt, + llama_token id_last, + llama_tokens & result) override { + auto * spec = this; + + auto & batch = spec->batch; + auto & ctx_tgt = spec->ctx_tgt; + auto & ctx_dft = spec->ctx_dft; + auto & smpl = spec->smpl; + auto & prompt_dft = spec->prompt_dft; + + int reuse_i = 0; + int reuse_n = 0; + + const int n_ctx = llama_n_ctx(ctx_dft) - params.n_max; + + llama_tokens prompt_cnv; + if (!spec->vocab_cmpt) { + // convert id_last to draft vocab. llama_detokenize is called directly to avoid an allocation + const auto * model_tgt = llama_get_model(ctx_tgt); + const auto * vocab_tgt = llama_model_get_vocab(model_tgt); + + std::string text; + + text = common_detokenize(ctx_tgt, prompt_tgt, true); + text = replace_to_dft(text); + + LOG_DBG("%s: main->draft detokenized string: '%s'\n", __func__, text.c_str()); + + prompt_cnv = common_tokenize(ctx_dft, text, false, true); + + + + int32_t n_chars = llama_detokenize(vocab_tgt, &id_last, 1, nullptr, 0, false, false); + GGML_ASSERT(n_chars < 0 && "failed to detokenize id_last"); + + text.resize(-n_chars); + llama_detokenize(vocab_tgt, &id_last, 1, text.data(), text.size(), false, false); + text = replace_to_dft(text); + + LOG_DBG("main->draft detokenized id_last(%d): '%s'\n", id_last, text.c_str()); + id_last = common_tokenize(ctx_dft, text, false, true)[0]; + } + + const llama_tokens & prompt_cur = spec->vocab_cmpt ? prompt_tgt : prompt_cnv; + + const int i_start = std::max(0, (int) prompt_cur.size() - n_ctx); + + // reuse as much as possible from the old draft context + // ideally, the draft context should be as big as the target context and we will always reuse the entire prompt + for (int i = 0; i < (int) prompt_dft.size(); ++i) { + int cur = 0; + while (i_start + cur < (int) prompt_cur.size() && + i + cur < (int) prompt_dft.size() && + prompt_cur[i_start + cur] == prompt_dft[i + cur]) { + cur++; + } + + if ((cur >= 256 || n_ctx >= (int) prompt_cur.size()) && cur > reuse_n) { + reuse_i = i; + reuse_n = cur; + } + } + + LOG_DBG("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt_dft.size()); + + result.clear(); + result.reserve(params.n_max); + + if (reuse_n == 0) { + llama_kv_cache_clear(ctx_dft); + prompt_dft.clear(); + } else { + // this happens when a previous draft has been discarded (for example, due to being too small), but the + // target model agreed with it. in this case, we simply pass back the previous results to save compute + if (reuse_i + reuse_n < (int) prompt_dft.size() && prompt_dft[reuse_i + reuse_n] == id_last) { + for (int i = reuse_i + reuse_n + 1; i < (int) prompt_dft.size(); ++i) { + result.push_back(prompt_dft[i]); + + if (params.n_max <= (int) result.size()) { + break; + } + } + + return; + } + + if (reuse_i > 0) { + llama_kv_cache_seq_rm (ctx_dft, 0, 0, reuse_i); + llama_kv_cache_seq_add(ctx_dft, 0, reuse_i, -1, -reuse_i); + + prompt_dft.erase(prompt_dft.begin(), prompt_dft.begin() + reuse_i); + } + + if (reuse_n < (int) prompt_dft.size()) { + llama_kv_cache_seq_rm (ctx_dft, 0, reuse_n, -1); + prompt_dft.erase(prompt_dft.begin() + reuse_n, prompt_dft.end()); + } + } + + // prepare a batch to evaluate any new tokens in the prompt + common_batch_clear(batch); + + for (size_t i = i_start + reuse_n; i < prompt_cur.size(); ++i) { + //LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_cur[i]); + common_batch_add(batch, prompt_cur[i], i - i_start, { 0 }, false); + + prompt_dft.push_back(prompt_cur[i]); + } + + // we should rarely end-up here during normal decoding + if (batch.n_tokens > 0) { + //LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str()); + + llama_decode(ctx_dft, batch); + } + + const llama_pos n_past = prompt_dft.size(); + + LOG_DBG("%s: n_past = %d\n", __func__, n_past); + + common_batch_clear(batch); + common_batch_add (batch, id_last, n_past, { 0 }, true); + + prompt_dft.push_back(id_last); + + //LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx_dft, prompt_dft).c_str()); + + llama_decode(ctx_dft, batch); + + common_sampler_reset(smpl); + + // sample n_draft tokens from the draft model + for (int i = 0; i < params.n_max; ++i) { + common_batch_clear(batch); + + common_sampler_sample(smpl, ctx_dft, 0, true); + + const auto * cur_p = common_sampler_get_candidates(smpl, true); + + for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) { + LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", + k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str()); + } + + // add drafted token for each sequence + const llama_token id = cur_p->data[0].id; + + common_sampler_accept(smpl, nullptr, id, true); + + // only collect very high-confidence draft tokens + if (cur_p->data[0].p < params.p_min) { + if (i == 0) { + result.push_back(id); + } + break; + } + + result.push_back(id); + + if (params.n_max <= (int) result.size()) { + break; + } + + + common_batch_add(batch, id, n_past + i + 1, { 0 }, true); + + // evaluate the drafted tokens on the draft model + llama_decode(ctx_dft, batch); + + prompt_dft.push_back(id); + } + + if (!spec->vocab_cmpt) { + std::string detokenized = common_detokenize(ctx_dft, result, true); + detokenized = replace_to_tgt(detokenized); + LOG_DBG("draft->main detokenized string: '%s'\n", detokenized.c_str()); + result = common_tokenize(ctx_tgt, detokenized, false, true); + if (result.size() > (size_t)params.n_max) { + result.resize(params.n_max); + } + } + } + + void accept(uint16_t n_accepted) override { + // noop + GGML_UNUSED(n_accepted); + } + + std::string replace_to_dft(const std::string & input) const { + std::string result = input; + + for (const auto & pair : this->vocab_map) { + size_t pos = result.find(pair.first); + while (pos != std::string::npos) { + result.replace(pos, pair.first.length(), pair.second); + pos = result.find(pair.first, pos + pair.second.length()); + } + } + + return result; + } + + std::string replace_to_tgt(const std::string & input) const { + std::string result = input; + + for (const auto & pair : this->vocab_map) { + size_t pos = result.find(pair.second); + while (pos != std::string::npos) { + result.replace(pos, pair.second.length(), pair.first); + pos = result.find(pair.second, pos + pair.first.length()); + } + } + + return result; + } +}; + +struct common_speculative_state_eagle3 : public common_speculative_state { + common_speculative_state_eagle3(enum common_speculative_type type) : common_speculative_state(type) {} + + void begin(const llama_tokens & prompt) override { + GGML_UNUSED(prompt); + } + + void draft( + const common_params_speculative & params, + const llama_tokens & prompt_tgt, + llama_token id_last, + llama_tokens & draft_tokens) override { + // TODO: implement + GGML_UNUSED(params); + GGML_UNUSED(prompt_tgt); + GGML_UNUSED(id_last); + GGML_UNUSED(draft_tokens); + } + + void accept(uint16_t n_accepted) override { + // noop + GGML_UNUSED(n_accepted); + } +}; + +// state of self-speculation (simple implementation, not ngram-map) +struct common_speculative_state_ngram_simple : public common_speculative_state { + common_ngram_simple_config config; + + common_speculative_state_ngram_simple( + enum common_speculative_type type, + common_ngram_simple_config config) + : common_speculative_state(type), config(config) {} + + void begin(const llama_tokens & prompt) override { + GGML_UNUSED(prompt); + } + + void draft( + const common_params_speculative & params, + const llama_tokens & prompt_tgt, + llama_token id_last, + llama_tokens & result) override { + + result = common_ngram_simple_draft(config, prompt_tgt, id_last); + GGML_UNUSED(params); + } + + void accept(uint16_t n_accepted) override { + // noop + GGML_UNUSED(n_accepted); + } +}; + +struct common_speculative_state_ngram_map_k : public common_speculative_state { + // draft ngram map for speculative decoding without draft model + common_ngram_map map; + + common_speculative_state_ngram_map_k( + enum common_speculative_type type, + common_ngram_map map) + : common_speculative_state(type), map(std::move(map)) {} + + void begin(const llama_tokens & prompt) override { + common_ngram_map_begin(map, prompt); + } + + void draft( + const common_params_speculative & params, + const llama_tokens & prompt_tgt, + llama_token id_last, + llama_tokens & result) override { + common_ngram_map_draft(map, prompt_tgt, id_last, result); + GGML_UNUSED(params); + } + + void accept(uint16_t n_accepted) override { + common_ngram_map_accept(map, n_accepted); + } +}; + +struct common_speculative_state_ngram_mod : public common_speculative_state { + common_ngram_mod & mod; + + // the last position in the prompt that was added to the ngram container + size_t i_last = 0; + + // length of the last drafted n‑gram (number of tokens returned by draft) + size_t n_draft_last = 0; + + // consecutive accept rounds with low acceptance fraction (< 0.5) + int n_low = 0; + + // enable trace logging if LLAMA_TRACE is set + const bool verbose; + + common_speculative_state_ngram_mod(enum common_speculative_type type, common_ngram_mod & mod) + : common_speculative_state(type), mod(mod), verbose(std::getenv("LLAMA_TRACE") != nullptr) { + static_assert(sizeof(llama_token) == sizeof(common_ngram_mod::entry_t)); + } + + void begin(const llama_tokens & prompt) override { + i_last = 0; + + n_draft_last = 0; + n_low = 0; + + const size_t n = mod.get_n(); + + if (prompt.size() < n) { + return; + } + + for (size_t i = 0; i < prompt.size() - n; ++i) { + mod.add(prompt.data() + i); + } + + i_last = prompt.size() - n; + + const double f = (double)mod.get_used() / (double)mod.size(); + LOG_INF("%s: ngram_mod occupancy = %zu/%zu (%.2f)\n", __func__, mod.get_used(), mod.size(), f); + + constexpr double f_thold = 0.25; + if (f > f_thold) { + LOG_WRN("%s: ngram_mod occupancy %.2f exceeds threshold (%.2f) - resetting\n", __func__, f, f_thold); + + mod.reset(); + } + } + + void draft( + const common_params_speculative & params, + const llama_tokens & prompt_tgt, + llama_token id_last, + llama_tokens & result) override { + GGML_UNUSED(params); + + n_draft_last = 0; + + const size_t cur_len = prompt_tgt.size(); + if (cur_len < mod.get_n()) { + return; + } + + const size_t n = mod.get_n(); + + // add new ngrams in chunks + if (i_last + 32 < cur_len) { + for (size_t i = i_last; i < cur_len - n; ++i) { + mod.add(prompt_tgt.data() + i); + } + + i_last = cur_len - n; + } + + result.resize(n + params.n_max); + for (size_t i = 0; i < n - 1; ++i) { + result[i] = prompt_tgt[cur_len - n + 1 + i]; + } + result[n - 1] = id_last; + + for (int i = 0; i < params.n_max; ++i) { + const llama_token token = mod.get(result.data() + i); + if (token == common_ngram_mod::EMPTY) { + if (i < params.n_min) { + result.clear(); + return; + } + + result.resize(n + i); + break; + } + result[n + i] = token; + } + + // only return the m tokens that were drafted + for (size_t i = 0; n + i < result.size(); ++i) { + result[i] = result[n + i]; + } + result.resize(result.size() - n); + + // store length of drafted n‑gram for later acceptance analysis + n_draft_last = result.size(); + } + + void accept(uint16_t n_accepted) override { + if (verbose) { + LOG_INF("%s: accepted %d tokens from %zu drafted tokens\n", __func__, n_accepted, n_draft_last); + } + + // compute acceptance fraction if we have a recorded draft length + if (n_draft_last > 0) { + const double f_acc = (double)n_accepted / (double)n_draft_last; + if (f_acc < 0.5) { + n_low++; + if (n_low >= 3) { + LOG_WRN("%s: low acceptance streak (%d) – resetting ngram_mod\n", __func__, n_low); + + mod.reset(); + n_low = 0; + i_last = 0; + } + } else { + n_low = 0; + } + } + } +}; + +struct common_speculative_state_ngram_cache : public common_speculative_state { + uint16_t n_draft; + bool save_dynamic; + bool save_static; + + common_ngram_cache ngram_cache_context; + common_ngram_cache ngram_cache_dynamic; + common_ngram_cache ngram_cache_static; + + size_t cache_size = 0; // number of tokens in n-gram cache + + common_speculative_state_ngram_cache( + const enum common_speculative_type type, + const std::string & path_static, + const std::string & path_dynamic, + uint16_t n_draft, + bool save_dynamic, + bool save_static) + : common_speculative_state(type) + , n_draft(n_draft) + , save_dynamic(save_dynamic) + , save_static(save_static) + { + if (!path_static.empty()) { + try { + ngram_cache_static = common_ngram_cache_load(path_static); + } catch (...) { + LOG_ERR("failed to open static lookup cache: %s", path_static.c_str()); + GGML_ABORT("Couldn't read static lookup cache"); + } + } + + if (!path_dynamic.empty()) { + try { + ngram_cache_dynamic = common_ngram_cache_load(path_dynamic); + } catch (...) { + LOG_ERR("failed to open dynamic lookup cache: %s", path_dynamic.c_str()); + GGML_ABORT("Couldn't read dynamic lookup cache"); + } + } + } + + void begin(const llama_tokens & prompt) override { + GGML_UNUSED(prompt); + } + + void draft( + const common_params_speculative & params, + const llama_tokens & prompt_tgt, + llama_token id_last, + llama_tokens & result) override { + GGML_UNUSED(params); + + if (cache_size < prompt_tgt.size() + 1) { + llama_tokens tokens_new; + tokens_new.reserve(prompt_tgt.size() + 1 - cache_size); + for (size_t j = cache_size; j < prompt_tgt.size(); ++j) { + tokens_new.push_back(prompt_tgt[j]); + } + tokens_new.push_back(id_last); // add the last token + + // Update context ngram cache with new prompt_tgt: + common_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, + tokens_new, tokens_new.size(), false); + cache_size = prompt_tgt.size() + 1; + } + + llama_tokens inp; + inp.reserve(prompt_tgt.size() + 1); + for (size_t j = 0; j < prompt_tgt.size(); ++j) { + inp.push_back(prompt_tgt[j]); + } + inp.push_back(id_last); + + result.push_back(id_last); + + common_ngram_cache_draft(inp, result, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, + ngram_cache_context, + ngram_cache_dynamic, + ngram_cache_static); + + if (result.size() > 0) { + // delete first token in result (which is the id_last token) + result.erase(result.begin()); + } + } + + void accept(uint16_t n_accepted) override { + // TODO: noop + GGML_UNUSED(n_accepted); + } +}; + +struct common_speculative_state_suffix : public common_speculative_state { + common_suffix_tree tree; + common_suffix_tree corpus_tree; + bool has_corpus = false; + size_t cache_size = 0; + + // Acceptance feedback + size_t n_draft_last = 0; + bool had_accept = false; + int n_low = 0; + float base_p_min = 0.1f; + float eff_p_min = 0.1f; + + common_speculative_state_suffix( + enum common_speculative_type type, + int max_depth, + const std::string & corpus_path, + const llama_model * model) + : common_speculative_state(type) + , tree(max_depth) + , corpus_tree(max_depth) + { + if (!corpus_path.empty()) { + std::function(const std::string &)> tokenize_fn; + if (model) { + tokenize_fn = [model](const std::string & text) -> std::vector { + return common_tokenize(model, text, false, true); + }; + } + has_corpus = corpus_tree.load_corpus(corpus_path, tokenize_fn); + } + } + + void begin(const llama_tokens & prompt) override { + cache_size = 0; + n_draft_last = 0; + had_accept = false; + n_low = 0; + GGML_UNUSED(prompt); + } + + void draft( + const common_params_speculative & params, + const llama_tokens & prompt_tgt, + llama_token id_last, + llama_tokens & result) override { + + base_p_min = params.p_min; + if (n_draft_last > 0 && !had_accept) { + if (++n_low >= 3) { + eff_p_min = std::min(eff_p_min + 0.1f, 0.5f); + n_low = 0; + } + } + had_accept = false; + + if (cache_size < prompt_tgt.size() + 1) { + llama_tokens tokens_new; + tokens_new.reserve(prompt_tgt.size() + 1 - cache_size); + for (size_t j = cache_size; j < prompt_tgt.size(); ++j) { + tokens_new.push_back(prompt_tgt[j]); + } + tokens_new.push_back(id_last); + + tree.extend(tokens_new.data(), (int)tokens_new.size()); + cache_size = prompt_tgt.size() + 1; + } + + const int ctx_len = std::min((int)(prompt_tgt.size() + 1), tree.max_depth()); + llama_tokens context; + context.reserve(ctx_len); + const int ctx_start = (int)prompt_tgt.size() + 1 - ctx_len; + for (int j = ctx_start; j < (int)prompt_tgt.size(); ++j) { + context.push_back(prompt_tgt[j]); + } + context.push_back(id_last); + const int min_match_len = std::max(1, params.suffix_min_match_len); + + result = tree.speculate( + context.data(), (int)context.size(), + params.n_max, + eff_p_min, + 1, + min_match_len); + + if (has_corpus) { + auto corpus_result = corpus_tree.speculate( + context.data(), (int)context.size(), + params.n_max, + eff_p_min, + 1, + min_match_len); + if (corpus_result.size() > result.size()) { + result = std::move(corpus_result); + } + } + + n_draft_last = result.size(); + } + + void accept(uint16_t n_accepted) override { + if (n_draft_last == 0) { + return; + } + had_accept = true; + const double f_acc = (double)n_accepted / (double)n_draft_last; + if (f_acc < 0.5) { + if (++n_low >= 3) { + eff_p_min = std::min(eff_p_min + 0.1f, 0.5f); + n_low = 0; + } + } else { + n_low = 0; + if (eff_p_min > base_p_min) { + eff_p_min = std::max(eff_p_min - 0.05f, base_p_min); + } + } + } +}; + +struct common_speculative { + std::vector configs; // resolved stage config for each implementation + std::vector> impls; // list of implementations to use and their states + common_speculative_checkpoint checkpoint; + common_speculative_state * curr_impl = nullptr; // current implementation in use (for stats) + std::unique_ptr tuner; + int last_n_drafted = 0; + int64_t t_step_start_us = 0; +}; + +static bool common_speculative_stage_chain_matches( + const std::vector & stages, + const std::vector & configs) { + if (stages.size() != configs.size()) { + return false; + } + + for (size_t i = 0; i < stages.size(); ++i) { + if (stages[i].type != configs[i].type) { + return false; + } + } + + return true; +} + +static common_params_speculative common_speculative_get_runtime_params( + const common_speculative_config & config, + const common_params_speculative & params, + const common_speculative_stage_params & stage) { + common_params_speculative result = config.params; + + result.type = config.type; + result.n_max = stage.has_n_max_override() ? stage.n_max : params.n_max; + result.n_min = stage.has_n_min_override() ? stage.n_min : params.n_min; + result.p_min = stage.has_p_min_override() ? stage.p_min : params.p_min; + + if (config.type == COMMON_SPECULATIVE_TYPE_SUFFIX) { + result.suffix_min_match_len = stage.has_suffix_min_match_len_override() + ? stage.suffix_min_match_len + : params.suffix_min_match_len; + } + + result.n_max = std::max(result.n_max, 0); + result.n_min = std::max(0, std::min(result.n_min, result.n_max)); + result.stages.clear(); + + return result; +} + +static common_ngram_map get_common_ngram_map(const common_speculative_config & config) { + uint16_t size_key = config.params.ngram_size_n; + uint16_t size_value = config.params.ngram_size_m; + bool key_only = (config.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K); + uint16_t min_hits = config.params.ngram_min_hits; + + return common_ngram_map(size_key, size_value, key_only, min_hits); +} + +static common_speculative_state_ngram_cache create_state_ngram_cache( + const std::string & path_static, const std::string & path_dynamic, + const common_speculative_config & config) { + uint16_t n_draft = 8; // TODO get from config? + + // TODO bool param in common/common.h to set save_static/save_dynamic? + bool save_static = false; + bool save_dynamic = false; + + common_speculative_state_ngram_cache state(config.type, path_static, path_dynamic, n_draft, save_static, save_dynamic); + + return state; +} + +std::string common_speculative_type_name_str() { + std::string result; + for (size_t i = 0; i < common_speculative_types.size(); i++) { + if (i > 0) { + result += ", "; + } + result += common_speculative_type_to_str(common_speculative_types[i]); + } + return result; +} + +std::string common_speculative_type_to_str(enum common_speculative_type type) { + switch (type) { + case COMMON_SPECULATIVE_TYPE_NONE: return "none"; + case COMMON_SPECULATIVE_TYPE_DRAFT: return "draft"; + case COMMON_SPECULATIVE_TYPE_DFLASH: return "dflash"; + case COMMON_SPECULATIVE_TYPE_MTP: return "mtp"; + case COMMON_SPECULATIVE_TYPE_EAGLE3: return "eagle3"; + case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return "ngram_simple"; + case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: return "ngram_map_k"; + case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: return "ngram_map_k4v"; + case COMMON_SPECULATIVE_TYPE_NGRAM_MOD: return "ngram_mod"; + case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: return "ngram_cache"; + case COMMON_SPECULATIVE_TYPE_SUFFIX: return "suffix"; + default: return "unknown"; + } +} + +enum common_speculative_type common_speculative_type_from_name(const std::string & name) { + std::string normalized = name; + std::replace(normalized.begin(), normalized.end(), '-', '_'); + + const auto it = common_speculative_type_from_name_map.find(normalized); + if (it == common_speculative_type_from_name_map.end()) { + return COMMON_SPECULATIVE_TYPE_COUNT; + } + return it->second; +} + +bool common_speculative_is_compat(llama_context * ctx_tgt) { + bool res = true; + + llama_kv_cache_clear(ctx_tgt); + + // eval 2 tokens to check if the context is compatible + std::vector tmp; + tmp.push_back(0); + tmp.push_back(0); + + int ret = llama_decode(ctx_tgt, llama_batch_get_one(tmp.data(), tmp.size(), 0, 0)); + if (ret != 0) { + LOG_ERR("%s: llama_decode() failed: %d\n", __func__, ret); + res = false; + goto done; + } + + // try to remove the last tokens + if (!llama_kv_cache_seq_rm(ctx_tgt, 0, 1, -1)) { + LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__); + res = false; + goto done; + } + +done: + llama_kv_cache_clear(ctx_tgt); + llama_synchronize(ctx_tgt); + + return res; +} + +// initialization of the speculative decoding system +// +common_speculative * common_speculative_init( + common_params_speculative & params, + llama_context * ctx_tgt) { + std::string chain_error; + if (!common_speculative_validate_chain(params, &chain_error)) { + LOG_ERR("%s: invalid speculative stage chain: %s\n", __func__, chain_error.c_str()); + return nullptr; + } + + const auto stages = params.get_resolved_stages(); + if (params.model_dft && llama_model_is_gemma4_mtp_assistant(params.model_dft)) { + const bool has_draft_stage = std::any_of(stages.begin(), stages.end(), [](const common_speculative_stage_params & stage) { + return stage.type == COMMON_SPECULATIVE_TYPE_DRAFT; + }); + + if (has_draft_stage) { + LOG_ERR("%s: Gemma4 assistant models only support MTP stages; omit -md for self-spec-only runs or use -mtp/--spec-stage mtp for assistant-backed MTP\n", __func__); + return nullptr; + } + } + + const bool has_dflash_stage = std::any_of(stages.begin(), stages.end(), [](const common_speculative_stage_params & stage) { + return stage.type == COMMON_SPECULATIVE_TYPE_DFLASH; + }); + + const bool needs_draft_ctx = std::any_of(stages.begin(), stages.end(), [¶ms](const common_speculative_stage_params & stage) { + return stage.type == COMMON_SPECULATIVE_TYPE_DRAFT || + stage.type == COMMON_SPECULATIVE_TYPE_DFLASH || + (stage.type == COMMON_SPECULATIVE_TYPE_MTP && params.model_dft != nullptr); + }); + + llama_context * ctx_dft = nullptr; + if (needs_draft_ctx) { + if (!params.model_dft) { + LOG_ERR("%s: draft speculative stage requires a loaded draft model\n", __func__); + return nullptr; + } + + llama_context_params cparams_dft = params.cparams_dft; + + if (has_dflash_stage) { + if (!llama_model_share_dflash_io_tensors(params.model_dft, llama_get_model(ctx_tgt))) { + LOG_ERR("%s: failed to share target IO tensors with DFlash draft model\n", __func__); + return nullptr; + } + + int32_t max_cross_ctx = 0; + for (const auto & stage : stages) { + if (stage.type != COMMON_SPECULATIVE_TYPE_DFLASH) { + continue; + } + + max_cross_ctx = std::max(max_cross_ctx, params.with_stage_overrides(stage).dflash_cross_ctx); + } + + const int32_t block_size = llama_model_dflash_block_size(params.model_dft); + if (block_size <= 0) { + LOG_ERR("%s: invalid DFlash draft block size\n", __func__); + return nullptr; + } + + const int64_t required_n_ctx = (int64_t) max_cross_ctx + (int64_t) block_size; + if (required_n_ctx > std::numeric_limits::max()) { + LOG_ERR("%s: invalid DFlash draft context size cross_ctx=%d block_size=%d required_n_ctx=%lld\n", + __func__, max_cross_ctx, block_size, (long long) required_n_ctx); + return nullptr; + } + + cparams_dft.n_ctx = (uint32_t) required_n_ctx; + } + + ctx_dft = llama_init_from_model(params.model_dft, cparams_dft); + if (ctx_dft == nullptr) { + LOG_ERR("%s", "failed to create draft context\n"); + return nullptr; + } + } + + // Compute the implementations to use based on the resolved stage chain. + std::vector configs = {}; + configs.reserve(stages.size()); + + for (const auto & stage : stages) { + common_params_speculative stage_params = params.with_stage_overrides(stage); + + if (stage.type == COMMON_SPECULATIVE_TYPE_NGRAM_MOD && !stage_params.ngram_mod) { + stage_params.ngram_mod = std::make_shared(stage_params.ngram_size_n, 4*1024*1024); + + LOG_INF("%s: initialized ngram_mod with n=%d, size=%zu (%.3f MB)\n", __func__, + stage_params.ngram_size_n, stage_params.ngram_mod->size(), + (float)(stage_params.ngram_mod->size_bytes())/1024/1024); + + if (stage_params.ngram_size_n < 16) { + LOG_WRN("%s: ngram_mod n=%d is too small - poor quality is possible, see: https://github.com/ggml-org/llama.cpp/pull/19164\n", __func__, stage_params.ngram_size_n); + } + } + + configs.push_back(common_speculative_config(stage, stage_params)); + } + + if (!configs.empty() && llama_model_has_recurrent(llama_get_model(ctx_tgt))) { + const int ckpt_tokens = std::max(1, params.get_max_stage_n_max() + 1); + const int actual_mode = llama_spec_ckpt_init(ctx_tgt, params.recurrent_ckpt_mode, ckpt_tokens); + if (actual_mode == LLAMA_SPEC_CKPT_NONE) { + LOG_ERR("%s: failed to prepare recurrent checkpoint mode '%s' during speculative init (max_tokens=%d)\n", + __func__, + params.recurrent_ckpt_mode == LLAMA_SPEC_CKPT_PER_STEP ? "per-step" : + params.recurrent_ckpt_mode == LLAMA_SPEC_CKPT_GPU_FALLBACK ? "gpu-fallback" : + params.recurrent_ckpt_mode == LLAMA_SPEC_CKPT_CPU ? "cpu" : "auto", + ckpt_tokens); + if (ctx_dft != nullptr) { + llama_free(ctx_dft); + } + return nullptr; + } + llama_spec_ckpt_discard(ctx_tgt); + params.recurrent_ckpt_mode = actual_mode; + } + + std::vector> impls = {}; + + for (const common_speculative_config & config : configs) { + LOG_DBG("%s: adding implementation %s\n", __func__, common_speculative_type_to_str(config.type).c_str()); + switch (config.type) { + case COMMON_SPECULATIVE_TYPE_NONE: + break; + case COMMON_SPECULATIVE_TYPE_DRAFT: { + impls.push_back(std::make_unique(config.type, + /* .ctx_tgt = */ ctx_tgt, + /* .ctx_dft = */ ctx_dft, + /* .replacements = */ config.params.replacements + )); + break; + } + case COMMON_SPECULATIVE_TYPE_DFLASH: { + auto state = std::make_unique( + config.type, + ctx_tgt, + ctx_dft, + config.params.dflash_cross_ctx); + if (!state->ready) { + LOG_ERR("%s: failed to initialize DFlash speculative state\n", __func__); + return nullptr; + } + impls.push_back(std::move(state)); + ctx_dft = nullptr; + break; + } + case COMMON_SPECULATIVE_TYPE_MTP: { + llama_context * ctx_mtp = ctx_dft; + if (!ctx_mtp) { + const llama_model * model = llama_get_model(ctx_tgt); + ctx_mtp = llama_init_from_model(const_cast(model), config.params.cparams_dft); + if (!ctx_mtp) { + LOG_ERR("%s: failed to create MTP context\n", __func__); + return nullptr; + } + } + ctx_dft = nullptr; + + const bool use_constant_draft_positions = llama_model_is_gemma4_mtp_assistant(llama_get_model(ctx_mtp)); + impls.push_back(std::make_unique( + config.type, ctx_tgt, ctx_mtp, use_constant_draft_positions)); + break; + } + case COMMON_SPECULATIVE_TYPE_EAGLE3: { + impls.push_back(std::make_unique(config.type)); + break; + } + case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: { + common_ngram_map ngram_map = get_common_ngram_map(config); + + uint16_t ngram_size_key = ngram_map.size_key; + uint16_t mgram_size_value = ngram_map.size_value; + + auto config_simple = common_ngram_simple_config { + /* .size_ngram = */ ngram_size_key, + /* .size_mgram = */ mgram_size_value + }; + auto state = std::make_unique( + /* .type = */ config.type, + /* .state = */ config_simple + ); + impls.push_back(std::move(state)); + break; + } + case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: + case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: { + impls.push_back(std::make_unique( + (config.type), + get_common_ngram_map(config) + )); + break; + } + case COMMON_SPECULATIVE_TYPE_NGRAM_MOD: { + GGML_ASSERT(config.params.ngram_mod); + impls.push_back(std::make_unique(config.type, *config.params.ngram_mod)); + break; + } + case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: { + auto state = create_state_ngram_cache( + config.params.lookup_cache_static, config.params.lookup_cache_dynamic, config); + impls.push_back(std::make_unique(state)); + break; + } + case COMMON_SPECULATIVE_TYPE_SUFFIX: { + int depth = config.params.suffix_max_depth > 0 ? config.params.suffix_max_depth : 64; + const llama_model * model = llama_get_model(ctx_tgt); + impls.push_back(std::make_unique( + config.type, depth, config.params.suffix_corpus, model)); + break; + } + default: + break; + } + } + + if (impls.empty()) { + LOG_WRN("%s", "no implementations specified for speculative decoding\n"); + return nullptr; + } + + auto * result = new common_speculative { + /* .configs = */ std::move(configs), + /* .impls = */ std::move(impls) + }; + + // initialize autotune if requested + if (params.autotune && params.has_composite_stage_chain()) { + LOG_WRN("Autotune disabled — explicit speculative stage chains are not supported yet\n"); + } else if (params.autotune && !result->impls.empty()) { + auto actual_type = result->impls[0]->type; + if (actual_type != COMMON_SPECULATIVE_TYPE_NONE && + actual_type != COMMON_SPECULATIVE_TYPE_EAGLE3) { + result->tuner = std::make_unique(); + result->tuner->init(actual_type, params, llama_get_model(ctx_tgt)); + LOG_DBG("Autotune initialized for %s, tuning %zu parameters\n", + common_speculative_type_to_str(actual_type).c_str(), + result->tuner->coords.size()); + } else { + LOG_WRN("Autotune disabled — speculative type %s is not supported for autotuning\n", + common_speculative_type_to_str(actual_type).c_str()); + } + } + + return result; +} + +void common_speculative_free(common_speculative * spec) { + if (spec == nullptr) { + return; + } + + spec->checkpoint.clear(); + delete spec; +} + +void common_speculative_begin(common_speculative * spec, const llama_tokens & prompt) { + if (spec == nullptr) { + return; + } + + for (auto & impl : spec->impls) { + common_time_meas tm(impl->t_begin_us, !impl->gen_perf); + impl->begin(prompt); + impl->n_call_begin++; + } +} + +llama_tokens common_speculative_draft( + common_speculative * spec, + common_params_speculative & params, + const llama_tokens & prompt_tgt, // specified in target model vocab + llama_token id_last, + llama_pos draft_base_pos, + llama_seq_id draft_seq_id) { + llama_tokens result; + + spec->t_step_start_us = ggml_time_us(); + + // apply autotune proposal if enabled + if (spec->tuner && spec->tuner->enabled) { + spec->tuner->propose(params); + } + + const auto runtime_stages = params.get_resolved_stages(); + const bool use_runtime_stage_overrides = common_speculative_stage_chain_matches(runtime_stages, spec->configs); + + spec->curr_impl = nullptr; // reset current implementation + + for (size_t i = 0; i < spec->impls.size(); ++i) { + auto & impl = spec->impls[i]; + const auto & runtime_stage = use_runtime_stage_overrides ? runtime_stages[i] : spec->configs[i].stage; + common_params_speculative impl_params = common_speculative_get_runtime_params(spec->configs[i], params, runtime_stage); + result.clear(); + + { + common_time_meas tm(impl->t_draft_us, !impl->gen_perf); + impl->draft(impl_params, prompt_tgt, id_last, draft_base_pos, draft_seq_id, result); + impl->n_call_draft++; + } + + if (result.empty()) { + continue; + } + + if (common_speculative_type_is_self_spec(impl->type) && impl_params.n_min > 0 && (int)result.size() < impl_params.n_min) { + LOG_DBG("%s: impl %s drafted %zu tokens, below fallback threshold %d - trying next implementation\n", + __func__, common_speculative_type_to_str(impl->type).c_str(), result.size(), impl_params.n_min); + result.clear(); + continue; + } + LOG_DBG("%s: called impl %s, hist size = %zu, call_count = %zu, gen = %zu\n", __func__, + common_speculative_type_to_str(impl.get()->type).c_str(), prompt_tgt.size(), + impl.get()->n_call_draft, result.size()); + + spec->curr_impl = impl.get(); + impl->n_gen_drafts++; + impl->n_gen_tokens += result.size(); + + break; // We have a draft, so break out of the loop and return it. + } + + // store draft count for tuner feedback + if (spec->tuner && spec->tuner->enabled) { + spec->last_n_drafted = (int)result.size(); + } + + return result; +} + +void common_speculative_accept(common_speculative * spec, uint16_t n_accepted) { + if (spec->tuner && spec->tuner->enabled && spec->t_step_start_us > 0) { + int64_t step_time_us = ggml_time_us() - spec->t_step_start_us; + double step_tps = (step_time_us > 100) + ? (n_accepted + 1.0) * 1e6 / (double)step_time_us + : 0.0; + spec->tuner->accept_feedback(n_accepted, spec->last_n_drafted, step_tps); + spec->t_step_start_us = 0; + } + + common_speculative_state * impl = spec->curr_impl; + + if (!impl) { + return; + } + + { + common_time_meas tm(impl->t_accept_us, !impl->gen_perf); + if (n_accepted > 0) { + impl->n_acc_drafts++; + impl->n_acc_tokens += n_accepted; + } + + impl->accept(n_accepted); + impl->n_call_accept++; + } + + if (impl->type != COMMON_SPECULATIVE_TYPE_MTP) { + if (auto * mtp_state = common_speculative_get_mtp_state(spec); mtp_state != nullptr) { + mtp_invalidate_cached_drafts(*mtp_state); + } + } +} + +static bool common_speculative_has_type(const common_speculative * spec, common_speculative_type type) { + if (spec == nullptr) { + return false; + } + + return std::any_of(spec->configs.begin(), spec->configs.end(), [type](const common_speculative_config & config) { + return config.type == type; + }); +} + +static int common_speculative_ctx_mtp_n_embd(llama_context * ctx) { + return ctx ? (int) llama_mtp_state_n_embd(ctx) : 0; +} + +static bool common_speculative_batch_token_has_seq_id( + const llama_batch & batch, + int token_index, + llama_seq_id seq_id) { + if (batch.n_seq_id == nullptr || batch.seq_id == nullptr || batch.n_seq_id[token_index] <= 0 || batch.seq_id[token_index] == nullptr) { + return false; + } + + for (int i = 0; i < batch.n_seq_id[token_index]; ++i) { + if (batch.seq_id[token_index][i] == seq_id) { + return true; + } + } + + return false; +} + +static bool common_speculative_batch_is_exact_single_seq( + const llama_batch & batch, + llama_seq_id seq_id) { + if (batch.n_tokens <= 0 || batch.n_seq_id == nullptr || batch.seq_id == nullptr) { + return false; + } + + for (int i = 0; i < batch.n_tokens; ++i) { + if (batch.n_seq_id[i] != 1 || batch.seq_id[i] == nullptr || batch.seq_id[i][0] != seq_id) { + return false; + } + } + + return true; +} + +static int common_speculative_copy_seq_batch( + const llama_batch & batch, + llama_seq_id seq_id, + llama_batch & seq_batch) { + if (batch.token == nullptr || batch.pos == nullptr) { + return -1; + } + + if (batch.n_tokens < 1) { + return 0; + } + + std::vector token_indices; + token_indices.reserve(batch.n_tokens); + for (int i = 0; i < batch.n_tokens; ++i) { + if (common_speculative_batch_token_has_seq_id(batch, i, seq_id)) { + token_indices.push_back(i); + } + } + + if (token_indices.empty()) { + return 0; + } + + seq_batch = llama_batch_init((int) token_indices.size(), 0, 1); + for (const int i : token_indices) { + common_batch_add(seq_batch, batch.token[i], batch.pos[i], { seq_id }, batch.logits != nullptr && batch.logits[i]); + } + + return (int) token_indices.size(); +} + +static bool common_speculative_feature_view_copy_batch_rows( + const common_speculative_feature_view & view, + const llama_batch & batch, + llama_seq_id seq_id, + std::vector * hidden_rows) { + if (hidden_rows == nullptr || view.kind != COMMON_SPECULATIVE_FEATURE_HIDDEN_STATE || view.width <= 0 || batch.n_tokens <= 0 || batch.pos == nullptr) { + return false; + } + + std::unordered_map rows_by_pos; + rows_by_pos.reserve(view.rows.size()); + for (const auto & row : view.rows) { + if (row.seq_id == seq_id && row.data != nullptr) { + rows_by_pos[row.pos] = row.data; + } + } + + hidden_rows->clear(); + hidden_rows->reserve((size_t) batch.n_tokens * view.width); + for (int i = 0; i < batch.n_tokens; ++i) { + auto it = rows_by_pos.find(batch.pos[i]); + if (it == rows_by_pos.end()) { + hidden_rows->clear(); + return false; + } + + hidden_rows->insert(hidden_rows->end(), it->second, it->second + view.width); + } + + return hidden_rows->size() == (size_t) batch.n_tokens * view.width; +} static bool common_speculative_capture_target_features( common_speculative * spec, @@ -1767,286 +2922,6 @@ static void mtp_clear_target_hidden(common_speculative_state_mtp & state, llama_ state.draft_cache_by_seq.erase(seq_id); } -// DFlash target-window replay and maintenance helpers. -struct dflash_append_breakdown { - uint64_t filter_us = 0; - uint64_t window_alloc_us = 0; - uint64_t replace_us = 0; - uint64_t keep_old_us = 0; - uint64_t new_rows_us = 0; - uint64_t commit_us = 0; - uint64_t log_us = 0; - bool replace_call = false; -}; - -static void dflash_record_window_update( - common_speculative_state_dflash & state, - int32_t keep_rows, - int32_t append_rows, - bool replace) { - state.target_window_keep_rows = std::max(0, keep_rows); - state.target_window_append_rows = std::max(0, append_rows); - state.target_window_replace = replace; - state.target_window_version++; -} - -static void dflash_ring_reset_rows( - common_speculative_state_dflash & state, - const float * rows, - int32_t n_rows) { - const size_t row_width = (size_t) state.n_target_features; - if (n_rows <= 0 || rows == nullptr) { - state.target_window_ring_write_pos = 0; - state.target_window_ring_filled = 0; - return; - } - - if (state.target_window_ring.size() != (size_t) state.cross_ctx * row_width) { - state.target_window_ring.resize((size_t) state.cross_ctx * row_width); - } - - std::memcpy(state.target_window_ring.data(), rows, (size_t) n_rows * row_width * sizeof(float)); - state.target_window_ring_write_pos = n_rows % state.cross_ctx; - state.target_window_ring_filled = n_rows; - state.target_window_materialized = false; -} - -static void dflash_ring_append_rows( - common_speculative_state_dflash & state, - const float * rows, - int32_t n_rows) { - const size_t row_width = (size_t) state.n_target_features; - if (n_rows <= 0 || rows == nullptr) { - return; - } - - if (state.target_window_ring.size() != (size_t) state.cross_ctx * row_width) { - state.target_window_ring.resize((size_t) state.cross_ctx * row_width); - } - - int32_t write_pos = state.target_window_ring_write_pos; - int32_t remaining = n_rows; - const float * src = rows; - while (remaining > 0) { - const int32_t chunk_rows = std::min(remaining, state.cross_ctx - write_pos); - std::memcpy( - state.target_window_ring.data() + (size_t) write_pos * row_width, - src, - (size_t) chunk_rows * row_width * sizeof(float)); - src += (size_t) chunk_rows * row_width; - remaining -= chunk_rows; - write_pos = (write_pos + chunk_rows) % state.cross_ctx; - } - - state.target_window_ring_write_pos = write_pos; - state.target_window_ring_filled = std::min(state.cross_ctx, state.target_window_ring_filled + n_rows); - state.target_window_materialized = false; -} - -static void dflash_materialize_target_window_features(common_speculative_state_dflash & state) { - if (state.target_window_materialized || state.target_window_rows <= 0) { - return; - } - - const size_t row_width = (size_t) state.n_target_features; - state.target_window.resize((size_t) state.target_window_rows * row_width); - - const int32_t read_start = (state.target_window_ring_write_pos - state.target_window_rows + state.cross_ctx) % state.cross_ctx; - const int32_t first_rows = std::min(state.target_window_rows, state.cross_ctx - read_start); - std::memcpy( - state.target_window.data(), - state.target_window_ring.data() + (size_t) read_start * row_width, - (size_t) first_rows * row_width * sizeof(float)); - - const int32_t second_rows = state.target_window_rows - first_rows; - if (second_rows > 0) { - std::memcpy( - state.target_window.data() + (size_t) first_rows * row_width, - state.target_window_ring.data(), - (size_t) second_rows * row_width * sizeof(float)); - } - - state.target_window_materialized = true; -} - -static bool dflash_append_target_features( - common_speculative_state_dflash & state, - const common_speculative_feature_view & features, - const llama_batch & batch, - llama_seq_id seq_id, - dflash_append_breakdown * breakdown = nullptr) { - GGML_UNUSED(batch); - - if (features.kind != COMMON_SPECULATIVE_FEATURE_HIDDEN_STATE || - features.width != state.n_target_features || - features.rows.empty() || - state.cross_ctx <= 0) { - return false; - } - - const size_t row_width = (size_t) state.n_target_features; - std::vector new_rows; - std::vector new_positions; - new_rows.reserve(features.rows.size() * row_width); - new_positions.reserve(features.rows.size()); - - const int64_t t_filter_us = ggml_time_us(); - for (const auto & row : features.rows) { - if (row.seq_id != seq_id || row.data == nullptr) { - continue; - } - - new_positions.push_back(row.pos); - new_rows.insert(new_rows.end(), row.data, row.data + row_width); - } - if (breakdown != nullptr) { - breakdown->filter_us += (uint64_t) (ggml_time_us() - t_filter_us); - } - - if (new_positions.empty()) { - return false; - } - - const int32_t n_rows = (int32_t) new_positions.size(); - state.n_window_updates++; - state.n_rows_seen += (size_t) n_rows; - if (n_rows >= state.cross_ctx) { - state.n_rows_dropped += (size_t) state.target_window_rows + (size_t) (n_rows - state.cross_ctx); - const int32_t keep_from = n_rows - state.cross_ctx; - const int64_t t_replace_us = ggml_time_us(); - state.target_window_pos.assign(new_positions.begin() + keep_from, new_positions.end()); - state.target_window_append_features.assign( - new_rows.begin() + (ptrdiff_t) keep_from * (ptrdiff_t) row_width, - new_rows.end()); - dflash_ring_reset_rows(state, state.target_window_append_features.data(), state.cross_ctx); - if (breakdown != nullptr) { - breakdown->replace_us += (uint64_t) (ggml_time_us() - t_replace_us); - breakdown->replace_call = true; - } - - const int64_t t_commit_us = ggml_time_us(); - state.target_window_rows = state.cross_ctx; - state.target_window_ring_filled = state.target_window_rows; - state.last_target_pos = state.target_window_pos.empty() ? -1 : state.target_window_pos.back(); - dflash_record_window_update(state, 0, state.target_window_rows, true); - if (breakdown != nullptr) { - breakdown->commit_us += (uint64_t) (ggml_time_us() - t_commit_us); - } - - const int64_t t_log_us = ggml_time_us(); - dflash_contract_log_append(state, seq_id, new_positions); - if (breakdown != nullptr) { - breakdown->log_us += (uint64_t) (ggml_time_us() - t_log_us); - } - return true; - } - - const int32_t keep_old_rows = std::min(state.target_window_rows, state.cross_ctx - n_rows); - state.n_rows_dropped += (size_t) std::max(0, state.target_window_rows - keep_old_rows); - const int64_t t_window_alloc_us = ggml_time_us(); - std::vector & next_window_pos = state.target_window_pos_stage; - next_window_pos.resize((size_t) (keep_old_rows + n_rows)); - if (breakdown != nullptr) { - breakdown->window_alloc_us += (uint64_t) (ggml_time_us() - t_window_alloc_us); - } - - if (keep_old_rows > 0) { - const int64_t t_keep_old_us = ggml_time_us(); - std::copy(state.target_window_pos.end() - keep_old_rows, state.target_window_pos.end(), next_window_pos.begin()); - if (breakdown != nullptr) { - breakdown->keep_old_us += (uint64_t) (ggml_time_us() - t_keep_old_us); - } - } - - const int64_t t_new_rows_us = ggml_time_us(); - state.target_window_append_features.assign(new_rows.begin(), new_rows.end()); - dflash_ring_append_rows(state, state.target_window_append_features.data(), n_rows); - std::copy(new_positions.begin(), new_positions.end(), next_window_pos.begin() + keep_old_rows); - if (breakdown != nullptr) { - breakdown->new_rows_us += (uint64_t) (ggml_time_us() - t_new_rows_us); - } - - const int64_t t_commit_us = ggml_time_us(); - state.target_window_pos.swap(next_window_pos); - next_window_pos.clear(); - state.target_window_rows = keep_old_rows + n_rows; - state.target_window_ring_filled = state.target_window_rows; - state.last_target_pos = state.target_window_pos.empty() ? -1 : state.target_window_pos.back(); - dflash_record_window_update(state, keep_old_rows, n_rows, false); - if (breakdown != nullptr) { - breakdown->commit_us += (uint64_t) (ggml_time_us() - t_commit_us); - } - - const int64_t t_log_us = ggml_time_us(); - dflash_contract_log_append(state, seq_id, new_positions); - if (breakdown != nullptr) { - breakdown->log_us += (uint64_t) (ggml_time_us() - t_log_us); - } - return true; -} - -static void dflash_clear_target_features(common_speculative_state_dflash & state) { - state.target_window.clear(); - state.target_window_pos.clear(); - state.target_window_stage.clear(); - state.target_window_pos_stage.clear(); - state.target_window_append_features.clear(); - state.target_window_rows = 0; - state.target_window_ring_write_pos = 0; - state.target_window_ring_filled = 0; - state.target_window_keep_rows = 0; - state.target_window_append_rows = 0; - state.target_window_replace = false; - state.target_window_materialized = false; - state.last_target_pos = -1; - llama_reset_dflash_kv_cache_state(state.ctx_dft); -} - -static void dflash_context_shift( - common_speculative_state_dflash & state, - llama_pos kv_keep, - llama_pos kv_discard, - llama_pos kv_past) { - if (kv_discard <= 0 || state.target_window_rows <= 0 || state.target_window_pos.empty()) { - return; - } - - dflash_materialize_target_window_features(state); - - const size_t row_width = (size_t) state.n_target_features; - const llama_pos discard_begin = kv_keep; - const llama_pos discard_end = kv_keep + kv_discard; - - std::vector shifted_rows; - std::vector shifted_positions; - shifted_rows.reserve(state.target_window.size()); - shifted_positions.reserve(state.target_window_pos.size()); - - for (int32_t row = 0; row < state.target_window_rows; ++row) { - llama_pos pos = state.target_window_pos[(size_t) row]; - if (pos >= discard_begin && pos < discard_end) { - continue; - } - - if (pos >= discard_end && pos < kv_past) { - pos -= kv_discard; - } - - const float * row_src = state.target_window.data() + (size_t) row * row_width; - shifted_rows.insert(shifted_rows.end(), row_src, row_src + row_width); - shifted_positions.push_back(pos); - } - - state.target_window = std::move(shifted_rows); - state.target_window_pos = std::move(shifted_positions); - state.target_window_rows = (int32_t) state.target_window_pos.size(); - dflash_ring_reset_rows(state, state.target_window.data(), state.target_window_rows); - state.last_target_pos = state.target_window_pos.empty() ? -1 : state.target_window_pos.back(); - dflash_record_window_update(state, 0, state.target_window_rows, true); - llama_reset_dflash_kv_cache_state(state.ctx_dft); - state.n_context_shifts++; -} - static bool common_speculative_capture_target_features(common_speculative * spec, const common_speculative_feature_view & features) { auto * mtp_state = common_speculative_get_mtp_state(spec); if (mtp_state == nullptr || features.kind != COMMON_SPECULATIVE_FEATURE_HIDDEN_STATE || features.width <= 0) { diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 3cd8f9fe..e615d85d 100644 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -570,6 +570,9 @@ class Model: if chkhsh == "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273": # ref: https://huggingface.co/alvarobartt/grok-2-tokenizer res = "grok-2" + if chkhsh == "972da7b59cec44d1f0a490a86c96df53859e486e481563e5dddac155013d87ac": + # ref: https://huggingface.co/poolside/Laguna-XS.2 + res = "laguna" if chkhsh == "0ef9807a4087ebef797fc749390439009c3b9eda9ad1a097abbe738f486c01e5": # ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B res = "llama-bpe" @@ -603,6 +606,9 @@ class Model: if chkhsh == "9c2227e4dd922002fb81bde4fc02b0483ca4f12911410dee2255e4987644e3f8": # ref: https://huggingface.co/CohereForAI/c4ai-command-r-v01 res = "command-r" + if chkhsh == "52df12b4c8d4176e7481aab4b6e8454d1fd0a210a04a574f6d4e067d10e23c3e": + # ref: https://huggingface.co/CohereLabs/North-Mini-Code-1.0 + res = "cohere2_moe" if chkhsh == "e636dc30a262dcc0d8c323492e32ae2b70728f4df7dfe9737d9f920a282b8aea": # ref: https://huggingface.co/Qwen/Qwen1.5-7B res = "qwen2" @@ -684,6 +690,9 @@ class Model: if chkhsh == "f4f37b6c8eb9ea29b3eac6bb8c8487c5ab7885f8d8022e67edc1c68ce8403e95": # ref: https://huggingface.co/MiniMaxAI/MiniMax-M2 res = "minimax-m2" + if chkhsh == "9dcf830ee9990cdbf78cc523a5f7bd9ad8f3f9890c2d3581d2785ad10f07049d": + # ref: https://huggingface.co/JetBrains/Mellum2-12B-A2.5B-Base + res = "mellum2" if res is None: logger.warning("\n") logger.warning("**************************************************************************************") @@ -1580,7 +1589,12 @@ class LlamaModel(Model): special_vocab.add_to_gguf(self.gguf_writer) def set_gguf_parameters(self): + saved_intermediate_size = self.hparams.get("intermediate_size") + saved_num_experts_per_tok = self.hparams.pop("num_experts_per_tok") + self.hparams["intermediate_size"] = self.hparams["prefix_dense_intermediate_size"] super().set_gguf_parameters() + self.hparams["intermediate_size"] = saved_intermediate_size + self.hparams["num_experts_per_tok"] = saved_num_experts_per_tok hparams = self.hparams self.gguf_writer.add_vocab_size(hparams["vocab_size"]) @@ -3735,7 +3749,7 @@ class Gemma4Model(Gemma4BaseModel): return [(self.map_tensor_name(name), data_torch)] -@Model.register("Gemma4AssistantForCausalLM") +@Model.register("Gemma4AssistantForCausalLM", "Gemma4UnifiedAssistantForCausalLM") class Gemma4AssistantModel(Gemma4BaseModel): model_arch = gguf.MODEL_ARCH.GEMMA4_MTP @@ -3950,6 +3964,86 @@ class CommandR2Model(Model): self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) +@Model.register("Cohere2MoeForCausalLM") +class Cohere2MoeModel(Model): + model_arch = gguf.MODEL_ARCH.COHERE2_MOE + + _experts: list[dict[str, Tensor]] | None = None + + def set_gguf_parameters(self): + saved_intermediate_size = self.hparams["intermediate_size"] + saved_num_experts_per_tok = self.hparams.pop("num_experts_per_tok") + self.hparams["intermediate_size"] = self.hparams["prefix_dense_intermediate_size"] + super().set_gguf_parameters() + self.hparams["intermediate_size"] = saved_intermediate_size + self.hparams["num_experts_per_tok"] = saved_num_experts_per_tok + hparams = self.hparams + + self.gguf_writer.add_vocab_size(hparams["vocab_size"]) + self.gguf_writer.add_logit_scale(hparams.get("logit_scale", 1.0)) + self.gguf_writer.add_sliding_window(hparams["sliding_window"]) + self.gguf_writer.add_sliding_window_pattern([ + layer_type == "sliding_attention" + for layer_type in hparams["layer_types"] + ]) + self.gguf_writer.add_rope_dimension_count(hparams.get("head_dim", hparams["hidden_size"] // hparams["num_attention_heads"])) + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) + + self.gguf_writer.add_expert_feed_forward_length(hparams["intermediate_size"]) + self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"]) + self.gguf_writer.add_expert_count(hparams["num_experts"]) + self.gguf_writer.add_expert_used_count(hparams["num_experts_per_tok"]) + self.gguf_writer.add_expert_weights_norm(bool(hparams.get("norm_topk_prob", False))) + + expert_selection_fn = hparams.get("expert_selection_fn", "softmax") + if expert_selection_fn != "sigmoid": + raise ValueError(f"Unsupported Cohere2-MoE expert_selection_fn={expert_selection_fn!r}") + self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID) + + if hparams.get("num_shared_experts", 0) != 0: + raise ValueError("Cohere2-MoE shared experts are not supported in this GGUF converter yet") + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # Cohere2-MoE HF tensors already use the interleaved RoPE layout expected here. + + if ".mlp.experts." in name: + n_experts = self.hparams["num_experts"] + assert bid is not None + + if self._experts is None: + self._experts = [{} for _ in range(self.block_count)] + + self._experts[bid][name] = data_torch + + if len(self._experts[bid]) < n_experts * 3: + return [] + + tensors: list[tuple[str, Tensor]] = [] + for src, dst in [ + ("gate_proj", "gate_proj"), + ("down_proj", "down_proj"), + ("up_proj", "up_proj"), + ]: + datas: list[Tensor] = [] + for xid in range(n_experts): + ename = f"model.layers.{bid}.mlp.experts.{xid}.{src}.weight" + datas.append(self._experts[bid][ename]) + del self._experts[bid][ename] + + merged_name = f"model.layers.{bid}.mlp.experts.{dst}.weight" + tensors.append((self.map_tensor_name(merged_name), torch.stack(datas, dim=0))) + yield from tensors + return + + if name == "model.embed_tokens.weight": + yield self.map_tensor_name(name), data_torch + if self.tensor_names is None or "lm_head.weight" not in self.tensor_names: + yield self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT, suffix=".weight"), data_torch + return + + yield self.map_tensor_name(name), data_torch + + @Model.register("OlmoForCausalLM") @Model.register("OLMoForCausalLM") class OlmoModel(Model): @@ -5307,6 +5401,138 @@ class BailingMoeV2Model(Model): raise ValueError(f"Unprocessed experts: {experts}") +@Model.register("LagunaForCausalLM") +class LagunaModel(Model): + model_arch = gguf.MODEL_ARCH.LAGUNA + + _experts: list[dict[str, Tensor]] | None = None + + def set_gguf_parameters(self): + hparams = self.hparams + arch = gguf.MODEL_ARCH_NAMES[self.model_arch] + n_layers = int(hparams["num_hidden_layers"]) + n_head_base = int(hparams["num_attention_heads"]) + n_kv_base = int(hparams.get("num_key_value_heads", n_head_base)) + head_dim = int(hparams.get("head_dim", hparams["hidden_size"] // n_head_base)) + + heads_per_layer = hparams.get("num_attention_heads_per_layer") + kv_per_layer = hparams.get("num_key_value_heads_per_layer") + + head_arr: list[int] = [] + kv_arr: list[int] = [] + for i in range(n_layers): + head_arr.append(int(heads_per_layer[i]) if heads_per_layer is not None else n_head_base) + kv_arr.append(int(kv_per_layer[i]) if kv_per_layer is not None else n_kv_base) + + rope_params = hparams.get("rope_parameters", {}) + full_rope = rope_params.get("full_attention", rope_params) + swa_rope = rope_params.get("sliding_attention", {}) + + self.gguf_writer.add_context_length(int(hparams["max_position_embeddings"])) + self.gguf_writer.add_embedding_length(int(hparams["hidden_size"])) + self.gguf_writer.add_block_count(n_layers) + self.gguf_writer.add_feed_forward_length(int(hparams["intermediate_size"])) + self.gguf_writer.add_head_count(head_arr) + if all(n_kv == kv_arr[0] for n_kv in kv_arr): + self.gguf_writer.add_head_count_kv(kv_arr[0]) + else: + self.gguf_writer.add_head_count_kv(kv_arr) + self.gguf_writer.add_key_length(head_dim) + self.gguf_writer.add_value_length(head_dim) + self.gguf_writer.add_layer_norm_rms_eps(float(hparams["rms_norm_eps"])) + self.gguf_writer.add_file_type(self.ftype) + + self.gguf_writer.add_sliding_window(int(hparams["sliding_window"])) + self.gguf_writer.add_rope_dimension_count(head_dim // 2) + self.gguf_writer.add_uint32(f"{arch}.rope.dimension_count_swa", head_dim) + self.gguf_writer.add_rope_freq_base(float(full_rope.get("rope_theta", 500000.0))) + self.gguf_writer.add_float32(f"{arch}.rope.freq_base_swa", float(swa_rope.get("rope_theta", 10000.0))) + if full_rope.get("rope_type") == "yarn": + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) + self.gguf_writer.add_rope_scaling_factor(float(full_rope.get("factor", 1.0))) + self.gguf_writer.add_rope_scaling_orig_ctx_len(int(full_rope.get( + "original_max_position_embeddings", + rope_params.get("original_max_position_embeddings", hparams["max_position_embeddings"]), + ))) + self.gguf_writer.add_rope_scaling_yarn_ext_factor(float(full_rope.get("factor", 1.0))) + self.gguf_writer.add_rope_scaling_yarn_attn_factor(float(full_rope.get("attention_factor", 1.0))) + self.gguf_writer.add_rope_scaling_yarn_beta_fast(float(full_rope.get("beta_fast", 32.0))) + self.gguf_writer.add_rope_scaling_yarn_beta_slow(float(full_rope.get("beta_slow", 1.0))) + + self.gguf_writer.add_expert_count(int(hparams["num_experts"])) + self.gguf_writer.add_expert_used_count(int(hparams["num_experts_per_tok"])) + self.gguf_writer.add_expert_feed_forward_length(int(hparams["moe_intermediate_size"])) + if (shared_dim := hparams.get("shared_expert_intermediate_size")) is not None and int(shared_dim) > 0: + self.gguf_writer.add_expert_shared_feed_forward_length(int(shared_dim)) + if (routing_scale := hparams.get("moe_routed_scaling_factor")) is not None: + self.gguf_writer.add_expert_weights_scale(float(routing_scale)) + self.gguf_writer.add_expert_weights_norm(True) + self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID) + + leading_dense = 0 + for mlp_type in hparams.get("mlp_layer_types", []): + if mlp_type != "dense": + break + leading_dense += 1 + self.gguf_writer.add_uint32(f"{arch}.leading_dense_block_count", leading_dense) + + if hparams.get("moe_apply_router_weight_on_input", False): + raise ValueError("moe_apply_router_weight_on_input=True is not supported for Laguna") + + def set_vocab(self) -> None: + super().set_vocab() + if isinstance(eos_token_id := self.hparams.get("eos_token_id"), list) and len(eos_token_id) > 1: + # Poolside uses token 24 () as a turn boundary. + self.gguf_writer.add_eot_token_id(int(eos_token_id[1])) + template_file = self.dir_model / "chat_template.jinja" + if template_file.is_file(): + self.gguf_writer.add_chat_template(template_file.read_text(encoding="utf-8")) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if bid is not None and name in ( + f"model.layers.{bid}.mlp.experts.e_score_correction_bias", + f"model.layers.{bid}.mlp.experts.e_score_correction", + ): + # The C++ loader asks for this tensor through the ".bias" suffix. + # Keep the Laguna converter aligned with existing community GGUFs. + yield f"blk.{bid}.exp_probs_b.bias", data_torch + return + + if name.endswith(".self_attn.g_proj.weight"): + # HF stores the head-wise attention gate with a singleton dimension. + data_torch = data_torch.squeeze().contiguous() + + if bid is not None and re.match(r"model\.layers\.\d+\.mlp\.experts\.\d+\.(gate_proj|up_proj|down_proj)\.weight$", name): + n_experts = int(self.find_hparam(["num_experts"])) + if self._experts is None: + self._experts = [{} for _ in range(self.block_count)] + + self._experts[bid][name] = data_torch + if len(self._experts[bid]) < n_experts * 3: + return + + for w_name in ("down_proj", "gate_proj", "up_proj"): + datas: list[Tensor] = [] + for xid in range(n_experts): + ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight" + datas.append(self._experts[bid][ename]) + del self._experts[bid][ename] + + merged = torch.stack(datas, dim=0) + merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight" + yield from super().modify_tensors(merged, merged_name, bid) + return + + yield from super().modify_tensors(data_torch, name, bid) + + def prepare_tensors(self): + super().prepare_tensors() + if self._experts is not None: + experts = [k for d in self._experts for k in d.keys()] + if experts: + raise ValueError(f"Unprocessed experts: {experts}") + + ###### CONVERSION LOGIC ###### diff --git a/examples/server/server-context.cpp b/examples/server/server-context.cpp index dcf3469d..7be87f97 100644 --- a/examples/server/server-context.cpp +++ b/examples/server/server-context.cpp @@ -6,6 +6,7 @@ #include "common.h" #include "llama.h" +#include "llama-spec-features.h" #include "log.h" #include "sampling.h" #include "speculative.h" @@ -13,12 +14,8 @@ #include "mtmd-helper.h" #include -#include -#include -#include #include #include -#include #include static void server_prompt_checkpoint_update(server_prompt_checkpoint & ckpt, llama_context * ctx, int id, int64_t n_tokens, llama_pos pos_min = -1, llama_pos pos_max = -1, int32_t offset = 0) { @@ -49,83 +46,6 @@ static void log_text(const gpt_params & params_base, const std::string & text) { } } -static bool server_dflash_contract_log_enabled() { - const char * env = std::getenv("IK_DFLASH_CONTRACT_LOG"); - if (env == nullptr || *env == '\0') { - return false; - } - - return std::strcmp(env, "0") != 0 && - std::strcmp(env, "false") != 0 && - std::strcmp(env, "off") != 0; -} - -static std::string server_dflash_contract_format_indices( - const std::vector & values, - size_t edge_count = 4) { - std::ostringstream oss; - oss << '['; - if (values.empty()) { - oss << ']'; - return oss.str(); - } - - const size_t head = std::min(edge_count, values.size()); - for (size_t i = 0; i < head; ++i) { - if (i > 0) { - oss << ','; - } - oss << values[i]; - } - - if (values.size() > edge_count * 2) { - oss << ",...,"; - for (size_t i = values.size() - edge_count; i < values.size(); ++i) { - if (i > values.size() - edge_count) { - oss << ','; - } - oss << values[i]; - } - } else { - for (size_t i = head; i < values.size(); ++i) { - oss << ',' << values[i]; - } - } - - oss << ']'; - return oss.str(); -} - -static void server_dflash_contract_log_accept( - const server_slot & slot, - common_speculative_type spec_type_used, - const char * path, - bool any_rejected, - size_t n_draft, - const std::vector & ids, - llama_pos pos_base, - const std::vector & output_indices) { - if (!server_dflash_contract_log_enabled() || spec_type_used != COMMON_SPECULATIVE_TYPE_DFLASH) { - return; - } - - static std::atomic counter = 0; - const uint64_t ordinal = counter.fetch_add(1, std::memory_order_relaxed); - if (ordinal >= 8) { - return; - } - - LLAMA_LOG_INFO("dflash contract accept[%llu]: slot=%d path=%s rejected=%s drafted=%zu accepted=%zu pos_base=%d output_indices=%s\n", - (unsigned long long) (ordinal + 1), - slot.id, - path, - any_rejected ? "true" : "false", - n_draft, - ids.size(), - (int) pos_base, - server_dflash_contract_format_indices(output_indices).c_str()); -} - static bool server_slot_prompt_batch_overlaps( const server_slot & slot, int32_t batch_i0, @@ -158,12 +78,13 @@ static bool server_response_needs_chat_parse(oaicompat_type oaicompat) { oaicompat == OAICOMPAT_TYPE_RESP; } -static bool server_speculative_has_dflash(const common_params_speculative & spec) { - return spec.has_stage_type(COMMON_SPECULATIVE_TYPE_DFLASH); +static bool server_speculative_uses_target_features(const common_params_speculative & spec) { + return spec.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP) || + spec.has_stage_type(COMMON_SPECULATIVE_TYPE_DFLASH); } -static bool server_speculative_has_target_features(const common_params_speculative & spec) { - return spec.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP) || server_speculative_has_dflash(spec); +static bool server_speculative_requires_single_slot(const common_params_speculative & spec) { + return spec.has_stage_chain(); } static bool server_speculative_same_stage_types( @@ -295,8 +216,8 @@ bool server_context::load_model(const gpt_params& params_) { common_speculative_prepare_startup(params_base, false); - if (server_speculative_has_dflash(params_base.speculative) && params_base.n_parallel > 1) { - LOG_ERROR("DFlash is currently limited to a single server slot (-np 1).\n", { + if (server_speculative_requires_single_slot(params_base.speculative) && params_base.n_parallel > 1) { + LOG_ERROR("Speculative decoding is currently limited to a single server slot (-np 1).\n", { {"n_parallel", params_base.n_parallel}, }); return false; @@ -4127,7 +4048,7 @@ void server_context::speculative_decoding_accept() { } std::vector accepted_output_indices; - if (server_speculative_has_target_features(slot.params.speculative)) { + if (server_speculative_uses_target_features(slot.params.speculative)) { if (!ids.empty()) { accepted_output_indices.assign(slot.i_batch_dft.begin(), slot.i_batch_dft.begin() + ids.size()); } @@ -4160,14 +4081,14 @@ void server_context::speculative_decoding_accept() { const common_speculative_checkpoint * ckpt = common_speculative_get_checkpoint(slot.spec); const bool will_restore = any_rejected && ckpt != nullptr && ckpt->valid; - if (server_speculative_has_target_features(slot.params.speculative) && !accepted_output_indices.empty()) { - server_dflash_contract_log_accept( - slot, - spec_type_used, + if (server_speculative_uses_target_features(slot.params.speculative) && !accepted_output_indices.empty()) { + llama_dflash_contract_log_accept( + slot.id, + spec_type_used == COMMON_SPECULATIVE_TYPE_DFLASH, will_restore ? "restore" : "direct", any_rejected, n_draft, - ids, + ids.size(), spec_pos_base, accepted_output_indices); } @@ -4556,9 +4477,9 @@ void server_context::process_batch_tokens(int32_t & n_batch) { continue; // continue loop of n_batch } - if (server_speculative_has_target_features(params_base.speculative)) { + if (server_speculative_uses_target_features(params_base.speculative)) { for (auto & slot : slots) { - if (!slot.spec || !server_speculative_has_target_features(slot.params.speculative)) { + if (!slot.spec || !server_speculative_uses_target_features(slot.params.speculative)) { continue; } diff --git a/src/llama-dflash-profile.h b/src/llama-dflash-profile.h new file mode 100644 index 00000000..4d998488 --- /dev/null +++ b/src/llama-dflash-profile.h @@ -0,0 +1,340 @@ +#pragma once + +#include +#include + +enum llama_dflash_kv_node_kind { + LLAMA_DFLASH_KV_NODE_NONE = 0, + LLAMA_DFLASH_KV_NODE_FUSED_TARGET, + LLAMA_DFLASH_KV_NODE_K_PROJ, + LLAMA_DFLASH_KV_NODE_K_NORM, + LLAMA_DFLASH_KV_NODE_K_ROPE, + LLAMA_DFLASH_KV_NODE_V_PROJ, + LLAMA_DFLASH_KV_NODE_K_STORE, + LLAMA_DFLASH_KV_NODE_V_STORE, +}; + +enum llama_dflash_main_node_kind { + LLAMA_DFLASH_MAIN_NODE_NONE = 0, + LLAMA_DFLASH_MAIN_NODE_QCUR, + LLAMA_DFLASH_MAIN_NODE_K_DRAFT, + LLAMA_DFLASH_MAIN_NODE_V_DRAFT, + LLAMA_DFLASH_MAIN_NODE_K_CTX_VIEW, + LLAMA_DFLASH_MAIN_NODE_V_CTX_VIEW, + LLAMA_DFLASH_MAIN_NODE_K_CONCAT, + LLAMA_DFLASH_MAIN_NODE_V_CONCAT, + LLAMA_DFLASH_MAIN_NODE_K_PAD, + LLAMA_DFLASH_MAIN_NODE_V_PAD, + LLAMA_DFLASH_MAIN_NODE_K_PERM_CONT, + LLAMA_DFLASH_MAIN_NODE_V_PERM_CONT, + LLAMA_DFLASH_MAIN_NODE_FLASH_ATTN, + LLAMA_DFLASH_MAIN_NODE_ATTN_OUT, + LLAMA_DFLASH_MAIN_NODE_FFN, + LLAMA_DFLASH_MAIN_NODE_RESULT_ROWS, + LLAMA_DFLASH_MAIN_NODE_RESULT_NORM, + LLAMA_DFLASH_MAIN_NODE_RESULT, +}; + +struct llama_dflash_kv_node_profiler { + llama_dflash_profile_stats * profile = nullptr; + int64_t t_start_us = 0; + llama_dflash_kv_node_kind active_kind = LLAMA_DFLASH_KV_NODE_NONE; +}; + +struct llama_dflash_main_node_profiler { + llama_dflash_profile_stats * profile = nullptr; + ggml_backend_sched_eval_callback prev_callback = nullptr; + void * prev_user_data = nullptr; + bool prev_active = false; + int64_t t_start_us = 0; + llama_dflash_main_node_kind active_kind = LLAMA_DFLASH_MAIN_NODE_NONE; +}; + +static inline bool llama_dflash_tensor_name_has_prefix(const struct ggml_tensor * tensor, const char * prefix) { + if (tensor == nullptr || prefix == nullptr || prefix[0] == '\0') { + return false; + } + + return std::strncmp(tensor->name, prefix, std::strlen(prefix)) == 0; +} + +static inline bool llama_dflash_tensor_name_matches_label(const struct ggml_tensor * tensor, const char * label) { + if (!llama_dflash_tensor_name_has_prefix(tensor, label)) { + return false; + } + + const size_t label_len = std::strlen(label); + const char next = tensor->name[label_len]; + return next == '\0' || next == '-'; +} + +static inline llama_dflash_kv_node_kind llama_dflash_kv_node_kind_from_tensor(const struct ggml_tensor * tensor) { + if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_fused_target")) { + return LLAMA_DFLASH_KV_NODE_FUSED_TARGET; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_proj")) { + return LLAMA_DFLASH_KV_NODE_K_PROJ; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_norm")) { + return LLAMA_DFLASH_KV_NODE_K_NORM; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_rope")) { + return LLAMA_DFLASH_KV_NODE_K_ROPE; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_v_proj")) { + return LLAMA_DFLASH_KV_NODE_V_PROJ; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_store")) { + return LLAMA_DFLASH_KV_NODE_K_STORE; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_v_store")) { + return LLAMA_DFLASH_KV_NODE_V_STORE; + } + + return LLAMA_DFLASH_KV_NODE_NONE; +} + +static inline void llama_dflash_kv_node_profile_add( + llama_dflash_profile_stats & profile, + llama_dflash_kv_node_kind kind, + uint64_t elapsed_us) { + switch (kind) { + case LLAMA_DFLASH_KV_NODE_FUSED_TARGET: + profile.graph_kv_node_fused_target_calls++; + profile.graph_kv_node_fused_target_us += elapsed_us; + break; + case LLAMA_DFLASH_KV_NODE_K_PROJ: + profile.graph_kv_node_k_proj_calls++; + profile.graph_kv_node_k_proj_us += elapsed_us; + break; + case LLAMA_DFLASH_KV_NODE_K_NORM: + profile.graph_kv_node_k_norm_calls++; + profile.graph_kv_node_k_norm_us += elapsed_us; + break; + case LLAMA_DFLASH_KV_NODE_K_ROPE: + profile.graph_kv_node_k_rope_calls++; + profile.graph_kv_node_k_rope_us += elapsed_us; + break; + case LLAMA_DFLASH_KV_NODE_V_PROJ: + profile.graph_kv_node_v_proj_calls++; + profile.graph_kv_node_v_proj_us += elapsed_us; + break; + case LLAMA_DFLASH_KV_NODE_K_STORE: + profile.graph_kv_node_k_store_calls++; + profile.graph_kv_node_k_store_us += elapsed_us; + break; + case LLAMA_DFLASH_KV_NODE_V_STORE: + profile.graph_kv_node_v_store_calls++; + profile.graph_kv_node_v_store_us += elapsed_us; + break; + case LLAMA_DFLASH_KV_NODE_NONE: + break; + } +} + +static inline llama_dflash_main_node_kind llama_dflash_main_node_kind_from_tensor(const struct ggml_tensor * tensor) { + if (llama_dflash_tensor_name_has_prefix(tensor, "Qcur")) { + return LLAMA_DFLASH_MAIN_NODE_QCUR; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "Kcur_noise")) { + return LLAMA_DFLASH_MAIN_NODE_K_DRAFT; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "Vcur_noise")) { + return LLAMA_DFLASH_MAIN_NODE_V_DRAFT; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "Kcur_ctx_cache")) { + return LLAMA_DFLASH_MAIN_NODE_K_CTX_VIEW; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "Vcur_ctx_cache")) { + return LLAMA_DFLASH_MAIN_NODE_V_CTX_VIEW; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_k_concat")) { + return LLAMA_DFLASH_MAIN_NODE_K_CONCAT; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_v_concat")) { + return LLAMA_DFLASH_MAIN_NODE_V_CONCAT; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_k_pad")) { + return LLAMA_DFLASH_MAIN_NODE_K_PAD; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_v_pad")) { + return LLAMA_DFLASH_MAIN_NODE_V_PAD; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_k_perm_cont")) { + return LLAMA_DFLASH_MAIN_NODE_K_PERM_CONT; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_v_perm_cont")) { + return LLAMA_DFLASH_MAIN_NODE_V_PERM_CONT; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "flash_attn_reshaped")) { + return LLAMA_DFLASH_MAIN_NODE_NONE; + } + if (llama_dflash_tensor_name_matches_label(tensor, "flash_attn")) { + return LLAMA_DFLASH_MAIN_NODE_FLASH_ATTN; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "kqv_out")) { + return LLAMA_DFLASH_MAIN_NODE_ATTN_OUT; + } + if (llama_dflash_tensor_name_has_prefix(tensor, "ffn_out")) { + return LLAMA_DFLASH_MAIN_NODE_FFN; + } + if (llama_dflash_tensor_name_matches_label(tensor, "result_output_rows")) { + return LLAMA_DFLASH_MAIN_NODE_RESULT_ROWS; + } + if (llama_dflash_tensor_name_matches_label(tensor, "result_norm")) { + return LLAMA_DFLASH_MAIN_NODE_RESULT_NORM; + } + if (llama_dflash_tensor_name_matches_label(tensor, "output")) { + return LLAMA_DFLASH_MAIN_NODE_RESULT; + } + if (llama_dflash_tensor_name_matches_label(tensor, "result_output")) { + return LLAMA_DFLASH_MAIN_NODE_RESULT; + } + + return LLAMA_DFLASH_MAIN_NODE_NONE; +} + +static inline void llama_dflash_main_node_profile_add( + llama_dflash_profile_stats & profile, + llama_dflash_main_node_kind kind, + uint64_t elapsed_us) { + switch (kind) { + case LLAMA_DFLASH_MAIN_NODE_QCUR: + profile.graph_main_node_qcur_calls++; + profile.graph_main_node_qcur_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_K_DRAFT: + profile.graph_main_node_k_draft_calls++; + profile.graph_main_node_k_draft_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_V_DRAFT: + profile.graph_main_node_v_draft_calls++; + profile.graph_main_node_v_draft_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_K_CTX_VIEW: + profile.graph_main_node_k_ctx_view_calls++; + profile.graph_main_node_k_ctx_view_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_V_CTX_VIEW: + profile.graph_main_node_v_ctx_view_calls++; + profile.graph_main_node_v_ctx_view_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_K_CONCAT: + profile.graph_main_node_k_concat_calls++; + profile.graph_main_node_k_concat_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_V_CONCAT: + profile.graph_main_node_v_concat_calls++; + profile.graph_main_node_v_concat_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_K_PAD: + profile.graph_main_node_k_pad_calls++; + profile.graph_main_node_k_pad_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_V_PAD: + profile.graph_main_node_v_pad_calls++; + profile.graph_main_node_v_pad_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_K_PERM_CONT: + profile.graph_main_node_k_perm_cont_calls++; + profile.graph_main_node_k_perm_cont_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_V_PERM_CONT: + profile.graph_main_node_v_perm_cont_calls++; + profile.graph_main_node_v_perm_cont_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_FLASH_ATTN: + profile.graph_main_node_flash_attn_calls++; + profile.graph_main_node_flash_attn_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_ATTN_OUT: + profile.graph_main_node_attn_out_calls++; + profile.graph_main_node_attn_out_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_FFN: + profile.graph_main_node_ffn_calls++; + profile.graph_main_node_ffn_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_RESULT_ROWS: + profile.graph_main_node_result_rows_calls++; + profile.graph_main_node_result_rows_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_RESULT_NORM: + profile.graph_main_node_result_norm_calls++; + profile.graph_main_node_result_norm_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_RESULT: + profile.graph_main_node_result_calls++; + profile.graph_main_node_result_us += elapsed_us; + break; + case LLAMA_DFLASH_MAIN_NODE_NONE: + break; + } +} + +static inline bool llama_dflash_kv_node_eval_callback(struct ggml_tensor * tensor, bool ask, void * user_data) { + auto * profiler = static_cast(user_data); + if (profiler == nullptr || profiler->profile == nullptr) { + return false; + } + + const llama_dflash_kv_node_kind kind = llama_dflash_kv_node_kind_from_tensor(tensor); + if (ask) { + if (kind == LLAMA_DFLASH_KV_NODE_NONE) { + return false; + } + + profiler->active_kind = kind; + profiler->t_start_us = ggml_time_us(); + return true; + } + + if (kind != LLAMA_DFLASH_KV_NODE_NONE && profiler->active_kind == kind && profiler->t_start_us > 0) { + llama_dflash_kv_node_profile_add(*profiler->profile, kind, (uint64_t) (ggml_time_us() - profiler->t_start_us)); + } + + profiler->active_kind = LLAMA_DFLASH_KV_NODE_NONE; + profiler->t_start_us = 0; + return true; +} + +static inline bool llama_dflash_main_node_eval_callback(struct ggml_tensor * tensor, bool ask, void * user_data) { + auto * profiler = static_cast(user_data); + if (profiler == nullptr || profiler->profile == nullptr) { + return false; + } + + const llama_dflash_main_node_kind kind = llama_dflash_main_node_kind_from_tensor(tensor); + if (ask) { + profiler->prev_active = profiler->prev_callback != nullptr + ? profiler->prev_callback(tensor, ask, profiler->prev_user_data) + : false; + + if (kind == LLAMA_DFLASH_MAIN_NODE_NONE) { + profiler->active_kind = LLAMA_DFLASH_MAIN_NODE_NONE; + profiler->t_start_us = 0; + return profiler->prev_active; + } + + profiler->active_kind = kind; + profiler->t_start_us = ggml_time_us(); + return true; + } + + bool prev_result = false; + if (profiler->prev_active && profiler->prev_callback != nullptr) { + prev_result = profiler->prev_callback(tensor, ask, profiler->prev_user_data); + } + + const bool tracked = kind != LLAMA_DFLASH_MAIN_NODE_NONE && + profiler->active_kind == kind && + profiler->t_start_us > 0; + if (tracked) { + llama_dflash_main_node_profile_add(*profiler->profile, kind, (uint64_t) (ggml_time_us() - profiler->t_start_us)); + } + + profiler->prev_active = false; + profiler->active_kind = LLAMA_DFLASH_MAIN_NODE_NONE; + profiler->t_start_us = 0; + return prev_result || tracked; +} diff --git a/src/llama-dflash.cpp b/src/llama-dflash.cpp index 9230840d..e32f53d1 100644 --- a/src/llama-dflash.cpp +++ b/src/llama-dflash.cpp @@ -5,6 +5,7 @@ #include "llama-context.h" #include "llama-model.h" #include "llama-spec-features.h" +#include "llama-dflash-profile.h" #include "ggml.h" #include "ggml-backend.h" @@ -27,342 +28,6 @@ static bool llama_dflash_stats_log_enabled() { return llama_env_flag_enabled_local("IK_DFLASH_STATS_LOG"); } -enum llama_dflash_kv_node_kind { - LLAMA_DFLASH_KV_NODE_NONE = 0, - LLAMA_DFLASH_KV_NODE_FUSED_TARGET, - LLAMA_DFLASH_KV_NODE_K_PROJ, - LLAMA_DFLASH_KV_NODE_K_NORM, - LLAMA_DFLASH_KV_NODE_K_ROPE, - LLAMA_DFLASH_KV_NODE_V_PROJ, - LLAMA_DFLASH_KV_NODE_K_STORE, - LLAMA_DFLASH_KV_NODE_V_STORE, -}; - -enum llama_dflash_main_node_kind { - LLAMA_DFLASH_MAIN_NODE_NONE = 0, - LLAMA_DFLASH_MAIN_NODE_QCUR, - LLAMA_DFLASH_MAIN_NODE_K_DRAFT, - LLAMA_DFLASH_MAIN_NODE_V_DRAFT, - LLAMA_DFLASH_MAIN_NODE_K_CTX_VIEW, - LLAMA_DFLASH_MAIN_NODE_V_CTX_VIEW, - LLAMA_DFLASH_MAIN_NODE_K_CONCAT, - LLAMA_DFLASH_MAIN_NODE_V_CONCAT, - LLAMA_DFLASH_MAIN_NODE_K_PAD, - LLAMA_DFLASH_MAIN_NODE_V_PAD, - LLAMA_DFLASH_MAIN_NODE_K_PERM_CONT, - LLAMA_DFLASH_MAIN_NODE_V_PERM_CONT, - LLAMA_DFLASH_MAIN_NODE_FLASH_ATTN, - LLAMA_DFLASH_MAIN_NODE_ATTN_OUT, - LLAMA_DFLASH_MAIN_NODE_FFN, - LLAMA_DFLASH_MAIN_NODE_RESULT_ROWS, - LLAMA_DFLASH_MAIN_NODE_RESULT_NORM, - LLAMA_DFLASH_MAIN_NODE_RESULT, -}; - -struct llama_dflash_kv_node_profiler { - llama_dflash_profile_stats * profile = nullptr; - int64_t t_start_us = 0; - llama_dflash_kv_node_kind active_kind = LLAMA_DFLASH_KV_NODE_NONE; -}; - -struct llama_dflash_main_node_profiler { - llama_dflash_profile_stats * profile = nullptr; - ggml_backend_sched_eval_callback prev_callback = nullptr; - void * prev_user_data = nullptr; - bool prev_active = false; - int64_t t_start_us = 0; - llama_dflash_main_node_kind active_kind = LLAMA_DFLASH_MAIN_NODE_NONE; -}; - -static bool llama_dflash_tensor_name_has_prefix(const struct ggml_tensor * tensor, const char * prefix) { - if (tensor == nullptr || prefix == nullptr || prefix[0] == '\0') { - return false; - } - - return std::strncmp(tensor->name, prefix, std::strlen(prefix)) == 0; -} - -static bool llama_dflash_tensor_name_matches_label(const struct ggml_tensor * tensor, const char * label) { - if (!llama_dflash_tensor_name_has_prefix(tensor, label)) { - return false; - } - - const size_t label_len = std::strlen(label); - const char next = tensor->name[label_len]; - return next == '\0' || next == '-'; -} - -static llama_dflash_kv_node_kind llama_dflash_kv_node_kind_from_tensor(const struct ggml_tensor * tensor) { - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_fused_target")) { - return LLAMA_DFLASH_KV_NODE_FUSED_TARGET; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_proj")) { - return LLAMA_DFLASH_KV_NODE_K_PROJ; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_norm")) { - return LLAMA_DFLASH_KV_NODE_K_NORM; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_rope")) { - return LLAMA_DFLASH_KV_NODE_K_ROPE; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_v_proj")) { - return LLAMA_DFLASH_KV_NODE_V_PROJ; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_store")) { - return LLAMA_DFLASH_KV_NODE_K_STORE; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_v_store")) { - return LLAMA_DFLASH_KV_NODE_V_STORE; - } - - return LLAMA_DFLASH_KV_NODE_NONE; -} - -static void llama_dflash_kv_node_profile_add( - llama_dflash_profile_stats & profile, - llama_dflash_kv_node_kind kind, - uint64_t elapsed_us) { - switch (kind) { - case LLAMA_DFLASH_KV_NODE_FUSED_TARGET: - profile.graph_kv_node_fused_target_calls++; - profile.graph_kv_node_fused_target_us += elapsed_us; - break; - case LLAMA_DFLASH_KV_NODE_K_PROJ: - profile.graph_kv_node_k_proj_calls++; - profile.graph_kv_node_k_proj_us += elapsed_us; - break; - case LLAMA_DFLASH_KV_NODE_K_NORM: - profile.graph_kv_node_k_norm_calls++; - profile.graph_kv_node_k_norm_us += elapsed_us; - break; - case LLAMA_DFLASH_KV_NODE_K_ROPE: - profile.graph_kv_node_k_rope_calls++; - profile.graph_kv_node_k_rope_us += elapsed_us; - break; - case LLAMA_DFLASH_KV_NODE_V_PROJ: - profile.graph_kv_node_v_proj_calls++; - profile.graph_kv_node_v_proj_us += elapsed_us; - break; - case LLAMA_DFLASH_KV_NODE_K_STORE: - profile.graph_kv_node_k_store_calls++; - profile.graph_kv_node_k_store_us += elapsed_us; - break; - case LLAMA_DFLASH_KV_NODE_V_STORE: - profile.graph_kv_node_v_store_calls++; - profile.graph_kv_node_v_store_us += elapsed_us; - break; - case LLAMA_DFLASH_KV_NODE_NONE: - break; - } -} - -static llama_dflash_main_node_kind llama_dflash_main_node_kind_from_tensor(const struct ggml_tensor * tensor) { - if (llama_dflash_tensor_name_has_prefix(tensor, "Qcur")) { - return LLAMA_DFLASH_MAIN_NODE_QCUR; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "Kcur_noise")) { - return LLAMA_DFLASH_MAIN_NODE_K_DRAFT; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "Vcur_noise")) { - return LLAMA_DFLASH_MAIN_NODE_V_DRAFT; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "Kcur_ctx_cache")) { - return LLAMA_DFLASH_MAIN_NODE_K_CTX_VIEW; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "Vcur_ctx_cache")) { - return LLAMA_DFLASH_MAIN_NODE_V_CTX_VIEW; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_k_concat")) { - return LLAMA_DFLASH_MAIN_NODE_K_CONCAT; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_v_concat")) { - return LLAMA_DFLASH_MAIN_NODE_V_CONCAT; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_k_pad")) { - return LLAMA_DFLASH_MAIN_NODE_K_PAD; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_v_pad")) { - return LLAMA_DFLASH_MAIN_NODE_V_PAD; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_k_perm_cont")) { - return LLAMA_DFLASH_MAIN_NODE_K_PERM_CONT; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_v_perm_cont")) { - return LLAMA_DFLASH_MAIN_NODE_V_PERM_CONT; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "flash_attn_reshaped")) { - return LLAMA_DFLASH_MAIN_NODE_NONE; - } - if (llama_dflash_tensor_name_matches_label(tensor, "flash_attn")) { - return LLAMA_DFLASH_MAIN_NODE_FLASH_ATTN; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "kqv_out")) { - return LLAMA_DFLASH_MAIN_NODE_ATTN_OUT; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "ffn_out")) { - return LLAMA_DFLASH_MAIN_NODE_FFN; - } - if (llama_dflash_tensor_name_matches_label(tensor, "result_output_rows")) { - return LLAMA_DFLASH_MAIN_NODE_RESULT_ROWS; - } - if (llama_dflash_tensor_name_matches_label(tensor, "result_norm")) { - return LLAMA_DFLASH_MAIN_NODE_RESULT_NORM; - } - if (llama_dflash_tensor_name_matches_label(tensor, "output")) { - return LLAMA_DFLASH_MAIN_NODE_RESULT; - } - if (llama_dflash_tensor_name_matches_label(tensor, "result_output")) { - return LLAMA_DFLASH_MAIN_NODE_RESULT; - } - - return LLAMA_DFLASH_MAIN_NODE_NONE; -} - -static void llama_dflash_main_node_profile_add( - llama_dflash_profile_stats & profile, - llama_dflash_main_node_kind kind, - uint64_t elapsed_us) { - switch (kind) { - case LLAMA_DFLASH_MAIN_NODE_QCUR: - profile.graph_main_node_qcur_calls++; - profile.graph_main_node_qcur_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_K_DRAFT: - profile.graph_main_node_k_draft_calls++; - profile.graph_main_node_k_draft_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_V_DRAFT: - profile.graph_main_node_v_draft_calls++; - profile.graph_main_node_v_draft_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_K_CTX_VIEW: - profile.graph_main_node_k_ctx_view_calls++; - profile.graph_main_node_k_ctx_view_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_V_CTX_VIEW: - profile.graph_main_node_v_ctx_view_calls++; - profile.graph_main_node_v_ctx_view_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_K_CONCAT: - profile.graph_main_node_k_concat_calls++; - profile.graph_main_node_k_concat_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_V_CONCAT: - profile.graph_main_node_v_concat_calls++; - profile.graph_main_node_v_concat_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_K_PAD: - profile.graph_main_node_k_pad_calls++; - profile.graph_main_node_k_pad_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_V_PAD: - profile.graph_main_node_v_pad_calls++; - profile.graph_main_node_v_pad_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_K_PERM_CONT: - profile.graph_main_node_k_perm_cont_calls++; - profile.graph_main_node_k_perm_cont_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_V_PERM_CONT: - profile.graph_main_node_v_perm_cont_calls++; - profile.graph_main_node_v_perm_cont_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_FLASH_ATTN: - profile.graph_main_node_flash_attn_calls++; - profile.graph_main_node_flash_attn_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_ATTN_OUT: - profile.graph_main_node_attn_out_calls++; - profile.graph_main_node_attn_out_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_FFN: - profile.graph_main_node_ffn_calls++; - profile.graph_main_node_ffn_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_RESULT_ROWS: - profile.graph_main_node_result_rows_calls++; - profile.graph_main_node_result_rows_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_RESULT_NORM: - profile.graph_main_node_result_norm_calls++; - profile.graph_main_node_result_norm_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_RESULT: - profile.graph_main_node_result_calls++; - profile.graph_main_node_result_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_NONE: - break; - } -} - -static bool llama_dflash_kv_node_eval_callback(struct ggml_tensor * tensor, bool ask, void * user_data) { - auto * profiler = static_cast(user_data); - if (profiler == nullptr || profiler->profile == nullptr) { - return false; - } - - const llama_dflash_kv_node_kind kind = llama_dflash_kv_node_kind_from_tensor(tensor); - if (ask) { - if (kind == LLAMA_DFLASH_KV_NODE_NONE) { - return false; - } - - profiler->active_kind = kind; - profiler->t_start_us = ggml_time_us(); - return true; - } - - if (kind != LLAMA_DFLASH_KV_NODE_NONE && profiler->active_kind == kind && profiler->t_start_us > 0) { - llama_dflash_kv_node_profile_add(*profiler->profile, kind, (uint64_t) (ggml_time_us() - profiler->t_start_us)); - } - - profiler->active_kind = LLAMA_DFLASH_KV_NODE_NONE; - profiler->t_start_us = 0; - return true; -} - -static bool llama_dflash_main_node_eval_callback(struct ggml_tensor * tensor, bool ask, void * user_data) { - auto * profiler = static_cast(user_data); - if (profiler == nullptr || profiler->profile == nullptr) { - return false; - } - - const llama_dflash_main_node_kind kind = llama_dflash_main_node_kind_from_tensor(tensor); - if (ask) { - profiler->prev_active = profiler->prev_callback != nullptr - ? profiler->prev_callback(tensor, ask, profiler->prev_user_data) - : false; - - if (kind == LLAMA_DFLASH_MAIN_NODE_NONE) { - profiler->active_kind = LLAMA_DFLASH_MAIN_NODE_NONE; - profiler->t_start_us = 0; - return profiler->prev_active; - } - - profiler->active_kind = kind; - profiler->t_start_us = ggml_time_us(); - return true; - } - - bool prev_result = false; - if (profiler->prev_active && profiler->prev_callback != nullptr) { - prev_result = profiler->prev_callback(tensor, ask, profiler->prev_user_data); - } - - const bool tracked = kind != LLAMA_DFLASH_MAIN_NODE_NONE && - profiler->active_kind == kind && - profiler->t_start_us > 0; - if (tracked) { - llama_dflash_main_node_profile_add(*profiler->profile, kind, (uint64_t) (ggml_time_us() - profiler->t_start_us)); - } - - profiler->prev_active = false; - profiler->active_kind = LLAMA_DFLASH_MAIN_NODE_NONE; - profiler->t_start_us = 0; - return prev_result || tracked; -} - void llama_sync_dflash_workspace_if_pending(struct llama_context & lctx) { if (!lctx.dflash_kv_workspace_sync_pending || lctx.dflash_workspace_sched == nullptr) { return; diff --git a/src/llama-spec-features-dflash.cpp b/src/llama-spec-features-dflash.cpp index 2df71006..c8cd5181 100644 --- a/src/llama-spec-features-dflash.cpp +++ b/src/llama-spec-features-dflash.cpp @@ -880,6 +880,36 @@ static void llama_dflash_contract_log_output_indices( have_capture ? "true" : "false"); } +void llama_dflash_contract_log_accept( + int slot_id, + bool is_dflash, + const char * path, + bool any_rejected, + size_t n_draft, + size_t n_accepted, + llama_pos pos_base, + const std::vector & output_indices) { + if (!llama_dflash_contract_log_enabled() || !is_dflash) { + return; + } + + static std::atomic counter = 0; + const uint64_t ordinal = counter.fetch_add(1, std::memory_order_relaxed); + if (ordinal >= 8) { + return; + } + + LLAMA_LOG_INFO("dflash contract accept[%llu]: slot=%d path=%s rejected=%s drafted=%zu accepted=%zu pos_base=%d output_indices=%s\n", + (unsigned long long) (ordinal + 1), + slot_id, + path, + any_rejected ? "true" : "false", + n_draft, + n_accepted, + (int) pos_base, + llama_dflash_contract_format_values(output_indices).c_str()); +} + static bool llama_spec_materialize_dflash_rows_prepared( struct llama_context * ctx, int32_t row_count, diff --git a/src/llama-spec-features-dflash.h b/src/llama-spec-features-dflash.h index 02e709d1..e05f2a91 100644 --- a/src/llama-spec-features-dflash.h +++ b/src/llama-spec-features-dflash.h @@ -277,3 +277,13 @@ bool llama_spec_copy_dflash_rows_from_output_indices( struct llama_context * ctx, const std::vector & output_indices, std::vector & hidden_rows); + +void llama_dflash_contract_log_accept( + int slot_id, + bool is_dflash, + const char * path, + bool any_rejected, + size_t n_draft, + size_t n_accepted, + llama_pos pos_base, + const std::vector & output_indices); diff --git a/src/llama-spec-features.cpp b/src/llama-spec-features.cpp index 86e00399..2e93b5fa 100644 --- a/src/llama-spec-features.cpp +++ b/src/llama-spec-features.cpp @@ -1,11 +1,6 @@ #include "llama-spec-features.h" -#include -#include -#include -#include #include -#include #include "llama-model.h" #include "llama-context.h" diff --git a/src/llama.cpp b/src/llama.cpp index 48b6815c..8adab80e 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -19,6 +19,7 @@ #include "llama-context.h" #include "llama-spec-features.h" #include "llama-dflash.h" +#include "llama-dflash-profile.h" #include "llama-quantize.h" #include "unicode.h" @@ -180,343 +181,6 @@ static bool llama_env_flag_enabled(const char * name) { std::strcmp(env, "off") != 0; } -enum llama_dflash_kv_node_kind { - LLAMA_DFLASH_KV_NODE_NONE = 0, - LLAMA_DFLASH_KV_NODE_FUSED_TARGET, - LLAMA_DFLASH_KV_NODE_K_PROJ, - LLAMA_DFLASH_KV_NODE_K_NORM, - LLAMA_DFLASH_KV_NODE_K_ROPE, - LLAMA_DFLASH_KV_NODE_V_PROJ, - LLAMA_DFLASH_KV_NODE_K_STORE, - LLAMA_DFLASH_KV_NODE_V_STORE, -}; - -enum llama_dflash_main_node_kind { - LLAMA_DFLASH_MAIN_NODE_NONE = 0, - LLAMA_DFLASH_MAIN_NODE_QCUR, - LLAMA_DFLASH_MAIN_NODE_K_DRAFT, - LLAMA_DFLASH_MAIN_NODE_V_DRAFT, - LLAMA_DFLASH_MAIN_NODE_K_CTX_VIEW, - LLAMA_DFLASH_MAIN_NODE_V_CTX_VIEW, - LLAMA_DFLASH_MAIN_NODE_K_CONCAT, - LLAMA_DFLASH_MAIN_NODE_V_CONCAT, - LLAMA_DFLASH_MAIN_NODE_K_PAD, - LLAMA_DFLASH_MAIN_NODE_V_PAD, - LLAMA_DFLASH_MAIN_NODE_K_PERM_CONT, - LLAMA_DFLASH_MAIN_NODE_V_PERM_CONT, - LLAMA_DFLASH_MAIN_NODE_FLASH_ATTN, - LLAMA_DFLASH_MAIN_NODE_ATTN_OUT, - LLAMA_DFLASH_MAIN_NODE_FFN, - LLAMA_DFLASH_MAIN_NODE_RESULT_ROWS, - LLAMA_DFLASH_MAIN_NODE_RESULT_NORM, - LLAMA_DFLASH_MAIN_NODE_RESULT, -}; - -struct llama_dflash_kv_node_profiler { - llama_dflash_profile_stats * profile = nullptr; - int64_t t_start_us = 0; - llama_dflash_kv_node_kind active_kind = LLAMA_DFLASH_KV_NODE_NONE; -}; - -struct llama_dflash_main_node_profiler { - llama_dflash_profile_stats * profile = nullptr; - ggml_backend_sched_eval_callback prev_callback = nullptr; - void * prev_user_data = nullptr; - bool prev_active = false; - int64_t t_start_us = 0; - llama_dflash_main_node_kind active_kind = LLAMA_DFLASH_MAIN_NODE_NONE; -}; - -static bool llama_dflash_tensor_name_has_prefix(const struct ggml_tensor * tensor, const char * prefix) { - if (tensor == nullptr || prefix == nullptr || prefix[0] == '\0') { - return false; - } - - return std::strncmp(tensor->name, prefix, std::strlen(prefix)) == 0; -} - -static bool llama_dflash_tensor_name_matches_label(const struct ggml_tensor * tensor, const char * label) { - if (!llama_dflash_tensor_name_has_prefix(tensor, label)) { - return false; - } - - const size_t label_len = std::strlen(label); - const char next = tensor->name[label_len]; - return next == '\0' || next == '-'; -} - -static llama_dflash_kv_node_kind llama_dflash_kv_node_kind_from_tensor(const struct ggml_tensor * tensor) { - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_fused_target")) { - return LLAMA_DFLASH_KV_NODE_FUSED_TARGET; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_proj")) { - return LLAMA_DFLASH_KV_NODE_K_PROJ; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_norm")) { - return LLAMA_DFLASH_KV_NODE_K_NORM; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_rope")) { - return LLAMA_DFLASH_KV_NODE_K_ROPE; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_v_proj")) { - return LLAMA_DFLASH_KV_NODE_V_PROJ; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_store")) { - return LLAMA_DFLASH_KV_NODE_K_STORE; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_v_store")) { - return LLAMA_DFLASH_KV_NODE_V_STORE; - } - - return LLAMA_DFLASH_KV_NODE_NONE; -} - -static void llama_dflash_kv_node_profile_add( - llama_dflash_profile_stats & profile, - llama_dflash_kv_node_kind kind, - uint64_t elapsed_us) { - switch (kind) { - case LLAMA_DFLASH_KV_NODE_FUSED_TARGET: - profile.graph_kv_node_fused_target_calls++; - profile.graph_kv_node_fused_target_us += elapsed_us; - break; - case LLAMA_DFLASH_KV_NODE_K_PROJ: - profile.graph_kv_node_k_proj_calls++; - profile.graph_kv_node_k_proj_us += elapsed_us; - break; - case LLAMA_DFLASH_KV_NODE_K_NORM: - profile.graph_kv_node_k_norm_calls++; - profile.graph_kv_node_k_norm_us += elapsed_us; - break; - case LLAMA_DFLASH_KV_NODE_K_ROPE: - profile.graph_kv_node_k_rope_calls++; - profile.graph_kv_node_k_rope_us += elapsed_us; - break; - case LLAMA_DFLASH_KV_NODE_V_PROJ: - profile.graph_kv_node_v_proj_calls++; - profile.graph_kv_node_v_proj_us += elapsed_us; - break; - case LLAMA_DFLASH_KV_NODE_K_STORE: - profile.graph_kv_node_k_store_calls++; - profile.graph_kv_node_k_store_us += elapsed_us; - break; - case LLAMA_DFLASH_KV_NODE_V_STORE: - profile.graph_kv_node_v_store_calls++; - profile.graph_kv_node_v_store_us += elapsed_us; - break; - case LLAMA_DFLASH_KV_NODE_NONE: - break; - } -} - -static llama_dflash_main_node_kind llama_dflash_main_node_kind_from_tensor(const struct ggml_tensor * tensor) { - if (llama_dflash_tensor_name_has_prefix(tensor, "Qcur")) { - return LLAMA_DFLASH_MAIN_NODE_QCUR; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "Kcur_noise")) { - return LLAMA_DFLASH_MAIN_NODE_K_DRAFT; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "Vcur_noise")) { - return LLAMA_DFLASH_MAIN_NODE_V_DRAFT; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "Kcur_ctx_cache")) { - return LLAMA_DFLASH_MAIN_NODE_K_CTX_VIEW; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "Vcur_ctx_cache")) { - return LLAMA_DFLASH_MAIN_NODE_V_CTX_VIEW; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_k_concat")) { - return LLAMA_DFLASH_MAIN_NODE_K_CONCAT; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_v_concat")) { - return LLAMA_DFLASH_MAIN_NODE_V_CONCAT; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_k_pad")) { - return LLAMA_DFLASH_MAIN_NODE_K_PAD; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_v_pad")) { - return LLAMA_DFLASH_MAIN_NODE_V_PAD; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_k_perm_cont")) { - return LLAMA_DFLASH_MAIN_NODE_K_PERM_CONT; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_v_perm_cont")) { - return LLAMA_DFLASH_MAIN_NODE_V_PERM_CONT; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "flash_attn_reshaped")) { - return LLAMA_DFLASH_MAIN_NODE_NONE; - } - if (llama_dflash_tensor_name_matches_label(tensor, "flash_attn")) { - return LLAMA_DFLASH_MAIN_NODE_FLASH_ATTN; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "kqv_out")) { - return LLAMA_DFLASH_MAIN_NODE_ATTN_OUT; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "ffn_out")) { - return LLAMA_DFLASH_MAIN_NODE_FFN; - } - if (llama_dflash_tensor_name_matches_label(tensor, "result_output_rows")) { - return LLAMA_DFLASH_MAIN_NODE_RESULT_ROWS; - } - if (llama_dflash_tensor_name_matches_label(tensor, "result_norm")) { - return LLAMA_DFLASH_MAIN_NODE_RESULT_NORM; - } - if (llama_dflash_tensor_name_matches_label(tensor, "output")) { - return LLAMA_DFLASH_MAIN_NODE_RESULT; - } - if (llama_dflash_tensor_name_matches_label(tensor, "result_output")) { - return LLAMA_DFLASH_MAIN_NODE_RESULT; - } - - return LLAMA_DFLASH_MAIN_NODE_NONE; -} - -static void llama_dflash_main_node_profile_add( - llama_dflash_profile_stats & profile, - llama_dflash_main_node_kind kind, - uint64_t elapsed_us) { - switch (kind) { - case LLAMA_DFLASH_MAIN_NODE_QCUR: - profile.graph_main_node_qcur_calls++; - profile.graph_main_node_qcur_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_K_DRAFT: - profile.graph_main_node_k_draft_calls++; - profile.graph_main_node_k_draft_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_V_DRAFT: - profile.graph_main_node_v_draft_calls++; - profile.graph_main_node_v_draft_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_K_CTX_VIEW: - profile.graph_main_node_k_ctx_view_calls++; - profile.graph_main_node_k_ctx_view_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_V_CTX_VIEW: - profile.graph_main_node_v_ctx_view_calls++; - profile.graph_main_node_v_ctx_view_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_K_CONCAT: - profile.graph_main_node_k_concat_calls++; - profile.graph_main_node_k_concat_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_V_CONCAT: - profile.graph_main_node_v_concat_calls++; - profile.graph_main_node_v_concat_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_K_PAD: - profile.graph_main_node_k_pad_calls++; - profile.graph_main_node_k_pad_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_V_PAD: - profile.graph_main_node_v_pad_calls++; - profile.graph_main_node_v_pad_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_K_PERM_CONT: - profile.graph_main_node_k_perm_cont_calls++; - profile.graph_main_node_k_perm_cont_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_V_PERM_CONT: - profile.graph_main_node_v_perm_cont_calls++; - profile.graph_main_node_v_perm_cont_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_FLASH_ATTN: - profile.graph_main_node_flash_attn_calls++; - profile.graph_main_node_flash_attn_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_ATTN_OUT: - profile.graph_main_node_attn_out_calls++; - profile.graph_main_node_attn_out_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_FFN: - profile.graph_main_node_ffn_calls++; - profile.graph_main_node_ffn_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_RESULT_ROWS: - profile.graph_main_node_result_rows_calls++; - profile.graph_main_node_result_rows_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_RESULT_NORM: - profile.graph_main_node_result_norm_calls++; - profile.graph_main_node_result_norm_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_RESULT: - profile.graph_main_node_result_calls++; - profile.graph_main_node_result_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_NONE: - break; - } -} - -static bool llama_dflash_kv_node_eval_callback(struct ggml_tensor * tensor, bool ask, void * user_data) { - auto * profiler = static_cast(user_data); - if (profiler == nullptr || profiler->profile == nullptr) { - return false; - } - - const llama_dflash_kv_node_kind kind = llama_dflash_kv_node_kind_from_tensor(tensor); - if (ask) { - if (kind == LLAMA_DFLASH_KV_NODE_NONE) { - return false; - } - - profiler->active_kind = kind; - profiler->t_start_us = ggml_time_us(); - return true; - } - - if (kind != LLAMA_DFLASH_KV_NODE_NONE && profiler->active_kind == kind && profiler->t_start_us > 0) { - llama_dflash_kv_node_profile_add(*profiler->profile, kind, (uint64_t) (ggml_time_us() - profiler->t_start_us)); - } - - profiler->active_kind = LLAMA_DFLASH_KV_NODE_NONE; - profiler->t_start_us = 0; - return true; -} - -static bool llama_dflash_main_node_eval_callback(struct ggml_tensor * tensor, bool ask, void * user_data) { - auto * profiler = static_cast(user_data); - if (profiler == nullptr || profiler->profile == nullptr) { - return false; - } - - const llama_dflash_main_node_kind kind = llama_dflash_main_node_kind_from_tensor(tensor); - if (ask) { - profiler->prev_active = profiler->prev_callback != nullptr - ? profiler->prev_callback(tensor, ask, profiler->prev_user_data) - : false; - - if (kind == LLAMA_DFLASH_MAIN_NODE_NONE) { - profiler->active_kind = LLAMA_DFLASH_MAIN_NODE_NONE; - profiler->t_start_us = 0; - return profiler->prev_active; - } - - profiler->active_kind = kind; - profiler->t_start_us = ggml_time_us(); - return true; - } - - bool prev_result = false; - if (profiler->prev_active && profiler->prev_callback != nullptr) { - prev_result = profiler->prev_callback(tensor, ask, profiler->prev_user_data); - } - - const bool tracked = kind != LLAMA_DFLASH_MAIN_NODE_NONE && - profiler->active_kind == kind && - profiler->t_start_us > 0; - if (tracked) { - llama_dflash_main_node_profile_add(*profiler->profile, kind, (uint64_t) (ggml_time_us() - profiler->t_start_us)); - } - - profiler->prev_active = false; - profiler->active_kind = LLAMA_DFLASH_MAIN_NODE_NONE; - profiler->t_start_us = 0; - return prev_result || tracked; -} - - // extract ip and port from RPC[ip:port] for rpc and keep other device names static std::vector extract_device_from_rpc_device(std::vector devices) { std::vector rpc_servers; From 6cae8c7ba2d699afa38506fd963cd425c183db32 Mon Sep 17 00:00:00 2001 From: SamuelOliveirads Date: Sun, 14 Jun 2026 21:07:57 -0300 Subject: [PATCH 12/13] clean logs --- common/speculative-dflash-impl.h | 348 +---------------- common/speculative.cpp | 395 +------------------ examples/server/server-context.cpp | 17 - src/graphs/build_dflash.cpp | 181 +++++---- src/llama-context.h | 77 +--- src/llama-dflash-profile.h | 340 ----------------- src/llama-dflash.cpp | 586 +++++++++++------------------ src/llama-spec-features-dflash.cpp | 442 +--------------------- src/llama-spec-features-dflash.h | 153 -------- src/llama.cpp | 117 +----- 10 files changed, 351 insertions(+), 2305 deletions(-) delete mode 100644 src/llama-dflash-profile.h diff --git a/common/speculative-dflash-impl.h b/common/speculative-dflash-impl.h index c644ddbb..746c8130 100644 --- a/common/speculative-dflash-impl.h +++ b/common/speculative-dflash-impl.h @@ -1,11 +1,8 @@ #pragma once #include -#include #include -#include #include -#include #include static bool common_speculative_are_dflash_compatible( @@ -71,102 +68,7 @@ static bool common_speculative_are_dflash_compatible( return true; } -static bool dflash_contract_log_enabled() { - const char * env = std::getenv("IK_DFLASH_CONTRACT_LOG"); - if (env == nullptr || *env == '\0') { - return false; - } - - return std::strcmp(env, "0") != 0 && - std::strcmp(env, "false") != 0 && - std::strcmp(env, "off") != 0; -} - -static bool dflash_stats_log_enabled() { - const char * env = std::getenv("IK_DFLASH_STATS_LOG"); - if (env == nullptr || *env == '\0') { - return false; - } - - return std::strcmp(env, "0") != 0 && - std::strcmp(env, "false") != 0 && - std::strcmp(env, "off") != 0; -} - -template -static std::string dflash_contract_format_values( - const std::vector & values, - size_t edge_count = 4) { - std::ostringstream oss; - oss << '['; - if (values.empty()) { - oss << ']'; - return oss.str(); - } - - const size_t head = std::min(edge_count, values.size()); - for (size_t i = 0; i < head; ++i) { - if (i > 0) { - oss << ','; - } - oss << values[i]; - } - - if (values.size() > edge_count * 2) { - oss << ",...,"; - for (size_t i = values.size() - edge_count; i < values.size(); ++i) { - if (i > values.size() - edge_count) { - oss << ','; - } - oss << values[i]; - } - } else { - for (size_t i = head; i < values.size(); ++i) { - oss << ',' << values[i]; - } - } - - oss << ']'; - return oss.str(); -} - -struct dflash_contract_pos_summary { - llama_pos first = -1; - llama_pos last = -1; - int32_t gap_count = 0; - int32_t nonmono_count = 0; -}; - -static dflash_contract_pos_summary dflash_contract_summarize_positions( - const std::vector & positions) { - dflash_contract_pos_summary summary; - if (positions.empty()) { - return summary; - } - - summary.first = positions.front(); - summary.last = positions.back(); - for (size_t i = 1; i < positions.size(); ++i) { - if (positions[i] <= positions[i - 1]) { - summary.nonmono_count++; - } else if (positions[i] != positions[i - 1] + 1) { - summary.gap_count++; - } - } - - return summary; -} - struct common_speculative_state_dflash; - -static void dflash_contract_log_append( - const common_speculative_state_dflash & state, - llama_seq_id seq_id, - const std::vector & new_positions); -static void dflash_contract_log_draft( - const common_speculative_state_dflash & state, - int32_t n_keep, - size_t result_size); static void dflash_materialize_target_window_features(common_speculative_state_dflash & state); // DFlash runtime state and draft path. @@ -198,41 +100,6 @@ struct common_speculative_state_dflash : public common_speculative_state { bool target_window_replace = false; bool target_window_materialized = false; llama_pos last_target_pos = -1; - size_t n_window_updates = 0; - size_t n_rows_seen = 0; - size_t n_rows_dropped = 0; - size_t n_context_shifts = 0; - size_t n_draft_empty = 0; - size_t n_set_target_fail = 0; - size_t n_decode_fail = 0; - llama_pos last_draft_pos_base = -1; - - uint64_t t_draft_decode_us = 0; - uint64_t t_draft_sample_us = 0; - uint64_t t_warmup_collect_us = 0; - uint64_t t_warmup_append_us = 0; - uint64_t t_accept_output_copy_us = 0; - uint64_t t_accept_commit_us = 0; - uint64_t t_accept_append_us = 0; - uint64_t t_accept_append_filter_us = 0; - uint64_t t_accept_append_window_alloc_us = 0; - uint64_t t_accept_append_replace_us = 0; - uint64_t t_accept_append_keep_old_us = 0; - uint64_t t_accept_append_new_rows_us = 0; - uint64_t t_accept_append_commit_detail_us = 0; - uint64_t t_accept_append_log_us = 0; - size_t n_warmup_collect_calls = 0; - size_t n_warmup_collect_rows = 0; - size_t n_warmup_append_calls = 0; - size_t n_warmup_append_rows = 0; - size_t n_accept_output_copy_calls = 0; - size_t n_accept_output_copy_rows = 0; - size_t n_accept_commit_calls = 0; - size_t n_accept_commit_rows = 0; - size_t n_accept_append_calls = 0; - size_t n_accept_append_rows = 0; - size_t n_accept_append_replace_calls = 0; - size_t n_accept_append_slide_calls = 0; common_speculative_state_dflash( enum common_speculative_type type, @@ -271,9 +138,7 @@ struct common_speculative_state_dflash : public common_speculative_state { } const auto * vocab_tgt = llama_model_get_vocab(model_tgt); - const auto * vocab_dft = llama_model_get_vocab(model_dft); const int32_t target_vocab_size = llama_vocab_n_tokens(vocab_tgt); - const int32_t draft_vocab_size = llama_vocab_n_tokens(vocab_dft); const int32_t target_hidden_size = llama_model_n_embd(model_tgt); const int32_t draft_hidden_size = llama_model_n_embd(model_dft); const int32_t target_mask_token_id = llama_model_dflash_target_mask_token_id(model_tgt); @@ -349,22 +214,8 @@ struct common_speculative_state_dflash : public common_speculative_state { ready = true; llama_set_dflash_visible_cross_ctx(ctx_dft, this->cross_ctx); - llama_dflash_profile_reset(ctx_tgt); - llama_dflash_profile_reset(ctx_dft); - - std::ostringstream layers_oss; - for (size_t i = 0; i < target_layer_ids.size(); ++i) { - if (i > 0) { - layers_oss << ","; - } - layers_oss << target_layer_ids[i]; - } - - const char * io_mode_name = io_mode == LLAMA_DFLASH_IO_MODE_SHARED ? "shared" : "self-contained"; - LOG_INF("%s: DFlash context ready (n_ctx=%d, block_size=%d, cross_ctx=%d, n_target_features=%d, target_layer_ids=[%s])\n", - __func__, llama_n_ctx(ctx_dft), block_size, this->cross_ctx, n_target_features, layers_oss.str().c_str()); - LOG_INF("%s: DFlash artifact io=%s draft_vocab=%d target_vocab=%d draft_hidden=%d target_hidden=%d mask_token_id=%d target_mask_token_id=%d\n", - __func__, io_mode_name, draft_vocab_size, target_vocab_size, draft_hidden_size, target_hidden_size, mask_token_id, target_mask_token_id); + LOG_INF("%s: DFlash context ready (n_ctx=%d, block_size=%d, cross_ctx=%d, n_target_features=%d, n_target_layers=%d)\n", + __func__, llama_n_ctx(ctx_dft), block_size, this->cross_ctx, n_target_features, n_target_layers); } ~common_speculative_state_dflash() override { @@ -381,42 +232,6 @@ struct common_speculative_state_dflash : public common_speculative_state { GGML_UNUSED(prompt); llama_kv_cache_clear(ctx_dft); llama_reset_dflash_kv_cache_state(ctx_dft); - n_window_updates = 0; - n_rows_seen = 0; - n_rows_dropped = 0; - n_context_shifts = 0; - n_draft_empty = 0; - n_set_target_fail = 0; - n_decode_fail = 0; - last_draft_pos_base = -1; - t_draft_decode_us = 0; - t_draft_sample_us = 0; - t_warmup_collect_us = 0; - t_warmup_append_us = 0; - t_accept_output_copy_us = 0; - t_accept_commit_us = 0; - t_accept_append_us = 0; - t_accept_append_filter_us = 0; - t_accept_append_window_alloc_us = 0; - t_accept_append_replace_us = 0; - t_accept_append_keep_old_us = 0; - t_accept_append_new_rows_us = 0; - t_accept_append_commit_detail_us = 0; - t_accept_append_log_us = 0; - n_warmup_collect_calls = 0; - n_warmup_collect_rows = 0; - n_warmup_append_calls = 0; - n_warmup_append_rows = 0; - n_accept_output_copy_calls = 0; - n_accept_output_copy_rows = 0; - n_accept_commit_calls = 0; - n_accept_commit_rows = 0; - n_accept_append_calls = 0; - n_accept_append_rows = 0; - n_accept_append_replace_calls = 0; - n_accept_append_slide_calls = 0; - llama_dflash_profile_reset(ctx_tgt); - llama_dflash_profile_reset(ctx_dft); } void draft( @@ -428,7 +243,6 @@ struct common_speculative_state_dflash : public common_speculative_state { result.clear(); if (!ready || target_window_rows <= 0) { - n_draft_empty++; return; } @@ -461,7 +275,6 @@ struct common_speculative_state_dflash : public common_speculative_state { if (!llama_set_dflash_target_features_view(ctx_dft, target_features, target_feature_floats, target_window_rows, target_window_pos.data(), &window_update)) { LOG_ERR("%s: failed to set DFlash target features\n", __func__); - n_set_target_fail++; return; } @@ -470,23 +283,18 @@ struct common_speculative_state_dflash : public common_speculative_state { const int32_t batch_len = n_keep + 1; const llama_pos draft_pos_base = last_target_pos >= 0 ? last_target_pos + 1 : (llama_pos) target_window_rows; const llama_pos seed_pos = last_target_pos >= 0 ? last_target_pos : draft_pos_base - 1; - last_draft_pos_base = draft_pos_base; common_batch_add(batch, id_last, seed_pos, { 0 }, false); for (int32_t i = 1; i < batch_len; ++i) { common_batch_add(batch, mask_token_id, draft_pos_base + (i - 1), { 0 }, i <= n_keep); } - const int64_t t_decode_us = ggml_time_us(); if (llama_decode(ctx_dft, batch) != 0) { LOG_ERR("%s: llama_decode() failed for DFlash draft batch\n", __func__); - n_decode_fail++; batch.n_tokens = 0; return; } - t_draft_decode_us += (uint64_t) (ggml_time_us() - t_decode_us); result.reserve((size_t) n_keep); - const int64_t t_sample_us = ggml_time_us(); for (int32_t i = 0; i < n_keep; ++i) { llama_token id = llama_get_dflash_draft_token_ith(ctx_dft, i); if (id == LLAMA_TOKEN_NULL) { @@ -494,10 +302,8 @@ struct common_speculative_state_dflash : public common_speculative_state { } result.push_back(id); } - t_draft_sample_us += (uint64_t) (ggml_time_us() - t_sample_us); batch.n_tokens = 0; - dflash_contract_log_draft(*this, n_keep, result.size()); } void accept(uint16_t n_accepted) override { @@ -505,104 +311,6 @@ struct common_speculative_state_dflash : public common_speculative_state { } }; -static void dflash_contract_log_append( - const common_speculative_state_dflash & state, - llama_seq_id seq_id, - const std::vector & new_positions) { - if (!dflash_contract_log_enabled()) { - return; - } - - static std::atomic counter = 0; - const uint64_t ordinal = counter.fetch_add(1, std::memory_order_relaxed); - if (ordinal >= 8) { - return; - } - - const dflash_contract_pos_summary incoming = dflash_contract_summarize_positions(new_positions); - const dflash_contract_pos_summary window = dflash_contract_summarize_positions(state.target_window_pos); - - LOG_INF("dflash contract append[%llu]: seq=%d incoming_rows=%zu incoming_pos=%s pos=[%d..%d] gaps=%d nonmono=%d window_rows=%d window_pos=%s pos=[%d..%d] gaps=%d nonmono=%d last_target_pos=%d\n", - (unsigned long long) (ordinal + 1), - (int) seq_id, - new_positions.size(), - dflash_contract_format_values(new_positions).c_str(), - (int) incoming.first, - (int) incoming.last, - incoming.gap_count, - incoming.nonmono_count, - state.target_window_rows, - dflash_contract_format_values(state.target_window_pos).c_str(), - (int) window.first, - (int) window.last, - window.gap_count, - window.nonmono_count, - (int) state.last_target_pos); -} - -static void dflash_contract_log_draft( - const common_speculative_state_dflash & state, - int32_t n_keep, - size_t result_size) { - if (!dflash_contract_log_enabled()) { - return; - } - - static std::atomic counter = 0; - const uint64_t ordinal = counter.fetch_add(1, std::memory_order_relaxed); - if (ordinal >= 8) { - return; - } - - const dflash_contract_pos_summary window = dflash_contract_summarize_positions(state.target_window_pos); - llama_dflash_profile_stats graph_stats = {}; - llama_dflash_profile_get_stats(state.ctx_dft, &graph_stats); - const int draft_delta = (state.last_target_pos >= 0 && state.last_draft_pos_base >= 0) - ? (int) (state.last_draft_pos_base - state.last_target_pos) - : -1; - const llama_pos seed_pos = state.last_target_pos; - const llama_pos mask_first_pos = state.last_draft_pos_base; - const llama_pos mask_last_pos = state.last_draft_pos_base >= 0 - ? state.last_draft_pos_base + n_keep - 1 - : -1; - - LOG_INF("dflash contract draft[%llu]: window_rows=%d window_pos=%s pos=[%d..%d] gaps=%d nonmono=%d last_target_pos=%d seed_pos=%d mask_pos=[%d..%d] sample_rows=[1..%d] output_rows=[1..%d] draft_pos_base=%d delta=%d n_keep=%d result=%zu set_target(missing/nonmono)=%llu/%llu graph(fallback/nonmono)=%llu/%llu graph_pos=[%d..%d]\n", - (unsigned long long) (ordinal + 1), - state.target_window_rows, - dflash_contract_format_values(state.target_window_pos).c_str(), - (int) window.first, - (int) window.last, - window.gap_count, - window.nonmono_count, - (int) state.last_target_pos, - (int) seed_pos, - (int) mask_first_pos, - (int) mask_last_pos, - n_keep, - n_keep, - (int) state.last_draft_pos_base, - draft_delta, - n_keep, - result_size, - (unsigned long long) graph_stats.set_target_missing_positions, - (unsigned long long) graph_stats.set_target_non_monotonic_positions, - (unsigned long long) graph_stats.graph_pos_fallbacks, - (unsigned long long) graph_stats.graph_pos_non_monotonic, - (int) graph_stats.last_pos_first, - (int) graph_stats.last_pos_last); -} - -struct dflash_append_breakdown { - uint64_t filter_us = 0; - uint64_t window_alloc_us = 0; - uint64_t replace_us = 0; - uint64_t keep_old_us = 0; - uint64_t new_rows_us = 0; - uint64_t commit_us = 0; - uint64_t log_us = 0; - bool replace_call = false; -}; - static void dflash_record_window_update( common_speculative_state_dflash & state, int32_t keep_rows, @@ -696,11 +404,7 @@ static void dflash_materialize_target_window_features(common_speculative_state_d static bool dflash_append_target_features( common_speculative_state_dflash & state, const common_speculative_feature_view & features, - const llama_batch & batch, - llama_seq_id seq_id, - dflash_append_breakdown * breakdown = nullptr) { - GGML_UNUSED(batch); - + llama_seq_id seq_id) { if (features.kind != COMMON_SPECULATIVE_FEATURE_HIDDEN_STATE || features.width != state.n_target_features || features.rows.empty() || @@ -714,7 +418,6 @@ static bool dflash_append_target_features( new_rows.reserve(features.rows.size() * row_width); new_positions.reserve(features.rows.size()); - const int64_t t_filter_us = ggml_time_us(); for (const auto & row : features.rows) { if (row.seq_id != seq_id || row.data == nullptr) { continue; @@ -723,89 +426,45 @@ static bool dflash_append_target_features( new_positions.push_back(row.pos); new_rows.insert(new_rows.end(), row.data, row.data + row_width); } - if (breakdown != nullptr) { - breakdown->filter_us += (uint64_t) (ggml_time_us() - t_filter_us); - } if (new_positions.empty()) { return false; } const int32_t n_rows = (int32_t) new_positions.size(); - state.n_window_updates++; - state.n_rows_seen += (size_t) n_rows; if (n_rows >= state.cross_ctx) { - state.n_rows_dropped += (size_t) state.target_window_rows + (size_t) (n_rows - state.cross_ctx); const int32_t keep_from = n_rows - state.cross_ctx; - const int64_t t_replace_us = ggml_time_us(); state.target_window_pos.assign(new_positions.begin() + keep_from, new_positions.end()); state.target_window_append_features.assign( new_rows.begin() + (ptrdiff_t) keep_from * (ptrdiff_t) row_width, new_rows.end()); dflash_ring_reset_rows(state, state.target_window_append_features.data(), state.cross_ctx); - if (breakdown != nullptr) { - breakdown->replace_us += (uint64_t) (ggml_time_us() - t_replace_us); - breakdown->replace_call = true; - } - const int64_t t_commit_us = ggml_time_us(); state.target_window_rows = state.cross_ctx; state.target_window_ring_filled = state.target_window_rows; state.last_target_pos = state.target_window_pos.empty() ? -1 : state.target_window_pos.back(); dflash_record_window_update(state, 0, state.target_window_rows, true); - if (breakdown != nullptr) { - breakdown->commit_us += (uint64_t) (ggml_time_us() - t_commit_us); - } - - const int64_t t_log_us = ggml_time_us(); - dflash_contract_log_append(state, seq_id, new_positions); - if (breakdown != nullptr) { - breakdown->log_us += (uint64_t) (ggml_time_us() - t_log_us); - } return true; } const int32_t keep_old_rows = std::min(state.target_window_rows, state.cross_ctx - n_rows); - state.n_rows_dropped += (size_t) std::max(0, state.target_window_rows - keep_old_rows); - const int64_t t_window_alloc_us = ggml_time_us(); std::vector & next_window_pos = state.target_window_pos_stage; next_window_pos.resize((size_t) (keep_old_rows + n_rows)); - if (breakdown != nullptr) { - breakdown->window_alloc_us += (uint64_t) (ggml_time_us() - t_window_alloc_us); - } if (keep_old_rows > 0) { - const int64_t t_keep_old_us = ggml_time_us(); std::copy(state.target_window_pos.end() - keep_old_rows, state.target_window_pos.end(), next_window_pos.begin()); - if (breakdown != nullptr) { - breakdown->keep_old_us += (uint64_t) (ggml_time_us() - t_keep_old_us); - } } - const int64_t t_new_rows_us = ggml_time_us(); state.target_window_append_features.assign(new_rows.begin(), new_rows.end()); dflash_ring_append_rows(state, state.target_window_append_features.data(), n_rows); std::copy(new_positions.begin(), new_positions.end(), next_window_pos.begin() + keep_old_rows); - if (breakdown != nullptr) { - breakdown->new_rows_us += (uint64_t) (ggml_time_us() - t_new_rows_us); - } - const int64_t t_commit_us = ggml_time_us(); state.target_window_pos.swap(next_window_pos); next_window_pos.clear(); state.target_window_rows = keep_old_rows + n_rows; state.target_window_ring_filled = state.target_window_rows; state.last_target_pos = state.target_window_pos.empty() ? -1 : state.target_window_pos.back(); dflash_record_window_update(state, keep_old_rows, n_rows, false); - if (breakdown != nullptr) { - breakdown->commit_us += (uint64_t) (ggml_time_us() - t_commit_us); - } - - const int64_t t_log_us = ggml_time_us(); - dflash_contract_log_append(state, seq_id, new_positions); - if (breakdown != nullptr) { - breakdown->log_us += (uint64_t) (ggml_time_us() - t_log_us); - } return true; } @@ -868,5 +527,4 @@ static void dflash_context_shift( state.last_target_pos = state.target_window_pos.empty() ? -1 : state.target_window_pos.back(); dflash_record_window_update(state, 0, state.target_window_rows, true); llama_reset_dflash_kv_cache_state(state.ctx_dft); - state.n_context_shifts++; } diff --git a/common/speculative.cpp b/common/speculative.cpp index 8f15f8ef..82819c9b 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -2039,8 +2039,6 @@ int32_t common_speculative_on_target_seq_batch( const llama_batch * batch_for_spec = &batch; llama_batch seq_batch = {}; const bool needs_seq_split = is_prompt_warmup && !common_speculative_batch_is_exact_single_seq(batch, seq_id); - auto * dflash_state = common_speculative_get_dflash_state(spec); - const bool measure_dflash_warmup_collect = dflash_state != nullptr && is_prompt_warmup; if (needs_seq_split) { const int n_seq_tokens = common_speculative_copy_seq_batch(batch, seq_id, seq_batch); @@ -2048,28 +2046,16 @@ int32_t common_speculative_on_target_seq_batch( return n_seq_tokens < 0 ? -1 : 0; } - const int64_t t_collect_us = measure_dflash_warmup_collect ? ggml_time_us() : 0; if (!common_speculative_collect_target_seq_batch_features(spec, ctx_tgt, batch, seq_id, feature_view)) { llama_batch_free(seq_batch); return -1; } - if (measure_dflash_warmup_collect) { - dflash_state->t_warmup_collect_us += (uint64_t) (ggml_time_us() - t_collect_us); - dflash_state->n_warmup_collect_calls++; - dflash_state->n_warmup_collect_rows += (size_t) n_seq_tokens; - } batch_for_spec = &seq_batch; } else { - const int64_t t_collect_us = measure_dflash_warmup_collect ? ggml_time_us() : 0; if (!common_speculative_collect_target_batch_features(spec, ctx_tgt, batch, feature_view)) { return -1; } - if (measure_dflash_warmup_collect) { - dflash_state->t_warmup_collect_us += (uint64_t) (ggml_time_us() - t_collect_us); - dflash_state->n_warmup_collect_calls++; - dflash_state->n_warmup_collect_rows += (size_t) batch.n_tokens; - } } const int32_t ret = common_speculative_on_target_batch(spec, *batch_for_spec, feature_view, is_prompt_warmup); @@ -2170,16 +2156,7 @@ bool common_speculative_commit_accepted_hidden_rows( return false; } - auto * dflash_state = common_speculative_get_dflash_state(spec); - const int64_t t_commit_us = dflash_state != nullptr ? ggml_time_us() : 0; - const bool ok = common_speculative_apply_hidden_rows(spec, seq_id, pos_base, commit_tokens, hidden_rows); - if (dflash_state != nullptr) { - dflash_state->t_accept_commit_us += (uint64_t) (ggml_time_us() - t_commit_us); - dflash_state->n_accept_commit_calls++; - dflash_state->n_accept_commit_rows += commit_tokens.size(); - } - - return ok; + return common_speculative_apply_hidden_rows(spec, seq_id, pos_base, commit_tokens, hidden_rows); } bool common_speculative_commit_accepted_output( @@ -2196,16 +2173,9 @@ bool common_speculative_commit_accepted_output( } std::vector hidden_rows; - auto * dflash_state = common_speculative_get_dflash_state(spec); - const int64_t t_copy_us = dflash_state != nullptr ? ggml_time_us() : 0; if (!common_speculative_copy_output_hidden_rows(spec, ctx, output_indices, hidden_rows)) { return false; } - if (dflash_state != nullptr) { - dflash_state->t_accept_output_copy_us += (uint64_t) (ggml_time_us() - t_copy_us); - dflash_state->n_accept_output_copy_calls++; - dflash_state->n_accept_output_copy_rows += output_indices.size(); - } return common_speculative_commit_accepted_hidden_rows( spec, @@ -2471,341 +2441,6 @@ void common_speculative_print_stats(const common_speculative * spec, double slot impl->n_acc_tokens, str_perf.c_str()); - if (impl->type == COMMON_SPECULATIVE_TYPE_DFLASH) { - const auto * dflash_state = dynamic_cast(impl.get()); - if (dflash_state != nullptr && dflash_stats_log_enabled()) { - llama_dflash_profile_stats capture_stats; - llama_dflash_profile_stats graph_stats; - const bool have_capture = llama_dflash_profile_get_stats(dflash_state->ctx_tgt, &capture_stats); - const bool have_graph = llama_dflash_profile_get_stats(dflash_state->ctx_dft, &graph_stats); - - LOG_INF("statistics dflash detail: cross_ctx=%d, window_rows=%d, pos=[%d..%d], window_updates=%zu, rows_seen=%zu, rows_dropped=%zu, shifts=%zu, draft_fail(empty/set/decode)=%zu/%zu/%zu, next_draft_pos=%d\n", - dflash_state->cross_ctx, - dflash_state->target_window_rows, - dflash_state->target_window_pos.empty() ? -1 : (int) dflash_state->target_window_pos.front(), - dflash_state->target_window_pos.empty() ? -1 : (int) dflash_state->target_window_pos.back(), - dflash_state->n_window_updates, - dflash_state->n_rows_seen, - dflash_state->n_rows_dropped, - dflash_state->n_context_shifts, - dflash_state->n_draft_empty, - dflash_state->n_set_target_fail, - dflash_state->n_decode_fail, - (int) dflash_state->last_draft_pos_base); - - if (have_capture || have_graph) { - const double kv_cache_total_ms = (double) ( - graph_stats.graph_kv_cache_build_us + - graph_stats.graph_kv_cache_reserve_us + - graph_stats.graph_kv_cache_reset_us + - graph_stats.graph_kv_cache_alloc_us + - graph_stats.graph_kv_cache_feature_upload_us + - graph_stats.graph_kv_cache_pos_upload_us + - graph_stats.graph_kv_cache_compute_us + - graph_stats.graph_kv_cache_sync_us + - graph_stats.graph_kv_cache_read_concat_pad_us) / 1000.0; - const double kv_upload_feature_ms = (double) graph_stats.graph_kv_cache_feature_upload_us / 1000.0; - const double kv_upload_pos_ms = (double) graph_stats.graph_kv_cache_pos_upload_us / 1000.0; - const double kv_upload_total_ms = kv_upload_feature_ms + kv_upload_pos_ms; - const double kv_compute_ms = (double) graph_stats.graph_kv_cache_compute_us / 1000.0; - const double kv_sync_ms = (double) graph_stats.graph_kv_cache_sync_us / 1000.0; - const double kv_workspace_total_ms = (double) ( - graph_stats.graph_kv_workspace_build_us + - graph_stats.graph_kv_workspace_reserve_us + - graph_stats.graph_kv_workspace_reset_us + - graph_stats.graph_kv_workspace_alloc_us + - graph_stats.graph_kv_workspace_compute_us + - graph_stats.graph_kv_workspace_sync_us) / 1000.0; - const double draft_kv_traffic_ms = (double) ( - graph_stats.graph_main_node_k_ctx_view_us + - graph_stats.graph_main_node_v_ctx_view_us + - graph_stats.graph_main_node_k_concat_us + - graph_stats.graph_main_node_v_concat_us + - graph_stats.graph_main_node_k_pad_us + - graph_stats.graph_main_node_v_pad_us + - graph_stats.graph_main_node_k_perm_cont_us + - graph_stats.graph_main_node_v_perm_cont_us) / 1000.0; - const double draft_main_profiled_ms = (double) ( - graph_stats.graph_main_node_qcur_us + - graph_stats.graph_main_node_k_draft_us + - graph_stats.graph_main_node_v_draft_us + - graph_stats.graph_main_node_flash_attn_us + - graph_stats.graph_main_node_attn_out_us + - graph_stats.graph_main_node_ffn_us + - graph_stats.graph_main_node_result_rows_us + - graph_stats.graph_main_node_result_norm_us + - graph_stats.graph_main_node_result_us) / 1000.0; - const double replay_append_ms = (double) dflash_state->t_accept_append_us / 1000.0; - const double feature_path_ms = (double) ( - capture_stats.capture_prepare_sync_us + - capture_stats.capture_materialize_us + - graph_stats.set_target_copy_us + - graph_stats.graph_feature_copy_us + - graph_stats.graph_pos_copy_us + - graph_stats.graph_mask_build_us) / 1000.0; - const double decode_internal_ms = (double) ( - graph_stats.decode_prelude_us + - graph_stats.decode_sched_reset_us + - graph_stats.decode_build_graph_us + - graph_stats.decode_sched_alloc_graph_us + - graph_stats.decode_prepare_us + - graph_stats.decode_set_inputs_us + - graph_stats.decode_graph_compute_us + - graph_stats.decode_result_us + - graph_stats.decode_embedding_us + - graph_stats.decode_final_sched_reset_us) / 1000.0; - - LOG_INF("statistics dflash profile: capture(sync/materialize)=%.3f/%.3f ms calls=%llu/%llu bytes=%llu phase(prompt/verify batches changes)=%llu/%llu %llu/%llu, set_target=%.3f ms rows=%llu bytes=%llu, decode(llama_output_reserve/prepare)=%.3f/%.3f ms calls=%llu/%llu realloc(bytes)=%llu/%llu, prep(total/features/pos/mask)=%.3f/%.3f/%.3f/%.3f ms kv_cache(total/build/reserve/reset/alloc/up_f/up_p/compute/sync/read)=%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f ms calls(prepare/cache/read)=%llu/%llu/%llu bytes(feature/pos/mask/read)=%llu/%llu/%llu/%llu host_layers=%d, fallback_pos(copy/graph)=%llu/%llu, nonmono(copy/graph)=%llu/%llu, capture_fail=%llu/%llu decode_prepare_fail=%llu, visible_kv_max=%llu, last(rows=%d width=%d left_pad=%d n_tokens=%d n_kv=%d pos=[%d..%d])\n", - (double) capture_stats.capture_prepare_sync_us / 1000.0, - (double) capture_stats.capture_materialize_us / 1000.0, - (unsigned long long) capture_stats.capture_prepare_calls, - (unsigned long long) capture_stats.capture_materialize_calls, - (unsigned long long) capture_stats.capture_materialize_bytes, - (unsigned long long) capture_stats.capture_prompt_batches, - (unsigned long long) capture_stats.capture_prompt_shape_changes, - (unsigned long long) capture_stats.capture_verify_batches, - (unsigned long long) capture_stats.capture_verify_shape_changes, - (double) graph_stats.set_target_copy_us / 1000.0, - (unsigned long long) graph_stats.set_target_rows, - (unsigned long long) graph_stats.set_target_copy_bytes, - (double) graph_stats.decode_output_reserve_us / 1000.0, - (double) graph_stats.decode_prepare_us / 1000.0, - (unsigned long long) graph_stats.decode_output_reserve_calls, - (unsigned long long) graph_stats.decode_prepare_calls, - (unsigned long long) graph_stats.decode_output_reserve_reallocs, - (unsigned long long) graph_stats.decode_output_reserve_realloc_bytes, - (double) graph_stats.graph_prepare_total_us / 1000.0, - (double) graph_stats.graph_feature_copy_us / 1000.0, - (double) graph_stats.graph_pos_copy_us / 1000.0, - (double) graph_stats.graph_mask_build_us / 1000.0, - kv_cache_total_ms, - (double) graph_stats.graph_kv_cache_build_us / 1000.0, - (double) graph_stats.graph_kv_cache_reserve_us / 1000.0, - (double) graph_stats.graph_kv_cache_reset_us / 1000.0, - (double) graph_stats.graph_kv_cache_alloc_us / 1000.0, - (double) graph_stats.graph_kv_cache_feature_upload_us / 1000.0, - (double) graph_stats.graph_kv_cache_pos_upload_us / 1000.0, - (double) graph_stats.graph_kv_cache_compute_us / 1000.0, - (double) graph_stats.graph_kv_cache_sync_us / 1000.0, - (double) graph_stats.graph_kv_cache_read_concat_pad_us / 1000.0, - (unsigned long long) graph_stats.graph_prepare_calls, - (unsigned long long) graph_stats.graph_kv_cache_calls, - (unsigned long long) graph_stats.graph_kv_cache_read_concat_pad_calls, - (unsigned long long) graph_stats.graph_feature_bytes, - (unsigned long long) graph_stats.graph_pos_bytes, - (unsigned long long) graph_stats.graph_mask_bytes, - (unsigned long long) graph_stats.graph_kv_cache_cached_bytes, - graph_stats.last_kv_cache_host_layers, - (unsigned long long) graph_stats.set_target_missing_positions, - (unsigned long long) graph_stats.graph_pos_fallbacks, - (unsigned long long) graph_stats.set_target_non_monotonic_positions, - (unsigned long long) graph_stats.graph_pos_non_monotonic, - (unsigned long long) capture_stats.capture_prepare_failures, - (unsigned long long) capture_stats.capture_materialize_failures, - (unsigned long long) graph_stats.decode_prepare_failures, - (unsigned long long) graph_stats.graph_visible_kv_max, - graph_stats.last_n_rows, - graph_stats.last_width, - graph_stats.last_left_pad, - graph_stats.last_n_tokens, - graph_stats.last_n_kv_total, - (int) graph_stats.last_pos_first, - (int) graph_stats.last_pos_last); - - LOG_INF("statistics dflash features: total=%.3f ms capture(sync/materialize)=%.3f/%.3f ms set_target=%.3f ms prep(feature/pos/mask)=%.3f/%.3f/%.3f ms rows(materialize/set_target)=%llu/%llu bytes(materialize/set_target/feature/pos/mask)=%llu/%llu/%llu/%llu/%llu\n", - feature_path_ms, - (double) capture_stats.capture_prepare_sync_us / 1000.0, - (double) capture_stats.capture_materialize_us / 1000.0, - (double) graph_stats.set_target_copy_us / 1000.0, - (double) graph_stats.graph_feature_copy_us / 1000.0, - (double) graph_stats.graph_pos_copy_us / 1000.0, - (double) graph_stats.graph_mask_build_us / 1000.0, - (unsigned long long) capture_stats.capture_materialize_rows, - (unsigned long long) graph_stats.set_target_rows, - (unsigned long long) capture_stats.capture_materialize_bytes, - (unsigned long long) graph_stats.set_target_copy_bytes, - (unsigned long long) graph_stats.graph_feature_bytes, - (unsigned long long) graph_stats.graph_pos_bytes, - (unsigned long long) graph_stats.graph_mask_bytes); - - LOG_INF("statistics dflash kv: total=%.3f ms build/reserve/reset/alloc/upload_f/upload_p/compute/sync/read=%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f ms calls=%llu cached_bytes=%llu host_layers=%d\n", - kv_cache_total_ms, - (double) graph_stats.graph_kv_cache_build_us / 1000.0, - (double) graph_stats.graph_kv_cache_reserve_us / 1000.0, - (double) graph_stats.graph_kv_cache_reset_us / 1000.0, - (double) graph_stats.graph_kv_cache_alloc_us / 1000.0, - (double) graph_stats.graph_kv_cache_feature_upload_us / 1000.0, - (double) graph_stats.graph_kv_cache_pos_upload_us / 1000.0, - (double) graph_stats.graph_kv_cache_compute_us / 1000.0, - (double) graph_stats.graph_kv_cache_sync_us / 1000.0, - (double) graph_stats.graph_kv_cache_read_concat_pad_us / 1000.0, - (unsigned long long) graph_stats.graph_kv_cache_calls, - (unsigned long long) graph_stats.graph_kv_cache_cached_bytes, - graph_stats.last_kv_cache_host_layers); - - if (graph_stats.graph_kv_workspace_calls > 0) { - LOG_INF("statistics dflash kv workspace: total=%.3f ms build/reserve/reset/alloc/compute/sync=%.3f/%.3f/%.3f/%.3f/%.3f/%.3f ms calls=%llu\n", - kv_workspace_total_ms, - (double) graph_stats.graph_kv_workspace_build_us / 1000.0, - (double) graph_stats.graph_kv_workspace_reserve_us / 1000.0, - (double) graph_stats.graph_kv_workspace_reset_us / 1000.0, - (double) graph_stats.graph_kv_workspace_alloc_us / 1000.0, - (double) graph_stats.graph_kv_workspace_compute_us / 1000.0, - (double) graph_stats.graph_kv_workspace_sync_us / 1000.0, - (unsigned long long) graph_stats.graph_kv_workspace_calls); - } - - if (graph_stats.decode_internal_chunks > 0) { - LOG_INF("statistics dflash decode: llama_decode(total)=%.3f ms calls=%zu chunks=%llu rebuilds=%llu sync_points=%llu internal(total/prelude/sched_reset/build/alloc/prepare/set_inputs/compute/get_result/get_embedding/final_reset)=%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f ms\n", - (double) dflash_state->t_draft_decode_us / 1000.0, - dflash_state->n_call_draft, - (unsigned long long) graph_stats.decode_internal_chunks, - (unsigned long long) graph_stats.decode_graph_rebuilds, - (unsigned long long) graph_stats.decode_sync_profile_points, - decode_internal_ms, - (double) graph_stats.decode_prelude_us / 1000.0, - (double) graph_stats.decode_sched_reset_us / 1000.0, - (double) graph_stats.decode_build_graph_us / 1000.0, - (double) graph_stats.decode_sched_alloc_graph_us / 1000.0, - (double) graph_stats.decode_prepare_us / 1000.0, - (double) graph_stats.decode_set_inputs_us / 1000.0, - (double) graph_stats.decode_graph_compute_us / 1000.0, - (double) graph_stats.decode_result_us / 1000.0, - (double) graph_stats.decode_embedding_us / 1000.0, - (double) graph_stats.decode_final_sched_reset_us / 1000.0); - } - - if (graph_stats.graph_kv_node_fused_target_calls > 0 || - graph_stats.graph_kv_node_k_proj_calls > 0 || - graph_stats.graph_kv_node_k_norm_calls > 0 || - graph_stats.graph_kv_node_k_rope_calls > 0 || - graph_stats.graph_kv_node_v_proj_calls > 0 || - graph_stats.graph_kv_node_k_store_calls > 0 || - graph_stats.graph_kv_node_v_store_calls > 0) { - LOG_INF("statistics dflash kv nodes: fused_target/k_proj/k_norm/k_rope/v_proj/k_store/v_store=%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f ms calls=%llu/%llu/%llu/%llu/%llu/%llu/%llu\n", - (double) graph_stats.graph_kv_node_fused_target_us / 1000.0, - (double) graph_stats.graph_kv_node_k_proj_us / 1000.0, - (double) graph_stats.graph_kv_node_k_norm_us / 1000.0, - (double) graph_stats.graph_kv_node_k_rope_us / 1000.0, - (double) graph_stats.graph_kv_node_v_proj_us / 1000.0, - (double) graph_stats.graph_kv_node_k_store_us / 1000.0, - (double) graph_stats.graph_kv_node_v_store_us / 1000.0, - (unsigned long long) graph_stats.graph_kv_node_fused_target_calls, - (unsigned long long) graph_stats.graph_kv_node_k_proj_calls, - (unsigned long long) graph_stats.graph_kv_node_k_norm_calls, - (unsigned long long) graph_stats.graph_kv_node_k_rope_calls, - (unsigned long long) graph_stats.graph_kv_node_v_proj_calls, - (unsigned long long) graph_stats.graph_kv_node_k_store_calls, - (unsigned long long) graph_stats.graph_kv_node_v_store_calls); - } - - if (graph_stats.graph_main_node_qcur_calls > 0 || - graph_stats.graph_main_node_k_draft_calls > 0 || - graph_stats.graph_main_node_v_draft_calls > 0 || - graph_stats.graph_main_node_flash_attn_calls > 0 || - graph_stats.graph_main_node_attn_out_calls > 0 || - graph_stats.graph_main_node_ffn_calls > 0 || - graph_stats.graph_main_node_result_rows_calls > 0 || - graph_stats.graph_main_node_result_norm_calls > 0 || - graph_stats.graph_main_node_result_calls > 0) { - LOG_INF("statistics dflash draft nodes: profiled=%.3f ms graph_compute=%.3f ms qcur/k_draft/v_draft/flash_attn/attn_out/ffn/result_rows/result_norm/result=%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f ms calls=%llu/%llu/%llu/%llu/%llu/%llu/%llu/%llu/%llu\n", - draft_main_profiled_ms, - (double) graph_stats.decode_graph_compute_us / 1000.0, - (double) graph_stats.graph_main_node_qcur_us / 1000.0, - (double) graph_stats.graph_main_node_k_draft_us / 1000.0, - (double) graph_stats.graph_main_node_v_draft_us / 1000.0, - (double) graph_stats.graph_main_node_flash_attn_us / 1000.0, - (double) graph_stats.graph_main_node_attn_out_us / 1000.0, - (double) graph_stats.graph_main_node_ffn_us / 1000.0, - (double) graph_stats.graph_main_node_result_rows_us / 1000.0, - (double) graph_stats.graph_main_node_result_norm_us / 1000.0, - (double) graph_stats.graph_main_node_result_us / 1000.0, - (unsigned long long) graph_stats.graph_main_node_qcur_calls, - (unsigned long long) graph_stats.graph_main_node_k_draft_calls, - (unsigned long long) graph_stats.graph_main_node_v_draft_calls, - (unsigned long long) graph_stats.graph_main_node_flash_attn_calls, - (unsigned long long) graph_stats.graph_main_node_attn_out_calls, - (unsigned long long) graph_stats.graph_main_node_ffn_calls, - (unsigned long long) graph_stats.graph_main_node_result_rows_calls, - (unsigned long long) graph_stats.graph_main_node_result_norm_calls, - (unsigned long long) graph_stats.graph_main_node_result_calls); - } - - if (graph_stats.graph_main_node_k_ctx_view_calls > 0 || - graph_stats.graph_main_node_v_ctx_view_calls > 0 || - graph_stats.graph_main_node_k_concat_calls > 0 || - graph_stats.graph_main_node_v_concat_calls > 0 || - graph_stats.graph_main_node_k_pad_calls > 0 || - graph_stats.graph_main_node_v_pad_calls > 0 || - graph_stats.graph_main_node_k_perm_cont_calls > 0 || - graph_stats.graph_main_node_v_perm_cont_calls > 0) { - LOG_INF("statistics dflash draft kv traffic: total=%.3f ms graph_compute=%.3f ms k_ctx_view/v_ctx_view/k_concat/v_concat/k_pad/v_pad/k_perm_cont/v_perm_cont=%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f ms calls=%llu/%llu/%llu/%llu/%llu/%llu/%llu/%llu\n", - draft_kv_traffic_ms, - (double) graph_stats.decode_graph_compute_us / 1000.0, - (double) graph_stats.graph_main_node_k_ctx_view_us / 1000.0, - (double) graph_stats.graph_main_node_v_ctx_view_us / 1000.0, - (double) graph_stats.graph_main_node_k_concat_us / 1000.0, - (double) graph_stats.graph_main_node_v_concat_us / 1000.0, - (double) graph_stats.graph_main_node_k_pad_us / 1000.0, - (double) graph_stats.graph_main_node_v_pad_us / 1000.0, - (double) graph_stats.graph_main_node_k_perm_cont_us / 1000.0, - (double) graph_stats.graph_main_node_v_perm_cont_us / 1000.0, - (unsigned long long) graph_stats.graph_main_node_k_ctx_view_calls, - (unsigned long long) graph_stats.graph_main_node_v_ctx_view_calls, - (unsigned long long) graph_stats.graph_main_node_k_concat_calls, - (unsigned long long) graph_stats.graph_main_node_v_concat_calls, - (unsigned long long) graph_stats.graph_main_node_k_pad_calls, - (unsigned long long) graph_stats.graph_main_node_v_pad_calls, - (unsigned long long) graph_stats.graph_main_node_k_perm_cont_calls, - (unsigned long long) graph_stats.graph_main_node_v_perm_cont_calls); - } - - LOG_INF("statistics dflash hot: kv(upload_f/upload_p/upload/compute/sync)=%.3f/%.3f/%.3f/%.3f/%.3f ms calls=%llu replay(accepted_prefix_append)=%.3f ms calls=%zu rows=%zu\n", - kv_upload_feature_ms, - kv_upload_pos_ms, - kv_upload_total_ms, - kv_compute_ms, - kv_sync_ms, - (unsigned long long) graph_stats.graph_kv_cache_calls, - replay_append_ms, - dflash_state->n_accept_append_calls, - dflash_state->n_accept_append_rows); - - LOG_INF("statistics dflash stages: draft(decode/sample)=%.3f/%.3f ms warmup(collect/append)=%.3f/%.3f ms calls=%zu/%zu rows=%zu/%zu accept(total/output_copy/append)=%.3f/%.3f/%.3f ms calls=%zu/%zu/%zu rows=%zu/%zu/%zu\n", - (double) dflash_state->t_draft_decode_us / 1000.0, - (double) dflash_state->t_draft_sample_us / 1000.0, - (double) dflash_state->t_warmup_collect_us / 1000.0, - (double) dflash_state->t_warmup_append_us / 1000.0, - dflash_state->n_warmup_collect_calls, - dflash_state->n_warmup_append_calls, - dflash_state->n_warmup_collect_rows, - dflash_state->n_warmup_append_rows, - (double) dflash_state->t_accept_commit_us / 1000.0, - (double) dflash_state->t_accept_output_copy_us / 1000.0, - (double) dflash_state->t_accept_append_us / 1000.0, - dflash_state->n_accept_commit_calls, - dflash_state->n_accept_output_copy_calls, - dflash_state->n_accept_append_calls, - dflash_state->n_accept_commit_rows, - dflash_state->n_accept_output_copy_rows, - dflash_state->n_accept_append_rows); - - if (dflash_state->n_accept_append_calls > 0) { - LOG_INF("statistics dflash replay: append(filter/window_alloc/replace/keep_old/new_rows/commit/log)=%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f ms calls=%zu replace/slide=%zu/%zu\n", - (double) dflash_state->t_accept_append_filter_us / 1000.0, - (double) dflash_state->t_accept_append_window_alloc_us / 1000.0, - (double) dflash_state->t_accept_append_replace_us / 1000.0, - (double) dflash_state->t_accept_append_keep_old_us / 1000.0, - (double) dflash_state->t_accept_append_new_rows_us / 1000.0, - (double) dflash_state->t_accept_append_commit_detail_us / 1000.0, - (double) dflash_state->t_accept_append_log_us / 1000.0, - dflash_state->n_accept_append_calls, - dflash_state->n_accept_append_replace_calls, - dflash_state->n_accept_append_slide_calls); - } - } - } - } } if (spec->tuner && spec->tuner->enabled && slot_tps > 0.0 && n_decoded > 0) { @@ -3076,35 +2711,9 @@ int32_t common_speculative_on_target_batch( } } - dflash_append_breakdown append_breakdown; - const int64_t t_append_us = ggml_time_us(); - if (!dflash_append_target_features(*dflash_state, features, batch, seq_id, &append_breakdown)) { + if (!dflash_append_target_features(*dflash_state, features, seq_id)) { return -1; } - - const uint64_t append_us = (uint64_t) (ggml_time_us() - t_append_us); - if (is_prompt_warmup) { - dflash_state->t_warmup_append_us += append_us; - dflash_state->n_warmup_append_calls++; - dflash_state->n_warmup_append_rows += (size_t) batch.n_tokens; - } else { - dflash_state->t_accept_append_us += append_us; - dflash_state->t_accept_append_filter_us += append_breakdown.filter_us; - dflash_state->t_accept_append_window_alloc_us += append_breakdown.window_alloc_us; - dflash_state->t_accept_append_replace_us += append_breakdown.replace_us; - dflash_state->t_accept_append_keep_old_us += append_breakdown.keep_old_us; - dflash_state->t_accept_append_new_rows_us += append_breakdown.new_rows_us; - dflash_state->t_accept_append_commit_detail_us += append_breakdown.commit_us; - dflash_state->t_accept_append_log_us += append_breakdown.log_us; - dflash_state->n_accept_append_calls++; - dflash_state->n_accept_append_rows += (size_t) batch.n_tokens; - if (append_breakdown.replace_call) { - dflash_state->n_accept_append_replace_calls++; - } else { - dflash_state->n_accept_append_slide_calls++; - } - } - return 0; } diff --git a/examples/server/server-context.cpp b/examples/server/server-context.cpp index 7be87f97..21b2be85 100644 --- a/examples/server/server-context.cpp +++ b/examples/server/server-context.cpp @@ -4076,23 +4076,6 @@ void server_context::speculative_decoding_accept() { slot.sampled = ids.back(); // last accepted token slot.n_past = slot.cache_tokens.n_tokens(); - const common_speculative_type spec_type_used = common_speculative_current_type(slot.spec); - const bool any_rejected = (ids.size() - 1) < n_draft; - const common_speculative_checkpoint * ckpt = common_speculative_get_checkpoint(slot.spec); - const bool will_restore = any_rejected && ckpt != nullptr && ckpt->valid; - - if (server_speculative_uses_target_features(slot.params.speculative) && !accepted_output_indices.empty()) { - llama_dflash_contract_log_accept( - slot.id, - spec_type_used == COMMON_SPECULATIVE_TYPE_DFLASH, - will_restore ? "restore" : "direct", - any_rejected, - n_draft, - ids.size(), - spec_pos_base, - accepted_output_indices); - } - common_speculative_commit( slot.spec, ctx, diff --git a/src/graphs/build_dflash.cpp b/src/graphs/build_dflash.cpp index cbb03403..adb583ef 100644 --- a/src/graphs/build_dflash.cpp +++ b/src/graphs/build_dflash.cpp @@ -7,23 +7,23 @@ ggml_cgraph * llm_build_context::build_dflash_kv_workspace() { const int64_t n_embd_head_k = hparams.n_embd_head_k(0); const int64_t n_embd_head_v = hparams.n_embd_head_v(0); - const int64_t ctx_len = lctx.dflash_visible_cross_ctx > 0 - ? (int64_t) lctx.dflash_visible_cross_ctx + const int64_t ctx_len = lctx.dflash.visible_cross_ctx > 0 + ? (int64_t) lctx.dflash.visible_cross_ctx : std::max(1, (int64_t) cparams.n_ctx - (int64_t) hparams.dflash_block_size); - const int32_t cache_rows = std::clamp(lctx.dflash_kv_cache_view_n_filled, 0, (int32_t) ctx_len); + const int32_t cache_rows = std::clamp(lctx.dflash.kv.cache_view_n_filled, 0, (int32_t) ctx_len); const int32_t cache_write_pos = ctx_len > 0 - ? ((lctx.dflash_kv_cache_view_write_pos % (int32_t) ctx_len) + (int32_t) ctx_len) % (int32_t) ctx_len + ? ((lctx.dflash.kv.cache_view_write_pos % (int32_t) ctx_len) + (int32_t) ctx_len) % (int32_t) ctx_len : 0; GGML_ASSERT(n_embd_head_k == n_embd_head_v); GGML_ASSERT(lctx.ensure_dflash_kv_cache_tensors((int32_t) ctx_len)); - GGML_ASSERT((int32_t) lctx.dflash_k_ctx_workspace.size() == n_layer); - GGML_ASSERT((int32_t) lctx.dflash_v_ctx_workspace.size() == n_layer); + GGML_ASSERT((int32_t) lctx.dflash.kv.k_ctx_workspace.size() == n_layer); + GGML_ASSERT((int32_t) lctx.dflash.kv.v_ctx_workspace.size() == n_layer); ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes((int) std::max(1, ctx_len)) + 16 * n_layer, false); auto build_ordered_cache_view = [&](ggml_tensor * cache) -> ggml_tensor * { - if (!lctx.dflash_kv_cache_view_valid || cache_rows <= 0) { + if (!lctx.dflash.kv.cache_view_valid || cache_rows <= 0) { return cache; } @@ -67,11 +67,11 @@ ggml_cgraph * llm_build_context::build_dflash_kv_workspace() { }; for (int il = 0; il < n_layer; ++il) { - GGML_ASSERT((size_t) il < lctx.dflash_k_ctx_cache.size()); - GGML_ASSERT((size_t) il < lctx.dflash_v_ctx_cache.size()); + GGML_ASSERT((size_t) il < lctx.dflash.kv.k_ctx_cache.size()); + GGML_ASSERT((size_t) il < lctx.dflash.kv.v_ctx_cache.size()); - ggml_tensor * Kordered = build_ordered_cache_view(lctx.dflash_k_ctx_cache[(size_t) il]); - ggml_tensor * Vordered = build_ordered_cache_view(lctx.dflash_v_ctx_cache[(size_t) il]); + ggml_tensor * Kordered = build_ordered_cache_view(lctx.dflash.kv.k_ctx_cache[(size_t) il]); + ggml_tensor * Vordered = build_ordered_cache_view(lctx.dflash.kv.v_ctx_cache[(size_t) il]); cb(Kordered, "dflash_workspace_k_ctx_view", il); cb(Vordered, "dflash_workspace_v_ctx_view", il); @@ -80,19 +80,19 @@ ggml_cgraph * llm_build_context::build_dflash_kv_workspace() { cb(Kworkspace, "dflash_workspace_k_perm_cont", il); cb(Vworkspace, "dflash_workspace_v_perm_cont", il); - ggml_tensor * Kdst = ggml_view_3d(ctx0, lctx.dflash_k_ctx_workspace[(size_t) il], - lctx.dflash_k_ctx_workspace[(size_t) il]->ne[0], + ggml_tensor * Kdst = ggml_view_3d(ctx0, lctx.dflash.kv.k_ctx_workspace[(size_t) il], + lctx.dflash.kv.k_ctx_workspace[(size_t) il]->ne[0], ctx_len, - lctx.dflash_k_ctx_workspace[(size_t) il]->ne[2], - lctx.dflash_k_ctx_workspace[(size_t) il]->nb[1], - lctx.dflash_k_ctx_workspace[(size_t) il]->nb[2], + lctx.dflash.kv.k_ctx_workspace[(size_t) il]->ne[2], + lctx.dflash.kv.k_ctx_workspace[(size_t) il]->nb[1], + lctx.dflash.kv.k_ctx_workspace[(size_t) il]->nb[2], 0); - ggml_tensor * Vdst = ggml_view_3d(ctx0, lctx.dflash_v_ctx_workspace[(size_t) il], - lctx.dflash_v_ctx_workspace[(size_t) il]->ne[0], + ggml_tensor * Vdst = ggml_view_3d(ctx0, lctx.dflash.kv.v_ctx_workspace[(size_t) il], + lctx.dflash.kv.v_ctx_workspace[(size_t) il]->ne[0], ctx_len, - lctx.dflash_v_ctx_workspace[(size_t) il]->ne[2], - lctx.dflash_v_ctx_workspace[(size_t) il]->nb[1], - lctx.dflash_v_ctx_workspace[(size_t) il]->nb[2], + lctx.dflash.kv.v_ctx_workspace[(size_t) il]->ne[2], + lctx.dflash.kv.v_ctx_workspace[(size_t) il]->nb[1], + lctx.dflash.kv.v_ctx_workspace[(size_t) il]->nb[2], 0); ggml_tensor * Kstore = ggml_cpy(ctx0, Kworkspace, Kdst); @@ -110,11 +110,11 @@ ggml_cgraph * llm_build_context::build_dflash_kv_cache() { const int64_t n_embd_head_k = hparams.n_embd_head_k(0); const int64_t n_embd_head_v = hparams.n_embd_head_v(0); const int64_t n_target_features = hparams.dflash_n_target_features; - const int64_t ctx_len = lctx.dflash_visible_cross_ctx > 0 - ? (int64_t) lctx.dflash_visible_cross_ctx + const int64_t ctx_len = lctx.dflash.visible_cross_ctx > 0 + ? (int64_t) lctx.dflash.visible_cross_ctx : std::max(1, (int64_t) cparams.n_ctx - (int64_t) hparams.dflash_block_size); - const int64_t update_rows = std::max(1, lctx.dflash_kv_cache_update_rows > 0 ? lctx.dflash_kv_cache_update_rows : ctx_len); - const int32_t write_pos = lctx.dflash_kv_cache_write_pos; + const int64_t update_rows = std::max(1, lctx.dflash.kv.cache_update_rows > 0 ? lctx.dflash.kv.cache_update_rows : ctx_len); + const int32_t write_pos = lctx.dflash.kv.cache_write_pos; GGML_ASSERT(n_embd_head_k == n_embd_head_v); GGML_ASSERT(n_target_features > 0); @@ -124,21 +124,21 @@ ggml_cgraph * llm_build_context::build_dflash_kv_cache() { ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes((int) std::max(1, update_rows)) + 24 * n_layer, false); - lctx.dflash_kv_input_target_features = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_target_features, update_rows); - ggml_set_input(lctx.dflash_kv_input_target_features); - cb(lctx.dflash_kv_input_target_features, "dflash_kv_input_target_features", -1); + lctx.dflash.kv.cache_input_target_features = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_target_features, update_rows); + ggml_set_input(lctx.dflash.kv.cache_input_target_features); + cb(lctx.dflash.kv.cache_input_target_features, "dflash_kv_input_target_features", -1); - lctx.dflash_kv_input_pos_ctx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, update_rows); - ggml_set_input(lctx.dflash_kv_input_pos_ctx); - cb(lctx.dflash_kv_input_pos_ctx, "dflash_kv_input_pos_ctx", -1); + lctx.dflash.kv.cache_input_pos_ctx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, update_rows); + ggml_set_input(lctx.dflash.kv.cache_input_pos_ctx); + cb(lctx.dflash.kv.cache_input_pos_ctx, "dflash_kv_input_pos_ctx", -1); - ggml_tensor * fused_target = llm_build_lora_mm(lctx, ctx0, model.dflash_fc, lctx.dflash_kv_input_target_features); + ggml_tensor * fused_target = llm_build_lora_mm(lctx, ctx0, model.dflash_fc, lctx.dflash.kv.cache_input_target_features); fused_target = llm_build_norm(ctx0, fused_target, hparams, model.dflash_hidden_norm, nullptr, LLM_NORM_RMS, cb, -1); cb(fused_target, "dflash_kv_fused_target", -1); for (int il = 0; il < n_layer; ++il) { - GGML_ASSERT((size_t) il < lctx.dflash_k_ctx_cache.size()); - GGML_ASSERT((size_t) il < lctx.dflash_v_ctx_cache.size()); + GGML_ASSERT((size_t) il < lctx.dflash.kv.k_ctx_cache.size()); + GGML_ASSERT((size_t) il < lctx.dflash.kv.v_ctx_cache.size()); ggml_tensor * Kcur_ctx_proj = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, fused_target); cb(Kcur_ctx_proj, "dflash_kv_k_proj", il); @@ -146,7 +146,7 @@ ggml_cgraph * llm_build_context::build_dflash_kv_cache() { ggml_tensor * Kcur_ctx = ggml_reshape_3d(ctx0, Kcur_ctx_proj, n_embd_head_k, n_head_kv, update_rows); Kcur_ctx = llm_build_norm(ctx0, Kcur_ctx, hparams, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, cb, il); cb(Kcur_ctx, "dflash_kv_k_norm", il); - Kcur_ctx = ggml_rope_ext(ctx0, Kcur_ctx, lctx.dflash_kv_input_pos_ctx, nullptr, + Kcur_ctx = ggml_rope_ext(ctx0, Kcur_ctx, lctx.dflash.kv.cache_input_pos_ctx, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); cb(Kcur_ctx, "dflash_kv_k_rope", il); @@ -177,20 +177,20 @@ ggml_cgraph * llm_build_context::build_dflash_kv_cache() { Vcur_ctx->nb[1], Vcur_ctx->nb[2], 0); - ggml_tensor * Kdst_first = ggml_view_3d(ctx0, lctx.dflash_k_ctx_cache[(size_t) il], - lctx.dflash_k_ctx_cache[(size_t) il]->ne[0], - lctx.dflash_k_ctx_cache[(size_t) il]->ne[1], + ggml_tensor * Kdst_first = ggml_view_3d(ctx0, lctx.dflash.kv.k_ctx_cache[(size_t) il], + lctx.dflash.kv.k_ctx_cache[(size_t) il]->ne[0], + lctx.dflash.kv.k_ctx_cache[(size_t) il]->ne[1], first_rows, - lctx.dflash_k_ctx_cache[(size_t) il]->nb[1], - lctx.dflash_k_ctx_cache[(size_t) il]->nb[2], - (size_t) write_pos * lctx.dflash_k_ctx_cache[(size_t) il]->nb[2]); - ggml_tensor * Vdst_first = ggml_view_3d(ctx0, lctx.dflash_v_ctx_cache[(size_t) il], - lctx.dflash_v_ctx_cache[(size_t) il]->ne[0], - lctx.dflash_v_ctx_cache[(size_t) il]->ne[1], + lctx.dflash.kv.k_ctx_cache[(size_t) il]->nb[1], + lctx.dflash.kv.k_ctx_cache[(size_t) il]->nb[2], + (size_t) write_pos * lctx.dflash.kv.k_ctx_cache[(size_t) il]->nb[2]); + ggml_tensor * Vdst_first = ggml_view_3d(ctx0, lctx.dflash.kv.v_ctx_cache[(size_t) il], + lctx.dflash.kv.v_ctx_cache[(size_t) il]->ne[0], + lctx.dflash.kv.v_ctx_cache[(size_t) il]->ne[1], first_rows, - lctx.dflash_v_ctx_cache[(size_t) il]->nb[1], - lctx.dflash_v_ctx_cache[(size_t) il]->nb[2], - (size_t) write_pos * lctx.dflash_v_ctx_cache[(size_t) il]->nb[2]); + lctx.dflash.kv.v_ctx_cache[(size_t) il]->nb[1], + lctx.dflash.kv.v_ctx_cache[(size_t) il]->nb[2], + (size_t) write_pos * lctx.dflash.kv.v_ctx_cache[(size_t) il]->nb[2]); ggml_tensor * Kstore_first = ggml_cpy(ctx0, Ksrc_first, Kdst_first); cb(Kstore_first, "dflash_kv_k_store", il); @@ -216,19 +216,19 @@ ggml_cgraph * llm_build_context::build_dflash_kv_cache() { Vcur_ctx->nb[1], Vcur_ctx->nb[2], (size_t) first_rows * Vcur_ctx->nb[2]); - ggml_tensor * Kdst_second = ggml_view_3d(ctx0, lctx.dflash_k_ctx_cache[(size_t) il], - lctx.dflash_k_ctx_cache[(size_t) il]->ne[0], - lctx.dflash_k_ctx_cache[(size_t) il]->ne[1], + ggml_tensor * Kdst_second = ggml_view_3d(ctx0, lctx.dflash.kv.k_ctx_cache[(size_t) il], + lctx.dflash.kv.k_ctx_cache[(size_t) il]->ne[0], + lctx.dflash.kv.k_ctx_cache[(size_t) il]->ne[1], second_rows, - lctx.dflash_k_ctx_cache[(size_t) il]->nb[1], - lctx.dflash_k_ctx_cache[(size_t) il]->nb[2], + lctx.dflash.kv.k_ctx_cache[(size_t) il]->nb[1], + lctx.dflash.kv.k_ctx_cache[(size_t) il]->nb[2], 0); - ggml_tensor * Vdst_second = ggml_view_3d(ctx0, lctx.dflash_v_ctx_cache[(size_t) il], - lctx.dflash_v_ctx_cache[(size_t) il]->ne[0], - lctx.dflash_v_ctx_cache[(size_t) il]->ne[1], + ggml_tensor * Vdst_second = ggml_view_3d(ctx0, lctx.dflash.kv.v_ctx_cache[(size_t) il], + lctx.dflash.kv.v_ctx_cache[(size_t) il]->ne[0], + lctx.dflash.kv.v_ctx_cache[(size_t) il]->ne[1], second_rows, - lctx.dflash_v_ctx_cache[(size_t) il]->nb[1], - lctx.dflash_v_ctx_cache[(size_t) il]->nb[2], + lctx.dflash.kv.v_ctx_cache[(size_t) il]->nb[1], + lctx.dflash.kv.v_ctx_cache[(size_t) il]->nb[2], 0); ggml_tensor * Kstore_second = ggml_cpy(ctx0, Ksrc_second, Kdst_second); @@ -248,12 +248,11 @@ ggml_cgraph * llm_build_context::build_dflash() { const int64_t n_embd_head_k = hparams.n_embd_head_k(0); const int64_t n_embd_head_v = hparams.n_embd_head_v(0); const int64_t n_target_features = hparams.dflash_n_target_features; - auto & profile = lctx.dflash_profile; - const int64_t ctx_len = lctx.dflash_visible_cross_ctx > 0 - ? (int64_t) lctx.dflash_visible_cross_ctx + const int64_t ctx_len = lctx.dflash.visible_cross_ctx > 0 + ? (int64_t) lctx.dflash.visible_cross_ctx : std::max(1, (int64_t) cparams.n_ctx - (int64_t) hparams.dflash_block_size); const int32_t cache_write_pos = ctx_len > 0 - ? ((lctx.dflash_kv_cache_view_write_pos % (int32_t) ctx_len) + (int32_t) ctx_len) % (int32_t) ctx_len + ? ((lctx.dflash.kv.cache_view_write_pos % (int32_t) ctx_len) + (int32_t) ctx_len) % (int32_t) ctx_len : 0; const int64_t n_kv_total = GGML_PAD(ctx_len + n_tokens, flash_attn ? 256 : 32); const int64_t n_kv_pad = n_kv_total - (ctx_len + n_tokens); @@ -273,21 +272,21 @@ ggml_cgraph * llm_build_context::build_dflash() { } } - lctx.inp_dflash_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv_total, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); - lctx.dflash_kq_mask_tensor = lctx.inp_dflash_kq_mask; - ggml_set_input(lctx.inp_dflash_kq_mask); - cb(lctx.inp_dflash_kq_mask, "dflash_kq_mask", -1); + lctx.dflash.inputs.kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv_total, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + lctx.dflash.kv.kq_mask_tensor = lctx.dflash.inputs.kq_mask; + ggml_set_input(lctx.dflash.inputs.kq_mask); + cb(lctx.dflash.inputs.kq_mask, "dflash_kq_mask", -1); - ggml_tensor * dflash_kq_mask_full = flash_attn ? ggml_cast(ctx0, lctx.inp_dflash_kq_mask, GGML_TYPE_F16) : lctx.inp_dflash_kq_mask; + ggml_tensor * dflash_kq_mask_full = flash_attn ? ggml_cast(ctx0, lctx.dflash.inputs.kq_mask, GGML_TYPE_F16) : lctx.dflash.inputs.kq_mask; ggml_tensor * dflash_kq_mask_swa = nullptr; - lctx.inp_dflash_kq_mask_swa = nullptr; - lctx.dflash_kq_mask_swa_tensor = nullptr; + lctx.dflash.inputs.kq_mask_swa = nullptr; + lctx.dflash.kv.kq_mask_swa_tensor = nullptr; if (have_swa_layers && hparams.n_swa > 0) { - lctx.inp_dflash_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv_total, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); - lctx.dflash_kq_mask_swa_tensor = lctx.inp_dflash_kq_mask_swa; - ggml_set_input(lctx.inp_dflash_kq_mask_swa); - cb(lctx.inp_dflash_kq_mask_swa, "dflash_kq_mask_swa", -1); - dflash_kq_mask_swa = flash_attn ? ggml_cast(ctx0, lctx.inp_dflash_kq_mask_swa, GGML_TYPE_F16) : lctx.inp_dflash_kq_mask_swa; + lctx.dflash.inputs.kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv_total, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + lctx.dflash.kv.kq_mask_swa_tensor = lctx.dflash.inputs.kq_mask_swa; + ggml_set_input(lctx.dflash.inputs.kq_mask_swa); + cb(lctx.dflash.inputs.kq_mask_swa, "dflash_kq_mask_swa", -1); + dflash_kq_mask_swa = flash_attn ? ggml_cast(ctx0, lctx.dflash.inputs.kq_mask_swa, GGML_TYPE_F16) : lctx.dflash.inputs.kq_mask_swa; } ggml_tensor * tok_embd = model.tok_embd; @@ -328,25 +327,24 @@ ggml_cgraph * llm_build_context::build_dflash() { Vcur_noise = ggml_reshape_3d(ctx0, Vcur_noise, n_embd_head_v, n_head_kv, n_tokens); cb(Vcur_noise, "Vcur_noise", il); - const int64_t t_cache_read_us = ggml_time_us(); - GGML_ASSERT((size_t) il < lctx.dflash_k_ctx_workspace.size()); - GGML_ASSERT((size_t) il < lctx.dflash_v_ctx_workspace.size()); - GGML_ASSERT(lctx.dflash_k_ctx_workspace[(size_t) il] != nullptr); - GGML_ASSERT(lctx.dflash_v_ctx_workspace[(size_t) il] != nullptr); + GGML_ASSERT((size_t) il < lctx.dflash.kv.k_ctx_workspace.size()); + GGML_ASSERT((size_t) il < lctx.dflash.kv.v_ctx_workspace.size()); + GGML_ASSERT(lctx.dflash.kv.k_ctx_workspace[(size_t) il] != nullptr); + GGML_ASSERT(lctx.dflash.kv.v_ctx_workspace[(size_t) il] != nullptr); - ggml_tensor * Kcur_ctx = ggml_view_3d(ctx0, lctx.dflash_k_ctx_workspace[(size_t) il], - lctx.dflash_k_ctx_workspace[(size_t) il]->ne[0], + ggml_tensor * Kcur_ctx = ggml_view_3d(ctx0, lctx.dflash.kv.k_ctx_workspace[(size_t) il], + lctx.dflash.kv.k_ctx_workspace[(size_t) il]->ne[0], ctx_len, - lctx.dflash_k_ctx_workspace[(size_t) il]->ne[2], - lctx.dflash_k_ctx_workspace[(size_t) il]->nb[1], - lctx.dflash_k_ctx_workspace[(size_t) il]->nb[2], + lctx.dflash.kv.k_ctx_workspace[(size_t) il]->ne[2], + lctx.dflash.kv.k_ctx_workspace[(size_t) il]->nb[1], + lctx.dflash.kv.k_ctx_workspace[(size_t) il]->nb[2], 0); - ggml_tensor * Vcur_ctx = ggml_view_3d(ctx0, lctx.dflash_v_ctx_workspace[(size_t) il], - lctx.dflash_v_ctx_workspace[(size_t) il]->ne[0], + ggml_tensor * Vcur_ctx = ggml_view_3d(ctx0, lctx.dflash.kv.v_ctx_workspace[(size_t) il], + lctx.dflash.kv.v_ctx_workspace[(size_t) il]->ne[0], ctx_len, - lctx.dflash_v_ctx_workspace[(size_t) il]->ne[2], - lctx.dflash_v_ctx_workspace[(size_t) il]->nb[1], - lctx.dflash_v_ctx_workspace[(size_t) il]->nb[2], + lctx.dflash.kv.v_ctx_workspace[(size_t) il]->ne[2], + lctx.dflash.kv.v_ctx_workspace[(size_t) il]->nb[1], + lctx.dflash.kv.v_ctx_workspace[(size_t) il]->nb[2], 0); cb(Kcur_ctx, "Kcur_ctx_workspace", il); cb(Vcur_ctx, "Vcur_ctx_workspace", il); @@ -368,9 +366,6 @@ ggml_cgraph * llm_build_context::build_dflash() { cb(Vcur, "dflash_main_v_pad", il); } - profile.graph_kv_cache_read_concat_pad_us += (uint64_t) (ggml_time_us() - t_cache_read_us); - profile.graph_kv_cache_read_concat_pad_calls++; - profile.graph_kv_cache_cached_bytes += ggml_nbytes(lctx.dflash_k_ctx_cache[(size_t) il]) + ggml_nbytes(lctx.dflash_v_ctx_cache[(size_t) il]); cb(Qcur, "Qcur", il); ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); @@ -434,11 +429,11 @@ ggml_cgraph * llm_build_context::build_dflash() { cb(result, "result_output", -1); ggml_build_forward_expand(gf, result); - lctx.dflash_draft_tokens_tensor = nullptr; + lctx.dflash.draft_tokens_tensor = nullptr; ggml_tensor * draft_tokens = ggml_argmax(ctx0, result); ggml_set_name(draft_tokens, "draft_argmax"); ggml_build_forward_expand(gf, draft_tokens); - lctx.dflash_draft_tokens_tensor = draft_tokens; + lctx.dflash.draft_tokens_tensor = draft_tokens; return gf; } diff --git a/src/llama-context.h b/src/llama-context.h index c06a4b27..b8a2b4cc 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -365,82 +365,15 @@ struct llama_context { std::vector feature_view_buffer; input_state inputs; int32_t visible_cross_ctx = 0; - llama_dflash_profile_stats profile; + + // Argmax token IDs from the DFlash draft graph, computed via GPU argmax. + // Populated in llama_decode_internal after graph compute. + std::vector draft_tokens; + struct ggml_tensor * draft_tokens_tensor = nullptr; }; dflash_runtime dflash; using dflash_capture_state = dflash_runtime::capture_state; - const float * & dflash_target_features = dflash.target.features; - size_t & dflash_target_features_n_floats = dflash.target.features_n_floats; - int32_t & dflash_target_features_n_rows = dflash.target.features_n_rows; - const float * & dflash_target_append_features = dflash.target.append_features; - size_t & dflash_target_append_features_n_floats = dflash.target.append_features_n_floats; - int32_t & dflash_target_append_features_n_rows = dflash.target.append_features_n_rows; - const llama_pos * & dflash_target_positions = dflash.target.positions; - size_t & dflash_target_positions_n = dflash.target.positions_n; - uint64_t & dflash_target_window_version = dflash.target.version; - int32_t & dflash_target_window_keep_rows = dflash.target.keep_rows; - int32_t & dflash_target_window_append_rows = dflash.target.append_rows; - bool & dflash_target_window_replace = dflash.target.replace; - std::vector & dflash_target_features_owned = dflash.target.features_owned; - std::vector & dflash_target_append_features_owned = dflash.target.append_features_owned; - std::vector & dflash_target_positions_owned = dflash.target.positions_owned; - std::vector & dflash_target_features_padded = dflash.target.features_padded; - std::vector & dflash_feature_view_buffer = dflash.feature_view_buffer; - std::vector & dflash_pos_ctx_data = dflash.target.pos_ctx_data; - std::vector & dflash_kq_mask_data = dflash.target.kq_mask_data; - std::vector & dflash_kq_mask_swa_data = dflash.target.kq_mask_swa_data; - int32_t & dflash_visible_cross_ctx = dflash.visible_cross_ctx; - std::vector & dflash_k_ctx_cache = dflash.kv.k_ctx_cache; - std::vector & dflash_v_ctx_cache = dflash.kv.v_ctx_cache; - - // Argmax token IDs from the DFlash draft graph, computed via GPU argmax. - // Populated in llama_decode_internal after graph compute. - std::vector dflash_draft_tokens; - struct ggml_tensor * dflash_draft_tokens_tensor = nullptr; - - std::vector & dflash_k_ctx_workspace = dflash.kv.k_ctx_workspace; - std::vector & dflash_v_ctx_workspace = dflash.kv.v_ctx_workspace; - struct ggml_context * & dflash_cache_ctx = dflash.kv.cache_ctx; - std::vector & dflash_cache_bufs = dflash.kv.cache_bufs; - int32_t & dflash_kv_cache_write_pos = dflash.kv.cache_write_pos; - int32_t & dflash_kv_cache_n_filled = dflash.kv.cache_n_filled; - int32_t & dflash_kv_cache_update_rows = dflash.kv.cache_update_rows; - int32_t & dflash_kv_cache_reserved_rows = dflash.kv.cache_reserved_rows; - int32_t & dflash_kv_cache_view_write_pos = dflash.kv.cache_view_write_pos; - int32_t & dflash_kv_cache_view_n_filled = dflash.kv.cache_view_n_filled; - uint64_t & dflash_kv_cache_applied_window_version = dflash.kv.cache_applied_window_version; - bool & dflash_kv_cache_valid = dflash.kv.cache_valid; - bool & dflash_kv_cache_view_valid = dflash.kv.cache_view_valid; - int32_t & dflash_kv_workspace_write_pos = dflash.kv.workspace_write_pos; - int32_t & dflash_kv_workspace_n_filled = dflash.kv.workspace_n_filled; - int32_t & dflash_kv_workspace_reserved_rows = dflash.kv.workspace_reserved_rows; - int32_t & dflash_kv_workspace_token_capacity = dflash.kv.workspace_token_capacity; - int32_t & dflash_kv_workspace_n_kv_total = dflash.kv.workspace_n_kv_total; - uint64_t & dflash_kv_workspace_applied_window_version = dflash.kv.workspace_applied_window_version; - bool & dflash_kv_workspace_valid = dflash.kv.workspace_valid; - bool & dflash_kv_workspace_sync_pending = dflash.kv.workspace_sync_pending; - std::vector & dflash_buf_compute_meta = dflash.kv.cache_compute_meta; - std::vector & dflash_workspace_buf_compute_meta = dflash.kv.workspace_compute_meta; - ggml_backend_sched_t & dflash_sched = dflash.kv.cache_sched; - ggml_backend_sched_t & dflash_workspace_sched = dflash.kv.workspace_sched; - ggml_cgraph * & dflash_kv_graph = dflash.kv.cache_graph; - ggml_cgraph * & dflash_kv_workspace_graph = dflash.kv.workspace_graph; - int32_t & dflash_kv_graph_rows = dflash.kv.cache_graph_rows; - int32_t & dflash_kv_graph_write_pos = dflash.kv.cache_graph_write_pos; - int32_t & dflash_kv_workspace_graph_rows = dflash.kv.workspace_graph_rows; - int32_t & dflash_kv_workspace_graph_write_pos = dflash.kv.workspace_graph_write_pos; - struct ggml_tensor * & dflash_kv_input_target_features = dflash.kv.cache_input_target_features; - struct ggml_tensor * & dflash_kv_input_pos_ctx = dflash.kv.cache_input_pos_ctx; - struct ggml_tensor * & dflash_kq_mask_tensor = dflash.kv.kq_mask_tensor; - struct ggml_tensor * & dflash_kq_mask_swa_tensor = dflash.kv.kq_mask_swa_tensor; - std::unique_ptr & dflash_capture = dflash.capture; - llama_dflash_profile_stats & dflash_profile = dflash.profile; - struct ggml_tensor * & inp_dflash_target_features = dflash.inputs.target_features; - struct ggml_tensor * & inp_dflash_pos_ctx = dflash.inputs.pos_ctx; - struct ggml_tensor * & inp_dflash_kq_mask = dflash.inputs.kq_mask; - struct ggml_tensor * & inp_dflash_kq_mask_swa = dflash.inputs.kq_mask_swa; - // input tensors struct ggml_tensor * inp_tokens; // I32 [n_batch] struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch] diff --git a/src/llama-dflash-profile.h b/src/llama-dflash-profile.h deleted file mode 100644 index 4d998488..00000000 --- a/src/llama-dflash-profile.h +++ /dev/null @@ -1,340 +0,0 @@ -#pragma once - -#include -#include - -enum llama_dflash_kv_node_kind { - LLAMA_DFLASH_KV_NODE_NONE = 0, - LLAMA_DFLASH_KV_NODE_FUSED_TARGET, - LLAMA_DFLASH_KV_NODE_K_PROJ, - LLAMA_DFLASH_KV_NODE_K_NORM, - LLAMA_DFLASH_KV_NODE_K_ROPE, - LLAMA_DFLASH_KV_NODE_V_PROJ, - LLAMA_DFLASH_KV_NODE_K_STORE, - LLAMA_DFLASH_KV_NODE_V_STORE, -}; - -enum llama_dflash_main_node_kind { - LLAMA_DFLASH_MAIN_NODE_NONE = 0, - LLAMA_DFLASH_MAIN_NODE_QCUR, - LLAMA_DFLASH_MAIN_NODE_K_DRAFT, - LLAMA_DFLASH_MAIN_NODE_V_DRAFT, - LLAMA_DFLASH_MAIN_NODE_K_CTX_VIEW, - LLAMA_DFLASH_MAIN_NODE_V_CTX_VIEW, - LLAMA_DFLASH_MAIN_NODE_K_CONCAT, - LLAMA_DFLASH_MAIN_NODE_V_CONCAT, - LLAMA_DFLASH_MAIN_NODE_K_PAD, - LLAMA_DFLASH_MAIN_NODE_V_PAD, - LLAMA_DFLASH_MAIN_NODE_K_PERM_CONT, - LLAMA_DFLASH_MAIN_NODE_V_PERM_CONT, - LLAMA_DFLASH_MAIN_NODE_FLASH_ATTN, - LLAMA_DFLASH_MAIN_NODE_ATTN_OUT, - LLAMA_DFLASH_MAIN_NODE_FFN, - LLAMA_DFLASH_MAIN_NODE_RESULT_ROWS, - LLAMA_DFLASH_MAIN_NODE_RESULT_NORM, - LLAMA_DFLASH_MAIN_NODE_RESULT, -}; - -struct llama_dflash_kv_node_profiler { - llama_dflash_profile_stats * profile = nullptr; - int64_t t_start_us = 0; - llama_dflash_kv_node_kind active_kind = LLAMA_DFLASH_KV_NODE_NONE; -}; - -struct llama_dflash_main_node_profiler { - llama_dflash_profile_stats * profile = nullptr; - ggml_backend_sched_eval_callback prev_callback = nullptr; - void * prev_user_data = nullptr; - bool prev_active = false; - int64_t t_start_us = 0; - llama_dflash_main_node_kind active_kind = LLAMA_DFLASH_MAIN_NODE_NONE; -}; - -static inline bool llama_dflash_tensor_name_has_prefix(const struct ggml_tensor * tensor, const char * prefix) { - if (tensor == nullptr || prefix == nullptr || prefix[0] == '\0') { - return false; - } - - return std::strncmp(tensor->name, prefix, std::strlen(prefix)) == 0; -} - -static inline bool llama_dflash_tensor_name_matches_label(const struct ggml_tensor * tensor, const char * label) { - if (!llama_dflash_tensor_name_has_prefix(tensor, label)) { - return false; - } - - const size_t label_len = std::strlen(label); - const char next = tensor->name[label_len]; - return next == '\0' || next == '-'; -} - -static inline llama_dflash_kv_node_kind llama_dflash_kv_node_kind_from_tensor(const struct ggml_tensor * tensor) { - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_fused_target")) { - return LLAMA_DFLASH_KV_NODE_FUSED_TARGET; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_proj")) { - return LLAMA_DFLASH_KV_NODE_K_PROJ; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_norm")) { - return LLAMA_DFLASH_KV_NODE_K_NORM; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_rope")) { - return LLAMA_DFLASH_KV_NODE_K_ROPE; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_v_proj")) { - return LLAMA_DFLASH_KV_NODE_V_PROJ; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_k_store")) { - return LLAMA_DFLASH_KV_NODE_K_STORE; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_kv_v_store")) { - return LLAMA_DFLASH_KV_NODE_V_STORE; - } - - return LLAMA_DFLASH_KV_NODE_NONE; -} - -static inline void llama_dflash_kv_node_profile_add( - llama_dflash_profile_stats & profile, - llama_dflash_kv_node_kind kind, - uint64_t elapsed_us) { - switch (kind) { - case LLAMA_DFLASH_KV_NODE_FUSED_TARGET: - profile.graph_kv_node_fused_target_calls++; - profile.graph_kv_node_fused_target_us += elapsed_us; - break; - case LLAMA_DFLASH_KV_NODE_K_PROJ: - profile.graph_kv_node_k_proj_calls++; - profile.graph_kv_node_k_proj_us += elapsed_us; - break; - case LLAMA_DFLASH_KV_NODE_K_NORM: - profile.graph_kv_node_k_norm_calls++; - profile.graph_kv_node_k_norm_us += elapsed_us; - break; - case LLAMA_DFLASH_KV_NODE_K_ROPE: - profile.graph_kv_node_k_rope_calls++; - profile.graph_kv_node_k_rope_us += elapsed_us; - break; - case LLAMA_DFLASH_KV_NODE_V_PROJ: - profile.graph_kv_node_v_proj_calls++; - profile.graph_kv_node_v_proj_us += elapsed_us; - break; - case LLAMA_DFLASH_KV_NODE_K_STORE: - profile.graph_kv_node_k_store_calls++; - profile.graph_kv_node_k_store_us += elapsed_us; - break; - case LLAMA_DFLASH_KV_NODE_V_STORE: - profile.graph_kv_node_v_store_calls++; - profile.graph_kv_node_v_store_us += elapsed_us; - break; - case LLAMA_DFLASH_KV_NODE_NONE: - break; - } -} - -static inline llama_dflash_main_node_kind llama_dflash_main_node_kind_from_tensor(const struct ggml_tensor * tensor) { - if (llama_dflash_tensor_name_has_prefix(tensor, "Qcur")) { - return LLAMA_DFLASH_MAIN_NODE_QCUR; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "Kcur_noise")) { - return LLAMA_DFLASH_MAIN_NODE_K_DRAFT; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "Vcur_noise")) { - return LLAMA_DFLASH_MAIN_NODE_V_DRAFT; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "Kcur_ctx_cache")) { - return LLAMA_DFLASH_MAIN_NODE_K_CTX_VIEW; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "Vcur_ctx_cache")) { - return LLAMA_DFLASH_MAIN_NODE_V_CTX_VIEW; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_k_concat")) { - return LLAMA_DFLASH_MAIN_NODE_K_CONCAT; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_v_concat")) { - return LLAMA_DFLASH_MAIN_NODE_V_CONCAT; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_k_pad")) { - return LLAMA_DFLASH_MAIN_NODE_K_PAD; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_v_pad")) { - return LLAMA_DFLASH_MAIN_NODE_V_PAD; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_k_perm_cont")) { - return LLAMA_DFLASH_MAIN_NODE_K_PERM_CONT; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "dflash_main_v_perm_cont")) { - return LLAMA_DFLASH_MAIN_NODE_V_PERM_CONT; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "flash_attn_reshaped")) { - return LLAMA_DFLASH_MAIN_NODE_NONE; - } - if (llama_dflash_tensor_name_matches_label(tensor, "flash_attn")) { - return LLAMA_DFLASH_MAIN_NODE_FLASH_ATTN; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "kqv_out")) { - return LLAMA_DFLASH_MAIN_NODE_ATTN_OUT; - } - if (llama_dflash_tensor_name_has_prefix(tensor, "ffn_out")) { - return LLAMA_DFLASH_MAIN_NODE_FFN; - } - if (llama_dflash_tensor_name_matches_label(tensor, "result_output_rows")) { - return LLAMA_DFLASH_MAIN_NODE_RESULT_ROWS; - } - if (llama_dflash_tensor_name_matches_label(tensor, "result_norm")) { - return LLAMA_DFLASH_MAIN_NODE_RESULT_NORM; - } - if (llama_dflash_tensor_name_matches_label(tensor, "output")) { - return LLAMA_DFLASH_MAIN_NODE_RESULT; - } - if (llama_dflash_tensor_name_matches_label(tensor, "result_output")) { - return LLAMA_DFLASH_MAIN_NODE_RESULT; - } - - return LLAMA_DFLASH_MAIN_NODE_NONE; -} - -static inline void llama_dflash_main_node_profile_add( - llama_dflash_profile_stats & profile, - llama_dflash_main_node_kind kind, - uint64_t elapsed_us) { - switch (kind) { - case LLAMA_DFLASH_MAIN_NODE_QCUR: - profile.graph_main_node_qcur_calls++; - profile.graph_main_node_qcur_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_K_DRAFT: - profile.graph_main_node_k_draft_calls++; - profile.graph_main_node_k_draft_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_V_DRAFT: - profile.graph_main_node_v_draft_calls++; - profile.graph_main_node_v_draft_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_K_CTX_VIEW: - profile.graph_main_node_k_ctx_view_calls++; - profile.graph_main_node_k_ctx_view_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_V_CTX_VIEW: - profile.graph_main_node_v_ctx_view_calls++; - profile.graph_main_node_v_ctx_view_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_K_CONCAT: - profile.graph_main_node_k_concat_calls++; - profile.graph_main_node_k_concat_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_V_CONCAT: - profile.graph_main_node_v_concat_calls++; - profile.graph_main_node_v_concat_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_K_PAD: - profile.graph_main_node_k_pad_calls++; - profile.graph_main_node_k_pad_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_V_PAD: - profile.graph_main_node_v_pad_calls++; - profile.graph_main_node_v_pad_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_K_PERM_CONT: - profile.graph_main_node_k_perm_cont_calls++; - profile.graph_main_node_k_perm_cont_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_V_PERM_CONT: - profile.graph_main_node_v_perm_cont_calls++; - profile.graph_main_node_v_perm_cont_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_FLASH_ATTN: - profile.graph_main_node_flash_attn_calls++; - profile.graph_main_node_flash_attn_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_ATTN_OUT: - profile.graph_main_node_attn_out_calls++; - profile.graph_main_node_attn_out_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_FFN: - profile.graph_main_node_ffn_calls++; - profile.graph_main_node_ffn_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_RESULT_ROWS: - profile.graph_main_node_result_rows_calls++; - profile.graph_main_node_result_rows_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_RESULT_NORM: - profile.graph_main_node_result_norm_calls++; - profile.graph_main_node_result_norm_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_RESULT: - profile.graph_main_node_result_calls++; - profile.graph_main_node_result_us += elapsed_us; - break; - case LLAMA_DFLASH_MAIN_NODE_NONE: - break; - } -} - -static inline bool llama_dflash_kv_node_eval_callback(struct ggml_tensor * tensor, bool ask, void * user_data) { - auto * profiler = static_cast(user_data); - if (profiler == nullptr || profiler->profile == nullptr) { - return false; - } - - const llama_dflash_kv_node_kind kind = llama_dflash_kv_node_kind_from_tensor(tensor); - if (ask) { - if (kind == LLAMA_DFLASH_KV_NODE_NONE) { - return false; - } - - profiler->active_kind = kind; - profiler->t_start_us = ggml_time_us(); - return true; - } - - if (kind != LLAMA_DFLASH_KV_NODE_NONE && profiler->active_kind == kind && profiler->t_start_us > 0) { - llama_dflash_kv_node_profile_add(*profiler->profile, kind, (uint64_t) (ggml_time_us() - profiler->t_start_us)); - } - - profiler->active_kind = LLAMA_DFLASH_KV_NODE_NONE; - profiler->t_start_us = 0; - return true; -} - -static inline bool llama_dflash_main_node_eval_callback(struct ggml_tensor * tensor, bool ask, void * user_data) { - auto * profiler = static_cast(user_data); - if (profiler == nullptr || profiler->profile == nullptr) { - return false; - } - - const llama_dflash_main_node_kind kind = llama_dflash_main_node_kind_from_tensor(tensor); - if (ask) { - profiler->prev_active = profiler->prev_callback != nullptr - ? profiler->prev_callback(tensor, ask, profiler->prev_user_data) - : false; - - if (kind == LLAMA_DFLASH_MAIN_NODE_NONE) { - profiler->active_kind = LLAMA_DFLASH_MAIN_NODE_NONE; - profiler->t_start_us = 0; - return profiler->prev_active; - } - - profiler->active_kind = kind; - profiler->t_start_us = ggml_time_us(); - return true; - } - - bool prev_result = false; - if (profiler->prev_active && profiler->prev_callback != nullptr) { - prev_result = profiler->prev_callback(tensor, ask, profiler->prev_user_data); - } - - const bool tracked = kind != LLAMA_DFLASH_MAIN_NODE_NONE && - profiler->active_kind == kind && - profiler->t_start_us > 0; - if (tracked) { - llama_dflash_main_node_profile_add(*profiler->profile, kind, (uint64_t) (ggml_time_us() - profiler->t_start_us)); - } - - profiler->prev_active = false; - profiler->active_kind = LLAMA_DFLASH_MAIN_NODE_NONE; - profiler->t_start_us = 0; - return prev_result || tracked; -} diff --git a/src/llama-dflash.cpp b/src/llama-dflash.cpp index e32f53d1..bfb4595b 100644 --- a/src/llama-dflash.cpp +++ b/src/llama-dflash.cpp @@ -5,38 +5,22 @@ #include "llama-context.h" #include "llama-model.h" #include "llama-spec-features.h" -#include "llama-dflash-profile.h" #include "ggml.h" #include "ggml-backend.h" #include -#include #include #include #include -static bool llama_env_flag_enabled_local(const char * name) { - const char * env = std::getenv(name); - return env != nullptr && *env != '\0' && - std::strcmp(env, "0") != 0 && - std::strcmp(env, "false") != 0 && - std::strcmp(env, "off") != 0; -} - -static bool llama_dflash_stats_log_enabled() { - return llama_env_flag_enabled_local("IK_DFLASH_STATS_LOG"); -} - void llama_sync_dflash_workspace_if_pending(struct llama_context & lctx) { - if (!lctx.dflash_kv_workspace_sync_pending || lctx.dflash_workspace_sched == nullptr) { + if (!lctx.dflash.kv.workspace_sync_pending || lctx.dflash.kv.workspace_sched == nullptr) { return; } - const int64_t t_workspace_sync_us = ggml_time_us(); - ggml_backend_sched_synchronize(lctx.dflash_workspace_sched); - lctx.dflash_profile.graph_kv_workspace_sync_us += (uint64_t) (ggml_time_us() - t_workspace_sync_us); - lctx.dflash_kv_workspace_sync_pending = false; + ggml_backend_sched_synchronize(lctx.dflash.kv.workspace_sched); + lctx.dflash.kv.workspace_sync_pending = false; } static ggml_backend_buffer_type_t llama_dflash_kv_cache_layer_buft(const llama_context & lctx, int32_t il) { @@ -86,36 +70,36 @@ bool llama_context::ensure_dflash_kv_cache_tensors(int32_t cross_ctx) { const int64_t n_embd_head_v = model.hparams.n_embd_head_v(0); const int64_t n_head_kv = model.hparams.n_head_kv(); - if (dflash_cache_ctx != nullptr && !dflash_k_ctx_cache.empty()) { - const bool cache_matches = (int32_t) dflash_k_ctx_cache.size() == n_layer && - dflash_k_ctx_cache.front() != nullptr && - (int32_t) dflash_k_ctx_cache.front()->ne[2] == target_cross_ctx; - const bool workspace_matches = (int32_t) dflash_k_ctx_workspace.size() == n_layer && - dflash_k_ctx_workspace.front() != nullptr && - (int32_t) dflash_k_ctx_workspace.front()->ne[1] == target_workspace_n_kv_total; + if (dflash.kv.cache_ctx != nullptr && !dflash.kv.k_ctx_cache.empty()) { + const bool cache_matches = (int32_t) dflash.kv.k_ctx_cache.size() == n_layer && + dflash.kv.k_ctx_cache.front() != nullptr && + (int32_t) dflash.kv.k_ctx_cache.front()->ne[2] == target_cross_ctx; + const bool workspace_matches = (int32_t) dflash.kv.k_ctx_workspace.size() == n_layer && + dflash.kv.k_ctx_workspace.front() != nullptr && + (int32_t) dflash.kv.k_ctx_workspace.front()->ne[1] == target_workspace_n_kv_total; if (cache_matches && workspace_matches) { return true; } free_dflash_kv_cache_tensors(); - if (dflash_sched != nullptr) { - ggml_backend_sched_free(dflash_sched); - dflash_sched = nullptr; + if (dflash.kv.cache_sched != nullptr) { + ggml_backend_sched_free(dflash.kv.cache_sched); + dflash.kv.cache_sched = nullptr; } - if (dflash_workspace_sched != nullptr) { - ggml_backend_sched_free(dflash_workspace_sched); - dflash_workspace_sched = nullptr; + if (dflash.kv.workspace_sched != nullptr) { + ggml_backend_sched_free(dflash.kv.workspace_sched); + dflash.kv.workspace_sched = nullptr; } - dflash_kv_graph = nullptr; - dflash_kv_workspace_graph = nullptr; - dflash_kv_graph_rows = 0; - dflash_kv_graph_write_pos = 0; - dflash_kv_workspace_graph_rows = 0; - dflash_kv_workspace_graph_write_pos = 0; - dflash_kv_workspace_reserved_rows = 0; - dflash_buf_compute_meta.clear(); - dflash_workspace_buf_compute_meta.clear(); + dflash.kv.cache_graph = nullptr; + dflash.kv.workspace_graph = nullptr; + dflash.kv.cache_graph_rows = 0; + dflash.kv.cache_graph_write_pos = 0; + dflash.kv.workspace_graph_rows = 0; + dflash.kv.workspace_graph_write_pos = 0; + dflash.kv.workspace_reserved_rows = 0; + dflash.kv.cache_compute_meta.clear(); + dflash.kv.workspace_compute_meta.clear(); } ggml_init_params params = { @@ -124,166 +108,146 @@ bool llama_context::ensure_dflash_kv_cache_tensors(int32_t cross_ctx) { /*.no_alloc =*/ true, }; - dflash_cache_ctx = ggml_init(params); - if (dflash_cache_ctx == nullptr) { + dflash.kv.cache_ctx = ggml_init(params); + if (dflash.kv.cache_ctx == nullptr) { return false; } - dflash_k_ctx_cache.resize((size_t) n_layer); - dflash_v_ctx_cache.resize((size_t) n_layer); - dflash_k_ctx_workspace.clear(); - dflash_v_ctx_workspace.clear(); - dflash_k_ctx_workspace.resize((size_t) n_layer); - dflash_v_ctx_workspace.resize((size_t) n_layer); - dflash_cache_bufs.clear(); - dflash_cache_bufs.reserve((size_t) std::max(1, n_layer) * 4); - int32_t host_layers = 0; - const char * first_buft_name = nullptr; - const char * last_buft_name = nullptr; + dflash.kv.k_ctx_cache.resize((size_t) n_layer); + dflash.kv.v_ctx_cache.resize((size_t) n_layer); + dflash.kv.k_ctx_workspace.clear(); + dflash.kv.v_ctx_workspace.clear(); + dflash.kv.k_ctx_workspace.resize((size_t) n_layer); + dflash.kv.v_ctx_workspace.resize((size_t) n_layer); + dflash.kv.cache_bufs.clear(); + dflash.kv.cache_bufs.reserve((size_t) std::max(1, n_layer) * 4); for (int32_t il = 0; il < n_layer; ++il) { ggml_backend_buffer_type_t layer_buft = llama_dflash_kv_cache_layer_buft(*this, il); - if (ggml_backend_buft_is_host(layer_buft)) { - host_layers++; - } - if (first_buft_name == nullptr) { - first_buft_name = ggml_backend_buft_name(layer_buft); - } - last_buft_name = ggml_backend_buft_name(layer_buft); - dflash_k_ctx_cache[(size_t) il] = ggml_new_tensor_3d(dflash_cache_ctx, GGML_TYPE_F32, n_embd_head_k, n_head_kv, target_cross_ctx); - dflash_v_ctx_cache[(size_t) il] = ggml_new_tensor_3d(dflash_cache_ctx, GGML_TYPE_F32, n_embd_head_v, n_head_kv, target_cross_ctx); - if (dflash_k_ctx_cache[(size_t) il] == nullptr || dflash_v_ctx_cache[(size_t) il] == nullptr) { + dflash.kv.k_ctx_cache[(size_t) il] = ggml_new_tensor_3d(dflash.kv.cache_ctx, GGML_TYPE_F32, n_embd_head_k, n_head_kv, target_cross_ctx); + dflash.kv.v_ctx_cache[(size_t) il] = ggml_new_tensor_3d(dflash.kv.cache_ctx, GGML_TYPE_F32, n_embd_head_v, n_head_kv, target_cross_ctx); + if (dflash.kv.k_ctx_cache[(size_t) il] == nullptr || dflash.kv.v_ctx_cache[(size_t) il] == nullptr) { free_dflash_kv_cache_tensors(); return false; } - ggml_set_input(dflash_k_ctx_cache[(size_t) il]); - ggml_set_input(dflash_v_ctx_cache[(size_t) il]); - ggml_format_name(dflash_k_ctx_cache[(size_t) il], "dflash_k_ctx_cache_%d", il); - ggml_format_name(dflash_v_ctx_cache[(size_t) il], "dflash_v_ctx_cache_%d", il); + ggml_set_input(dflash.kv.k_ctx_cache[(size_t) il]); + ggml_set_input(dflash.kv.v_ctx_cache[(size_t) il]); + ggml_format_name(dflash.kv.k_ctx_cache[(size_t) il], "dflash_k_ctx_cache_%d", il); + ggml_format_name(dflash.kv.v_ctx_cache[(size_t) il], "dflash_v_ctx_cache_%d", il); - const size_t k_bytes = ggml_backend_buft_get_alloc_size(layer_buft, dflash_k_ctx_cache[(size_t) il]); + const size_t k_bytes = ggml_backend_buft_get_alloc_size(layer_buft, dflash.kv.k_ctx_cache[(size_t) il]); ggml_backend_buffer_t k_buf = ggml_backend_buft_alloc_buffer(layer_buft, k_bytes); if (k_buf == nullptr) { free_dflash_kv_cache_tensors(); return false; } ggml_backend_buffer_set_usage(k_buf, GGML_BACKEND_BUFFER_USAGE_COMPUTE); - ggml_backend_tensor_alloc(k_buf, dflash_k_ctx_cache[(size_t) il], ggml_backend_buffer_get_base(k_buf)); + ggml_backend_tensor_alloc(k_buf, dflash.kv.k_ctx_cache[(size_t) il], ggml_backend_buffer_get_base(k_buf)); ggml_backend_buffer_clear(k_buf, 0); - dflash_cache_bufs.push_back(k_buf); + dflash.kv.cache_bufs.push_back(k_buf); - const size_t v_bytes = ggml_backend_buft_get_alloc_size(layer_buft, dflash_v_ctx_cache[(size_t) il]); + const size_t v_bytes = ggml_backend_buft_get_alloc_size(layer_buft, dflash.kv.v_ctx_cache[(size_t) il]); ggml_backend_buffer_t v_buf = ggml_backend_buft_alloc_buffer(layer_buft, v_bytes); if (v_buf == nullptr) { free_dflash_kv_cache_tensors(); return false; } ggml_backend_buffer_set_usage(v_buf, GGML_BACKEND_BUFFER_USAGE_COMPUTE); - ggml_backend_tensor_alloc(v_buf, dflash_v_ctx_cache[(size_t) il], ggml_backend_buffer_get_base(v_buf)); + ggml_backend_tensor_alloc(v_buf, dflash.kv.v_ctx_cache[(size_t) il], ggml_backend_buffer_get_base(v_buf)); ggml_backend_buffer_clear(v_buf, 0); - dflash_cache_bufs.push_back(v_buf); + dflash.kv.cache_bufs.push_back(v_buf); - dflash_k_ctx_workspace[(size_t) il] = ggml_new_tensor_3d(dflash_cache_ctx, GGML_TYPE_F32, n_embd_head_k, target_workspace_n_kv_total, n_head_kv); - dflash_v_ctx_workspace[(size_t) il] = ggml_new_tensor_3d(dflash_cache_ctx, GGML_TYPE_F32, n_embd_head_v, target_workspace_n_kv_total, n_head_kv); - if (dflash_k_ctx_workspace[(size_t) il] == nullptr || dflash_v_ctx_workspace[(size_t) il] == nullptr) { + dflash.kv.k_ctx_workspace[(size_t) il] = ggml_new_tensor_3d(dflash.kv.cache_ctx, GGML_TYPE_F32, n_embd_head_k, target_workspace_n_kv_total, n_head_kv); + dflash.kv.v_ctx_workspace[(size_t) il] = ggml_new_tensor_3d(dflash.kv.cache_ctx, GGML_TYPE_F32, n_embd_head_v, target_workspace_n_kv_total, n_head_kv); + if (dflash.kv.k_ctx_workspace[(size_t) il] == nullptr || dflash.kv.v_ctx_workspace[(size_t) il] == nullptr) { free_dflash_kv_cache_tensors(); return false; } - ggml_set_input(dflash_k_ctx_workspace[(size_t) il]); - ggml_set_input(dflash_v_ctx_workspace[(size_t) il]); - ggml_format_name(dflash_k_ctx_workspace[(size_t) il], "dflash_k_ctx_workspace_%d", il); - ggml_format_name(dflash_v_ctx_workspace[(size_t) il], "dflash_v_ctx_workspace_%d", il); + ggml_set_input(dflash.kv.k_ctx_workspace[(size_t) il]); + ggml_set_input(dflash.kv.v_ctx_workspace[(size_t) il]); + ggml_format_name(dflash.kv.k_ctx_workspace[(size_t) il], "dflash_k_ctx_workspace_%d", il); + ggml_format_name(dflash.kv.v_ctx_workspace[(size_t) il], "dflash_v_ctx_workspace_%d", il); - const size_t k_workspace_bytes = ggml_backend_buft_get_alloc_size(layer_buft, dflash_k_ctx_workspace[(size_t) il]); + const size_t k_workspace_bytes = ggml_backend_buft_get_alloc_size(layer_buft, dflash.kv.k_ctx_workspace[(size_t) il]); ggml_backend_buffer_t k_workspace_buf = ggml_backend_buft_alloc_buffer(layer_buft, k_workspace_bytes); if (k_workspace_buf == nullptr) { free_dflash_kv_cache_tensors(); return false; } ggml_backend_buffer_set_usage(k_workspace_buf, GGML_BACKEND_BUFFER_USAGE_COMPUTE); - ggml_backend_tensor_alloc(k_workspace_buf, dflash_k_ctx_workspace[(size_t) il], ggml_backend_buffer_get_base(k_workspace_buf)); + ggml_backend_tensor_alloc(k_workspace_buf, dflash.kv.k_ctx_workspace[(size_t) il], ggml_backend_buffer_get_base(k_workspace_buf)); ggml_backend_buffer_clear(k_workspace_buf, 0); - dflash_cache_bufs.push_back(k_workspace_buf); + dflash.kv.cache_bufs.push_back(k_workspace_buf); - const size_t v_workspace_bytes = ggml_backend_buft_get_alloc_size(layer_buft, dflash_v_ctx_workspace[(size_t) il]); + const size_t v_workspace_bytes = ggml_backend_buft_get_alloc_size(layer_buft, dflash.kv.v_ctx_workspace[(size_t) il]); ggml_backend_buffer_t v_workspace_buf = ggml_backend_buft_alloc_buffer(layer_buft, v_workspace_bytes); if (v_workspace_buf == nullptr) { free_dflash_kv_cache_tensors(); return false; } ggml_backend_buffer_set_usage(v_workspace_buf, GGML_BACKEND_BUFFER_USAGE_COMPUTE); - ggml_backend_tensor_alloc(v_workspace_buf, dflash_v_ctx_workspace[(size_t) il], ggml_backend_buffer_get_base(v_workspace_buf)); + ggml_backend_tensor_alloc(v_workspace_buf, dflash.kv.v_ctx_workspace[(size_t) il], ggml_backend_buffer_get_base(v_workspace_buf)); ggml_backend_buffer_clear(v_workspace_buf, 0); - dflash_cache_bufs.push_back(v_workspace_buf); + dflash.kv.cache_bufs.push_back(v_workspace_buf); } - dflash_profile.last_kv_cache_host_layers = host_layers; - dflash_kv_workspace_token_capacity = target_token_capacity; - dflash_kv_workspace_n_kv_total = target_workspace_n_kv_total; + dflash.kv.workspace_token_capacity = target_token_capacity; + dflash.kv.workspace_n_kv_total = target_workspace_n_kv_total; llama_reset_dflash_kv_cache_state(this); - if (llama_dflash_stats_log_enabled()) { - LLAMA_LOG_INFO("%s: DFlash K/V cache placement cross_ctx=%d host_layers=%d/%d first=%s last=%s\n", - __func__, - target_cross_ctx, - host_layers, - n_layer, - first_buft_name != nullptr ? first_buft_name : "(none)", - last_buft_name != nullptr ? last_buft_name : "(none)"); - } return true; } void llama_context::free_dflash_kv_cache_tensors() { - dflash_k_ctx_cache.clear(); - dflash_v_ctx_cache.clear(); - dflash_k_ctx_workspace.clear(); - dflash_v_ctx_workspace.clear(); - dflash_kv_cache_write_pos = 0; - dflash_kv_cache_n_filled = 0; - dflash_kv_cache_update_rows = 0; - dflash_kv_cache_reserved_rows = 0; - dflash_kv_cache_view_write_pos = 0; - dflash_kv_cache_view_n_filled = 0; - dflash_kv_cache_applied_window_version = 0; - dflash_kv_cache_valid = false; - dflash_kv_cache_view_valid = false; - dflash_kv_workspace_write_pos = 0; - dflash_kv_workspace_n_filled = 0; - dflash_kv_workspace_reserved_rows = 0; - dflash_kv_workspace_token_capacity = 0; - dflash_kv_workspace_n_kv_total = 0; - dflash_kv_workspace_applied_window_version = 0; - dflash_kv_workspace_valid = false; - dflash_kv_workspace_sync_pending = false; - dflash_kv_graph = nullptr; - dflash_kv_workspace_graph = nullptr; - dflash_kv_graph_rows = 0; - dflash_kv_graph_write_pos = 0; - dflash_kv_workspace_graph_rows = 0; - dflash_kv_workspace_graph_write_pos = 0; - dflash_kv_input_target_features = nullptr; - dflash_kv_input_pos_ctx = nullptr; - dflash_kq_mask_tensor = nullptr; - dflash_kq_mask_swa_tensor = nullptr; + dflash.kv.k_ctx_cache.clear(); + dflash.kv.v_ctx_cache.clear(); + dflash.kv.k_ctx_workspace.clear(); + dflash.kv.v_ctx_workspace.clear(); + dflash.kv.cache_write_pos = 0; + dflash.kv.cache_n_filled = 0; + dflash.kv.cache_update_rows = 0; + dflash.kv.cache_reserved_rows = 0; + dflash.kv.cache_view_write_pos = 0; + dflash.kv.cache_view_n_filled = 0; + dflash.kv.cache_applied_window_version = 0; + dflash.kv.cache_valid = false; + dflash.kv.cache_view_valid = false; + dflash.kv.workspace_write_pos = 0; + dflash.kv.workspace_n_filled = 0; + dflash.kv.workspace_reserved_rows = 0; + dflash.kv.workspace_token_capacity = 0; + dflash.kv.workspace_n_kv_total = 0; + dflash.kv.workspace_applied_window_version = 0; + dflash.kv.workspace_valid = false; + dflash.kv.workspace_sync_pending = false; + dflash.kv.cache_graph = nullptr; + dflash.kv.workspace_graph = nullptr; + dflash.kv.cache_graph_rows = 0; + dflash.kv.cache_graph_write_pos = 0; + dflash.kv.workspace_graph_rows = 0; + dflash.kv.workspace_graph_write_pos = 0; + dflash.kv.cache_input_target_features = nullptr; + dflash.kv.cache_input_pos_ctx = nullptr; + dflash.kv.kq_mask_tensor = nullptr; + dflash.kv.kq_mask_swa_tensor = nullptr; - if (dflash_workspace_sched != nullptr) { - ggml_backend_sched_synchronize(dflash_workspace_sched); - ggml_backend_sched_free(dflash_workspace_sched); - dflash_workspace_sched = nullptr; + if (dflash.kv.workspace_sched != nullptr) { + ggml_backend_sched_synchronize(dflash.kv.workspace_sched); + ggml_backend_sched_free(dflash.kv.workspace_sched); + dflash.kv.workspace_sched = nullptr; } - for (ggml_backend_buffer_t buf : dflash_cache_bufs) { + for (ggml_backend_buffer_t buf : dflash.kv.cache_bufs) { if (buf != nullptr) { ggml_backend_buffer_free(buf); } } - dflash_cache_bufs.clear(); - if (dflash_cache_ctx != nullptr) { - ggml_free(dflash_cache_ctx); - dflash_cache_ctx = nullptr; + dflash.kv.cache_bufs.clear(); + if (dflash.kv.cache_ctx != nullptr) { + ggml_free(dflash.kv.cache_ctx); + dflash.kv.cache_ctx = nullptr; } } @@ -418,13 +382,11 @@ static bool validate_dflash_graph_contract(const llama_context & lctx) { bool llama_prepare_dflash_graph_inputs( struct llama_context & lctx, uint32_t n_tokens) { - const bool kv_node_timing = llama_env_flag_enabled_local("IK_DFLASH_KV_NODE_TIMING"); - auto & profile = lctx.dflash_profile; - const int32_t cross_ctx = lctx.dflash_visible_cross_ctx > 0 - ? lctx.dflash_visible_cross_ctx + const int32_t cross_ctx = lctx.dflash.visible_cross_ctx > 0 + ? lctx.dflash.visible_cross_ctx : std::max(1, (int32_t) lctx.cparams.n_ctx - (int32_t) lctx.model.hparams.dflash_block_size); - ggml_tensor * kq_mask = lctx.dflash_kq_mask_tensor; - ggml_tensor * kq_mask_swa = lctx.dflash_kq_mask_swa_tensor; + ggml_tensor * kq_mask = lctx.dflash.kv.kq_mask_tensor; + ggml_tensor * kq_mask_swa = lctx.dflash.kv.kq_mask_swa_tensor; if (kq_mask == nullptr) { LLAMA_LOG_ERROR("%s: DFlash graph inputs are not initialized\n", __func__); @@ -432,113 +394,84 @@ bool llama_prepare_dflash_graph_inputs( } if (!validate_dflash_graph_contract(lctx)) { - profile.graph_shape_failures++; return false; } - if (!lctx.ensure_dflash_kv_cache_tensors(cross_ctx) || lctx.dflash_k_ctx_cache.empty() || lctx.dflash_v_ctx_cache.empty()) { + if (!lctx.ensure_dflash_kv_cache_tensors(cross_ctx) || lctx.dflash.kv.k_ctx_cache.empty() || lctx.dflash.kv.v_ctx_cache.empty()) { LLAMA_LOG_ERROR("%s: DFlash K/V cache inputs are not initialized\n", __func__); return false; } - const float * src = lctx.dflash_target_features; - const float * append_src = lctx.dflash_target_append_features; - const llama_pos * src_pos = lctx.dflash_target_positions; - const size_t total_floats = lctx.dflash_target_features_n_floats; - const size_t append_floats = lctx.dflash_target_append_features_n_floats; - const size_t total_positions = lctx.dflash_target_positions_n; - const int32_t n_rows = lctx.dflash_target_features_n_rows; - const int32_t append_rows_available = lctx.dflash_target_append_features_n_rows; + const float * src = lctx.dflash.target.features; + const float * append_src = lctx.dflash.target.append_features; + const llama_pos * src_pos = lctx.dflash.target.positions; + const size_t total_floats = lctx.dflash.target.features_n_floats; + const size_t append_floats = lctx.dflash.target.append_features_n_floats; + const size_t total_positions = lctx.dflash.target.positions_n; + const int32_t n_rows = lctx.dflash.target.features_n_rows; + const int32_t append_rows_available = lctx.dflash.target.append_features_n_rows; const int32_t width = (int32_t) lctx.model.hparams.dflash_n_target_features; - const int32_t graph_cross_ctx = lctx.dflash_k_ctx_cache.front() != nullptr - ? (int32_t) lctx.dflash_k_ctx_cache.front()->ne[2] + const int32_t graph_cross_ctx = lctx.dflash.kv.k_ctx_cache.front() != nullptr + ? (int32_t) lctx.dflash.kv.k_ctx_cache.front()->ne[2] : 0; const int32_t n_mask_tokens = (int32_t) kq_mask->ne[1]; const int32_t n_kv_total = (int32_t) kq_mask->ne[0]; - const int64_t t_total_us = ggml_time_us(); - - profile.graph_prepare_calls++; - profile.last_n_rows = n_rows; - profile.last_width = width; - profile.last_cross_ctx = cross_ctx; - profile.last_n_tokens = (int32_t) n_tokens; - profile.last_n_kv_total = n_kv_total; llama_sync_dflash_workspace_if_pending(lctx); if (graph_cross_ctx != cross_ctx) { - profile.graph_shape_failures++; - LLAMA_LOG_ERROR("%s: DFlash graph cross_ctx drift (graph=%d configured=%d)\n", __func__, graph_cross_ctx, cross_ctx); return false; } if (n_rows <= 0) { - profile.graph_shape_failures++; LLAMA_LOG_ERROR("%s: missing DFlash target feature rows\n", __func__); return false; } const bool have_full_src = src != nullptr && total_floats == (size_t) n_rows * (size_t) width; if (n_rows > cross_ctx || (src != nullptr && !have_full_src)) { - profile.graph_shape_failures++; LLAMA_LOG_ERROR("%s: invalid DFlash target feature shape (rows=%d width=%d floats=%zu cross_ctx=%d)\n", __func__, n_rows, width, total_floats, cross_ctx); return false; } if (n_kv_total < cross_ctx + (int32_t) n_tokens) { - profile.graph_mask_overflow++; LLAMA_LOG_ERROR("%s: invalid DFlash mask shape (n_kv_total=%d < cross_ctx+n_tokens=%d)\n", __func__, n_kv_total, cross_ctx + (int32_t) n_tokens); return false; } const int32_t left_pad = cross_ctx - n_rows; - profile.last_left_pad = left_pad; - const int64_t t_pos_us = ggml_time_us(); - lctx.dflash_pos_ctx_data.resize((size_t) cross_ctx); - std::fill(lctx.dflash_pos_ctx_data.begin(), lctx.dflash_pos_ctx_data.end(), 0); + lctx.dflash.target.pos_ctx_data.resize((size_t) cross_ctx); + std::fill(lctx.dflash.target.pos_ctx_data.begin(), lctx.dflash.target.pos_ctx_data.end(), 0); if (src_pos == nullptr || total_positions != (size_t) n_rows) { - profile.graph_pos_fallbacks++; - profile.graph_shape_failures++; - profile.last_pos_first = -1; - profile.last_pos_last = -1; - if (profile.graph_pos_fallbacks <= 3) { - LLAMA_LOG_ERROR("%s: missing DFlash target positions (rows=%d positions=%zu cross_ctx=%d)\n", - __func__, n_rows, total_positions, cross_ctx); - } + LLAMA_LOG_ERROR("%s: missing DFlash target positions (rows=%d positions=%zu cross_ctx=%d)\n", + __func__, n_rows, total_positions, cross_ctx); return false; } - profile.last_pos_first = src_pos[0]; - profile.last_pos_last = src_pos[n_rows - 1]; + const llama_pos last_target_pos = src_pos[n_rows - 1]; for (int32_t i = 1; i < n_rows; ++i) { if (src_pos[i] <= src_pos[i - 1]) { - profile.graph_pos_non_monotonic++; - profile.graph_shape_failures++; - if (profile.graph_pos_non_monotonic <= 3) { - LLAMA_LOG_ERROR("%s: DFlash target positions are not strictly increasing (rows=%d first=%d last=%d)\n", - __func__, n_rows, (int) src_pos[0], (int) src_pos[n_rows - 1]); - } + LLAMA_LOG_ERROR("%s: DFlash target positions are not strictly increasing (rows=%d first=%d last=%d)\n", + __func__, n_rows, (int) src_pos[0], (int) src_pos[n_rows - 1]); return false; } } - std::copy(src_pos, src_pos + n_rows, lctx.dflash_pos_ctx_data.begin() + (ptrdiff_t) left_pad); - profile.graph_pos_copy_us += (uint64_t) (ggml_time_us() - t_pos_us); - profile.graph_pos_bytes += lctx.dflash_pos_ctx_data.size() * sizeof(llama_pos); + std::copy(src_pos, src_pos + n_rows, lctx.dflash.target.pos_ctx_data.begin() + (ptrdiff_t) left_pad); const llama_dflash_kv_cache_transition cache_plan = llama_plan_dflash_kv_cache_transition( cross_ctx, - lctx.dflash_kv_cache_n_filled, - lctx.dflash_kv_cache_write_pos, - lctx.dflash_kv_cache_valid, - lctx.dflash_kv_cache_applied_window_version, - lctx.dflash_target_window_version, - lctx.dflash_target_window_keep_rows, - lctx.dflash_target_window_append_rows, - lctx.dflash_target_window_replace, + lctx.dflash.kv.cache_n_filled, + lctx.dflash.kv.cache_write_pos, + lctx.dflash.kv.cache_valid, + lctx.dflash.kv.cache_applied_window_version, + lctx.dflash.target.version, + lctx.dflash.target.keep_rows, + lctx.dflash.target.append_rows, + lctx.dflash.target.replace, n_rows); const bool have_append_src = append_src != nullptr && @@ -550,11 +483,11 @@ bool llama_prepare_dflash_graph_inputs( : (cache_plan.rebuild_cache ? n_rows : cache_plan.append_rows); const size_t max_nodes = lctx.model.max_nodes((int) std::max(1, cross_ctx)) + 24 * lctx.model.hparams.n_layer; const size_t meta_size = ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false); - if (lctx.dflash_buf_compute_meta.size() != meta_size) { - lctx.dflash_buf_compute_meta.resize(meta_size); + if (lctx.dflash.kv.cache_compute_meta.size() != meta_size) { + lctx.dflash.kv.cache_compute_meta.resize(meta_size); } - if (lctx.dflash_sched == nullptr || lctx.dflash_kv_cache_reserved_rows != cross_ctx) { + if (lctx.dflash.kv.cache_sched == nullptr || lctx.dflash.kv.cache_reserved_rows != cross_ctx) { std::vector backend_buft; backend_buft.reserve(lctx.backends.size()); for (auto * backend : lctx.backends) { @@ -565,36 +498,30 @@ bool llama_prepare_dflash_graph_inputs( } } - if (lctx.dflash_sched != nullptr) { - ggml_backend_sched_free(lctx.dflash_sched); - lctx.dflash_sched = nullptr; + if (lctx.dflash.kv.cache_sched != nullptr) { + ggml_backend_sched_free(lctx.dflash.kv.cache_sched); + lctx.dflash.kv.cache_sched = nullptr; } - lctx.dflash_kv_graph = nullptr; - lctx.dflash_kv_graph_rows = 0; - lctx.dflash_kv_graph_write_pos = 0; + lctx.dflash.kv.cache_graph = nullptr; + lctx.dflash.kv.cache_graph_rows = 0; + lctx.dflash.kv.cache_graph_write_pos = 0; - const int32_t saved_update_rows = lctx.dflash_kv_cache_update_rows; - lctx.dflash_kv_cache_update_rows = cross_ctx; - const int64_t t_build_us = ggml_time_us(); + const int32_t saved_update_rows = lctx.dflash.kv.cache_update_rows; + lctx.dflash.kv.cache_update_rows = cross_ctx; ggml_cgraph * gf_reserve = llm_build_context::llama_build_graph_dflash_kv_cache(lctx); - profile.graph_kv_cache_build_us += (uint64_t) (ggml_time_us() - t_build_us); - lctx.dflash_kv_cache_update_rows = saved_update_rows; + lctx.dflash.kv.cache_update_rows = saved_update_rows; if (gf_reserve == nullptr) { - profile.graph_shape_failures++; LLAMA_LOG_ERROR("%s: failed to build DFlash K/V cache reserve graph\n", __func__); return false; } - const int64_t t_reserve_us = ggml_time_us(); - lctx.dflash_sched = ggml_backend_sched_new(lctx.backends.data(), backend_buft.data(), lctx.backends.size(), max_nodes, false); - const bool reserved = lctx.dflash_sched != nullptr && ggml_backend_sched_reserve(lctx.dflash_sched, gf_reserve); - profile.graph_kv_cache_reserve_us += (uint64_t) (ggml_time_us() - t_reserve_us); + lctx.dflash.kv.cache_sched = ggml_backend_sched_new(lctx.backends.data(), backend_buft.data(), lctx.backends.size(), max_nodes, false); + const bool reserved = lctx.dflash.kv.cache_sched != nullptr && ggml_backend_sched_reserve(lctx.dflash.kv.cache_sched, gf_reserve); if (!reserved) { - profile.graph_shape_failures++; LLAMA_LOG_ERROR("%s: failed to initialize DFlash K/V scheduler\n", __func__); return false; } - lctx.dflash_kv_cache_reserved_rows = cross_ctx; + lctx.dflash.kv.cache_reserved_rows = cross_ctx; } if (update_rows > 0) { @@ -607,7 +534,6 @@ bool llama_prepare_dflash_graph_inputs( const llama_pos * update_pos = src_pos + (n_rows - update_rows); if (update_src == nullptr) { - profile.graph_shape_failures++; LLAMA_LOG_ERROR("%s: missing DFlash appended target features for cached update (rows=%d append_rows=%d floats=%zu)\n", __func__, n_rows, update_rows, append_floats); return false; @@ -617,108 +543,77 @@ bool llama_prepare_dflash_graph_inputs( llama_reset_dflash_kv_cache_state(&lctx); } - lctx.dflash_kv_cache_update_rows = update_rows; + lctx.dflash.kv.cache_update_rows = update_rows; ggml_cgraph * gf_kv = nullptr; - const bool can_reuse_kv_graph = lctx.dflash_kv_graph != nullptr && - lctx.dflash_kv_graph_rows == update_rows && - lctx.dflash_kv_graph_write_pos == lctx.dflash_kv_cache_write_pos; + const bool can_reuse_kv_graph = lctx.dflash.kv.cache_graph != nullptr && + lctx.dflash.kv.cache_graph_rows == update_rows && + lctx.dflash.kv.cache_graph_write_pos == lctx.dflash.kv.cache_write_pos; if (can_reuse_kv_graph) { - gf_kv = lctx.dflash_kv_graph; + gf_kv = lctx.dflash.kv.cache_graph; } else { - const int64_t t_build_us = ggml_time_us(); gf_kv = llm_build_context::llama_build_graph_dflash_kv_cache(lctx); - profile.graph_kv_cache_build_us += (uint64_t) (ggml_time_us() - t_build_us); - if (gf_kv == nullptr || lctx.dflash_kv_input_target_features == nullptr || lctx.dflash_kv_input_pos_ctx == nullptr) { - profile.graph_shape_failures++; + if (gf_kv == nullptr || lctx.dflash.kv.cache_input_target_features == nullptr || lctx.dflash.kv.cache_input_pos_ctx == nullptr) { LLAMA_LOG_ERROR("%s: failed to build DFlash K/V cache graph\n", __func__); return false; } - const int64_t t_reset_us = ggml_time_us(); - ggml_backend_sched_reset(lctx.dflash_sched); - profile.graph_kv_cache_reset_us += (uint64_t) (ggml_time_us() - t_reset_us); + ggml_backend_sched_reset(lctx.dflash.kv.cache_sched); + ggml_backend_sched_alloc_graph(lctx.dflash.kv.cache_sched, gf_kv); - const int64_t t_alloc_us = ggml_time_us(); - ggml_backend_sched_alloc_graph(lctx.dflash_sched, gf_kv); - profile.graph_kv_cache_alloc_us += (uint64_t) (ggml_time_us() - t_alloc_us); - - lctx.dflash_kv_graph = gf_kv; - lctx.dflash_kv_graph_rows = update_rows; - lctx.dflash_kv_graph_write_pos = lctx.dflash_kv_cache_write_pos; + lctx.dflash.kv.cache_graph = gf_kv; + lctx.dflash.kv.cache_graph_rows = update_rows; + lctx.dflash.kv.cache_graph_write_pos = lctx.dflash.kv.cache_write_pos; } - ggml_backend_t kv_feature_backend = llama_backend_for_tensor(lctx, lctx.dflash_kv_input_target_features); - const int64_t t_feature_upload_us = ggml_time_us(); + ggml_backend_t kv_feature_backend = llama_backend_for_tensor(lctx, lctx.dflash.kv.cache_input_target_features); if (kv_feature_backend != nullptr) { - ggml_backend_tensor_set_async(kv_feature_backend, lctx.dflash_kv_input_target_features, update_src, 0, ggml_nbytes(lctx.dflash_kv_input_target_features)); + ggml_backend_tensor_set_async(kv_feature_backend, lctx.dflash.kv.cache_input_target_features, update_src, 0, ggml_nbytes(lctx.dflash.kv.cache_input_target_features)); } else { - ggml_backend_tensor_set(lctx.dflash_kv_input_target_features, update_src, 0, ggml_nbytes(lctx.dflash_kv_input_target_features)); + ggml_backend_tensor_set(lctx.dflash.kv.cache_input_target_features, update_src, 0, ggml_nbytes(lctx.dflash.kv.cache_input_target_features)); } - profile.graph_kv_cache_feature_upload_us += (uint64_t) (ggml_time_us() - t_feature_upload_us); - profile.graph_feature_bytes += (size_t) update_rows * (size_t) width * sizeof(float); - ggml_backend_t kv_pos_backend = llama_backend_for_tensor(lctx, lctx.dflash_kv_input_pos_ctx); - const int64_t t_pos_upload_us = ggml_time_us(); + ggml_backend_t kv_pos_backend = llama_backend_for_tensor(lctx, lctx.dflash.kv.cache_input_pos_ctx); if (kv_pos_backend != nullptr) { - ggml_backend_tensor_set_async(kv_pos_backend, lctx.dflash_kv_input_pos_ctx, update_pos, 0, ggml_nbytes(lctx.dflash_kv_input_pos_ctx)); + ggml_backend_tensor_set_async(kv_pos_backend, lctx.dflash.kv.cache_input_pos_ctx, update_pos, 0, ggml_nbytes(lctx.dflash.kv.cache_input_pos_ctx)); } else { - ggml_backend_tensor_set(lctx.dflash_kv_input_pos_ctx, update_pos, 0, ggml_nbytes(lctx.dflash_kv_input_pos_ctx)); + ggml_backend_tensor_set(lctx.dflash.kv.cache_input_pos_ctx, update_pos, 0, ggml_nbytes(lctx.dflash.kv.cache_input_pos_ctx)); } - profile.graph_kv_cache_pos_upload_us += (uint64_t) (ggml_time_us() - t_pos_upload_us); + llama_graph_compute_sched(lctx, lctx.dflash.kv.cache_sched, gf_kv, lctx.cparams.n_threads); + ggml_backend_sched_synchronize(lctx.dflash.kv.cache_sched); - const int64_t t_kv_cache_us = ggml_time_us(); - llama_dflash_kv_node_profiler kv_node_profiler; - if (kv_node_timing) { - kv_node_profiler.profile = &profile; - ggml_backend_sched_set_eval_callback(lctx.dflash_sched, llama_dflash_kv_node_eval_callback, &kv_node_profiler); - } - llama_graph_compute_sched(lctx, lctx.dflash_sched, gf_kv, lctx.cparams.n_threads); - if (kv_node_timing) { - ggml_backend_sched_set_eval_callback(lctx.dflash_sched, nullptr, nullptr); - } - profile.graph_kv_cache_compute_us += (uint64_t) (ggml_time_us() - t_kv_cache_us); - - const int64_t t_sync_us = ggml_time_us(); - ggml_backend_sched_synchronize(lctx.dflash_sched); - profile.graph_kv_cache_sync_us += (uint64_t) (ggml_time_us() - t_sync_us); - profile.graph_kv_cache_calls++; - - lctx.dflash_kv_cache_n_filled = std::min(cross_ctx, lctx.dflash_kv_cache_n_filled + update_rows); - lctx.dflash_kv_cache_write_pos = (lctx.dflash_kv_cache_write_pos + update_rows) % cross_ctx; - lctx.dflash_kv_cache_applied_window_version = lctx.dflash_target_window_version; - lctx.dflash_kv_cache_valid = true; - lctx.dflash_kv_cache_view_n_filled = lctx.dflash_kv_cache_n_filled; - lctx.dflash_kv_cache_view_write_pos = lctx.dflash_kv_cache_write_pos; - lctx.dflash_kv_cache_view_valid = true; + lctx.dflash.kv.cache_n_filled = std::min(cross_ctx, lctx.dflash.kv.cache_n_filled + update_rows); + lctx.dflash.kv.cache_write_pos = (lctx.dflash.kv.cache_write_pos + update_rows) % cross_ctx; + lctx.dflash.kv.cache_applied_window_version = lctx.dflash.target.version; + lctx.dflash.kv.cache_valid = true; + lctx.dflash.kv.cache_view_n_filled = lctx.dflash.kv.cache_n_filled; + lctx.dflash.kv.cache_view_write_pos = lctx.dflash.kv.cache_write_pos; + lctx.dflash.kv.cache_view_valid = true; } - if (lctx.dflash_kv_cache_view_valid && - !lctx.dflash_k_ctx_workspace.empty() && !lctx.dflash_v_ctx_workspace.empty()) { - const bool need_workspace_refresh = !lctx.dflash_kv_workspace_valid || - lctx.dflash_kv_workspace_n_filled != lctx.dflash_kv_cache_view_n_filled || - lctx.dflash_kv_workspace_write_pos != lctx.dflash_kv_cache_view_write_pos || - lctx.dflash_kv_workspace_applied_window_version != lctx.dflash_kv_cache_applied_window_version; + if (lctx.dflash.kv.cache_view_valid && + !lctx.dflash.kv.k_ctx_workspace.empty() && !lctx.dflash.kv.v_ctx_workspace.empty()) { + const bool need_workspace_refresh = !lctx.dflash.kv.workspace_valid || + lctx.dflash.kv.workspace_n_filled != lctx.dflash.kv.cache_view_n_filled || + lctx.dflash.kv.workspace_write_pos != lctx.dflash.kv.cache_view_write_pos || + lctx.dflash.kv.workspace_applied_window_version != lctx.dflash.kv.cache_applied_window_version; if (need_workspace_refresh) { const size_t max_nodes = lctx.model.max_nodes((int) std::max(1, cross_ctx)) + 16 * lctx.model.hparams.n_layer; const size_t meta_size = ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false); - if (lctx.dflash_workspace_buf_compute_meta.size() != meta_size) { - lctx.dflash_workspace_buf_compute_meta.resize(meta_size); + if (lctx.dflash.kv.workspace_compute_meta.size() != meta_size) { + lctx.dflash.kv.workspace_compute_meta.resize(meta_size); } ggml_cgraph * gf_workspace = nullptr; - const bool can_reuse_workspace_graph = lctx.dflash_kv_workspace_graph != nullptr && - lctx.dflash_kv_workspace_graph_rows == lctx.dflash_kv_cache_view_n_filled && - lctx.dflash_kv_workspace_graph_write_pos == lctx.dflash_kv_cache_view_write_pos; + const bool can_reuse_workspace_graph = lctx.dflash.kv.workspace_graph != nullptr && + lctx.dflash.kv.workspace_graph_rows == lctx.dflash.kv.cache_view_n_filled && + lctx.dflash.kv.workspace_graph_write_pos == lctx.dflash.kv.cache_view_write_pos; if (can_reuse_workspace_graph) { - gf_workspace = lctx.dflash_kv_workspace_graph; + gf_workspace = lctx.dflash.kv.workspace_graph; } else { - const int64_t t_build_us = ggml_time_us(); gf_workspace = llm_build_context::llama_build_graph_dflash_kv_workspace(lctx); - profile.graph_kv_workspace_build_us += (uint64_t) (ggml_time_us() - t_build_us); if (gf_workspace == nullptr) { - profile.graph_shape_failures++; LLAMA_LOG_ERROR("%s: failed to build DFlash K/V workspace graph\n", __func__); return false; } @@ -733,95 +628,75 @@ bool llama_prepare_dflash_graph_inputs( } } - if (lctx.dflash_workspace_sched == nullptr) { - lctx.dflash_workspace_sched = ggml_backend_sched_new(lctx.backends.data(), backend_buft.data(), lctx.backends.size(), max_nodes, false); + if (lctx.dflash.kv.workspace_sched == nullptr) { + lctx.dflash.kv.workspace_sched = ggml_backend_sched_new(lctx.backends.data(), backend_buft.data(), lctx.backends.size(), max_nodes, false); } - if (lctx.dflash_kv_workspace_reserved_rows != cross_ctx) { - const bool saved_view_valid = lctx.dflash_kv_cache_view_valid; - const int32_t saved_view_rows = lctx.dflash_kv_cache_view_n_filled; - const int32_t saved_view_write_pos = lctx.dflash_kv_cache_view_write_pos; + if (lctx.dflash.kv.workspace_reserved_rows != cross_ctx) { + const bool saved_view_valid = lctx.dflash.kv.cache_view_valid; + const int32_t saved_view_rows = lctx.dflash.kv.cache_view_n_filled; + const int32_t saved_view_write_pos = lctx.dflash.kv.cache_view_write_pos; - lctx.dflash_kv_cache_view_valid = true; - lctx.dflash_kv_cache_view_n_filled = cross_ctx; - lctx.dflash_kv_cache_view_write_pos = cross_ctx > 1 ? 1 : 0; + lctx.dflash.kv.cache_view_valid = true; + lctx.dflash.kv.cache_view_n_filled = cross_ctx; + lctx.dflash.kv.cache_view_write_pos = cross_ctx > 1 ? 1 : 0; - const int64_t t_reserve_build_us = ggml_time_us(); ggml_cgraph * gf_workspace_reserve = llm_build_context::llama_build_graph_dflash_kv_workspace(lctx); - profile.graph_kv_workspace_build_us += (uint64_t) (ggml_time_us() - t_reserve_build_us); - lctx.dflash_kv_cache_view_valid = saved_view_valid; - lctx.dflash_kv_cache_view_n_filled = saved_view_rows; - lctx.dflash_kv_cache_view_write_pos = saved_view_write_pos; + lctx.dflash.kv.cache_view_valid = saved_view_valid; + lctx.dflash.kv.cache_view_n_filled = saved_view_rows; + lctx.dflash.kv.cache_view_write_pos = saved_view_write_pos; - const int64_t t_reserve_us = ggml_time_us(); - const bool reserved = lctx.dflash_workspace_sched != nullptr && + const bool reserved = lctx.dflash.kv.workspace_sched != nullptr && gf_workspace_reserve != nullptr && - ggml_backend_sched_reserve(lctx.dflash_workspace_sched, gf_workspace_reserve); - profile.graph_kv_workspace_reserve_us += (uint64_t) (ggml_time_us() - t_reserve_us); + ggml_backend_sched_reserve(lctx.dflash.kv.workspace_sched, gf_workspace_reserve); if (!reserved) { - profile.graph_shape_failures++; LLAMA_LOG_ERROR("%s: failed to initialize DFlash K/V workspace scheduler\n", __func__); return false; } - lctx.dflash_kv_workspace_reserved_rows = cross_ctx; + lctx.dflash.kv.workspace_reserved_rows = cross_ctx; } - const int64_t t_reset_us = ggml_time_us(); - ggml_backend_sched_reset(lctx.dflash_workspace_sched); - profile.graph_kv_workspace_reset_us += (uint64_t) (ggml_time_us() - t_reset_us); + ggml_backend_sched_reset(lctx.dflash.kv.workspace_sched); + ggml_backend_sched_alloc_graph(lctx.dflash.kv.workspace_sched, gf_workspace); - const int64_t t_alloc_us = ggml_time_us(); - ggml_backend_sched_alloc_graph(lctx.dflash_workspace_sched, gf_workspace); - profile.graph_kv_workspace_alloc_us += (uint64_t) (ggml_time_us() - t_alloc_us); - - lctx.dflash_kv_workspace_graph = gf_workspace; - lctx.dflash_kv_workspace_graph_rows = lctx.dflash_kv_cache_view_n_filled; - lctx.dflash_kv_workspace_graph_write_pos = lctx.dflash_kv_cache_view_write_pos; + lctx.dflash.kv.workspace_graph = gf_workspace; + lctx.dflash.kv.workspace_graph_rows = lctx.dflash.kv.cache_view_n_filled; + lctx.dflash.kv.workspace_graph_write_pos = lctx.dflash.kv.cache_view_write_pos; } - const int64_t t_workspace_us = ggml_time_us(); - llama_graph_compute_sched(lctx, lctx.dflash_workspace_sched, gf_workspace, lctx.cparams.n_threads); - profile.graph_kv_workspace_compute_us += (uint64_t) (ggml_time_us() - t_workspace_us); - lctx.dflash_kv_workspace_sync_pending = true; - profile.graph_kv_workspace_calls++; + llama_graph_compute_sched(lctx, lctx.dflash.kv.workspace_sched, gf_workspace, lctx.cparams.n_threads); + lctx.dflash.kv.workspace_sync_pending = true; - lctx.dflash_kv_workspace_n_filled = lctx.dflash_kv_cache_view_n_filled; - lctx.dflash_kv_workspace_write_pos = lctx.dflash_kv_cache_view_write_pos; - lctx.dflash_kv_workspace_applied_window_version = lctx.dflash_kv_cache_applied_window_version; - lctx.dflash_kv_workspace_valid = true; + lctx.dflash.kv.workspace_n_filled = lctx.dflash.kv.cache_view_n_filled; + lctx.dflash.kv.workspace_write_pos = lctx.dflash.kv.cache_view_write_pos; + lctx.dflash.kv.workspace_applied_window_version = lctx.dflash.kv.cache_applied_window_version; + lctx.dflash.kv.workspace_valid = true; } } - const int64_t t_mask_us = ggml_time_us(); const int32_t full_visible_first = left_pad; const int32_t full_visible_last = cross_ctx + (int32_t) n_tokens - 1; - lctx.dflash_kq_mask_data.assign((size_t) n_kv_total * (size_t) n_mask_tokens, -INFINITY); - int32_t visible_kv_max = 0; + lctx.dflash.target.kq_mask_data.assign((size_t) n_kv_total * (size_t) n_mask_tokens, -INFINITY); for (uint32_t j = 0; j < n_tokens; ++j) { - float * row = lctx.dflash_kq_mask_data.data() + (size_t) j * (size_t) n_kv_total; - const int32_t visible_kv = cross_ctx + (int32_t) n_tokens; - visible_kv_max = std::max(visible_kv_max, visible_kv); - profile.graph_visible_kv_sum += (uint64_t) visible_kv; + float * row = lctx.dflash.target.kq_mask_data.data() + (size_t) j * (size_t) n_kv_total; for (int32_t i = full_visible_first; i <= full_visible_last; ++i) { row[i] = 0.0f; } } - ggml_backend_tensor_set(kq_mask, lctx.dflash_kq_mask_data.data(), 0, ggml_nbytes(kq_mask)); - profile.graph_mask_build_us += (uint64_t) (ggml_time_us() - t_mask_us); - profile.graph_mask_bytes += ggml_nbytes(kq_mask); + ggml_backend_tensor_set(kq_mask, lctx.dflash.target.kq_mask_data.data(), 0, ggml_nbytes(kq_mask)); if (kq_mask_swa != nullptr) { - lctx.dflash_kq_mask_swa_data.assign((size_t) n_kv_total * (size_t) n_mask_tokens, -INFINITY); + lctx.dflash.target.kq_mask_swa_data.assign((size_t) n_kv_total * (size_t) n_mask_tokens, -INFINITY); const int32_t swa_window = (int32_t) lctx.model.hparams.n_swa; - const int32_t draft_pos_base = (int32_t) profile.last_pos_last; + const int32_t draft_pos_base = (int32_t) last_target_pos; for (uint32_t j = 0; j < n_tokens; ++j) { - float * row = lctx.dflash_kq_mask_swa_data.data() + (size_t) j * (size_t) n_kv_total; + float * row = lctx.dflash.target.kq_mask_swa_data.data() + (size_t) j * (size_t) n_kv_total; const int32_t q_pos = draft_pos_base + (int32_t) j; for (int32_t k = left_pad; k < cross_ctx; ++k) { - const int32_t k_pos = (int32_t) lctx.dflash_pos_ctx_data[(size_t) k]; + const int32_t k_pos = (int32_t) lctx.dflash.target.pos_ctx_data[(size_t) k]; if (q_pos - k_pos < swa_window) { row[k] = 0.0f; } @@ -835,26 +710,7 @@ bool llama_prepare_dflash_graph_inputs( } } - ggml_backend_tensor_set(kq_mask_swa, lctx.dflash_kq_mask_swa_data.data(), 0, ggml_nbytes(kq_mask_swa)); - profile.graph_mask_bytes += ggml_nbytes(kq_mask_swa); - } - - profile.graph_visible_kv_max = std::max(profile.graph_visible_kv_max, (uint64_t) visible_kv_max); - profile.graph_prepare_total_us += (uint64_t) (ggml_time_us() - t_total_us); - - if (profile.graph_prepare_calls == 1 && llama_dflash_stats_log_enabled()) { - int32_t n_swa_layers = 0; - for (int32_t il = 0; il < lctx.model.hparams.n_layer; ++il) { - n_swa_layers += lctx.model.hparams.swa_layers[(size_t) il] ? 1 : 0; - } - - LLAMA_LOG_INFO("%s: DFlash graph contract rows=%d width=%d cross_ctx=%d n_tokens=%u left_pad=%d n_kv_total=%d draft_n_ctx=%u pos=%s [%d..%d] full_mask=[%d..%d] swa_window=%u swa_layers=%d\n", - __func__, n_rows, width, cross_ctx, n_tokens, left_pad, n_kv_total, lctx.cparams.n_ctx, - (src_pos != nullptr && total_positions == (size_t) n_rows) ? "target" : "synthetic", - (int) profile.last_pos_first, (int) profile.last_pos_last, - full_visible_first, full_visible_last, - lctx.model.hparams.n_swa, - n_swa_layers); + ggml_backend_tensor_set(kq_mask_swa, lctx.dflash.target.kq_mask_swa_data.data(), 0, ggml_nbytes(kq_mask_swa)); } return true; diff --git a/src/llama-spec-features-dflash.cpp b/src/llama-spec-features-dflash.cpp index c8cd5181..6a5e3ed5 100644 --- a/src/llama-spec-features-dflash.cpp +++ b/src/llama-spec-features-dflash.cpp @@ -1,55 +1,13 @@ #include "llama-spec-features.h" #include -#include #include #include #include -#include #include "llama-model.h" #include "llama-context.h" -static bool llama_dflash_stats_log_enabled() { - const char * env = std::getenv("IK_DFLASH_STATS_LOG"); - return env != nullptr && *env != '\0' && - std::strcmp(env, "0") != 0 && - std::strcmp(env, "false") != 0 && - std::strcmp(env, "off") != 0; -} - -static bool llama_dflash_positions_strictly_increasing( - const llama_pos * positions, - int32_t n_rows, - llama_pos & first_pos, - llama_pos & last_pos) { - first_pos = -1; - last_pos = -1; - - if (positions == nullptr || n_rows <= 0) { - return false; - } - - first_pos = positions[0]; - last_pos = positions[n_rows - 1]; - - for (int32_t i = 1; i < n_rows; ++i) { - if (positions[i] <= positions[i - 1]) { - return false; - } - } - - return true; -} - -void llama_dflash_profile_reset(struct llama_context * ctx) { - if (ctx == nullptr) { - return; - } - - ctx->dflash.profile = {}; -} - void llama_reset_dflash_kv_cache_state(struct llama_context * ctx) { if (ctx == nullptr) { return; @@ -120,17 +78,6 @@ int32_t llama_get_dflash_visible_cross_ctx( return ctx != nullptr ? ctx->dflash.visible_cross_ctx : 0; } -bool llama_dflash_profile_get_stats( - const struct llama_context * ctx, - llama_dflash_profile_stats * stats) { - if (ctx == nullptr || stats == nullptr) { - return false; - } - - *stats = ctx->dflash.profile; - return true; -} - int32_t llama_model_dflash_block_size(const struct llama_model * model) { return model ? (int32_t) model->hparams.dflash_block_size : 0; } @@ -188,48 +135,6 @@ const struct ggml_tensor * llama_model_dflash_output_tensor( return model->tok_embd; } -static const char * llama_dflash_io_mode_name(int32_t io_mode) { - switch (io_mode) { - case LLAMA_DFLASH_IO_MODE_SHARED: - return "shared"; - case LLAMA_DFLASH_IO_MODE_SELF_CONTAINED: - return "self-contained"; - case LLAMA_DFLASH_IO_MODE_MIXED: - return "mixed"; - default: - return "invalid"; - } -} - -static const char * llama_dflash_output_head_kind( - const struct llama_model * draft_model, - const struct llama_model * target_model) { - const struct ggml_tensor * output = llama_model_dflash_output_tensor(draft_model); - if (output == nullptr) { - return "missing"; - } - - if (output == draft_model->tok_embd) { - return draft_model->tok_embd == (target_model ? target_model->tok_embd : nullptr) - ? "shared_token_embedding" - : "token_embedding"; - } - - if (draft_model->output_mtp != nullptr && output == draft_model->output_mtp) { - if (target_model != nullptr && target_model->output_mtp != nullptr && output == target_model->output_mtp) { - return "output_mtp"; - } - - if (std::strcmp(output->name, "output_extra.weight") == 0) { - return "output_extra"; - } - - return "output_mtp"; - } - - return "output"; -} - int32_t llama_model_dflash_io_mode( const struct llama_model * draft_model, const struct llama_model * target_model) { @@ -302,19 +207,6 @@ bool llama_model_share_dflash_io_tensors( } const struct ggml_tensor * output = llama_model_dflash_output_tensor(draft_model); - if (draft_model->tok_embd != nullptr && output != nullptr) { - LLAMA_LOG_INFO("%s: DFlash ready io=%s output_head=%s\n", - __func__, - llama_dflash_io_mode_name(llama_model_dflash_io_mode(draft_model, target_model)), - llama_dflash_output_head_kind(draft_model, target_model)); - if (llama_dflash_stats_log_enabled()) { - LLAMA_LOG_INFO("%s: DFlash IO tensor=%s type=%s\n", - __func__, - output->name[0] != '\0' ? output->name : "(unnamed)", - ggml_type_name(output->type)); - } - } - return draft_model->tok_embd != nullptr && output != nullptr; } @@ -336,14 +228,6 @@ static bool llama_set_dflash_target_features_impl( return false; } - auto & profile = ctx->dflash.profile; - const int64_t t_start_us = ggml_time_us(); - const int32_t row_width = have_full_features - ? (n_rows > 0 ? (int32_t) (n_floats / (size_t) n_rows) : 0) - : (window_update->append_rows > 0 ? (int32_t) (window_update->append_floats / (size_t) window_update->append_rows) : 0); - llama_pos first_pos = -1; - llama_pos last_pos = -1; - if (have_full_features && copy_data) { ctx->dflash.target.features_owned.assign(target_features, target_features + n_floats); ctx->dflash.target.features = ctx->dflash.target.features_owned.data(); @@ -424,28 +308,6 @@ static bool llama_set_dflash_target_features_impl( ctx->dflash.target.positions_n = 0; } - profile.set_target_copy_calls++; - profile.set_target_copy_us += (uint64_t) (ggml_time_us() - t_start_us); - profile.set_target_rows += (uint64_t) n_rows; - profile.set_target_copy_bytes += - (have_full_features ? n_floats : 0) * sizeof(float) + - (have_append_features ? window_update->append_floats : 0) * sizeof(float) + - (target_positions ? (size_t) n_rows * sizeof(llama_pos) : 0); - profile.last_n_rows = n_rows; - profile.last_width = row_width; - - if (target_positions == nullptr) { - profile.set_target_missing_positions++; - profile.last_pos_first = -1; - profile.last_pos_last = -1; - } else { - if (!llama_dflash_positions_strictly_increasing(target_positions, n_rows, first_pos, last_pos)) { - profile.set_target_non_monotonic_positions++; - } - profile.last_pos_first = first_pos; - profile.last_pos_last = last_pos; - } - return true; } @@ -469,35 +331,6 @@ bool llama_set_dflash_target_features_view( return llama_set_dflash_target_features_impl(ctx, target_features, n_floats, n_rows, target_positions, false, window_update); } -static void llama_record_dflash_capture_phase( - struct llama_context * ctx, - bool is_prompt_warmup, - int32_t row_count, - int32_t row_width) { - if (ctx == nullptr || row_count <= 0 || row_width <= 0) { - return; - } - - auto & profile = ctx->dflash.profile; - if (is_prompt_warmup) { - profile.capture_prompt_batches++; - if (profile.capture_prompt_last_rows > 0 && profile.capture_prompt_last_width > 0 && - (profile.capture_prompt_last_rows != row_count || profile.capture_prompt_last_width != row_width)) { - profile.capture_prompt_shape_changes++; - } - profile.capture_prompt_last_rows = row_count; - profile.capture_prompt_last_width = row_width; - } else { - profile.capture_verify_batches++; - if (profile.capture_verify_last_rows > 0 && profile.capture_verify_last_width > 0 && - (profile.capture_verify_last_rows != row_count || profile.capture_verify_last_width != row_width)) { - profile.capture_verify_shape_changes++; - } - profile.capture_verify_last_rows = row_count; - profile.capture_verify_last_width = row_width; - } -} - static bool llama_dflash_parse_layer_id(const struct ggml_tensor * tensor, int32_t & layer_id) { if (tensor == nullptr) { return false; @@ -644,9 +477,8 @@ void llama_finish_dflash_capture_batch( return; } + GGML_UNUSED(is_prompt_warmup); auto & capture = *ctx->dflash.capture; - llama_record_dflash_capture_phase(ctx, is_prompt_warmup, capture.row_count, capture.row_width); - // Reset the batch-local reference shape so the next decode only compares layers within // the same batch, not against the previous prompt/verify batch. capture.row_count = 0; @@ -662,59 +494,42 @@ static bool llama_spec_prepare_dflash_capture( return false; } - auto & profile = ctx->dflash.profile; - profile.capture_prepare_calls++; - const int64_t t_sync_us = ggml_time_us(); llama_synchronize(ctx); - profile.capture_prepare_sync_us += (uint64_t) (ggml_time_us() - t_sync_us); auto & capture = *ctx->dflash.capture; row_count = capture.row_count; row_width = capture.row_width; n_layers = (int32_t) capture.layer_ids.size(); if (row_count <= 0 || row_width <= 0 || n_layers <= 0 || capture.layer_rows.size() != (size_t) n_layers) { - profile.capture_prepare_failures++; return false; } if (capture.capture_batch_id == 0 || capture.layer_seen_batch_id.size() != (size_t) n_layers) { - profile.capture_prepare_failures++; - profile.capture_layer_batch_mismatch++; - if (profile.capture_layer_batch_mismatch <= 3) { - LLAMA_LOG_WARN("%s: DFlash capture batch markers are not initialized (batch_id=%llu layers=%zu expected=%d)\n", - __func__, - (unsigned long long) capture.capture_batch_id, - capture.layer_seen_batch_id.size(), - n_layers); - } + LLAMA_LOG_WARN("%s: DFlash capture batch markers are not initialized (batch_id=%llu layers=%zu expected=%d)\n", + __func__, + (unsigned long long) capture.capture_batch_id, + capture.layer_seen_batch_id.size(), + n_layers); return false; } for (int32_t layer_idx = 0; layer_idx < n_layers; ++layer_idx) { if (capture.layer_seen_batch_id[(size_t) layer_idx] != capture.capture_batch_id) { - profile.capture_prepare_failures++; - profile.capture_layer_batch_mismatch++; - if (profile.capture_layer_batch_mismatch <= 3) { - LLAMA_LOG_WARN("%s: DFlash capture is stale for layer %d (seen_batch=%llu current_batch=%llu rows=%d width=%d)\n", - __func__, - capture.layer_ids[(size_t) layer_idx], - (unsigned long long) capture.layer_seen_batch_id[(size_t) layer_idx], - (unsigned long long) capture.capture_batch_id, - row_count, - row_width); - } + LLAMA_LOG_WARN("%s: DFlash capture is stale for layer %d (seen_batch=%llu current_batch=%llu rows=%d width=%d)\n", + __func__, + capture.layer_ids[(size_t) layer_idx], + (unsigned long long) capture.layer_seen_batch_id[(size_t) layer_idx], + (unsigned long long) capture.capture_batch_id, + row_count, + row_width); return false; } const auto & rows = capture.layer_rows[(size_t) layer_idx]; if (rows.size() != (size_t) row_count * (size_t) row_width) { - profile.capture_prepare_failures++; - profile.capture_layer_shape_mismatch++; - if (profile.capture_layer_shape_mismatch <= 3) { - LLAMA_LOG_WARN("%s: DFlash capture rows mismatch for layer %d: got=%zu expected=%zu (rows=%d width=%d)\n", - __func__, capture.layer_ids[(size_t) layer_idx], rows.size(), - (size_t) row_count * (size_t) row_width, row_count, row_width); - } + LLAMA_LOG_WARN("%s: DFlash capture rows mismatch for layer %d: got=%zu expected=%zu (rows=%d width=%d)\n", + __func__, capture.layer_ids[(size_t) layer_idx], rows.size(), + (size_t) row_count * (size_t) row_width, row_count, row_width); return false; } } @@ -722,194 +537,6 @@ static bool llama_spec_prepare_dflash_capture( return true; } -static bool llama_dflash_contract_log_enabled() { - const char * env = std::getenv("IK_DFLASH_CONTRACT_LOG"); - if (env == nullptr || *env == '\0') { - return false; - } - - return std::strcmp(env, "0") != 0 && - std::strcmp(env, "false") != 0 && - std::strcmp(env, "off") != 0; -} - -template -static std::string llama_dflash_contract_format_values( - const std::vector & values, - size_t edge_count = 4) { - std::ostringstream oss; - oss << '['; - if (values.empty()) { - oss << ']'; - return oss.str(); - } - - const size_t head = std::min(edge_count, values.size()); - for (size_t i = 0; i < head; ++i) { - if (i > 0) { - oss << ','; - } - oss << values[i]; - } - - if (values.size() > edge_count * 2) { - oss << ",...,"; - for (size_t i = values.size() - edge_count; i < values.size(); ++i) { - if (i > values.size() - edge_count) { - oss << ','; - } - oss << values[i]; - } - } else { - for (size_t i = head; i < values.size(); ++i) { - oss << ',' << values[i]; - } - } - - oss << ']'; - return oss.str(); -} - -static std::vector llama_dflash_contract_collect_batch_positions( - const llama_batch & batch, - const std::vector & batch_indices) { - std::vector positions; - positions.reserve(batch_indices.size()); - for (int32_t batch_index : batch_indices) { - positions.push_back(batch.pos[batch_index]); - } - return positions; -} - -static void llama_dflash_contract_summarize_positions( - const std::vector & positions, - llama_pos & first_pos, - llama_pos & last_pos, - int32_t & gap_count, - int32_t & nonmono_count) { - first_pos = -1; - last_pos = -1; - gap_count = 0; - nonmono_count = 0; - if (positions.empty()) { - return; - } - - first_pos = positions.front(); - last_pos = positions.back(); - for (size_t i = 1; i < positions.size(); ++i) { - if (positions[i] <= positions[i - 1]) { - nonmono_count++; - } else if (positions[i] != positions[i - 1] + 1) { - gap_count++; - } - } -} - -static void llama_dflash_contract_log_feature_view( - const char * kind, - llama_seq_id seq_id, - const llama_batch & batch, - int32_t row_count, - int32_t row_width, - int32_t n_layers, - int32_t batch_row_offset, - const std::vector & row_indices, - const std::vector & batch_indices) { - if (!llama_dflash_contract_log_enabled()) { - return; - } - - static std::atomic counter = 0; - const uint64_t ordinal = counter.fetch_add(1, std::memory_order_relaxed); - if (ordinal >= 8) { - return; - } - - const std::vector positions = llama_dflash_contract_collect_batch_positions(batch, batch_indices); - llama_pos first_pos = -1; - llama_pos last_pos = -1; - int32_t gap_count = 0; - int32_t nonmono_count = 0; - llama_dflash_contract_summarize_positions(positions, first_pos, last_pos, gap_count, nonmono_count); - - LLAMA_LOG_INFO("%s[%llu]: kind=%s seq=%d batch_tokens=%d capture_rows=%d row_width=%d layers=%d batch_row_offset=%d row_indices=%s batch_indices=%s batch_pos=%s pos=[%d..%d] gaps=%d nonmono=%d\n", - __func__, - (unsigned long long) (ordinal + 1), - kind, - (int) seq_id, - batch.n_tokens, - row_count, - row_width, - n_layers, - batch_row_offset, - llama_dflash_contract_format_values(row_indices).c_str(), - llama_dflash_contract_format_values(batch_indices).c_str(), - llama_dflash_contract_format_values(positions).c_str(), - (int) first_pos, - (int) last_pos, - gap_count, - nonmono_count); -} - -static void llama_dflash_contract_log_output_indices( - struct llama_context * ctx, - const std::vector & output_indices) { - if (!llama_dflash_contract_log_enabled()) { - return; - } - - static std::atomic counter = 0; - const uint64_t ordinal = counter.fetch_add(1, std::memory_order_relaxed); - if (ordinal >= 8) { - return; - } - - int32_t row_count = 0; - int32_t row_width = 0; - int32_t n_layers = 0; - const bool have_capture = llama_spec_prepare_dflash_capture(ctx, row_count, row_width, n_layers); - - LLAMA_LOG_INFO("%s[%llu]: output_indices=%s capture_rows=%d row_width=%d layers=%d have_capture=%s\n", - __func__, - (unsigned long long) (ordinal + 1), - llama_dflash_contract_format_values(output_indices).c_str(), - row_count, - row_width, - n_layers, - have_capture ? "true" : "false"); -} - -void llama_dflash_contract_log_accept( - int slot_id, - bool is_dflash, - const char * path, - bool any_rejected, - size_t n_draft, - size_t n_accepted, - llama_pos pos_base, - const std::vector & output_indices) { - if (!llama_dflash_contract_log_enabled() || !is_dflash) { - return; - } - - static std::atomic counter = 0; - const uint64_t ordinal = counter.fetch_add(1, std::memory_order_relaxed); - if (ordinal >= 8) { - return; - } - - LLAMA_LOG_INFO("dflash contract accept[%llu]: slot=%d path=%s rejected=%s drafted=%zu accepted=%zu pos_base=%d output_indices=%s\n", - (unsigned long long) (ordinal + 1), - slot_id, - path, - any_rejected ? "true" : "false", - n_draft, - n_accepted, - (int) pos_base, - llama_dflash_contract_format_values(output_indices).c_str()); -} - static bool llama_spec_materialize_dflash_rows_prepared( struct llama_context * ctx, int32_t row_count, @@ -928,9 +555,6 @@ static bool llama_spec_materialize_dflash_rows( int32_t row_width = 0; int32_t n_layers = 0; if (!llama_spec_prepare_dflash_capture(ctx, row_count, row_width, n_layers)) { - if (ctx != nullptr) { - ctx->dflash.profile.capture_materialize_failures++; - } return false; } @@ -951,12 +575,7 @@ static bool llama_spec_materialize_dflash_rows_prepared( return false; } - auto & profile = ctx->dflash.profile; - profile.capture_materialize_calls++; - const int64_t t_start_us = ggml_time_us(); - if (row_count <= 0 || row_width <= 0 || n_layers <= 0 || ctx->dflash.capture == nullptr) { - profile.capture_materialize_failures++; return false; } @@ -972,7 +591,6 @@ static bool llama_spec_materialize_dflash_rows_prepared( if (row_index < 0 || row_index >= row_count) { rows_out.clear(); combined_width = 0; - profile.capture_materialize_failures++; return false; } @@ -983,10 +601,6 @@ static bool llama_spec_materialize_dflash_rows_prepared( } } - profile.capture_materialize_us += (uint64_t) (ggml_time_us() - t_start_us); - profile.capture_materialize_rows += (uint64_t) row_indices.size(); - profile.capture_materialize_bytes += rows_out.size() * sizeof(float); - return true; } @@ -1040,17 +654,6 @@ bool llama_spec_get_dflash_feature_view( }); } - llama_dflash_contract_log_feature_view( - "batch", - view.rows.empty() ? -1 : view.rows.front().seq_id, - batch, - row_count, - row_width, - n_layers, - batch_row_offset, - row_indices, - batch_indices); - return true; } @@ -1109,17 +712,6 @@ bool llama_spec_get_dflash_feature_view_for_seq( }); } - llama_dflash_contract_log_feature_view( - "seq", - seq_id, - batch, - row_count, - row_width, - n_layers, - batch_row_offset, - row_indices, - batch_indices); - return true; } @@ -1133,7 +725,5 @@ bool llama_spec_copy_dflash_rows_from_output_indices( return false; } - llama_dflash_contract_log_output_indices(ctx, output_indices); - return hidden_rows.size() == (size_t) output_indices.size() * (size_t) combined_width; } diff --git a/src/llama-spec-features-dflash.h b/src/llama-spec-features-dflash.h index e05f2a91..c893db7a 100644 --- a/src/llama-spec-features-dflash.h +++ b/src/llama-spec-features-dflash.h @@ -11,147 +11,6 @@ struct llama_model; struct ggml_tensor; struct llama_spec_feature_view; -struct llama_dflash_profile_stats { - uint64_t decode_internal_chunks = 0; - uint64_t decode_graph_rebuilds = 0; - uint64_t decode_sync_profile_points = 0; - uint64_t decode_prelude_us = 0; - uint64_t decode_sched_reset_us = 0; - uint64_t decode_build_graph_us = 0; - uint64_t decode_sched_alloc_graph_us = 0; - uint64_t decode_set_inputs_us = 0; - uint64_t decode_graph_compute_us = 0; - uint64_t decode_result_us = 0; - uint64_t decode_embedding_us = 0; - uint64_t decode_final_sched_reset_us = 0; - - uint64_t decode_output_reserve_calls = 0; - uint64_t decode_output_reserve_us = 0; - uint64_t decode_output_reserve_reallocs = 0; - uint64_t decode_output_reserve_realloc_bytes = 0; - uint64_t decode_prepare_calls = 0; - uint64_t decode_prepare_us = 0; - uint64_t decode_prepare_failures = 0; - - uint64_t set_target_copy_calls = 0; - uint64_t set_target_copy_us = 0; - uint64_t set_target_rows = 0; - uint64_t set_target_copy_bytes = 0; - uint64_t set_target_missing_positions = 0; - uint64_t set_target_non_monotonic_positions = 0; - - uint64_t capture_prepare_calls = 0; - uint64_t capture_prepare_sync_us = 0; - uint64_t capture_prepare_failures = 0; - uint64_t capture_layer_shape_mismatch = 0; - uint64_t capture_layer_batch_mismatch = 0; - uint64_t capture_prompt_batches = 0; - uint64_t capture_prompt_shape_changes = 0; - uint64_t capture_verify_batches = 0; - uint64_t capture_verify_shape_changes = 0; - uint64_t capture_materialize_calls = 0; - uint64_t capture_materialize_rows = 0; - uint64_t capture_materialize_bytes = 0; - uint64_t capture_materialize_us = 0; - uint64_t capture_materialize_failures = 0; - - uint64_t graph_prepare_calls = 0; - uint64_t graph_prepare_total_us = 0; - uint64_t graph_feature_copy_us = 0; - uint64_t graph_pos_copy_us = 0; - uint64_t graph_mask_build_us = 0; - uint64_t graph_kv_cache_build_us = 0; - uint64_t graph_kv_cache_reserve_us = 0; - uint64_t graph_kv_cache_reset_us = 0; - uint64_t graph_kv_cache_alloc_us = 0; - uint64_t graph_kv_cache_feature_upload_us = 0; - uint64_t graph_kv_cache_pos_upload_us = 0; - uint64_t graph_kv_cache_compute_us = 0; - uint64_t graph_kv_cache_sync_us = 0; - uint64_t graph_kv_cache_read_concat_pad_us = 0; - uint64_t graph_kv_cache_read_concat_pad_calls = 0; - uint64_t graph_kv_cache_cached_bytes = 0; - uint64_t graph_kv_cache_calls = 0; - uint64_t graph_kv_workspace_build_us = 0; - uint64_t graph_kv_workspace_reserve_us = 0; - uint64_t graph_kv_workspace_reset_us = 0; - uint64_t graph_kv_workspace_alloc_us = 0; - uint64_t graph_kv_workspace_compute_us = 0; - uint64_t graph_kv_workspace_sync_us = 0; - uint64_t graph_kv_workspace_calls = 0; - uint64_t graph_kv_node_fused_target_calls = 0; - uint64_t graph_kv_node_fused_target_us = 0; - uint64_t graph_kv_node_k_proj_calls = 0; - uint64_t graph_kv_node_k_proj_us = 0; - uint64_t graph_kv_node_k_norm_calls = 0; - uint64_t graph_kv_node_k_norm_us = 0; - uint64_t graph_kv_node_k_rope_calls = 0; - uint64_t graph_kv_node_k_rope_us = 0; - uint64_t graph_kv_node_v_proj_calls = 0; - uint64_t graph_kv_node_v_proj_us = 0; - uint64_t graph_kv_node_k_store_calls = 0; - uint64_t graph_kv_node_k_store_us = 0; - uint64_t graph_kv_node_v_store_calls = 0; - uint64_t graph_kv_node_v_store_us = 0; - uint64_t graph_main_node_qcur_calls = 0; - uint64_t graph_main_node_qcur_us = 0; - uint64_t graph_main_node_k_draft_calls = 0; - uint64_t graph_main_node_k_draft_us = 0; - uint64_t graph_main_node_v_draft_calls = 0; - uint64_t graph_main_node_v_draft_us = 0; - uint64_t graph_main_node_k_ctx_view_calls = 0; - uint64_t graph_main_node_k_ctx_view_us = 0; - uint64_t graph_main_node_v_ctx_view_calls = 0; - uint64_t graph_main_node_v_ctx_view_us = 0; - uint64_t graph_main_node_k_concat_calls = 0; - uint64_t graph_main_node_k_concat_us = 0; - uint64_t graph_main_node_v_concat_calls = 0; - uint64_t graph_main_node_v_concat_us = 0; - uint64_t graph_main_node_k_pad_calls = 0; - uint64_t graph_main_node_k_pad_us = 0; - uint64_t graph_main_node_v_pad_calls = 0; - uint64_t graph_main_node_v_pad_us = 0; - uint64_t graph_main_node_k_perm_cont_calls = 0; - uint64_t graph_main_node_k_perm_cont_us = 0; - uint64_t graph_main_node_v_perm_cont_calls = 0; - uint64_t graph_main_node_v_perm_cont_us = 0; - uint64_t graph_main_node_flash_attn_calls = 0; - uint64_t graph_main_node_flash_attn_us = 0; - uint64_t graph_main_node_attn_out_calls = 0; - uint64_t graph_main_node_attn_out_us = 0; - uint64_t graph_main_node_ffn_calls = 0; - uint64_t graph_main_node_ffn_us = 0; - uint64_t graph_main_node_result_rows_calls = 0; - uint64_t graph_main_node_result_rows_us = 0; - uint64_t graph_main_node_result_norm_calls = 0; - uint64_t graph_main_node_result_norm_us = 0; - uint64_t graph_main_node_result_calls = 0; - uint64_t graph_main_node_result_us = 0; - uint64_t graph_feature_bytes = 0; - uint64_t graph_pos_bytes = 0; - uint64_t graph_mask_bytes = 0; - uint64_t graph_visible_kv_sum = 0; - uint64_t graph_visible_kv_max = 0; - uint64_t graph_pos_fallbacks = 0; - uint64_t graph_pos_non_monotonic = 0; - uint64_t graph_shape_failures = 0; - uint64_t graph_mask_overflow = 0; - - int32_t last_n_rows = 0; - int32_t last_width = 0; - int32_t last_cross_ctx = 0; - int32_t last_left_pad = 0; - int32_t last_n_tokens = 0; - int32_t last_n_kv_total = 0; - int32_t last_kv_cache_host_layers = 0; - int32_t capture_prompt_last_rows = 0; - int32_t capture_prompt_last_width = 0; - int32_t capture_verify_last_rows = 0; - int32_t capture_verify_last_width = 0; - llama_pos last_pos_first = -1; - llama_pos last_pos_last = -1; -}; - struct llama_dflash_window_update { uint64_t version = 0; int32_t keep_rows = 0; @@ -216,11 +75,9 @@ llama_dflash_kv_cache_transition llama_plan_dflash_kv_cache_transition_for_ctx( const llama_dflash_window_update & window_update, int32_t n_rows); -void llama_dflash_profile_reset(struct llama_context * ctx); void llama_reset_dflash_kv_cache_state(struct llama_context * ctx); void llama_set_dflash_visible_cross_ctx(struct llama_context * ctx, int32_t cross_ctx); int32_t llama_get_dflash_visible_cross_ctx(const struct llama_context * ctx); -bool llama_dflash_profile_get_stats(const struct llama_context * ctx, llama_dflash_profile_stats * stats); int32_t llama_model_dflash_block_size(const struct llama_model * model); int32_t llama_model_dflash_mask_token_id(const struct llama_model * model); @@ -277,13 +134,3 @@ bool llama_spec_copy_dflash_rows_from_output_indices( struct llama_context * ctx, const std::vector & output_indices, std::vector & hidden_rows); - -void llama_dflash_contract_log_accept( - int slot_id, - bool is_dflash, - const char * path, - bool any_rejected, - size_t n_draft, - size_t n_accepted, - llama_pos pos_base, - const std::vector & output_indices); diff --git a/src/llama.cpp b/src/llama.cpp index 8adab80e..3389b3f1 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -19,7 +19,6 @@ #include "llama-context.h" #include "llama-spec-features.h" #include "llama-dflash.h" -#include "llama-dflash-profile.h" #include "llama-quantize.h" #include "unicode.h" @@ -697,8 +696,8 @@ void llama_context::set_mtp_op_type(llama_mtp_op_type value) { } llama_context::~llama_context() { - if (dflash_sched != nullptr) { - ggml_backend_sched_free(dflash_sched); + if (dflash.kv.cache_sched != nullptr) { + ggml_backend_sched_free(dflash.kv.cache_sched); } free_dflash_kv_cache_tensors(); ggml_backend_sched_free(sched); @@ -5096,10 +5095,6 @@ static int llama_decode_internal( } lctx.n_queued_tokens += n_tokens_all; - auto * dflash_profile = lctx.model.arch == LLM_ARCH_DFLASH_DRAFT ? &lctx.dflash_profile : nullptr; - const bool dflash_decode_timing = dflash_profile != nullptr && llama_env_flag_enabled("IK_DFLASH_DECODE_TIMING"); - const bool dflash_draft_node_timing = dflash_profile != nullptr && llama_env_flag_enabled("IK_DFLASH_DRAFT_NODE_TIMING"); - auto & kv_self = lctx.kv_self; const int64_t n_embd = hparams.n_embd; @@ -5139,20 +5134,7 @@ static int llama_decode_internal( n_outputs_embd = has_mtp && cparams.mtp_op_type == MTP_OP_NONE ? n_tokens_all : n_outputs; const size_t required_outputs = std::max(n_outputs, n_outputs_embd); const bool is_dflash_decode = lctx.model.arch == LLM_ARCH_DFLASH_DRAFT; - const size_t output_buf_size_before = lctx.buf_output ? ggml_backend_buffer_get_size(lctx.buf_output) : 0; - const int64_t t_output_reserve_us = is_dflash_decode ? ggml_time_us() : 0; const size_t reserved_outputs = llama_output_reserve(lctx, required_outputs); - if (is_dflash_decode) { - auto & profile = lctx.dflash_profile; - profile.decode_output_reserve_calls++; - profile.decode_output_reserve_us += (uint64_t) (ggml_time_us() - t_output_reserve_us); - - const size_t output_buf_size_after = lctx.buf_output ? ggml_backend_buffer_get_size(lctx.buf_output) : 0; - if (output_buf_size_after > output_buf_size_before) { - profile.decode_output_reserve_reallocs++; - profile.decode_output_reserve_realloc_bytes += (uint64_t) output_buf_size_after; - } - } if (reserved_outputs < required_outputs) { LLAMA_LOG_ERROR("%s: could not reserve space for batch with %zu outputs\n", __func__, required_outputs); return -2; @@ -5184,10 +5166,6 @@ static int llama_decode_internal( #if IK_PRINT_TIMING auto tim1 = ggml_time_us(); #endif - const int64_t t_dflash_prelude_us = dflash_decode_timing ? ggml_time_us() : 0; - if (dflash_decode_timing) { - dflash_profile->decode_internal_chunks++; - } uint32_t n_tokens = std::min(n_ubatch, n_tokens_all - cur_token); if (llm_arch_is_hybrid(model.arch) && n_tokens > 1 && @@ -5353,55 +5331,36 @@ static int llama_decode_internal( auto tim2 = ggml_time_us(); printf("prelude(...): %d us\n", int(tim2-tim1)); #endif - if (dflash_decode_timing) { - dflash_profile->decode_prelude_us += (uint64_t) (ggml_time_us() - t_dflash_prelude_us); - } - #if IK_PRINT_TIMING tim1 = ggml_time_us(); #endif auto & prev = cparams.mtp_op_type == MTP_OP_NONE ? lctx.prev : lctx.prev_mtp; ggml_cgraph * gf = nullptr; if (!lctx.can_reuse_graph(u_batch)) { - if (dflash_decode_timing) { - dflash_profile->decode_graph_rebuilds++; - } - const int64_t t_dflash_sched_reset_us = dflash_decode_timing ? ggml_time_us() : 0; lctx.reset_scheduler(); ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data); #if IK_PRINT_TIMING tim2 = ggml_time_us(); printf("sched_reset(...): %d us\n", int(tim2-tim1)); #endif - if (dflash_decode_timing) { - dflash_profile->decode_sched_reset_us += (uint64_t) (ggml_time_us() - t_dflash_sched_reset_us); - } #if IK_PRINT_TIMING tim1 = ggml_time_us(); #endif - const int64_t t_dflash_build_graph_us = dflash_decode_timing ? ggml_time_us() : 0; gf = llm_build_context::llama_build_graph(lctx, u_batch, false); #if IK_PRINT_TIMING tim2 = ggml_time_us(); printf("build_graph(...): %d us\n", int(tim2-tim1)); #endif - if (dflash_decode_timing) { - dflash_profile->decode_build_graph_us += (uint64_t) (ggml_time_us() - t_dflash_build_graph_us); - } #if IK_PRINT_TIMING tim1 = ggml_time_us(); #endif - const int64_t t_dflash_sched_alloc_us = dflash_decode_timing ? ggml_time_us() : 0; ggml_backend_sched_alloc_graph(lctx.sched, gf); #if IK_PRINT_TIMING tim2 = ggml_time_us(); printf("sched_alloc_graph(...): %d us\n", int(tim2-tim1)); #endif - if (dflash_decode_timing) { - dflash_profile->decode_sched_alloc_graph_us += (uint64_t) (ggml_time_us() - t_dflash_sched_alloc_us); - } //if (u_batch.n_tokens == 1 && u_batch.embd == nullptr && lctx.cparams.graph_reuse) { if (u_batch.embd == nullptr && lctx.cparams.graph_reuse && !((lctx.model.arch == LLM_ARCH_GEMMA4_MTP || lctx.model.arch == LLM_ARCH_GEMMA4_ASSISTANT) && lctx.mtp_target_ctx != nullptr)) { @@ -5422,15 +5381,8 @@ static int llama_decode_internal( } } - if (dflash_profile != nullptr) { - dflash_profile->decode_prepare_calls++; - const int64_t t_prepare_dflash_us = ggml_time_us(); - if (!llama_prepare_dflash_graph_inputs(lctx, n_tokens)) { - dflash_profile->decode_prepare_failures++; - dflash_profile->decode_prepare_us += (uint64_t) (ggml_time_us() - t_prepare_dflash_us); - return GGML_STATUS_FAILED; - } - dflash_profile->decode_prepare_us += (uint64_t) (ggml_time_us() - t_prepare_dflash_us); + if (is_dflash_decode && !llama_prepare_dflash_graph_inputs(lctx, n_tokens)) { + return GGML_STATUS_FAILED; } // the output is always the last tensor in the graph @@ -5438,7 +5390,7 @@ static int llama_decode_internal( struct ggml_tensor * embd = nullptr; // DFlash GPU argmax draft_argmax node - if (lctx.dflash_draft_tokens_tensor != nullptr && + if (lctx.dflash.draft_tokens_tensor != nullptr && strcmp(res->name, "result_output") != 0) { for (int i = gf->n_nodes - 2; i >= 0; --i) { if (strcmp(gf->nodes[i]->name, "result_output") == 0) { @@ -5489,39 +5441,18 @@ static int llama_decode_internal( #if IK_PRINT_TIMING == 1 tim1 = ggml_time_us(); #endif - const int64_t t_dflash_set_inputs_us = dflash_decode_timing ? ggml_time_us() : 0; llama_set_inputs(lctx, u_batch); #if IK_PRINT_TIMING == 1 tim2 = ggml_time_us(); printf("set_inputs(...): %d us\n", int(tim2-tim1)); #endif - if (dflash_decode_timing) { - dflash_profile->decode_set_inputs_us += (uint64_t) (ggml_time_us() - t_dflash_set_inputs_us); - } - #if IK_PRINT_TIMING tim1 = ggml_time_us(); #endif - if (lctx.dflash_kv_workspace_sync_pending) { + if (lctx.dflash.kv.workspace_sync_pending) { llama_sync_dflash_workspace_if_pending(lctx); } - const int64_t t_dflash_graph_compute_us = dflash_decode_timing ? ggml_time_us() : 0; - llama_dflash_main_node_profiler draft_node_profiler; - if (dflash_draft_node_timing) { - draft_node_profiler.profile = dflash_profile; - draft_node_profiler.prev_callback = lctx.cparams.cb_eval; - draft_node_profiler.prev_user_data = lctx.cparams.cb_eval_user_data; - ggml_backend_sched_set_eval_callback(lctx.sched, llama_dflash_main_node_eval_callback, &draft_node_profiler); - } llama_graph_compute(lctx, gf, n_threads); - if (dflash_draft_node_timing) { - ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data); - } - if (dflash_decode_timing) { - llama_synchronize(&lctx); - dflash_profile->decode_sync_profile_points++; - dflash_profile->decode_graph_compute_us += (uint64_t) (ggml_time_us() - t_dflash_graph_compute_us); - } #if IK_PRINT_TIMING llama_synchronize(&lctx); tim2 = ggml_time_us(); @@ -5547,16 +5478,16 @@ static int llama_decode_internal( // ggml_graph_dump_dot(gf, NULL, "llama.dot"); //} - lctx.dflash_draft_tokens.clear(); - if (lctx.dflash_draft_tokens_tensor != nullptr) { + lctx.dflash.draft_tokens.clear(); + if (lctx.dflash.draft_tokens_tensor != nullptr) { ggml_backend_t backend_argmax = ggml_backend_sched_get_tensor_backend( - lctx.sched, lctx.dflash_draft_tokens_tensor); + lctx.sched, lctx.dflash.draft_tokens_tensor); if (backend_argmax != nullptr) { - const int64_t n_tokens_argmax = lctx.dflash_draft_tokens_tensor->ne[0]; - lctx.dflash_draft_tokens.resize((size_t) n_tokens_argmax); + const int64_t n_tokens_argmax = lctx.dflash.draft_tokens_tensor->ne[0]; + lctx.dflash.draft_tokens.resize((size_t) n_tokens_argmax); ggml_backend_tensor_get_async(backend_argmax, - lctx.dflash_draft_tokens_tensor, - lctx.dflash_draft_tokens.data(), 0, + lctx.dflash.draft_tokens_tensor, + lctx.dflash.draft_tokens.data(), 0, (size_t) n_tokens_argmax * sizeof(int32_t)); } } @@ -5564,7 +5495,7 @@ static int llama_decode_internal( // extract logits { const bool dflash_skip_logits = (lctx.model.arch == LLM_ARCH_DFLASH_DRAFT - && !lctx.dflash_draft_tokens.empty()); + && !lctx.dflash.draft_tokens.empty()); if (dflash_skip_logits) { res = nullptr; } @@ -5573,7 +5504,6 @@ static int llama_decode_internal( #if IK_PRINT_TIMING tim1 = ggml_time_us(); #endif - const int64_t t_dflash_get_result_us = dflash_decode_timing ? ggml_time_us() : 0; // Do not process logits if MTP is only updating the KV cache. if (cparams.mtp_op_type != MTP_OP_WARMUP) { // && cparams.mtp_op_type != MTP_OP_UPDATE_ACCEPTED) { ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched, res); @@ -5604,11 +5534,6 @@ static int llama_decode_internal( } } } - if (dflash_decode_timing) { - llama_synchronize(&lctx); - dflash_profile->decode_sync_profile_points++; - dflash_profile->decode_result_us += (uint64_t) (ggml_time_us() - t_dflash_get_result_us); - } #if IK_PRINT_TIMING tim2 = ggml_time_us(); printf("get_result(...): %d us\n", int(tim2-tim1)); @@ -5621,7 +5546,6 @@ static int llama_decode_internal( #if IK_PRINT_TIMING tim1 = ggml_time_us(); #endif - const int64_t t_dflash_get_embedding_us = dflash_decode_timing ? ggml_time_us() : 0; ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched, embd); GGML_ASSERT(backend_embd != nullptr); @@ -5661,11 +5585,6 @@ static int llama_decode_internal( GGML_ABORT("unknown pooling type"); } } - if (dflash_decode_timing) { - llama_synchronize(&lctx); - dflash_profile->decode_sync_profile_points++; - dflash_profile->decode_embedding_us += (uint64_t) (ggml_time_us() - t_dflash_get_embedding_us); - } #if IK_PRINT_TIMING tim2 = ggml_time_us(); printf("get_embedding(...): %d us\n", int(tim2-tim1)); @@ -5709,13 +5628,9 @@ static int llama_decode_internal( #if IK_PRINT_TIMING auto tim1 = ggml_time_us(); #endif - const int64_t t_dflash_final_sched_reset_us = dflash_decode_timing ? ggml_time_us() : 0; if (!lctx.prev) { lctx.reset_scheduler(); } - if (dflash_decode_timing) { - dflash_profile->decode_final_sched_reset_us += (uint64_t) (ggml_time_us() - t_dflash_final_sched_reset_us); - } #if IK_PRINT_TIMING auto tim2 = ggml_time_us(); printf("sched_reset(...): %d us\n", int(tim2-tim1)); @@ -9838,10 +9753,10 @@ float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) { } llama_token llama_get_dflash_draft_token_ith(struct llama_context * ctx, int32_t i) { - if ((size_t) i >= ctx->dflash_draft_tokens.size()) { + if ((size_t) i >= ctx->dflash.draft_tokens.size()) { return LLAMA_TOKEN_NULL; } - return ctx->dflash_draft_tokens[(size_t) i]; + return ctx->dflash.draft_tokens[(size_t) i]; } float * llama_get_embeddings(struct llama_context * ctx) { From ad24046b51640262cc4ad6d902b54102f024061c Mon Sep 17 00:00:00 2001 From: SamuelOliveirads Date: Mon, 15 Jun 2026 18:22:56 -0300 Subject: [PATCH 13/13] minor refactor in DFlash kv cache graph --- gguf-py/gguf/constants.py | 164 ++++++++++++++++++------------------- src/llama-dflash.cpp | 135 +++++++++++++----------------- src/llama-load-tensors.cpp | 6 +- 3 files changed, 143 insertions(+), 162 deletions(-) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 19b4d99f..76c81d9e 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -1764,89 +1764,89 @@ class ExpertGatingFuncType(IntEnum): # ALL VALUES SHOULD BE THE SAME HERE AS THEY ARE OVER THERE. class LlamaFileType(IntEnum): ALL_F32 = 0 - MOSTLY_F16 = 1 # except 1d tensors - MOSTLY_Q4_0 = 2 # except 1d tensors - MOSTLY_Q4_1 = 3 # except 1d tensors - MOSTLY_Q8_0 = 7 # except 1d tensors - MOSTLY_Q5_0 = 8 # except 1d tensors - MOSTLY_Q5_1 = 9 # except 1d tensors - MOSTLY_Q2_K = 10 # except 1d tensors - MOSTLY_Q3_K_S = 11 # except 1d tensors - MOSTLY_Q3_K_M = 12 # except 1d tensors - MOSTLY_Q3_K_L = 13 # except 1d tensors - MOSTLY_Q4_K_S = 14 # except 1d tensors - MOSTLY_Q4_K_M = 15 # except 1d tensors - MOSTLY_Q5_K_S = 16 # except 1d tensors - MOSTLY_Q5_K_M = 17 # except 1d tensors - MOSTLY_Q6_K = 18 # except 1d tensors - MOSTLY_IQ2_XXS = 19 # except 1d tensors - MOSTLY_IQ2_XS = 20 # except 1d tensors - MOSTLY_Q2_K_S = 21 # except 1d tensors - MOSTLY_IQ3_XS = 22 # except 1d tensors - MOSTLY_IQ3_XXS = 23 # except 1d tensors - MOSTLY_IQ1_S = 24 # except 1d tensors - MOSTLY_IQ4_NL = 25 # except 1d tensors - MOSTLY_IQ3_S = 26 # except 1d tensors - MOSTLY_IQ3_M = 27 # except 1d tensors - MOSTLY_IQ2_S = 28 # except 1d tensors - MOSTLY_IQ2_M = 29 # except 1d tensors - MOSTLY_IQ4_XS = 30 # except 1d tensors - MOSTLY_IQ1_M = 31 # except 1d tensors - MOSTLY_BF16 = 32 # except 1d tensors - MOSTLY_Q4_0_4_4 = 33 # except 1d tensors - MOSTLY_Q4_0_4_8 = 34 # except 1d tensors - MOSTLY_Q4_0_8_8 = 35 # except 1d tensors - MOSTLY_MXFP4 = 38 # except 1d tensors, 38 to be compatible with mainline + MOSTLY_F16 = 1 #except 1d tensors + MOSTLY_Q4_0 = 2 #except 1d tensors + MOSTLY_Q4_1 = 3 #except 1d tensors + MOSTLY_Q8_0 = 7 #except 1d tensors + MOSTLY_Q5_0 = 8 #except 1d tensors + MOSTLY_Q5_1 = 9 #except 1d tensors + MOSTLY_Q2_K = 10 #except 1d tensors + MOSTLY_Q3_K_S = 11 #except 1d tensors + MOSTLY_Q3_K_M = 12 #except 1d tensors + MOSTLY_Q3_K_L = 13 #except 1d tensors + MOSTLY_Q4_K_S = 14 #except 1d tensors + MOSTLY_Q4_K_M = 15 #except 1d tensors + MOSTLY_Q5_K_S = 16 #except 1d tensors + MOSTLY_Q5_K_M = 17 #except 1d tensors + MOSTLY_Q6_K = 18 #except 1d tensors + MOSTLY_IQ2_XXS = 19 #except 1d tensors + MOSTLY_IQ2_XS = 20 #except 1d tensors + MOSTLY_Q2_K_S = 21 #except 1d tensors + MOSTLY_IQ3_XS = 22 #except 1d tensors + MOSTLY_IQ3_XXS = 23 #except 1d tensors + MOSTLY_IQ1_S = 24 #except 1d tensors + MOSTLY_IQ4_NL = 25 #except 1d tensors + MOSTLY_IQ3_S = 26 #except 1d tensors + MOSTLY_IQ3_M = 27 #except 1d tensors + MOSTLY_IQ2_S = 28 #except 1d tensors + MOSTLY_IQ2_M = 29 #except 1d tensors + MOSTLY_IQ4_XS = 30 #except 1d tensors + MOSTLY_IQ1_M = 31 #except 1d tensors + MOSTLY_BF16 = 32 #except 1d tensors + MOSTLY_Q4_0_4_4 = 33 #except 1d tensors + MOSTLY_Q4_0_4_8 = 34 #except 1d tensors + MOSTLY_Q4_0_8_8 = 35 #except 1d tensors + MOSTLY_MXFP4 = 38 #except 1d tensors, 38 to be compatible with mainline - MOSTLY_Q6_0 = 135 # except 1d tensors - MOSTLY_IQ1_BN = 136 # except 1d tensors - MOSTLY_IQ2_BN = 137 # except 1d tensors - MOSTLY_IQ2_K = 138 # except 1d tensors - MOSTLY_IQ3_K = 139 # except 1d tensors - MOSTLY_IQ4_K = 140 # except 1d tensors - MOSTLY_IQ5_K = 141 # except 1d tensors - MOSTLY_IQ6_K = 142 # except 1d tensors - MOSTLY_IQ4_KS = 145 # except 1d tensors - MOSTLY_IQ3_KL = 146 # except 1d tensors - MOSTLY_IQ2_KS = 147 # except 1d tensors - MOSTLY_IQ4_KSS = 148 # except 1d tensors - MOSTLY_Q8_KV = 149 # except 1d tensors - MOSTLY_IQ5_KS = 150 # except 1d tensors - MOSTLY_IQ2_KT = 151 # except 1d tensors - MOSTLY_IQ3_KT = 152 # except 1d tensors - MOSTLY_IQ4_KT = 153 # except 1d tensors - MOSTLY_IQ3_KS = 154 # except 1d tensors - MOSTLY_IQ2_KL = 155 # except 1d tensors - MOSTLY_IQ1_KT = 156 # except 1d tensors + MOSTLY_Q6_0 = 135 #except 1d tensors + MOSTLY_IQ1_BN = 136 #except 1d tensors + MOSTLY_IQ2_BN = 137 #except 1d tensors + MOSTLY_IQ2_K = 138 #except 1d tensors + MOSTLY_IQ3_K = 139 #except 1d tensors + MOSTLY_IQ4_K = 140 #except 1d tensors + MOSTLY_IQ5_K = 141 #except 1d tensors + MOSTLY_IQ6_K = 142 #except 1d tensors + MOSTLY_IQ4_KS = 145 #except 1d tensors + MOSTLY_IQ3_KL = 146 #except 1d tensors + MOSTLY_IQ2_KS = 147 #except 1d tensors + MOSTLY_IQ4_KSS = 148 #except 1d tensors + MOSTLY_Q8_KV = 149 #except 1d tensors + MOSTLY_IQ5_KS = 150 #except 1d tensors + MOSTLY_IQ2_KT = 151 #except 1d tensors + MOSTLY_IQ3_KT = 152 #except 1d tensors + MOSTLY_IQ4_KT = 153 #except 1d tensors + MOSTLY_IQ3_KS = 154 #except 1d tensors + MOSTLY_IQ2_KL = 155 #except 1d tensors + MOSTLY_IQ1_KT = 156 #except 1d tensors - MOSTLY_Q4_0_R8 = 202 # except 1d tensors - MOSTLY_Q8_0_R8 = 207 # except 1d tensors - MOSTLY_Q5_0_R4 = 208 # except 1d tensors - MOSTLY_Q2_K_R4 = 210 # except 1d tensors - MOSTLY_Q3_K_R4 = 211 # except 1d tensors - MOSTLY_Q4_K_R4 = 214 # except 1d tensors - MOSTLY_Q5_K_R4 = 216 # except 1d tensors - MOSTLY_Q6_K_R4 = 218 # except 1d tensors - MOSTLY_IQ2_XXS_R4 = 219 # except 1d tensors - MOSTLY_IQ2_XS_R4 = 220 # except 1d tensors - MOSTLY_IQ3_XXS_R4 = 223 # except 1d tensors - MOSTLY_IQ1_S_R4 = 224 # except 1d tensors - MOSTLY_IQ4_NL_R4 = 225 # except 1d tensors - MOSTLY_IQ3_S_R4 = 226 # except 1d tensors - MOSTLY_IQ2_M_R4 = 229 # except 1d tensors - MOSTLY_IQ4_XS_R8 = 230 # except 1d tensors - MOSTLY_IQ1_M_R4 = 231 # except 1d tensors - MOSTLY_Q6_0_R4 = 335 # except 1d tensors - MOSTLY_BF16_R16 = 232 # except 1d tensors - MOSTLY_IQ2_BN_R4 = 337 # except 1d tensors - MOSTLY_IQ2_K_R4 = 338 # except 1d tensors - MOSTLY_IQ3_K_R4 = 339 # except 1d tensors - MOSTLY_IQ4_K_R4 = 340 # except 1d tensors - MOSTLY_IQ5_K_R4 = 341 # except 1d tensors - MOSTLY_IQ4_KS_R4 = 345 # except 1d tensors - MOSTLY_IQ5_KS_R4 = 350 # except 1d tensors - MOSTLY_Q8_KV_R8 = 398 # except 1d tensors - MOSTLY_Q8_K_R8 = 399 # except 1d tensors + MOSTLY_Q4_0_R8 = 202 #except 1d tensors + MOSTLY_Q8_0_R8 = 207 #except 1d tensors + MOSTLY_Q5_0_R4 = 208 #except 1d tensors + MOSTLY_Q2_K_R4 = 210 #except 1d tensors + MOSTLY_Q3_K_R4 = 211 #except 1d tensors + MOSTLY_Q4_K_R4 = 214 #except 1d tensors + MOSTLY_Q5_K_R4 = 216 #except 1d tensors + MOSTLY_Q6_K_R4 = 218 #except 1d tensors + MOSTLY_IQ2_XXS_R4 = 219 #except 1d tensors + MOSTLY_IQ2_XS_R4 = 220 #except 1d tensors + MOSTLY_IQ3_XXS_R4 = 223 #except 1d tensors + MOSTLY_IQ1_S_R4 = 224 #except 1d tensors + MOSTLY_IQ4_NL_R4 = 225 #except 1d tensors + MOSTLY_IQ3_S_R4 = 226 #except 1d tensors + MOSTLY_IQ2_M_R4 = 229 #except 1d tensors + MOSTLY_IQ4_XS_R8 = 230 #except 1d tensors + MOSTLY_IQ1_M_R4 = 231 #except 1d tensors + MOSTLY_Q6_0_R4 = 335 #except 1d tensors + MOSTLY_BF16_R16 = 232 #except 1d tensors + MOSTLY_IQ2_BN_R4 = 337 #except 1d tensors + MOSTLY_IQ2_K_R4 = 338 #except 1d tensors + MOSTLY_IQ3_K_R4 = 339 #except 1d tensors + MOSTLY_IQ4_K_R4 = 340 #except 1d tensors + MOSTLY_IQ5_K_R4 = 341 #except 1d tensors + MOSTLY_IQ4_KS_R4 = 345 #except 1d tensors + MOSTLY_IQ5_KS_R4 = 350 #except 1d tensors + MOSTLY_Q8_KV_R8 = 398 #except 1d tensors + MOSTLY_Q8_K_R8 = 399 #except 1d tensors GUESSED = 1024 # not specified in the model file @@ -1891,7 +1891,7 @@ class GGUFValueType(IntEnum): # Items here are (block size, type size) QK_K = 256 -# Values generated programatically +#Values generated programatically GGML_QUANT_SIZES: dict[GGMLQuantizationType, tuple[int, int]] = { GGMLQuantizationType.F32 : ( 1, 4), GGMLQuantizationType.F16 : ( 1, 2), diff --git a/src/llama-dflash.cpp b/src/llama-dflash.cpp index bfb4595b..277a6ffd 100644 --- a/src/llama-dflash.cpp +++ b/src/llama-dflash.cpp @@ -12,6 +12,7 @@ #include #include #include +#include #include void llama_sync_dflash_workspace_if_pending(struct llama_context & lctx) { @@ -70,15 +71,15 @@ bool llama_context::ensure_dflash_kv_cache_tensors(int32_t cross_ctx) { const int64_t n_embd_head_v = model.hparams.n_embd_head_v(0); const int64_t n_head_kv = model.hparams.n_head_kv(); - if (dflash.kv.cache_ctx != nullptr && !dflash.kv.k_ctx_cache.empty()) { - const bool cache_matches = (int32_t) dflash.kv.k_ctx_cache.size() == n_layer && - dflash.kv.k_ctx_cache.front() != nullptr && - (int32_t) dflash.kv.k_ctx_cache.front()->ne[2] == target_cross_ctx; - const bool workspace_matches = (int32_t) dflash.kv.k_ctx_workspace.size() == n_layer && - dflash.kv.k_ctx_workspace.front() != nullptr && - (int32_t) dflash.kv.k_ctx_workspace.front()->ne[1] == target_workspace_n_kv_total; + if (dflash.kv.cache_ctx != nullptr && + (int32_t) dflash.kv.k_ctx_cache.size() == n_layer && + (int32_t) dflash.kv.k_ctx_workspace.size() == n_layer) { + const bool cache_matches = + (int32_t) dflash.kv.k_ctx_cache.front()->ne[2] == target_cross_ctx; + const bool workspace_matches = + (int32_t) dflash.kv.k_ctx_workspace.front()->ne[1] == target_workspace_n_kv_total; - if (cache_matches && workspace_matches) { + if (cache_matches && workspace_matches) { return true; } @@ -98,8 +99,6 @@ bool llama_context::ensure_dflash_kv_cache_tensors(int32_t cross_ctx) { dflash.kv.workspace_graph_rows = 0; dflash.kv.workspace_graph_write_pos = 0; dflash.kv.workspace_reserved_rows = 0; - dflash.kv.cache_compute_meta.clear(); - dflash.kv.workspace_compute_meta.clear(); } ggml_init_params params = { @@ -110,6 +109,7 @@ bool llama_context::ensure_dflash_kv_cache_tensors(int32_t cross_ctx) { dflash.kv.cache_ctx = ggml_init(params); if (dflash.kv.cache_ctx == nullptr) { + LLAMA_LOG_ERROR("%s: failed to allocate DFlash K/V cache context\n", __func__); return false; } @@ -123,74 +123,44 @@ bool llama_context::ensure_dflash_kv_cache_tensors(int32_t cross_ctx) { dflash.kv.cache_bufs.reserve((size_t) std::max(1, n_layer) * 4); for (int32_t il = 0; il < n_layer; ++il) { ggml_backend_buffer_type_t layer_buft = llama_dflash_kv_cache_layer_buft(*this, il); + auto alloc_kv_input = [&](ggml_tensor *& tensor, const char * tensor_tag, const char * tensor_name, + int64_t ne0, int64_t ne1, int64_t ne2) -> bool { + tensor = ggml_new_tensor_3d(dflash.kv.cache_ctx, GGML_TYPE_F32, ne0, ne1, ne2); + if (tensor == nullptr) { + LLAMA_LOG_ERROR("%s: failed to create %s for layer %d\n", __func__, tensor_tag, il); + return false; + } - dflash.kv.k_ctx_cache[(size_t) il] = ggml_new_tensor_3d(dflash.kv.cache_ctx, GGML_TYPE_F32, n_embd_head_k, n_head_kv, target_cross_ctx); - dflash.kv.v_ctx_cache[(size_t) il] = ggml_new_tensor_3d(dflash.kv.cache_ctx, GGML_TYPE_F32, n_embd_head_v, n_head_kv, target_cross_ctx); - if (dflash.kv.k_ctx_cache[(size_t) il] == nullptr || dflash.kv.v_ctx_cache[(size_t) il] == nullptr) { + ggml_set_input(tensor); + ggml_format_name(tensor, tensor_name, il); + + const size_t tensor_bytes = ggml_backend_buft_get_alloc_size(layer_buft, tensor); + ggml_backend_buffer_t buf = ggml_backend_buft_alloc_buffer(layer_buft, tensor_bytes); + if (buf == nullptr) { + LLAMA_LOG_ERROR("%s: failed to allocate %s buffer for layer %d (%zu bytes)\n", + __func__, tensor_tag, il, tensor_bytes); + return false; + } + + ggml_backend_buffer_set_usage(buf, GGML_BACKEND_BUFFER_USAGE_COMPUTE); + ggml_backend_tensor_alloc(buf, tensor, ggml_backend_buffer_get_base(buf)); + ggml_backend_buffer_clear(buf, 0); + dflash.kv.cache_bufs.push_back(buf); + + return true; + }; + + if (!alloc_kv_input(dflash.kv.k_ctx_cache[(size_t) il], "dflash_k_ctx_cache", "dflash_k_ctx_cache_%d", + n_embd_head_k, n_head_kv, target_cross_ctx) || + !alloc_kv_input(dflash.kv.v_ctx_cache[(size_t) il], "dflash_v_ctx_cache", "dflash_v_ctx_cache_%d", + n_embd_head_v, n_head_kv, target_cross_ctx) || + !alloc_kv_input(dflash.kv.k_ctx_workspace[(size_t) il], "dflash_k_ctx_workspace", "dflash_k_ctx_workspace_%d", + n_embd_head_k, target_workspace_n_kv_total, n_head_kv) || + !alloc_kv_input(dflash.kv.v_ctx_workspace[(size_t) il], "dflash_v_ctx_workspace", "dflash_v_ctx_workspace_%d", + n_embd_head_v, target_workspace_n_kv_total, n_head_kv)) { free_dflash_kv_cache_tensors(); return false; } - - ggml_set_input(dflash.kv.k_ctx_cache[(size_t) il]); - ggml_set_input(dflash.kv.v_ctx_cache[(size_t) il]); - ggml_format_name(dflash.kv.k_ctx_cache[(size_t) il], "dflash_k_ctx_cache_%d", il); - ggml_format_name(dflash.kv.v_ctx_cache[(size_t) il], "dflash_v_ctx_cache_%d", il); - - const size_t k_bytes = ggml_backend_buft_get_alloc_size(layer_buft, dflash.kv.k_ctx_cache[(size_t) il]); - ggml_backend_buffer_t k_buf = ggml_backend_buft_alloc_buffer(layer_buft, k_bytes); - if (k_buf == nullptr) { - free_dflash_kv_cache_tensors(); - return false; - } - ggml_backend_buffer_set_usage(k_buf, GGML_BACKEND_BUFFER_USAGE_COMPUTE); - ggml_backend_tensor_alloc(k_buf, dflash.kv.k_ctx_cache[(size_t) il], ggml_backend_buffer_get_base(k_buf)); - ggml_backend_buffer_clear(k_buf, 0); - dflash.kv.cache_bufs.push_back(k_buf); - - const size_t v_bytes = ggml_backend_buft_get_alloc_size(layer_buft, dflash.kv.v_ctx_cache[(size_t) il]); - ggml_backend_buffer_t v_buf = ggml_backend_buft_alloc_buffer(layer_buft, v_bytes); - if (v_buf == nullptr) { - free_dflash_kv_cache_tensors(); - return false; - } - ggml_backend_buffer_set_usage(v_buf, GGML_BACKEND_BUFFER_USAGE_COMPUTE); - ggml_backend_tensor_alloc(v_buf, dflash.kv.v_ctx_cache[(size_t) il], ggml_backend_buffer_get_base(v_buf)); - ggml_backend_buffer_clear(v_buf, 0); - dflash.kv.cache_bufs.push_back(v_buf); - - dflash.kv.k_ctx_workspace[(size_t) il] = ggml_new_tensor_3d(dflash.kv.cache_ctx, GGML_TYPE_F32, n_embd_head_k, target_workspace_n_kv_total, n_head_kv); - dflash.kv.v_ctx_workspace[(size_t) il] = ggml_new_tensor_3d(dflash.kv.cache_ctx, GGML_TYPE_F32, n_embd_head_v, target_workspace_n_kv_total, n_head_kv); - if (dflash.kv.k_ctx_workspace[(size_t) il] == nullptr || dflash.kv.v_ctx_workspace[(size_t) il] == nullptr) { - free_dflash_kv_cache_tensors(); - return false; - } - - ggml_set_input(dflash.kv.k_ctx_workspace[(size_t) il]); - ggml_set_input(dflash.kv.v_ctx_workspace[(size_t) il]); - ggml_format_name(dflash.kv.k_ctx_workspace[(size_t) il], "dflash_k_ctx_workspace_%d", il); - ggml_format_name(dflash.kv.v_ctx_workspace[(size_t) il], "dflash_v_ctx_workspace_%d", il); - - const size_t k_workspace_bytes = ggml_backend_buft_get_alloc_size(layer_buft, dflash.kv.k_ctx_workspace[(size_t) il]); - ggml_backend_buffer_t k_workspace_buf = ggml_backend_buft_alloc_buffer(layer_buft, k_workspace_bytes); - if (k_workspace_buf == nullptr) { - free_dflash_kv_cache_tensors(); - return false; - } - ggml_backend_buffer_set_usage(k_workspace_buf, GGML_BACKEND_BUFFER_USAGE_COMPUTE); - ggml_backend_tensor_alloc(k_workspace_buf, dflash.kv.k_ctx_workspace[(size_t) il], ggml_backend_buffer_get_base(k_workspace_buf)); - ggml_backend_buffer_clear(k_workspace_buf, 0); - dflash.kv.cache_bufs.push_back(k_workspace_buf); - - const size_t v_workspace_bytes = ggml_backend_buft_get_alloc_size(layer_buft, dflash.kv.v_ctx_workspace[(size_t) il]); - ggml_backend_buffer_t v_workspace_buf = ggml_backend_buft_alloc_buffer(layer_buft, v_workspace_bytes); - if (v_workspace_buf == nullptr) { - free_dflash_kv_cache_tensors(); - return false; - } - ggml_backend_buffer_set_usage(v_workspace_buf, GGML_BACKEND_BUFFER_USAGE_COMPUTE); - ggml_backend_tensor_alloc(v_workspace_buf, dflash.kv.v_ctx_workspace[(size_t) il], ggml_backend_buffer_get_base(v_workspace_buf)); - ggml_backend_buffer_clear(v_workspace_buf, 0); - dflash.kv.cache_bufs.push_back(v_workspace_buf); } dflash.kv.workspace_token_capacity = target_token_capacity; @@ -201,10 +171,15 @@ bool llama_context::ensure_dflash_kv_cache_tensors(int32_t cross_ctx) { } void llama_context::free_dflash_kv_cache_tensors() { - dflash.kv.k_ctx_cache.clear(); - dflash.kv.v_ctx_cache.clear(); - dflash.kv.k_ctx_workspace.clear(); - dflash.kv.v_ctx_workspace.clear(); + auto release_vector = [](auto & v) { + using vec_type = std::decay_t; + vec_type().swap(v); + }; + + release_vector(dflash.kv.k_ctx_cache); + release_vector(dflash.kv.v_ctx_cache); + release_vector(dflash.kv.k_ctx_workspace); + release_vector(dflash.kv.v_ctx_workspace); dflash.kv.cache_write_pos = 0; dflash.kv.cache_n_filled = 0; dflash.kv.cache_update_rows = 0; @@ -244,7 +219,9 @@ void llama_context::free_dflash_kv_cache_tensors() { ggml_backend_buffer_free(buf); } } - dflash.kv.cache_bufs.clear(); + release_vector(dflash.kv.cache_bufs); + release_vector(dflash.kv.cache_compute_meta); + release_vector(dflash.kv.workspace_compute_meta); if (dflash.kv.cache_ctx != nullptr) { ggml_free(dflash.kv.cache_ctx); dflash.kv.cache_ctx = nullptr; diff --git a/src/llama-load-tensors.cpp b/src/llama-load-tensors.cpp index 88e8816b..0585f0ba 100644 --- a/src/llama-load-tensors.cpp +++ b/src/llama-load-tensors.cpp @@ -2257,10 +2257,14 @@ bool create_tensors_helper::create_dflash_tensors(const LLM_TN & tn) { model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); model.output = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); - model.output_mtp = create_tensor(ctx_output, "output_extra.weight", {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + auto output_extra = create_tensor(ctx_output, "output_extra.weight", {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + if (output_extra != nullptr) { + model.output = output_extra; + } if (model.output == nullptr && model.tok_embd != nullptr) { model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); } + model.output_mtp = model.output; model.dflash_fc = create_tensor(ctx_output, tn(LLM_TENSOR_DFLASH_FC, "weight"), {(int64_t) hparams.dflash_n_target_features, n_embd}, 0); model.dflash_hidden_norm = create_tensor(ctx_output, tn(LLM_TENSOR_DFLASH_HIDDEN_NORM, "weight"), {n_embd}, 0);