mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-06-28 04:30:15 -05:00
fix graph mask, swa layers and tokens positions
This commit is contained in:
parent
532499836e
commit
1369e68471
@ -15,6 +15,7 @@
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <iomanip>
|
||||
#include <limits>
|
||||
#include <map>
|
||||
#include <sstream>
|
||||
#include <unordered_map>
|
||||
@ -133,6 +134,31 @@ static bool common_speculative_are_dflash_compatible(
|
||||
return false;
|
||||
}
|
||||
|
||||
const bool add_bos_tgt = llama_vocab_get_add_bos(vocab_tgt);
|
||||
const bool add_bos_dft = llama_vocab_get_add_bos(vocab_dft);
|
||||
const bool add_eos_tgt = llama_vocab_get_add_eos(vocab_tgt);
|
||||
const bool add_eos_dft = llama_vocab_get_add_eos(vocab_dft);
|
||||
const llama_token bos_tgt = llama_vocab_bos(vocab_tgt);
|
||||
const llama_token bos_dft = llama_vocab_bos(vocab_dft);
|
||||
const llama_token eos_tgt = llama_vocab_eos(vocab_tgt);
|
||||
const llama_token eos_dft = llama_vocab_eos(vocab_dft);
|
||||
|
||||
if (add_bos_tgt != add_bos_dft || add_eos_tgt != add_eos_dft ||
|
||||
(add_bos_tgt && bos_tgt != bos_dft) ||
|
||||
(add_eos_tgt && eos_tgt != eos_dft)) {
|
||||
LOG_DBG("%s: DFlash draft special tokens must match the target model (add_bos=%d/%d add_eos=%d/%d bos=%d/%d eos=%d/%d)\n",
|
||||
__func__,
|
||||
(int) add_bos_tgt,
|
||||
(int) add_bos_dft,
|
||||
(int) add_eos_tgt,
|
||||
(int) add_eos_dft,
|
||||
(int) bos_tgt,
|
||||
(int) bos_dft,
|
||||
(int) eos_tgt,
|
||||
(int) eos_dft);
|
||||
return false;
|
||||
}
|
||||
|
||||
const int n_vocab_tgt = llama_vocab_n_tokens(vocab_tgt);
|
||||
const int n_vocab_dft = llama_vocab_n_tokens(vocab_dft);
|
||||
const int vocab_diff = n_vocab_tgt > n_vocab_dft
|
||||
@ -464,6 +490,24 @@ struct common_speculative_state_dflash : public common_speculative_state {
|
||||
size_t n_decode_fail = 0;
|
||||
llama_pos last_draft_pos_base = -1;
|
||||
|
||||
uint64_t t_draft_decode_us = 0;
|
||||
uint64_t t_draft_sample_us = 0;
|
||||
uint64_t t_warmup_collect_us = 0;
|
||||
uint64_t t_warmup_append_us = 0;
|
||||
uint64_t t_accept_output_copy_us = 0;
|
||||
uint64_t t_accept_commit_us = 0;
|
||||
uint64_t t_accept_append_us = 0;
|
||||
size_t n_warmup_collect_calls = 0;
|
||||
size_t n_warmup_collect_rows = 0;
|
||||
size_t n_warmup_append_calls = 0;
|
||||
size_t n_warmup_append_rows = 0;
|
||||
size_t n_accept_output_copy_calls = 0;
|
||||
size_t n_accept_output_copy_rows = 0;
|
||||
size_t n_accept_commit_calls = 0;
|
||||
size_t n_accept_commit_rows = 0;
|
||||
size_t n_accept_append_calls = 0;
|
||||
size_t n_accept_append_rows = 0;
|
||||
|
||||
common_speculative_state_dflash(
|
||||
enum common_speculative_type type,
|
||||
llama_context * ctx_tgt,
|
||||
@ -474,9 +518,10 @@ struct common_speculative_state_dflash : public common_speculative_state {
|
||||
, ctx_dft(ctx_dft)
|
||||
, cross_ctx(std::max(1, cross_ctx))
|
||||
{
|
||||
const llama_model * model_tgt = llama_get_model(ctx_tgt);
|
||||
const llama_model * model_dft = llama_get_model(ctx_dft);
|
||||
|
||||
if (!common_speculative_are_dflash_compatible(llama_get_model(ctx_tgt), model_dft)) {
|
||||
if (!common_speculative_are_dflash_compatible(model_tgt, model_dft)) {
|
||||
LOG_ERR("%s: DFlash draft model vocab/tokenizer is incompatible with the target model\n", __func__);
|
||||
return;
|
||||
}
|
||||
@ -499,6 +544,70 @@ struct common_speculative_state_dflash : public common_speculative_state {
|
||||
return;
|
||||
}
|
||||
|
||||
const auto * vocab_tgt = llama_model_get_vocab(model_tgt);
|
||||
const auto * vocab_dft = llama_model_get_vocab(model_dft);
|
||||
const int32_t target_vocab_size = llama_vocab_n_tokens(vocab_tgt);
|
||||
const int32_t draft_vocab_size = llama_vocab_n_tokens(vocab_dft);
|
||||
const int32_t target_hidden_size = llama_model_n_embd(model_tgt);
|
||||
const int32_t draft_hidden_size = llama_model_n_embd(model_dft);
|
||||
const int32_t target_mask_token_id = llama_model_dflash_target_mask_token_id(model_tgt);
|
||||
const int32_t expected_n_target_features = target_hidden_size > 0 ? target_hidden_size * n_target_layers : 0;
|
||||
|
||||
if (target_mask_token_id != (int32_t) LLAMA_TOKEN_NULL && mask_token_id != target_mask_token_id) {
|
||||
LOG_ERR("%s: DFlash mask token mismatch (draft=%d target=%d)\n",
|
||||
__func__, mask_token_id, target_mask_token_id);
|
||||
return;
|
||||
}
|
||||
|
||||
if (target_hidden_size <= 0 || draft_hidden_size <= 0) {
|
||||
LOG_ERR("%s: invalid DFlash hidden sizes (draft=%d target=%d)\n",
|
||||
__func__, draft_hidden_size, target_hidden_size);
|
||||
return;
|
||||
}
|
||||
|
||||
if (expected_n_target_features <= 0 || n_target_features != expected_n_target_features) {
|
||||
LOG_ERR("%s: DFlash target feature width mismatch (metadata=%d expected=%d target_hidden=%d target_layers=%d)\n",
|
||||
__func__, n_target_features, expected_n_target_features, target_hidden_size, n_target_layers);
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<int32_t> sorted_target_layer_ids = target_layer_ids;
|
||||
std::sort(sorted_target_layer_ids.begin(), sorted_target_layer_ids.end());
|
||||
if (std::adjacent_find(sorted_target_layer_ids.begin(), sorted_target_layer_ids.end()) != sorted_target_layer_ids.end()) {
|
||||
LOG_ERR("%s: duplicate DFlash target layer ids survived into runtime validation\n", __func__);
|
||||
target_layer_ids.clear();
|
||||
return;
|
||||
}
|
||||
|
||||
const int32_t n_target_model_layers = llama_n_layer(model_tgt);
|
||||
for (int32_t layer_id : target_layer_ids) {
|
||||
if (layer_id < 0 || layer_id >= n_target_model_layers) {
|
||||
LOG_ERR("%s: invalid DFlash target layer id %d for target model with %d layers\n",
|
||||
__func__, layer_id, n_target_model_layers);
|
||||
target_layer_ids.clear();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
const int32_t io_mode = llama_model_dflash_io_mode(model_dft, model_tgt);
|
||||
if (io_mode == LLAMA_DFLASH_IO_MODE_INVALID) {
|
||||
LOG_ERR("%s: DFlash draft is missing required IO tensors after target sharing\n", __func__);
|
||||
return;
|
||||
}
|
||||
|
||||
if (io_mode == LLAMA_DFLASH_IO_MODE_MIXED) {
|
||||
LOG_ERR("%s: DFlash IO contract must be fully shared or fully self-contained, but resolved to mixed mode\n", __func__);
|
||||
return;
|
||||
}
|
||||
|
||||
if (io_mode == LLAMA_DFLASH_IO_MODE_SELF_CONTAINED && !llama_model_dflash_io_tensors_match(model_dft, target_hidden_size, target_vocab_size)) {
|
||||
LOG_ERR("%s: DFlash self-contained IO tensors do not match the target hidden/vocab contract (target_hidden=%d target_vocab=%d)\n",
|
||||
__func__,
|
||||
target_hidden_size,
|
||||
target_vocab_size);
|
||||
return;
|
||||
}
|
||||
|
||||
if (!llama_set_dflash_capture_layers(ctx_tgt, target_layer_ids.data(), (int32_t) target_layer_ids.size())) {
|
||||
LOG_ERR("%s: failed to configure DFlash target capture callback\n", __func__);
|
||||
return;
|
||||
@ -519,8 +628,11 @@ struct common_speculative_state_dflash : public common_speculative_state {
|
||||
layers_oss << target_layer_ids[i];
|
||||
}
|
||||
|
||||
const char * io_mode_name = io_mode == LLAMA_DFLASH_IO_MODE_SHARED ? "shared" : "self-contained";
|
||||
LOG_INF("%s: DFlash context ready (n_ctx=%d, block_size=%d, cross_ctx=%d, n_target_features=%d, target_layer_ids=[%s])\n",
|
||||
__func__, llama_n_ctx(ctx_dft), block_size, this->cross_ctx, n_target_features, layers_oss.str().c_str());
|
||||
__func__, llama_n_ctx(ctx_dft), block_size, this->cross_ctx, n_target_features, layers_oss.str().c_str());
|
||||
LOG_INF("%s: DFlash artifact io=%s draft_vocab=%d target_vocab=%d draft_hidden=%d target_hidden=%d mask_token_id=%d target_mask_token_id=%d\n",
|
||||
__func__, io_mode_name, draft_vocab_size, target_vocab_size, draft_hidden_size, target_hidden_size, mask_token_id, target_mask_token_id);
|
||||
}
|
||||
|
||||
~common_speculative_state_dflash() override {
|
||||
@ -544,6 +656,23 @@ struct common_speculative_state_dflash : public common_speculative_state {
|
||||
n_set_target_fail = 0;
|
||||
n_decode_fail = 0;
|
||||
last_draft_pos_base = -1;
|
||||
t_draft_decode_us = 0;
|
||||
t_draft_sample_us = 0;
|
||||
t_warmup_collect_us = 0;
|
||||
t_warmup_append_us = 0;
|
||||
t_accept_output_copy_us = 0;
|
||||
t_accept_commit_us = 0;
|
||||
t_accept_append_us = 0;
|
||||
n_warmup_collect_calls = 0;
|
||||
n_warmup_collect_rows = 0;
|
||||
n_warmup_append_calls = 0;
|
||||
n_warmup_append_rows = 0;
|
||||
n_accept_output_copy_calls = 0;
|
||||
n_accept_output_copy_rows = 0;
|
||||
n_accept_commit_calls = 0;
|
||||
n_accept_commit_rows = 0;
|
||||
n_accept_append_calls = 0;
|
||||
n_accept_append_rows = 0;
|
||||
llama_dflash_profile_reset(ctx_tgt);
|
||||
llama_dflash_profile_reset(ctx_dft);
|
||||
}
|
||||
@ -554,7 +683,6 @@ struct common_speculative_state_dflash : public common_speculative_state {
|
||||
llama_token id_last,
|
||||
llama_tokens & result) override {
|
||||
GGML_UNUSED(prompt_tgt);
|
||||
GGML_UNUSED(id_last);
|
||||
|
||||
result.clear();
|
||||
if (!ready || target_window_rows <= 0) {
|
||||
@ -562,7 +690,7 @@ struct common_speculative_state_dflash : public common_speculative_state {
|
||||
return;
|
||||
}
|
||||
|
||||
const int32_t n_keep = std::min<int32_t>(params.n_max, block_size);
|
||||
const int32_t n_keep = std::min<int32_t>(params.n_max, block_size - 1);
|
||||
if (n_keep <= 0) {
|
||||
return;
|
||||
}
|
||||
@ -575,23 +703,30 @@ struct common_speculative_state_dflash : public common_speculative_state {
|
||||
|
||||
llama_kv_cache_clear(ctx_dft);
|
||||
batch.n_tokens = 0;
|
||||
const int32_t batch_len = n_keep + 1;
|
||||
const llama_pos draft_pos_base = last_target_pos >= 0 ? last_target_pos + 1 : (llama_pos) target_window_rows;
|
||||
const llama_pos seed_pos = last_target_pos >= 0 ? last_target_pos : draft_pos_base - 1;
|
||||
last_draft_pos_base = draft_pos_base;
|
||||
for (int32_t i = 0; i < block_size; ++i) {
|
||||
common_batch_add(batch, mask_token_id, draft_pos_base + i, { 0 }, i < n_keep);
|
||||
common_batch_add(batch, id_last, seed_pos, { 0 }, false);
|
||||
for (int32_t i = 1; i < batch_len; ++i) {
|
||||
common_batch_add(batch, mask_token_id, draft_pos_base + (i - 1), { 0 }, i <= n_keep);
|
||||
}
|
||||
|
||||
const int64_t t_decode_us = ggml_time_us();
|
||||
if (llama_decode(ctx_dft, batch) != 0) {
|
||||
LOG_ERR("%s: llama_decode() failed for DFlash draft batch\n", __func__);
|
||||
n_decode_fail++;
|
||||
batch.n_tokens = 0;
|
||||
return;
|
||||
}
|
||||
t_draft_decode_us += (uint64_t) (ggml_time_us() - t_decode_us);
|
||||
|
||||
result.reserve((size_t) n_keep);
|
||||
const int64_t t_sample_us = ggml_time_us();
|
||||
for (int32_t i = 0; i < n_keep; ++i) {
|
||||
result.push_back(common_sampler_sample_speculative(nullptr, ctx_dft, i, nullptr));
|
||||
result.push_back(common_sampler_sample_speculative(nullptr, ctx_dft, i + 1, nullptr));
|
||||
}
|
||||
t_draft_sample_us += (uint64_t) (ggml_time_us() - t_sample_us);
|
||||
|
||||
batch.n_tokens = 0;
|
||||
dflash_contract_log_draft(*this, n_keep, result.size());
|
||||
@ -657,8 +792,13 @@ static void dflash_contract_log_draft(
|
||||
const int draft_delta = (state.last_target_pos >= 0 && state.last_draft_pos_base >= 0)
|
||||
? (int) (state.last_draft_pos_base - state.last_target_pos)
|
||||
: -1;
|
||||
const llama_pos seed_pos = state.last_target_pos;
|
||||
const llama_pos mask_first_pos = state.last_draft_pos_base;
|
||||
const llama_pos mask_last_pos = state.last_draft_pos_base >= 0
|
||||
? state.last_draft_pos_base + n_keep - 1
|
||||
: -1;
|
||||
|
||||
LOG_INF("dflash contract draft[%llu]: window_rows=%d window_pos=%s pos=[%d..%d] gaps=%d nonmono=%d last_target_pos=%d draft_pos_base=%d delta=%d n_keep=%d result=%zu set_target(missing/nonmono)=%llu/%llu graph(fallback/nonmono)=%llu/%llu graph_pos=[%d..%d]\n",
|
||||
LOG_INF("dflash contract draft[%llu]: window_rows=%d window_pos=%s pos=[%d..%d] gaps=%d nonmono=%d last_target_pos=%d seed_pos=%d mask_pos=[%d..%d] sample_rows=[1..%d] output_rows=[1..%d] draft_pos_base=%d delta=%d n_keep=%d result=%zu set_target(missing/nonmono)=%llu/%llu graph(fallback/nonmono)=%llu/%llu graph_pos=[%d..%d]\n",
|
||||
(unsigned long long) (ordinal + 1),
|
||||
state.target_window_rows,
|
||||
dflash_contract_format_values(state.target_window_pos).c_str(),
|
||||
@ -667,6 +807,11 @@ static void dflash_contract_log_draft(
|
||||
window.gap_count,
|
||||
window.nonmono_count,
|
||||
(int) state.last_target_pos,
|
||||
(int) seed_pos,
|
||||
(int) mask_first_pos,
|
||||
(int) mask_last_pos,
|
||||
n_keep,
|
||||
n_keep,
|
||||
(int) state.last_draft_pos_base,
|
||||
draft_delta,
|
||||
n_keep,
|
||||
@ -1583,7 +1728,14 @@ common_speculative * common_speculative_init(
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
cparams_dft.n_ctx = (uint32_t) (max_cross_ctx + block_size);
|
||||
const int64_t required_n_ctx = (int64_t) max_cross_ctx + (int64_t) block_size;
|
||||
if (required_n_ctx > std::numeric_limits<int32_t>::max()) {
|
||||
LOG_ERR("%s: invalid DFlash draft context size cross_ctx=%d block_size=%d required_n_ctx=%lld\n",
|
||||
__func__, max_cross_ctx, block_size, (long long) required_n_ctx);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
cparams_dft.n_ctx = (uint32_t) required_n_ctx;
|
||||
}
|
||||
|
||||
ctx_dft = llama_init_from_model(params.model_dft, cparams_dft);
|
||||
@ -2127,6 +2279,8 @@ int32_t common_speculative_on_target_seq_batch(
|
||||
const llama_batch * batch_for_spec = &batch;
|
||||
llama_batch seq_batch = {};
|
||||
const bool needs_seq_split = is_prompt_warmup && !common_speculative_batch_is_exact_single_seq(batch, seq_id);
|
||||
auto * dflash_state = common_speculative_get_dflash_state(spec);
|
||||
const bool measure_dflash_warmup_collect = dflash_state != nullptr && is_prompt_warmup;
|
||||
|
||||
if (needs_seq_split) {
|
||||
const int n_seq_tokens = common_speculative_copy_seq_batch(batch, seq_id, seq_batch);
|
||||
@ -2134,16 +2288,28 @@ int32_t common_speculative_on_target_seq_batch(
|
||||
return n_seq_tokens < 0 ? -1 : 0;
|
||||
}
|
||||
|
||||
const int64_t t_collect_us = measure_dflash_warmup_collect ? ggml_time_us() : 0;
|
||||
if (!common_speculative_collect_target_seq_batch_features(spec, ctx_tgt, batch, seq_id, feature_view)) {
|
||||
llama_batch_free(seq_batch);
|
||||
return -1;
|
||||
}
|
||||
if (measure_dflash_warmup_collect) {
|
||||
dflash_state->t_warmup_collect_us += (uint64_t) (ggml_time_us() - t_collect_us);
|
||||
dflash_state->n_warmup_collect_calls++;
|
||||
dflash_state->n_warmup_collect_rows += (size_t) n_seq_tokens;
|
||||
}
|
||||
|
||||
batch_for_spec = &seq_batch;
|
||||
} else {
|
||||
const int64_t t_collect_us = measure_dflash_warmup_collect ? ggml_time_us() : 0;
|
||||
if (!common_speculative_collect_target_batch_features(spec, ctx_tgt, batch, feature_view)) {
|
||||
return -1;
|
||||
}
|
||||
if (measure_dflash_warmup_collect) {
|
||||
dflash_state->t_warmup_collect_us += (uint64_t) (ggml_time_us() - t_collect_us);
|
||||
dflash_state->n_warmup_collect_calls++;
|
||||
dflash_state->n_warmup_collect_rows += (size_t) batch.n_tokens;
|
||||
}
|
||||
}
|
||||
|
||||
const int32_t ret = common_speculative_on_target_batch(spec, *batch_for_spec, feature_view, is_prompt_warmup);
|
||||
@ -2244,7 +2410,16 @@ bool common_speculative_commit_accepted_hidden_rows(
|
||||
return false;
|
||||
}
|
||||
|
||||
return common_speculative_apply_hidden_rows(spec, seq_id, pos_base, commit_tokens, hidden_rows);
|
||||
auto * dflash_state = common_speculative_get_dflash_state(spec);
|
||||
const int64_t t_commit_us = dflash_state != nullptr ? ggml_time_us() : 0;
|
||||
const bool ok = common_speculative_apply_hidden_rows(spec, seq_id, pos_base, commit_tokens, hidden_rows);
|
||||
if (dflash_state != nullptr) {
|
||||
dflash_state->t_accept_commit_us += (uint64_t) (ggml_time_us() - t_commit_us);
|
||||
dflash_state->n_accept_commit_calls++;
|
||||
dflash_state->n_accept_commit_rows += commit_tokens.size();
|
||||
}
|
||||
|
||||
return ok;
|
||||
}
|
||||
|
||||
bool common_speculative_commit_accepted_output(
|
||||
@ -2261,9 +2436,16 @@ bool common_speculative_commit_accepted_output(
|
||||
}
|
||||
|
||||
std::vector<float> hidden_rows;
|
||||
auto * dflash_state = common_speculative_get_dflash_state(spec);
|
||||
const int64_t t_copy_us = dflash_state != nullptr ? ggml_time_us() : 0;
|
||||
if (!common_speculative_copy_output_hidden_rows(spec, ctx, output_indices, hidden_rows)) {
|
||||
return false;
|
||||
}
|
||||
if (dflash_state != nullptr) {
|
||||
dflash_state->t_accept_output_copy_us += (uint64_t) (ggml_time_us() - t_copy_us);
|
||||
dflash_state->n_accept_output_copy_calls++;
|
||||
dflash_state->n_accept_output_copy_rows += output_indices.size();
|
||||
}
|
||||
|
||||
return common_speculative_commit_accepted_hidden_rows(
|
||||
spec,
|
||||
@ -2324,7 +2506,24 @@ void common_speculative_print_stats(const common_speculative * spec, double slot
|
||||
(int) dflash_state->last_draft_pos_base);
|
||||
|
||||
if (have_capture || have_graph) {
|
||||
LOG_INF("statistics dflash profile: capture(sync/materialize)=%.3f/%.3f ms calls=%llu/%llu bytes=%llu phase(prompt/verify batches changes)=%llu/%llu %llu/%llu, set_target=%.3f ms rows=%llu bytes=%llu, prep(total/features/pos/mask)=%.3f/%.3f/%.3f/%.3f ms kv_cache=%.3f ms calls=%llu/%llu bytes=%llu/%llu/%llu, fallback_pos(copy/graph)=%llu/%llu, nonmono(copy/graph)=%llu/%llu, capture_fail=%llu/%llu, visible_kv_max=%llu, last(rows=%d width=%d left_pad=%d n_tokens=%d n_kv=%d pos=[%d..%d])\n",
|
||||
const double kv_cache_total_ms = (double) (
|
||||
graph_stats.graph_kv_cache_build_us +
|
||||
graph_stats.graph_kv_cache_reserve_us +
|
||||
graph_stats.graph_kv_cache_reset_us +
|
||||
graph_stats.graph_kv_cache_alloc_us +
|
||||
graph_stats.graph_kv_cache_feature_upload_us +
|
||||
graph_stats.graph_kv_cache_pos_upload_us +
|
||||
graph_stats.graph_kv_cache_compute_us +
|
||||
graph_stats.graph_kv_cache_sync_us +
|
||||
graph_stats.graph_kv_cache_read_concat_pad_us) / 1000.0;
|
||||
const double kv_upload_feature_ms = (double) graph_stats.graph_kv_cache_feature_upload_us / 1000.0;
|
||||
const double kv_upload_pos_ms = (double) graph_stats.graph_kv_cache_pos_upload_us / 1000.0;
|
||||
const double kv_upload_total_ms = kv_upload_feature_ms + kv_upload_pos_ms;
|
||||
const double kv_compute_ms = (double) graph_stats.graph_kv_cache_compute_us / 1000.0;
|
||||
const double kv_sync_ms = (double) graph_stats.graph_kv_cache_sync_us / 1000.0;
|
||||
const double replay_append_ms = (double) dflash_state->t_accept_append_us / 1000.0;
|
||||
|
||||
LOG_INF("statistics dflash profile: capture(sync/materialize)=%.3f/%.3f ms calls=%llu/%llu bytes=%llu phase(prompt/verify batches changes)=%llu/%llu %llu/%llu, set_target=%.3f ms rows=%llu bytes=%llu, decode(llama_output_reserve/prepare)=%.3f/%.3f ms calls=%llu/%llu realloc(bytes)=%llu/%llu, prep(total/features/pos/mask)=%.3f/%.3f/%.3f/%.3f ms kv_cache(total/build/reserve/reset/alloc/up_f/up_p/compute/sync/read)=%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f/%.3f ms calls(prepare/cache/read)=%llu/%llu/%llu bytes(feature/pos/mask/read)=%llu/%llu/%llu/%llu host_layers=%d, fallback_pos(copy/graph)=%llu/%llu, nonmono(copy/graph)=%llu/%llu, capture_fail=%llu/%llu decode_prepare_fail=%llu, visible_kv_max=%llu, last(rows=%d width=%d left_pad=%d n_tokens=%d n_kv=%d pos=[%d..%d])\n",
|
||||
(double) capture_stats.capture_prepare_sync_us / 1000.0,
|
||||
(double) capture_stats.capture_materialize_us / 1000.0,
|
||||
(unsigned long long) capture_stats.capture_prepare_calls,
|
||||
@ -2337,22 +2536,41 @@ void common_speculative_print_stats(const common_speculative * spec, double slot
|
||||
(double) graph_stats.set_target_copy_us / 1000.0,
|
||||
(unsigned long long) graph_stats.set_target_rows,
|
||||
(unsigned long long) graph_stats.set_target_copy_bytes,
|
||||
(double) graph_stats.decode_output_reserve_us / 1000.0,
|
||||
(double) graph_stats.decode_prepare_us / 1000.0,
|
||||
(unsigned long long) graph_stats.decode_output_reserve_calls,
|
||||
(unsigned long long) graph_stats.decode_prepare_calls,
|
||||
(unsigned long long) graph_stats.decode_output_reserve_reallocs,
|
||||
(unsigned long long) graph_stats.decode_output_reserve_realloc_bytes,
|
||||
(double) graph_stats.graph_prepare_total_us / 1000.0,
|
||||
(double) graph_stats.graph_feature_copy_us / 1000.0,
|
||||
(double) graph_stats.graph_pos_copy_us / 1000.0,
|
||||
(double) graph_stats.graph_mask_build_us / 1000.0,
|
||||
kv_cache_total_ms,
|
||||
(double) graph_stats.graph_kv_cache_build_us / 1000.0,
|
||||
(double) graph_stats.graph_kv_cache_reserve_us / 1000.0,
|
||||
(double) graph_stats.graph_kv_cache_reset_us / 1000.0,
|
||||
(double) graph_stats.graph_kv_cache_alloc_us / 1000.0,
|
||||
(double) graph_stats.graph_kv_cache_feature_upload_us / 1000.0,
|
||||
(double) graph_stats.graph_kv_cache_pos_upload_us / 1000.0,
|
||||
(double) graph_stats.graph_kv_cache_compute_us / 1000.0,
|
||||
(double) graph_stats.graph_kv_cache_sync_us / 1000.0,
|
||||
(double) graph_stats.graph_kv_cache_read_concat_pad_us / 1000.0,
|
||||
(unsigned long long) graph_stats.graph_prepare_calls,
|
||||
(unsigned long long) graph_stats.graph_kv_cache_calls,
|
||||
(unsigned long long) graph_stats.graph_kv_cache_read_concat_pad_calls,
|
||||
(unsigned long long) graph_stats.graph_feature_bytes,
|
||||
(unsigned long long) graph_stats.graph_pos_bytes,
|
||||
(unsigned long long) graph_stats.graph_mask_bytes,
|
||||
(unsigned long long) graph_stats.graph_kv_cache_cached_bytes,
|
||||
graph_stats.last_kv_cache_host_layers,
|
||||
(unsigned long long) graph_stats.set_target_missing_positions,
|
||||
(unsigned long long) graph_stats.graph_pos_fallbacks,
|
||||
(unsigned long long) graph_stats.set_target_non_monotonic_positions,
|
||||
(unsigned long long) graph_stats.graph_pos_non_monotonic,
|
||||
(unsigned long long) capture_stats.capture_prepare_failures,
|
||||
(unsigned long long) capture_stats.capture_materialize_failures,
|
||||
(unsigned long long) graph_stats.decode_prepare_failures,
|
||||
(unsigned long long) graph_stats.graph_visible_kv_max,
|
||||
graph_stats.last_n_rows,
|
||||
graph_stats.last_width,
|
||||
@ -2361,6 +2579,36 @@ void common_speculative_print_stats(const common_speculative * spec, double slot
|
||||
graph_stats.last_n_kv_total,
|
||||
(int) graph_stats.last_pos_first,
|
||||
(int) graph_stats.last_pos_last);
|
||||
|
||||
LOG_INF("statistics dflash hot: kv(upload_f/upload_p/upload/compute/sync)=%.3f/%.3f/%.3f/%.3f/%.3f ms calls=%llu replay(accepted_prefix_append)=%.3f ms calls=%zu rows=%zu\n",
|
||||
kv_upload_feature_ms,
|
||||
kv_upload_pos_ms,
|
||||
kv_upload_total_ms,
|
||||
kv_compute_ms,
|
||||
kv_sync_ms,
|
||||
(unsigned long long) graph_stats.graph_kv_cache_calls,
|
||||
replay_append_ms,
|
||||
dflash_state->n_accept_append_calls,
|
||||
dflash_state->n_accept_append_rows);
|
||||
|
||||
LOG_INF("statistics dflash stages: draft(decode/sample)=%.3f/%.3f ms warmup(collect/append)=%.3f/%.3f ms calls=%zu/%zu rows=%zu/%zu accept(total/output_copy/append)=%.3f/%.3f/%.3f ms calls=%zu/%zu/%zu rows=%zu/%zu/%zu\n",
|
||||
(double) dflash_state->t_draft_decode_us / 1000.0,
|
||||
(double) dflash_state->t_draft_sample_us / 1000.0,
|
||||
(double) dflash_state->t_warmup_collect_us / 1000.0,
|
||||
(double) dflash_state->t_warmup_append_us / 1000.0,
|
||||
dflash_state->n_warmup_collect_calls,
|
||||
dflash_state->n_warmup_append_calls,
|
||||
dflash_state->n_warmup_collect_rows,
|
||||
dflash_state->n_warmup_append_rows,
|
||||
(double) dflash_state->t_accept_commit_us / 1000.0,
|
||||
(double) dflash_state->t_accept_output_copy_us / 1000.0,
|
||||
(double) dflash_state->t_accept_append_us / 1000.0,
|
||||
dflash_state->n_accept_commit_calls,
|
||||
dflash_state->n_accept_output_copy_calls,
|
||||
dflash_state->n_accept_append_calls,
|
||||
dflash_state->n_accept_commit_rows,
|
||||
dflash_state->n_accept_output_copy_rows,
|
||||
dflash_state->n_accept_append_rows);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -2690,8 +2938,6 @@ int32_t common_speculative_on_target_batch(
|
||||
const common_speculative_feature_view & features,
|
||||
bool is_prompt_warmup) {
|
||||
if (auto * dflash_state = common_speculative_get_dflash_state(spec); dflash_state != nullptr) {
|
||||
GGML_UNUSED(is_prompt_warmup);
|
||||
|
||||
if (features.kind != COMMON_SPECULATIVE_FEATURE_HIDDEN_STATE || batch.n_tokens <= 0) {
|
||||
return 0;
|
||||
}
|
||||
@ -2713,9 +2959,22 @@ int32_t common_speculative_on_target_batch(
|
||||
}
|
||||
}
|
||||
|
||||
const int64_t t_append_us = ggml_time_us();
|
||||
if (!dflash_append_target_features(*dflash_state, features, batch, seq_id)) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
const uint64_t append_us = (uint64_t) (ggml_time_us() - t_append_us);
|
||||
if (is_prompt_warmup) {
|
||||
dflash_state->t_warmup_append_us += append_us;
|
||||
dflash_state->n_warmup_append_calls++;
|
||||
dflash_state->n_warmup_append_rows += (size_t) batch.n_tokens;
|
||||
} else {
|
||||
dflash_state->t_accept_append_us += append_us;
|
||||
dflash_state->n_accept_append_calls++;
|
||||
dflash_state->n_accept_append_rows += (size_t) batch.n_tokens;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
@ -2287,6 +2287,8 @@ class DFlashDraftModel(Qwen3Model):
|
||||
model_arch = gguf.MODEL_ARCH.DFLASH_DRAFT
|
||||
|
||||
_target_hparams: dict[str, Any] | None = None
|
||||
_saw_token_embd = False
|
||||
_saw_output = False
|
||||
|
||||
def _require_target_model_dir(self) -> Path:
|
||||
if self.target_model_dir is None:
|
||||
@ -2338,25 +2340,28 @@ class DFlashDraftModel(Qwen3Model):
|
||||
elif "n_target_features" in self.hparams:
|
||||
n_target_features = int(self.hparams["n_target_features"])
|
||||
else:
|
||||
target_hparams = self._get_target_hparams()
|
||||
target_hidden_size = target_hparams.get("hidden_size")
|
||||
if target_hidden_size is None:
|
||||
raise ValueError("DFlashDraftModel: target config is missing hidden_size")
|
||||
|
||||
draft_hidden_size = self.hparams.get("hidden_size")
|
||||
if draft_hidden_size is None:
|
||||
raise ValueError("DFlashDraftModel: draft config is missing hidden_size")
|
||||
|
||||
n_target_features = int(draft_hidden_size) * len(target_layer_ids)
|
||||
n_target_features = int(target_hidden_size) * len(target_layer_ids)
|
||||
|
||||
target_hparams = self._get_target_hparams()
|
||||
target_hidden_size = target_hparams.get("hidden_size")
|
||||
if target_hidden_size is not None and int(target_hidden_size) != int(draft_hidden_size):
|
||||
logger.warning(
|
||||
"DFlashDraftModel: target hidden_size=%d differs from draft hidden_size=%d; using draft hidden width for n_target_features",
|
||||
"DFlashDraftModel: target hidden_size=%d differs from draft hidden_size=%d; using target hidden width for n_target_features",
|
||||
int(target_hidden_size),
|
||||
int(draft_hidden_size),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"DFlashDraftModel: inferred n_target_features=%d from draft hidden_size=%d and n_target_layers=%d",
|
||||
"DFlashDraftModel: inferred n_target_features=%d from target hidden_size=%d and n_target_layers=%d",
|
||||
n_target_features,
|
||||
int(draft_hidden_size),
|
||||
int(target_hidden_size),
|
||||
len(target_layer_ids),
|
||||
)
|
||||
|
||||
@ -2370,17 +2375,52 @@ class DFlashDraftModel(Qwen3Model):
|
||||
n_target_features,
|
||||
)
|
||||
|
||||
def prepare_tensors(self):
|
||||
super().prepare_tensors()
|
||||
|
||||
if self._saw_output and not self._saw_token_embd:
|
||||
raise ValueError(
|
||||
"DFlashDraftModel conversion requires token_embd.weight when output.weight is present"
|
||||
)
|
||||
|
||||
if self._saw_token_embd and self._saw_output:
|
||||
io_mode = "self-contained"
|
||||
elif self._saw_token_embd:
|
||||
io_mode = "self-contained-tied"
|
||||
else:
|
||||
io_mode = "shared-target"
|
||||
|
||||
logger.info(
|
||||
"DFlashDraftModel IO contract: io=%s token_embd=%s output=%s target_model_dir=%s",
|
||||
io_mode,
|
||||
self._saw_token_embd,
|
||||
self._saw_output,
|
||||
self._require_target_model_dir(),
|
||||
)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
if name == "fc.weight":
|
||||
top_level_name = name[6:] if name.startswith("model.") else name
|
||||
|
||||
if top_level_name == "fc.weight":
|
||||
return [(f"{gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.DFLASH_FC]}.weight", data_torch)]
|
||||
if name == "hidden_norm.weight":
|
||||
if top_level_name == "hidden_norm.weight":
|
||||
return [(f"{gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.DFLASH_HIDDEN_NORM]}.weight", data_torch)]
|
||||
if name == "norm.weight":
|
||||
name = "model.norm.weight"
|
||||
elif name.startswith("layers."):
|
||||
name = f"model.{name}"
|
||||
|
||||
return super().modify_tensors(data_torch, name, bid)
|
||||
tensors = list(super().modify_tensors(data_torch, name, bid))
|
||||
token_embd_name = f"{gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.TOKEN_EMBD]}.weight"
|
||||
output_name = f"{gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.OUTPUT]}.weight"
|
||||
|
||||
for tensor_name, _ in tensors:
|
||||
if tensor_name == token_embd_name:
|
||||
self._saw_token_embd = True
|
||||
elif tensor_name == output_name:
|
||||
self._saw_output = True
|
||||
|
||||
return tensors
|
||||
|
||||
|
||||
@Model.register("Ernie4_5_ForCausalLM", "Ernie4_5ForCausalLM")
|
||||
|
||||
@ -126,6 +126,17 @@ static void server_dflash_contract_log_accept(
|
||||
server_dflash_contract_format_indices(output_indices).c_str());
|
||||
}
|
||||
|
||||
static bool server_slot_prompt_batch_overlaps(
|
||||
const server_slot & slot,
|
||||
int32_t batch_i0,
|
||||
int32_t batch_i1) {
|
||||
if (slot.prompt_batch_i0 < 0 || slot.prompt_batch_i1 <= slot.prompt_batch_i0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return slot.prompt_batch_i0 < batch_i1 && batch_i0 < slot.prompt_batch_i1;
|
||||
}
|
||||
|
||||
static bool params_use_gemma4_external_mtp(const gpt_params & params_base) {
|
||||
return params_base.has_mtp &&
|
||||
llama_model_is_gemma4_mtp_assistant(params_base.speculative.model_dft);
|
||||
@ -708,6 +719,8 @@ void server_slot::reset() {
|
||||
n_past_prompt = 0;
|
||||
n_discarded_prompt = 0;
|
||||
n_kept_prompt = 0;
|
||||
prompt_batch_i0 = -1;
|
||||
prompt_batch_i1 = -1;
|
||||
n_sent_text = 0;
|
||||
drafted.clear();
|
||||
drafted_spec_type = COMMON_SPECULATIVE_TYPE_NONE;
|
||||
@ -3840,6 +3853,9 @@ bool server_context::create_checkpoint(server_slot & slot) {
|
||||
void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t n_batch, int32_t & batch_type) {
|
||||
if (params_base.cont_batching || batch.n_tokens == 0) {
|
||||
for (auto& slot : slots) {
|
||||
slot.prompt_batch_i0 = -1;
|
||||
slot.prompt_batch_i1 = -1;
|
||||
|
||||
// this slot still has a prompt to be processed
|
||||
if (slot.state == SLOT_STATE_IDLE && slot.command == SLOT_COMMAND_LOAD_PROMPT) {
|
||||
auto& prompt_tokens = slot.prompt_tokens;
|
||||
@ -4127,6 +4143,7 @@ void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t
|
||||
int32_t ga_i = slot.ga_i;
|
||||
int32_t ga_n = slot.ga_n;
|
||||
int32_t ga_w = slot.ga_w;
|
||||
const int32_t prompt_batch_i0 = batch.n_tokens;
|
||||
|
||||
// add prompt tokens for processing in the current batch
|
||||
// TODO: the self-extend stuff here is a mess - simplify and/or abstract it somehow
|
||||
@ -4161,6 +4178,9 @@ void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t
|
||||
}
|
||||
|
||||
}
|
||||
slot.prompt_batch_i0 = prompt_batch_i0;
|
||||
slot.prompt_batch_i1 = batch.n_tokens;
|
||||
|
||||
LOG_VERBOSE("prompt processing progress", {
|
||||
{"id_slot", slot.id},
|
||||
{"n_past", slot.n_past},
|
||||
@ -4770,6 +4790,7 @@ void server_context::update_allowlist_state(server_slot& slot) {
|
||||
void server_context::process_batch_tokens(int32_t & n_batch) {
|
||||
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
|
||||
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
|
||||
bool finish_prompt_warmup_batch = false;
|
||||
extend_context(n_tokens);
|
||||
|
||||
llama_batch batch_view = {
|
||||
@ -4830,7 +4851,6 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
|
||||
}
|
||||
|
||||
if (server_speculative_has_target_features(params_base.speculative)) {
|
||||
bool finished_prompt_warmup_batch = false;
|
||||
for (auto & slot : slots) {
|
||||
if (!slot.spec || !server_speculative_has_target_features(slot.params.speculative)) {
|
||||
continue;
|
||||
@ -4840,7 +4860,7 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (slot.command != SLOT_COMMAND_LOAD_PROMPT) {
|
||||
if (!server_slot_prompt_batch_overlaps(slot, i, i + n_tokens)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -4849,13 +4869,9 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
|
||||
slot.spec_prompt_warmup_failed = true;
|
||||
LOG_ERROR("failed to warm up speculative target-feature state from prompt batch for slot %d\n", slot.id);
|
||||
} else {
|
||||
finished_prompt_warmup_batch = true;
|
||||
finish_prompt_warmup_batch = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (finished_prompt_warmup_batch) {
|
||||
llama_finish_dflash_capture_batch(ctx, true);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& slot : slots) {
|
||||
@ -4974,6 +4990,10 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
|
||||
}
|
||||
// speculative decoding - main model sample and accept
|
||||
speculative_decoding_accept();
|
||||
|
||||
if (finish_prompt_warmup_batch) {
|
||||
llama_finish_dflash_capture_batch(ctx, true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -5005,6 +5025,11 @@ void server_context::update_slots() {
|
||||
// start populating the batch for this iteration
|
||||
common_batch_clear(batch);
|
||||
|
||||
for (auto & slot : slots) {
|
||||
slot.prompt_batch_i0 = -1;
|
||||
slot.prompt_batch_i1 = -1;
|
||||
}
|
||||
|
||||
// first, add sampled tokens from any ongoing sequences
|
||||
add_sampled_tokens(); // Prepare batch for inference
|
||||
|
||||
|
||||
@ -64,6 +64,8 @@ struct server_slot {
|
||||
|
||||
int32_t i_batch = -1;
|
||||
int32_t n_predict = -1; // TODO: disambiguate from params.n_predict
|
||||
int32_t prompt_batch_i0 = -1;
|
||||
int32_t prompt_batch_i1 = -1;
|
||||
|
||||
int32_t n_prompt_tokens = 0;
|
||||
int32_t n_prompt_tokens_cache = 0;
|
||||
|
||||
@ -1261,7 +1261,9 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.DFLASH_HIDDEN_NORM,
|
||||
],
|
||||
MODEL_ARCH.DFLASH_DRAFT: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_Q_NORM,
|
||||
|
||||
@ -64,6 +64,7 @@ ggml_cgraph * llm_build_context::build_dflash() {
|
||||
const int64_t n_embd_head_k = hparams.n_embd_head_k(0);
|
||||
const int64_t n_embd_head_v = hparams.n_embd_head_v(0);
|
||||
const int64_t n_target_features = hparams.dflash_n_target_features;
|
||||
auto & profile = lctx.dflash_profile;
|
||||
const bool use_kv_cache = dflash_use_kv_cache_experiment();
|
||||
const int64_t ctx_len = lctx.dflash_visible_cross_ctx > 0
|
||||
? (int64_t) lctx.dflash_visible_cross_ctx
|
||||
@ -77,12 +78,30 @@ ggml_cgraph * llm_build_context::build_dflash() {
|
||||
|
||||
ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes((int) std::max<int64_t>(n_tokens, ctx_len)) + 32 * n_layer, false);
|
||||
|
||||
bool have_swa_layers = false;
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
if (hparams.swa_layers[il]) {
|
||||
have_swa_layers = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
lctx.inp_dflash_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv_total, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
lctx.dflash_kq_mask_tensor = lctx.inp_dflash_kq_mask;
|
||||
ggml_set_input(lctx.inp_dflash_kq_mask);
|
||||
cb(lctx.inp_dflash_kq_mask, "dflash_kq_mask", -1);
|
||||
|
||||
ggml_tensor * dflash_kq_mask = flash_attn ? ggml_cast(ctx0, lctx.inp_dflash_kq_mask, GGML_TYPE_F16) : lctx.inp_dflash_kq_mask;
|
||||
ggml_tensor * dflash_kq_mask_full = flash_attn ? ggml_cast(ctx0, lctx.inp_dflash_kq_mask, GGML_TYPE_F16) : lctx.inp_dflash_kq_mask;
|
||||
ggml_tensor * dflash_kq_mask_swa = nullptr;
|
||||
lctx.inp_dflash_kq_mask_swa = nullptr;
|
||||
lctx.dflash_kq_mask_swa_tensor = nullptr;
|
||||
if (have_swa_layers && hparams.n_swa > 0) {
|
||||
lctx.inp_dflash_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv_total, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
lctx.dflash_kq_mask_swa_tensor = lctx.inp_dflash_kq_mask_swa;
|
||||
ggml_set_input(lctx.inp_dflash_kq_mask_swa);
|
||||
cb(lctx.inp_dflash_kq_mask_swa, "dflash_kq_mask_swa", -1);
|
||||
dflash_kq_mask_swa = flash_attn ? ggml_cast(ctx0, lctx.inp_dflash_kq_mask_swa, GGML_TYPE_F16) : lctx.inp_dflash_kq_mask_swa;
|
||||
}
|
||||
|
||||
ggml_tensor * fused_target = nullptr;
|
||||
ggml_tensor * pos_ctx = nullptr;
|
||||
@ -137,6 +156,7 @@ ggml_cgraph * llm_build_context::build_dflash() {
|
||||
Vcur_noise = ggml_reshape_3d(ctx0, Vcur_noise, n_embd_head_v, n_head_kv, n_tokens);
|
||||
cb(Vcur_noise, "Vcur_noise", il);
|
||||
|
||||
const int64_t t_cache_read_us = use_kv_cache ? ggml_time_us() : 0;
|
||||
ggml_tensor * Kcur_ctx = nullptr;
|
||||
ggml_tensor * Vcur_ctx = nullptr;
|
||||
if (use_kv_cache) {
|
||||
@ -164,6 +184,11 @@ ggml_cgraph * llm_build_context::build_dflash() {
|
||||
Kcur = ggml_pad(ctx0, Kcur, 0, 0, (int) n_kv_pad, 0);
|
||||
Vcur = ggml_pad(ctx0, Vcur, 0, 0, (int) n_kv_pad, 0);
|
||||
}
|
||||
if (use_kv_cache) {
|
||||
profile.graph_kv_cache_read_concat_pad_us += (uint64_t) (ggml_time_us() - t_cache_read_us);
|
||||
profile.graph_kv_cache_read_concat_pad_calls++;
|
||||
profile.graph_kv_cache_cached_bytes += ggml_nbytes(lctx.dflash_k_ctx_cache[(size_t) il]) + ggml_nbytes(lctx.dflash_v_ctx_cache[(size_t) il]);
|
||||
}
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
cb(Qcur, "Qcur", il);
|
||||
@ -173,11 +198,14 @@ ggml_cgraph * llm_build_context::build_dflash() {
|
||||
ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
|
||||
ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
|
||||
ggml_tensor * v = ggml_cont(ctx0, ggml_permute(ctx0, Vcur, 0, 2, 1, 3));
|
||||
ggml_tensor * dflash_kq_mask_l = (hparams.swa_layers[il] && dflash_kq_mask_swa != nullptr)
|
||||
? dflash_kq_mask_swa
|
||||
: dflash_kq_mask_full;
|
||||
cb(q, "q", il);
|
||||
cb(k, "k", il);
|
||||
cb(v, "v", il);
|
||||
|
||||
cur = ggml_flash_attn_ext(ctx0, q, k, v, dflash_kq_mask, kq_scale, hparams.f_max_alibi_bias,
|
||||
cur = ggml_flash_attn_ext(ctx0, q, k, v, dflash_kq_mask_l, kq_scale, hparams.f_max_alibi_bias,
|
||||
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
|
||||
cb(cur, "flash_attn", il);
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
@ -289,22 +289,26 @@ struct llama_context {
|
||||
std::vector<float> dflash_feature_view_buffer;
|
||||
std::vector<llama_pos> dflash_pos_ctx_data;
|
||||
std::vector<float> dflash_kq_mask_data;
|
||||
std::vector<float> dflash_kq_mask_swa_data;
|
||||
int32_t dflash_visible_cross_ctx = 0;
|
||||
std::vector<struct ggml_tensor *> dflash_k_ctx_cache;
|
||||
std::vector<struct ggml_tensor *> dflash_v_ctx_cache;
|
||||
struct ggml_context * dflash_cache_ctx = nullptr;
|
||||
ggml_backend_buffer_t dflash_cache_buf = nullptr;
|
||||
std::vector<ggml_backend_buffer_t> dflash_cache_bufs;
|
||||
std::vector<uint8_t> dflash_buf_compute_meta;
|
||||
ggml_backend_sched_t dflash_sched = nullptr;
|
||||
struct ggml_tensor * dflash_kv_input_target_features = nullptr;
|
||||
struct ggml_tensor * dflash_kv_input_pos_ctx = nullptr;
|
||||
struct ggml_tensor * dflash_kq_mask_tensor = nullptr;
|
||||
struct ggml_tensor * dflash_kq_mask_swa_tensor = nullptr;
|
||||
|
||||
struct dflash_capture_state {
|
||||
std::vector<int32_t> layer_ids;
|
||||
std::vector<std::vector<float>> layer_rows;
|
||||
int32_t row_count = 0;
|
||||
int32_t row_width = 0;
|
||||
uint64_t capture_batch_id = 0;
|
||||
std::vector<uint64_t> layer_seen_batch_id;
|
||||
ggml_backend_sched_eval_callback prev_cb_eval = nullptr;
|
||||
void * prev_cb_eval_user_data = nullptr;
|
||||
};
|
||||
@ -333,6 +337,7 @@ struct llama_context {
|
||||
struct ggml_tensor * inp_dflash_target_features = nullptr; // F32 [n_target_features, cross_ctx]
|
||||
struct ggml_tensor * inp_dflash_pos_ctx = nullptr; // I32 [cross_ctx]
|
||||
struct ggml_tensor * inp_dflash_kq_mask = nullptr; // F32 [cross_ctx + n_batch, GGML_PAD(n_batch)]
|
||||
struct ggml_tensor * inp_dflash_kq_mask_swa = nullptr; // F32 [cross_ctx + n_batch, GGML_PAD(n_batch)]
|
||||
|
||||
ggml_backend_t ggml_backend_by_name(const char * name);
|
||||
|
||||
|
||||
@ -79,6 +79,17 @@ static bool load_dflash_target_layer_ids(
|
||||
} else {
|
||||
hparams.dflash_target_layer_ids[i] = ((const uint32_t *) data)[i];
|
||||
}
|
||||
|
||||
const uint32_t id = hparams.dflash_target_layer_ids[i];
|
||||
|
||||
for (uint32_t j = 0; j < i; ++j) {
|
||||
if (hparams.dflash_target_layer_ids[j] == id) {
|
||||
throw std::runtime_error(format(
|
||||
"dflash: %s contains duplicate layer id %u",
|
||||
key.c_str(),
|
||||
id));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
@ -92,19 +103,8 @@ static void validate_dflash_hparams(llama_hparams & hparams, llm_arch arch) {
|
||||
throw std::runtime_error(format("%s: dflash target_layer_ids are required", llama_model_arch_name(arch)));
|
||||
}
|
||||
|
||||
if (arch == LLM_ARCH_DFLASH_DRAFT && hparams.n_embd > 0) {
|
||||
const uint32_t expected_n_target_features = hparams.n_embd * hparams.dflash_n_target_layers;
|
||||
if (expected_n_target_features > 0 && hparams.dflash_n_target_features != expected_n_target_features) {
|
||||
LLAMA_LOG_WARN(
|
||||
"%s: overriding dflash n_target_features from %u to %u based on n_embd=%u and n_target_layers=%u\n",
|
||||
llama_model_arch_name(arch),
|
||||
hparams.dflash_n_target_features,
|
||||
expected_n_target_features,
|
||||
hparams.n_embd,
|
||||
hparams.dflash_n_target_layers);
|
||||
hparams.dflash_n_target_features = expected_n_target_features;
|
||||
}
|
||||
}
|
||||
// DFlash feature width is target-model specific. Keep the serialized metadata intact here
|
||||
// and validate it against the live target model during DFlash init.
|
||||
|
||||
if (hparams.dflash_n_target_features == 0) {
|
||||
throw std::runtime_error(format(
|
||||
|
||||
@ -113,6 +113,53 @@ int32_t llama_model_dflash_target_layer_ids(
|
||||
return n_layers;
|
||||
}
|
||||
|
||||
int32_t llama_model_dflash_target_mask_token_id(const struct llama_model * model) {
|
||||
if (model == nullptr) {
|
||||
return (int32_t) LLAMA_TOKEN_NULL;
|
||||
}
|
||||
|
||||
return (int32_t) model->vocab.token_mask();
|
||||
}
|
||||
|
||||
int32_t llama_model_dflash_io_mode(
|
||||
const struct llama_model * draft_model,
|
||||
const struct llama_model * target_model) {
|
||||
if (draft_model == nullptr || target_model == nullptr || draft_model->arch != LLM_ARCH_DFLASH_DRAFT) {
|
||||
return LLAMA_DFLASH_IO_MODE_INVALID;
|
||||
}
|
||||
|
||||
const ggml_tensor * target_output = target_model->output != nullptr ? target_model->output : target_model->tok_embd;
|
||||
if (draft_model->tok_embd == nullptr || draft_model->output == nullptr || target_model->tok_embd == nullptr || target_output == nullptr) {
|
||||
return LLAMA_DFLASH_IO_MODE_INVALID;
|
||||
}
|
||||
|
||||
const bool shared_tok = draft_model->tok_embd == target_model->tok_embd;
|
||||
const bool shared_output = draft_model->output == target_output;
|
||||
if (shared_tok && shared_output) {
|
||||
return LLAMA_DFLASH_IO_MODE_SHARED;
|
||||
}
|
||||
|
||||
if (!shared_tok && !shared_output) {
|
||||
return LLAMA_DFLASH_IO_MODE_SELF_CONTAINED;
|
||||
}
|
||||
|
||||
return LLAMA_DFLASH_IO_MODE_MIXED;
|
||||
}
|
||||
|
||||
bool llama_model_dflash_io_tensors_match(
|
||||
const struct llama_model * draft_model,
|
||||
int32_t n_embd,
|
||||
int32_t n_vocab) {
|
||||
if (draft_model == nullptr || draft_model->tok_embd == nullptr || draft_model->output == nullptr || n_embd <= 0 || n_vocab <= 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return (int32_t) draft_model->tok_embd->ne[0] == n_embd &&
|
||||
(int32_t) draft_model->tok_embd->ne[1] == n_vocab &&
|
||||
(int32_t) draft_model->output->ne[0] == n_embd &&
|
||||
(int32_t) draft_model->output->ne[1] == n_vocab;
|
||||
}
|
||||
|
||||
bool llama_model_share_dflash_io_tensors(
|
||||
struct llama_model * draft_model,
|
||||
const struct llama_model * target_model) {
|
||||
@ -323,11 +370,19 @@ static bool llama_dflash_capture_eval_callback(struct ggml_tensor * tensor, bool
|
||||
}
|
||||
|
||||
auto & capture = *ctx->dflash_capture;
|
||||
if (capture.capture_batch_id == 0) {
|
||||
capture.capture_batch_id = 1;
|
||||
}
|
||||
if (capture.layer_seen_batch_id.size() != capture.layer_ids.size()) {
|
||||
capture.layer_seen_batch_id.assign(capture.layer_ids.size(), 0);
|
||||
}
|
||||
|
||||
auto & rows = capture.layer_rows[(size_t) layer_idx];
|
||||
rows.resize((size_t) row_count * (size_t) row_width);
|
||||
ggml_backend_tensor_get(tensor, rows.data(), 0, ggml_nbytes(tensor));
|
||||
capture.row_width = row_width;
|
||||
capture.row_count = row_count;
|
||||
capture.layer_seen_batch_id[(size_t) layer_idx] = capture.capture_batch_id;
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -342,6 +397,7 @@ bool llama_set_dflash_capture_layers(
|
||||
auto capture = std::make_unique<llama_context::dflash_capture_state>();
|
||||
capture->layer_ids.assign(layer_ids, layer_ids + n_layers);
|
||||
capture->layer_rows.resize((size_t) n_layers);
|
||||
capture->layer_seen_batch_id.assign((size_t) n_layers, 0);
|
||||
capture->prev_cb_eval = ctx->cparams.cb_eval;
|
||||
capture->prev_cb_eval_user_data = ctx->cparams.cb_eval_user_data;
|
||||
ctx->dflash_capture = std::move(capture);
|
||||
@ -380,6 +436,18 @@ void llama_clear_dflash_capture(struct llama_context * ctx) {
|
||||
}
|
||||
}
|
||||
|
||||
void llama_begin_dflash_capture_batch(struct llama_context * ctx) {
|
||||
if (ctx == nullptr || !ctx->dflash_capture) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto & capture = *ctx->dflash_capture;
|
||||
capture.capture_batch_id++;
|
||||
capture.row_count = 0;
|
||||
capture.row_width = 0;
|
||||
std::fill(capture.layer_seen_batch_id.begin(), capture.layer_seen_batch_id.end(), 0);
|
||||
}
|
||||
|
||||
void llama_finish_dflash_capture_batch(
|
||||
struct llama_context * ctx,
|
||||
bool is_prompt_warmup) {
|
||||
@ -420,7 +488,35 @@ static bool llama_spec_prepare_dflash_capture(
|
||||
return false;
|
||||
}
|
||||
|
||||
if (capture.capture_batch_id == 0 || capture.layer_seen_batch_id.size() != (size_t) n_layers) {
|
||||
profile.capture_prepare_failures++;
|
||||
profile.capture_layer_batch_mismatch++;
|
||||
if (profile.capture_layer_batch_mismatch <= 3) {
|
||||
LLAMA_LOG_WARN("%s: DFlash capture batch markers are not initialized (batch_id=%llu layers=%zu expected=%d)\n",
|
||||
__func__,
|
||||
(unsigned long long) capture.capture_batch_id,
|
||||
capture.layer_seen_batch_id.size(),
|
||||
n_layers);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
for (int32_t layer_idx = 0; layer_idx < n_layers; ++layer_idx) {
|
||||
if (capture.layer_seen_batch_id[(size_t) layer_idx] != capture.capture_batch_id) {
|
||||
profile.capture_prepare_failures++;
|
||||
profile.capture_layer_batch_mismatch++;
|
||||
if (profile.capture_layer_batch_mismatch <= 3) {
|
||||
LLAMA_LOG_WARN("%s: DFlash capture is stale for layer %d (seen_batch=%llu current_batch=%llu rows=%d width=%d)\n",
|
||||
__func__,
|
||||
capture.layer_ids[(size_t) layer_idx],
|
||||
(unsigned long long) capture.layer_seen_batch_id[(size_t) layer_idx],
|
||||
(unsigned long long) capture.capture_batch_id,
|
||||
row_count,
|
||||
row_width);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto & rows = capture.layer_rows[(size_t) layer_idx];
|
||||
if (rows.size() != (size_t) row_count * (size_t) row_width) {
|
||||
profile.capture_prepare_failures++;
|
||||
@ -595,11 +691,41 @@ static void llama_dflash_contract_log_output_indices(
|
||||
have_capture ? "true" : "false");
|
||||
}
|
||||
|
||||
static bool llama_spec_materialize_dflash_rows_prepared(
|
||||
struct llama_context * ctx,
|
||||
int32_t row_count,
|
||||
int32_t row_width,
|
||||
int32_t n_layers,
|
||||
const std::vector<int32_t> & row_indices,
|
||||
std::vector<float> & rows_out,
|
||||
int32_t & combined_width);
|
||||
|
||||
static bool llama_spec_materialize_dflash_rows(
|
||||
struct llama_context * ctx,
|
||||
const std::vector<int32_t> & row_indices,
|
||||
std::vector<float> & rows_out,
|
||||
int32_t & combined_width) {
|
||||
int32_t row_count = 0;
|
||||
int32_t row_width = 0;
|
||||
int32_t n_layers = 0;
|
||||
if (!llama_spec_prepare_dflash_capture(ctx, row_count, row_width, n_layers)) {
|
||||
if (ctx != nullptr) {
|
||||
ctx->dflash_profile.capture_materialize_failures++;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
return llama_spec_materialize_dflash_rows_prepared(ctx, row_count, row_width, n_layers, row_indices, rows_out, combined_width);
|
||||
}
|
||||
|
||||
static bool llama_spec_materialize_dflash_rows_prepared(
|
||||
struct llama_context * ctx,
|
||||
int32_t row_count,
|
||||
int32_t row_width,
|
||||
int32_t n_layers,
|
||||
const std::vector<int32_t> & row_indices,
|
||||
std::vector<float> & rows_out,
|
||||
int32_t & combined_width) {
|
||||
rows_out.clear();
|
||||
combined_width = 0;
|
||||
if (ctx == nullptr || row_indices.empty()) {
|
||||
@ -610,10 +736,7 @@ static bool llama_spec_materialize_dflash_rows(
|
||||
profile.capture_materialize_calls++;
|
||||
const int64_t t_start_us = ggml_time_us();
|
||||
|
||||
int32_t row_count = 0;
|
||||
int32_t row_width = 0;
|
||||
int32_t n_layers = 0;
|
||||
if (!llama_spec_prepare_dflash_capture(ctx, row_count, row_width, n_layers)) {
|
||||
if (row_count <= 0 || row_width <= 0 || n_layers <= 0 || ctx->dflash_capture == nullptr) {
|
||||
profile.capture_materialize_failures++;
|
||||
return false;
|
||||
}
|
||||
@ -735,7 +858,7 @@ bool llama_spec_get_dflash_feature_view(
|
||||
|
||||
view = {};
|
||||
view.kind = LLAMA_SPEC_FEATURE_HIDDEN_STATE;
|
||||
if (!llama_spec_materialize_dflash_rows(ctx, row_indices, ctx->dflash_feature_view_buffer, view.width)) {
|
||||
if (!llama_spec_materialize_dflash_rows_prepared(ctx, row_count, row_width, n_layers, row_indices, ctx->dflash_feature_view_buffer, view.width)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -808,7 +931,7 @@ bool llama_spec_get_dflash_feature_view_for_seq(
|
||||
|
||||
view = {};
|
||||
view.kind = LLAMA_SPEC_FEATURE_HIDDEN_STATE;
|
||||
if (!llama_spec_materialize_dflash_rows(ctx, row_indices, ctx->dflash_feature_view_buffer, view.width)) {
|
||||
if (!llama_spec_materialize_dflash_rows_prepared(ctx, row_count, row_width, n_layers, row_indices, ctx->dflash_feature_view_buffer, view.width)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
@ -24,6 +24,14 @@ struct llama_spec_feature_view {
|
||||
};
|
||||
|
||||
struct llama_dflash_profile_stats {
|
||||
uint64_t decode_output_reserve_calls = 0;
|
||||
uint64_t decode_output_reserve_us = 0;
|
||||
uint64_t decode_output_reserve_reallocs = 0;
|
||||
uint64_t decode_output_reserve_realloc_bytes = 0;
|
||||
uint64_t decode_prepare_calls = 0;
|
||||
uint64_t decode_prepare_us = 0;
|
||||
uint64_t decode_prepare_failures = 0;
|
||||
|
||||
uint64_t set_target_copy_calls = 0;
|
||||
uint64_t set_target_copy_us = 0;
|
||||
uint64_t set_target_rows = 0;
|
||||
@ -35,6 +43,7 @@ struct llama_dflash_profile_stats {
|
||||
uint64_t capture_prepare_sync_us = 0;
|
||||
uint64_t capture_prepare_failures = 0;
|
||||
uint64_t capture_layer_shape_mismatch = 0;
|
||||
uint64_t capture_layer_batch_mismatch = 0;
|
||||
uint64_t capture_prompt_batches = 0;
|
||||
uint64_t capture_prompt_shape_changes = 0;
|
||||
uint64_t capture_verify_batches = 0;
|
||||
@ -50,7 +59,17 @@ struct llama_dflash_profile_stats {
|
||||
uint64_t graph_feature_copy_us = 0;
|
||||
uint64_t graph_pos_copy_us = 0;
|
||||
uint64_t graph_mask_build_us = 0;
|
||||
uint64_t graph_kv_cache_build_us = 0;
|
||||
uint64_t graph_kv_cache_reserve_us = 0;
|
||||
uint64_t graph_kv_cache_reset_us = 0;
|
||||
uint64_t graph_kv_cache_alloc_us = 0;
|
||||
uint64_t graph_kv_cache_feature_upload_us = 0;
|
||||
uint64_t graph_kv_cache_pos_upload_us = 0;
|
||||
uint64_t graph_kv_cache_compute_us = 0;
|
||||
uint64_t graph_kv_cache_sync_us = 0;
|
||||
uint64_t graph_kv_cache_read_concat_pad_us = 0;
|
||||
uint64_t graph_kv_cache_read_concat_pad_calls = 0;
|
||||
uint64_t graph_kv_cache_cached_bytes = 0;
|
||||
uint64_t graph_kv_cache_calls = 0;
|
||||
uint64_t graph_feature_bytes = 0;
|
||||
uint64_t graph_pos_bytes = 0;
|
||||
@ -68,6 +87,7 @@ struct llama_dflash_profile_stats {
|
||||
int32_t last_left_pad = 0;
|
||||
int32_t last_n_tokens = 0;
|
||||
int32_t last_n_kv_total = 0;
|
||||
int32_t last_kv_cache_host_layers = 0;
|
||||
int32_t capture_prompt_last_rows = 0;
|
||||
int32_t capture_prompt_last_width = 0;
|
||||
int32_t capture_verify_last_rows = 0;
|
||||
@ -104,6 +124,24 @@ int32_t llama_model_dflash_target_layer_ids(
|
||||
int32_t * layer_ids,
|
||||
int32_t capacity);
|
||||
|
||||
enum llama_dflash_io_mode {
|
||||
LLAMA_DFLASH_IO_MODE_INVALID = 0,
|
||||
LLAMA_DFLASH_IO_MODE_SHARED,
|
||||
LLAMA_DFLASH_IO_MODE_SELF_CONTAINED,
|
||||
LLAMA_DFLASH_IO_MODE_MIXED,
|
||||
};
|
||||
|
||||
int32_t llama_model_dflash_target_mask_token_id(const struct llama_model * model);
|
||||
|
||||
int32_t llama_model_dflash_io_mode(
|
||||
const struct llama_model * draft_model,
|
||||
const struct llama_model * target_model);
|
||||
|
||||
bool llama_model_dflash_io_tensors_match(
|
||||
const struct llama_model * draft_model,
|
||||
int32_t n_embd,
|
||||
int32_t n_vocab);
|
||||
|
||||
bool llama_model_share_dflash_io_tensors(
|
||||
struct llama_model * draft_model,
|
||||
const struct llama_model * target_model);
|
||||
@ -134,6 +172,8 @@ bool llama_set_dflash_capture_layers(
|
||||
|
||||
void llama_clear_dflash_capture(struct llama_context * ctx);
|
||||
|
||||
void llama_begin_dflash_capture_batch(struct llama_context * ctx);
|
||||
|
||||
void llama_finish_dflash_capture_batch(
|
||||
struct llama_context * ctx,
|
||||
bool is_prompt_warmup);
|
||||
|
||||
363
src/llama.cpp
363
src/llama.cpp
@ -565,6 +565,44 @@ void llama_context::reset_scheduler() {
|
||||
prev_mtp.reset();
|
||||
}
|
||||
|
||||
static ggml_backend_buffer_type_t llama_dflash_kv_cache_layer_buft(const llama_context & lctx, int32_t il) {
|
||||
if (il >= 0 && (size_t) il < lctx.model.buft_layer.size() && lctx.model.buft_layer[(size_t) il].buft != nullptr) {
|
||||
return lctx.model.buft_layer[(size_t) il].buft;
|
||||
}
|
||||
|
||||
if (il >= 0 && (size_t) il < lctx.model.layers.size()) {
|
||||
const ggml_tensor * wk = lctx.model.layers[(size_t) il].wk;
|
||||
if (wk != nullptr && wk->buffer != nullptr) {
|
||||
return ggml_backend_buffer_get_type(wk->buffer);
|
||||
}
|
||||
}
|
||||
|
||||
return llama_default_buffer_type_cpu(true);
|
||||
}
|
||||
|
||||
static ggml_backend_t llama_backend_for_tensor(const llama_context & lctx, const ggml_tensor * tensor) {
|
||||
if (tensor == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
|
||||
if (buf == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ggml_backend_buffer_type_t buft = ggml_backend_buffer_get_type(buf);
|
||||
for (ggml_backend_t backend : lctx.backends) {
|
||||
ggml_backend_buffer_type_t backend_buft = ggml_backend_is_cpu(backend)
|
||||
? llama_default_buffer_type_cpu(true)
|
||||
: ggml_backend_get_default_buffer_type(backend);
|
||||
if (backend_buft == buft) {
|
||||
return backend;
|
||||
}
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool llama_context::ensure_dflash_kv_cache_tensors(int32_t cross_ctx) {
|
||||
const int32_t target_cross_ctx = std::max<int32_t>(1, cross_ctx);
|
||||
const int32_t n_layer = model.hparams.n_layer;
|
||||
@ -587,8 +625,6 @@ bool llama_context::ensure_dflash_kv_cache_tensors(int32_t cross_ctx) {
|
||||
dflash_buf_compute_meta.clear();
|
||||
}
|
||||
|
||||
ggml_backend_buffer_type_t buft = llama_default_buffer_type_cpu(true);
|
||||
|
||||
ggml_init_params params = {
|
||||
/*.mem_size =*/ (size_t) (2 * std::max(1, n_layer)) * ggml_tensor_overhead(),
|
||||
/*.mem_buffer =*/ nullptr,
|
||||
@ -602,7 +638,21 @@ bool llama_context::ensure_dflash_kv_cache_tensors(int32_t cross_ctx) {
|
||||
|
||||
dflash_k_ctx_cache.resize((size_t) n_layer);
|
||||
dflash_v_ctx_cache.resize((size_t) n_layer);
|
||||
dflash_cache_bufs.clear();
|
||||
dflash_cache_bufs.reserve((size_t) std::max(1, n_layer) * 2);
|
||||
int32_t host_layers = 0;
|
||||
const char * first_buft_name = nullptr;
|
||||
const char * last_buft_name = nullptr;
|
||||
for (int32_t il = 0; il < n_layer; ++il) {
|
||||
ggml_backend_buffer_type_t layer_buft = llama_dflash_kv_cache_layer_buft(*this, il);
|
||||
if (ggml_backend_buft_is_host(layer_buft)) {
|
||||
host_layers++;
|
||||
}
|
||||
if (first_buft_name == nullptr) {
|
||||
first_buft_name = ggml_backend_buft_name(layer_buft);
|
||||
}
|
||||
last_buft_name = ggml_backend_buft_name(layer_buft);
|
||||
|
||||
dflash_k_ctx_cache[(size_t) il] = ggml_new_tensor_3d(dflash_cache_ctx, GGML_TYPE_F32, n_embd_head_k, n_head_kv, target_cross_ctx);
|
||||
dflash_v_ctx_cache[(size_t) il] = ggml_new_tensor_3d(dflash_cache_ctx, GGML_TYPE_F32, n_embd_head_v, n_head_kv, target_cross_ctx);
|
||||
if (dflash_k_ctx_cache[(size_t) il] == nullptr || dflash_v_ctx_cache[(size_t) il] == nullptr) {
|
||||
@ -614,15 +664,39 @@ bool llama_context::ensure_dflash_kv_cache_tensors(int32_t cross_ctx) {
|
||||
ggml_set_input(dflash_v_ctx_cache[(size_t) il]);
|
||||
ggml_format_name(dflash_k_ctx_cache[(size_t) il], "dflash_k_ctx_cache_%d", il);
|
||||
ggml_format_name(dflash_v_ctx_cache[(size_t) il], "dflash_v_ctx_cache_%d", il);
|
||||
|
||||
const size_t k_bytes = ggml_backend_buft_get_alloc_size(layer_buft, dflash_k_ctx_cache[(size_t) il]);
|
||||
ggml_backend_buffer_t k_buf = ggml_backend_buft_alloc_buffer(layer_buft, k_bytes);
|
||||
if (k_buf == nullptr) {
|
||||
free_dflash_kv_cache_tensors();
|
||||
return false;
|
||||
}
|
||||
ggml_backend_buffer_set_usage(k_buf, GGML_BACKEND_BUFFER_USAGE_COMPUTE);
|
||||
ggml_backend_tensor_alloc(k_buf, dflash_k_ctx_cache[(size_t) il], ggml_backend_buffer_get_base(k_buf));
|
||||
ggml_backend_buffer_clear(k_buf, 0);
|
||||
dflash_cache_bufs.push_back(k_buf);
|
||||
|
||||
const size_t v_bytes = ggml_backend_buft_get_alloc_size(layer_buft, dflash_v_ctx_cache[(size_t) il]);
|
||||
ggml_backend_buffer_t v_buf = ggml_backend_buft_alloc_buffer(layer_buft, v_bytes);
|
||||
if (v_buf == nullptr) {
|
||||
free_dflash_kv_cache_tensors();
|
||||
return false;
|
||||
}
|
||||
ggml_backend_buffer_set_usage(v_buf, GGML_BACKEND_BUFFER_USAGE_COMPUTE);
|
||||
ggml_backend_tensor_alloc(v_buf, dflash_v_ctx_cache[(size_t) il], ggml_backend_buffer_get_base(v_buf));
|
||||
ggml_backend_buffer_clear(v_buf, 0);
|
||||
dflash_cache_bufs.push_back(v_buf);
|
||||
}
|
||||
|
||||
dflash_cache_buf = ggml_backend_alloc_ctx_tensors_from_buft(dflash_cache_ctx, buft);
|
||||
if (dflash_cache_buf == nullptr) {
|
||||
free_dflash_kv_cache_tensors();
|
||||
return false;
|
||||
}
|
||||
dflash_profile.last_kv_cache_host_layers = host_layers;
|
||||
LLAMA_LOG_INFO("%s: DFlash K/V cache placement cross_ctx=%d host_layers=%d/%d first=%s last=%s\n",
|
||||
__func__,
|
||||
target_cross_ctx,
|
||||
host_layers,
|
||||
n_layer,
|
||||
first_buft_name != nullptr ? first_buft_name : "(none)",
|
||||
last_buft_name != nullptr ? last_buft_name : "(none)");
|
||||
|
||||
ggml_backend_buffer_clear(dflash_cache_buf, 0);
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -632,11 +706,14 @@ void llama_context::free_dflash_kv_cache_tensors() {
|
||||
dflash_kv_input_target_features = nullptr;
|
||||
dflash_kv_input_pos_ctx = nullptr;
|
||||
dflash_kq_mask_tensor = nullptr;
|
||||
dflash_kq_mask_swa_tensor = nullptr;
|
||||
|
||||
if (dflash_cache_buf != nullptr) {
|
||||
ggml_backend_buffer_free(dflash_cache_buf);
|
||||
dflash_cache_buf = nullptr;
|
||||
for (ggml_backend_buffer_t buf : dflash_cache_bufs) {
|
||||
if (buf != nullptr) {
|
||||
ggml_backend_buffer_free(buf);
|
||||
}
|
||||
}
|
||||
dflash_cache_bufs.clear();
|
||||
if (dflash_cache_ctx != nullptr) {
|
||||
ggml_free(dflash_cache_ctx);
|
||||
dflash_cache_ctx = nullptr;
|
||||
@ -5087,6 +5164,110 @@ static bool prepare_mtp_graph_inputs(
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool dflash_layer_has_attention_bias(const llama_layer & layer) {
|
||||
return layer.bq != nullptr ||
|
||||
layer.bk != nullptr ||
|
||||
layer.bv != nullptr ||
|
||||
layer.bo != nullptr ||
|
||||
layer.bqkv != nullptr ||
|
||||
layer.bqk != nullptr ||
|
||||
layer.bkv != nullptr;
|
||||
}
|
||||
|
||||
static bool validate_dflash_graph_contract(const llama_context & lctx) {
|
||||
const auto & model = lctx.model;
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
auto rope_dim_for_layer = [&hparams](int32_t il) -> uint32_t {
|
||||
if (hparams.rope_dim_per_layer[(size_t) il] != 0) {
|
||||
return hparams.rope_dim_per_layer[(size_t) il];
|
||||
}
|
||||
|
||||
return hparams.swa_layers[(size_t) il] ? hparams.n_rot_swa : hparams.n_rot;
|
||||
};
|
||||
|
||||
auto rope_base_for_layer = [&hparams](int32_t il) -> float {
|
||||
if (hparams.has_rope_freq_base_per_layer) {
|
||||
return hparams.rope_freq_base_per_layer[(size_t) il];
|
||||
}
|
||||
|
||||
return hparams.swa_layers[(size_t) il] ? hparams.rope_freq_base_train_swa : hparams.rope_freq_base_train;
|
||||
};
|
||||
|
||||
auto rope_scale_for_layer = [&hparams](int32_t il) -> float {
|
||||
return hparams.swa_layers[(size_t) il] ? hparams.rope_freq_scale_train_swa : hparams.rope_freq_scale_train;
|
||||
};
|
||||
|
||||
const uint32_t ref_n_head = hparams.n_head(0);
|
||||
const uint32_t ref_n_head_kv = hparams.n_head_kv(0);
|
||||
const uint32_t ref_n_embd_head_k = hparams.n_embd_head_k(0);
|
||||
const uint32_t ref_n_embd_head_v = hparams.n_embd_head_v(0);
|
||||
const uint32_t ref_rope_dim = rope_dim_for_layer(0);
|
||||
const float ref_rope_base = rope_base_for_layer(0);
|
||||
const float ref_rope_scale = rope_scale_for_layer(0);
|
||||
|
||||
for (int32_t il = 0; il < (int32_t) hparams.n_layer; ++il) {
|
||||
if (hparams.n_head((uint32_t) il) != ref_n_head ||
|
||||
hparams.n_head_kv((uint32_t) il) != ref_n_head_kv ||
|
||||
hparams.n_embd_head_k(il) != ref_n_embd_head_k ||
|
||||
hparams.n_embd_head_v(il) != ref_n_embd_head_v) {
|
||||
LLAMA_LOG_ERROR("%s: DFlash graph assumes layer-invariant head config, but layer %d differs (n_head=%u/%u n_head_kv=%u/%u head_k=%u/%u head_v=%u/%u)\n",
|
||||
__func__,
|
||||
il,
|
||||
hparams.n_head((uint32_t) il), ref_n_head,
|
||||
hparams.n_head_kv((uint32_t) il), ref_n_head_kv,
|
||||
hparams.n_embd_head_k(il), ref_n_embd_head_k,
|
||||
hparams.n_embd_head_v(il), ref_n_embd_head_v);
|
||||
return false;
|
||||
}
|
||||
|
||||
const uint32_t rope_dim = rope_dim_for_layer(il);
|
||||
const float rope_base = rope_base_for_layer(il);
|
||||
const float rope_scale = rope_scale_for_layer(il);
|
||||
if (rope_dim != ref_rope_dim || std::fabs(rope_base - ref_rope_base) > 1e-6f || std::fabs(rope_scale - ref_rope_scale) > 1e-6f) {
|
||||
LLAMA_LOG_ERROR("%s: DFlash graph assumes layer-invariant RoPE config, but layer %d differs (dim=%u/%u base=%g/%g scale=%g/%g)\n",
|
||||
__func__,
|
||||
il,
|
||||
rope_dim, ref_rope_dim,
|
||||
(double) rope_base, (double) ref_rope_base,
|
||||
(double) rope_scale, (double) ref_rope_scale);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (model.layers[(size_t) il].attn_norm == nullptr ||
|
||||
model.layers[(size_t) il].attn_q_norm == nullptr ||
|
||||
model.layers[(size_t) il].attn_k_norm == nullptr) {
|
||||
LLAMA_LOG_ERROR("%s: DFlash graph requires attn_norm, attn_q_norm, and attn_k_norm weights, but layer %d is missing one or more of them\n",
|
||||
__func__, il);
|
||||
return false;
|
||||
}
|
||||
|
||||
const bool has_q_norm = model.layers[(size_t) il].attn_q_norm != nullptr;
|
||||
const bool has_k_norm = model.layers[(size_t) il].attn_k_norm != nullptr;
|
||||
if (has_q_norm != has_k_norm) {
|
||||
LLAMA_LOG_ERROR("%s: DFlash graph requires symmetric Q/K norm presence, but layer %d has q_norm=%d k_norm=%d\n",
|
||||
__func__, il, (int) has_q_norm, (int) has_k_norm);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (model.layers[(size_t) il].attn_norm_b != nullptr ||
|
||||
model.layers[(size_t) il].attn_q_norm_b != nullptr ||
|
||||
model.layers[(size_t) il].attn_k_norm_b != nullptr) {
|
||||
LLAMA_LOG_ERROR("%s: DFlash graph does not implement norm-bias tensors, but layer %d requires attn_norm_b/q_norm_b/k_norm_b\n",
|
||||
__func__, il);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (dflash_layer_has_attention_bias(model.layers[(size_t) il])) {
|
||||
LLAMA_LOG_ERROR("%s: DFlash graph does not implement attention bias tensors, but layer %d requires them\n",
|
||||
__func__, il);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool prepare_dflash_graph_inputs(
|
||||
struct llama_context & lctx,
|
||||
uint32_t n_tokens) {
|
||||
@ -5095,16 +5276,23 @@ static bool prepare_dflash_graph_inputs(
|
||||
std::strcmp(dflash_kv_cache_env, "0") != 0 &&
|
||||
std::strcmp(dflash_kv_cache_env, "false") != 0 &&
|
||||
std::strcmp(dflash_kv_cache_env, "off") != 0;
|
||||
auto & profile = lctx.dflash_profile;
|
||||
const int32_t cross_ctx = lctx.dflash_visible_cross_ctx > 0
|
||||
? lctx.dflash_visible_cross_ctx
|
||||
: std::max<int32_t>(1, (int32_t) lctx.cparams.n_ctx - (int32_t) lctx.model.hparams.dflash_block_size);
|
||||
ggml_tensor * kq_mask = lctx.dflash_kq_mask_tensor;
|
||||
ggml_tensor * kq_mask_swa = lctx.dflash_kq_mask_swa_tensor;
|
||||
|
||||
if (kq_mask == nullptr) {
|
||||
LLAMA_LOG_ERROR("%s: DFlash graph inputs are not initialized\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!validate_dflash_graph_contract(lctx)) {
|
||||
profile.graph_shape_failures++;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (use_kv_cache) {
|
||||
if (!lctx.ensure_dflash_kv_cache_tensors(cross_ctx) || lctx.dflash_k_ctx_cache.empty() || lctx.dflash_v_ctx_cache.empty()) {
|
||||
LLAMA_LOG_ERROR("%s: DFlash K/V cache inputs are not initialized\n", __func__);
|
||||
@ -5126,7 +5314,6 @@ static bool prepare_dflash_graph_inputs(
|
||||
: (lctx.inp_dflash_target_features != nullptr ? (int32_t) lctx.inp_dflash_target_features->ne[1] : 0);
|
||||
const int32_t n_mask_tokens = (int32_t) kq_mask->ne[1];
|
||||
const int32_t n_kv_total = (int32_t) kq_mask->ne[0];
|
||||
auto & profile = lctx.dflash_profile;
|
||||
const int64_t t_total_us = ggml_time_us();
|
||||
|
||||
profile.graph_prepare_calls++;
|
||||
@ -5186,36 +5373,32 @@ static bool prepare_dflash_graph_inputs(
|
||||
const int64_t t_pos_us = ggml_time_us();
|
||||
lctx.dflash_pos_ctx_data.resize((size_t) cross_ctx);
|
||||
std::fill(lctx.dflash_pos_ctx_data.begin(), lctx.dflash_pos_ctx_data.end(), 0);
|
||||
if (src_pos != nullptr && total_positions == (size_t) n_rows) {
|
||||
bool monotonic = true;
|
||||
for (int32_t i = 1; i < n_rows; ++i) {
|
||||
if (src_pos[i] <= src_pos[i - 1]) {
|
||||
monotonic = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!monotonic) {
|
||||
profile.graph_pos_non_monotonic++;
|
||||
if (profile.graph_pos_non_monotonic <= 3) {
|
||||
LLAMA_LOG_WARN("%s: DFlash target positions are not strictly increasing (rows=%d first=%d last=%d)\n",
|
||||
__func__, n_rows, (int) src_pos[0], (int) src_pos[n_rows - 1]);
|
||||
}
|
||||
}
|
||||
profile.last_pos_first = src_pos[0];
|
||||
profile.last_pos_last = src_pos[n_rows - 1];
|
||||
std::copy(src_pos, src_pos + n_rows, lctx.dflash_pos_ctx_data.begin() + (ptrdiff_t) left_pad);
|
||||
} else {
|
||||
if (src_pos == nullptr || total_positions != (size_t) n_rows) {
|
||||
profile.graph_pos_fallbacks++;
|
||||
profile.graph_shape_failures++;
|
||||
profile.last_pos_first = -1;
|
||||
profile.last_pos_last = -1;
|
||||
if (profile.graph_pos_fallbacks <= 3) {
|
||||
LLAMA_LOG_WARN("%s: using synthetic DFlash positions (rows=%d positions=%zu cross_ctx=%d)\n",
|
||||
LLAMA_LOG_ERROR("%s: missing DFlash target positions (rows=%d positions=%zu cross_ctx=%d)\n",
|
||||
__func__, n_rows, total_positions, cross_ctx);
|
||||
}
|
||||
for (int32_t i = 0; i < n_rows; ++i) {
|
||||
lctx.dflash_pos_ctx_data[(size_t) left_pad + (size_t) i] = i;
|
||||
return false;
|
||||
}
|
||||
|
||||
profile.last_pos_first = src_pos[0];
|
||||
profile.last_pos_last = src_pos[n_rows - 1];
|
||||
for (int32_t i = 1; i < n_rows; ++i) {
|
||||
if (src_pos[i] <= src_pos[i - 1]) {
|
||||
profile.graph_pos_non_monotonic++;
|
||||
profile.graph_shape_failures++;
|
||||
if (profile.graph_pos_non_monotonic <= 3) {
|
||||
LLAMA_LOG_ERROR("%s: DFlash target positions are not strictly increasing (rows=%d first=%d last=%d)\n",
|
||||
__func__, n_rows, (int) src_pos[0], (int) src_pos[n_rows - 1]);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
std::copy(src_pos, src_pos + n_rows, lctx.dflash_pos_ctx_data.begin() + (ptrdiff_t) left_pad);
|
||||
profile.graph_pos_copy_us += (uint64_t) (ggml_time_us() - t_pos_us);
|
||||
profile.graph_pos_bytes += lctx.dflash_pos_ctx_data.size() * sizeof(llama_pos);
|
||||
|
||||
@ -5226,7 +5409,9 @@ static bool prepare_dflash_graph_inputs(
|
||||
lctx.dflash_buf_compute_meta.resize(meta_size);
|
||||
}
|
||||
|
||||
const int64_t t_build_us = ggml_time_us();
|
||||
ggml_cgraph * gf_kv = llm_build_context::llama_build_graph_dflash_kv_cache(lctx);
|
||||
profile.graph_kv_cache_build_us += (uint64_t) (ggml_time_us() - t_build_us);
|
||||
if (gf_kv == nullptr || lctx.dflash_kv_input_target_features == nullptr || lctx.dflash_kv_input_pos_ctx == nullptr) {
|
||||
profile.graph_shape_failures++;
|
||||
LLAMA_LOG_ERROR("%s: failed to build DFlash K/V cache graph\n", __func__);
|
||||
@ -5244,22 +5429,50 @@ static bool prepare_dflash_graph_inputs(
|
||||
}
|
||||
}
|
||||
|
||||
const int64_t t_reserve_us = ggml_time_us();
|
||||
lctx.dflash_sched = ggml_backend_sched_new(lctx.backends.data(), backend_buft.data(), lctx.backends.size(), max_nodes, false);
|
||||
if (lctx.dflash_sched == nullptr || !ggml_backend_sched_reserve(lctx.dflash_sched, gf_kv)) {
|
||||
const bool reserved = lctx.dflash_sched != nullptr && ggml_backend_sched_reserve(lctx.dflash_sched, gf_kv);
|
||||
profile.graph_kv_cache_reserve_us += (uint64_t) (ggml_time_us() - t_reserve_us);
|
||||
if (!reserved) {
|
||||
profile.graph_shape_failures++;
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize DFlash K/V scheduler\n", __func__);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
const int64_t t_reset_us = ggml_time_us();
|
||||
ggml_backend_sched_reset(lctx.dflash_sched);
|
||||
profile.graph_kv_cache_reset_us += (uint64_t) (ggml_time_us() - t_reset_us);
|
||||
|
||||
const int64_t t_alloc_us = ggml_time_us();
|
||||
ggml_backend_sched_alloc_graph(lctx.dflash_sched, gf_kv);
|
||||
ggml_backend_tensor_set(lctx.dflash_kv_input_target_features, lctx.dflash_target_features_padded.data(), 0, ggml_nbytes(lctx.dflash_kv_input_target_features));
|
||||
ggml_backend_tensor_set(lctx.dflash_kv_input_pos_ctx, lctx.dflash_pos_ctx_data.data(), 0, ggml_nbytes(lctx.dflash_kv_input_pos_ctx));
|
||||
profile.graph_kv_cache_alloc_us += (uint64_t) (ggml_time_us() - t_alloc_us);
|
||||
|
||||
ggml_backend_t kv_feature_backend = llama_backend_for_tensor(lctx, lctx.dflash_kv_input_target_features);
|
||||
const int64_t t_feature_upload_us = ggml_time_us();
|
||||
if (kv_feature_backend != nullptr) {
|
||||
ggml_backend_tensor_set_async(kv_feature_backend, lctx.dflash_kv_input_target_features, lctx.dflash_target_features_padded.data(), 0, ggml_nbytes(lctx.dflash_kv_input_target_features));
|
||||
} else {
|
||||
ggml_backend_tensor_set(lctx.dflash_kv_input_target_features, lctx.dflash_target_features_padded.data(), 0, ggml_nbytes(lctx.dflash_kv_input_target_features));
|
||||
}
|
||||
profile.graph_kv_cache_feature_upload_us += (uint64_t) (ggml_time_us() - t_feature_upload_us);
|
||||
|
||||
ggml_backend_t kv_pos_backend = llama_backend_for_tensor(lctx, lctx.dflash_kv_input_pos_ctx);
|
||||
const int64_t t_pos_upload_us = ggml_time_us();
|
||||
if (kv_pos_backend != nullptr) {
|
||||
ggml_backend_tensor_set_async(kv_pos_backend, lctx.dflash_kv_input_pos_ctx, lctx.dflash_pos_ctx_data.data(), 0, ggml_nbytes(lctx.dflash_kv_input_pos_ctx));
|
||||
} else {
|
||||
ggml_backend_tensor_set(lctx.dflash_kv_input_pos_ctx, lctx.dflash_pos_ctx_data.data(), 0, ggml_nbytes(lctx.dflash_kv_input_pos_ctx));
|
||||
}
|
||||
profile.graph_kv_cache_pos_upload_us += (uint64_t) (ggml_time_us() - t_pos_upload_us);
|
||||
|
||||
const int64_t t_kv_cache_us = ggml_time_us();
|
||||
llama_graph_compute_sched(lctx, lctx.dflash_sched, gf_kv, lctx.cparams.n_threads);
|
||||
llama_synchronize(&lctx);
|
||||
profile.graph_kv_cache_compute_us += (uint64_t) (ggml_time_us() - t_kv_cache_us);
|
||||
|
||||
const int64_t t_sync_us = ggml_time_us();
|
||||
ggml_backend_sched_synchronize(lctx.dflash_sched);
|
||||
profile.graph_kv_cache_sync_us += (uint64_t) (ggml_time_us() - t_sync_us);
|
||||
profile.graph_kv_cache_calls++;
|
||||
} else {
|
||||
ggml_backend_tensor_set(lctx.inp_dflash_target_features, lctx.dflash_target_features_padded.data(), 0, ggml_nbytes(lctx.inp_dflash_target_features));
|
||||
@ -5267,28 +5480,66 @@ static bool prepare_dflash_graph_inputs(
|
||||
}
|
||||
|
||||
const int64_t t_mask_us = ggml_time_us();
|
||||
const int32_t full_visible_first = left_pad;
|
||||
const int32_t full_visible_last = cross_ctx + (int32_t) n_tokens - 1;
|
||||
lctx.dflash_kq_mask_data.assign((size_t) n_kv_total * (size_t) n_mask_tokens, -INFINITY);
|
||||
int32_t visible_kv_max = 0;
|
||||
for (uint32_t j = 0; j < n_tokens; ++j) {
|
||||
float * row = lctx.dflash_kq_mask_data.data() + (size_t) j * (size_t) n_kv_total;
|
||||
const int32_t visible_kv = cross_ctx + (int32_t) j + 1;
|
||||
const int32_t visible_kv = cross_ctx + (int32_t) n_tokens;
|
||||
visible_kv_max = std::max(visible_kv_max, visible_kv);
|
||||
profile.graph_visible_kv_sum += (uint64_t) visible_kv;
|
||||
for (int32_t i = left_pad; i < visible_kv; ++i) {
|
||||
for (int32_t i = full_visible_first; i <= full_visible_last; ++i) {
|
||||
row[i] = 0.0f;
|
||||
}
|
||||
}
|
||||
ggml_backend_tensor_set(kq_mask, lctx.dflash_kq_mask_data.data(), 0, ggml_nbytes(kq_mask));
|
||||
profile.graph_mask_build_us += (uint64_t) (ggml_time_us() - t_mask_us);
|
||||
profile.graph_mask_bytes += ggml_nbytes(kq_mask);
|
||||
|
||||
if (kq_mask_swa != nullptr) {
|
||||
lctx.dflash_kq_mask_swa_data.assign((size_t) n_kv_total * (size_t) n_mask_tokens, -INFINITY);
|
||||
const int32_t swa_window = (int32_t) lctx.model.hparams.n_swa;
|
||||
const int32_t draft_pos_base = (int32_t) profile.last_pos_last;
|
||||
for (uint32_t j = 0; j < n_tokens; ++j) {
|
||||
float * row = lctx.dflash_kq_mask_swa_data.data() + (size_t) j * (size_t) n_kv_total;
|
||||
const int32_t q_pos = draft_pos_base + (int32_t) j;
|
||||
|
||||
for (int32_t k = left_pad; k < cross_ctx; ++k) {
|
||||
const int32_t k_pos = (int32_t) lctx.dflash_pos_ctx_data[(size_t) k];
|
||||
if (q_pos - k_pos < swa_window) {
|
||||
row[k] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
for (int32_t k = cross_ctx; k < cross_ctx + (int32_t) n_tokens; ++k) {
|
||||
const int32_t block_k = k - cross_ctx;
|
||||
if (block_k <= (int32_t) j) {
|
||||
row[k] = 0.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ggml_backend_tensor_set(kq_mask_swa, lctx.dflash_kq_mask_swa_data.data(), 0, ggml_nbytes(kq_mask_swa));
|
||||
profile.graph_mask_bytes += ggml_nbytes(kq_mask_swa);
|
||||
}
|
||||
|
||||
profile.graph_visible_kv_max = std::max<uint64_t>(profile.graph_visible_kv_max, (uint64_t) visible_kv_max);
|
||||
profile.graph_prepare_total_us += (uint64_t) (ggml_time_us() - t_total_us);
|
||||
|
||||
if (profile.graph_prepare_calls == 1) {
|
||||
LLAMA_LOG_INFO("%s: DFlash graph contract rows=%d width=%d cross_ctx=%d n_tokens=%u left_pad=%d n_kv_total=%d draft_n_ctx=%u pos=%s [%d..%d]\n",
|
||||
int32_t n_swa_layers = 0;
|
||||
for (int32_t il = 0; il < lctx.model.hparams.n_layer; ++il) {
|
||||
n_swa_layers += lctx.model.hparams.swa_layers[(size_t) il] ? 1 : 0;
|
||||
}
|
||||
|
||||
LLAMA_LOG_INFO("%s: DFlash graph contract rows=%d width=%d cross_ctx=%d n_tokens=%u left_pad=%d n_kv_total=%d draft_n_ctx=%u pos=%s [%d..%d] full_mask=[%d..%d] swa_window=%u swa_layers=%d\n",
|
||||
__func__, n_rows, width, cross_ctx, n_tokens, left_pad, n_kv_total, lctx.cparams.n_ctx,
|
||||
(src_pos != nullptr && total_positions == (size_t) n_rows) ? "target" : "synthetic",
|
||||
(int) profile.last_pos_first, (int) profile.last_pos_last);
|
||||
(int) profile.last_pos_first, (int) profile.last_pos_last,
|
||||
full_visible_first, full_visible_last,
|
||||
lctx.model.hparams.n_swa,
|
||||
n_swa_layers);
|
||||
}
|
||||
|
||||
return true;
|
||||
@ -5322,6 +5573,8 @@ static int llama_decode_internal(
|
||||
const auto & hparams = model.hparams;
|
||||
const auto & cparams = lctx.cparams;
|
||||
|
||||
llama_begin_dflash_capture_batch(&lctx);
|
||||
|
||||
GGML_ASSERT((!batch_all.token && batch_all.embd) || (batch_all.token && !batch_all.embd)); // NOLINT
|
||||
|
||||
GGML_ASSERT(n_tokens_all <= cparams.n_batch);
|
||||
@ -5370,8 +5623,24 @@ static int llama_decode_internal(
|
||||
|
||||
// reserve output buffer
|
||||
n_outputs_embd = has_mtp && cparams.mtp_op_type == MTP_OP_NONE ? n_tokens_all : n_outputs;
|
||||
if (llama_output_reserve(lctx, std::max<size_t>(n_outputs, n_outputs_embd)) < std::max<size_t>(n_outputs, n_outputs_embd)) {
|
||||
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %zu outputs\n", __func__, std::max<size_t>(n_outputs, n_outputs_embd));
|
||||
const size_t required_outputs = std::max<size_t>(n_outputs, n_outputs_embd);
|
||||
const bool is_dflash_decode = lctx.model.arch == LLM_ARCH_DFLASH_DRAFT;
|
||||
const size_t output_buf_size_before = lctx.buf_output ? ggml_backend_buffer_get_size(lctx.buf_output) : 0;
|
||||
const int64_t t_output_reserve_us = is_dflash_decode ? ggml_time_us() : 0;
|
||||
const size_t reserved_outputs = llama_output_reserve(lctx, required_outputs);
|
||||
if (is_dflash_decode) {
|
||||
auto & profile = lctx.dflash_profile;
|
||||
profile.decode_output_reserve_calls++;
|
||||
profile.decode_output_reserve_us += (uint64_t) (ggml_time_us() - t_output_reserve_us);
|
||||
|
||||
const size_t output_buf_size_after = lctx.buf_output ? ggml_backend_buffer_get_size(lctx.buf_output) : 0;
|
||||
if (output_buf_size_after > output_buf_size_before) {
|
||||
profile.decode_output_reserve_reallocs++;
|
||||
profile.decode_output_reserve_realloc_bytes += (uint64_t) output_buf_size_after;
|
||||
}
|
||||
}
|
||||
if (reserved_outputs < required_outputs) {
|
||||
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %zu outputs\n", __func__, required_outputs);
|
||||
return -2;
|
||||
};
|
||||
|
||||
@ -5587,9 +5856,15 @@ static int llama_decode_internal(
|
||||
}
|
||||
|
||||
if (lctx.model.arch == LLM_ARCH_DFLASH_DRAFT) {
|
||||
auto & profile = lctx.dflash_profile;
|
||||
profile.decode_prepare_calls++;
|
||||
const int64_t t_prepare_dflash_us = ggml_time_us();
|
||||
if (!prepare_dflash_graph_inputs(lctx, n_tokens)) {
|
||||
profile.decode_prepare_failures++;
|
||||
profile.decode_prepare_us += (uint64_t) (ggml_time_us() - t_prepare_dflash_us);
|
||||
return GGML_STATUS_FAILED;
|
||||
}
|
||||
profile.decode_prepare_us += (uint64_t) (ggml_time_us() - t_prepare_dflash_us);
|
||||
}
|
||||
|
||||
// the output is always the last tensor in the graph
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user