mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-06-28 04:30:15 -05:00
Cleanup: Unify location of m-rope repacking for token and embd (#1924)
* unify location of rope-position-array rewriting prior to ubatching * Reorder terms.
This commit is contained in:
parent
b1eb8bb0a1
commit
d1339249d7
@ -4239,18 +4239,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||
#endif
|
||||
const int64_t n_tokens = batch.n_tokens;
|
||||
const int n_pos_per_embd = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE || hparams.rope_type == LLAMA_ROPE_TYPE_IMROPE ? 4 : 1;
|
||||
if (batch.token && n_pos_per_embd == 4) {
|
||||
std::vector<llama_pos> pos_data(n_tokens*n_pos_per_embd);
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
pos_data[ i] = batch.pos[i];
|
||||
pos_data[ n_tokens + i] = batch.pos[i];
|
||||
pos_data[2 * n_tokens + i] = batch.pos[i];
|
||||
pos_data[3 * n_tokens + i] = 0; // 4th dim is 0
|
||||
}
|
||||
ggml_backend_tensor_set(lctx.inp_pos, pos_data.data(), 0, pos_data.size()*ggml_element_size(lctx.inp_pos));
|
||||
} else {
|
||||
ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*n_pos_per_embd*ggml_element_size(lctx.inp_pos));
|
||||
}
|
||||
ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*n_pos_per_embd*ggml_element_size(lctx.inp_pos));
|
||||
#if IK_PRINT_TIMING == 2
|
||||
auto tim2 = ggml_time_us();
|
||||
printf("set_inputs(pos): %d us\n", int(tim2-tim1));
|
||||
@ -5162,23 +5151,34 @@ static int llama_decode_internal(
|
||||
}
|
||||
|
||||
|
||||
// Repack the rope buffer for the ubatch depending on type.
|
||||
// * mrope: (section-major array of rope fields) [t; n][h; n][w; n][extra; n]
|
||||
// * others: (flat array ) [t; n]
|
||||
// Repack the rope buffer for the ubatch depending on whether the model uses mrope.
|
||||
// * standard rope: (flat array) [t; n]
|
||||
// * mrope (token): (section-major array of rope fields) [t; n][t; n][t; n][0; n]
|
||||
// * mrope (embd): (section-major array of rope fields) [t; n][h; n][w; n][extra; n]
|
||||
const uint8_t rope_params_per_token = (hparams.rope_type == LLAMA_ROPE_TYPE_MROPE ||
|
||||
hparams.rope_type == LLAMA_ROPE_TYPE_IMROPE) ? 4 : 1;
|
||||
|
||||
llama_pos * u_batch_pos;
|
||||
if (batch_all.pos && batch_all.embd && rope_params_per_token == 4) {
|
||||
pos.resize((size_t) n_tokens * rope_params_per_token);
|
||||
for (uint32_t i = 0; i < n_tokens; ++i) {
|
||||
pos[0*n_tokens + i] = batch_all.pos[0*n_tokens_all + cur_token + i]; // t
|
||||
pos[1*n_tokens + i] = batch_all.pos[1*n_tokens_all + cur_token + i]; // h
|
||||
pos[2*n_tokens + i] = batch_all.pos[2*n_tokens_all + cur_token + i]; // w
|
||||
pos[3*n_tokens + i] = batch_all.pos[3*n_tokens_all + cur_token + i]; // extra
|
||||
}
|
||||
u_batch_pos = pos.data();
|
||||
if (!batch_all.pos) {
|
||||
u_batch_pos = nullptr;
|
||||
} else if (rope_params_per_token == 1) {
|
||||
u_batch_pos = batch_all.pos + cur_token;
|
||||
} else {
|
||||
u_batch_pos = batch_all.pos ? batch_all.pos + cur_token : nullptr;
|
||||
pos.resize((size_t) n_tokens * rope_params_per_token);
|
||||
u_batch_pos = pos.data();
|
||||
for (uint32_t i = 0; i < n_tokens; ++i) {
|
||||
if (batch_all.token) {
|
||||
pos[ i] = batch_all.pos[cur_token + i]; // t
|
||||
pos[ n_tokens + i] = batch_all.pos[cur_token + i]; // t
|
||||
pos[2 * n_tokens + i] = batch_all.pos[cur_token + i]; // t
|
||||
pos[3 * n_tokens + i] = 0;
|
||||
} else { //embd
|
||||
pos[ i] = batch_all.pos[ cur_token + i]; // t
|
||||
pos[ n_tokens + i] = batch_all.pos[ n_tokens_all + cur_token + i]; // h
|
||||
pos[2 * n_tokens + i] = batch_all.pos[2 * n_tokens_all + cur_token + i]; // w
|
||||
pos[3 * n_tokens + i] = batch_all.pos[3 * n_tokens_all + cur_token + i]; // extra (this field is supplied but unused)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
llama_batch u_batch = {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user