mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-06-28 04:30:15 -05:00
implement target position tracking and context management
This commit is contained in:
parent
82cff238fe
commit
9f5f70cf7e
@ -365,7 +365,9 @@ struct common_speculative_state_dflash : public common_speculative_state {
|
||||
|
||||
std::vector<int32_t> target_layer_ids;
|
||||
std::vector<float> target_window;
|
||||
std::vector<llama_pos> target_window_pos;
|
||||
int32_t target_window_rows = 0;
|
||||
llama_pos last_target_pos = -1;
|
||||
|
||||
common_speculative_state_dflash(
|
||||
enum common_speculative_type type,
|
||||
@ -426,8 +428,6 @@ struct common_speculative_state_dflash : public common_speculative_state {
|
||||
|
||||
void begin(const llama_tokens & prompt) override {
|
||||
GGML_UNUSED(prompt);
|
||||
target_window.clear();
|
||||
target_window_rows = 0;
|
||||
llama_kv_cache_clear(ctx_dft);
|
||||
}
|
||||
|
||||
@ -444,20 +444,21 @@ struct common_speculative_state_dflash : public common_speculative_state {
|
||||
return;
|
||||
}
|
||||
|
||||
const int32_t n_draft = std::min<int32_t>(params.n_max, block_size);
|
||||
if (n_draft <= 0) {
|
||||
const int32_t n_keep = std::min<int32_t>(params.n_max, block_size);
|
||||
if (n_keep <= 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!llama_set_dflash_target_features_copy(ctx_dft, target_window.data(), target_window.size(), target_window_rows)) {
|
||||
if (!llama_set_dflash_target_features_copy(ctx_dft, target_window.data(), target_window.size(), target_window_rows, target_window_pos.data())) {
|
||||
LOG_ERR("%s: failed to set DFlash target features\n", __func__);
|
||||
return;
|
||||
}
|
||||
|
||||
llama_kv_cache_clear(ctx_dft);
|
||||
batch.n_tokens = 0;
|
||||
for (int32_t i = 0; i < n_draft; ++i) {
|
||||
common_batch_add(batch, mask_token_id, cross_ctx + i, { 0 }, true);
|
||||
const llama_pos draft_pos_base = last_target_pos >= 0 ? last_target_pos + 1 : (llama_pos) target_window_rows;
|
||||
for (int32_t i = 0; i < block_size; ++i) {
|
||||
common_batch_add(batch, mask_token_id, draft_pos_base + i, { 0 }, i < n_keep);
|
||||
}
|
||||
|
||||
if (llama_decode(ctx_dft, batch) != 0) {
|
||||
@ -466,8 +467,8 @@ struct common_speculative_state_dflash : public common_speculative_state {
|
||||
return;
|
||||
}
|
||||
|
||||
result.reserve((size_t) n_draft);
|
||||
for (int32_t i = 0; i < n_draft; ++i) {
|
||||
result.reserve((size_t) n_keep);
|
||||
for (int32_t i = 0; i < n_keep; ++i) {
|
||||
result.push_back(common_sampler_sample_speculative(nullptr, ctx_dft, i, nullptr));
|
||||
}
|
||||
|
||||
@ -2216,42 +2217,118 @@ static void mtp_clear_target_hidden(common_speculative_state_mtp & state, llama_
|
||||
state.draft_cache_by_seq.erase(seq_id);
|
||||
}
|
||||
|
||||
static void dflash_append_target_features(
|
||||
static bool dflash_append_target_features(
|
||||
common_speculative_state_dflash & state,
|
||||
const float * feature_rows,
|
||||
int32_t n_rows) {
|
||||
if (feature_rows == nullptr || n_rows <= 0 || state.n_target_features <= 0 || state.cross_ctx <= 0) {
|
||||
return;
|
||||
const common_speculative_feature_view & features,
|
||||
const llama_batch & batch,
|
||||
llama_seq_id seq_id) {
|
||||
GGML_UNUSED(batch);
|
||||
|
||||
if (features.kind != COMMON_SPECULATIVE_FEATURE_HIDDEN_STATE ||
|
||||
features.width != state.n_target_features ||
|
||||
features.rows.empty() ||
|
||||
state.cross_ctx <= 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const size_t row_width = (size_t) state.n_target_features;
|
||||
std::vector<float> new_rows;
|
||||
std::vector<llama_pos> new_positions;
|
||||
new_rows.reserve(features.rows.size() * row_width);
|
||||
new_positions.reserve(features.rows.size());
|
||||
|
||||
for (const auto & row : features.rows) {
|
||||
if (row.seq_id != seq_id || row.data == nullptr) {
|
||||
continue;
|
||||
}
|
||||
|
||||
new_positions.push_back(row.pos);
|
||||
new_rows.insert(new_rows.end(), row.data, row.data + row_width);
|
||||
}
|
||||
|
||||
if (new_positions.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const int32_t n_rows = (int32_t) new_positions.size();
|
||||
if (n_rows >= state.cross_ctx) {
|
||||
const float * src = feature_rows + (size_t) (n_rows - state.cross_ctx) * row_width;
|
||||
state.target_window.assign(src, src + (size_t) state.cross_ctx * row_width);
|
||||
const int32_t keep_from = n_rows - state.cross_ctx;
|
||||
state.target_window.assign(
|
||||
new_rows.begin() + (ptrdiff_t) keep_from * (ptrdiff_t) row_width,
|
||||
new_rows.end());
|
||||
state.target_window_pos.assign(new_positions.begin() + keep_from, new_positions.end());
|
||||
state.target_window_rows = state.cross_ctx;
|
||||
return;
|
||||
state.last_target_pos = state.target_window_pos.empty() ? -1 : state.target_window_pos.back();
|
||||
return true;
|
||||
}
|
||||
|
||||
const int32_t keep_old_rows = std::min<int32_t>(state.target_window_rows, state.cross_ctx - n_rows);
|
||||
std::vector<float> next_window((size_t) (keep_old_rows + n_rows) * row_width);
|
||||
std::vector<llama_pos> next_window_pos((size_t) (keep_old_rows + n_rows));
|
||||
|
||||
if (keep_old_rows > 0) {
|
||||
const float * old_src = state.target_window.data() + (size_t) (state.target_window_rows - keep_old_rows) * row_width;
|
||||
std::memcpy(next_window.data(), old_src, (size_t) keep_old_rows * row_width * sizeof(float));
|
||||
std::copy(state.target_window_pos.end() - keep_old_rows, state.target_window_pos.end(), next_window_pos.begin());
|
||||
}
|
||||
|
||||
std::memcpy(
|
||||
next_window.data() + (size_t) keep_old_rows * row_width,
|
||||
feature_rows,
|
||||
new_rows.data(),
|
||||
(size_t) n_rows * row_width * sizeof(float));
|
||||
std::copy(new_positions.begin(), new_positions.end(), next_window_pos.begin() + keep_old_rows);
|
||||
|
||||
state.target_window = std::move(next_window);
|
||||
state.target_window_pos = std::move(next_window_pos);
|
||||
state.target_window_rows = keep_old_rows + n_rows;
|
||||
state.last_target_pos = state.target_window_pos.empty() ? -1 : state.target_window_pos.back();
|
||||
return true;
|
||||
}
|
||||
|
||||
static void dflash_clear_target_features(common_speculative_state_dflash & state) {
|
||||
state.target_window.clear();
|
||||
state.target_window_pos.clear();
|
||||
state.target_window_rows = 0;
|
||||
state.last_target_pos = -1;
|
||||
}
|
||||
|
||||
static void dflash_context_shift(
|
||||
common_speculative_state_dflash & state,
|
||||
llama_pos kv_keep,
|
||||
llama_pos kv_discard,
|
||||
llama_pos kv_past) {
|
||||
if (kv_discard <= 0 || state.target_window_rows <= 0 || state.target_window.empty() || state.target_window_pos.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
const size_t row_width = (size_t) state.n_target_features;
|
||||
const llama_pos discard_begin = kv_keep;
|
||||
const llama_pos discard_end = kv_keep + kv_discard;
|
||||
|
||||
std::vector<float> shifted_rows;
|
||||
std::vector<llama_pos> shifted_positions;
|
||||
shifted_rows.reserve(state.target_window.size());
|
||||
shifted_positions.reserve(state.target_window_pos.size());
|
||||
|
||||
for (int32_t row = 0; row < state.target_window_rows; ++row) {
|
||||
llama_pos pos = state.target_window_pos[(size_t) row];
|
||||
if (pos >= discard_begin && pos < discard_end) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (pos >= discard_end && pos < kv_past) {
|
||||
pos -= kv_discard;
|
||||
}
|
||||
|
||||
const float * row_src = state.target_window.data() + (size_t) row * row_width;
|
||||
shifted_rows.insert(shifted_rows.end(), row_src, row_src + row_width);
|
||||
shifted_positions.push_back(pos);
|
||||
}
|
||||
|
||||
state.target_window = std::move(shifted_rows);
|
||||
state.target_window_pos = std::move(shifted_positions);
|
||||
state.target_window_rows = (int32_t) state.target_window_pos.size();
|
||||
state.last_target_pos = state.target_window_pos.empty() ? -1 : state.target_window_pos.back();
|
||||
}
|
||||
|
||||
static bool common_speculative_capture_target_features(common_speculative * spec, const common_speculative_feature_view & features) {
|
||||
@ -2366,12 +2443,9 @@ int32_t common_speculative_on_target_batch(
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<float> hidden_rows_storage;
|
||||
if (!common_speculative_feature_view_copy_batch_rows(features, batch, seq_id, &hidden_rows_storage)) {
|
||||
if (!dflash_append_target_features(*dflash_state, features, batch, seq_id)) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
dflash_append_target_features(*dflash_state, hidden_rows_storage.data(), batch.n_tokens);
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -2439,6 +2513,10 @@ void common_speculative_context_shift(
|
||||
llama_kv_cache_seq_rm (ctx_mtp, seq_id, kv_keep, kv_keep + kv_discard);
|
||||
llama_kv_cache_seq_add(ctx_mtp, seq_id, kv_keep + kv_discard, kv_past, -kv_discard);
|
||||
}
|
||||
|
||||
if (auto * dflash_state = common_speculative_get_dflash_state(spec); dflash_state != nullptr) {
|
||||
dflash_context_shift(*dflash_state, kv_keep, kv_discard, kv_past);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<llama_token> mtp_speculative_gen_draft(
|
||||
|
||||
@ -367,6 +367,9 @@ bool server_context::load_model(const gpt_params& params_) {
|
||||
if (params_dft.n_ctx == 0) {
|
||||
params_dft.n_ctx = params_base.speculative.n_ctx;
|
||||
}
|
||||
if (server_speculative_has_dflash(params_base.speculative) && params_dft.n_gpu_layers < 0) {
|
||||
params_dft.n_gpu_layers = params_base.n_gpu_layers;
|
||||
}
|
||||
params_dft.n_ctx = params_dft.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_dft.n_ctx;
|
||||
params_dft.n_parallel = 1;
|
||||
params_dft.n_batch = params_dft.n_ctx;
|
||||
@ -629,6 +632,7 @@ void server_slot::reset() {
|
||||
drafted_spec_type = COMMON_SPECULATIVE_TYPE_NONE;
|
||||
i_batch_dft.clear();
|
||||
spec_ckpt.clear();
|
||||
spec_prompt_warmup_failed = false;
|
||||
n_sent_token_probs = 0;
|
||||
infill = false;
|
||||
ga_i = 0;
|
||||
@ -717,7 +721,7 @@ void server_slot::add_token_string(const completion_token_output& token) {
|
||||
}
|
||||
|
||||
bool server_slot::can_speculate() const {
|
||||
return (!!spec || has_mtp);
|
||||
return !spec_prompt_warmup_failed && (!!spec || has_mtp);
|
||||
}
|
||||
|
||||
int server_slot::get_n_draft_max() const {
|
||||
@ -3347,7 +3351,7 @@ void server_context::discard_n_kv_and_cache_tokens(llama_context* ctx, server_sl
|
||||
const auto pos_max = llama_kv_cache_seq_pos_max(slot.ctx, slot.id);
|
||||
llama_kv_cache_seq_rm(ctx, slot.id, slot.cache_tokens.pos_next(kv_keep), slot.cache_tokens.pos_next(kv_keep + kv_discard));
|
||||
llama_kv_cache_seq_add(ctx, slot.id, kv_keep + kv_discard, kv_past, -kv_discard);
|
||||
if (slot.has_mtp && slot.spec) {
|
||||
if (slot.spec) {
|
||||
common_speculative_context_shift(slot.spec, slot.id, kv_keep, kv_discard, kv_past);
|
||||
}
|
||||
if (slot.params.cache_prompt) {
|
||||
@ -4730,12 +4734,18 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (slot.spec_prompt_warmup_failed) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if ((slot.state != SLOT_STATE_PROCESSING || slot.n_decoded != 0) &&
|
||||
(slot.state != SLOT_STATE_IDLE || slot.command != SLOT_COMMAND_LOAD_PROMPT)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (common_speculative_on_target_seq_batch(slot.spec, ctx, batch_view, slot.id, true) != 0) {
|
||||
common_speculative_clear_sequence_hidden(slot.spec, slot.id);
|
||||
slot.spec_prompt_warmup_failed = true;
|
||||
LOG_ERROR("failed to warm up speculative target-feature state from prompt batch for slot %d\n", slot.id);
|
||||
}
|
||||
}
|
||||
|
||||
@ -176,6 +176,8 @@ struct server_slot {
|
||||
// saves recurrent state before a speculative batch so it can be restored on rejection
|
||||
server_speculative_checkpoint spec_ckpt;
|
||||
|
||||
bool spec_prompt_warmup_failed = false;
|
||||
|
||||
// speculative decoding stats
|
||||
int32_t n_draft_total = 0; // Total draft tokens generated
|
||||
int32_t n_draft_accepted = 0; // Draft tokens actually accepted
|
||||
|
||||
@ -28,6 +28,8 @@ ggml_cgraph * llm_build_context::build_dflash() {
|
||||
ggml_set_input(lctx.inp_dflash_kq_mask);
|
||||
cb(lctx.inp_dflash_kq_mask, "dflash_kq_mask", -1);
|
||||
|
||||
ggml_tensor * dflash_kq_mask = flash_attn ? ggml_cast(ctx0, lctx.inp_dflash_kq_mask, GGML_TYPE_F16) : lctx.inp_dflash_kq_mask;
|
||||
|
||||
ggml_tensor * tok_embd = model.tok_embd;
|
||||
if (tok_embd == nullptr) {
|
||||
tok_embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_Q4_0, n_embd, hparams.n_vocab);
|
||||
@ -35,6 +37,7 @@ ggml_cgraph * llm_build_context::build_dflash() {
|
||||
|
||||
ggml_tensor * inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, tok_embd, cb);
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
ggml_tensor * inp_out_ids = (n_tokens > 1 && n_outputs < n_tokens) ? build_inp_out_ids() : nullptr;
|
||||
|
||||
ggml_tensor * fused_target = llm_build_lora_mm(lctx, ctx0, model.dflash_fc, lctx.inp_dflash_target_features);
|
||||
fused_target = llm_build_norm(ctx0, fused_target, hparams, model.dflash_hidden_norm, nullptr, LLM_NORM_RMS, cb, -1);
|
||||
@ -85,10 +88,9 @@ ggml_cgraph * llm_build_context::build_dflash() {
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
Qcur = ggml_cast(ctx0, Qcur, GGML_TYPE_F16);
|
||||
Kcur = ggml_cast(ctx0, Kcur, GGML_TYPE_F16);
|
||||
Vcur = ggml_cast(ctx0, Vcur, GGML_TYPE_F16);
|
||||
cb(Qcur, "Qcur_f16", il);
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur_f16", il);
|
||||
cb(Vcur, "Vcur_f16", il);
|
||||
|
||||
@ -99,7 +101,7 @@ ggml_cgraph * llm_build_context::build_dflash() {
|
||||
cb(k, "k", il);
|
||||
cb(v, "v", il);
|
||||
|
||||
cur = ggml_flash_attn_ext(ctx0, q, k, v, lctx.inp_dflash_kq_mask, kq_scale, hparams.f_max_alibi_bias,
|
||||
cur = ggml_flash_attn_ext(ctx0, q, k, v, dflash_kq_mask, kq_scale, hparams.f_max_alibi_bias,
|
||||
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
|
||||
cb(cur, "flash_attn", il);
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
@ -136,7 +138,13 @@ ggml_cgraph * llm_build_context::build_dflash() {
|
||||
output = ggml_new_tensor_2d(ctx0, GGML_TYPE_Q4_0, n_embd, hparams.n_vocab);
|
||||
}
|
||||
|
||||
ggml_tensor * result = build_output(lctx, ctx0, inpL, output, model.output_norm, cb);
|
||||
ggml_tensor * result_input = inpL;
|
||||
if (inp_out_ids) {
|
||||
result_input = ggml_get_rows(ctx0, result_input, inp_out_ids);
|
||||
cb(result_input, "result_output_rows", -1);
|
||||
}
|
||||
|
||||
ggml_tensor * result = build_output(lctx, ctx0, result_input, output, model.output_norm, cb);
|
||||
cb(result, "result_output", -1);
|
||||
ggml_build_forward_expand(gf, result);
|
||||
|
||||
|
||||
@ -281,7 +281,10 @@ struct llama_context {
|
||||
const float * dflash_target_features = nullptr;
|
||||
size_t dflash_target_features_n_floats = 0;
|
||||
int32_t dflash_target_features_n_rows = 0;
|
||||
const llama_pos * dflash_target_positions = nullptr;
|
||||
size_t dflash_target_positions_n = 0;
|
||||
std::vector<float> dflash_target_features_owned;
|
||||
std::vector<llama_pos> dflash_target_positions_owned;
|
||||
std::vector<float> dflash_target_features_padded;
|
||||
std::vector<float> dflash_feature_view_buffer;
|
||||
std::vector<llama_pos> dflash_pos_ctx_data;
|
||||
|
||||
@ -96,7 +96,8 @@ bool llama_set_dflash_target_features_copy(
|
||||
struct llama_context * ctx,
|
||||
const float * target_features,
|
||||
size_t n_floats,
|
||||
int32_t n_rows) {
|
||||
int32_t n_rows,
|
||||
const llama_pos * target_positions) {
|
||||
if (ctx == nullptr || target_features == nullptr || n_floats == 0 || n_rows <= 0) {
|
||||
return false;
|
||||
}
|
||||
@ -105,6 +106,15 @@ bool llama_set_dflash_target_features_copy(
|
||||
ctx->dflash_target_features = ctx->dflash_target_features_owned.data();
|
||||
ctx->dflash_target_features_n_floats = n_floats;
|
||||
ctx->dflash_target_features_n_rows = n_rows;
|
||||
if (target_positions != nullptr) {
|
||||
ctx->dflash_target_positions_owned.assign(target_positions, target_positions + n_rows);
|
||||
ctx->dflash_target_positions = ctx->dflash_target_positions_owned.data();
|
||||
ctx->dflash_target_positions_n = (size_t) n_rows;
|
||||
} else {
|
||||
ctx->dflash_target_positions_owned.clear();
|
||||
ctx->dflash_target_positions = nullptr;
|
||||
ctx->dflash_target_positions_n = 0;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -361,9 +371,25 @@ bool llama_spec_get_dflash_feature_view(
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<int32_t> row_indices((size_t) batch.n_tokens);
|
||||
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
||||
row_indices[(size_t) i] = i;
|
||||
int32_t row_count = 0;
|
||||
int32_t row_width = 0;
|
||||
int32_t n_layers = 0;
|
||||
if (!llama_spec_prepare_dflash_capture(ctx, row_count, row_width, n_layers)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const int32_t batch_row_offset = std::max<int32_t>(0, batch.n_tokens - row_count);
|
||||
std::vector<int32_t> row_indices;
|
||||
std::vector<int32_t> batch_indices;
|
||||
row_indices.reserve((size_t) (batch.n_tokens - batch_row_offset));
|
||||
batch_indices.reserve((size_t) (batch.n_tokens - batch_row_offset));
|
||||
for (int32_t i = batch_row_offset; i < batch.n_tokens; ++i) {
|
||||
row_indices.push_back(i - batch_row_offset);
|
||||
batch_indices.push_back(i);
|
||||
}
|
||||
|
||||
if (row_indices.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
view = {};
|
||||
@ -372,17 +398,17 @@ bool llama_spec_get_dflash_feature_view(
|
||||
return false;
|
||||
}
|
||||
|
||||
view.rows.reserve((size_t) batch.n_tokens);
|
||||
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
||||
if (batch.n_seq_id[i] <= 0 || batch.seq_id[i] == nullptr) {
|
||||
view.rows.reserve(batch_indices.size());
|
||||
for (int32_t batch_index : batch_indices) {
|
||||
if (batch.n_seq_id[batch_index] <= 0 || batch.seq_id[batch_index] == nullptr) {
|
||||
view.rows.clear();
|
||||
return false;
|
||||
}
|
||||
|
||||
view.rows.push_back({
|
||||
/* .seq_id = */ batch.seq_id[i][0],
|
||||
/* .pos = */ batch.pos[i],
|
||||
/* .data = */ ctx->dflash_feature_view_buffer.data() + (size_t) i * (size_t) view.width,
|
||||
/* .seq_id = */ batch.seq_id[batch_index][0],
|
||||
/* .pos = */ batch.pos[batch_index],
|
||||
/* .data = */ ctx->dflash_feature_view_buffer.data() + view.rows.size() * (size_t) view.width,
|
||||
});
|
||||
}
|
||||
|
||||
@ -398,18 +424,26 @@ bool llama_spec_get_dflash_feature_view_for_seq(
|
||||
return false;
|
||||
}
|
||||
|
||||
int32_t row_count = 0;
|
||||
int32_t row_width = 0;
|
||||
int32_t n_layers = 0;
|
||||
if (!llama_spec_prepare_dflash_capture(ctx, row_count, row_width, n_layers)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const int32_t batch_row_offset = std::max<int32_t>(0, batch.n_tokens - row_count);
|
||||
std::vector<int32_t> row_indices;
|
||||
row_indices.reserve((size_t) batch.n_tokens);
|
||||
std::vector<int32_t> batch_indices;
|
||||
batch_indices.reserve((size_t) batch.n_tokens);
|
||||
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
||||
for (int32_t i = batch_row_offset; i < batch.n_tokens; ++i) {
|
||||
if (batch.n_seq_id[i] <= 0 || batch.seq_id[i] == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (int32_t j = 0; j < batch.n_seq_id[i]; ++j) {
|
||||
if (batch.seq_id[i][j] == seq_id) {
|
||||
row_indices.push_back(i);
|
||||
row_indices.push_back(i - batch_row_offset);
|
||||
batch_indices.push_back(i);
|
||||
break;
|
||||
}
|
||||
|
||||
@ -51,7 +51,8 @@ bool llama_set_dflash_target_features_copy(
|
||||
struct llama_context * ctx,
|
||||
const float * target_features,
|
||||
size_t n_floats,
|
||||
int32_t n_rows);
|
||||
int32_t n_rows,
|
||||
const llama_pos * target_positions);
|
||||
|
||||
bool llama_set_dflash_capture_layers(
|
||||
struct llama_context * ctx,
|
||||
|
||||
@ -4994,7 +4994,9 @@ static bool prepare_dflash_graph_inputs(
|
||||
}
|
||||
|
||||
const float * src = lctx.dflash_target_features;
|
||||
const llama_pos * src_pos = lctx.dflash_target_positions;
|
||||
const size_t total_floats = lctx.dflash_target_features_n_floats;
|
||||
const size_t total_positions = lctx.dflash_target_positions_n;
|
||||
const int32_t n_rows = lctx.dflash_target_features_n_rows;
|
||||
const int32_t width = (int32_t) target_hidden->ne[0];
|
||||
const int32_t cross_ctx = (int32_t) target_hidden->ne[1];
|
||||
@ -5014,20 +5016,26 @@ static bool prepare_dflash_graph_inputs(
|
||||
|
||||
lctx.dflash_target_features_padded.assign((size_t) cross_ctx * (size_t) width, 0.0f);
|
||||
const size_t dst_offset = (size_t) (cross_ctx - n_rows) * (size_t) width;
|
||||
const int32_t left_pad = cross_ctx - n_rows;
|
||||
std::copy(src, src + total_floats, lctx.dflash_target_features_padded.begin() + (ptrdiff_t) dst_offset);
|
||||
ggml_backend_tensor_set(target_hidden, lctx.dflash_target_features_padded.data(), 0, ggml_nbytes(target_hidden));
|
||||
|
||||
lctx.dflash_pos_ctx_data.resize((size_t) cross_ctx);
|
||||
for (int32_t i = 0; i < cross_ctx; ++i) {
|
||||
lctx.dflash_pos_ctx_data[i] = i;
|
||||
std::fill(lctx.dflash_pos_ctx_data.begin(), lctx.dflash_pos_ctx_data.end(), 0);
|
||||
if (src_pos != nullptr && total_positions == (size_t) n_rows) {
|
||||
std::copy(src_pos, src_pos + n_rows, lctx.dflash_pos_ctx_data.begin() + (ptrdiff_t) left_pad);
|
||||
} else {
|
||||
for (int32_t i = 0; i < n_rows; ++i) {
|
||||
lctx.dflash_pos_ctx_data[(size_t) left_pad + (size_t) i] = i;
|
||||
}
|
||||
}
|
||||
ggml_backend_tensor_set(pos_ctx, lctx.dflash_pos_ctx_data.data(), 0, ggml_nbytes(pos_ctx));
|
||||
|
||||
lctx.dflash_kq_mask_data.assign((size_t) n_kv_total * (size_t) n_mask_tokens, -INFINITY);
|
||||
const int32_t left_pad = cross_ctx - n_rows;
|
||||
for (uint32_t j = 0; j < n_tokens; ++j) {
|
||||
float * row = lctx.dflash_kq_mask_data.data() + (size_t) j * (size_t) n_kv_total;
|
||||
for (int32_t i = left_pad; i < n_kv_total; ++i) {
|
||||
const int32_t visible_kv = cross_ctx + (int32_t) j + 1;
|
||||
for (int32_t i = left_pad; i < visible_kv; ++i) {
|
||||
row[i] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user