mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-06-28 04:30:15 -05:00
implement gpu argmax
This commit is contained in:
parent
dc43cdf06b
commit
08e4590dcb
@ -321,7 +321,12 @@ struct common_speculative_state_dflash : public common_speculative_state {
|
||||
result.reserve((size_t) n_keep);
|
||||
const int64_t t_sample_us = ggml_time_us();
|
||||
for (int32_t i = 0; i < n_keep; ++i) {
|
||||
result.push_back(common_sampler_sample_speculative(nullptr, ctx_dft, i + 1, nullptr));
|
||||
// Use argmax in GPU when available
|
||||
llama_token id = llama_get_dflash_draft_token_ith(ctx_dft, i);
|
||||
if (id == LLAMA_TOKEN_NULL) {
|
||||
id = common_sampler_sample_speculative(nullptr, ctx_dft, i + 1, nullptr);
|
||||
}
|
||||
result.push_back(id);
|
||||
}
|
||||
t_draft_sample_us += (uint64_t) (ggml_time_us() - t_sample_us);
|
||||
|
||||
|
||||
@ -53,7 +53,7 @@
|
||||
#define LLAMA_STATE_SEQ_VERSION 3
|
||||
|
||||
#define LLAMA_SERVER_MAGIC 0x6c6d7376u // 'lmsv'
|
||||
#define LLAMA_SERVER_VERSION 1
|
||||
#define LLAMA_SERVER_VERSION 1
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
@ -1096,6 +1096,10 @@ extern "C" {
|
||||
// returns NULL for invalid ids.
|
||||
LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i);
|
||||
|
||||
// Get the argmax token ID for DFlash draft position i without materializing full logits.
|
||||
// Returns LLAMA_TOKEN_NULL if argmax is not available (falls back to logits path).
|
||||
LLAMA_API llama_token llama_get_dflash_draft_token_ith(struct llama_context * ctx, int32_t i);
|
||||
|
||||
// Get all output token embeddings.
|
||||
// when pooling_type == LLAMA_POOLING_TYPE_NONE or when using a generative model,
|
||||
// the embeddings for which llama_batch.logits[i] != 0 are stored contiguously
|
||||
|
||||
@ -565,5 +565,11 @@ ggml_cgraph * llm_build_context::build_dflash() {
|
||||
cb(result, "result_output", -1);
|
||||
ggml_build_forward_expand(gf, result);
|
||||
|
||||
lctx.dflash_draft_tokens_tensor = nullptr;
|
||||
ggml_tensor * draft_tokens = ggml_argmax(ctx0, result);
|
||||
ggml_set_name(draft_tokens, "draft_argmax");
|
||||
ggml_build_forward_expand(gf, draft_tokens);
|
||||
lctx.dflash_draft_tokens_tensor = draft_tokens;
|
||||
|
||||
return gf;
|
||||
}
|
||||
|
||||
@ -393,6 +393,12 @@ struct llama_context {
|
||||
int32_t & dflash_visible_cross_ctx = dflash.visible_cross_ctx;
|
||||
std::vector<struct ggml_tensor *> & dflash_k_ctx_cache = dflash.kv.k_ctx_cache;
|
||||
std::vector<struct ggml_tensor *> & dflash_v_ctx_cache = dflash.kv.v_ctx_cache;
|
||||
|
||||
// Argmax token IDs from the DFlash draft graph, computed via GPU argmax.
|
||||
// Populated in llama_decode_internal after graph compute.
|
||||
std::vector<llama_token> dflash_draft_tokens;
|
||||
struct ggml_tensor * dflash_draft_tokens_tensor = nullptr;
|
||||
|
||||
std::vector<struct ggml_tensor *> & dflash_k_ctx_workspace = dflash.kv.k_ctx_workspace;
|
||||
std::vector<struct ggml_tensor *> & dflash_v_ctx_workspace = dflash.kv.v_ctx_workspace;
|
||||
struct ggml_context * & dflash_cache_ctx = dflash.kv.cache_ctx;
|
||||
|
||||
@ -5716,6 +5716,17 @@ static int llama_decode_internal(
|
||||
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
|
||||
struct ggml_tensor * embd = nullptr;
|
||||
|
||||
// DFlash GPU argmax draft_argmax node
|
||||
if (lctx.dflash_draft_tokens_tensor != nullptr &&
|
||||
strcmp(res->name, "result_output") != 0) {
|
||||
for (int i = gf->n_nodes - 2; i >= 0; --i) {
|
||||
if (strcmp(gf->nodes[i]->name, "result_output") == 0) {
|
||||
res = gf->nodes[i];
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (lctx.n_outputs == 0) {
|
||||
// no output
|
||||
res = nullptr;
|
||||
@ -5813,7 +5824,28 @@ static int llama_decode_internal(
|
||||
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
|
||||
//}
|
||||
|
||||
lctx.dflash_draft_tokens.clear();
|
||||
if (lctx.dflash_draft_tokens_tensor != nullptr) {
|
||||
ggml_backend_t backend_argmax = ggml_backend_sched_get_tensor_backend(
|
||||
lctx.sched, lctx.dflash_draft_tokens_tensor);
|
||||
if (backend_argmax != nullptr) {
|
||||
const int64_t n_tokens_argmax = lctx.dflash_draft_tokens_tensor->ne[0];
|
||||
lctx.dflash_draft_tokens.resize((size_t) n_tokens_argmax);
|
||||
ggml_backend_tensor_get_async(backend_argmax,
|
||||
lctx.dflash_draft_tokens_tensor,
|
||||
lctx.dflash_draft_tokens.data(), 0,
|
||||
(size_t) n_tokens_argmax * sizeof(int32_t));
|
||||
}
|
||||
}
|
||||
|
||||
// extract logits
|
||||
{
|
||||
const bool dflash_skip_logits = (lctx.model.arch == LLM_ARCH_DFLASH_DRAFT
|
||||
&& !lctx.dflash_draft_tokens.empty());
|
||||
if (dflash_skip_logits) {
|
||||
res = nullptr;
|
||||
}
|
||||
}
|
||||
if (res) {
|
||||
#if IK_PRINT_TIMING
|
||||
tim1 = ggml_time_us();
|
||||
@ -10068,6 +10100,13 @@ float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
|
||||
}
|
||||
}
|
||||
|
||||
llama_token llama_get_dflash_draft_token_ith(struct llama_context * ctx, int32_t i) {
|
||||
if ((size_t) i >= ctx->dflash_draft_tokens.size()) {
|
||||
return LLAMA_TOKEN_NULL;
|
||||
}
|
||||
return ctx->dflash_draft_tokens[(size_t) i];
|
||||
}
|
||||
|
||||
float * llama_get_embeddings(struct llama_context * ctx) {
|
||||
llama_synchronize(ctx);
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user