ik_llama.cpp/src/llama-spec-features-dflash.cpp
SamuelOliveirads 6cae8c7ba2 clean logs
2026-06-14 21:07:57 -03:00

730 lines
26 KiB
C++

#include "llama-spec-features.h"
#include <algorithm>
#include <cstdlib>
#include <cstring>
#include <random>
#include "llama-model.h"
#include "llama-context.h"
void llama_reset_dflash_kv_cache_state(struct llama_context * ctx) {
if (ctx == nullptr) {
return;
}
ctx->dflash.kv.cache_write_pos = 0;
ctx->dflash.kv.cache_n_filled = 0;
ctx->dflash.kv.cache_update_rows = 0;
ctx->dflash.kv.cache_view_write_pos = 0;
ctx->dflash.kv.cache_view_n_filled = 0;
ctx->dflash.kv.cache_applied_window_version = 0;
ctx->dflash.kv.cache_valid = false;
ctx->dflash.kv.cache_view_valid = false;
ctx->dflash.kv.workspace_write_pos = 0;
ctx->dflash.kv.workspace_n_filled = 0;
ctx->dflash.kv.workspace_applied_window_version = 0;
ctx->dflash.kv.workspace_valid = false;
ctx->dflash.kv.workspace_sync_pending = false;
for (ggml_backend_buffer_t buf : ctx->dflash.kv.cache_bufs) {
if (buf != nullptr) {
ggml_backend_buffer_clear(buf, 0);
}
}
}
llama_dflash_kv_cache_transition llama_plan_dflash_kv_cache_transition_for_ctx(
const struct llama_context * ctx,
const llama_dflash_window_update & window_update,
int32_t n_rows) {
if (ctx == nullptr) {
llama_dflash_kv_cache_transition plan;
plan.rebuild_cache = true;
plan.append_rows = std::clamp(window_update.append_rows, 0, n_rows);
plan.next_n_filled = n_rows;
return plan;
}
const int32_t cross_ctx = ctx->dflash.visible_cross_ctx > 0
? ctx->dflash.visible_cross_ctx
: std::max<int32_t>(1, (int32_t) ctx->cparams.n_ctx - (int32_t) ctx->model.hparams.dflash_block_size);
return llama_plan_dflash_kv_cache_transition(
cross_ctx,
ctx->dflash.kv.cache_n_filled,
ctx->dflash.kv.cache_write_pos,
ctx->dflash.kv.cache_valid,
ctx->dflash.kv.cache_applied_window_version,
window_update.version,
window_update.keep_rows,
window_update.append_rows,
window_update.replace,
n_rows);
}
void llama_set_dflash_visible_cross_ctx(
struct llama_context * ctx,
int32_t cross_ctx) {
if (ctx == nullptr) {
return;
}
ctx->dflash.visible_cross_ctx = std::max<int32_t>(0, cross_ctx);
}
int32_t llama_get_dflash_visible_cross_ctx(
const struct llama_context * ctx) {
return ctx != nullptr ? ctx->dflash.visible_cross_ctx : 0;
}
int32_t llama_model_dflash_block_size(const struct llama_model * model) {
return model ? (int32_t) model->hparams.dflash_block_size : 0;
}
int32_t llama_model_dflash_mask_token_id(const struct llama_model * model) {
return model ? (int32_t) model->hparams.dflash_mask_token_id : -1;
}
int32_t llama_model_dflash_n_target_layers(const struct llama_model * model) {
return model ? (int32_t) model->hparams.dflash_n_target_layers : 0;
}
int32_t llama_model_dflash_n_target_features(const struct llama_model * model) {
return model ? (int32_t) model->hparams.dflash_n_target_features : 0;
}
int32_t llama_model_dflash_target_layer_ids(
const struct llama_model * model,
int32_t * layer_ids,
int32_t capacity) {
if (model == nullptr || layer_ids == nullptr || capacity <= 0) {
return 0;
}
const int32_t n_layers = std::min<int32_t>((int32_t) model->hparams.dflash_n_target_layers, capacity);
for (int32_t i = 0; i < n_layers; ++i) {
layer_ids[i] = (int32_t) model->hparams.dflash_target_layer_ids[i];
}
return n_layers;
}
int32_t llama_model_dflash_target_mask_token_id(const struct llama_model * model) {
if (model == nullptr) {
return (int32_t) LLAMA_TOKEN_NULL;
}
return (int32_t) model->vocab.token_mask();
}
const struct ggml_tensor * llama_model_dflash_output_tensor(
const struct llama_model * model) {
if (model == nullptr) {
return nullptr;
}
if (model->output_mtp != nullptr) {
return model->output_mtp;
}
if (model->output != nullptr) {
return model->output;
}
return model->tok_embd;
}
int32_t llama_model_dflash_io_mode(
const struct llama_model * draft_model,
const struct llama_model * target_model) {
if (draft_model == nullptr || target_model == nullptr || draft_model->arch != LLM_ARCH_DFLASH_DRAFT) {
return LLAMA_DFLASH_IO_MODE_INVALID;
}
const ggml_tensor * draft_output = llama_model_dflash_output_tensor(draft_model);
const ggml_tensor * target_output = llama_model_dflash_output_tensor(target_model);
if (draft_model->tok_embd == nullptr || draft_output == nullptr || target_model->tok_embd == nullptr || target_output == nullptr) {
return LLAMA_DFLASH_IO_MODE_INVALID;
}
const bool shared_tok = draft_model->tok_embd == target_model->tok_embd;
const bool shared_output = draft_output == target_output;
if (shared_tok && shared_output) {
return LLAMA_DFLASH_IO_MODE_SHARED;
}
if (!shared_tok && !shared_output) {
return LLAMA_DFLASH_IO_MODE_SELF_CONTAINED;
}
return LLAMA_DFLASH_IO_MODE_MIXED;
}
bool llama_model_dflash_io_tensors_match(
const struct llama_model * draft_model,
int32_t n_embd,
int32_t n_vocab) {
const ggml_tensor * output = llama_model_dflash_output_tensor(draft_model);
if (draft_model == nullptr || draft_model->tok_embd == nullptr || output == nullptr || n_embd <= 0 || n_vocab <= 0) {
return false;
}
return (int32_t) draft_model->tok_embd->ne[0] == n_embd &&
(int32_t) draft_model->tok_embd->ne[1] == n_vocab &&
(int32_t) output->ne[0] == n_embd &&
(int32_t) output->ne[1] == n_vocab;
}
bool llama_model_share_dflash_io_tensors(
struct llama_model * draft_model,
const struct llama_model * target_model) {
if (draft_model == nullptr || target_model == nullptr) {
return false;
}
if (draft_model->arch != LLM_ARCH_DFLASH_DRAFT) {
return true;
}
if (draft_model->tok_embd == nullptr) {
draft_model->tok_embd = target_model->tok_embd;
}
if (draft_model->output == nullptr) {
draft_model->output = target_model->output ? target_model->output : target_model->tok_embd;
if (draft_model->output == nullptr) {
draft_model->output = draft_model->tok_embd;
}
}
const bool uses_shared_tok = draft_model->tok_embd == target_model->tok_embd;
const bool uses_shared_output = draft_model->output == target_model->output ||
draft_model->output == target_model->tok_embd;
if (draft_model->output_mtp == nullptr && target_model->output_mtp != nullptr && uses_shared_tok && uses_shared_output) {
draft_model->output_mtp = target_model->output_mtp;
}
const struct ggml_tensor * output = llama_model_dflash_output_tensor(draft_model);
return draft_model->tok_embd != nullptr && output != nullptr;
}
static bool llama_set_dflash_target_features_impl(
struct llama_context * ctx,
const float * target_features,
size_t n_floats,
int32_t n_rows,
const llama_pos * target_positions,
bool copy_data,
const llama_dflash_window_update * window_update) {
const bool have_full_features = target_features != nullptr && n_floats > 0;
const bool have_append_features = window_update != nullptr &&
window_update->append_features != nullptr &&
window_update->append_floats > 0 &&
window_update->append_rows > 0;
if (ctx == nullptr || n_rows <= 0 || (!have_full_features && !have_append_features)) {
return false;
}
if (have_full_features && copy_data) {
ctx->dflash.target.features_owned.assign(target_features, target_features + n_floats);
ctx->dflash.target.features = ctx->dflash.target.features_owned.data();
} else if (have_full_features) {
ctx->dflash.target.features_owned.clear();
ctx->dflash.target.features = target_features;
} else {
ctx->dflash.target.features_owned.clear();
ctx->dflash.target.features = nullptr;
}
ctx->dflash.target.features_n_floats = have_full_features ? n_floats : 0;
ctx->dflash.target.features_n_rows = n_rows;
if (have_append_features && copy_data) {
ctx->dflash.target.append_features_owned.assign(
window_update->append_features,
window_update->append_features + window_update->append_floats);
ctx->dflash.target.append_features = ctx->dflash.target.append_features_owned.data();
} else if (have_append_features) {
ctx->dflash.target.append_features_owned.clear();
ctx->dflash.target.append_features = window_update->append_features;
} else {
ctx->dflash.target.append_features_owned.clear();
ctx->dflash.target.append_features = nullptr;
}
ctx->dflash.target.append_features_n_floats = have_append_features ? window_update->append_floats : 0;
ctx->dflash.target.append_features_n_rows = have_append_features ? window_update->append_rows : 0;
ctx->dflash.target.version = window_update != nullptr && window_update->version > 0
? window_update->version
: ctx->dflash.target.version + 1;
ctx->dflash.target.keep_rows = window_update != nullptr
? std::max<int32_t>(0, std::min(n_rows, window_update->keep_rows))
: 0;
ctx->dflash.target.append_rows = window_update != nullptr
? std::max<int32_t>(0, std::min(n_rows, window_update->append_rows))
: n_rows;
ctx->dflash.target.replace = window_update != nullptr
? window_update->replace
: true;
if (ctx->dflash.target.keep_rows + ctx->dflash.target.append_rows > n_rows) {
ctx->dflash.target.keep_rows = std::max<int32_t>(0, n_rows - ctx->dflash.target.append_rows);
}
const int32_t cross_ctx = ctx->dflash.visible_cross_ctx > 0
? ctx->dflash.visible_cross_ctx
: std::max<int32_t>(1, (int32_t) ctx->cparams.n_ctx - (int32_t) ctx->model.hparams.dflash_block_size);
const llama_dflash_window_update cache_window_update = {
ctx->dflash.target.version,
ctx->dflash.target.keep_rows,
ctx->dflash.target.append_rows,
ctx->dflash.target.replace,
ctx->dflash.target.append_features,
ctx->dflash.target.append_features_n_floats,
};
const llama_dflash_kv_cache_transition cache_plan = llama_plan_dflash_kv_cache_transition_for_ctx(ctx, cache_window_update, n_rows);
if (cache_plan.cache_up_to_date) {
ctx->dflash.kv.cache_view_n_filled = ctx->dflash.kv.cache_n_filled;
ctx->dflash.kv.cache_view_write_pos = ctx->dflash.kv.cache_write_pos;
ctx->dflash.kv.cache_view_valid = ctx->dflash.kv.cache_valid;
} else if (cross_ctx > 0) {
ctx->dflash.kv.cache_view_n_filled = cache_plan.next_n_filled;
ctx->dflash.kv.cache_view_write_pos = cache_plan.next_write_pos;
ctx->dflash.kv.cache_view_valid = cache_plan.next_n_filled > 0;
}
if (target_positions != nullptr) {
if (copy_data) {
ctx->dflash.target.positions_owned.assign(target_positions, target_positions + n_rows);
ctx->dflash.target.positions = ctx->dflash.target.positions_owned.data();
} else {
ctx->dflash.target.positions_owned.clear();
ctx->dflash.target.positions = target_positions;
}
ctx->dflash.target.positions_n = (size_t) n_rows;
} else {
ctx->dflash.target.positions_owned.clear();
ctx->dflash.target.positions = nullptr;
ctx->dflash.target.positions_n = 0;
}
return true;
}
bool llama_set_dflash_target_features_copy(
struct llama_context * ctx,
const float * target_features,
size_t n_floats,
int32_t n_rows,
const llama_pos * target_positions,
const llama_dflash_window_update * window_update) {
return llama_set_dflash_target_features_impl(ctx, target_features, n_floats, n_rows, target_positions, true, window_update);
}
bool llama_set_dflash_target_features_view(
struct llama_context * ctx,
const float * target_features,
size_t n_floats,
int32_t n_rows,
const llama_pos * target_positions,
const llama_dflash_window_update * window_update) {
return llama_set_dflash_target_features_impl(ctx, target_features, n_floats, n_rows, target_positions, false, window_update);
}
static bool llama_dflash_parse_layer_id(const struct ggml_tensor * tensor, int32_t & layer_id) {
if (tensor == nullptr) {
return false;
}
static constexpr const char * prefix = "l_out-";
if (std::strncmp(tensor->name, prefix, std::strlen(prefix)) != 0) {
return false;
}
char * end = nullptr;
const long raw = std::strtol(tensor->name + std::strlen(prefix), &end, 10);
if (end == tensor->name + std::strlen(prefix) || *end != '\0') {
return false;
}
layer_id = (int32_t) raw;
if (layer_id >= 1000) {
layer_id %= 1000;
}
return layer_id >= 0;
}
static int32_t llama_dflash_find_layer_index(const struct llama_context * ctx, int32_t layer_id) {
if (ctx == nullptr || !ctx->dflash.capture) {
return -1;
}
const auto & layer_ids = ctx->dflash.capture->layer_ids;
const auto it = std::find(layer_ids.begin(), layer_ids.end(), layer_id);
return it == layer_ids.end() ? -1 : (int32_t) std::distance(layer_ids.begin(), it);
}
static bool llama_dflash_capture_eval_callback(struct ggml_tensor * tensor, bool ask, void * user_data) {
auto * ctx = static_cast<llama_context *>(user_data);
if (ctx == nullptr || !ctx->dflash.capture) {
return false;
}
int32_t layer_id = -1;
if (!llama_dflash_parse_layer_id(tensor, layer_id)) {
return false;
}
const int32_t layer_idx = llama_dflash_find_layer_index(ctx, layer_id);
if (layer_idx < 0) {
return false;
}
if (ask) {
return true;
}
const int32_t row_width = (int32_t) tensor->ne[0];
const int32_t row_count = row_width > 0 ? (int32_t) (ggml_nelements(tensor) / (int64_t) row_width) : 0;
if (row_width <= 0 || row_count <= 0) {
return false;
}
auto & capture = *ctx->dflash.capture;
if (capture.capture_batch_id == 0) {
capture.capture_batch_id = 1;
}
if (capture.layer_seen_batch_id.size() != capture.layer_ids.size()) {
capture.layer_seen_batch_id.assign(capture.layer_ids.size(), 0);
}
auto & rows = capture.layer_rows[(size_t) layer_idx];
rows.resize((size_t) row_count * (size_t) row_width);
ggml_backend_tensor_get(tensor, rows.data(), 0, ggml_nbytes(tensor));
capture.row_width = row_width;
capture.row_count = row_count;
capture.layer_seen_batch_id[(size_t) layer_idx] = capture.capture_batch_id;
return true;
}
bool llama_set_dflash_capture_layers(
struct llama_context * ctx,
const int32_t * layer_ids,
int32_t n_layers) {
if (ctx == nullptr || layer_ids == nullptr || n_layers <= 0) {
return false;
}
auto capture = std::make_unique<llama_context::dflash_runtime::capture_state>();
capture->layer_ids.assign(layer_ids, layer_ids + n_layers);
capture->layer_rows.resize((size_t) n_layers);
capture->layer_seen_batch_id.assign((size_t) n_layers, 0);
capture->prev_cb_eval = ctx->cparams.cb_eval;
capture->prev_cb_eval_user_data = ctx->cparams.cb_eval_user_data;
ctx->dflash.capture = std::move(capture);
ctx->dflash.feature_view_buffer.clear();
ctx->cparams.cb_eval = llama_dflash_capture_eval_callback;
ctx->cparams.cb_eval_user_data = ctx;
if (ctx->sched != nullptr) {
ggml_backend_sched_set_eval_callback(ctx->sched, ctx->cparams.cb_eval, ctx->cparams.cb_eval_user_data);
}
return true;
}
void llama_clear_dflash_capture(struct llama_context * ctx) {
if (ctx == nullptr) {
return;
}
ggml_backend_sched_eval_callback prev_cb_eval = nullptr;
void * prev_cb_eval_user_data = nullptr;
if (ctx->dflash.capture) {
prev_cb_eval = ctx->dflash.capture->prev_cb_eval;
prev_cb_eval_user_data = ctx->dflash.capture->prev_cb_eval_user_data;
}
ctx->dflash.capture.reset();
ctx->dflash.feature_view_buffer.clear();
if (ctx->cparams.cb_eval == llama_dflash_capture_eval_callback && ctx->cparams.cb_eval_user_data == ctx) {
ctx->cparams.cb_eval = prev_cb_eval;
ctx->cparams.cb_eval_user_data = prev_cb_eval_user_data;
if (ctx->sched != nullptr) {
ggml_backend_sched_set_eval_callback(ctx->sched, prev_cb_eval, prev_cb_eval_user_data);
}
}
}
void llama_begin_dflash_capture_batch(struct llama_context * ctx) {
if (ctx == nullptr || !ctx->dflash.capture) {
return;
}
auto & capture = *ctx->dflash.capture;
capture.capture_batch_id++;
capture.row_count = 0;
capture.row_width = 0;
std::fill(capture.layer_seen_batch_id.begin(), capture.layer_seen_batch_id.end(), 0);
}
void llama_finish_dflash_capture_batch(
struct llama_context * ctx,
bool is_prompt_warmup) {
if (ctx == nullptr || !ctx->dflash.capture) {
return;
}
GGML_UNUSED(is_prompt_warmup);
auto & capture = *ctx->dflash.capture;
// Reset the batch-local reference shape so the next decode only compares layers within
// the same batch, not against the previous prompt/verify batch.
capture.row_count = 0;
capture.row_width = 0;
}
static bool llama_spec_prepare_dflash_capture(
struct llama_context * ctx,
int32_t & row_count,
int32_t & row_width,
int32_t & n_layers) {
if (ctx == nullptr || !ctx->dflash.capture) {
return false;
}
llama_synchronize(ctx);
auto & capture = *ctx->dflash.capture;
row_count = capture.row_count;
row_width = capture.row_width;
n_layers = (int32_t) capture.layer_ids.size();
if (row_count <= 0 || row_width <= 0 || n_layers <= 0 || capture.layer_rows.size() != (size_t) n_layers) {
return false;
}
if (capture.capture_batch_id == 0 || capture.layer_seen_batch_id.size() != (size_t) n_layers) {
LLAMA_LOG_WARN("%s: DFlash capture batch markers are not initialized (batch_id=%llu layers=%zu expected=%d)\n",
__func__,
(unsigned long long) capture.capture_batch_id,
capture.layer_seen_batch_id.size(),
n_layers);
return false;
}
for (int32_t layer_idx = 0; layer_idx < n_layers; ++layer_idx) {
if (capture.layer_seen_batch_id[(size_t) layer_idx] != capture.capture_batch_id) {
LLAMA_LOG_WARN("%s: DFlash capture is stale for layer %d (seen_batch=%llu current_batch=%llu rows=%d width=%d)\n",
__func__,
capture.layer_ids[(size_t) layer_idx],
(unsigned long long) capture.layer_seen_batch_id[(size_t) layer_idx],
(unsigned long long) capture.capture_batch_id,
row_count,
row_width);
return false;
}
const auto & rows = capture.layer_rows[(size_t) layer_idx];
if (rows.size() != (size_t) row_count * (size_t) row_width) {
LLAMA_LOG_WARN("%s: DFlash capture rows mismatch for layer %d: got=%zu expected=%zu (rows=%d width=%d)\n",
__func__, capture.layer_ids[(size_t) layer_idx], rows.size(),
(size_t) row_count * (size_t) row_width, row_count, row_width);
return false;
}
}
return true;
}
static bool llama_spec_materialize_dflash_rows_prepared(
struct llama_context * ctx,
int32_t row_count,
int32_t row_width,
int32_t n_layers,
const std::vector<int32_t> & row_indices,
std::vector<float> & rows_out,
int32_t & combined_width);
static bool llama_spec_materialize_dflash_rows(
struct llama_context * ctx,
const std::vector<int32_t> & row_indices,
std::vector<float> & rows_out,
int32_t & combined_width) {
int32_t row_count = 0;
int32_t row_width = 0;
int32_t n_layers = 0;
if (!llama_spec_prepare_dflash_capture(ctx, row_count, row_width, n_layers)) {
return false;
}
return llama_spec_materialize_dflash_rows_prepared(ctx, row_count, row_width, n_layers, row_indices, rows_out, combined_width);
}
static bool llama_spec_materialize_dflash_rows_prepared(
struct llama_context * ctx,
int32_t row_count,
int32_t row_width,
int32_t n_layers,
const std::vector<int32_t> & row_indices,
std::vector<float> & rows_out,
int32_t & combined_width) {
rows_out.clear();
combined_width = 0;
if (ctx == nullptr || row_indices.empty()) {
return false;
}
if (row_count <= 0 || row_width <= 0 || n_layers <= 0 || ctx->dflash.capture == nullptr) {
return false;
}
combined_width = row_width * n_layers;
rows_out.resize((size_t) row_indices.size() * (size_t) combined_width);
const auto & layer_rows = ctx->dflash.capture->layer_rows;
for (size_t out_row = 0; out_row < row_indices.size(); ++out_row) {
int32_t row_index = row_indices[out_row];
if (row_index < 0) {
row_index += row_count;
}
if (row_index < 0 || row_index >= row_count) {
rows_out.clear();
combined_width = 0;
return false;
}
float * dst = rows_out.data() + out_row * (size_t) combined_width;
for (int32_t layer_idx = 0; layer_idx < n_layers; ++layer_idx) {
const float * src = layer_rows[(size_t) layer_idx].data() + (size_t) row_index * (size_t) row_width;
std::memcpy(dst + (size_t) layer_idx * (size_t) row_width, src, (size_t) row_width * sizeof(float));
}
}
return true;
}
bool llama_spec_get_dflash_feature_view(
struct llama_context * ctx,
const llama_batch & batch,
llama_spec_feature_view & view) {
if (ctx == nullptr || batch.n_tokens <= 0 || batch.pos == nullptr || batch.n_seq_id == nullptr || batch.seq_id == nullptr) {
return false;
}
int32_t row_count = 0;
int32_t row_width = 0;
int32_t n_layers = 0;
if (!llama_spec_prepare_dflash_capture(ctx, row_count, row_width, n_layers)) {
return false;
}
const int32_t batch_row_offset = std::max<int32_t>(0, batch.n_tokens - row_count);
std::vector<int32_t> row_indices;
std::vector<int32_t> batch_indices;
row_indices.reserve((size_t) (batch.n_tokens - batch_row_offset));
batch_indices.reserve((size_t) (batch.n_tokens - batch_row_offset));
for (int32_t i = batch_row_offset; i < batch.n_tokens; ++i) {
row_indices.push_back(i - batch_row_offset);
batch_indices.push_back(i);
}
if (row_indices.empty()) {
return false;
}
view = {};
view.kind = LLAMA_SPEC_FEATURE_HIDDEN_STATE;
if (!llama_spec_materialize_dflash_rows_prepared(ctx, row_count, row_width, n_layers, row_indices, ctx->dflash.feature_view_buffer, view.width)) {
return false;
}
view.rows.reserve(batch_indices.size());
for (int32_t batch_index : batch_indices) {
if (batch.n_seq_id[batch_index] <= 0 || batch.seq_id[batch_index] == nullptr) {
view.rows.clear();
return false;
}
view.rows.push_back({
/* .seq_id = */ batch.seq_id[batch_index][0],
/* .pos = */ batch.pos[batch_index],
/* .data = */ ctx->dflash.feature_view_buffer.data() + view.rows.size() * (size_t) view.width,
});
}
return true;
}
bool llama_spec_get_dflash_feature_view_for_seq(
struct llama_context * ctx,
const llama_batch & batch,
llama_seq_id seq_id,
llama_spec_feature_view & view) {
if (ctx == nullptr || batch.n_tokens <= 0 || batch.pos == nullptr || batch.n_seq_id == nullptr || batch.seq_id == nullptr) {
return false;
}
int32_t row_count = 0;
int32_t row_width = 0;
int32_t n_layers = 0;
if (!llama_spec_prepare_dflash_capture(ctx, row_count, row_width, n_layers)) {
return false;
}
const int32_t batch_row_offset = std::max<int32_t>(0, batch.n_tokens - row_count);
std::vector<int32_t> row_indices;
row_indices.reserve((size_t) batch.n_tokens);
std::vector<int32_t> batch_indices;
batch_indices.reserve((size_t) batch.n_tokens);
for (int32_t i = batch_row_offset; i < batch.n_tokens; ++i) {
if (batch.n_seq_id[i] <= 0 || batch.seq_id[i] == nullptr) {
return false;
}
for (int32_t j = 0; j < batch.n_seq_id[i]; ++j) {
if (batch.seq_id[i][j] == seq_id) {
row_indices.push_back(i - batch_row_offset);
batch_indices.push_back(i);
break;
}
}
}
if (row_indices.empty()) {
return false;
}
view = {};
view.kind = LLAMA_SPEC_FEATURE_HIDDEN_STATE;
if (!llama_spec_materialize_dflash_rows_prepared(ctx, row_count, row_width, n_layers, row_indices, ctx->dflash.feature_view_buffer, view.width)) {
return false;
}
view.rows.reserve(row_indices.size());
for (size_t i = 0; i < batch_indices.size(); ++i) {
const int32_t batch_index = batch_indices[i];
view.rows.push_back({
/* .seq_id = */ seq_id,
/* .pos = */ batch.pos[batch_index],
/* .data = */ ctx->dflash.feature_view_buffer.data() + i * (size_t) view.width,
});
}
return true;
}
bool llama_spec_copy_dflash_rows_from_output_indices(
struct llama_context * ctx,
const std::vector<int32_t> & output_indices,
std::vector<float> & hidden_rows) {
int32_t combined_width = 0;
if (!llama_spec_materialize_dflash_rows(ctx, output_indices, hidden_rows, combined_width)) {
hidden_rows.clear();
return false;
}
return hidden_rows.size() == (size_t) output_indices.size() * (size_t) combined_width;
}