From 9f5f70cf7e97971d5f3ea94228653689601ccc0a Mon Sep 17 00:00:00 2001 From: SamuelOliveirads Date: Fri, 29 May 2026 23:11:38 -0300 Subject: [PATCH] implement target position tracking and context management --- common/speculative.cpp | 122 +++++++++++++++++++++++------ examples/server/server-context.cpp | 14 +++- examples/server/server-context.h | 2 + src/graphs/build_dflash.cpp | 16 +++- src/llama-context.h | 3 + src/llama-spec-features.cpp | 58 +++++++++++--- src/llama-spec-features.h | 3 +- src/llama.cpp | 16 +++- 8 files changed, 189 insertions(+), 45 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index 3b08b26a..d740ded9 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -365,7 +365,9 @@ struct common_speculative_state_dflash : public common_speculative_state { std::vector target_layer_ids; std::vector target_window; + std::vector target_window_pos; int32_t target_window_rows = 0; + llama_pos last_target_pos = -1; common_speculative_state_dflash( enum common_speculative_type type, @@ -426,8 +428,6 @@ struct common_speculative_state_dflash : public common_speculative_state { void begin(const llama_tokens & prompt) override { GGML_UNUSED(prompt); - target_window.clear(); - target_window_rows = 0; llama_kv_cache_clear(ctx_dft); } @@ -444,20 +444,21 @@ struct common_speculative_state_dflash : public common_speculative_state { return; } - const int32_t n_draft = std::min(params.n_max, block_size); - if (n_draft <= 0) { + const int32_t n_keep = std::min(params.n_max, block_size); + if (n_keep <= 0) { return; } - if (!llama_set_dflash_target_features_copy(ctx_dft, target_window.data(), target_window.size(), target_window_rows)) { + if (!llama_set_dflash_target_features_copy(ctx_dft, target_window.data(), target_window.size(), target_window_rows, target_window_pos.data())) { LOG_ERR("%s: failed to set DFlash target features\n", __func__); return; } llama_kv_cache_clear(ctx_dft); batch.n_tokens = 0; - for (int32_t i = 0; i < n_draft; ++i) { - common_batch_add(batch, mask_token_id, cross_ctx + i, { 0 }, true); + const llama_pos draft_pos_base = last_target_pos >= 0 ? last_target_pos + 1 : (llama_pos) target_window_rows; + for (int32_t i = 0; i < block_size; ++i) { + common_batch_add(batch, mask_token_id, draft_pos_base + i, { 0 }, i < n_keep); } if (llama_decode(ctx_dft, batch) != 0) { @@ -466,8 +467,8 @@ struct common_speculative_state_dflash : public common_speculative_state { return; } - result.reserve((size_t) n_draft); - for (int32_t i = 0; i < n_draft; ++i) { + result.reserve((size_t) n_keep); + for (int32_t i = 0; i < n_keep; ++i) { result.push_back(common_sampler_sample_speculative(nullptr, ctx_dft, i, nullptr)); } @@ -2216,42 +2217,118 @@ static void mtp_clear_target_hidden(common_speculative_state_mtp & state, llama_ state.draft_cache_by_seq.erase(seq_id); } -static void dflash_append_target_features( +static bool dflash_append_target_features( common_speculative_state_dflash & state, - const float * feature_rows, - int32_t n_rows) { - if (feature_rows == nullptr || n_rows <= 0 || state.n_target_features <= 0 || state.cross_ctx <= 0) { - return; + const common_speculative_feature_view & features, + const llama_batch & batch, + llama_seq_id seq_id) { + GGML_UNUSED(batch); + + if (features.kind != COMMON_SPECULATIVE_FEATURE_HIDDEN_STATE || + features.width != state.n_target_features || + features.rows.empty() || + state.cross_ctx <= 0) { + return false; } const size_t row_width = (size_t) state.n_target_features; + std::vector new_rows; + std::vector new_positions; + new_rows.reserve(features.rows.size() * row_width); + new_positions.reserve(features.rows.size()); + + for (const auto & row : features.rows) { + if (row.seq_id != seq_id || row.data == nullptr) { + continue; + } + + new_positions.push_back(row.pos); + new_rows.insert(new_rows.end(), row.data, row.data + row_width); + } + + if (new_positions.empty()) { + return false; + } + + const int32_t n_rows = (int32_t) new_positions.size(); if (n_rows >= state.cross_ctx) { - const float * src = feature_rows + (size_t) (n_rows - state.cross_ctx) * row_width; - state.target_window.assign(src, src + (size_t) state.cross_ctx * row_width); + const int32_t keep_from = n_rows - state.cross_ctx; + state.target_window.assign( + new_rows.begin() + (ptrdiff_t) keep_from * (ptrdiff_t) row_width, + new_rows.end()); + state.target_window_pos.assign(new_positions.begin() + keep_from, new_positions.end()); state.target_window_rows = state.cross_ctx; - return; + state.last_target_pos = state.target_window_pos.empty() ? -1 : state.target_window_pos.back(); + return true; } const int32_t keep_old_rows = std::min(state.target_window_rows, state.cross_ctx - n_rows); std::vector next_window((size_t) (keep_old_rows + n_rows) * row_width); + std::vector next_window_pos((size_t) (keep_old_rows + n_rows)); if (keep_old_rows > 0) { const float * old_src = state.target_window.data() + (size_t) (state.target_window_rows - keep_old_rows) * row_width; std::memcpy(next_window.data(), old_src, (size_t) keep_old_rows * row_width * sizeof(float)); + std::copy(state.target_window_pos.end() - keep_old_rows, state.target_window_pos.end(), next_window_pos.begin()); } std::memcpy( next_window.data() + (size_t) keep_old_rows * row_width, - feature_rows, + new_rows.data(), (size_t) n_rows * row_width * sizeof(float)); + std::copy(new_positions.begin(), new_positions.end(), next_window_pos.begin() + keep_old_rows); state.target_window = std::move(next_window); + state.target_window_pos = std::move(next_window_pos); state.target_window_rows = keep_old_rows + n_rows; + state.last_target_pos = state.target_window_pos.empty() ? -1 : state.target_window_pos.back(); + return true; } static void dflash_clear_target_features(common_speculative_state_dflash & state) { state.target_window.clear(); + state.target_window_pos.clear(); state.target_window_rows = 0; + state.last_target_pos = -1; +} + +static void dflash_context_shift( + common_speculative_state_dflash & state, + llama_pos kv_keep, + llama_pos kv_discard, + llama_pos kv_past) { + if (kv_discard <= 0 || state.target_window_rows <= 0 || state.target_window.empty() || state.target_window_pos.empty()) { + return; + } + + const size_t row_width = (size_t) state.n_target_features; + const llama_pos discard_begin = kv_keep; + const llama_pos discard_end = kv_keep + kv_discard; + + std::vector shifted_rows; + std::vector shifted_positions; + shifted_rows.reserve(state.target_window.size()); + shifted_positions.reserve(state.target_window_pos.size()); + + for (int32_t row = 0; row < state.target_window_rows; ++row) { + llama_pos pos = state.target_window_pos[(size_t) row]; + if (pos >= discard_begin && pos < discard_end) { + continue; + } + + if (pos >= discard_end && pos < kv_past) { + pos -= kv_discard; + } + + const float * row_src = state.target_window.data() + (size_t) row * row_width; + shifted_rows.insert(shifted_rows.end(), row_src, row_src + row_width); + shifted_positions.push_back(pos); + } + + state.target_window = std::move(shifted_rows); + state.target_window_pos = std::move(shifted_positions); + state.target_window_rows = (int32_t) state.target_window_pos.size(); + state.last_target_pos = state.target_window_pos.empty() ? -1 : state.target_window_pos.back(); } static bool common_speculative_capture_target_features(common_speculative * spec, const common_speculative_feature_view & features) { @@ -2366,12 +2443,9 @@ int32_t common_speculative_on_target_batch( } } - std::vector hidden_rows_storage; - if (!common_speculative_feature_view_copy_batch_rows(features, batch, seq_id, &hidden_rows_storage)) { + if (!dflash_append_target_features(*dflash_state, features, batch, seq_id)) { return -1; } - - dflash_append_target_features(*dflash_state, hidden_rows_storage.data(), batch.n_tokens); return 0; } @@ -2439,6 +2513,10 @@ void common_speculative_context_shift( llama_kv_cache_seq_rm (ctx_mtp, seq_id, kv_keep, kv_keep + kv_discard); llama_kv_cache_seq_add(ctx_mtp, seq_id, kv_keep + kv_discard, kv_past, -kv_discard); } + + if (auto * dflash_state = common_speculative_get_dflash_state(spec); dflash_state != nullptr) { + dflash_context_shift(*dflash_state, kv_keep, kv_discard, kv_past); + } } std::vector mtp_speculative_gen_draft( diff --git a/examples/server/server-context.cpp b/examples/server/server-context.cpp index af6924fe..2a4e8c95 100644 --- a/examples/server/server-context.cpp +++ b/examples/server/server-context.cpp @@ -367,6 +367,9 @@ bool server_context::load_model(const gpt_params& params_) { if (params_dft.n_ctx == 0) { params_dft.n_ctx = params_base.speculative.n_ctx; } + if (server_speculative_has_dflash(params_base.speculative) && params_dft.n_gpu_layers < 0) { + params_dft.n_gpu_layers = params_base.n_gpu_layers; + } params_dft.n_ctx = params_dft.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_dft.n_ctx; params_dft.n_parallel = 1; params_dft.n_batch = params_dft.n_ctx; @@ -629,6 +632,7 @@ void server_slot::reset() { drafted_spec_type = COMMON_SPECULATIVE_TYPE_NONE; i_batch_dft.clear(); spec_ckpt.clear(); + spec_prompt_warmup_failed = false; n_sent_token_probs = 0; infill = false; ga_i = 0; @@ -717,7 +721,7 @@ void server_slot::add_token_string(const completion_token_output& token) { } bool server_slot::can_speculate() const { - return (!!spec || has_mtp); + return !spec_prompt_warmup_failed && (!!spec || has_mtp); } int server_slot::get_n_draft_max() const { @@ -3347,7 +3351,7 @@ void server_context::discard_n_kv_and_cache_tokens(llama_context* ctx, server_sl const auto pos_max = llama_kv_cache_seq_pos_max(slot.ctx, slot.id); llama_kv_cache_seq_rm(ctx, slot.id, slot.cache_tokens.pos_next(kv_keep), slot.cache_tokens.pos_next(kv_keep + kv_discard)); llama_kv_cache_seq_add(ctx, slot.id, kv_keep + kv_discard, kv_past, -kv_discard); - if (slot.has_mtp && slot.spec) { + if (slot.spec) { common_speculative_context_shift(slot.spec, slot.id, kv_keep, kv_discard, kv_past); } if (slot.params.cache_prompt) { @@ -4730,12 +4734,18 @@ void server_context::process_batch_tokens(int32_t & n_batch) { continue; } + if (slot.spec_prompt_warmup_failed) { + continue; + } + if ((slot.state != SLOT_STATE_PROCESSING || slot.n_decoded != 0) && (slot.state != SLOT_STATE_IDLE || slot.command != SLOT_COMMAND_LOAD_PROMPT)) { continue; } if (common_speculative_on_target_seq_batch(slot.spec, ctx, batch_view, slot.id, true) != 0) { + common_speculative_clear_sequence_hidden(slot.spec, slot.id); + slot.spec_prompt_warmup_failed = true; LOG_ERROR("failed to warm up speculative target-feature state from prompt batch for slot %d\n", slot.id); } } diff --git a/examples/server/server-context.h b/examples/server/server-context.h index a33c2113..d4d0913c 100644 --- a/examples/server/server-context.h +++ b/examples/server/server-context.h @@ -176,6 +176,8 @@ struct server_slot { // saves recurrent state before a speculative batch so it can be restored on rejection server_speculative_checkpoint spec_ckpt; + bool spec_prompt_warmup_failed = false; + // speculative decoding stats int32_t n_draft_total = 0; // Total draft tokens generated int32_t n_draft_accepted = 0; // Draft tokens actually accepted diff --git a/src/graphs/build_dflash.cpp b/src/graphs/build_dflash.cpp index fe1cec15..542821ad 100644 --- a/src/graphs/build_dflash.cpp +++ b/src/graphs/build_dflash.cpp @@ -28,6 +28,8 @@ ggml_cgraph * llm_build_context::build_dflash() { ggml_set_input(lctx.inp_dflash_kq_mask); cb(lctx.inp_dflash_kq_mask, "dflash_kq_mask", -1); + ggml_tensor * dflash_kq_mask = flash_attn ? ggml_cast(ctx0, lctx.inp_dflash_kq_mask, GGML_TYPE_F16) : lctx.inp_dflash_kq_mask; + ggml_tensor * tok_embd = model.tok_embd; if (tok_embd == nullptr) { tok_embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_Q4_0, n_embd, hparams.n_vocab); @@ -35,6 +37,7 @@ ggml_cgraph * llm_build_context::build_dflash() { ggml_tensor * inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, tok_embd, cb); ggml_tensor * inp_pos = build_inp_pos(); + ggml_tensor * inp_out_ids = (n_tokens > 1 && n_outputs < n_tokens) ? build_inp_out_ids() : nullptr; ggml_tensor * fused_target = llm_build_lora_mm(lctx, ctx0, model.dflash_fc, lctx.inp_dflash_target_features); fused_target = llm_build_norm(ctx0, fused_target, hparams, model.dflash_hidden_norm, nullptr, LLM_NORM_RMS, cb, -1); @@ -85,10 +88,9 @@ ggml_cgraph * llm_build_context::build_dflash() { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - Qcur = ggml_cast(ctx0, Qcur, GGML_TYPE_F16); Kcur = ggml_cast(ctx0, Kcur, GGML_TYPE_F16); Vcur = ggml_cast(ctx0, Vcur, GGML_TYPE_F16); - cb(Qcur, "Qcur_f16", il); + cb(Qcur, "Qcur", il); cb(Kcur, "Kcur_f16", il); cb(Vcur, "Vcur_f16", il); @@ -99,7 +101,7 @@ ggml_cgraph * llm_build_context::build_dflash() { cb(k, "k", il); cb(v, "v", il); - cur = ggml_flash_attn_ext(ctx0, q, k, v, lctx.inp_dflash_kq_mask, kq_scale, hparams.f_max_alibi_bias, + cur = ggml_flash_attn_ext(ctx0, q, k, v, dflash_kq_mask, kq_scale, hparams.f_max_alibi_bias, hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f); cb(cur, "flash_attn", il); ggml_build_forward_expand(gf, cur); @@ -136,7 +138,13 @@ ggml_cgraph * llm_build_context::build_dflash() { output = ggml_new_tensor_2d(ctx0, GGML_TYPE_Q4_0, n_embd, hparams.n_vocab); } - ggml_tensor * result = build_output(lctx, ctx0, inpL, output, model.output_norm, cb); + ggml_tensor * result_input = inpL; + if (inp_out_ids) { + result_input = ggml_get_rows(ctx0, result_input, inp_out_ids); + cb(result_input, "result_output_rows", -1); + } + + ggml_tensor * result = build_output(lctx, ctx0, result_input, output, model.output_norm, cb); cb(result, "result_output", -1); ggml_build_forward_expand(gf, result); diff --git a/src/llama-context.h b/src/llama-context.h index d0d9fe61..c4f62ac1 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -281,7 +281,10 @@ struct llama_context { const float * dflash_target_features = nullptr; size_t dflash_target_features_n_floats = 0; int32_t dflash_target_features_n_rows = 0; + const llama_pos * dflash_target_positions = nullptr; + size_t dflash_target_positions_n = 0; std::vector dflash_target_features_owned; + std::vector dflash_target_positions_owned; std::vector dflash_target_features_padded; std::vector dflash_feature_view_buffer; std::vector dflash_pos_ctx_data; diff --git a/src/llama-spec-features.cpp b/src/llama-spec-features.cpp index 827d536f..ccc6fb5d 100644 --- a/src/llama-spec-features.cpp +++ b/src/llama-spec-features.cpp @@ -96,7 +96,8 @@ bool llama_set_dflash_target_features_copy( struct llama_context * ctx, const float * target_features, size_t n_floats, - int32_t n_rows) { + int32_t n_rows, + const llama_pos * target_positions) { if (ctx == nullptr || target_features == nullptr || n_floats == 0 || n_rows <= 0) { return false; } @@ -105,6 +106,15 @@ bool llama_set_dflash_target_features_copy( ctx->dflash_target_features = ctx->dflash_target_features_owned.data(); ctx->dflash_target_features_n_floats = n_floats; ctx->dflash_target_features_n_rows = n_rows; + if (target_positions != nullptr) { + ctx->dflash_target_positions_owned.assign(target_positions, target_positions + n_rows); + ctx->dflash_target_positions = ctx->dflash_target_positions_owned.data(); + ctx->dflash_target_positions_n = (size_t) n_rows; + } else { + ctx->dflash_target_positions_owned.clear(); + ctx->dflash_target_positions = nullptr; + ctx->dflash_target_positions_n = 0; + } return true; } @@ -361,9 +371,25 @@ bool llama_spec_get_dflash_feature_view( return false; } - std::vector row_indices((size_t) batch.n_tokens); - for (int32_t i = 0; i < batch.n_tokens; ++i) { - row_indices[(size_t) i] = i; + int32_t row_count = 0; + int32_t row_width = 0; + int32_t n_layers = 0; + if (!llama_spec_prepare_dflash_capture(ctx, row_count, row_width, n_layers)) { + return false; + } + + const int32_t batch_row_offset = std::max(0, batch.n_tokens - row_count); + std::vector row_indices; + std::vector batch_indices; + row_indices.reserve((size_t) (batch.n_tokens - batch_row_offset)); + batch_indices.reserve((size_t) (batch.n_tokens - batch_row_offset)); + for (int32_t i = batch_row_offset; i < batch.n_tokens; ++i) { + row_indices.push_back(i - batch_row_offset); + batch_indices.push_back(i); + } + + if (row_indices.empty()) { + return false; } view = {}; @@ -372,17 +398,17 @@ bool llama_spec_get_dflash_feature_view( return false; } - view.rows.reserve((size_t) batch.n_tokens); - for (int32_t i = 0; i < batch.n_tokens; ++i) { - if (batch.n_seq_id[i] <= 0 || batch.seq_id[i] == nullptr) { + view.rows.reserve(batch_indices.size()); + for (int32_t batch_index : batch_indices) { + if (batch.n_seq_id[batch_index] <= 0 || batch.seq_id[batch_index] == nullptr) { view.rows.clear(); return false; } view.rows.push_back({ - /* .seq_id = */ batch.seq_id[i][0], - /* .pos = */ batch.pos[i], - /* .data = */ ctx->dflash_feature_view_buffer.data() + (size_t) i * (size_t) view.width, + /* .seq_id = */ batch.seq_id[batch_index][0], + /* .pos = */ batch.pos[batch_index], + /* .data = */ ctx->dflash_feature_view_buffer.data() + view.rows.size() * (size_t) view.width, }); } @@ -398,18 +424,26 @@ bool llama_spec_get_dflash_feature_view_for_seq( return false; } + int32_t row_count = 0; + int32_t row_width = 0; + int32_t n_layers = 0; + if (!llama_spec_prepare_dflash_capture(ctx, row_count, row_width, n_layers)) { + return false; + } + + const int32_t batch_row_offset = std::max(0, batch.n_tokens - row_count); std::vector row_indices; row_indices.reserve((size_t) batch.n_tokens); std::vector batch_indices; batch_indices.reserve((size_t) batch.n_tokens); - for (int32_t i = 0; i < batch.n_tokens; ++i) { + for (int32_t i = batch_row_offset; i < batch.n_tokens; ++i) { if (batch.n_seq_id[i] <= 0 || batch.seq_id[i] == nullptr) { return false; } for (int32_t j = 0; j < batch.n_seq_id[i]; ++j) { if (batch.seq_id[i][j] == seq_id) { - row_indices.push_back(i); + row_indices.push_back(i - batch_row_offset); batch_indices.push_back(i); break; } diff --git a/src/llama-spec-features.h b/src/llama-spec-features.h index ea177c1e..130d3895 100644 --- a/src/llama-spec-features.h +++ b/src/llama-spec-features.h @@ -51,7 +51,8 @@ bool llama_set_dflash_target_features_copy( struct llama_context * ctx, const float * target_features, size_t n_floats, - int32_t n_rows); + int32_t n_rows, + const llama_pos * target_positions); bool llama_set_dflash_capture_layers( struct llama_context * ctx, diff --git a/src/llama.cpp b/src/llama.cpp index 0bc960c3..a9b443c7 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -4994,7 +4994,9 @@ static bool prepare_dflash_graph_inputs( } const float * src = lctx.dflash_target_features; + const llama_pos * src_pos = lctx.dflash_target_positions; const size_t total_floats = lctx.dflash_target_features_n_floats; + const size_t total_positions = lctx.dflash_target_positions_n; const int32_t n_rows = lctx.dflash_target_features_n_rows; const int32_t width = (int32_t) target_hidden->ne[0]; const int32_t cross_ctx = (int32_t) target_hidden->ne[1]; @@ -5014,20 +5016,26 @@ static bool prepare_dflash_graph_inputs( lctx.dflash_target_features_padded.assign((size_t) cross_ctx * (size_t) width, 0.0f); const size_t dst_offset = (size_t) (cross_ctx - n_rows) * (size_t) width; + const int32_t left_pad = cross_ctx - n_rows; std::copy(src, src + total_floats, lctx.dflash_target_features_padded.begin() + (ptrdiff_t) dst_offset); ggml_backend_tensor_set(target_hidden, lctx.dflash_target_features_padded.data(), 0, ggml_nbytes(target_hidden)); lctx.dflash_pos_ctx_data.resize((size_t) cross_ctx); - for (int32_t i = 0; i < cross_ctx; ++i) { - lctx.dflash_pos_ctx_data[i] = i; + std::fill(lctx.dflash_pos_ctx_data.begin(), lctx.dflash_pos_ctx_data.end(), 0); + if (src_pos != nullptr && total_positions == (size_t) n_rows) { + std::copy(src_pos, src_pos + n_rows, lctx.dflash_pos_ctx_data.begin() + (ptrdiff_t) left_pad); + } else { + for (int32_t i = 0; i < n_rows; ++i) { + lctx.dflash_pos_ctx_data[(size_t) left_pad + (size_t) i] = i; + } } ggml_backend_tensor_set(pos_ctx, lctx.dflash_pos_ctx_data.data(), 0, ggml_nbytes(pos_ctx)); lctx.dflash_kq_mask_data.assign((size_t) n_kv_total * (size_t) n_mask_tokens, -INFINITY); - const int32_t left_pad = cross_ctx - n_rows; for (uint32_t j = 0; j < n_tokens; ++j) { float * row = lctx.dflash_kq_mask_data.data() + (size_t) j * (size_t) n_kv_total; - for (int32_t i = left_pad; i < n_kv_total; ++i) { + const int32_t visible_kv = cross_ctx + (int32_t) j + 1; + for (int32_t i = left_pad; i < visible_kv; ++i) { row[i] = 0.0f; } }