Cleanup: Unify location of m-rope repacking for token and embd (#1924)

* unify location of rope-position-array rewriting prior to ubatching

* Reorder terms.
This commit is contained in:
Farmadupe 2026-06-12 07:27:50 +01:00 committed by GitHub
parent b1eb8bb0a1
commit d1339249d7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -4239,18 +4239,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
#endif
const int64_t n_tokens = batch.n_tokens;
const int n_pos_per_embd = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE || hparams.rope_type == LLAMA_ROPE_TYPE_IMROPE ? 4 : 1;
if (batch.token && n_pos_per_embd == 4) {
std::vector<llama_pos> pos_data(n_tokens*n_pos_per_embd);
for (int i = 0; i < n_tokens; ++i) {
pos_data[ i] = batch.pos[i];
pos_data[ n_tokens + i] = batch.pos[i];
pos_data[2 * n_tokens + i] = batch.pos[i];
pos_data[3 * n_tokens + i] = 0; // 4th dim is 0
}
ggml_backend_tensor_set(lctx.inp_pos, pos_data.data(), 0, pos_data.size()*ggml_element_size(lctx.inp_pos));
} else {
ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*n_pos_per_embd*ggml_element_size(lctx.inp_pos));
}
ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*n_pos_per_embd*ggml_element_size(lctx.inp_pos));
#if IK_PRINT_TIMING == 2
auto tim2 = ggml_time_us();
printf("set_inputs(pos): %d us\n", int(tim2-tim1));
@ -5162,23 +5151,34 @@ static int llama_decode_internal(
}
// Repack the rope buffer for the ubatch depending on type.
// * mrope: (section-major array of rope fields) [t; n][h; n][w; n][extra; n]
// * others: (flat array ) [t; n]
// Repack the rope buffer for the ubatch depending on whether the model uses mrope.
// * standard rope: (flat array) [t; n]
// * mrope (token): (section-major array of rope fields) [t; n][t; n][t; n][0; n]
// * mrope (embd): (section-major array of rope fields) [t; n][h; n][w; n][extra; n]
const uint8_t rope_params_per_token = (hparams.rope_type == LLAMA_ROPE_TYPE_MROPE ||
hparams.rope_type == LLAMA_ROPE_TYPE_IMROPE) ? 4 : 1;
llama_pos * u_batch_pos;
if (batch_all.pos && batch_all.embd && rope_params_per_token == 4) {
pos.resize((size_t) n_tokens * rope_params_per_token);
for (uint32_t i = 0; i < n_tokens; ++i) {
pos[0*n_tokens + i] = batch_all.pos[0*n_tokens_all + cur_token + i]; // t
pos[1*n_tokens + i] = batch_all.pos[1*n_tokens_all + cur_token + i]; // h
pos[2*n_tokens + i] = batch_all.pos[2*n_tokens_all + cur_token + i]; // w
pos[3*n_tokens + i] = batch_all.pos[3*n_tokens_all + cur_token + i]; // extra
}
u_batch_pos = pos.data();
if (!batch_all.pos) {
u_batch_pos = nullptr;
} else if (rope_params_per_token == 1) {
u_batch_pos = batch_all.pos + cur_token;
} else {
u_batch_pos = batch_all.pos ? batch_all.pos + cur_token : nullptr;
pos.resize((size_t) n_tokens * rope_params_per_token);
u_batch_pos = pos.data();
for (uint32_t i = 0; i < n_tokens; ++i) {
if (batch_all.token) {
pos[ i] = batch_all.pos[cur_token + i]; // t
pos[ n_tokens + i] = batch_all.pos[cur_token + i]; // t
pos[2 * n_tokens + i] = batch_all.pos[cur_token + i]; // t
pos[3 * n_tokens + i] = 0;
} else { //embd
pos[ i] = batch_all.pos[ cur_token + i]; // t
pos[ n_tokens + i] = batch_all.pos[ n_tokens_all + cur_token + i]; // h
pos[2 * n_tokens + i] = batch_all.pos[2 * n_tokens_all + cur_token + i]; // w
pos[3 * n_tokens + i] = batch_all.pos[3 * n_tokens_all + cur_token + i]; // extra (this field is supplied but unused)
}
}
}
llama_batch u_batch = {