Change MTP graph input preparation with additional parameters and validation checks (#1866)

This commit is contained in:
Samuel Oliveira Alves 2026-05-23 01:22:04 -03:00 committed by GitHub
parent b3d39cff8b
commit 19e09e81d4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -4885,19 +4885,46 @@ static void llama_graph_compute(
// fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(lctx.sched));
}
static bool prepare_mtp_graph_inputs(struct llama_context & lctx) {
static bool prepare_mtp_graph_inputs(
struct llama_context & lctx,
uint32_t cur_token,
uint32_t n_tokens,
uint32_t n_tokens_all) {
ggml_tensor * dst = lctx.inp_mtp_states;
const float * src = lctx.draft_input_hidden_state;
const size_t expected_floats = ggml_nbytes(dst) / sizeof(float);
const size_t total_floats = lctx.draft_input_hidden_state_n_floats;
const float * src = lctx.draft_input_hidden_state;
if (!src) {
LLAMA_LOG_ERROR("%s: Source hidden state is null\n", __func__);
return false;
}
if (lctx.draft_input_hidden_state_n_floats != expected_floats) {
if (total_floats != expected_floats) {
if (n_tokens_all == 0 || total_floats % n_tokens_all != 0) {
LLAMA_LOG_ERROR("%s: Source hidden state size mismatch (have %zu floats, need %zu)\n",
__func__, total_floats, expected_floats);
return false;
}
const size_t row_width = total_floats / n_tokens_all;
const size_t slice_floats = row_width * n_tokens;
const size_t slice_offset = (size_t) cur_token * row_width;
if (slice_floats != expected_floats ||
slice_offset > total_floats ||
total_floats - slice_offset < expected_floats) {
LLAMA_LOG_ERROR("%s: Source hidden state size mismatch (have %zu floats, need %zu)\n",
__func__, total_floats, expected_floats);
return false;
}
src += slice_offset;
}
if (src == nullptr) {
LLAMA_LOG_ERROR("%s: Source hidden state size mismatch (have %zu floats, need %zu)\n",
__func__, lctx.draft_input_hidden_state_n_floats, expected_floats);
__func__, total_floats, expected_floats);
return false;
}
@ -5192,7 +5219,7 @@ static int llama_decode_internal(
}
if (cparams.mtp_op_type != MTP_OP_NONE) {
if (!prepare_mtp_graph_inputs(lctx)) {
if (!prepare_mtp_graph_inputs(lctx, cur_token, n_tokens, n_tokens_all)) {
return GGML_STATUS_FAILED;
}
}