mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-06-28 04:30:15 -05:00
parent
45dfd80371
commit
8b56d813a9
@ -14,6 +14,7 @@
|
||||
#include <cstring>
|
||||
#include <iomanip>
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
|
||||
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
|
||||
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
|
||||
@ -219,7 +220,6 @@ struct common_speculative_state_mtp : public common_speculative_state {
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
struct common_speculative_state_draft : public common_speculative_state {
|
||||
llama_context * ctx_tgt; // only used for retokenizing from ctx_dft
|
||||
llama_context * ctx_dft;
|
||||
@ -1213,6 +1213,23 @@ void common_speculative_begin(common_speculative * spec, const llama_tokens & pr
|
||||
}
|
||||
}
|
||||
|
||||
struct mtp_last_embd {
|
||||
std::vector<float> embd;
|
||||
float prob;
|
||||
int last_id = -1;
|
||||
};
|
||||
|
||||
// Hopefully never called concurrently from multiple threads
|
||||
static mtp_last_embd & mtp_get_last_embd(const llama_context * ctx) {
|
||||
static std::unordered_map<const llama_context *, mtp_last_embd> map;
|
||||
auto & last = map[ctx];
|
||||
if (last.embd.empty()) {
|
||||
auto n_embd = llama_model_n_embd(llama_get_model(ctx));
|
||||
last.embd.resize(n_embd);
|
||||
}
|
||||
return last;
|
||||
}
|
||||
|
||||
llama_tokens common_speculative_draft(
|
||||
common_speculative * spec,
|
||||
common_params_speculative & params,
|
||||
@ -1361,7 +1378,7 @@ std::vector<llama_token> mtp_speculative_gen_draft(
|
||||
llama_token id_last,
|
||||
int32_t n_past,
|
||||
llama_seq_id seq_id) {
|
||||
|
||||
|
||||
llama_tokens drafts;
|
||||
drafts.reserve(n_draft);
|
||||
|
||||
@ -1372,12 +1389,28 @@ std::vector<llama_token> mtp_speculative_gen_draft(
|
||||
llama_batch mtp_batch = llama_batch_init(1, 0, 1);
|
||||
llama_set_mtp_op_type(ctx, MTP_OP_DRAFT_GEN);
|
||||
|
||||
float prob;
|
||||
auto prob_ptr = p_min > 0 ? &prob : nullptr;
|
||||
|
||||
llama_token current_input_id = id_last;
|
||||
int32_t current_n_past = n_past;
|
||||
const int n_embd = llama_model_n_embd(llama_get_model(ctx));
|
||||
std::vector<float> draft_hidden_state(n_embd);
|
||||
|
||||
for (int i = 0; i < n_draft; ++i) {
|
||||
auto & last = mtp_get_last_embd(ctx);
|
||||
int i0 = 0;
|
||||
if (last.last_id >= 0) {
|
||||
if (last.prob < p_min) {
|
||||
return drafts;
|
||||
}
|
||||
current_input_id = last.last_id;
|
||||
last.last_id = -1;
|
||||
drafts.push_back(current_input_id);
|
||||
current_n_past++;
|
||||
llama_set_draft_input_hidden_state(ctx, last.embd.data());
|
||||
i0 = 1;
|
||||
}
|
||||
|
||||
for (int i = i0; i < n_draft; ++i) {
|
||||
mtp_batch.n_tokens = 0;
|
||||
common_batch_add(mtp_batch, current_input_id, current_n_past, {seq_id}, true);
|
||||
|
||||
@ -1385,8 +1418,10 @@ std::vector<llama_token> mtp_speculative_gen_draft(
|
||||
break;
|
||||
}
|
||||
|
||||
float prob;
|
||||
llama_token id_next = common_sampler_sample_speculative(smpl, ctx, 0, &prob);
|
||||
llama_token id_next = common_sampler_sample_speculative(smpl, ctx, 0, prob_ptr);
|
||||
if (i > 0 && prob_ptr && prob < p_min) {
|
||||
return drafts;
|
||||
}
|
||||
|
||||
drafts.push_back(id_next);
|
||||
|
||||
@ -1394,15 +1429,15 @@ std::vector<llama_token> mtp_speculative_gen_draft(
|
||||
if (!emb) {
|
||||
break;
|
||||
}
|
||||
|
||||
|
||||
// Keep a stable copy because later decode steps reuse ctx->embd storage.
|
||||
memcpy(draft_hidden_state.data(), emb, n_embd * sizeof(float));
|
||||
llama_set_draft_input_hidden_state(ctx, draft_hidden_state.data());
|
||||
memcpy(last.embd.data(), emb, n_embd * sizeof(float));
|
||||
llama_set_draft_input_hidden_state(ctx, last.embd.data());
|
||||
|
||||
current_input_id = id_next;
|
||||
current_n_past++;
|
||||
|
||||
if (prob < p_min) {
|
||||
if (prob_ptr && prob < p_min) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -1431,8 +1466,8 @@ void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, b
|
||||
llama_kv_cache_seq_rm(ctx, seq_id, start_pos, -1);
|
||||
}
|
||||
|
||||
LOG_DBG("[MTP-UPDATE|%s] Updating %d tokens from pos %d...\n",
|
||||
is_prompt_warmup ? "PROMPT_WARMUP" : "GEN_ACCEPTED", batch.n_tokens, (int)start_pos);
|
||||
LOG_DBG("[MTP-UPDATE|%s] Updating %d tokens for seq_id %d from pos %d...\n",
|
||||
is_prompt_warmup ? "PROMPT_WARMUP" : "GEN_ACCEPTED", batch.n_tokens, seq_id, (int)start_pos);
|
||||
|
||||
llama_batch mtp_batch = batch;
|
||||
if (is_prompt_warmup) {
|
||||
@ -1452,8 +1487,7 @@ void mtp_accept_tokens(
|
||||
struct llama_context * ctx,
|
||||
const std::vector<llama_token> & ids,
|
||||
int32_t n_past_base,
|
||||
llama_seq_id seq_id
|
||||
) {
|
||||
llama_seq_id seq_id) {
|
||||
if (ids.empty()) {
|
||||
return;
|
||||
}
|
||||
@ -1465,5 +1499,13 @@ void mtp_accept_tokens(
|
||||
|
||||
mtp_update_kv_cache(ctx, accepted_batch, false);
|
||||
|
||||
auto & last = mtp_get_last_embd(ctx);
|
||||
auto embd = llama_get_embeddings_ith(ctx, ids.size() - 1);
|
||||
if (embd) {
|
||||
std::memcpy(last.embd.data(), embd, last.embd.size()*sizeof(float));
|
||||
llama_set_draft_input_hidden_state(ctx, last.embd.data());
|
||||
last.last_id = common_sampler_sample_speculative(nullptr, ctx, ids.size() - 1, &last.prob);
|
||||
}
|
||||
|
||||
llama_batch_free(accepted_batch);
|
||||
}
|
||||
|
||||
@ -147,8 +147,8 @@ struct ggml_tensor * llm_build_context::build_qwen35_mtp(
|
||||
struct ggml_tensor * prev_embeddings,
|
||||
int64_t n_embd_head,
|
||||
struct ggml_cgraph * gf,
|
||||
struct ggml_tensor * inp_pos
|
||||
) {
|
||||
struct ggml_tensor * inp_pos) {
|
||||
|
||||
const int il = hparams.n_layer - 1;
|
||||
|
||||
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
||||
@ -217,4 +217,4 @@ struct ggml_tensor * llm_build_context::build_qwen35_mtp(
|
||||
cb(cur, "result_output", -1);
|
||||
|
||||
return cur;
|
||||
}
|
||||
}
|
||||
|
||||
@ -4590,8 +4590,7 @@ static int llama_decode_internal(
|
||||
tim1 = ggml_time_us();
|
||||
#endif
|
||||
// Do not process logits if MTP is only updating the KV cache.
|
||||
if (cparams.mtp_op_type != MTP_OP_WARMUP &&
|
||||
cparams.mtp_op_type != MTP_OP_UPDATE_ACCEPTED) {
|
||||
if (cparams.mtp_op_type != MTP_OP_WARMUP) { // && cparams.mtp_op_type != MTP_OP_UPDATE_ACCEPTED) {
|
||||
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched, res);
|
||||
GGML_ASSERT(backend_res != nullptr);
|
||||
GGML_ASSERT(lctx.logits != nullptr);
|
||||
@ -4627,7 +4626,8 @@ static int llama_decode_internal(
|
||||
}
|
||||
|
||||
// extract embeddings
|
||||
if (embd && (cparams.mtp_op_type == MTP_OP_NONE || cparams.mtp_op_type == MTP_OP_DRAFT_GEN)) {
|
||||
//if (embd && (cparams.mtp_op_type == MTP_OP_NONE || cparams.mtp_op_type == MTP_OP_DRAFT_GEN)) {
|
||||
if (embd && cparams.mtp_op_type != MTP_OP_WARMUP) {
|
||||
#if IK_PRINT_TIMING
|
||||
tim1 = ggml_time_us();
|
||||
#endif
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user