implement gpu argmax

This commit is contained in:
SamuelOliveirads 2026-06-04 20:45:12 -03:00
parent dc43cdf06b
commit 08e4590dcb
5 changed files with 62 additions and 2 deletions

View File

@ -321,7 +321,12 @@ struct common_speculative_state_dflash : public common_speculative_state {
result.reserve((size_t) n_keep);
const int64_t t_sample_us = ggml_time_us();
for (int32_t i = 0; i < n_keep; ++i) {
result.push_back(common_sampler_sample_speculative(nullptr, ctx_dft, i + 1, nullptr));
// Use argmax in GPU when available
llama_token id = llama_get_dflash_draft_token_ith(ctx_dft, i);
if (id == LLAMA_TOKEN_NULL) {
id = common_sampler_sample_speculative(nullptr, ctx_dft, i + 1, nullptr);
}
result.push_back(id);
}
t_draft_sample_us += (uint64_t) (ggml_time_us() - t_sample_us);

View File

@ -53,7 +53,7 @@
#define LLAMA_STATE_SEQ_VERSION 3
#define LLAMA_SERVER_MAGIC 0x6c6d7376u // 'lmsv'
#define LLAMA_SERVER_VERSION 1
#define LLAMA_SERVER_VERSION 1
#ifdef __cplusplus
extern "C" {
@ -1096,6 +1096,10 @@ extern "C" {
// returns NULL for invalid ids.
LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i);
// Get the argmax token ID for DFlash draft position i without materializing full logits.
// Returns LLAMA_TOKEN_NULL if argmax is not available (falls back to logits path).
LLAMA_API llama_token llama_get_dflash_draft_token_ith(struct llama_context * ctx, int32_t i);
// Get all output token embeddings.
// when pooling_type == LLAMA_POOLING_TYPE_NONE or when using a generative model,
// the embeddings for which llama_batch.logits[i] != 0 are stored contiguously

View File

@ -565,5 +565,11 @@ ggml_cgraph * llm_build_context::build_dflash() {
cb(result, "result_output", -1);
ggml_build_forward_expand(gf, result);
lctx.dflash_draft_tokens_tensor = nullptr;
ggml_tensor * draft_tokens = ggml_argmax(ctx0, result);
ggml_set_name(draft_tokens, "draft_argmax");
ggml_build_forward_expand(gf, draft_tokens);
lctx.dflash_draft_tokens_tensor = draft_tokens;
return gf;
}

View File

@ -393,6 +393,12 @@ struct llama_context {
int32_t & dflash_visible_cross_ctx = dflash.visible_cross_ctx;
std::vector<struct ggml_tensor *> & dflash_k_ctx_cache = dflash.kv.k_ctx_cache;
std::vector<struct ggml_tensor *> & dflash_v_ctx_cache = dflash.kv.v_ctx_cache;
// Argmax token IDs from the DFlash draft graph, computed via GPU argmax.
// Populated in llama_decode_internal after graph compute.
std::vector<llama_token> dflash_draft_tokens;
struct ggml_tensor * dflash_draft_tokens_tensor = nullptr;
std::vector<struct ggml_tensor *> & dflash_k_ctx_workspace = dflash.kv.k_ctx_workspace;
std::vector<struct ggml_tensor *> & dflash_v_ctx_workspace = dflash.kv.v_ctx_workspace;
struct ggml_context * & dflash_cache_ctx = dflash.kv.cache_ctx;

View File

@ -5716,6 +5716,17 @@ static int llama_decode_internal(
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
struct ggml_tensor * embd = nullptr;
// DFlash GPU argmax draft_argmax node
if (lctx.dflash_draft_tokens_tensor != nullptr &&
strcmp(res->name, "result_output") != 0) {
for (int i = gf->n_nodes - 2; i >= 0; --i) {
if (strcmp(gf->nodes[i]->name, "result_output") == 0) {
res = gf->nodes[i];
break;
}
}
}
if (lctx.n_outputs == 0) {
// no output
res = nullptr;
@ -5813,7 +5824,28 @@ static int llama_decode_internal(
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
//}
lctx.dflash_draft_tokens.clear();
if (lctx.dflash_draft_tokens_tensor != nullptr) {
ggml_backend_t backend_argmax = ggml_backend_sched_get_tensor_backend(
lctx.sched, lctx.dflash_draft_tokens_tensor);
if (backend_argmax != nullptr) {
const int64_t n_tokens_argmax = lctx.dflash_draft_tokens_tensor->ne[0];
lctx.dflash_draft_tokens.resize((size_t) n_tokens_argmax);
ggml_backend_tensor_get_async(backend_argmax,
lctx.dflash_draft_tokens_tensor,
lctx.dflash_draft_tokens.data(), 0,
(size_t) n_tokens_argmax * sizeof(int32_t));
}
}
// extract logits
{
const bool dflash_skip_logits = (lctx.model.arch == LLM_ARCH_DFLASH_DRAFT
&& !lctx.dflash_draft_tokens.empty());
if (dflash_skip_logits) {
res = nullptr;
}
}
if (res) {
#if IK_PRINT_TIMING
tim1 = ggml_time_us();
@ -10068,6 +10100,13 @@ float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
}
}
llama_token llama_get_dflash_draft_token_ith(struct llama_context * ctx, int32_t i) {
if ((size_t) i >= ctx->dflash_draft_tokens.size()) {
return LLAMA_TOKEN_NULL;
}
return ctx->dflash_draft_tokens[(size_t) i];
}
float * llama_get_embeddings(struct llama_context * ctx) {
llama_synchronize(ctx);