mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-06-28 04:30:15 -05:00
MTP: better graph reuse (#1713)
This commit is contained in:
parent
1beaaa002d
commit
2dd3818083
@ -289,6 +289,7 @@ struct llama_context {
|
||||
|
||||
struct Prev;
|
||||
std::unique_ptr<Prev> prev;
|
||||
std::unique_ptr<Prev> prev_mtp;
|
||||
|
||||
void reset_scheduler();
|
||||
bool can_reuse_graph(const llama_batch & u_batch);
|
||||
|
||||
105
src/llama.cpp
105
src/llama.cpp
@ -548,6 +548,7 @@ struct llama_context::Prev {
|
||||
int all_seq_id;
|
||||
int n_outputs;
|
||||
int n_kv;
|
||||
int n_tokens;
|
||||
llama_mtp_op_type mtp_op_type;
|
||||
ggml_cgraph * graph;
|
||||
};
|
||||
@ -555,29 +556,53 @@ struct llama_context::Prev {
|
||||
void llama_context::reset_scheduler() {
|
||||
ggml_backend_sched_reset(sched);
|
||||
prev.reset();
|
||||
prev_mtp.reset();
|
||||
}
|
||||
|
||||
bool llama_context::can_reuse_graph(const llama_batch & u_batch) {
|
||||
if (!prev || !prev->graph) return false;
|
||||
if (u_batch.n_tokens > 1) return false;
|
||||
if (u_batch.embd) return false;
|
||||
if (!cparams.graph_reuse) return false;
|
||||
return u_batch.all_seq_id == prev->all_seq_id &&
|
||||
auto the_prev = cparams.mtp_op_type == MTP_OP_NONE ? prev.get() : prev_mtp.get();
|
||||
if (!the_prev || !the_prev->graph) return false;
|
||||
//if (u_batch.n_tokens > 1) return false;
|
||||
if (u_batch.embd) return false;
|
||||
return u_batch.all_seq_id == the_prev->all_seq_id &&
|
||||
kv_self.head > 0 &&
|
||||
kv_self.n == prev->n_kv &&
|
||||
n_outputs == prev->n_outputs &&
|
||||
cparams.mtp_op_type == prev->mtp_op_type &&
|
||||
kv_self.n == the_prev->n_kv &&
|
||||
n_outputs == the_prev->n_outputs &&
|
||||
u_batch.n_tokens == the_prev->n_tokens &&
|
||||
cparams.mtp_op_type == the_prev->mtp_op_type &&
|
||||
update_cache_copies();
|
||||
}
|
||||
|
||||
/*
|
||||
static void why_not_reuse_previous(const llama_batch & u_batch, const llama_context & ctx, const llama_context::Prev * the_prev) {
|
||||
if (!the_prev) { printf(" previous is null\n"); return; }
|
||||
if (!the_prev->graph) { printf(" previous graph is null\n"); return; }
|
||||
if (!ctx.cparams.graph_reuse) { printf(" graph_reuse is false\n"); return; }
|
||||
if (u_batch.embd) { printf(" ubatch.embd is not null\n"); return; }
|
||||
if (u_batch.all_seq_id != the_prev->all_seq_id) { printf(" all_seq_id is not the same\n"); return; }
|
||||
if (ctx.kv_self.head == 0) { printf(" kv_self.head = 0\n"); return; }
|
||||
if (ctx.kv_self.n != the_prev->n_kv) { printf(" kv_self.n is not the same\n"); return; }
|
||||
if (ctx.n_outputs != the_prev->n_outputs) { printf(" n_outputs is not the same\n"); return; }
|
||||
if (u_batch.n_tokens != the_prev->n_tokens) { printf(" n_tokens is not the same\n"); return; }
|
||||
if (ctx.cparams.mtp_op_type != the_prev->mtp_op_type) { printf(" mtp_op_type is not the same\n"); return; }
|
||||
printf(" update_cache_copies() must have failed\n");
|
||||
}
|
||||
*/
|
||||
|
||||
bool llama_context::update_cache_copies() {
|
||||
const int n_layer = model.mtp ? model.hparams.n_layer
|
||||
: model.hparams.n_layer - model.hparams.nextn_predict_layers; //cache_copies.size()/2;
|
||||
const int n_layer = model.mtp && cparams.mtp_op_type != MTP_OP_NONE ?
|
||||
model.hparams.n_layer : model.hparams.n_layer - model.hparams.nextn_predict_layers; //cache_copies.size()/2;
|
||||
auto layer_has_attention_kv = [&](int il) {
|
||||
return !model.hparams.is_recurrent(il);
|
||||
};
|
||||
if ((int)kv_self.k_l.size() != n_layer) return false;
|
||||
if (!(kv_self.v_l.empty() || (int)kv_self.v_l.size() == n_layer)) return false;
|
||||
|
||||
if ((int)kv_self.k_l.size() < n_layer) {
|
||||
return false;
|
||||
}
|
||||
if (!kv_self.v_l.empty() && (int)kv_self.v_l.size() < n_layer) {
|
||||
return false;
|
||||
}
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
if (!layer_has_attention_kv(il) || kv_self.k_l[il] == nullptr) {
|
||||
continue;
|
||||
@ -594,7 +619,9 @@ bool llama_context::update_cache_copies() {
|
||||
for (int id = 0; id < kl->n_device; ++id) {
|
||||
if (!kl->splits[id]) continue;
|
||||
auto& c = cache_copies[2*model.splits.size()*il + 2*id + 0];
|
||||
if (!c.cpy || c.cpy->op != GGML_OP_CPY || c.cpy->view_src != kl->splits[id]) return false;
|
||||
if (!c.cpy || c.cpy->op != GGML_OP_CPY || c.cpy->view_src != kl->splits[id]) {
|
||||
return false;
|
||||
}
|
||||
c.cpy->view_offs = kv_self.head*c.step;
|
||||
c.cpy->src[1]->data = (char *)kl->splits[id]->data + c.cpy->view_offs;
|
||||
c.cpy->data = c.cpy->src[1]->data;
|
||||
@ -603,29 +630,26 @@ bool llama_context::update_cache_copies() {
|
||||
for (int id = 0; id < vl->n_device; ++id) {
|
||||
if (!vl->splits[id]) continue;
|
||||
auto& c = cache_copies[2*model.splits.size()*il + 2*id + 1];
|
||||
if (!c.cpy || c.cpy->op != GGML_OP_CPY || c.cpy->view_src != vl->splits[id]) return false;
|
||||
if (!c.cpy || c.cpy->op != GGML_OP_CPY || c.cpy->view_src != vl->splits[id]) {
|
||||
return false;
|
||||
}
|
||||
c.cpy->view_offs = kv_self.head*c.step;
|
||||
c.cpy->src[1]->data = (char *)vl->splits[id]->data + c.cpy->view_offs;
|
||||
c.cpy->data = c.cpy->src[1]->data;
|
||||
}
|
||||
} else {
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
if (!layer_has_attention_kv(il) || kv_self.k_l[il] == nullptr) {
|
||||
continue;
|
||||
}
|
||||
auto& c = cache_copies[2*il+0];
|
||||
if (!c.cpy || c.cpy->op != GGML_OP_CPY || c.cpy->view_src != kv_self.k_l[il]) return false;
|
||||
c.cpy->view_offs = kv_self.head*c.step;
|
||||
c.cpy->src[1]->data = (char *)kv_self.k_l[il]->data + c.cpy->view_offs;
|
||||
c.cpy->data = c.cpy->src[1]->data;
|
||||
auto& c = cache_copies[2*il+0];
|
||||
if (!c.cpy || c.cpy->op != GGML_OP_CPY || c.cpy->view_src != kv_self.k_l[il]) {
|
||||
return false;
|
||||
}
|
||||
if (kv_self.v_l.empty()) return true;
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
if (!layer_has_attention_kv(il) || kv_self.v_l[il] == nullptr) {
|
||||
continue;
|
||||
}
|
||||
c.cpy->view_offs = kv_self.head*c.step;
|
||||
c.cpy->src[1]->data = (char *)kv_self.k_l[il]->data + c.cpy->view_offs;
|
||||
c.cpy->data = c.cpy->src[1]->data;
|
||||
if (!kv_self.v_l.empty() && kv_self.v_l[il]) {
|
||||
auto& c = cache_copies[2*il+1];
|
||||
if (!c.cpy || c.cpy->op != GGML_OP_CPY || c.cpy->view_src != kv_self.v_l[il]) return false;
|
||||
if (!c.cpy || c.cpy->op != GGML_OP_CPY || c.cpy->view_src != kv_self.v_l[il]) {
|
||||
return false;
|
||||
}
|
||||
c.cpy->view_offs = kv_self.head*c.step;
|
||||
c.cpy->src[1]->data = (char *)kv_self.v_l[il]->data + c.cpy->view_offs;
|
||||
c.cpy->data = c.cpy->src[1]->data;
|
||||
@ -4438,21 +4462,15 @@ static int llama_decode_internal(
|
||||
printf("prelude(...): %d us\n", int(tim2-tim1));
|
||||
#endif
|
||||
|
||||
|
||||
//if (n_tokens_all == 1) {
|
||||
// printf("================= %s\n", __func__);
|
||||
// printf(" all_pos_0 = %d, all_pos_1 = %d, all_seq_id = %d\n", batch_all.all_pos_0, batch_all.all_pos_1, batch_all.all_seq_id);
|
||||
// printf(" embd = %p, logits = %p, token = %p\n", (const void *)batch_all.embd, (const void *)batch_all.logits, (const void *)batch_all.token);
|
||||
// printf(" n_outputs = %d, kv_self.n = %d\n", n_outputs, kv_self.n);
|
||||
//}
|
||||
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
|
||||
|
||||
#if IK_PRINT_TIMING
|
||||
tim1 = ggml_time_us();
|
||||
#endif
|
||||
auto & prev = cparams.mtp_op_type == MTP_OP_NONE ? lctx.prev : lctx.prev_mtp;
|
||||
ggml_cgraph * gf = nullptr;
|
||||
if (!lctx.can_reuse_graph(u_batch)) {
|
||||
lctx.reset_scheduler();
|
||||
//lctx.reset_scheduler();
|
||||
ggml_backend_sched_reset(lctx.sched);
|
||||
prev.reset();
|
||||
ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
|
||||
#if IK_PRINT_TIMING
|
||||
tim2 = ggml_time_us();
|
||||
@ -4476,14 +4494,15 @@ static int llama_decode_internal(
|
||||
tim2 = ggml_time_us();
|
||||
printf("sched_alloc_graph(...): %d us\n", int(tim2-tim1));
|
||||
#endif
|
||||
if (u_batch.n_tokens == 1 && u_batch.embd == nullptr && lctx.cparams.graph_reuse) {
|
||||
lctx.prev = std::make_unique<llama_context::Prev>(llama_context::Prev{
|
||||
//if (u_batch.n_tokens == 1 && u_batch.embd == nullptr && lctx.cparams.graph_reuse) {
|
||||
if (u_batch.embd == nullptr && lctx.cparams.graph_reuse) {
|
||||
prev = std::make_unique<llama_context::Prev>(llama_context::Prev{
|
||||
(int)u_batch.all_seq_id, (int)lctx.n_outputs, (int)lctx.kv_self.n,
|
||||
cparams.mtp_op_type, gf});
|
||||
(int)u_batch.n_tokens, cparams.mtp_op_type, gf});
|
||||
}
|
||||
} else {
|
||||
//printf("Reusing graph\n");
|
||||
gf = lctx.prev->graph;
|
||||
//printf("Reusing graph with n_kv = %d, n_tokens = %d\n", (int)prev->n_kv, (int)prev->n_tokens);
|
||||
gf = prev->graph;
|
||||
}
|
||||
|
||||
if (cparams.mtp_op_type != MTP_OP_NONE) {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user