mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-06-28 04:30:15 -05:00
Change MTP graph input preparation with additional parameters and validation checks (#1866)
This commit is contained in:
parent
b3d39cff8b
commit
19e09e81d4
@ -4885,19 +4885,46 @@ static void llama_graph_compute(
|
||||
// fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(lctx.sched));
|
||||
}
|
||||
|
||||
static bool prepare_mtp_graph_inputs(struct llama_context & lctx) {
|
||||
static bool prepare_mtp_graph_inputs(
|
||||
struct llama_context & lctx,
|
||||
uint32_t cur_token,
|
||||
uint32_t n_tokens,
|
||||
uint32_t n_tokens_all) {
|
||||
ggml_tensor * dst = lctx.inp_mtp_states;
|
||||
const float * src = lctx.draft_input_hidden_state;
|
||||
const size_t expected_floats = ggml_nbytes(dst) / sizeof(float);
|
||||
const size_t total_floats = lctx.draft_input_hidden_state_n_floats;
|
||||
const float * src = lctx.draft_input_hidden_state;
|
||||
|
||||
if (!src) {
|
||||
LLAMA_LOG_ERROR("%s: Source hidden state is null\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (lctx.draft_input_hidden_state_n_floats != expected_floats) {
|
||||
if (total_floats != expected_floats) {
|
||||
if (n_tokens_all == 0 || total_floats % n_tokens_all != 0) {
|
||||
LLAMA_LOG_ERROR("%s: Source hidden state size mismatch (have %zu floats, need %zu)\n",
|
||||
__func__, total_floats, expected_floats);
|
||||
return false;
|
||||
}
|
||||
|
||||
const size_t row_width = total_floats / n_tokens_all;
|
||||
const size_t slice_floats = row_width * n_tokens;
|
||||
const size_t slice_offset = (size_t) cur_token * row_width;
|
||||
|
||||
if (slice_floats != expected_floats ||
|
||||
slice_offset > total_floats ||
|
||||
total_floats - slice_offset < expected_floats) {
|
||||
LLAMA_LOG_ERROR("%s: Source hidden state size mismatch (have %zu floats, need %zu)\n",
|
||||
__func__, total_floats, expected_floats);
|
||||
return false;
|
||||
}
|
||||
|
||||
src += slice_offset;
|
||||
}
|
||||
|
||||
if (src == nullptr) {
|
||||
LLAMA_LOG_ERROR("%s: Source hidden state size mismatch (have %zu floats, need %zu)\n",
|
||||
__func__, lctx.draft_input_hidden_state_n_floats, expected_floats);
|
||||
__func__, total_floats, expected_floats);
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -5192,7 +5219,7 @@ static int llama_decode_internal(
|
||||
}
|
||||
|
||||
if (cparams.mtp_op_type != MTP_OP_NONE) {
|
||||
if (!prepare_mtp_graph_inputs(lctx)) {
|
||||
if (!prepare_mtp_graph_inputs(lctx, cur_token, n_tokens, n_tokens_all)) {
|
||||
return GGML_STATUS_FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user