From 08e4590dcb00a17e4b11e17b35a0c46a088287d3 Mon Sep 17 00:00:00 2001 From: SamuelOliveirads Date: Thu, 4 Jun 2026 20:45:12 -0300 Subject: [PATCH] implement gpu argmax --- common/speculative-impl.h | 7 ++++++- include/llama.h | 6 +++++- src/graphs/build_dflash.cpp | 6 ++++++ src/llama-context.h | 6 ++++++ src/llama.cpp | 39 +++++++++++++++++++++++++++++++++++++ 5 files changed, 62 insertions(+), 2 deletions(-) diff --git a/common/speculative-impl.h b/common/speculative-impl.h index 47603461..ccec2e9e 100644 --- a/common/speculative-impl.h +++ b/common/speculative-impl.h @@ -321,7 +321,12 @@ struct common_speculative_state_dflash : public common_speculative_state { result.reserve((size_t) n_keep); const int64_t t_sample_us = ggml_time_us(); for (int32_t i = 0; i < n_keep; ++i) { - result.push_back(common_sampler_sample_speculative(nullptr, ctx_dft, i + 1, nullptr)); + // Use argmax in GPU when available + llama_token id = llama_get_dflash_draft_token_ith(ctx_dft, i); + if (id == LLAMA_TOKEN_NULL) { + id = common_sampler_sample_speculative(nullptr, ctx_dft, i + 1, nullptr); + } + result.push_back(id); } t_draft_sample_us += (uint64_t) (ggml_time_us() - t_sample_us); diff --git a/include/llama.h b/include/llama.h index f89d82ef..754a0643 100644 --- a/include/llama.h +++ b/include/llama.h @@ -53,7 +53,7 @@ #define LLAMA_STATE_SEQ_VERSION 3 #define LLAMA_SERVER_MAGIC 0x6c6d7376u // 'lmsv' -#define LLAMA_SERVER_VERSION 1 +#define LLAMA_SERVER_VERSION 1 #ifdef __cplusplus extern "C" { @@ -1096,6 +1096,10 @@ extern "C" { // returns NULL for invalid ids. LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i); + // Get the argmax token ID for DFlash draft position i without materializing full logits. + // Returns LLAMA_TOKEN_NULL if argmax is not available (falls back to logits path). + LLAMA_API llama_token llama_get_dflash_draft_token_ith(struct llama_context * ctx, int32_t i); + // Get all output token embeddings. // when pooling_type == LLAMA_POOLING_TYPE_NONE or when using a generative model, // the embeddings for which llama_batch.logits[i] != 0 are stored contiguously diff --git a/src/graphs/build_dflash.cpp b/src/graphs/build_dflash.cpp index 80c45c1e..4cbfc147 100644 --- a/src/graphs/build_dflash.cpp +++ b/src/graphs/build_dflash.cpp @@ -565,5 +565,11 @@ ggml_cgraph * llm_build_context::build_dflash() { cb(result, "result_output", -1); ggml_build_forward_expand(gf, result); + lctx.dflash_draft_tokens_tensor = nullptr; + ggml_tensor * draft_tokens = ggml_argmax(ctx0, result); + ggml_set_name(draft_tokens, "draft_argmax"); + ggml_build_forward_expand(gf, draft_tokens); + lctx.dflash_draft_tokens_tensor = draft_tokens; + return gf; } diff --git a/src/llama-context.h b/src/llama-context.h index ebd4ded3..fec7edbb 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -393,6 +393,12 @@ struct llama_context { int32_t & dflash_visible_cross_ctx = dflash.visible_cross_ctx; std::vector & dflash_k_ctx_cache = dflash.kv.k_ctx_cache; std::vector & dflash_v_ctx_cache = dflash.kv.v_ctx_cache; + + // Argmax token IDs from the DFlash draft graph, computed via GPU argmax. + // Populated in llama_decode_internal after graph compute. + std::vector dflash_draft_tokens; + struct ggml_tensor * dflash_draft_tokens_tensor = nullptr; + std::vector & dflash_k_ctx_workspace = dflash.kv.k_ctx_workspace; std::vector & dflash_v_ctx_workspace = dflash.kv.v_ctx_workspace; struct ggml_context * & dflash_cache_ctx = dflash.kv.cache_ctx; diff --git a/src/llama.cpp b/src/llama.cpp index 75482d80..0e2430e4 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -5716,6 +5716,17 @@ static int llama_decode_internal( struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1]; struct ggml_tensor * embd = nullptr; + // DFlash GPU argmax draft_argmax node + if (lctx.dflash_draft_tokens_tensor != nullptr && + strcmp(res->name, "result_output") != 0) { + for (int i = gf->n_nodes - 2; i >= 0; --i) { + if (strcmp(gf->nodes[i]->name, "result_output") == 0) { + res = gf->nodes[i]; + break; + } + } + } + if (lctx.n_outputs == 0) { // no output res = nullptr; @@ -5813,7 +5824,28 @@ static int llama_decode_internal( // ggml_graph_dump_dot(gf, NULL, "llama.dot"); //} + lctx.dflash_draft_tokens.clear(); + if (lctx.dflash_draft_tokens_tensor != nullptr) { + ggml_backend_t backend_argmax = ggml_backend_sched_get_tensor_backend( + lctx.sched, lctx.dflash_draft_tokens_tensor); + if (backend_argmax != nullptr) { + const int64_t n_tokens_argmax = lctx.dflash_draft_tokens_tensor->ne[0]; + lctx.dflash_draft_tokens.resize((size_t) n_tokens_argmax); + ggml_backend_tensor_get_async(backend_argmax, + lctx.dflash_draft_tokens_tensor, + lctx.dflash_draft_tokens.data(), 0, + (size_t) n_tokens_argmax * sizeof(int32_t)); + } + } + // extract logits + { + const bool dflash_skip_logits = (lctx.model.arch == LLM_ARCH_DFLASH_DRAFT + && !lctx.dflash_draft_tokens.empty()); + if (dflash_skip_logits) { + res = nullptr; + } + } if (res) { #if IK_PRINT_TIMING tim1 = ggml_time_us(); @@ -10068,6 +10100,13 @@ float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) { } } +llama_token llama_get_dflash_draft_token_ith(struct llama_context * ctx, int32_t i) { + if ((size_t) i >= ctx->dflash_draft_tokens.size()) { + return LLAMA_TOKEN_NULL; + } + return ctx->dflash_draft_tokens[(size_t) i]; +} + float * llama_get_embeddings(struct llama_context * ctx) { llama_synchronize(ctx);