diff --git a/src/llama.cpp b/src/llama.cpp index 662b8ac9..a867a915 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -4239,18 +4239,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { #endif const int64_t n_tokens = batch.n_tokens; const int n_pos_per_embd = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE || hparams.rope_type == LLAMA_ROPE_TYPE_IMROPE ? 4 : 1; - if (batch.token && n_pos_per_embd == 4) { - std::vector pos_data(n_tokens*n_pos_per_embd); - for (int i = 0; i < n_tokens; ++i) { - pos_data[ i] = batch.pos[i]; - pos_data[ n_tokens + i] = batch.pos[i]; - pos_data[2 * n_tokens + i] = batch.pos[i]; - pos_data[3 * n_tokens + i] = 0; // 4th dim is 0 - } - ggml_backend_tensor_set(lctx.inp_pos, pos_data.data(), 0, pos_data.size()*ggml_element_size(lctx.inp_pos)); - } else { - ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*n_pos_per_embd*ggml_element_size(lctx.inp_pos)); - } + ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*n_pos_per_embd*ggml_element_size(lctx.inp_pos)); #if IK_PRINT_TIMING == 2 auto tim2 = ggml_time_us(); printf("set_inputs(pos): %d us\n", int(tim2-tim1)); @@ -5162,23 +5151,34 @@ static int llama_decode_internal( } - // Repack the rope buffer for the ubatch depending on type. - // * mrope: (section-major array of rope fields) [t; n][h; n][w; n][extra; n] - // * others: (flat array ) [t; n] + // Repack the rope buffer for the ubatch depending on whether the model uses mrope. + // * standard rope: (flat array) [t; n] + // * mrope (token): (section-major array of rope fields) [t; n][t; n][t; n][0; n] + // * mrope (embd): (section-major array of rope fields) [t; n][h; n][w; n][extra; n] const uint8_t rope_params_per_token = (hparams.rope_type == LLAMA_ROPE_TYPE_MROPE || hparams.rope_type == LLAMA_ROPE_TYPE_IMROPE) ? 4 : 1; + llama_pos * u_batch_pos; - if (batch_all.pos && batch_all.embd && rope_params_per_token == 4) { - pos.resize((size_t) n_tokens * rope_params_per_token); - for (uint32_t i = 0; i < n_tokens; ++i) { - pos[0*n_tokens + i] = batch_all.pos[0*n_tokens_all + cur_token + i]; // t - pos[1*n_tokens + i] = batch_all.pos[1*n_tokens_all + cur_token + i]; // h - pos[2*n_tokens + i] = batch_all.pos[2*n_tokens_all + cur_token + i]; // w - pos[3*n_tokens + i] = batch_all.pos[3*n_tokens_all + cur_token + i]; // extra - } - u_batch_pos = pos.data(); + if (!batch_all.pos) { + u_batch_pos = nullptr; + } else if (rope_params_per_token == 1) { + u_batch_pos = batch_all.pos + cur_token; } else { - u_batch_pos = batch_all.pos ? batch_all.pos + cur_token : nullptr; + pos.resize((size_t) n_tokens * rope_params_per_token); + u_batch_pos = pos.data(); + for (uint32_t i = 0; i < n_tokens; ++i) { + if (batch_all.token) { + pos[ i] = batch_all.pos[cur_token + i]; // t + pos[ n_tokens + i] = batch_all.pos[cur_token + i]; // t + pos[2 * n_tokens + i] = batch_all.pos[cur_token + i]; // t + pos[3 * n_tokens + i] = 0; + } else { //embd + pos[ i] = batch_all.pos[ cur_token + i]; // t + pos[ n_tokens + i] = batch_all.pos[ n_tokens_all + cur_token + i]; // h + pos[2 * n_tokens + i] = batch_all.pos[2 * n_tokens_all + cur_token + i]; // w + pos[3 * n_tokens + i] = batch_all.pos[3 * n_tokens_all + cur_token + i]; // extra (this field is supplied but unused) + } + } } llama_batch u_batch = {