mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-06-28 04:30:15 -05:00
Initial dflash implementation
This commit is contained in:
parent
6eff055a0c
commit
82cff238fe
2
.flake8
2
.flake8
@ -17,3 +17,5 @@ exclude =
|
||||
# This contains builds that we don't want to check
|
||||
dist # This is generated with `python build .` for package releases
|
||||
# max-complexity = 10
|
||||
per-file-ignores =
|
||||
gguf-py/gguf/constants.py: E201, E222
|
||||
|
||||
@ -148,6 +148,9 @@ common_params_speculative common_params_speculative::with_stage_overrides(const
|
||||
if (stage.has_p_min_override()) {
|
||||
result.p_min = stage.p_min;
|
||||
}
|
||||
if (stage.has_dflash_cross_ctx_override()) {
|
||||
result.dflash_cross_ctx = stage.dflash_cross_ctx;
|
||||
}
|
||||
if (stage.has_ngram_size_n_override()) {
|
||||
result.ngram_size_n = stage.ngram_size_n;
|
||||
result.ngram_mod.reset();
|
||||
@ -247,8 +250,12 @@ bool common_speculative_validate_chain(const common_params_speculative & params,
|
||||
return fail("speculative stage has n_min greater than n_max");
|
||||
}
|
||||
|
||||
if (stage.type == COMMON_SPECULATIVE_TYPE_DRAFT && !params.has_dft()) {
|
||||
return fail("draft speculative stage requires a draft model or draft params");
|
||||
if ((stage.type == COMMON_SPECULATIVE_TYPE_DRAFT || stage.type == COMMON_SPECULATIVE_TYPE_DFLASH) && !params.has_dft()) {
|
||||
return fail(common_speculative_type_to_str(stage.type) + " speculative stage requires a draft model or draft params");
|
||||
}
|
||||
|
||||
if (stage.type == COMMON_SPECULATIVE_TYPE_DFLASH && stage_params.dflash_cross_ctx < 1) {
|
||||
return fail("dflash speculative stage requires cross_ctx >= 1");
|
||||
}
|
||||
}
|
||||
|
||||
@ -871,6 +878,13 @@ static void common_speculative_stage_apply_kv(
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (key == "cross_ctx" || key == "dflash_cross_ctx") {
|
||||
stage.dflash_cross_ctx = std::stoi(value_raw);
|
||||
if (stage.dflash_cross_ctx < 1) {
|
||||
throw std::invalid_argument("speculative stage dflash cross_ctx must be at least 1");
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (key == "ngram_size_n") {
|
||||
stage.ngram_size_n = std::stoi(value_raw);
|
||||
if (stage.ngram_size_n < 1 || stage.ngram_size_n > 1024) {
|
||||
@ -1468,8 +1482,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
||||
throw std::invalid_argument("--spec-type cannot be combined with --spec-stage; use only --spec-stage for explicit stage chains");
|
||||
}
|
||||
|
||||
const auto type = common_speculative_type_from_name(argv[i]);
|
||||
if (type == COMMON_SPECULATIVE_TYPE_NONE || type == COMMON_SPECULATIVE_TYPE_MTP || common_speculative_type_is_self_spec(type)) {
|
||||
const auto stage = common_speculative_stage_from_arg(argv[i]);
|
||||
const auto type = stage.type;
|
||||
if (type == COMMON_SPECULATIVE_TYPE_NONE || type == COMMON_SPECULATIVE_TYPE_DFLASH || type == COMMON_SPECULATIVE_TYPE_MTP || common_speculative_type_is_self_spec(type)) {
|
||||
params.speculative = params.speculative.with_stage_overrides(stage);
|
||||
params.speculative.type = type;
|
||||
if (type == COMMON_SPECULATIVE_TYPE_MTP) {
|
||||
params.has_mtp = true;
|
||||
@ -3178,7 +3194,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
|
||||
options.push_back({ "*", "--spec-stage SPEC[:k=v,...]", "explicit speculative stage. repeat once for a supported two-stage chain.\n"
|
||||
"examples: --spec-stage ngram-mod:n_max=64,n_min=2 --spec-stage mtp:n_max=1\n"
|
||||
"supported two-stage shape in this PR: self-spec first, then mtp or draft fallback" });
|
||||
options.push_back({ "*", "--spec-type Name [none | mtp | ngram-cache | ngram-simple | ngram-map-k | ngram-map-k4v | ngram-mod | suffix]", "single-stage speculative selection when --spec-stage is not used (default: %d)\n", (int)params.speculative.type});
|
||||
options.push_back({ "*", "--spec-type Name[:k=v,...] [none | dflash | mtp | ngram-cache | ngram-simple | ngram-map-k | ngram-map-k4v | ngram-mod | suffix]", "single-stage speculative selection when --spec-stage is not used (default: %d)\n", (int)params.speculative.type});
|
||||
options.push_back({ "*", "--spec-ngram-size-n N", "ngram size N for ngram-simple/ngram-map speculative decoding, length of lookup n-gram (default: %d)\n",params.speculative.ngram_size_n });
|
||||
|
||||
options.push_back({ "*", "--spec-ngram-size-m N", "ngram size M for ngram-simple/ngram-map speculative decoding, length of draft m-gram (default: %d)\n", params.speculative.ngram_size_m });
|
||||
|
||||
@ -140,6 +140,7 @@ thinking_tokens thinking_tokens_from_string(const std::string& format);
|
||||
enum common_speculative_type {
|
||||
COMMON_SPECULATIVE_TYPE_NONE, // no speculative decoding
|
||||
COMMON_SPECULATIVE_TYPE_DRAFT, // draft model
|
||||
COMMON_SPECULATIVE_TYPE_DFLASH, // DFlash draft model
|
||||
COMMON_SPECULATIVE_TYPE_MTP, // MTP model
|
||||
COMMON_SPECULATIVE_TYPE_EAGLE3, // eagle draft model
|
||||
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding
|
||||
@ -162,6 +163,7 @@ struct common_speculative_stage_params {
|
||||
int32_t n_max = -1;
|
||||
int32_t n_min = -1;
|
||||
float p_min = -1.0f;
|
||||
int32_t dflash_cross_ctx = -1;
|
||||
|
||||
uint16_t ngram_size_n = 0;
|
||||
uint16_t ngram_size_m = 0;
|
||||
@ -173,6 +175,7 @@ struct common_speculative_stage_params {
|
||||
bool has_n_max_override() const { return n_max >= 0; }
|
||||
bool has_n_min_override() const { return n_min >= 0; }
|
||||
bool has_p_min_override() const { return p_min >= 0.0f; }
|
||||
bool has_dflash_cross_ctx_override() const { return dflash_cross_ctx >= 0; }
|
||||
bool has_ngram_size_n_override() const { return ngram_size_n > 0; }
|
||||
bool has_ngram_size_m_override() const { return ngram_size_m > 0; }
|
||||
bool has_ngram_min_hits_override() const { return ngram_min_hits > 0; }
|
||||
@ -204,6 +207,7 @@ struct common_params_speculative {
|
||||
int32_t n_max = 16; // number of tokens to draft during speculative decoding
|
||||
int32_t n_min = 0; // minimum number of tokens to draft during speculative decoding
|
||||
std::vector<common_speculative_stage_params> stages; // explicit stage chain for single-spec or self-spec + model fallback
|
||||
int32_t dflash_cross_ctx = 512; // target-feature context window for DFlash
|
||||
|
||||
float p_split = 0.1f; // speculative decoding split probability
|
||||
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
|
||||
@ -516,7 +520,7 @@ struct gpt_params {
|
||||
bool do_checkpoint = false; // do checkpoint for recurrent models only
|
||||
int32_t ctx_checkpoints_n = 32; // max number of context checkpoints per slot
|
||||
int32_t ctx_checkpoints_interval = 512; // minimum number of tokens between each context checkpoints
|
||||
int32_t ctx_checkpoints_tolerance = 5; // the number of tokens before the full prompt to create the checkpoint
|
||||
int32_t ctx_checkpoints_tolerance = 5; // the number of tokens before the full prompt to create the checkpoint
|
||||
int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc.
|
||||
int32_t cache_ram_n_min = 0; // min number of tokens required to save in the ram
|
||||
float cache_ram_similarity = 0.5f; // similarity of tokens to cached tokens
|
||||
|
||||
@ -24,6 +24,7 @@ void llama_set_mtp_target_context(struct llama_context * ctx, struct llama_conte
|
||||
const std::vector<enum common_speculative_type> common_speculative_types = {
|
||||
COMMON_SPECULATIVE_TYPE_NONE,
|
||||
COMMON_SPECULATIVE_TYPE_DRAFT,
|
||||
COMMON_SPECULATIVE_TYPE_DFLASH,
|
||||
COMMON_SPECULATIVE_TYPE_MTP,
|
||||
COMMON_SPECULATIVE_TYPE_EAGLE3,
|
||||
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE,
|
||||
@ -37,6 +38,7 @@ const std::vector<enum common_speculative_type> common_speculative_types = {
|
||||
const std::map<std::string, enum common_speculative_type> common_speculative_type_from_name_map = {
|
||||
{"none", COMMON_SPECULATIVE_TYPE_NONE},
|
||||
{"draft", COMMON_SPECULATIVE_TYPE_DRAFT},
|
||||
{"dflash", COMMON_SPECULATIVE_TYPE_DFLASH},
|
||||
{"mtp", COMMON_SPECULATIVE_TYPE_MTP},
|
||||
{"eagle3", COMMON_SPECULATIVE_TYPE_EAGLE3},
|
||||
{"ngram_simple", COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE},
|
||||
@ -117,6 +119,44 @@ static bool common_speculative_are_compatible(
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool common_speculative_are_dflash_compatible(
|
||||
const llama_model * model_tgt,
|
||||
const llama_model * model_dft) {
|
||||
const llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt);
|
||||
const llama_vocab * vocab_dft = llama_model_get_vocab(model_dft);
|
||||
|
||||
if (llama_vocab_type(vocab_tgt) != llama_vocab_type(vocab_dft)) {
|
||||
LOG_DBG("%s: DFlash draft model vocab type must match the target model\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
const int n_vocab_tgt = llama_vocab_n_tokens(vocab_tgt);
|
||||
const int n_vocab_dft = llama_vocab_n_tokens(vocab_dft);
|
||||
const int vocab_diff = n_vocab_tgt > n_vocab_dft
|
||||
? n_vocab_tgt - n_vocab_dft
|
||||
: n_vocab_dft - n_vocab_tgt;
|
||||
|
||||
if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
|
||||
LOG_DBG("%s: DFlash draft vocab size differs too much from the target model (%d vs %d)\n",
|
||||
__func__, n_vocab_dft, n_vocab_tgt);
|
||||
return false;
|
||||
}
|
||||
|
||||
for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) {
|
||||
const char * token_text_tgt = llama_vocab_get_text(vocab_tgt, i);
|
||||
const char * token_text_dft = llama_vocab_get_text(vocab_dft, i);
|
||||
|
||||
if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
|
||||
LOG_DBG("%s: DFlash draft token %d differs - target '%s', draft '%s'\n", __func__, i,
|
||||
common_token_to_piece(vocab_tgt, i).c_str(),
|
||||
common_token_to_piece(vocab_dft, i).c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// state of an implementation of speculative decoding
|
||||
//
|
||||
// each implementation has a unique type and a state that is implementation-specific
|
||||
@ -168,9 +208,18 @@ struct common_speculative_state {
|
||||
};
|
||||
|
||||
struct common_speculative_state_mtp;
|
||||
struct common_speculative_state_dflash;
|
||||
|
||||
static common_speculative_state_mtp * common_speculative_get_mtp_state(common_speculative * spec);
|
||||
static const common_speculative_state_mtp * common_speculative_get_mtp_state(const common_speculative * spec);
|
||||
static common_speculative_state_dflash * common_speculative_get_dflash_state(common_speculative * spec);
|
||||
static const common_speculative_state_dflash * common_speculative_get_dflash_state(const common_speculative * spec);
|
||||
static int32_t common_speculative_feature_width(const common_speculative * spec);
|
||||
static void dflash_append_target_features(
|
||||
common_speculative_state_dflash & state,
|
||||
const float * feature_rows,
|
||||
int32_t n_rows);
|
||||
static void dflash_clear_target_features(common_speculative_state_dflash & state);
|
||||
static void mtp_invalidate_cached_drafts(common_speculative_state_mtp & state);
|
||||
|
||||
static std::vector<llama_token> mtp_speculative_gen_draft(
|
||||
@ -302,6 +351,134 @@ struct common_speculative_state_mtp : public common_speculative_state {
|
||||
}
|
||||
};
|
||||
|
||||
struct common_speculative_state_dflash : public common_speculative_state {
|
||||
llama_context * ctx_tgt;
|
||||
llama_context * ctx_dft;
|
||||
|
||||
llama_batch batch = {};
|
||||
|
||||
int32_t block_size = 0;
|
||||
int32_t mask_token_id = -1;
|
||||
int32_t n_target_features = 0;
|
||||
int32_t cross_ctx = 0;
|
||||
bool ready = false;
|
||||
|
||||
std::vector<int32_t> target_layer_ids;
|
||||
std::vector<float> target_window;
|
||||
int32_t target_window_rows = 0;
|
||||
|
||||
common_speculative_state_dflash(
|
||||
enum common_speculative_type type,
|
||||
llama_context * ctx_tgt,
|
||||
llama_context * ctx_dft,
|
||||
int32_t cross_ctx)
|
||||
: common_speculative_state(type)
|
||||
, ctx_tgt(ctx_tgt)
|
||||
, ctx_dft(ctx_dft)
|
||||
, cross_ctx(std::max(1, cross_ctx))
|
||||
{
|
||||
const llama_model * model_dft = llama_get_model(ctx_dft);
|
||||
|
||||
if (!common_speculative_are_dflash_compatible(llama_get_model(ctx_tgt), model_dft)) {
|
||||
LOG_ERR("%s: DFlash draft model vocab/tokenizer is incompatible with the target model\n", __func__);
|
||||
return;
|
||||
}
|
||||
|
||||
block_size = llama_model_dflash_block_size(model_dft);
|
||||
mask_token_id = llama_model_dflash_mask_token_id(model_dft);
|
||||
n_target_features = llama_model_dflash_n_target_features(model_dft);
|
||||
const int32_t n_target_layers = llama_model_dflash_n_target_layers(model_dft);
|
||||
|
||||
if (block_size <= 0 || mask_token_id < 0 || n_target_features <= 0 || n_target_layers <= 0) {
|
||||
LOG_ERR("%s: invalid DFlash metadata (block_size=%d, mask_token_id=%d, n_target_features=%d, n_target_layers=%d)\n",
|
||||
__func__, block_size, mask_token_id, n_target_features, n_target_layers);
|
||||
return;
|
||||
}
|
||||
|
||||
target_layer_ids.resize((size_t) n_target_layers);
|
||||
if (llama_model_dflash_target_layer_ids(model_dft, target_layer_ids.data(), n_target_layers) != n_target_layers) {
|
||||
LOG_ERR("%s: failed to read DFlash target layer ids\n", __func__);
|
||||
target_layer_ids.clear();
|
||||
return;
|
||||
}
|
||||
|
||||
if (!llama_set_dflash_capture_layers(ctx_tgt, target_layer_ids.data(), (int32_t) target_layer_ids.size())) {
|
||||
LOG_ERR("%s: failed to configure DFlash target capture callback\n", __func__);
|
||||
return;
|
||||
}
|
||||
|
||||
batch = llama_batch_init(std::max(1, block_size), 0, 1);
|
||||
ready = true;
|
||||
|
||||
LOG_INF("%s: DFlash context ready (n_ctx=%d, block_size=%d, cross_ctx=%d, n_target_features=%d)\n",
|
||||
__func__, llama_n_ctx(ctx_dft), block_size, this->cross_ctx, n_target_features);
|
||||
}
|
||||
|
||||
~common_speculative_state_dflash() override {
|
||||
llama_clear_dflash_capture(ctx_tgt);
|
||||
if (ctx_dft) {
|
||||
llama_free(ctx_dft);
|
||||
}
|
||||
if (batch.token != nullptr) {
|
||||
llama_batch_free(batch);
|
||||
}
|
||||
}
|
||||
|
||||
void begin(const llama_tokens & prompt) override {
|
||||
GGML_UNUSED(prompt);
|
||||
target_window.clear();
|
||||
target_window_rows = 0;
|
||||
llama_kv_cache_clear(ctx_dft);
|
||||
}
|
||||
|
||||
void draft(
|
||||
const common_params_speculative & params,
|
||||
const llama_tokens & prompt_tgt,
|
||||
llama_token id_last,
|
||||
llama_tokens & result) override {
|
||||
GGML_UNUSED(prompt_tgt);
|
||||
GGML_UNUSED(id_last);
|
||||
|
||||
result.clear();
|
||||
if (!ready || target_window_rows <= 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int32_t n_draft = std::min<int32_t>(params.n_max, block_size);
|
||||
if (n_draft <= 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!llama_set_dflash_target_features_copy(ctx_dft, target_window.data(), target_window.size(), target_window_rows)) {
|
||||
LOG_ERR("%s: failed to set DFlash target features\n", __func__);
|
||||
return;
|
||||
}
|
||||
|
||||
llama_kv_cache_clear(ctx_dft);
|
||||
batch.n_tokens = 0;
|
||||
for (int32_t i = 0; i < n_draft; ++i) {
|
||||
common_batch_add(batch, mask_token_id, cross_ctx + i, { 0 }, true);
|
||||
}
|
||||
|
||||
if (llama_decode(ctx_dft, batch) != 0) {
|
||||
LOG_ERR("%s: llama_decode() failed for DFlash draft batch\n", __func__);
|
||||
batch.n_tokens = 0;
|
||||
return;
|
||||
}
|
||||
|
||||
result.reserve((size_t) n_draft);
|
||||
for (int32_t i = 0; i < n_draft; ++i) {
|
||||
result.push_back(common_sampler_sample_speculative(nullptr, ctx_dft, i, nullptr));
|
||||
}
|
||||
|
||||
batch.n_tokens = 0;
|
||||
}
|
||||
|
||||
void accept(uint16_t n_accepted) override {
|
||||
GGML_UNUSED(n_accepted);
|
||||
}
|
||||
};
|
||||
|
||||
struct common_speculative_state_draft : public common_speculative_state {
|
||||
llama_context * ctx_tgt; // only used for retokenizing from ctx_dft
|
||||
llama_context * ctx_dft;
|
||||
@ -1088,6 +1265,7 @@ std::string common_speculative_type_to_str(enum common_speculative_type type) {
|
||||
switch (type) {
|
||||
case COMMON_SPECULATIVE_TYPE_NONE: return "none";
|
||||
case COMMON_SPECULATIVE_TYPE_DRAFT: return "draft";
|
||||
case COMMON_SPECULATIVE_TYPE_DFLASH: return "dflash";
|
||||
case COMMON_SPECULATIVE_TYPE_MTP: return "mtp";
|
||||
case COMMON_SPECULATIVE_TYPE_EAGLE3: return "eagle3";
|
||||
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return "ngram_simple";
|
||||
@ -1165,8 +1343,13 @@ common_speculative * common_speculative_init(
|
||||
}
|
||||
}
|
||||
|
||||
const bool has_dflash_stage = std::any_of(stages.begin(), stages.end(), [](const common_speculative_stage_params & stage) {
|
||||
return stage.type == COMMON_SPECULATIVE_TYPE_DFLASH;
|
||||
});
|
||||
|
||||
const bool needs_draft_ctx = std::any_of(stages.begin(), stages.end(), [¶ms](const common_speculative_stage_params & stage) {
|
||||
return stage.type == COMMON_SPECULATIVE_TYPE_DRAFT ||
|
||||
stage.type == COMMON_SPECULATIVE_TYPE_DFLASH ||
|
||||
(stage.type == COMMON_SPECULATIVE_TYPE_MTP && params.model_dft != nullptr);
|
||||
});
|
||||
|
||||
@ -1177,7 +1360,33 @@ common_speculative * common_speculative_init(
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ctx_dft = llama_init_from_model(params.model_dft, params.cparams_dft);
|
||||
llama_context_params cparams_dft = params.cparams_dft;
|
||||
|
||||
if (has_dflash_stage) {
|
||||
if (!llama_model_share_dflash_io_tensors(params.model_dft, llama_get_model(ctx_tgt))) {
|
||||
LOG_ERR("%s: failed to share target IO tensors with DFlash draft model\n", __func__);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
int32_t max_cross_ctx = 0;
|
||||
for (const auto & stage : stages) {
|
||||
if (stage.type != COMMON_SPECULATIVE_TYPE_DFLASH) {
|
||||
continue;
|
||||
}
|
||||
|
||||
max_cross_ctx = std::max(max_cross_ctx, params.with_stage_overrides(stage).dflash_cross_ctx);
|
||||
}
|
||||
|
||||
const int32_t block_size = llama_model_dflash_block_size(params.model_dft);
|
||||
if (block_size <= 0) {
|
||||
LOG_ERR("%s: invalid DFlash draft block size\n", __func__);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
cparams_dft.n_ctx = (uint32_t) (max_cross_ctx + block_size);
|
||||
}
|
||||
|
||||
ctx_dft = llama_init_from_model(params.model_dft, cparams_dft);
|
||||
if (ctx_dft == nullptr) {
|
||||
LOG_ERR("%s", "failed to create draft context\n");
|
||||
return nullptr;
|
||||
@ -1240,6 +1449,20 @@ common_speculative * common_speculative_init(
|
||||
));
|
||||
break;
|
||||
}
|
||||
case COMMON_SPECULATIVE_TYPE_DFLASH: {
|
||||
auto state = std::make_unique<common_speculative_state_dflash>(
|
||||
config.type,
|
||||
ctx_tgt,
|
||||
ctx_dft,
|
||||
config.params.dflash_cross_ctx);
|
||||
if (!state->ready) {
|
||||
LOG_ERR("%s: failed to initialize DFlash speculative state\n", __func__);
|
||||
return nullptr;
|
||||
}
|
||||
impls.push_back(std::move(state));
|
||||
ctx_dft = nullptr;
|
||||
break;
|
||||
}
|
||||
case COMMON_SPECULATIVE_TYPE_MTP: {
|
||||
llama_context * ctx_mtp = ctx_dft;
|
||||
if (!ctx_mtp) {
|
||||
@ -1604,6 +1827,10 @@ static bool common_speculative_collect_target_batch_features(
|
||||
const llama_batch & batch,
|
||||
common_speculative_feature_view & features) {
|
||||
features = {};
|
||||
if (common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_DFLASH)) {
|
||||
return llama_spec_get_dflash_feature_view(ctx, batch, features);
|
||||
}
|
||||
|
||||
if (!common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_MTP)) {
|
||||
return true;
|
||||
}
|
||||
@ -1622,6 +1849,10 @@ static bool common_speculative_collect_target_seq_batch_features(
|
||||
llama_seq_id seq_id,
|
||||
common_speculative_feature_view & features) {
|
||||
features = {};
|
||||
if (common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_DFLASH)) {
|
||||
return llama_spec_get_dflash_feature_view_for_seq(ctx, batch, seq_id, features);
|
||||
}
|
||||
|
||||
if (!common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_MTP)) {
|
||||
return true;
|
||||
}
|
||||
@ -1669,21 +1900,27 @@ int32_t common_speculative_on_target_seq_batch(
|
||||
const llama_batch & batch,
|
||||
llama_seq_id seq_id,
|
||||
bool is_prompt_warmup) {
|
||||
llama_context * ctx_mtp = common_speculative_get_companion_ctx(spec);
|
||||
ctx_mtp = ctx_mtp ? ctx_mtp : ctx_tgt;
|
||||
if (ctx_tgt == nullptr || ctx_mtp == nullptr || batch.n_tokens <= 0) {
|
||||
if (ctx_tgt == nullptr || batch.n_tokens <= 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const int n_embd_src = common_speculative_ctx_mtp_n_embd(ctx_tgt);
|
||||
const int n_embd_dst = common_speculative_ctx_mtp_n_embd(ctx_mtp);
|
||||
if (n_embd_src <= 0 || n_embd_dst <= 0) {
|
||||
return -1;
|
||||
}
|
||||
if (!common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_DFLASH)) {
|
||||
llama_context * ctx_mtp = common_speculative_get_companion_ctx(spec);
|
||||
ctx_mtp = ctx_mtp ? ctx_mtp : ctx_tgt;
|
||||
if (ctx_mtp == nullptr) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (n_embd_src != n_embd_dst) {
|
||||
LOG_ERR("MTP warmup hidden state width mismatch: n_embd_src = %d, n_embd_dst = %d\n", n_embd_src, n_embd_dst);
|
||||
return -1;
|
||||
const int n_embd_src = common_speculative_ctx_mtp_n_embd(ctx_tgt);
|
||||
const int n_embd_dst = common_speculative_ctx_mtp_n_embd(ctx_mtp);
|
||||
if (n_embd_src <= 0 || n_embd_dst <= 0) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (n_embd_src != n_embd_dst) {
|
||||
LOG_ERR("MTP warmup hidden state width mismatch: n_embd_src = %d, n_embd_dst = %d\n", n_embd_src, n_embd_dst);
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
common_speculative_feature_view feature_view;
|
||||
@ -1723,6 +1960,10 @@ bool common_speculative_copy_output_hidden_rows(
|
||||
const std::vector<int32_t> & output_indices,
|
||||
std::vector<float> & hidden_rows) {
|
||||
hidden_rows.clear();
|
||||
if (common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_DFLASH)) {
|
||||
return llama_spec_copy_dflash_rows_from_output_indices(ctx, output_indices, hidden_rows);
|
||||
}
|
||||
|
||||
if (!common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_MTP)) {
|
||||
return true;
|
||||
}
|
||||
@ -1760,13 +2001,13 @@ static bool common_speculative_apply_hidden_rows(
|
||||
llama_pos pos_base,
|
||||
const std::vector<llama_token> & ids,
|
||||
const std::vector<float> & hidden_rows) {
|
||||
auto * mtp_state = common_speculative_get_mtp_state(spec);
|
||||
if (mtp_state == nullptr || ids.empty()) {
|
||||
const int32_t feature_width = common_speculative_feature_width(spec);
|
||||
if (feature_width <= 0 || ids.empty()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
const size_t expected_floats = ids.size() * (size_t) mtp_state->n_embd;
|
||||
if (mtp_state->n_embd <= 0 || hidden_rows.size() != expected_floats) {
|
||||
const size_t expected_floats = ids.size() * (size_t) feature_width;
|
||||
if (hidden_rows.size() != expected_floats) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -1777,7 +2018,7 @@ static bool common_speculative_apply_hidden_rows(
|
||||
|
||||
common_speculative_feature_view feature_view;
|
||||
const bool have_feature_view = common_speculative_feature_view_from_hidden_rows(
|
||||
hidden_rows, mtp_state->n_embd, seq_id, pos_base, feature_view);
|
||||
hidden_rows, feature_width, seq_id, pos_base, feature_view);
|
||||
const int32_t ret = have_feature_view
|
||||
? common_speculative_on_target_batch(spec, accepted_batch, feature_view, false)
|
||||
: -1;
|
||||
@ -1794,7 +2035,7 @@ bool common_speculative_commit_accepted_hidden_rows(
|
||||
llama_token sampled_before,
|
||||
const std::vector<llama_token> & ids,
|
||||
const std::vector<float> & hidden_rows) {
|
||||
if (!common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_MTP) || ids.empty()) {
|
||||
if (common_speculative_feature_width(spec) <= 0 || ids.empty()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -1815,7 +2056,7 @@ bool common_speculative_commit_accepted_output(
|
||||
llama_token sampled_before,
|
||||
const std::vector<llama_token> & ids,
|
||||
const std::vector<int32_t> & output_indices) {
|
||||
if (!common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_MTP) || ids.empty()) {
|
||||
if (common_speculative_feature_width(spec) <= 0 || ids.empty()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -1898,6 +2139,40 @@ static const common_speculative_state_mtp * common_speculative_get_mtp_state(con
|
||||
return common_speculative_get_mtp_state(const_cast<common_speculative *>(spec));
|
||||
}
|
||||
|
||||
static common_speculative_state_dflash * common_speculative_get_dflash_state(common_speculative * spec) {
|
||||
if (!spec) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
for (auto & impl : spec->impls) {
|
||||
if (impl->type != COMMON_SPECULATIVE_TYPE_DFLASH) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto * dflash_state = dynamic_cast<common_speculative_state_dflash *>(impl.get())) {
|
||||
return dflash_state;
|
||||
}
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static const common_speculative_state_dflash * common_speculative_get_dflash_state(const common_speculative * spec) {
|
||||
return common_speculative_get_dflash_state(const_cast<common_speculative *>(spec));
|
||||
}
|
||||
|
||||
static int32_t common_speculative_feature_width(const common_speculative * spec) {
|
||||
if (const auto * dflash_state = common_speculative_get_dflash_state(spec); dflash_state != nullptr) {
|
||||
return dflash_state->n_target_features;
|
||||
}
|
||||
|
||||
if (const auto * mtp_state = common_speculative_get_mtp_state(spec); mtp_state != nullptr) {
|
||||
return mtp_state->n_embd;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
static mtp_last_embd & mtp_get_last_embd(common_speculative_state_mtp & state, llama_seq_id seq_id) {
|
||||
auto & last = state.draft_cache_by_seq[seq_id];
|
||||
if ((int) last.embd.size() != state.n_embd) {
|
||||
@ -1941,6 +2216,44 @@ static void mtp_clear_target_hidden(common_speculative_state_mtp & state, llama_
|
||||
state.draft_cache_by_seq.erase(seq_id);
|
||||
}
|
||||
|
||||
static void dflash_append_target_features(
|
||||
common_speculative_state_dflash & state,
|
||||
const float * feature_rows,
|
||||
int32_t n_rows) {
|
||||
if (feature_rows == nullptr || n_rows <= 0 || state.n_target_features <= 0 || state.cross_ctx <= 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const size_t row_width = (size_t) state.n_target_features;
|
||||
if (n_rows >= state.cross_ctx) {
|
||||
const float * src = feature_rows + (size_t) (n_rows - state.cross_ctx) * row_width;
|
||||
state.target_window.assign(src, src + (size_t) state.cross_ctx * row_width);
|
||||
state.target_window_rows = state.cross_ctx;
|
||||
return;
|
||||
}
|
||||
|
||||
const int32_t keep_old_rows = std::min<int32_t>(state.target_window_rows, state.cross_ctx - n_rows);
|
||||
std::vector<float> next_window((size_t) (keep_old_rows + n_rows) * row_width);
|
||||
|
||||
if (keep_old_rows > 0) {
|
||||
const float * old_src = state.target_window.data() + (size_t) (state.target_window_rows - keep_old_rows) * row_width;
|
||||
std::memcpy(next_window.data(), old_src, (size_t) keep_old_rows * row_width * sizeof(float));
|
||||
}
|
||||
|
||||
std::memcpy(
|
||||
next_window.data() + (size_t) keep_old_rows * row_width,
|
||||
feature_rows,
|
||||
(size_t) n_rows * row_width * sizeof(float));
|
||||
|
||||
state.target_window = std::move(next_window);
|
||||
state.target_window_rows = keep_old_rows + n_rows;
|
||||
}
|
||||
|
||||
static void dflash_clear_target_features(common_speculative_state_dflash & state) {
|
||||
state.target_window.clear();
|
||||
state.target_window_rows = 0;
|
||||
}
|
||||
|
||||
static bool common_speculative_capture_target_features(common_speculative * spec, const common_speculative_feature_view & features) {
|
||||
auto * mtp_state = common_speculative_get_mtp_state(spec);
|
||||
if (mtp_state == nullptr || features.kind != COMMON_SPECULATIVE_FEATURE_HIDDEN_STATE || features.width <= 0) {
|
||||
@ -1973,11 +2286,13 @@ bool common_speculative_has_sequence_hidden(const common_speculative * spec, lla
|
||||
|
||||
void common_speculative_clear_sequence_hidden(common_speculative * spec, llama_seq_id seq_id) {
|
||||
auto * mtp_state = common_speculative_get_mtp_state(spec);
|
||||
if (mtp_state == nullptr) {
|
||||
return;
|
||||
if (mtp_state != nullptr) {
|
||||
mtp_clear_target_hidden(*mtp_state, seq_id);
|
||||
}
|
||||
|
||||
mtp_clear_target_hidden(*mtp_state, seq_id);
|
||||
if (auto * dflash_state = common_speculative_get_dflash_state(spec); dflash_state != nullptr) {
|
||||
dflash_clear_target_features(*dflash_state);
|
||||
}
|
||||
}
|
||||
|
||||
llama_context * common_speculative_get_companion_ctx(common_speculative * spec) {
|
||||
@ -1985,6 +2300,10 @@ llama_context * common_speculative_get_companion_ctx(common_speculative * spec)
|
||||
return mtp_state->ctx_mtp;
|
||||
}
|
||||
|
||||
if (auto * dflash_state = common_speculative_get_dflash_state(spec); dflash_state != nullptr) {
|
||||
return dflash_state->ctx_dft;
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@ -2023,6 +2342,39 @@ int32_t common_speculative_on_target_batch(
|
||||
const llama_batch & batch,
|
||||
const common_speculative_feature_view & features,
|
||||
bool is_prompt_warmup) {
|
||||
if (auto * dflash_state = common_speculative_get_dflash_state(spec); dflash_state != nullptr) {
|
||||
GGML_UNUSED(is_prompt_warmup);
|
||||
|
||||
if (features.kind != COMMON_SPECULATIVE_FEATURE_HIDDEN_STATE || batch.n_tokens <= 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (features.width != dflash_state->n_target_features) {
|
||||
LOG_ERR("%s: DFlash feature width mismatch: got %d expected %d\n",
|
||||
__func__, features.width, dflash_state->n_target_features);
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (batch.n_seq_id == nullptr || batch.seq_id == nullptr || batch.n_seq_id[0] <= 0 || batch.seq_id[0] == nullptr) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
const llama_seq_id seq_id = batch.seq_id[0][0];
|
||||
for (int i = 0; i < batch.n_tokens; ++i) {
|
||||
if (batch.n_seq_id[i] != 1 || batch.seq_id[i] == nullptr || batch.seq_id[i][0] != seq_id) {
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<float> hidden_rows_storage;
|
||||
if (!common_speculative_feature_view_copy_batch_rows(features, batch, seq_id, &hidden_rows_storage)) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
dflash_append_target_features(*dflash_state, hidden_rows_storage.data(), batch.n_tokens);
|
||||
return 0;
|
||||
}
|
||||
|
||||
auto * mtp_state = common_speculative_get_mtp_state(spec);
|
||||
if (mtp_state == nullptr) {
|
||||
return 0;
|
||||
|
||||
@ -64,6 +64,7 @@ class Model:
|
||||
model_name: str | None
|
||||
metadata_override: Path | None
|
||||
dir_model_card: Path
|
||||
target_model_dir: Path | None
|
||||
|
||||
# subclasses should define this!
|
||||
model_arch: gguf.MODEL_ARCH
|
||||
@ -71,7 +72,8 @@ class Model:
|
||||
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool = False,
|
||||
use_temp_file: bool = False, eager: bool = False,
|
||||
metadata_override: Path | None = None, model_name: str | None = None,
|
||||
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False):
|
||||
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False,
|
||||
target_model_dir: Path | None = None):
|
||||
if type(self) is Model:
|
||||
raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
|
||||
|
||||
@ -93,6 +95,7 @@ class Model:
|
||||
self.metadata_override = metadata_override
|
||||
self.model_name = model_name
|
||||
self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py
|
||||
self.target_model_dir = target_model_dir
|
||||
|
||||
# Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type
|
||||
if self.ftype == gguf.LlamaFileType.GUESSED:
|
||||
@ -459,6 +462,14 @@ class Model:
|
||||
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
|
||||
@staticmethod
|
||||
def load_text_hparams(dir_model: Path) -> dict[str, Any]:
|
||||
hparams = Model.load_hparams(dir_model)
|
||||
text_config = hparams.get("text_config")
|
||||
if isinstance(text_config, dict):
|
||||
return {**hparams, **text_config}
|
||||
return hparams
|
||||
|
||||
@classmethod
|
||||
def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]:
|
||||
assert names
|
||||
@ -500,13 +511,14 @@ class Model:
|
||||
return seems_special
|
||||
|
||||
# used for GPT-2 BPE and WordPiece vocabs
|
||||
def get_vocab_base(self) -> tuple[list[str], list[int], str]:
|
||||
def get_vocab_base(self, dir_model: Path | None = None, vocab_size: int | None = None) -> tuple[list[str], list[int], str]:
|
||||
tokens: list[str] = []
|
||||
toktypes: list[int] = []
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
|
||||
vocab_size = self.hparams.get("vocab_size", len(tokenizer.vocab))
|
||||
dir_model = dir_model or self.dir_model
|
||||
tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)
|
||||
vocab_size = vocab_size or self.hparams.get("vocab_size", len(tokenizer.vocab))
|
||||
assert max(tokenizer.vocab.values()) < vocab_size
|
||||
|
||||
tokpre = self.get_vocab_base_pre(tokenizer)
|
||||
@ -594,6 +606,18 @@ class Model:
|
||||
if chkhsh == "e636dc30a262dcc0d8c323492e32ae2b70728f4df7dfe9737d9f920a282b8aea":
|
||||
# ref: https://huggingface.co/Qwen/Qwen1.5-7B
|
||||
res = "qwen2"
|
||||
if chkhsh == "d30d75d9059f1aa2c19359de71047b3ae408c70875e8a3ccf8c5fba56c9d8af4":
|
||||
# ref: https://huggingface.co/Qwen/Qwen3.5-9B-Instruct
|
||||
res = "qwen35"
|
||||
if chkhsh == "99cc61242f7106804ce24fdf3a6451e4a55251078dffd5453c806e11b2310db3":
|
||||
# ref: https://huggingface.co/Qwen/Qwen3.5-27B
|
||||
res = "qwen35"
|
||||
if chkhsh == "1444df51289cfa8063b96f0e62b1125440111bc79a52003ea14b6eac7016fd5f":
|
||||
# ref: https://huggingface.co/z-lab/Qwen3.5-27B-DFlash (uses Qwen3.5 tokenizer)
|
||||
res = "qwen35"
|
||||
if chkhsh == "4f53cda18c2baa0c0354bb5f9a3ecbe5ed12ab4d8e11ba873c2f11161202b945":
|
||||
# ref: https://huggingface.co/Qwen/Qwen3.6-35B-A3B (identical pre-tokenizer regex to qwen35)
|
||||
res = "qwen35"
|
||||
if chkhsh == "b6dc8df998e1cfbdc4eac8243701a65afe638679230920b50d6f17d81c098166":
|
||||
# ref: https://huggingface.co/allenai/OLMo-1.7-7B-hf
|
||||
res = "olmo"
|
||||
@ -681,19 +705,20 @@ class Model:
|
||||
return res
|
||||
# Marker: End get_vocab_base_pre
|
||||
|
||||
def _set_vocab_gpt2(self) -> None:
|
||||
tokens, toktypes, tokpre = self.get_vocab_base()
|
||||
def _set_vocab_gpt2(self, dir_model: Path | None = None, vocab_size: int | None = None) -> None:
|
||||
dir_model = dir_model or self.dir_model
|
||||
tokens, toktypes, tokpre = self.get_vocab_base(dir_model=dir_model, vocab_size=vocab_size)
|
||||
self.gguf_writer.add_tokenizer_model("gpt2")
|
||||
self.gguf_writer.add_tokenizer_pre(tokpre)
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
|
||||
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
|
||||
special_vocab = gguf.SpecialVocab(dir_model, load_merges=True)
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
def _set_vocab_qwen(self):
|
||||
dir_model = self.dir_model
|
||||
hparams = self.hparams
|
||||
def _set_vocab_qwen(self, dir_model: Path | None = None, hparams: dict[str, Any] | None = None):
|
||||
dir_model = dir_model or self.dir_model
|
||||
hparams = hparams or self.hparams
|
||||
tokens: list[str] = []
|
||||
toktypes: list[int] = []
|
||||
|
||||
@ -2246,15 +2271,118 @@ class Qwen2MoeModel(Model):
|
||||
if len(experts) > 0:
|
||||
raise ValueError(f"Unprocessed experts: {experts}")
|
||||
|
||||
|
||||
@Model.register("Qwen3ForCausalLM")
|
||||
class Qwen3Model(Qwen2Model):
|
||||
model_arch = gguf.MODEL_ARCH.QWEN3
|
||||
|
||||
|
||||
@Model.register("Qwen3MoeForCausalLM")
|
||||
class Qwen3MoeModel(Qwen2MoeModel):
|
||||
model_arch = gguf.MODEL_ARCH.QWEN3MOE
|
||||
|
||||
|
||||
@Model.register("DFlashDraftModel")
|
||||
class DFlashDraftModel(Qwen3Model):
|
||||
model_arch = gguf.MODEL_ARCH.DFLASH_DRAFT
|
||||
|
||||
_target_hparams: dict[str, Any] | None = None
|
||||
|
||||
def _require_target_model_dir(self) -> Path:
|
||||
if self.target_model_dir is None:
|
||||
raise ValueError("DFlashDraftModel conversion requires --target-model-dir <matching target model directory>")
|
||||
return self.target_model_dir
|
||||
|
||||
def _get_target_hparams(self) -> dict[str, Any]:
|
||||
if self._target_hparams is None:
|
||||
self._target_hparams = Model.load_text_hparams(self._require_target_model_dir())
|
||||
return self._target_hparams
|
||||
|
||||
def set_vocab(self):
|
||||
target_hparams = self._get_target_hparams()
|
||||
self._set_vocab_gpt2(
|
||||
dir_model=self._require_target_model_dir(),
|
||||
vocab_size=target_hparams.get("vocab_size"),
|
||||
)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
|
||||
self.gguf_writer.add_causal_attention(False)
|
||||
self.gguf_writer.add_rope_dimension_count(self.hparams.get("head_dim", 128))
|
||||
|
||||
arch = self.gguf_writer.arch
|
||||
dflash_cfg = self.hparams.get("dflash_config")
|
||||
dflash_cfg = dflash_cfg if isinstance(dflash_cfg, dict) else {}
|
||||
|
||||
def dflash_required_value(name: str) -> Any:
|
||||
if name in dflash_cfg:
|
||||
return dflash_cfg[name]
|
||||
if name in self.hparams:
|
||||
return self.hparams[name]
|
||||
raise ValueError(f"DFlashDraftModel conversion requires explicit {name} metadata")
|
||||
|
||||
block_size = int(dflash_required_value("block_size"))
|
||||
self.gguf_writer.add_uint32(f"{arch}.dflash.block_size", block_size)
|
||||
|
||||
mask_token_id = int(dflash_required_value("mask_token_id"))
|
||||
self.gguf_writer.add_uint32(f"{arch}.dflash.mask_token_id", mask_token_id)
|
||||
|
||||
target_layer_ids = [int(layer_id) for layer_id in dflash_required_value("target_layer_ids")]
|
||||
if len(target_layer_ids) == 0:
|
||||
raise ValueError("DFlashDraftModel conversion requires at least one target_layer_id")
|
||||
self.gguf_writer.add_array(f"{arch}.dflash.target_layer_ids", target_layer_ids)
|
||||
|
||||
if "n_target_features" in dflash_cfg:
|
||||
n_target_features = int(dflash_cfg["n_target_features"])
|
||||
elif "n_target_features" in self.hparams:
|
||||
n_target_features = int(self.hparams["n_target_features"])
|
||||
else:
|
||||
draft_hidden_size = self.hparams.get("hidden_size")
|
||||
if draft_hidden_size is None:
|
||||
raise ValueError("DFlashDraftModel: draft config is missing hidden_size")
|
||||
|
||||
n_target_features = int(draft_hidden_size) * len(target_layer_ids)
|
||||
|
||||
target_hparams = self._get_target_hparams()
|
||||
target_hidden_size = target_hparams.get("hidden_size")
|
||||
if target_hidden_size is not None and int(target_hidden_size) != int(draft_hidden_size):
|
||||
logger.warning(
|
||||
"DFlashDraftModel: target hidden_size=%d differs from draft hidden_size=%d; using draft hidden width for n_target_features",
|
||||
int(target_hidden_size),
|
||||
int(draft_hidden_size),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"DFlashDraftModel: inferred n_target_features=%d from draft hidden_size=%d and n_target_layers=%d",
|
||||
n_target_features,
|
||||
int(draft_hidden_size),
|
||||
len(target_layer_ids),
|
||||
)
|
||||
|
||||
self.gguf_writer.add_uint32(f"{arch}.dflash.n_target_features", n_target_features)
|
||||
|
||||
logger.info(
|
||||
"DFlashDraftModel metadata: block_size=%s mask_token_id=%s target_layer_ids=%s n_target_features=%s",
|
||||
block_size,
|
||||
mask_token_id,
|
||||
target_layer_ids,
|
||||
n_target_features,
|
||||
)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
if name == "fc.weight":
|
||||
return [(f"{gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.DFLASH_FC]}.weight", data_torch)]
|
||||
if name == "hidden_norm.weight":
|
||||
return [(f"{gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.DFLASH_HIDDEN_NORM]}.weight", data_torch)]
|
||||
if name == "norm.weight":
|
||||
name = "model.norm.weight"
|
||||
elif name.startswith("layers."):
|
||||
name = f"model.{name}"
|
||||
|
||||
return super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@Model.register("Ernie4_5_ForCausalLM", "Ernie4_5ForCausalLM")
|
||||
class Ernie4_5Model(Model):
|
||||
model_arch = gguf.MODEL_ARCH.ERNIE4_5
|
||||
@ -4385,6 +4513,7 @@ class JaisModel(Model):
|
||||
super().prepare_tensors()
|
||||
self.gguf_writer.add_max_alibi_bias(self.max_alibi_bias)
|
||||
|
||||
|
||||
@Model.register("MiniMaxM2ForCausalLM")
|
||||
class MiniMaxM2Model(Model):
|
||||
model_arch = gguf.MODEL_ARCH.MINIMAXM2
|
||||
@ -4457,10 +4586,12 @@ class SmolLM3Model(LlamaModel):
|
||||
chat_template = tokenizer.chat_template.replace("[:]", "")
|
||||
self.gguf_writer.add_chat_template(chat_template)
|
||||
|
||||
|
||||
@Model.register("SeedOssForCausalLM")
|
||||
class SeedOssModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.SEED_OSS
|
||||
|
||||
|
||||
@Model.register("Dots1ForCausalLM")
|
||||
class Dots1Model(Qwen2MoeModel):
|
||||
model_arch = gguf.MODEL_ARCH.DOTS1
|
||||
@ -4621,6 +4752,7 @@ class Glm4MoeModel(Model):
|
||||
if len(experts) > 0:
|
||||
raise ValueError(f"Unprocessed experts: {experts}")
|
||||
|
||||
|
||||
@Model.register("ChatGLMModel", "ChatGLMForConditionalGeneration")
|
||||
class ChatGLMModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.CHATGLM
|
||||
@ -4803,6 +4935,7 @@ class ChatGLMModel(Model):
|
||||
name = name.removeprefix("transformer.")
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
|
||||
@Model.register("BailingMoeV2ForCausalLM")
|
||||
class BailingMoeV2Model(Model):
|
||||
model_arch = gguf.MODEL_ARCH.BAILINGMOE2
|
||||
@ -5028,6 +5161,10 @@ def parse_args() -> argparse.Namespace:
|
||||
"--metadata", type=Path,
|
||||
help="Specify the path for an authorship metadata override file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--target-model-dir", type=Path,
|
||||
help="matching target model directory; required for DFlash conversion to reuse tokenizer and infer target feature width",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
@ -5107,7 +5244,8 @@ def main() -> None:
|
||||
metadata_override=args.metadata, model_name=args.model_name,
|
||||
split_max_tensors=args.split_max_tensors,
|
||||
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
|
||||
small_first_shard=args.no_tensor_first_split)
|
||||
small_first_shard=args.no_tensor_first_split,
|
||||
target_model_dir=args.target_model_dir)
|
||||
|
||||
if args.vocab_only:
|
||||
logger.info("Exporting model vocab...")
|
||||
|
||||
@ -78,6 +78,10 @@ models = [
|
||||
{"name": "refact", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/smallcloudai/Refact-1_6-base", },
|
||||
{"name": "command-r", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/CohereForAI/c4ai-command-r-v01", },
|
||||
{"name": "qwen2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen1.5-7B", },
|
||||
{"name": "qwen35", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3.5-9B-Instruct", "chkhsh": "d30d75d9059f1aa2c19359de71047b3ae408c70875e8a3ccf8c5fba56c9d8af4", },
|
||||
{"name": "qwen35", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3.5-27B", "chkhsh": "99cc61242f7106804ce24fdf3a6451e4a55251078dffd5453c806e11b2310db3", },
|
||||
{"name": "qwen35", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/z-lab/Qwen3.5-27B-DFlash", "chkhsh": "1444df51289cfa8063b96f0e62b1125440111bc79a52003ea14b6eac7016fd5f", },
|
||||
{"name": "qwen35", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3.6-35B-A3B", "chkhsh": "4f53cda18c2baa0c0354bb5f9a3ecbe5ed12ab4d8e11ba873c2f11161202b945", },
|
||||
{"name": "olmo", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/allenai/OLMo-1.7-7B-hf", },
|
||||
{"name": "dbrx", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/databricks/dbrx-base", },
|
||||
{"name": "jina-v2-en", "tokt": TOKENIZER_TYPE.WPM, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-en", }, # WPM!
|
||||
@ -154,39 +158,46 @@ for model in models:
|
||||
if tokt == TOKENIZER_TYPE.SPM or tokt == TOKENIZER_TYPE.UGM:
|
||||
continue
|
||||
|
||||
# Skip if the tokenizer folder does not exist or there are other download issues previously
|
||||
if not os.path.exists(f"models/tokenizers/{name}"):
|
||||
logger.warning(f"Directory for tokenizer {name} not found. Skipping...")
|
||||
continue
|
||||
chkhsh = model.get("chkhsh")
|
||||
|
||||
# create the tokenizer
|
||||
try:
|
||||
if name == "t5":
|
||||
tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}", use_fast=False)
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}")
|
||||
except (OSError, TypeError) as e:
|
||||
logger.error(f"Error loading tokenizer for model {name}. The model may not exist or is not accessible with the provided token. Error: {e}")
|
||||
continue # Skip to the next model if the tokenizer can't be loaded
|
||||
if chkhsh is None:
|
||||
# Skip if the tokenizer folder does not exist or there are other download issues previously
|
||||
if not os.path.exists(f"models/tokenizers/{name}"):
|
||||
logger.warning(f"Directory for tokenizer {name} not found. Skipping...")
|
||||
continue
|
||||
|
||||
chktok = tokenizer.encode(CHK_TXT)
|
||||
chkhsh = sha256(str(chktok).encode()).hexdigest()
|
||||
# create the tokenizer
|
||||
try:
|
||||
if name == "t5":
|
||||
tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}", use_fast=False)
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}")
|
||||
except (OSError, TypeError) as e:
|
||||
logger.error(f"Error loading tokenizer for model {name}. The model may not exist or is not accessible with the provided token. Error: {e}")
|
||||
continue # Skip to the next model if the tokenizer can't be loaded
|
||||
|
||||
chktok = tokenizer.encode(CHK_TXT)
|
||||
chkhsh = sha256(str(chktok).encode()).hexdigest()
|
||||
|
||||
logger.info(f"model: {name}")
|
||||
logger.info(f"tokt: {tokt}")
|
||||
logger.info(f"repo: {model['repo']}")
|
||||
logger.info(f"chktok: {chktok}")
|
||||
logger.info(f"chkhsh: {chkhsh}")
|
||||
|
||||
# print the "pre_tokenizer" content from the tokenizer.json
|
||||
with open(f"models/tokenizers/{name}/tokenizer.json", "r", encoding="utf-8") as f:
|
||||
cfg = json.load(f)
|
||||
normalizer = cfg["normalizer"]
|
||||
logger.info("normalizer: " + json.dumps(normalizer, indent=4))
|
||||
pre_tokenizer = cfg["pre_tokenizer"]
|
||||
logger.info("pre_tokenizer: " + json.dumps(pre_tokenizer, indent=4))
|
||||
if "ignore_merges" in cfg["model"]:
|
||||
logger.info("ignore_merges: " + json.dumps(cfg["model"]["ignore_merges"], indent=4))
|
||||
if model.get("chkhsh") is None:
|
||||
logger.info(f"chktok: {chktok}")
|
||||
|
||||
# print the "pre_tokenizer" content from the tokenizer.json
|
||||
with open(f"models/tokenizers/{name}/tokenizer.json", "r", encoding="utf-8") as f:
|
||||
cfg = json.load(f)
|
||||
normalizer = cfg["normalizer"]
|
||||
logger.info("normalizer: " + json.dumps(normalizer, indent=4))
|
||||
pre_tokenizer = cfg["pre_tokenizer"]
|
||||
logger.info("pre_tokenizer: " + json.dumps(pre_tokenizer, indent=4))
|
||||
if "ignore_merges" in cfg["model"]:
|
||||
logger.info("ignore_merges: " + json.dumps(cfg["model"]["ignore_merges"], indent=4))
|
||||
else:
|
||||
logger.info("using manually provided tokenizer hash")
|
||||
|
||||
logger.info("")
|
||||
|
||||
@ -353,6 +364,6 @@ logger.info("\nRun the following commands to generate the vocab files for testin
|
||||
for model in models:
|
||||
name = model["name"]
|
||||
|
||||
print(f"python3 convert_hf_to_gguf.py models/tokenizers/{name}/ --outfile models/ggml-vocab-{name}.gguf --vocab-only") # noqa: NP100
|
||||
logger.info(f"python3 convert_hf_to_gguf.py models/tokenizers/{name}/ --outfile models/ggml-vocab-{name}.gguf --vocab-only") # noqa: NP100
|
||||
|
||||
logger.info("\n")
|
||||
|
||||
@ -134,6 +134,14 @@ static bool server_speculative_has_mtp(const common_params_speculative & spec) {
|
||||
return spec.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP);
|
||||
}
|
||||
|
||||
static bool server_speculative_has_dflash(const common_params_speculative & spec) {
|
||||
return spec.has_stage_type(COMMON_SPECULATIVE_TYPE_DFLASH);
|
||||
}
|
||||
|
||||
static bool server_speculative_has_target_features(const common_params_speculative & spec) {
|
||||
return server_speculative_has_mtp(spec) || server_speculative_has_dflash(spec);
|
||||
}
|
||||
|
||||
static bool server_speculative_same_stage_types(
|
||||
const common_params_speculative & lhs,
|
||||
const common_params_speculative & rhs) {
|
||||
@ -217,6 +225,18 @@ static common_speculative_stage_params server_parse_speculative_stage_json(const
|
||||
}
|
||||
|
||||
server_context::~server_context() {
|
||||
// Speculative state may reference the live target context during teardown.
|
||||
for (server_slot& slot : slots) {
|
||||
if (slot.ctx_sampling != nullptr) {
|
||||
common_sampler_free(slot.ctx_sampling);
|
||||
}
|
||||
slot.spec_ckpt.clear();
|
||||
common_speculative_free(slot.spec);
|
||||
slot.spec = nullptr;
|
||||
slot.ctx_dft = nullptr;
|
||||
llama_batch_free(slot.batch_spec);
|
||||
}
|
||||
|
||||
if (ctx) {
|
||||
llama_free(ctx);
|
||||
ctx = nullptr;
|
||||
@ -238,19 +258,6 @@ server_context::~server_context() {
|
||||
model_draft = nullptr;
|
||||
}
|
||||
|
||||
// Clear any sampling context
|
||||
for (server_slot& slot : slots) {
|
||||
if (slot.ctx_sampling != nullptr) {
|
||||
common_sampler_free(slot.ctx_sampling);
|
||||
}
|
||||
slot.spec_ckpt.clear();
|
||||
if (slot.ctx_dft) {
|
||||
llama_free(slot.ctx_dft);
|
||||
}
|
||||
common_speculative_free(slot.spec);
|
||||
llama_batch_free(slot.batch_spec);
|
||||
}
|
||||
|
||||
llama_batch_free(batch);
|
||||
}
|
||||
|
||||
@ -286,6 +293,13 @@ bool server_context::load_model(const gpt_params& params_) {
|
||||
params_base.speculative.model_dft = nullptr;
|
||||
}
|
||||
|
||||
if (server_speculative_has_dflash(params_base.speculative) && params_base.n_parallel > 1) {
|
||||
LOG_ERROR("DFlash is currently limited to a single server slot (-np 1).\n", {
|
||||
{"n_parallel", params_base.n_parallel},
|
||||
});
|
||||
return false;
|
||||
}
|
||||
|
||||
bool has_draft_model = !params_base.speculative.model.empty() || !params_base.speculative.params.empty();
|
||||
std::string& mmproj_path = params_base.mmproj.path;
|
||||
if (!mmproj_path.empty()) {
|
||||
@ -470,7 +484,7 @@ void server_context::init() {
|
||||
bool can_spec = true;
|
||||
if (!params_base.dry_run) {
|
||||
can_spec = common_speculative_is_compat(ctx);
|
||||
}
|
||||
}
|
||||
if (!can_spec) {
|
||||
SRV_WRN("%s", "speculative decoding not supported by this context\n");
|
||||
}
|
||||
@ -1656,7 +1670,7 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
|
||||
int32_t banbuffer_size = json_value(data, "banbuffer_size", 0);
|
||||
slot.n_buffer = 0; // Ensure buffer calculation starts fresh for this slot
|
||||
slot.rewind_count_max = json_value(data, "rewind_count_max", -1);
|
||||
|
||||
|
||||
const auto& banned_strings = data.find("banned_strings");
|
||||
if (banned_strings != data.end() && banned_strings->is_array()) {
|
||||
slot.ban_phrases.clear();
|
||||
@ -2805,7 +2819,7 @@ static size_t load_server_tokens_from_file(const std::string & filename, server
|
||||
size_t pos = 0;
|
||||
json token_json;
|
||||
if (file.is_open()) {
|
||||
file >> token_json;
|
||||
file >> token_json;
|
||||
pos = file.tellg();
|
||||
file.close();
|
||||
}
|
||||
@ -3727,7 +3741,7 @@ bool server_context::create_checkpoint(server_slot & slot) {
|
||||
|
||||
slot.server_cached_prompt.checkpoints.erase(slot.server_cached_prompt.checkpoints.begin());
|
||||
}
|
||||
|
||||
|
||||
auto & cur = slot.server_cached_prompt.checkpoints.emplace_back();
|
||||
server_prompt_checkpoint_update(cur, ctx, slot.id, slot.cache_tokens.n_tokens(), pos_min, pos_max, slot.n_past_offset);
|
||||
|
||||
@ -4060,7 +4074,7 @@ void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t
|
||||
slot.do_checkpoint = true;
|
||||
break;
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
LOG_VERBOSE("prompt processing progress", {
|
||||
{"id_slot", slot.id},
|
||||
@ -4143,7 +4157,7 @@ static void restore_speculative_checkpoint(
|
||||
common_speculative_type spec_type_used,
|
||||
llama_token sampled_before,
|
||||
const std::vector<llama_token> & ids, int n_draft,
|
||||
const std::vector<float> & mtp_hidden_state_pre, int32_t mtp_n_past_base) {
|
||||
const std::vector<float> & spec_feature_rows_pre, int32_t spec_n_past_base) {
|
||||
if (slot.spec_ckpt.per_step_enabled) {
|
||||
const int step = (int)ids.size() - 1;
|
||||
llama_spec_ckpt_restore(ctx, slot.id, slot.spec_ckpt.n_past, step);
|
||||
@ -4155,16 +4169,16 @@ static void restore_speculative_checkpoint(
|
||||
common_sampler_accept(slot.ctx_sampling, ctx, id, true);
|
||||
}
|
||||
|
||||
// Update MTP KV cache and hidden state using embeddings collected before checkpoint restore.
|
||||
if (slot.has_mtp && !mtp_hidden_state_pre.empty()) {
|
||||
// Update speculative target features using rows collected before checkpoint restore.
|
||||
if (server_speculative_has_target_features(slot.params.speculative) && !spec_feature_rows_pre.empty()) {
|
||||
if (!common_speculative_commit_accepted_hidden_rows(
|
||||
slot.spec,
|
||||
spec_type_used,
|
||||
slot.id,
|
||||
mtp_n_past_base,
|
||||
spec_n_past_base,
|
||||
sampled_before,
|
||||
ids,
|
||||
mtp_hidden_state_pre)) {
|
||||
spec_feature_rows_pre)) {
|
||||
common_speculative_clear_sequence_hidden(slot.spec, slot.id);
|
||||
} else if (spec_type_used != COMMON_SPECULATIVE_TYPE_MTP) {
|
||||
SLT_DBG(slot, "%s", "synced MTP target hidden state from accepted-prefix rows after per-step restore");
|
||||
@ -4201,7 +4215,7 @@ static void restore_speculative_checkpoint(
|
||||
if (ret != 0) {
|
||||
SLT_ERR(slot, "failed to re-decode accepted tokens after checkpoint restore: %d\n", ret);
|
||||
}
|
||||
if (slot.has_mtp) {
|
||||
if (server_speculative_has_target_features(slot.params.speculative)) {
|
||||
const int n_accepted = (int)ids.size();
|
||||
std::vector<int32_t> redecoded_indices(n_accepted);
|
||||
for (int j = 0; j < n_accepted; ++j) {
|
||||
@ -4272,20 +4286,20 @@ void server_context::speculative_decoding_accept() {
|
||||
}
|
||||
|
||||
const bool any_rejected = (ids.size() - 1) < n_draft;
|
||||
int32_t mtp_n_past_base = 0;
|
||||
std::vector<float> mtp_hidden_state_pre;
|
||||
int32_t spec_n_past_base = 0;
|
||||
std::vector<float> spec_feature_rows_pre;
|
||||
std::vector<int32_t> accepted_output_indices;
|
||||
if (slot.has_mtp) {
|
||||
if (server_speculative_has_target_features(slot.params.speculative)) {
|
||||
const int32_t n_pre_spec_tokens = slot.cache_tokens.n_tokens() - (int32_t)(slot.drafted.size() + 1);
|
||||
mtp_n_past_base = slot.cache_tokens.pos_next(n_pre_spec_tokens);
|
||||
spec_n_past_base = slot.cache_tokens.pos_next(n_pre_spec_tokens);
|
||||
|
||||
if (!ids.empty()) {
|
||||
accepted_output_indices.assign(slot.i_batch_dft.begin(), slot.i_batch_dft.begin() + ids.size());
|
||||
}
|
||||
|
||||
if (any_rejected && slot.spec_ckpt.valid && !accepted_output_indices.empty()) {
|
||||
if (!common_speculative_copy_output_hidden_rows(slot.spec, ctx, accepted_output_indices, mtp_hidden_state_pre)) {
|
||||
mtp_hidden_state_pre.clear();
|
||||
if (!common_speculative_copy_output_hidden_rows(slot.spec, ctx, accepted_output_indices, spec_feature_rows_pre)) {
|
||||
spec_feature_rows_pre.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -4317,15 +4331,15 @@ void server_context::speculative_decoding_accept() {
|
||||
|
||||
// for recurrent/hybrid models: if any drafts were rejected, restore recurrent state
|
||||
if (any_rejected && slot.spec_ckpt.valid) {
|
||||
restore_speculative_checkpoint(slot, ctx, model, spec_type_used, sampled_before, ids, n_draft, mtp_hidden_state_pre, mtp_n_past_base);
|
||||
restore_speculative_checkpoint(slot, ctx, model, spec_type_used, sampled_before, ids, n_draft, spec_feature_rows_pre, spec_n_past_base);
|
||||
} else {
|
||||
if (slot.has_mtp && !accepted_output_indices.empty()) {
|
||||
if (server_speculative_has_target_features(slot.params.speculative) && !accepted_output_indices.empty()) {
|
||||
if (!common_speculative_commit_accepted_output(
|
||||
slot.spec,
|
||||
ctx,
|
||||
spec_type_used,
|
||||
slot.id,
|
||||
mtp_n_past_base,
|
||||
spec_n_past_base,
|
||||
sampled_before,
|
||||
ids,
|
||||
accepted_output_indices)) {
|
||||
@ -4395,15 +4409,15 @@ void server_context::release_slot_after_final_response(server_slot & slot) {
|
||||
void server_context::send_token_results(completion_token_outputs& results, server_slot& slot, int32_t n) {
|
||||
int count = 0;
|
||||
bool released = false;
|
||||
|
||||
|
||||
int32_t start_pos = slot.n_past - (int32_t)slot.token_buffer.size() + 1;
|
||||
|
||||
for (auto& it : results) {
|
||||
bool has_next = process_token(it, slot);
|
||||
|
||||
|
||||
// Clean up positional bans for the token we just confirmed/sent
|
||||
slot.positional_bans.erase(start_pos + count);
|
||||
|
||||
|
||||
count++;
|
||||
if (!has_next) {
|
||||
if (slot.stopped_limit && !slot.stopped_eos && !slot.stopped_word) {
|
||||
@ -4436,7 +4450,7 @@ inline int32_t check_ban_phrase(server_slot& slot) {
|
||||
|
||||
std::string string_buffer;
|
||||
std::vector<size_t> token_offsets;
|
||||
|
||||
|
||||
for (const auto& it : slot.token_buffer) {
|
||||
token_offsets.push_back(string_buffer.size());
|
||||
string_buffer += it.text_to_send;
|
||||
@ -4488,10 +4502,10 @@ inline int32_t check_ban_phrase(server_slot& slot) {
|
||||
if (found) {
|
||||
int32_t token_idx = -1;
|
||||
for (size_t i = 0; i < token_offsets.size(); ++i) {
|
||||
size_t len = (i == token_offsets.size() - 1)
|
||||
? string_buffer.size() - token_offsets[i]
|
||||
size_t len = (i == token_offsets.size() - 1)
|
||||
? string_buffer.size() - token_offsets[i]
|
||||
: token_offsets[i+1] - token_offsets[i];
|
||||
|
||||
|
||||
if (best_start >= token_offsets[i] && best_start < token_offsets[i] + len) {
|
||||
token_idx = (int32_t)i;
|
||||
break;
|
||||
@ -4509,7 +4523,7 @@ inline int32_t check_ban_phrase(server_slot& slot) {
|
||||
|
||||
inline void rewind_context(server_slot& slot, int32_t ban_pos) {
|
||||
slot.rewind_count++;
|
||||
|
||||
|
||||
int32_t buffer_start_pos = slot.n_past - (int32_t)slot.token_buffer.size() + 1;
|
||||
int32_t n_keep_buffer = ban_pos - buffer_start_pos;
|
||||
if (n_keep_buffer < 0) n_keep_buffer = 0;
|
||||
@ -4518,9 +4532,9 @@ inline void rewind_context(server_slot& slot, int32_t ban_pos) {
|
||||
int32_t n = 0;
|
||||
for (auto result = slot.token_buffer.begin() + n_keep_buffer; result != slot.token_buffer.end(); result++) {
|
||||
llama_token banned_tok = result->tok;
|
||||
|
||||
|
||||
if (n == 0) {
|
||||
LLAMA_LOG_DEBUG("Banned pattern detected at pos %d. Banning token %d ('%s') and rewinding.\n",
|
||||
LLAMA_LOG_DEBUG("Banned pattern detected at pos %d. Banning token %d ('%s') and rewinding.\n",
|
||||
ban_pos, banned_tok, result->text_to_send.c_str());
|
||||
}
|
||||
|
||||
@ -4533,7 +4547,7 @@ inline void rewind_context(server_slot& slot, int32_t ban_pos) {
|
||||
}
|
||||
|
||||
int32_t n_rewind_total = (slot.n_past + 1) - ban_pos;
|
||||
|
||||
|
||||
size_t n_keep_cache = 0;
|
||||
if (ban_pos > 0) {
|
||||
n_keep_cache = (size_t)(ban_pos - 1);
|
||||
@ -4546,13 +4560,13 @@ inline void rewind_context(server_slot& slot, int32_t ban_pos) {
|
||||
if (n_keep_cache < slot.cache_tokens.size()) {
|
||||
slot.sampled = slot.cache_tokens[n_keep_cache];
|
||||
} else {
|
||||
slot.sampled = 0;
|
||||
slot.sampled = 0;
|
||||
}
|
||||
|
||||
// Truncate cache
|
||||
slot.cache_tokens.keep_first(n_keep_cache);
|
||||
slot.n_past = slot.cache_tokens.n_tokens();
|
||||
|
||||
|
||||
// Remove from KV cache
|
||||
llama_kv_cache_seq_rm(slot.ctx, slot.id, slot.cache_tokens.pos_next(slot.n_past), -1);
|
||||
|
||||
@ -4590,13 +4604,13 @@ void server_context::buffer_and_check_string_ban(server_slot & slot, completion_
|
||||
// Automatic / Heuristic logic
|
||||
// Account for strings + regex + regex_ci
|
||||
size_t total_bans = slot.ban_phrases.size() + slot.ban_regex.size() + slot.ban_regex_ci.size();
|
||||
|
||||
|
||||
// Heuristic: Allow if under 20 OR under 2 * total_bans
|
||||
// Conversely: Stop if >= 20 AND > 2 * total_bans
|
||||
if (slot.rewind_count >= 20 && slot.rewind_count > 2 * total_bans) {
|
||||
allow_rewind = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (slot.rewind_count_max > 0) {
|
||||
// Strict limit logic
|
||||
if (slot.rewind_count >= slot.rewind_count_max) {
|
||||
@ -4613,7 +4627,7 @@ void server_context::buffer_and_check_string_ban(server_slot & slot, completion_
|
||||
else if (buffer_full || !next_token) {
|
||||
slot.rewind_status = false;
|
||||
slot.rewind_count = 0;
|
||||
|
||||
|
||||
if (!next_token) {
|
||||
// send all remaining tokens
|
||||
send_token_results(slot.token_buffer, slot);
|
||||
@ -4625,7 +4639,7 @@ void server_context::buffer_and_check_string_ban(server_slot & slot, completion_
|
||||
}
|
||||
else {
|
||||
// buffer the result, wait for more tokens to validate string
|
||||
slot.sampled = result.tok;
|
||||
slot.sampled = result.tok;
|
||||
}
|
||||
}
|
||||
|
||||
@ -4710,9 +4724,9 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
|
||||
continue; // continue loop of n_batch
|
||||
}
|
||||
|
||||
if (server_speculative_has_mtp(params_base.speculative)) {
|
||||
if (server_speculative_has_target_features(params_base.speculative)) {
|
||||
for (auto & slot : slots) {
|
||||
if (!slot.spec || !slot.has_mtp) {
|
||||
if (!slot.spec || !server_speculative_has_target_features(slot.params.speculative)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -4722,7 +4736,7 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
|
||||
}
|
||||
|
||||
if (common_speculative_on_target_seq_batch(slot.spec, ctx, batch_view, slot.id, true) != 0) {
|
||||
LOG_ERROR("failed to warm up MTP state from prompt batch for slot %d\n", slot.id);
|
||||
LOG_ERROR("failed to warm up speculative target-feature state from prompt batch for slot %d\n", slot.id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -236,6 +236,8 @@ class MODEL_ARCH(IntEnum):
|
||||
GEMMA3 = auto()
|
||||
GEMMA4 = auto()
|
||||
GEMMA4_MTP = auto()
|
||||
DFLASH = auto()
|
||||
DFLASH_DRAFT = auto()
|
||||
STARCODER2 = auto()
|
||||
MAMBA = auto()
|
||||
XVERSE = auto()
|
||||
@ -260,6 +262,7 @@ class MODEL_ARCH(IntEnum):
|
||||
SMOLLM3 = auto()
|
||||
SEED_OSS = auto()
|
||||
|
||||
|
||||
class MODEL_TENSOR(IntEnum):
|
||||
TOKEN_EMBD = auto()
|
||||
TOKEN_EMBD_NORM = auto()
|
||||
@ -366,6 +369,8 @@ class MODEL_TENSOR(IntEnum):
|
||||
MTP_POST_PROJ = auto()
|
||||
MTP_TOKEN_ORDERING = auto()
|
||||
MTP_CENTROIDS = auto()
|
||||
DFLASH_FC = auto()
|
||||
DFLASH_HIDDEN_NORM = auto()
|
||||
|
||||
|
||||
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
@ -402,6 +407,8 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
MODEL_ARCH.GEMMA3: "gemma3",
|
||||
MODEL_ARCH.GEMMA4: "gemma4",
|
||||
MODEL_ARCH.GEMMA4_MTP: "gemma4_mtp",
|
||||
MODEL_ARCH.DFLASH: "dflash",
|
||||
MODEL_ARCH.DFLASH_DRAFT: "dflash-draft",
|
||||
MODEL_ARCH.STARCODER2: "starcoder2",
|
||||
MODEL_ARCH.MAMBA: "mamba",
|
||||
MODEL_ARCH.XVERSE: "xverse",
|
||||
@ -534,6 +541,8 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
MODEL_TENSOR.MTP_POST_PROJ: "mtp_post_proj",
|
||||
MODEL_TENSOR.MTP_TOKEN_ORDERING: "mtp_token_ordering",
|
||||
MODEL_TENSOR.MTP_CENTROIDS: "mtp_centroids",
|
||||
MODEL_TENSOR.DFLASH_FC: "dflash_fc",
|
||||
MODEL_TENSOR.DFLASH_HIDDEN_NORM: "dflash_hidden_norm",
|
||||
}
|
||||
|
||||
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
@ -1235,6 +1244,38 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD,
|
||||
MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM,
|
||||
],
|
||||
MODEL_ARCH.DFLASH: [
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_Q_NORM,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_K_NORM,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.ATTN_POST_NORM,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.DFLASH_FC,
|
||||
MODEL_TENSOR.DFLASH_HIDDEN_NORM,
|
||||
],
|
||||
MODEL_ARCH.DFLASH_DRAFT: [
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_Q_NORM,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_K_NORM,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.ATTN_POST_NORM,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.DFLASH_FC,
|
||||
MODEL_TENSOR.DFLASH_HIDDEN_NORM,
|
||||
],
|
||||
MODEL_ARCH.BITNET: [
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
@ -1644,89 +1685,89 @@ class ExpertGatingFuncType(IntEnum):
|
||||
# ALL VALUES SHOULD BE THE SAME HERE AS THEY ARE OVER THERE.
|
||||
class LlamaFileType(IntEnum):
|
||||
ALL_F32 = 0
|
||||
MOSTLY_F16 = 1 #except 1d tensors
|
||||
MOSTLY_Q4_0 = 2 #except 1d tensors
|
||||
MOSTLY_Q4_1 = 3 #except 1d tensors
|
||||
MOSTLY_Q8_0 = 7 #except 1d tensors
|
||||
MOSTLY_Q5_0 = 8 #except 1d tensors
|
||||
MOSTLY_Q5_1 = 9 #except 1d tensors
|
||||
MOSTLY_Q2_K = 10 #except 1d tensors
|
||||
MOSTLY_Q3_K_S = 11 #except 1d tensors
|
||||
MOSTLY_Q3_K_M = 12 #except 1d tensors
|
||||
MOSTLY_Q3_K_L = 13 #except 1d tensors
|
||||
MOSTLY_Q4_K_S = 14 #except 1d tensors
|
||||
MOSTLY_Q4_K_M = 15 #except 1d tensors
|
||||
MOSTLY_Q5_K_S = 16 #except 1d tensors
|
||||
MOSTLY_Q5_K_M = 17 #except 1d tensors
|
||||
MOSTLY_Q6_K = 18 #except 1d tensors
|
||||
MOSTLY_IQ2_XXS = 19 #except 1d tensors
|
||||
MOSTLY_IQ2_XS = 20 #except 1d tensors
|
||||
MOSTLY_Q2_K_S = 21 #except 1d tensors
|
||||
MOSTLY_IQ3_XS = 22 #except 1d tensors
|
||||
MOSTLY_IQ3_XXS = 23 #except 1d tensors
|
||||
MOSTLY_IQ1_S = 24 #except 1d tensors
|
||||
MOSTLY_IQ4_NL = 25 #except 1d tensors
|
||||
MOSTLY_IQ3_S = 26 #except 1d tensors
|
||||
MOSTLY_IQ3_M = 27 #except 1d tensors
|
||||
MOSTLY_IQ2_S = 28 #except 1d tensors
|
||||
MOSTLY_IQ2_M = 29 #except 1d tensors
|
||||
MOSTLY_IQ4_XS = 30 #except 1d tensors
|
||||
MOSTLY_IQ1_M = 31 #except 1d tensors
|
||||
MOSTLY_BF16 = 32 #except 1d tensors
|
||||
MOSTLY_Q4_0_4_4 = 33 #except 1d tensors
|
||||
MOSTLY_Q4_0_4_8 = 34 #except 1d tensors
|
||||
MOSTLY_Q4_0_8_8 = 35 #except 1d tensors
|
||||
MOSTLY_MXFP4 = 38 #except 1d tensors, 38 to be compatible with mainline
|
||||
MOSTLY_F16 = 1 # except 1d tensors
|
||||
MOSTLY_Q4_0 = 2 # except 1d tensors
|
||||
MOSTLY_Q4_1 = 3 # except 1d tensors
|
||||
MOSTLY_Q8_0 = 7 # except 1d tensors
|
||||
MOSTLY_Q5_0 = 8 # except 1d tensors
|
||||
MOSTLY_Q5_1 = 9 # except 1d tensors
|
||||
MOSTLY_Q2_K = 10 # except 1d tensors
|
||||
MOSTLY_Q3_K_S = 11 # except 1d tensors
|
||||
MOSTLY_Q3_K_M = 12 # except 1d tensors
|
||||
MOSTLY_Q3_K_L = 13 # except 1d tensors
|
||||
MOSTLY_Q4_K_S = 14 # except 1d tensors
|
||||
MOSTLY_Q4_K_M = 15 # except 1d tensors
|
||||
MOSTLY_Q5_K_S = 16 # except 1d tensors
|
||||
MOSTLY_Q5_K_M = 17 # except 1d tensors
|
||||
MOSTLY_Q6_K = 18 # except 1d tensors
|
||||
MOSTLY_IQ2_XXS = 19 # except 1d tensors
|
||||
MOSTLY_IQ2_XS = 20 # except 1d tensors
|
||||
MOSTLY_Q2_K_S = 21 # except 1d tensors
|
||||
MOSTLY_IQ3_XS = 22 # except 1d tensors
|
||||
MOSTLY_IQ3_XXS = 23 # except 1d tensors
|
||||
MOSTLY_IQ1_S = 24 # except 1d tensors
|
||||
MOSTLY_IQ4_NL = 25 # except 1d tensors
|
||||
MOSTLY_IQ3_S = 26 # except 1d tensors
|
||||
MOSTLY_IQ3_M = 27 # except 1d tensors
|
||||
MOSTLY_IQ2_S = 28 # except 1d tensors
|
||||
MOSTLY_IQ2_M = 29 # except 1d tensors
|
||||
MOSTLY_IQ4_XS = 30 # except 1d tensors
|
||||
MOSTLY_IQ1_M = 31 # except 1d tensors
|
||||
MOSTLY_BF16 = 32 # except 1d tensors
|
||||
MOSTLY_Q4_0_4_4 = 33 # except 1d tensors
|
||||
MOSTLY_Q4_0_4_8 = 34 # except 1d tensors
|
||||
MOSTLY_Q4_0_8_8 = 35 # except 1d tensors
|
||||
MOSTLY_MXFP4 = 38 # except 1d tensors, 38 to be compatible with mainline
|
||||
|
||||
MOSTLY_Q6_0 = 135 #except 1d tensors
|
||||
MOSTLY_IQ1_BN = 136 #except 1d tensors
|
||||
MOSTLY_IQ2_BN = 137 #except 1d tensors
|
||||
MOSTLY_IQ2_K = 138 #except 1d tensors
|
||||
MOSTLY_IQ3_K = 139 #except 1d tensors
|
||||
MOSTLY_IQ4_K = 140 #except 1d tensors
|
||||
MOSTLY_IQ5_K = 141 #except 1d tensors
|
||||
MOSTLY_IQ6_K = 142 #except 1d tensors
|
||||
MOSTLY_IQ4_KS = 145 #except 1d tensors
|
||||
MOSTLY_IQ3_KL = 146 #except 1d tensors
|
||||
MOSTLY_IQ2_KS = 147 #except 1d tensors
|
||||
MOSTLY_IQ4_KSS = 148 #except 1d tensors
|
||||
MOSTLY_Q8_KV = 149 #except 1d tensors
|
||||
MOSTLY_IQ5_KS = 150 #except 1d tensors
|
||||
MOSTLY_IQ2_KT = 151 #except 1d tensors
|
||||
MOSTLY_IQ3_KT = 152 #except 1d tensors
|
||||
MOSTLY_IQ4_KT = 153 #except 1d tensors
|
||||
MOSTLY_IQ3_KS = 154 #except 1d tensors
|
||||
MOSTLY_IQ2_KL = 155 #except 1d tensors
|
||||
MOSTLY_IQ1_KT = 156 #except 1d tensors
|
||||
MOSTLY_Q6_0 = 135 # except 1d tensors
|
||||
MOSTLY_IQ1_BN = 136 # except 1d tensors
|
||||
MOSTLY_IQ2_BN = 137 # except 1d tensors
|
||||
MOSTLY_IQ2_K = 138 # except 1d tensors
|
||||
MOSTLY_IQ3_K = 139 # except 1d tensors
|
||||
MOSTLY_IQ4_K = 140 # except 1d tensors
|
||||
MOSTLY_IQ5_K = 141 # except 1d tensors
|
||||
MOSTLY_IQ6_K = 142 # except 1d tensors
|
||||
MOSTLY_IQ4_KS = 145 # except 1d tensors
|
||||
MOSTLY_IQ3_KL = 146 # except 1d tensors
|
||||
MOSTLY_IQ2_KS = 147 # except 1d tensors
|
||||
MOSTLY_IQ4_KSS = 148 # except 1d tensors
|
||||
MOSTLY_Q8_KV = 149 # except 1d tensors
|
||||
MOSTLY_IQ5_KS = 150 # except 1d tensors
|
||||
MOSTLY_IQ2_KT = 151 # except 1d tensors
|
||||
MOSTLY_IQ3_KT = 152 # except 1d tensors
|
||||
MOSTLY_IQ4_KT = 153 # except 1d tensors
|
||||
MOSTLY_IQ3_KS = 154 # except 1d tensors
|
||||
MOSTLY_IQ2_KL = 155 # except 1d tensors
|
||||
MOSTLY_IQ1_KT = 156 # except 1d tensors
|
||||
|
||||
MOSTLY_Q4_0_R8 = 202 #except 1d tensors
|
||||
MOSTLY_Q8_0_R8 = 207 #except 1d tensors
|
||||
MOSTLY_Q5_0_R4 = 208 #except 1d tensors
|
||||
MOSTLY_Q2_K_R4 = 210 #except 1d tensors
|
||||
MOSTLY_Q3_K_R4 = 211 #except 1d tensors
|
||||
MOSTLY_Q4_K_R4 = 214 #except 1d tensors
|
||||
MOSTLY_Q5_K_R4 = 216 #except 1d tensors
|
||||
MOSTLY_Q6_K_R4 = 218 #except 1d tensors
|
||||
MOSTLY_IQ2_XXS_R4 = 219 #except 1d tensors
|
||||
MOSTLY_IQ2_XS_R4 = 220 #except 1d tensors
|
||||
MOSTLY_IQ3_XXS_R4 = 223 #except 1d tensors
|
||||
MOSTLY_IQ1_S_R4 = 224 #except 1d tensors
|
||||
MOSTLY_IQ4_NL_R4 = 225 #except 1d tensors
|
||||
MOSTLY_IQ3_S_R4 = 226 #except 1d tensors
|
||||
MOSTLY_IQ2_M_R4 = 229 #except 1d tensors
|
||||
MOSTLY_IQ4_XS_R8 = 230 #except 1d tensors
|
||||
MOSTLY_IQ1_M_R4 = 231 #except 1d tensors
|
||||
MOSTLY_Q6_0_R4 = 335 #except 1d tensors
|
||||
MOSTLY_BF16_R16 = 232 #except 1d tensors
|
||||
MOSTLY_IQ2_BN_R4 = 337 #except 1d tensors
|
||||
MOSTLY_IQ2_K_R4 = 338 #except 1d tensors
|
||||
MOSTLY_IQ3_K_R4 = 339 #except 1d tensors
|
||||
MOSTLY_IQ4_K_R4 = 340 #except 1d tensors
|
||||
MOSTLY_IQ5_K_R4 = 341 #except 1d tensors
|
||||
MOSTLY_IQ4_KS_R4 = 345 #except 1d tensors
|
||||
MOSTLY_IQ5_KS_R4 = 350 #except 1d tensors
|
||||
MOSTLY_Q8_KV_R8 = 398 #except 1d tensors
|
||||
MOSTLY_Q8_K_R8 = 399 #except 1d tensors
|
||||
MOSTLY_Q4_0_R8 = 202 # except 1d tensors
|
||||
MOSTLY_Q8_0_R8 = 207 # except 1d tensors
|
||||
MOSTLY_Q5_0_R4 = 208 # except 1d tensors
|
||||
MOSTLY_Q2_K_R4 = 210 # except 1d tensors
|
||||
MOSTLY_Q3_K_R4 = 211 # except 1d tensors
|
||||
MOSTLY_Q4_K_R4 = 214 # except 1d tensors
|
||||
MOSTLY_Q5_K_R4 = 216 # except 1d tensors
|
||||
MOSTLY_Q6_K_R4 = 218 # except 1d tensors
|
||||
MOSTLY_IQ2_XXS_R4 = 219 # except 1d tensors
|
||||
MOSTLY_IQ2_XS_R4 = 220 # except 1d tensors
|
||||
MOSTLY_IQ3_XXS_R4 = 223 # except 1d tensors
|
||||
MOSTLY_IQ1_S_R4 = 224 # except 1d tensors
|
||||
MOSTLY_IQ4_NL_R4 = 225 # except 1d tensors
|
||||
MOSTLY_IQ3_S_R4 = 226 # except 1d tensors
|
||||
MOSTLY_IQ2_M_R4 = 229 # except 1d tensors
|
||||
MOSTLY_IQ4_XS_R8 = 230 # except 1d tensors
|
||||
MOSTLY_IQ1_M_R4 = 231 # except 1d tensors
|
||||
MOSTLY_Q6_0_R4 = 335 # except 1d tensors
|
||||
MOSTLY_BF16_R16 = 232 # except 1d tensors
|
||||
MOSTLY_IQ2_BN_R4 = 337 # except 1d tensors
|
||||
MOSTLY_IQ2_K_R4 = 338 # except 1d tensors
|
||||
MOSTLY_IQ3_K_R4 = 339 # except 1d tensors
|
||||
MOSTLY_IQ4_K_R4 = 340 # except 1d tensors
|
||||
MOSTLY_IQ5_K_R4 = 341 # except 1d tensors
|
||||
MOSTLY_IQ4_KS_R4 = 345 # except 1d tensors
|
||||
MOSTLY_IQ5_KS_R4 = 350 # except 1d tensors
|
||||
MOSTLY_Q8_KV_R8 = 398 # except 1d tensors
|
||||
MOSTLY_Q8_K_R8 = 399 # except 1d tensors
|
||||
|
||||
GUESSED = 1024 # not specified in the model file
|
||||
|
||||
@ -1771,7 +1812,7 @@ class GGUFValueType(IntEnum):
|
||||
# Items here are (block size, type size)
|
||||
QK_K = 256
|
||||
|
||||
#Values generated programatically
|
||||
# Values generated programatically
|
||||
GGML_QUANT_SIZES: dict[GGMLQuantizationType, tuple[int, int]] = {
|
||||
GGMLQuantizationType.F32 : ( 1, 4),
|
||||
GGMLQuantizationType.F16 : ( 1, 2),
|
||||
|
||||
@ -97,6 +97,7 @@ add_library(llama
|
||||
graphs/build_gemma2.cpp
|
||||
graphs/build_gemma3.cpp
|
||||
graphs/build_gemma4.cpp
|
||||
graphs/build_dflash.cpp
|
||||
graphs/build_mamba.cpp
|
||||
graphs/build_command_r.cpp
|
||||
graphs/build_olmo.cpp
|
||||
|
||||
144
src/graphs/build_dflash.cpp
Normal file
144
src/graphs/build_dflash.cpp
Normal file
@ -0,0 +1,144 @@
|
||||
#include "../llama-build-context.h"
|
||||
#include "../llama-context.h"
|
||||
#include "../llama-model.h"
|
||||
|
||||
#include <cmath>
|
||||
|
||||
ggml_cgraph * llm_build_context::build_dflash() {
|
||||
const int64_t n_embd_head_k = hparams.n_embd_head_k(0);
|
||||
const int64_t n_embd_head_v = hparams.n_embd_head_v(0);
|
||||
const int64_t n_target_features = hparams.dflash_n_target_features;
|
||||
const int64_t ctx_len = std::max<int64_t>(1, (int64_t) cparams.n_ctx - (int64_t) hparams.dflash_block_size);
|
||||
const int64_t n_kv_total = ctx_len + n_tokens;
|
||||
|
||||
GGML_ASSERT(n_embd_head_k == n_embd_head_v);
|
||||
GGML_ASSERT(n_target_features > 0);
|
||||
|
||||
ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes((int) std::max<int64_t>(n_tokens, ctx_len)) + 32 * n_layer, false);
|
||||
|
||||
lctx.inp_dflash_target_features = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_target_features, ctx_len);
|
||||
ggml_set_input(lctx.inp_dflash_target_features);
|
||||
cb(lctx.inp_dflash_target_features, "dflash_target_features", -1);
|
||||
|
||||
lctx.inp_dflash_pos_ctx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ctx_len);
|
||||
ggml_set_input(lctx.inp_dflash_pos_ctx);
|
||||
cb(lctx.inp_dflash_pos_ctx, "dflash_pos_ctx", -1);
|
||||
|
||||
lctx.inp_dflash_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv_total, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
ggml_set_input(lctx.inp_dflash_kq_mask);
|
||||
cb(lctx.inp_dflash_kq_mask, "dflash_kq_mask", -1);
|
||||
|
||||
ggml_tensor * tok_embd = model.tok_embd;
|
||||
if (tok_embd == nullptr) {
|
||||
tok_embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_Q4_0, n_embd, hparams.n_vocab);
|
||||
}
|
||||
|
||||
ggml_tensor * inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, tok_embd, cb);
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
ggml_tensor * fused_target = llm_build_lora_mm(lctx, ctx0, model.dflash_fc, lctx.inp_dflash_target_features);
|
||||
fused_target = llm_build_norm(ctx0, fused_target, hparams, model.dflash_hidden_norm, nullptr, LLM_NORM_RMS, cb, -1);
|
||||
cb(fused_target, "dflash_target_fused", -1);
|
||||
|
||||
const float kq_scale = 1.0f / std::sqrt((float) n_embd_head_k);
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
|
||||
ggml_tensor * cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens);
|
||||
Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, cb, il);
|
||||
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
ggml_tensor * Kcur_noise = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
|
||||
Kcur_noise = ggml_reshape_3d(ctx0, Kcur_noise, n_embd_head_k, n_head_kv, n_tokens);
|
||||
Kcur_noise = llm_build_norm(ctx0, Kcur_noise, hparams, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, cb, il);
|
||||
Kcur_noise = ggml_rope_ext(ctx0, Kcur_noise, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
cb(Kcur_noise, "Kcur_noise", il);
|
||||
|
||||
ggml_tensor * Vcur_noise = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
|
||||
Vcur_noise = ggml_reshape_3d(ctx0, Vcur_noise, n_embd_head_v, n_head_kv, n_tokens);
|
||||
cb(Vcur_noise, "Vcur_noise", il);
|
||||
|
||||
ggml_tensor * Kcur_ctx = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, fused_target);
|
||||
Kcur_ctx = ggml_reshape_3d(ctx0, Kcur_ctx, n_embd_head_k, n_head_kv, ctx_len);
|
||||
Kcur_ctx = llm_build_norm(ctx0, Kcur_ctx, hparams, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, cb, il);
|
||||
Kcur_ctx = ggml_rope_ext(ctx0, Kcur_ctx, lctx.inp_dflash_pos_ctx, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
cb(Kcur_ctx, "Kcur_ctx", il);
|
||||
|
||||
ggml_tensor * Vcur_ctx = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, fused_target);
|
||||
Vcur_ctx = ggml_reshape_3d(ctx0, Vcur_ctx, n_embd_head_v, n_head_kv, ctx_len);
|
||||
cb(Vcur_ctx, "Vcur_ctx", il);
|
||||
|
||||
ggml_tensor * Kcur = ggml_concat(ctx0, Kcur_ctx, Kcur_noise, 2);
|
||||
ggml_tensor * Vcur = ggml_concat(ctx0, Vcur_ctx, Vcur_noise, 2);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
Qcur = ggml_cast(ctx0, Qcur, GGML_TYPE_F16);
|
||||
Kcur = ggml_cast(ctx0, Kcur, GGML_TYPE_F16);
|
||||
Vcur = ggml_cast(ctx0, Vcur, GGML_TYPE_F16);
|
||||
cb(Qcur, "Qcur_f16", il);
|
||||
cb(Kcur, "Kcur_f16", il);
|
||||
cb(Vcur, "Vcur_f16", il);
|
||||
|
||||
ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
|
||||
ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
|
||||
ggml_tensor * v = ggml_cont(ctx0, ggml_permute(ctx0, Vcur, 0, 2, 1, 3));
|
||||
cb(q, "q", il);
|
||||
cb(k, "k", il);
|
||||
cb(v, "v", il);
|
||||
|
||||
cur = ggml_flash_attn_ext(ctx0, q, k, v, lctx.inp_dflash_kq_mask, kq_scale, hparams.f_max_alibi_bias,
|
||||
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
|
||||
cb(cur, "flash_attn", il);
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
cur = ggml_reshape_2d(ctx0, cur, model.layers[il].wo->ne[0], n_tokens);
|
||||
cb(cur, "flash_attn_reshaped", il);
|
||||
|
||||
cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur);
|
||||
cb(cur, "kqv_out", il);
|
||||
|
||||
cur = ggml_add(ctx0, cur, inpSA);
|
||||
cb(cur, "attn_residual", il);
|
||||
|
||||
ggml_tensor * ffn_residual = cur;
|
||||
cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "attn_post_norm", il);
|
||||
|
||||
cur = llm_build_ffn(ctx0, lctx, nullptr, cur,
|
||||
model.layers[il].ffn_up, nullptr, nullptr,
|
||||
model.layers[il].ffn_gate, nullptr, nullptr,
|
||||
model.layers[il].ffn_down, nullptr, nullptr,
|
||||
nullptr,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, cb, il, gf, false, false);
|
||||
cb(cur, "ffn_out", il);
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_residual);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
ggml_tensor * output = model.output;
|
||||
if (output == nullptr) {
|
||||
output = ggml_new_tensor_2d(ctx0, GGML_TYPE_Q4_0, n_embd, hparams.n_vocab);
|
||||
}
|
||||
|
||||
ggml_tensor * result = build_output(lctx, ctx0, inpL, output, model.output_norm, cb);
|
||||
cb(result, "result_output", -1);
|
||||
ggml_build_forward_expand(gf, result);
|
||||
|
||||
return gf;
|
||||
}
|
||||
@ -79,6 +79,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_MISTRAL4, "mistral4" },
|
||||
{ LLM_ARCH_GEMMA4, "gemma4" },
|
||||
{ LLM_ARCH_GEMMA4_MTP, "gemma4_mtp" },
|
||||
{ LLM_ARCH_DFLASH_DRAFT, "dflash-draft" },
|
||||
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
||||
};
|
||||
|
||||
@ -145,6 +146,10 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||
{ LLM_KV_MTP_USE_ORDERED_EMBEDDINGS, "%s.use_ordered_embeddings" },
|
||||
{ LLM_KV_MTP_CENTROID_COUNT, "%s.centroid_count" },
|
||||
{ LLM_KV_MTP_CENTROID_TOP_K, "%s.centroid_top_k" },
|
||||
{ LLM_KV_DFLASH_BLOCK_SIZE, "%s.dflash.block_size" },
|
||||
{ LLM_KV_DFLASH_MASK_TOKEN_ID, "%s.dflash.mask_token_id" },
|
||||
{ LLM_KV_DFLASH_TARGET_LAYER_IDS, "%s.dflash.target_layer_ids" },
|
||||
{ LLM_KV_DFLASH_N_TARGET_FEATURES, "%s.dflash.n_target_features" },
|
||||
|
||||
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
|
||||
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
|
||||
@ -279,4 +284,3 @@ bool llm_arch_is_hybrid(const llm_arch & arch) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -78,6 +78,7 @@ enum llm_arch {
|
||||
LLM_ARCH_MISTRAL4,
|
||||
LLM_ARCH_GEMMA4,
|
||||
LLM_ARCH_GEMMA4_MTP,
|
||||
LLM_ARCH_DFLASH_DRAFT,
|
||||
LLM_ARCH_UNKNOWN,
|
||||
};
|
||||
|
||||
@ -138,6 +139,10 @@ enum llm_kv {
|
||||
LLM_KV_MTP_USE_ORDERED_EMBEDDINGS,
|
||||
LLM_KV_MTP_CENTROID_COUNT,
|
||||
LLM_KV_MTP_CENTROID_TOP_K,
|
||||
LLM_KV_DFLASH_BLOCK_SIZE,
|
||||
LLM_KV_DFLASH_MASK_TOKEN_ID,
|
||||
LLM_KV_DFLASH_TARGET_LAYER_IDS,
|
||||
LLM_KV_DFLASH_N_TARGET_FEATURES,
|
||||
|
||||
LLM_KV_ATTENTION_HEAD_COUNT,
|
||||
LLM_KV_ATTENTION_HEAD_COUNT_KV,
|
||||
@ -367,6 +372,8 @@ enum llm_tensor {
|
||||
LLM_TENSOR_MTP_POST_PROJ,
|
||||
LLM_TENSOR_MTP_TOKEN_ORDERING,
|
||||
LLM_TENSOR_MTP_CENTROIDS,
|
||||
LLM_TENSOR_DFLASH_FC,
|
||||
LLM_TENSOR_DFLASH_HIDDEN_NORM,
|
||||
|
||||
LLM_TENSOR_UNKNOWN,
|
||||
};
|
||||
|
||||
@ -112,6 +112,9 @@ void llm_build_context::init() {
|
||||
lctx.inp_pos_bucket = nullptr;
|
||||
lctx.inp_embd_enc = nullptr;
|
||||
lctx.inp_KQ_mask_cross = nullptr;
|
||||
lctx.inp_dflash_target_features = nullptr;
|
||||
lctx.inp_dflash_pos_ctx = nullptr;
|
||||
lctx.inp_dflash_kq_mask = nullptr;
|
||||
}
|
||||
|
||||
void llm_build_context::free() {
|
||||
@ -2372,6 +2375,10 @@ ggml_cgraph * llm_build_context::llama_build_graph(
|
||||
{
|
||||
result = llm.build_gemma4_mtp();
|
||||
} break;
|
||||
case LLM_ARCH_DFLASH_DRAFT:
|
||||
{
|
||||
result = llm.build_dflash();
|
||||
} break;
|
||||
case LLM_ARCH_STARCODER2:
|
||||
{
|
||||
result = llm.build_starcoder2();
|
||||
|
||||
@ -242,6 +242,8 @@ struct llm_build_context {
|
||||
|
||||
ggml_cgraph * build_gemma4_mtp();
|
||||
|
||||
ggml_cgraph * build_dflash();
|
||||
|
||||
ggml_cgraph * build_starcoder2();
|
||||
|
||||
ggml_cgraph * build_mamba();
|
||||
|
||||
@ -278,6 +278,25 @@ struct llama_context {
|
||||
size_t draft_input_hidden_state_n_floats = 0;
|
||||
std::vector<float> draft_input_hidden_state_owned;
|
||||
|
||||
const float * dflash_target_features = nullptr;
|
||||
size_t dflash_target_features_n_floats = 0;
|
||||
int32_t dflash_target_features_n_rows = 0;
|
||||
std::vector<float> dflash_target_features_owned;
|
||||
std::vector<float> dflash_target_features_padded;
|
||||
std::vector<float> dflash_feature_view_buffer;
|
||||
std::vector<llama_pos> dflash_pos_ctx_data;
|
||||
std::vector<float> dflash_kq_mask_data;
|
||||
|
||||
struct dflash_capture_state {
|
||||
std::vector<int32_t> layer_ids;
|
||||
std::vector<std::vector<float>> layer_rows;
|
||||
int32_t row_count = 0;
|
||||
int32_t row_width = 0;
|
||||
ggml_backend_sched_eval_callback prev_cb_eval = nullptr;
|
||||
void * prev_cb_eval_user_data = nullptr;
|
||||
};
|
||||
std::unique_ptr<dflash_capture_state> dflash_capture;
|
||||
|
||||
// input tensors
|
||||
struct ggml_tensor * inp_tokens; // I32 [n_batch]
|
||||
struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch]
|
||||
@ -297,6 +316,9 @@ struct llama_context {
|
||||
struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
|
||||
struct ggml_tensor * inp_scale = nullptr; // F32 [n_tokens]
|
||||
struct ggml_tensor * inp_mtp_states = nullptr;
|
||||
struct ggml_tensor * inp_dflash_target_features = nullptr; // F32 [n_target_features, cross_ctx]
|
||||
struct ggml_tensor * inp_dflash_pos_ctx = nullptr; // I32 [cross_ctx]
|
||||
struct ggml_tensor * inp_dflash_kq_mask = nullptr; // F32 [cross_ctx + n_batch, GGML_PAD(n_batch)]
|
||||
|
||||
ggml_backend_t ggml_backend_by_name(const char * name);
|
||||
|
||||
@ -320,4 +342,3 @@ struct llama_context {
|
||||
void set_mtp_op_type(llama_mtp_op_type value);
|
||||
|
||||
};
|
||||
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
#include "llama-model-loader.h"
|
||||
#include "llama-model.h"
|
||||
|
||||
#include <limits>
|
||||
#include <map>
|
||||
|
||||
#define LLAMA_MAX_EXPERTS 512 // Qwen3 Next
|
||||
@ -36,6 +37,89 @@ static inline const char * llm_expert_gating_func_name(llm_expert_gating_func_ty
|
||||
}
|
||||
}
|
||||
|
||||
static bool load_dflash_target_layer_ids(
|
||||
llama_model_loader & ml,
|
||||
const std::string & key,
|
||||
llama_hparams & hparams,
|
||||
bool required) {
|
||||
const int kid = gguf_find_key(ml.meta, key.c_str());
|
||||
if (kid < 0 || gguf_get_kv_type(ml.meta, kid) != GGUF_TYPE_ARRAY) {
|
||||
if (required) {
|
||||
throw std::runtime_error(format("array key not found in model: %s", key.c_str()));
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
const enum gguf_type type = gguf_get_arr_type(ml.meta, kid);
|
||||
if (type != GGUF_TYPE_UINT32 && type != GGUF_TYPE_INT32) {
|
||||
throw std::runtime_error(format("dflash: %s must be a uint32/int32 array", key.c_str()));
|
||||
}
|
||||
|
||||
const size_t n = gguf_get_arr_n(ml.meta, kid);
|
||||
if (n == 0) {
|
||||
throw std::runtime_error(format("dflash: %s must not be empty", key.c_str()));
|
||||
}
|
||||
if (n > 8) {
|
||||
throw std::runtime_error(format("dflash: %s has %zu entries, max is 8", key.c_str(), n));
|
||||
}
|
||||
|
||||
hparams.dflash_n_target_layers = (uint32_t) n;
|
||||
for (uint32_t & id : hparams.dflash_target_layer_ids) {
|
||||
id = 0;
|
||||
}
|
||||
|
||||
const void * data = gguf_get_arr_data(ml.meta, kid);
|
||||
for (uint32_t i = 0; i < hparams.dflash_n_target_layers; ++i) {
|
||||
if (type == GGUF_TYPE_INT32) {
|
||||
const int32_t id = ((const int32_t *) data)[i];
|
||||
if (id < 0) {
|
||||
throw std::runtime_error(format("dflash: %s contains negative layer id %d", key.c_str(), id));
|
||||
}
|
||||
hparams.dflash_target_layer_ids[i] = (uint32_t) id;
|
||||
} else {
|
||||
hparams.dflash_target_layer_ids[i] = ((const uint32_t *) data)[i];
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static void validate_dflash_hparams(llama_hparams & hparams, llm_arch arch) {
|
||||
if (hparams.dflash_block_size <= 1) {
|
||||
throw std::runtime_error(format("%s: dflash block_size must be > 1", llama_model_arch_name(arch)));
|
||||
}
|
||||
if (hparams.dflash_n_target_layers == 0) {
|
||||
throw std::runtime_error(format("%s: dflash target_layer_ids are required", llama_model_arch_name(arch)));
|
||||
}
|
||||
|
||||
if (arch == LLM_ARCH_DFLASH_DRAFT && hparams.n_embd > 0) {
|
||||
const uint32_t expected_n_target_features = hparams.n_embd * hparams.dflash_n_target_layers;
|
||||
if (expected_n_target_features > 0 && hparams.dflash_n_target_features != expected_n_target_features) {
|
||||
LLAMA_LOG_WARN(
|
||||
"%s: overriding dflash n_target_features from %u to %u based on n_embd=%u and n_target_layers=%u\n",
|
||||
llama_model_arch_name(arch),
|
||||
hparams.dflash_n_target_features,
|
||||
expected_n_target_features,
|
||||
hparams.n_embd,
|
||||
hparams.dflash_n_target_layers);
|
||||
hparams.dflash_n_target_features = expected_n_target_features;
|
||||
}
|
||||
}
|
||||
|
||||
if (hparams.dflash_n_target_features == 0) {
|
||||
throw std::runtime_error(format(
|
||||
"%s: dflash n_target_features must be > 0",
|
||||
llama_model_arch_name(arch)));
|
||||
}
|
||||
if (hparams.dflash_n_target_features % hparams.dflash_n_target_layers != 0) {
|
||||
throw std::runtime_error(format(
|
||||
"%s: dflash n_target_features=%u must be divisible by n_target_layers=%u",
|
||||
llama_model_arch_name(arch),
|
||||
hparams.dflash_n_target_features,
|
||||
hparams.dflash_n_target_layers));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void llm_load_hparams(
|
||||
llama_model_loader & ml,
|
||||
@ -774,6 +858,18 @@ void llm_load_hparams(
|
||||
default: model.type = e_model::MODEL_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_DFLASH_DRAFT:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
ml.get_key(LLM_KV_DFLASH_BLOCK_SIZE, hparams.dflash_block_size, false);
|
||||
ml.get_key(LLM_KV_DFLASH_MASK_TOKEN_ID, hparams.dflash_mask_token_id, false);
|
||||
ml.get_key(LLM_KV_DFLASH_N_TARGET_FEATURES, hparams.dflash_n_target_features, false);
|
||||
load_dflash_target_layer_ids(ml, LLM_KV(model.arch)(LLM_KV_DFLASH_TARGET_LAYER_IDS), hparams, false);
|
||||
validate_dflash_hparams(hparams, model.arch);
|
||||
|
||||
hparams.n_layer_kv_from_start = hparams.n_layer;
|
||||
model.type = e_model::MODEL_UNKNOWN;
|
||||
} break;
|
||||
|
||||
case LLM_ARCH_STARCODER2:
|
||||
{
|
||||
|
||||
@ -140,6 +140,13 @@ struct llama_hparams {
|
||||
uint32_t mtp_num_centroids = 0;
|
||||
uint32_t mtp_centroid_top_k = 0;
|
||||
|
||||
// DFlash draft model metadata
|
||||
uint32_t dflash_block_size = 16;
|
||||
uint32_t dflash_mask_token_id = 0;
|
||||
uint32_t dflash_n_target_features = 0;
|
||||
uint32_t dflash_n_target_layers = 0;
|
||||
uint32_t dflash_target_layer_ids[8] = {};
|
||||
|
||||
// needed by encoder-decoder models (e.g. T5, FLAN-T5)
|
||||
// ref: https://github.com/ggerganov/llama.cpp/pull/8141
|
||||
llama_token dec_start_token_id = -1;
|
||||
@ -159,6 +166,10 @@ struct llama_hparams {
|
||||
if (this->n_ctx_train != other.n_ctx_train) return true;
|
||||
if (this->n_embd != other.n_embd) return true;
|
||||
if (this->mtp_backbone_n_embd != other.mtp_backbone_n_embd) return true;
|
||||
if (this->dflash_block_size != other.dflash_block_size) return true;
|
||||
if (this->dflash_mask_token_id != other.dflash_mask_token_id) return true;
|
||||
if (this->dflash_n_target_features != other.dflash_n_target_features) return true;
|
||||
if (this->dflash_n_target_layers != other.dflash_n_target_layers) return true;
|
||||
if (this->n_layer != other.n_layer) return true;
|
||||
if (this->n_rot != other.n_rot) return true;
|
||||
if (this->n_swa != other.n_swa) return true;
|
||||
@ -189,6 +200,9 @@ struct llama_hparams {
|
||||
if (this->ssm_dt_rank != other.ssm_dt_rank) return true;
|
||||
if (this->ssm_n_group != other.ssm_n_group) return true;
|
||||
if (this->recurrent_layer_arr != other.recurrent_layer_arr) return true;
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
if (this->dflash_target_layer_ids[i] != other.dflash_target_layer_ids[i]) return true;
|
||||
}
|
||||
|
||||
if (this->dec_start_token_id != other.dec_start_token_id) return true;
|
||||
|
||||
|
||||
@ -98,6 +98,8 @@ struct create_tensors_helper : public create_tensors_helper_interface {
|
||||
|
||||
bool create_gemma4_mtp_tensors(const LLM_TN & tn);
|
||||
|
||||
bool create_dflash_tensors(const LLM_TN & tn);
|
||||
|
||||
bool create_starcoder2_tensors(const LLM_TN & tn);
|
||||
|
||||
bool create_mamba_tensors(const LLM_TN & tn);
|
||||
@ -2192,6 +2194,43 @@ bool create_tensors_helper::create_gemma4_mtp_tensors(const LLM_TN & tn) {
|
||||
return use_mmap_buffer;
|
||||
}
|
||||
|
||||
bool create_tensors_helper::create_dflash_tensors(const LLM_TN & tn) {
|
||||
LOADING_PRELUDE
|
||||
|
||||
const bool use_split_ctx = model.split_mode == LLAMA_SPLIT_MODE_GRAPH || model.split_mode == LLAMA_SPLIT_MODE_ATTN;
|
||||
|
||||
model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
||||
model.output = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
if (model.output == nullptr && model.tok_embd != nullptr) {
|
||||
model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
|
||||
}
|
||||
model.dflash_fc = create_tensor(ctx_output, tn(LLM_TENSOR_DFLASH_FC, "weight"), {(int64_t) hparams.dflash_n_target_features, n_embd}, 0);
|
||||
model.dflash_hidden_norm = create_tensor(ctx_output, tn(LLM_TENSOR_DFLASH_HIDDEN_NORM, "weight"), {n_embd}, 0);
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
ggml_context * ctx_split = use_split_ctx ? ctx_for_layer_split(i) : ctx_for_layer(i);
|
||||
auto & layer = model.layers[i];
|
||||
|
||||
layer.attn_norm = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
||||
layer.attn_post_norm = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
|
||||
layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
|
||||
layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
|
||||
layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v * n_head, n_embd}, 0);
|
||||
|
||||
layer.attn_q_norm = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
|
||||
layer.attn_k_norm = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
|
||||
|
||||
layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
||||
layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
|
||||
layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||
}
|
||||
|
||||
return use_mmap_buffer;
|
||||
}
|
||||
|
||||
bool create_tensors_helper::create_starcoder2_tensors(const LLM_TN & tn) {
|
||||
LOADING_PRELUDE
|
||||
|
||||
@ -4263,6 +4302,8 @@ bool create_tensors_helper::create_tensors() {
|
||||
use_mmap_buffer = create_gemma4_tensors(tn); break;
|
||||
case LLM_ARCH_GEMMA4_MTP:
|
||||
use_mmap_buffer = create_gemma4_mtp_tensors(tn); break;
|
||||
case LLM_ARCH_DFLASH_DRAFT:
|
||||
use_mmap_buffer = create_dflash_tensors(tn); break;
|
||||
case LLM_ARCH_STARCODER2:
|
||||
use_mmap_buffer = create_starcoder2_tensors(tn); break;
|
||||
case LLM_ARCH_MAMBA:
|
||||
|
||||
@ -825,6 +825,27 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
|
||||
{ LLM_TENSOR_MTP_CENTROIDS, "mtp_centroids" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_DFLASH_DRAFT,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
{ LLM_TENSOR_DFLASH_FC, "dflash_fc" },
|
||||
{ LLM_TENSOR_DFLASH_HIDDEN_NORM, "dflash_hidden_norm" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_STARCODER2,
|
||||
{
|
||||
|
||||
@ -428,6 +428,8 @@ struct llama_model {
|
||||
struct ggml_tensor * mtp_post_proj = nullptr;
|
||||
struct ggml_tensor * mtp_token_ordering = nullptr;
|
||||
struct ggml_tensor * mtp_centroids = nullptr;
|
||||
struct ggml_tensor * dflash_fc = nullptr;
|
||||
struct ggml_tensor * dflash_hidden_norm = nullptr;
|
||||
|
||||
struct ggml_tensor * output_norm;
|
||||
struct ggml_tensor * output_norm_b;
|
||||
@ -621,4 +623,3 @@ struct LLM_TN {
|
||||
std::string llama_model_ftype_name(llama_ftype ftype);
|
||||
|
||||
const char * llama_model_type_name(e_model type);
|
||||
|
||||
|
||||
@ -1,5 +1,8 @@
|
||||
#include "llama-spec-features.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <random>
|
||||
|
||||
#include "llama-model.h"
|
||||
@ -18,6 +21,63 @@ uint32_t llama_mtp_state_n_embd(const struct llama_context * ctx) {
|
||||
return hparams.n_embd;
|
||||
}
|
||||
|
||||
int32_t llama_model_dflash_block_size(const struct llama_model * model) {
|
||||
return model ? (int32_t) model->hparams.dflash_block_size : 0;
|
||||
}
|
||||
|
||||
int32_t llama_model_dflash_mask_token_id(const struct llama_model * model) {
|
||||
return model ? (int32_t) model->hparams.dflash_mask_token_id : -1;
|
||||
}
|
||||
|
||||
int32_t llama_model_dflash_n_target_layers(const struct llama_model * model) {
|
||||
return model ? (int32_t) model->hparams.dflash_n_target_layers : 0;
|
||||
}
|
||||
|
||||
int32_t llama_model_dflash_n_target_features(const struct llama_model * model) {
|
||||
return model ? (int32_t) model->hparams.dflash_n_target_features : 0;
|
||||
}
|
||||
|
||||
int32_t llama_model_dflash_target_layer_ids(
|
||||
const struct llama_model * model,
|
||||
int32_t * layer_ids,
|
||||
int32_t capacity) {
|
||||
if (model == nullptr || layer_ids == nullptr || capacity <= 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const int32_t n_layers = std::min<int32_t>((int32_t) model->hparams.dflash_n_target_layers, capacity);
|
||||
for (int32_t i = 0; i < n_layers; ++i) {
|
||||
layer_ids[i] = (int32_t) model->hparams.dflash_target_layer_ids[i];
|
||||
}
|
||||
|
||||
return n_layers;
|
||||
}
|
||||
|
||||
bool llama_model_share_dflash_io_tensors(
|
||||
struct llama_model * draft_model,
|
||||
const struct llama_model * target_model) {
|
||||
if (draft_model == nullptr || target_model == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (draft_model->arch != LLM_ARCH_DFLASH_DRAFT) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (draft_model->tok_embd == nullptr) {
|
||||
draft_model->tok_embd = target_model->tok_embd;
|
||||
}
|
||||
|
||||
if (draft_model->output == nullptr) {
|
||||
draft_model->output = target_model->output ? target_model->output : target_model->tok_embd;
|
||||
if (draft_model->output == nullptr) {
|
||||
draft_model->output = draft_model->tok_embd;
|
||||
}
|
||||
}
|
||||
|
||||
return draft_model->tok_embd != nullptr && draft_model->output != nullptr;
|
||||
}
|
||||
|
||||
bool llama_set_draft_input_hidden_state_copy(
|
||||
struct llama_context * ctx,
|
||||
const float * hidden_state,
|
||||
@ -32,6 +92,211 @@ bool llama_set_draft_input_hidden_state_copy(
|
||||
return true;
|
||||
}
|
||||
|
||||
bool llama_set_dflash_target_features_copy(
|
||||
struct llama_context * ctx,
|
||||
const float * target_features,
|
||||
size_t n_floats,
|
||||
int32_t n_rows) {
|
||||
if (ctx == nullptr || target_features == nullptr || n_floats == 0 || n_rows <= 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
ctx->dflash_target_features_owned.assign(target_features, target_features + n_floats);
|
||||
ctx->dflash_target_features = ctx->dflash_target_features_owned.data();
|
||||
ctx->dflash_target_features_n_floats = n_floats;
|
||||
ctx->dflash_target_features_n_rows = n_rows;
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool llama_dflash_parse_layer_id(const struct ggml_tensor * tensor, int32_t & layer_id) {
|
||||
if (tensor == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
static constexpr const char * prefix = "l_out-";
|
||||
if (std::strncmp(tensor->name, prefix, std::strlen(prefix)) != 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
char * end = nullptr;
|
||||
const long raw = std::strtol(tensor->name + std::strlen(prefix), &end, 10);
|
||||
if (end == tensor->name + std::strlen(prefix) || *end != '\0') {
|
||||
return false;
|
||||
}
|
||||
|
||||
layer_id = (int32_t) raw;
|
||||
if (layer_id >= 1000) {
|
||||
layer_id %= 1000;
|
||||
}
|
||||
|
||||
return layer_id >= 0;
|
||||
}
|
||||
|
||||
static int32_t llama_dflash_find_layer_index(const struct llama_context * ctx, int32_t layer_id) {
|
||||
if (ctx == nullptr || !ctx->dflash_capture) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
const auto & layer_ids = ctx->dflash_capture->layer_ids;
|
||||
const auto it = std::find(layer_ids.begin(), layer_ids.end(), layer_id);
|
||||
return it == layer_ids.end() ? -1 : (int32_t) std::distance(layer_ids.begin(), it);
|
||||
}
|
||||
|
||||
static bool llama_dflash_capture_eval_callback(struct ggml_tensor * tensor, bool ask, void * user_data) {
|
||||
auto * ctx = static_cast<llama_context *>(user_data);
|
||||
if (ctx == nullptr || !ctx->dflash_capture) {
|
||||
return false;
|
||||
}
|
||||
|
||||
int32_t layer_id = -1;
|
||||
if (!llama_dflash_parse_layer_id(tensor, layer_id)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const int32_t layer_idx = llama_dflash_find_layer_index(ctx, layer_id);
|
||||
if (layer_idx < 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (ask) {
|
||||
return true;
|
||||
}
|
||||
|
||||
const int32_t row_width = (int32_t) tensor->ne[0];
|
||||
const int32_t row_count = row_width > 0 ? (int32_t) (ggml_nelements(tensor) / (int64_t) row_width) : 0;
|
||||
if (row_width <= 0 || row_count <= 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto & capture = *ctx->dflash_capture;
|
||||
auto & rows = capture.layer_rows[(size_t) layer_idx];
|
||||
rows.resize((size_t) row_count * (size_t) row_width);
|
||||
ggml_backend_tensor_get(tensor, rows.data(), 0, ggml_nbytes(tensor));
|
||||
capture.row_width = row_width;
|
||||
capture.row_count = row_count;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool llama_set_dflash_capture_layers(
|
||||
struct llama_context * ctx,
|
||||
const int32_t * layer_ids,
|
||||
int32_t n_layers) {
|
||||
if (ctx == nullptr || layer_ids == nullptr || n_layers <= 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto capture = std::make_unique<llama_context::dflash_capture_state>();
|
||||
capture->layer_ids.assign(layer_ids, layer_ids + n_layers);
|
||||
capture->layer_rows.resize((size_t) n_layers);
|
||||
capture->prev_cb_eval = ctx->cparams.cb_eval;
|
||||
capture->prev_cb_eval_user_data = ctx->cparams.cb_eval_user_data;
|
||||
ctx->dflash_capture = std::move(capture);
|
||||
ctx->dflash_feature_view_buffer.clear();
|
||||
|
||||
ctx->cparams.cb_eval = llama_dflash_capture_eval_callback;
|
||||
ctx->cparams.cb_eval_user_data = ctx;
|
||||
if (ctx->sched != nullptr) {
|
||||
ggml_backend_sched_set_eval_callback(ctx->sched, ctx->cparams.cb_eval, ctx->cparams.cb_eval_user_data);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void llama_clear_dflash_capture(struct llama_context * ctx) {
|
||||
if (ctx == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
ggml_backend_sched_eval_callback prev_cb_eval = nullptr;
|
||||
void * prev_cb_eval_user_data = nullptr;
|
||||
if (ctx->dflash_capture) {
|
||||
prev_cb_eval = ctx->dflash_capture->prev_cb_eval;
|
||||
prev_cb_eval_user_data = ctx->dflash_capture->prev_cb_eval_user_data;
|
||||
}
|
||||
|
||||
ctx->dflash_capture.reset();
|
||||
ctx->dflash_feature_view_buffer.clear();
|
||||
|
||||
if (ctx->cparams.cb_eval == llama_dflash_capture_eval_callback && ctx->cparams.cb_eval_user_data == ctx) {
|
||||
ctx->cparams.cb_eval = prev_cb_eval;
|
||||
ctx->cparams.cb_eval_user_data = prev_cb_eval_user_data;
|
||||
if (ctx->sched != nullptr) {
|
||||
ggml_backend_sched_set_eval_callback(ctx->sched, prev_cb_eval, prev_cb_eval_user_data);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static bool llama_spec_prepare_dflash_capture(
|
||||
struct llama_context * ctx,
|
||||
int32_t & row_count,
|
||||
int32_t & row_width,
|
||||
int32_t & n_layers) {
|
||||
if (ctx == nullptr || !ctx->dflash_capture) {
|
||||
return false;
|
||||
}
|
||||
|
||||
llama_synchronize(ctx);
|
||||
|
||||
auto & capture = *ctx->dflash_capture;
|
||||
row_count = capture.row_count;
|
||||
row_width = capture.row_width;
|
||||
n_layers = (int32_t) capture.layer_ids.size();
|
||||
if (row_count <= 0 || row_width <= 0 || n_layers <= 0 || capture.layer_rows.size() != (size_t) n_layers) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (const auto & rows : capture.layer_rows) {
|
||||
if (rows.size() != (size_t) row_count * (size_t) row_width) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool llama_spec_materialize_dflash_rows(
|
||||
struct llama_context * ctx,
|
||||
const std::vector<int32_t> & row_indices,
|
||||
std::vector<float> & rows_out,
|
||||
int32_t & combined_width) {
|
||||
rows_out.clear();
|
||||
combined_width = 0;
|
||||
if (ctx == nullptr || row_indices.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
int32_t row_count = 0;
|
||||
int32_t row_width = 0;
|
||||
int32_t n_layers = 0;
|
||||
if (!llama_spec_prepare_dflash_capture(ctx, row_count, row_width, n_layers)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
combined_width = row_width * n_layers;
|
||||
rows_out.resize((size_t) row_indices.size() * (size_t) combined_width);
|
||||
|
||||
const auto & layer_rows = ctx->dflash_capture->layer_rows;
|
||||
for (size_t out_row = 0; out_row < row_indices.size(); ++out_row) {
|
||||
int32_t row_index = row_indices[out_row];
|
||||
if (row_index < 0) {
|
||||
row_index += row_count;
|
||||
}
|
||||
if (row_index < 0 || row_index >= row_count) {
|
||||
rows_out.clear();
|
||||
combined_width = 0;
|
||||
return false;
|
||||
}
|
||||
|
||||
float * dst = rows_out.data() + out_row * (size_t) combined_width;
|
||||
for (int32_t layer_idx = 0; layer_idx < n_layers; ++layer_idx) {
|
||||
const float * src = layer_rows[(size_t) layer_idx].data() + (size_t) row_index * (size_t) row_width;
|
||||
std::memcpy(dst + (size_t) layer_idx * (size_t) row_width, src, (size_t) row_width * sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool llama_spec_prepare_hidden_feature_view(
|
||||
struct llama_context * ctx,
|
||||
int32_t n_rows,
|
||||
@ -88,6 +353,92 @@ bool llama_spec_get_hidden_feature_view(
|
||||
return true;
|
||||
}
|
||||
|
||||
bool llama_spec_get_dflash_feature_view(
|
||||
struct llama_context * ctx,
|
||||
const llama_batch & batch,
|
||||
llama_spec_feature_view & view) {
|
||||
if (ctx == nullptr || batch.n_tokens <= 0 || batch.pos == nullptr || batch.n_seq_id == nullptr || batch.seq_id == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<int32_t> row_indices((size_t) batch.n_tokens);
|
||||
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
||||
row_indices[(size_t) i] = i;
|
||||
}
|
||||
|
||||
view = {};
|
||||
view.kind = LLAMA_SPEC_FEATURE_HIDDEN_STATE;
|
||||
if (!llama_spec_materialize_dflash_rows(ctx, row_indices, ctx->dflash_feature_view_buffer, view.width)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
view.rows.reserve((size_t) batch.n_tokens);
|
||||
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
||||
if (batch.n_seq_id[i] <= 0 || batch.seq_id[i] == nullptr) {
|
||||
view.rows.clear();
|
||||
return false;
|
||||
}
|
||||
|
||||
view.rows.push_back({
|
||||
/* .seq_id = */ batch.seq_id[i][0],
|
||||
/* .pos = */ batch.pos[i],
|
||||
/* .data = */ ctx->dflash_feature_view_buffer.data() + (size_t) i * (size_t) view.width,
|
||||
});
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool llama_spec_get_dflash_feature_view_for_seq(
|
||||
struct llama_context * ctx,
|
||||
const llama_batch & batch,
|
||||
llama_seq_id seq_id,
|
||||
llama_spec_feature_view & view) {
|
||||
if (ctx == nullptr || batch.n_tokens <= 0 || batch.pos == nullptr || batch.n_seq_id == nullptr || batch.seq_id == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<int32_t> row_indices;
|
||||
row_indices.reserve((size_t) batch.n_tokens);
|
||||
std::vector<int32_t> batch_indices;
|
||||
batch_indices.reserve((size_t) batch.n_tokens);
|
||||
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
||||
if (batch.n_seq_id[i] <= 0 || batch.seq_id[i] == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (int32_t j = 0; j < batch.n_seq_id[i]; ++j) {
|
||||
if (batch.seq_id[i][j] == seq_id) {
|
||||
row_indices.push_back(i);
|
||||
batch_indices.push_back(i);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (row_indices.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
view = {};
|
||||
view.kind = LLAMA_SPEC_FEATURE_HIDDEN_STATE;
|
||||
if (!llama_spec_materialize_dflash_rows(ctx, row_indices, ctx->dflash_feature_view_buffer, view.width)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
view.rows.reserve(row_indices.size());
|
||||
for (size_t i = 0; i < batch_indices.size(); ++i) {
|
||||
const int32_t batch_index = batch_indices[i];
|
||||
view.rows.push_back({
|
||||
/* .seq_id = */ seq_id,
|
||||
/* .pos = */ batch.pos[batch_index],
|
||||
/* .data = */ ctx->dflash_feature_view_buffer.data() + i * (size_t) view.width,
|
||||
});
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool llama_spec_get_hidden_feature_view_for_seq(
|
||||
struct llama_context * ctx,
|
||||
const llama_batch & batch,
|
||||
@ -179,4 +530,17 @@ bool llama_spec_copy_hidden_rows_from_output_indices(
|
||||
}
|
||||
|
||||
return hidden_rows.size() == (size_t) output_indices.size() * view.width;
|
||||
}
|
||||
}
|
||||
|
||||
bool llama_spec_copy_dflash_rows_from_output_indices(
|
||||
struct llama_context * ctx,
|
||||
const std::vector<int32_t> & output_indices,
|
||||
std::vector<float> & hidden_rows) {
|
||||
int32_t combined_width = 0;
|
||||
if (!llama_spec_materialize_dflash_rows(ctx, output_indices, hidden_rows, combined_width)) {
|
||||
hidden_rows.clear();
|
||||
return false;
|
||||
}
|
||||
|
||||
return hidden_rows.size() == (size_t) output_indices.size() * (size_t) combined_width;
|
||||
}
|
||||
|
||||
@ -25,16 +25,57 @@ struct llama_spec_feature_view {
|
||||
|
||||
uint32_t llama_mtp_state_n_embd(const struct llama_context * ctx);
|
||||
|
||||
int32_t llama_model_dflash_block_size(const struct llama_model * model);
|
||||
|
||||
int32_t llama_model_dflash_mask_token_id(const struct llama_model * model);
|
||||
|
||||
int32_t llama_model_dflash_n_target_layers(const struct llama_model * model);
|
||||
|
||||
int32_t llama_model_dflash_n_target_features(const struct llama_model * model);
|
||||
|
||||
int32_t llama_model_dflash_target_layer_ids(
|
||||
const struct llama_model * model,
|
||||
int32_t * layer_ids,
|
||||
int32_t capacity);
|
||||
|
||||
bool llama_model_share_dflash_io_tensors(
|
||||
struct llama_model * draft_model,
|
||||
const struct llama_model * target_model);
|
||||
|
||||
bool llama_set_draft_input_hidden_state_copy(
|
||||
struct llama_context * ctx,
|
||||
const float * hidden_state,
|
||||
size_t n_floats);
|
||||
|
||||
bool llama_set_dflash_target_features_copy(
|
||||
struct llama_context * ctx,
|
||||
const float * target_features,
|
||||
size_t n_floats,
|
||||
int32_t n_rows);
|
||||
|
||||
bool llama_set_dflash_capture_layers(
|
||||
struct llama_context * ctx,
|
||||
const int32_t * layer_ids,
|
||||
int32_t n_layers);
|
||||
|
||||
void llama_clear_dflash_capture(struct llama_context * ctx);
|
||||
|
||||
bool llama_spec_get_hidden_feature_view(
|
||||
struct llama_context * ctx,
|
||||
const llama_batch & batch,
|
||||
llama_spec_feature_view & view);
|
||||
|
||||
bool llama_spec_get_dflash_feature_view(
|
||||
struct llama_context * ctx,
|
||||
const llama_batch & batch,
|
||||
llama_spec_feature_view & view);
|
||||
|
||||
bool llama_spec_get_dflash_feature_view_for_seq(
|
||||
struct llama_context * ctx,
|
||||
const llama_batch & batch,
|
||||
llama_seq_id seq_id,
|
||||
llama_spec_feature_view & view);
|
||||
|
||||
bool llama_spec_get_hidden_feature_view_for_seq(
|
||||
struct llama_context * ctx,
|
||||
const llama_batch & batch,
|
||||
@ -51,4 +92,9 @@ bool llama_spec_get_hidden_feature_view_from_output_index(
|
||||
bool llama_spec_copy_hidden_rows_from_output_indices(
|
||||
struct llama_context * ctx,
|
||||
const std::vector<int32_t> & output_indices,
|
||||
std::vector<float> & hidden_rows);
|
||||
std::vector<float> & hidden_rows);
|
||||
|
||||
bool llama_spec_copy_dflash_rows_from_output_indices(
|
||||
struct llama_context * ctx,
|
||||
const std::vector<int32_t> & output_indices,
|
||||
std::vector<float> & hidden_rows);
|
||||
|
||||
@ -3125,6 +3125,10 @@ static std::pair<std::vector<double>, double> get_layer_sizes(const llama_model_
|
||||
name == "mtp_centroids.weight" || name == "mtp_token_ordering.weight") {
|
||||
continue;
|
||||
}
|
||||
if (name == "dflash_fc.weight" || name == "dflash_hidden_norm.weight") {
|
||||
output_misc_size += size;
|
||||
continue;
|
||||
}
|
||||
auto pos = name.find("blk.");
|
||||
if (pos != 0) {
|
||||
LLAMA_LOG_WARN("Oops: tensor with strange name %s\n", name.c_str());
|
||||
@ -4977,6 +4981,61 @@ static bool prepare_mtp_graph_inputs(
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool prepare_dflash_graph_inputs(
|
||||
struct llama_context & lctx,
|
||||
uint32_t n_tokens) {
|
||||
ggml_tensor * target_hidden = lctx.inp_dflash_target_features;
|
||||
ggml_tensor * pos_ctx = lctx.inp_dflash_pos_ctx;
|
||||
ggml_tensor * kq_mask = lctx.inp_dflash_kq_mask;
|
||||
|
||||
if (target_hidden == nullptr || pos_ctx == nullptr || kq_mask == nullptr) {
|
||||
LLAMA_LOG_ERROR("%s: DFlash graph inputs are not initialized\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
const float * src = lctx.dflash_target_features;
|
||||
const size_t total_floats = lctx.dflash_target_features_n_floats;
|
||||
const int32_t n_rows = lctx.dflash_target_features_n_rows;
|
||||
const int32_t width = (int32_t) target_hidden->ne[0];
|
||||
const int32_t cross_ctx = (int32_t) target_hidden->ne[1];
|
||||
const int32_t n_mask_tokens = (int32_t) kq_mask->ne[1];
|
||||
const int32_t n_kv_total = (int32_t) kq_mask->ne[0];
|
||||
|
||||
if (src == nullptr || total_floats == 0 || n_rows <= 0) {
|
||||
LLAMA_LOG_ERROR("%s: missing DFlash target features\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (n_rows > cross_ctx || total_floats != (size_t) n_rows * (size_t) width) {
|
||||
LLAMA_LOG_ERROR("%s: invalid DFlash target feature shape (rows=%d width=%d floats=%zu cross_ctx=%d)\n",
|
||||
__func__, n_rows, width, total_floats, cross_ctx);
|
||||
return false;
|
||||
}
|
||||
|
||||
lctx.dflash_target_features_padded.assign((size_t) cross_ctx * (size_t) width, 0.0f);
|
||||
const size_t dst_offset = (size_t) (cross_ctx - n_rows) * (size_t) width;
|
||||
std::copy(src, src + total_floats, lctx.dflash_target_features_padded.begin() + (ptrdiff_t) dst_offset);
|
||||
ggml_backend_tensor_set(target_hidden, lctx.dflash_target_features_padded.data(), 0, ggml_nbytes(target_hidden));
|
||||
|
||||
lctx.dflash_pos_ctx_data.resize((size_t) cross_ctx);
|
||||
for (int32_t i = 0; i < cross_ctx; ++i) {
|
||||
lctx.dflash_pos_ctx_data[i] = i;
|
||||
}
|
||||
ggml_backend_tensor_set(pos_ctx, lctx.dflash_pos_ctx_data.data(), 0, ggml_nbytes(pos_ctx));
|
||||
|
||||
lctx.dflash_kq_mask_data.assign((size_t) n_kv_total * (size_t) n_mask_tokens, -INFINITY);
|
||||
const int32_t left_pad = cross_ctx - n_rows;
|
||||
for (uint32_t j = 0; j < n_tokens; ++j) {
|
||||
float * row = lctx.dflash_kq_mask_data.data() + (size_t) j * (size_t) n_kv_total;
|
||||
for (int32_t i = left_pad; i < n_kv_total; ++i) {
|
||||
row[i] = 0.0f;
|
||||
}
|
||||
}
|
||||
ggml_backend_tensor_set(kq_mask, lctx.dflash_kq_mask_data.data(), 0, ggml_nbytes(kq_mask));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// decode a batch of tokens by evaluating the transformer
|
||||
//
|
||||
// - lctx: llama context
|
||||
@ -5269,6 +5328,12 @@ static int llama_decode_internal(
|
||||
}
|
||||
}
|
||||
|
||||
if (lctx.model.arch == LLM_ARCH_DFLASH_DRAFT) {
|
||||
if (!prepare_dflash_graph_inputs(lctx, n_tokens)) {
|
||||
return GGML_STATUS_FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
// the output is always the last tensor in the graph
|
||||
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
|
||||
struct ggml_tensor * embd = nullptr;
|
||||
@ -7371,6 +7436,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
|
||||
case LLM_ARCH_STEP35:
|
||||
case LLM_ARCH_GEMMA4:
|
||||
case LLM_ARCH_GEMMA4_MTP:
|
||||
case LLM_ARCH_DFLASH_DRAFT:
|
||||
return LLAMA_ROPE_TYPE_NEOX;
|
||||
|
||||
case LLM_ARCH_QWEN2VL:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user