Change signature of llama_set_draft_input_hidden_state

This commit is contained in:
Kawrakow 2026-05-03 05:05:58 +00:00
parent bc549da0f7
commit 19a72d91a2
6 changed files with 20 additions and 12 deletions

View File

@ -1394,10 +1394,10 @@ std::vector<llama_token> mtp_speculative_gen_draft(
if (!emb) {
break;
}
// Keep a stable copy because later decode steps reuse ctx->embd storage.
memcpy(draft_hidden_state.data(), emb, n_embd * sizeof(float));
llama_set_draft_input_hidden_state(ctx, draft_hidden_state.data());
llama_set_draft_input_hidden_state(ctx, draft_hidden_state.data(), n_embd);
current_input_id = id_next;
current_n_past++;

View File

@ -3142,7 +3142,7 @@ void server_context::add_sampled_tokens() {
if (!slot.mtp_hidden_state.empty()) {
const int n_embd = llama_model_n_embd(llama_get_model(ctx));
const int n_hidden = slot.mtp_hidden_state.size() / n_embd;
llama_set_draft_input_hidden_state(hs_ctx, slot.mtp_hidden_state.data() + (n_hidden - 1) * n_embd);
llama_set_draft_input_hidden_state(hs_ctx, slot.mtp_hidden_state.data() + (n_hidden - 1) * n_embd, n_embd);
} else {
LOG_ERROR("MTP hidden state is empty during speculation", {});
const float* emb_neg1 = llama_get_embeddings_ith(ctx, -1);
@ -3150,7 +3150,7 @@ void server_context::add_sampled_tokens() {
const int n_embd = llama_model_n_embd(llama_get_model(ctx));
slot.mtp_hidden_state.resize(n_embd);
memcpy(slot.mtp_hidden_state.data(), emb_neg1, n_embd * sizeof(float));
llama_set_draft_input_hidden_state(hs_ctx, slot.mtp_hidden_state.data());
llama_set_draft_input_hidden_state(hs_ctx, slot.mtp_hidden_state.data(), slot.mtp_hidden_state.size());
}
}
}
@ -3727,7 +3727,7 @@ static void restore_speculative_checkpoint(
slot.mtp_hidden_state = mtp_hidden_state_pre;
llama_context * mtp_ctx = common_speculative_get_mtp_ctx(slot.spec);
llama_context * mtp_target = mtp_ctx ? mtp_ctx : ctx;
llama_set_draft_input_hidden_state(mtp_target, slot.mtp_hidden_state.data());
llama_set_draft_input_hidden_state(mtp_target, slot.mtp_hidden_state.data(), slot.mtp_hidden_state.size());
mtp_accept_tokens(mtp_target, ids, mtp_n_past_base, slot.id);
}
@ -3775,7 +3775,7 @@ static void restore_speculative_checkpoint(
llama_context * mtp_ctx_rej = common_speculative_get_mtp_ctx(slot.spec);
llama_context * mtp_target_rej = mtp_ctx_rej ? mtp_ctx_rej : ctx;
llama_set_draft_input_hidden_state(mtp_target_rej, slot.mtp_hidden_state.data());
llama_set_draft_input_hidden_state(mtp_target_rej, slot.mtp_hidden_state.data(), slot.mtp_hidden_state.size());
mtp_accept_tokens(mtp_target_rej, ids, slot.spec_ckpt.n_past, slot.id);
if (n_accepted > 1) {
@ -3882,7 +3882,7 @@ void server_context::speculative_decoding_accept() {
llama_context * mtp_target = mtp_ctx ? mtp_ctx : ctx;
slot.mtp_hidden_state = std::move(mtp_hidden_state_pre);
llama_set_draft_input_hidden_state(mtp_target, slot.mtp_hidden_state.data());
llama_set_draft_input_hidden_state(mtp_target, slot.mtp_hidden_state.data(), slot.mtp_hidden_state.size());
mtp_accept_tokens(mtp_target, ids, mtp_n_past_base, slot.id);
}
llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1);
@ -4403,7 +4403,7 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
}
}
llama_context * mtp_target = mtp_ctx ? mtp_ctx : ctx;
llama_set_draft_input_hidden_state(mtp_target, batch_mtp_hidden_state.data());
llama_set_draft_input_hidden_state(mtp_target, batch_mtp_hidden_state.data(), batch_mtp_hidden_state.size());
mtp_update_kv_cache(mtp_target, batch_view, true);
}

View File

@ -1560,7 +1560,7 @@ LLAMA_API struct llama_grammar* llama_sampler_init_grammar_lazy_patterns(
// Set which, if any, MTP operation the context will use
LLAMA_API void llama_set_mtp_op_type(struct llama_context * ctx, enum llama_mtp_op_type mtp_op_type);
LLAMA_API void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float * hidden_state);
LLAMA_API void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float * hidden_state, size_t size);
#ifdef __cplusplus
}

View File

@ -217,4 +217,4 @@ struct ggml_tensor * llm_build_context::build_qwen35_mtp(
cb(cur, "result_output", -1);
return cur;
}
}

View File

@ -264,6 +264,7 @@ struct llama_context {
void * abort_callback_data = nullptr;
const float * draft_input_hidden_state = nullptr;
size_t draft_input_hidden_size = 0;
// input tensors
struct ggml_tensor * inp_tokens; // I32 [n_batch]

View File

@ -4213,8 +4213,14 @@ static bool prepare_mtp_graph_inputs(struct llama_context & lctx) {
LLAMA_LOG_ERROR("%s: Source hidden state is null\n", __func__);
return false;
}
auto nbytes = ggml_nbytes(dst);
if (nbytes > lctx.draft_input_hidden_size*sizeof(float)) {
LLAMA_LOG_ERROR("%s: saved hidden state size %zu is less than input MTP state %zu\n", __func__,
lctx.draft_input_hidden_size, nbytes/sizeof(float));
return false;
}
ggml_backend_tensor_set(dst, src, 0, ggml_nbytes(dst));
ggml_backend_tensor_set(dst, src, 0, nbytes);
return true;
}
@ -9841,8 +9847,9 @@ void llama_set_offload_policy(struct llama_context * lctx, int op, bool on_or_of
ggml_backend_sched_set_op_offload(lctx->sched, ggml_op(op), on_or_off);
}
void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float * hidden_state) {
void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float * hidden_state, size_t size) {
ctx->draft_input_hidden_state = hidden_state;
ctx->draft_input_hidden_size = size;
}
size_t llama_fill_from_utf8(void* utf8, void* cpts, void* scripts) {