mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-06-28 04:30:15 -05:00
Change signature of llama_set_draft_input_hidden_state
This commit is contained in:
parent
bc549da0f7
commit
19a72d91a2
@ -1394,10 +1394,10 @@ std::vector<llama_token> mtp_speculative_gen_draft(
|
||||
if (!emb) {
|
||||
break;
|
||||
}
|
||||
|
||||
|
||||
// Keep a stable copy because later decode steps reuse ctx->embd storage.
|
||||
memcpy(draft_hidden_state.data(), emb, n_embd * sizeof(float));
|
||||
llama_set_draft_input_hidden_state(ctx, draft_hidden_state.data());
|
||||
llama_set_draft_input_hidden_state(ctx, draft_hidden_state.data(), n_embd);
|
||||
|
||||
current_input_id = id_next;
|
||||
current_n_past++;
|
||||
|
||||
@ -3142,7 +3142,7 @@ void server_context::add_sampled_tokens() {
|
||||
if (!slot.mtp_hidden_state.empty()) {
|
||||
const int n_embd = llama_model_n_embd(llama_get_model(ctx));
|
||||
const int n_hidden = slot.mtp_hidden_state.size() / n_embd;
|
||||
llama_set_draft_input_hidden_state(hs_ctx, slot.mtp_hidden_state.data() + (n_hidden - 1) * n_embd);
|
||||
llama_set_draft_input_hidden_state(hs_ctx, slot.mtp_hidden_state.data() + (n_hidden - 1) * n_embd, n_embd);
|
||||
} else {
|
||||
LOG_ERROR("MTP hidden state is empty during speculation", {});
|
||||
const float* emb_neg1 = llama_get_embeddings_ith(ctx, -1);
|
||||
@ -3150,7 +3150,7 @@ void server_context::add_sampled_tokens() {
|
||||
const int n_embd = llama_model_n_embd(llama_get_model(ctx));
|
||||
slot.mtp_hidden_state.resize(n_embd);
|
||||
memcpy(slot.mtp_hidden_state.data(), emb_neg1, n_embd * sizeof(float));
|
||||
llama_set_draft_input_hidden_state(hs_ctx, slot.mtp_hidden_state.data());
|
||||
llama_set_draft_input_hidden_state(hs_ctx, slot.mtp_hidden_state.data(), slot.mtp_hidden_state.size());
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -3727,7 +3727,7 @@ static void restore_speculative_checkpoint(
|
||||
slot.mtp_hidden_state = mtp_hidden_state_pre;
|
||||
llama_context * mtp_ctx = common_speculative_get_mtp_ctx(slot.spec);
|
||||
llama_context * mtp_target = mtp_ctx ? mtp_ctx : ctx;
|
||||
llama_set_draft_input_hidden_state(mtp_target, slot.mtp_hidden_state.data());
|
||||
llama_set_draft_input_hidden_state(mtp_target, slot.mtp_hidden_state.data(), slot.mtp_hidden_state.size());
|
||||
mtp_accept_tokens(mtp_target, ids, mtp_n_past_base, slot.id);
|
||||
}
|
||||
|
||||
@ -3775,7 +3775,7 @@ static void restore_speculative_checkpoint(
|
||||
|
||||
llama_context * mtp_ctx_rej = common_speculative_get_mtp_ctx(slot.spec);
|
||||
llama_context * mtp_target_rej = mtp_ctx_rej ? mtp_ctx_rej : ctx;
|
||||
llama_set_draft_input_hidden_state(mtp_target_rej, slot.mtp_hidden_state.data());
|
||||
llama_set_draft_input_hidden_state(mtp_target_rej, slot.mtp_hidden_state.data(), slot.mtp_hidden_state.size());
|
||||
mtp_accept_tokens(mtp_target_rej, ids, slot.spec_ckpt.n_past, slot.id);
|
||||
|
||||
if (n_accepted > 1) {
|
||||
@ -3882,7 +3882,7 @@ void server_context::speculative_decoding_accept() {
|
||||
llama_context * mtp_target = mtp_ctx ? mtp_ctx : ctx;
|
||||
|
||||
slot.mtp_hidden_state = std::move(mtp_hidden_state_pre);
|
||||
llama_set_draft_input_hidden_state(mtp_target, slot.mtp_hidden_state.data());
|
||||
llama_set_draft_input_hidden_state(mtp_target, slot.mtp_hidden_state.data(), slot.mtp_hidden_state.size());
|
||||
mtp_accept_tokens(mtp_target, ids, mtp_n_past_base, slot.id);
|
||||
}
|
||||
llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1);
|
||||
@ -4403,7 +4403,7 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
|
||||
}
|
||||
}
|
||||
llama_context * mtp_target = mtp_ctx ? mtp_ctx : ctx;
|
||||
llama_set_draft_input_hidden_state(mtp_target, batch_mtp_hidden_state.data());
|
||||
llama_set_draft_input_hidden_state(mtp_target, batch_mtp_hidden_state.data(), batch_mtp_hidden_state.size());
|
||||
mtp_update_kv_cache(mtp_target, batch_view, true);
|
||||
}
|
||||
|
||||
|
||||
@ -1560,7 +1560,7 @@ LLAMA_API struct llama_grammar* llama_sampler_init_grammar_lazy_patterns(
|
||||
// Set which, if any, MTP operation the context will use
|
||||
LLAMA_API void llama_set_mtp_op_type(struct llama_context * ctx, enum llama_mtp_op_type mtp_op_type);
|
||||
|
||||
LLAMA_API void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float * hidden_state);
|
||||
LLAMA_API void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float * hidden_state, size_t size);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
||||
@ -217,4 +217,4 @@ struct ggml_tensor * llm_build_context::build_qwen35_mtp(
|
||||
cb(cur, "result_output", -1);
|
||||
|
||||
return cur;
|
||||
}
|
||||
}
|
||||
|
||||
@ -264,6 +264,7 @@ struct llama_context {
|
||||
void * abort_callback_data = nullptr;
|
||||
|
||||
const float * draft_input_hidden_state = nullptr;
|
||||
size_t draft_input_hidden_size = 0;
|
||||
|
||||
// input tensors
|
||||
struct ggml_tensor * inp_tokens; // I32 [n_batch]
|
||||
|
||||
@ -4213,8 +4213,14 @@ static bool prepare_mtp_graph_inputs(struct llama_context & lctx) {
|
||||
LLAMA_LOG_ERROR("%s: Source hidden state is null\n", __func__);
|
||||
return false;
|
||||
}
|
||||
auto nbytes = ggml_nbytes(dst);
|
||||
if (nbytes > lctx.draft_input_hidden_size*sizeof(float)) {
|
||||
LLAMA_LOG_ERROR("%s: saved hidden state size %zu is less than input MTP state %zu\n", __func__,
|
||||
lctx.draft_input_hidden_size, nbytes/sizeof(float));
|
||||
return false;
|
||||
}
|
||||
|
||||
ggml_backend_tensor_set(dst, src, 0, ggml_nbytes(dst));
|
||||
ggml_backend_tensor_set(dst, src, 0, nbytes);
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -9841,8 +9847,9 @@ void llama_set_offload_policy(struct llama_context * lctx, int op, bool on_or_of
|
||||
ggml_backend_sched_set_op_offload(lctx->sched, ggml_op(op), on_or_off);
|
||||
}
|
||||
|
||||
void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float * hidden_state) {
|
||||
void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float * hidden_state, size_t size) {
|
||||
ctx->draft_input_hidden_state = hidden_state;
|
||||
ctx->draft_input_hidden_size = size;
|
||||
}
|
||||
|
||||
size_t llama_fill_from_utf8(void* utf8, void* cpts, void* scripts) {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user