WIP: Qwen3Next (#1266)

* qwen3next: add architecture support and recurrent-state fixes

* qwen3next: optimize broadcast sub and single-seq ssm conv

* cuda: build MoE row mapping on device in mul_mat_id

* cuda: add guarded multi-seq fast path for ssm_conv

* docs: update qwen3next perf report for cuda MoE/SSM tuning

* cuda: reduce qwen3next moe/ssm sync overhead and refresh eval

* qwen3next: split cpu/cuda eval builds and tune PP scheduling

* qwen3next: harden seq-state flow and support optional dense FFN layers

* qwen3next: trim delta-net graph overhead in chunking path

* qwen3next: remove redundant v_conv cont in delta path

* qwen3next: avoid extra cont on linear attention output

* qwen3next: drop redundant cont before recurrent state flatten

* qwen3next: keep recurrent state in 4d layout through delta path

* qwen3next: add fused delta-net op and wire model path

* tests: add backend-op coverage for ggml_delta_net

* qwen3next: add runtime switch for fused delta-net path

* docs: refresh qwen3next perf review and benchmark matrix

* qwen3next: default fused delta-net off and document quality checks

* qwen3next: add decode-only fused delta mode

* qwen3next: make fused delta safe by default and fix fused tensor layout

* qwen3next: warn when forcing fused decode mode

* qwen3next: add fused-delta regression runner script

* qwen3next: integrate fused regression into eval harness

* qwen3next: clean up chunked delta-net shape handling

* qwen3next: add absolute sanity guards to fused regression

* qwen3next: add unified regression runner script

* qwen3next: disable flash-attn for cpu-only contexts

* docs: reconcile qwen3next status and remaining upstream gaps

* common: add qwen3next fused-delta runtime flag

* cuda: add qwen3next delta-net kernel dispatch override

* docs: update qwen3next quality and serving baseline findings

* qwen3next: keep fused delta on safe path and remove PR artifacts

* qwen3next: align autoregressive delta-net decode layout

* Revert "qwen3next: align autoregressive delta-net decode layout"

This reverts commit 9241164a5ea9e032a2456fbf2dd0bf798b264fd7.

* cuda: port solve-tri fast-paths for qwen3next delta-net

* qwen3next: add fused-delta runtime flag and drop env toggle

* qwen3next: make fused delta single-flag and default on

* Account for GPU arch differences

* Revert "cuda: build MoE row mapping on device in mul_mat_id"

This reverts commit 89e9ecfa840b04e88699ab3803eb732cd78727f9.

* qwen3next: drop non-essential MoE scheduling and split heuristics

* qwen3next: avoid generic ggml_sub broadcast changes

* llama: restore only_active_experts log message

* Remove unnecessary hacks, disable fusion for now.

* qwen3next: port hybrid recurrent state memory semantics

* qwen3next: clean up recurrent state slot plumbing

* qwen3next: fix hybrid V-cache layout plumbing

* qwen3next: guard recurrent state slots against kv capacity

* qwen3next: persist recurrent state in session data

- serialize/restore qwen3next cache.s_l in state/session paths\n- bump session and sequence-state file versions for format change\n- fallback to single-token chunking for mixed repeated seq_id batches

* qwen3next: drop unused fused-delta builder path

- remove dead build_delta_net_fused lambda\n- remove unused llm_build_context::fused_delta member

* qwen3next: remove unused fused-delta CLI/context plumbing

- drop -fd/-no-fd options and related YAML dump field\n- remove fused_delta fields from public/internal context params\n- remove fused_delta assignment and logging in context init

* ggml: remove unused DELTA_NET operator stack

* Missing include

* Reorder ops/unary ops

So we don't change again the enum values of the mul mat ops

* Minor

* Discard unnecessary changes in llama-build-context.cpp

* Minor

* Revert "Discard unnecessary changes in llama-build-context.cpp"

This reverts commit edadb80ed68c4c0831e9c22609a9a3af19be9735.

* Increase GGML_SCHED_MAX_SPLITS - required for larger u-batches

* Fix CPU concat in the TG case: 7.25 -> 10.5 t/s for Qwen3Next

* Fix CPU sum_rows: 10.5 -> 13.6 t/s for Qwen3Next

It was single-threaded and was taking ~25% of the computation time
during TG. It is now down to 2%.

Strangely enough, I measure 13.6 t/s with llama-bench, but if I
let the model give me an actual response with llama-cli, I get close
to 17 t/s.

* Fix CPU scale: 13.6 -> 16.7 t/s for Qwen3Next

For Qwen3Next there is a scale op on a largish tensor (548k elements)
that has a single row for TG, so was done in a single thread.
We now simply use blocks of 1024 elements.

* Optimize CPU mul: 16.7 -> 17.6 t/s for Qwen3Next

* CPU: fuse transpose -> cont -> sum_rows -> transpos: 17.6 -> 23.1 t/s for Qwen3Next

* Optimize CPU repeat: 176 -> 200 t/s for Qwen3Next PP-512

* Multithreading for OP_SUB

* Don't commit with timing trace on

* Multithread neg and sigmoid

* Be able to turn on/off fusion more easily (CPU)

* Name the mul_mat ops so we know where the time goes

* WIP

* Much better PP on CUDA

* CUDA: fuse transpose -> cont -> sum_rows -> transpose

Needs non-coontiguous variant of sum_rows.
On the CPU this gave 30+% improvement in TG performance,
on CUDA ist is disapointing 6-7%. I guess, this is because
Georgi's cont CPU implementation was so bad that skipping
it made such a big difference.

* CUDA: faster mul for special case relevant for Qwen3Next

Worth 1% in TG

* Fix CPU OP_CONT

---------

Co-authored-by: yurko <yurko@local>
Co-authored-by: Yurko <yurko@example.com>
Co-authored-by: yurko <yurko@pop-os.tail5a1a6b.ts.net>
Co-authored-by: Yurko Hoshko <YurkoHoshko@users.noreply.github.com>
This commit is contained in:
Kawrakow 2026-02-16 06:50:28 +01:00 committed by GitHub
parent 528cadb07b
commit e30198a553
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
35 changed files with 4600 additions and 232 deletions

View File

@ -21,6 +21,7 @@
#include <climits>
#include <cmath>
#include <codecvt>
#include <cstdlib>
#include <cstdarg>
#include <cstring>
#include <ctime>
@ -494,6 +495,7 @@ void gpt_params_parse_from_env(gpt_params & params) {
get_env("LLAMA_ARG_CONT_BATCHING", params.cont_batching);
get_env("LLAMA_ARG_HOST", params.hostname);
get_env("LLAMA_ARG_PORT", params.port);
}
bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {

View File

@ -673,6 +673,12 @@ extern "C" {
GGML_OP_ADD_REL_POS,
GGML_OP_UNARY,
GGML_OP_CUMSUM,
GGML_OP_L2_NORM,
GGML_OP_TRI,
GGML_OP_FILL,
GGML_OP_SOLVE_TRI,
GGML_OP_MAP_UNARY,
GGML_OP_MAP_BINARY,
@ -713,6 +719,8 @@ extern "C" {
GGML_UNARY_OP_SWIGLU,
GGML_UNARY_OP_SWIGLU_OAI,
GGML_UNARY_OP_GELU,
GGML_UNARY_OP_EXP,
GGML_UNARY_OP_SOFTPLUS,
GGML_UNARY_OP_COUNT,
};
@ -739,6 +747,13 @@ extern "C" {
GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up)
};
enum ggml_tri_type {
GGML_TRI_TYPE_LOWER,
GGML_TRI_TYPE_UPPER,
GGML_TRI_TYPE_LOWER_DIAG,
GGML_TRI_TYPE_UPPER_DIAG,
};
// ggml object
struct ggml_object {
size_t offs;
@ -1189,6 +1204,14 @@ extern "C" {
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_softplus(
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_softplus_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a);
// return scalar
GGML_API struct ggml_tensor * ggml_sum(
struct ggml_context * ctx,
@ -1199,6 +1222,10 @@ extern "C" {
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_cumsum(
struct ggml_context * ctx,
struct ggml_tensor * a);
// mean along rows
GGML_API struct ggml_tensor * ggml_mean(
struct ggml_context * ctx,
@ -1217,6 +1244,15 @@ extern "C" {
struct ggml_tensor * a,
struct ggml_tensor * b);
// repeat a to specified shape
GGML_API struct ggml_tensor * ggml_repeat_4d(
struct ggml_context * ctx,
struct ggml_tensor * a,
int64_t ne0,
int64_t ne1,
int64_t ne2,
int64_t ne3);
// sums repetitions in a into shape of b
GGML_API struct ggml_tensor * ggml_repeat_back(
struct ggml_context * ctx,
@ -1455,6 +1491,14 @@ extern "C" {
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_exp(
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_exp_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a);
// normalize along rows
GGML_API struct ggml_tensor * ggml_norm(
struct ggml_context * ctx,
@ -1514,6 +1558,17 @@ extern "C" {
int n_groups,
float eps);
// l2 normalize along rows
GGML_API struct ggml_tensor * ggml_l2_norm(
struct ggml_context * ctx,
struct ggml_tensor * a,
float eps);
GGML_API struct ggml_tensor * ggml_l2_norm_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
float eps);
// a - x
// b - dy
GGML_API struct ggml_tensor * ggml_rms_norm_back(
@ -2283,6 +2338,23 @@ extern "C" {
int dim,
int max_period);
// convert matrix to triangular form by zeroing values outside selected half
GGML_API struct ggml_tensor * ggml_tri(
struct ggml_context * ctx,
struct ggml_tensor * a,
enum ggml_tri_type type);
// fill tensor with constant c
GGML_API struct ggml_tensor * ggml_fill(
struct ggml_context * ctx,
struct ggml_tensor * a,
float c);
GGML_API struct ggml_tensor * ggml_fill_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
float c);
// sort rows
enum ggml_sort_order {
GGML_SORT_ORDER_ASC,
@ -2426,6 +2498,15 @@ extern "C" {
struct ggml_tensor * pw,
struct ggml_tensor * ph);
// Solve Ax = B where A is triangular
GGML_API struct ggml_tensor * ggml_solve_tri(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
bool left,
bool lower,
bool uni);
// custom operators
typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);

View File

@ -1103,7 +1103,7 @@ static bool ggml_is_view_op(enum ggml_op op) {
#endif
#ifndef GGML_SCHED_MAX_SPLITS
#define GGML_SCHED_MAX_SPLITS 2048
#define GGML_SCHED_MAX_SPLITS 4096
#endif
#ifndef GGML_SCHED_MAX_SPLIT_INPUTS

View File

@ -18,9 +18,11 @@
#include "ggml-cuda/concat.cuh"
#include "ggml-cuda/convert.cuh"
#include "ggml-cuda/cpy.cuh"
#include "ggml-cuda/cumsum.cuh"
#include "ggml-cuda/diagmask.cuh"
#include "ggml-cuda/dmmv.cuh"
#include "ggml-cuda/fattn.cuh"
#include "ggml-cuda/fill.cuh"
#include "ggml-cuda/getrows.cuh"
#include "ggml-cuda/im2col.cuh"
#include "ggml-cuda/mmq.cuh"
@ -46,10 +48,13 @@
#include "ggml-cuda/conv2d.cuh"
#include "ggml-cuda/conv2d-dw.cuh"
#include "ggml-cuda/set-rows.cuh"
#include "ggml-cuda/solve_tri.cuh"
#include "ggml-cuda/ssm-conv.cuh"
#include "ggml-cuda/argmax.cuh"
#include "ggml-cuda/multiadd.cuh"
#include "ggml-cuda/hadamard.cuh"
#include "ggml-cuda/reduce.cuh"
#include "ggml-cuda/tri.cuh"
#include <algorithm>
#include <array>
@ -2011,9 +2016,11 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
const int64_t r3 = ne13/ne03;
if (r2 == 1 && r3 == 1 && is_src0_cont_2 && is_src1_cont_2) {
//printf("Using cublasGemmStridedBatchedEx for %s\n", dst->name);
// with a [0, 2, 1, 3] perm. and ne02==1 the matrix strides need to be determined from dim 3:
const int64_t sma = ne02 == 1 ? nb03/nb00 : nb02/nb00;
const int64_t smb = ne12 == 1 ? s13 : s12;
//const int64_t smb = ne12 == 1 ? s13 : s12;
const int64_t smb = ne12 == 1 ? nb13/nb10 : nb12/nb10;
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
// use cublasGemmStridedBatchedEx
@ -2027,6 +2034,9 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
cu_compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
} else {
//printf("Using cublasGemmBatchedEx for %s\n", dst->name);
//printf(" src0: %ld x %ld x %ld x %ld; %zu x %zu x %zu x %zu\n",src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
//printf(" src1: %ld x %ld x %ld x %ld; %zu x %zu x %zu x %zu\n",src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3]);
// use cublasGemmBatchedEx
const int64_t ne23 = ne12*ne13;
@ -2238,22 +2248,29 @@ static int ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
//printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
if (any_gpus_with_slow_fp16 && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
//printf("%s(%s): using ggml_cuda_mul_mat_vec_p021\n", __func__, dst->name);
// FP32 precision KQ single-batch for batch size 1 without FlashAttention
ggml_cuda_mul_mat_vec_p021(ctx, src0, src1, dst);
} else if (any_gpus_with_slow_fp16 && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
//printf("%s(%s): using ggml_cuda_mul_mat_vec_nc\n", __func__, dst->name);
// FP32 precision KQV single-batch for batch size 1 without FlashAttention
ggml_cuda_mul_mat_vec_nc(ctx, src0, src1, dst);
} else if (src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16)
} else if ((src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_F32) && (src1->type == src0->type || !any_gpus_with_slow_fp16)
&& !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
//printf("%s(%s): ggml_cuda_mul_mat_batched_cublas\n", __func__, dst->name);
// KQ + KQV multi-batch without FlashAttention
ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
} else if (use_dequantize_mul_mat_vec) {
//printf("%s(%s): ggml_cuda_op_mul_mat(ggml_cuda_op_dequantize_mul_mat_vec)\n", __func__, dst->name);
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, nullptr);
} else if (use_mul_mat_vec_q) {
//printf("%s(%s): ggml_cuda_op_mul_mat(ggml_cuda_op_mul_mat_vec_q)\n", __func__, dst->name);
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda);
} else if (use_mul_mat_q) {
//printf("%s(%s): ggml_cuda_op_mul_mat(ggml_cuda_op_mul_mat_q)\n", __func__, dst->name);
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_q, quantize_mmq_q8_1_cuda);
} else {
//printf("%s(%s): ggml_cuda_op_mul_mat(ggml_cuda_op_mul_mat_cublas)\n", __func__, dst->name);
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_cublas, nullptr);
}
return node_n;
@ -2822,11 +2839,6 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten
return i;
}
std::vector<char> ids_host(ggml_nbytes(ids));
const char * ids_dev = (const char *) ids->data;
CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
CUDA_CHECK(cudaStreamSynchronize(stream));
ggml_tensor src0_1_row = *src0_1;
ggml_tensor src0_2_row; if (src0_2) src0_2_row = *src0_2;
ggml_tensor src1_row = *src1;
@ -3199,7 +3211,15 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
}
break;
case GGML_OP_CONT:
ggml_cuda_dup(ctx, dst);
if (fusion && i + 2 < cgraph->n_nodes &&
cgraph->nodes[i+1]->op == GGML_OP_SUM_ROWS &&
cgraph->nodes[i+2]->op == GGML_OP_TRANSPOSE &&
dst->src[0]->op == GGML_OP_TRANSPOSE) {
ggml_cuda_op_sum_rows_nc(ctx, cgraph->nodes[i+1]);
i += 2;
} else {
ggml_cuda_dup(ctx, dst);
}
break;
case GGML_OP_ADD:
if (fusion && i + 2 < cgraph->n_nodes &&
@ -3242,6 +3262,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
ggml_cuda_op_acc(ctx, dst);
break;
case GGML_OP_MUL:
//printf("mul(%s): %d, %d, %d, %ld x %ld x %ld x %ld * %ld x %ld x %ld x %ld\n", dst->name, ggml_is_contiguous(dst->src[0]), ggml_is_contiguous(dst->src[1]), ggml_is_contiguous(dst),
// dst->src[0]->ne[0], dst->src[0]->ne[1], dst->src[0]->ne[2], dst->src[0]->ne[3],
// dst->src[1]->ne[0], dst->src[1]->ne[1], dst->src[1]->ne[2], dst->src[1]->ne[3]);
ggml_cuda_op_mul(ctx, dst);
break;
case GGML_OP_FUSED_MUL_UNARY:
@ -3250,6 +3273,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_DIV:
ggml_cuda_op_div(ctx, dst);
break;
case GGML_OP_SUB:
ggml_cuda_op_sub(ctx, dst);
break;
case GGML_OP_UNARY:
switch (ggml_get_unary_op(dst)) {
case GGML_UNARY_OP_GELU:
@ -3273,6 +3299,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_UNARY_OP_RELU:
ggml_cuda_op_relu(ctx, dst);
break;
case GGML_UNARY_OP_NEG:
ggml_cuda_op_neg(ctx, dst);
break;
case GGML_UNARY_OP_SIGMOID:
if (fusion && i + 5 < cgraph->n_nodes &&
cgraph->nodes[i+1]->op == GGML_OP_RESHAPE &&
@ -3305,6 +3334,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_UNARY_OP_HARDSWISH:
ggml_cuda_op_hardswish(ctx, dst);
break;
case GGML_UNARY_OP_EXP:
ggml_cuda_op_exp(ctx, dst);
break;
case GGML_UNARY_OP_SOFTPLUS:
ggml_cuda_op_softplus(ctx, dst);
break;
default:
return -1;
}
@ -3339,6 +3374,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_GROUP_NORM:
ggml_cuda_op_group_norm(ctx, dst);
break;
case GGML_OP_L2_NORM:
ggml_cuda_op_l2_norm(ctx, dst);
break;
case GGML_OP_CONCAT:
if (fusion && i + 2 < cgraph->n_nodes &&
cgraph->nodes[i+1]->op == GGML_OP_VIEW &&
@ -3554,6 +3592,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
ggml_cuda_op_sum_rows(ctx, dst);
}
break;
case GGML_OP_CUMSUM:
ggml_cuda_op_cumsum(ctx, dst);
break;
case GGML_OP_ARGSORT:
if (fusion && i + 5 < cgraph->n_nodes &&
cgraph->nodes[i+1]->op == GGML_OP_VIEW &&
@ -3573,6 +3614,18 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_GROUPED_TOPK:
ggml_cuda_op_grouped_topk(ctx, dst);
break;
case GGML_OP_SSM_CONV:
ggml_cuda_op_ssm_conv(ctx, dst);
break;
case GGML_OP_TRI:
ggml_cuda_op_tri(ctx, dst);
break;
case GGML_OP_FILL:
ggml_cuda_op_fill(ctx, dst);
break;
case GGML_OP_SOLVE_TRI:
ggml_cuda_op_solve_tri(ctx, dst);
break;
case GGML_OP_FLASH_ATTN_EXT:
ggml_cuda_flash_attn_ext(ctx, dst);
break;
@ -3594,6 +3647,10 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
}
#if IK_PRINT_TIMING
if (auto err = cudaStreamSynchronize(ctx.stream()); err != cudaSuccess) {
GGML_CUDA_LOG_ERROR("%s: %s failed\n", __func__, ggml_op_desc(dst));
CUDA_CHECK(err);
}
int64_t tim2 = ggml_time_us();
printf("%s(%s): %d us\n", ggml_op_name(dst->op), dst->name, (int)(tim2 - tim1));
#endif
@ -4149,6 +4206,9 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_UNARY_OP_HARDSWISH:
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_TANH:
case GGML_UNARY_OP_EXP:
case GGML_UNARY_OP_SOFTPLUS:
case GGML_UNARY_OP_NEG:
return ggml_is_contiguous(op->src[0]);
default:
return false;
@ -4342,6 +4402,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_OP_NORM:
case GGML_OP_RMS_NORM:
return true;
case GGML_OP_L2_NORM:
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
case GGML_OP_RMS_NORM_BACK:
return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0;
break;
@ -4356,6 +4418,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_OP_MUL_MULTI_ADD:
case GGML_OP_MUL:
case GGML_OP_DIV:
case GGML_OP_SUB:
case GGML_OP_FUSED_RMS_NORM:
case GGML_OP_SCALE:
case GGML_OP_SOFTCAP:
@ -4389,6 +4452,38 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_LEAKY_RELU:
return true;
case GGML_OP_CUMSUM:
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
case GGML_OP_TRI:
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
op->src[0]->type == op->type;
case GGML_OP_FILL:
return ggml_is_contiguous(op) && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
case GGML_OP_SOLVE_TRI:
return ggml_is_contiguous(op->src[0]) &&
ggml_is_contiguous(op->src[1]) &&
ggml_is_contiguous(op) &&
op->src[0]->type == GGML_TYPE_F32 &&
op->src[1]->type == GGML_TYPE_F32 &&
op->type == GGML_TYPE_F32 &&
op->src[0]->ne[0] == op->src[0]->ne[1] &&
op->src[0]->ne[1] == op->src[1]->ne[1] &&
op->src[0]->ne[2] == op->src[1]->ne[2] &&
op->src[0]->ne[3] == op->src[1]->ne[3];
case GGML_OP_SSM_CONV:
return op->src[0]->type == GGML_TYPE_F32 &&
op->src[1]->type == GGML_TYPE_F32 &&
op->src[2]->type == GGML_TYPE_F32 &&
op->src[3]->type == GGML_TYPE_I32 &&
op->type == GGML_TYPE_F32 &&
op->src[0]->nb[0] == sizeof(float) &&
op->src[1]->nb[0] == sizeof(float) &&
op->src[2]->nb[0] == sizeof(float) &&
op->src[3]->nb[0] == sizeof(int32_t) &&
op->src[2]->ne[0] == op->src[0]->ne[0] + 1 &&
op->src[2]->ne[1] == op->src[0]->ne[1] &&
op->src[1]->ne[0] == op->src[0]->ne[1] &&
op->src[3]->ne[0] == op->src[0]->ne[2];
case GGML_OP_FLASH_ATTN_EXT:
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
return (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) || op->src[0]->ne[0] == 128;

View File

@ -24,6 +24,10 @@ static __device__ __forceinline__ float op_div(const float a, const float b) {
return a / b;
}
static __device__ __forceinline__ float op_sub(const float a, const float b) {
return a - b;
}
template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
int ne0, int ne1, int ne2, int ne3,
@ -512,14 +516,37 @@ static void ggml_cuda_op_scale_tensor(ggml_backend_cuda_context & ctx, ggml_tens
scale_f32_cuda_l(src0_d, dst_d, dst->src[1]->data, ggml_nelements(src0), stream);
}
static __global__ void k_mul_fast(int ne0, int nelem, const float * x, const float * y, float * z) {
int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= nelem) {
return;
}
int i1 = i / ne0;
z[i] = x[i] * y[i1];
}
void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
if (ggml_nelements(dst->src[1]) == 1 && dst->src[1]->type == GGML_TYPE_F32 && dst->src[0]->type == GGML_TYPE_F32) {
ggml_cuda_op_scale_tensor(ctx, dst);
return;
}
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_mul>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
auto src0 = dst->src[0];
auto src1 = dst->src[1];
if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
src1->ne[0] == 1 && src0->ne[1] == src1->ne[1] && src0->ne[2] == src1->ne[2] && src0->ne[3] == src1->ne[3]) {
constexpr int kBlockSize = 256;
int nelem = ggml_nelements(src0);
int nblock = (nelem + kBlockSize - 1)/kBlockSize;
k_mul_fast<<<nblock, kBlockSize, 0, ctx.stream()>>>(src0->ne[0], nelem, (const float *)src0->data, (const float *)src1->data, (float *)dst->data);
return;
}
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_mul>>(src0, src1, dst, src0->data, src1->data, dst->data, ctx.stream());
}
void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_div>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
}
void ggml_cuda_op_sub(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_sub>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
}

View File

@ -2,6 +2,7 @@
void ggml_cuda_op_repeat(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_sub(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@ -0,0 +1,76 @@
#include "cumsum.cuh"
#define CUDA_CUMSUM_BLOCK_SIZE 256
static __global__ void cumsum_f32_kernel(
const float * src, float * dst,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t s00, const int64_t s01, const int64_t s02, const int64_t s03,
const int64_t d0, const int64_t d1, const int64_t d2, const int64_t d3) {
const int64_t i1 = blockIdx.x;
const int64_t i2 = blockIdx.y;
const int64_t i3 = blockIdx.z;
if (i1 >= ne01 || i2 >= ne02 || i3 >= ne03) {
return;
}
const float * src_row = src + i1 * s01 + i2 * s02 + i3 * s03;
float * dst_row = dst + i1 * d1 + i2 * d2 + i3 * d3;
extern __shared__ float s_scan[];
float carry = 0.0f;
for (int64_t start = 0; start < ne00; start += blockDim.x) {
const int tile_n = (int) ((ne00 - start) < (int64_t) blockDim.x ? (ne00 - start) : (int64_t) blockDim.x);
float value = 0.0f;
if (threadIdx.x < tile_n) {
value = src_row[(start + threadIdx.x) * s00];
}
s_scan[threadIdx.x] = value;
__syncthreads();
for (int offset = 1; offset < blockDim.x; offset <<= 1) {
float add = 0.0f;
if (threadIdx.x >= offset) {
add = s_scan[threadIdx.x - offset];
}
__syncthreads();
if (threadIdx.x >= offset) {
s_scan[threadIdx.x] += add;
}
__syncthreads();
}
if (threadIdx.x < tile_n) {
dst_row[(start + threadIdx.x) * d0] = s_scan[threadIdx.x] + carry;
}
__syncthreads();
if (threadIdx.x == tile_n - 1) {
carry += s_scan[threadIdx.x];
}
__syncthreads();
}
}
void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
int block_size = WARP_SIZE;
while (block_size < src0->ne[0] && block_size < CUDA_CUMSUM_BLOCK_SIZE) {
block_size <<= 1;
}
dim3 grid_dims(src0->ne[1], src0->ne[2], src0->ne[3]);
cumsum_f32_kernel<<<grid_dims, block_size, block_size * sizeof(float), ctx.stream()>>>(
(const float *) src0->data,
(float *) dst->data,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
src0->nb[0] / sizeof(float), src0->nb[1] / sizeof(float), src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float),
dst->nb[0] / sizeof(float), dst->nb[1] / sizeof(float), dst->nb[2] / sizeof(float), dst->nb[3] / sizeof(float));
}

View File

@ -0,0 +1,3 @@
#include "common.cuh"
void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@ -0,0 +1,34 @@
#include "fill.cuh"
#include "convert.cuh"
#define CUDA_FILL_BLOCK_SIZE 256
template <typename T>
static __global__ void fill_kernel(T * dst, const int64_t k, const T value) {
const int64_t i = (int64_t) blockDim.x * blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
dst[i] = value;
}
void ggml_cuda_op_fill(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(ggml_is_contiguous(dst));
float value = 0.0f;
memcpy(&value, dst->op_params, sizeof(float));
const int64_t k = ggml_nelements(dst);
const int64_t num_blocks = (k + CUDA_FILL_BLOCK_SIZE - 1) / CUDA_FILL_BLOCK_SIZE;
switch (dst->type) {
case GGML_TYPE_F32:
fill_kernel<<<num_blocks, CUDA_FILL_BLOCK_SIZE, 0, ctx.stream()>>>((float *) dst->data, k, value);
break;
case GGML_TYPE_F16:
fill_kernel<<<num_blocks, CUDA_FILL_BLOCK_SIZE, 0, ctx.stream()>>>((half *) dst->data, k, ggml_cuda_cast<half>(value));
break;
default:
GGML_ABORT("unsupported type");
}
}

View File

@ -0,0 +1,3 @@
#include "common.cuh"
void ggml_cuda_op_fill(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@ -185,6 +185,38 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol
}
}
template <int block_size>
static __global__ void l2_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
const int row = blockIdx.x * blockDim.y + threadIdx.y;
const int tid = threadIdx.x;
float tmp = 0.0f;
for (int col = tid; col < ncols; col += block_size) {
const float xi = x[row * ncols + col];
tmp += xi * xi;
}
tmp = warp_reduce_sum(tmp);
if (block_size > WARP_SIZE) {
__shared__ float s_sum[32];
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = tmp;
}
__syncthreads();
tmp = lane_id < block_size / WARP_SIZE ? s_sum[lane_id] : 0.0f;
tmp = warp_reduce_sum(tmp);
}
const float scale = rsqrtf(fmaxf(tmp, eps * eps));
for (int col = tid; col < ncols; col += block_size) {
dst[row * ncols + col] = scale * x[row * ncols + col];
}
}
template <int block_size>
static __global__ void rms_norm_f32_nc(
const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
@ -230,6 +262,49 @@ static __global__ void rms_norm_f32_nc(
}
}
template <int block_size>
static __global__ void l2_norm_f32_nc(
const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
const int64_t stride_sample, const float eps) {
const int nrows = gridDim.x;
const int nchannels = gridDim.y;
const int row = blockIdx.x;
const int channel = blockIdx.y;
const int sample = blockIdx.z;
const int tid = threadIdx.x;
x += sample * stride_sample + channel * stride_channel + row * stride_row;
dst += ((sample * nchannels + channel) * nrows + row) * ncols;
float tmp = 0.0f;
for (int col = tid; col < ncols; col += block_size) {
const float xi = x[col];
tmp += xi * xi;
}
tmp = warp_reduce_sum(tmp);
if constexpr (block_size > WARP_SIZE) {
static_assert(block_size == 1024, "unexpected block_size");
__shared__ float s_sum[32];
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = tmp;
}
__syncthreads();
tmp = s_sum[lane_id];
tmp = warp_reduce_sum(tmp);
}
const float scale = rsqrtf(fmaxf(tmp, eps * eps));
for (int col = tid; col < ncols; col += block_size) {
dst[col] = scale * x[col];
}
}
template <int block_size, typename src_t>
static __global__ void fused_rms_norm_f32(const src_t * x, const float * y, float * dst, const int ncols, const float eps) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
@ -387,6 +462,31 @@ static void rms_norm_f32_nc_cuda(
}
}
static void l2_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
GGML_ASSERT(ncols % WARP_SIZE == 0);
constexpr int kBlockSize = 256;
if (ncols < 1024) {
const dim3 block_dims(kBlockSize, 1, 1);
l2_norm_f32<kBlockSize><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
} else {
const dim3 block_dims(1024, 1, 1);
l2_norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
}
}
static void l2_norm_f32_nc_cuda(
const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
const dim3 blocks_num(nrows, nchannels, nsamples);
if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
l2_norm_f32_nc<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
} else {
const dim3 block_dims(1024, 1, 1);
l2_norm_f32_nc<1024><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
}
}
template <typename src_t>
static void fused_rms_norm_f32_cuda(const src_t * x, const float * y, float * dst,
const int ncols, const int nrows, const float eps, bool is_norm, cudaStream_t stream) {
@ -527,6 +627,32 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
}
}
void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *) src0->data;
float * dst_d = (float *) dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
float eps = 0.0f;
memcpy(&eps, dst->op_params, sizeof(float));
const int64_t ne00 = src0->ne[0];
if (ggml_is_contiguous(src0)) {
const int64_t nrows = ggml_nrows(src0);
l2_norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
} else {
const size_t ts0 = ggml_type_size(src0->type);
GGML_ASSERT(src0->nb[0] == ts0);
const int64_t s01 = src0->nb[1] / ts0;
const int64_t s02 = src0->nb[2] / ts0;
const int64_t s03 = src0->nb[3] / ts0;
l2_norm_f32_nc_cuda(src0_d, dst_d, ne00, src0->ne[1], src0->ne[2], src0->ne[3], s01, s02, s03, eps, stream);
}
}
void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst, bool is_norm) {
if (!dst->src[1]) {
ggml_cuda_op_rms_norm(ctx, dst);

View File

@ -6,6 +6,8 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst, bool is_norm = false);
void ggml_cuda_op_fused_add_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * add, ggml_tensor * dst);

View File

@ -0,0 +1,914 @@
#include "common.cuh"
#include "ggml.h"
#include "solve_tri.cuh"
#include "ggml-cuda.h"
#include <cublas_v2.h>
#include <cstdio>
#define MAX_N_FAST 64
#define MAX_K_FAST 64
// This branch does not carry the fast-div helpers from upstream CUDA common code.
// Keep the PR kernel logic but back it with plain div/mod wrappers.
static inline uint3 init_fastdiv_values(uint32_t d) {
return make_uint3(d, 0u, 0u);
}
static __device__ __forceinline__ uint2 fast_div_modulo(uint32_t n, const uint3 d) {
return make_uint2(n / d.x, n % d.x);
}
// Kernel to set up pointer arrays for batched cuBLAS TRSM
// This avoids host-device copy during CUDA graph capture
static __global__ void setup_trsm_batch_pointers(
const float * A,
float * X,
const float ** A_ptrs,
float ** X_ptrs,
const int64_t ne02,
const int64_t total_batches,
const size_t nb02, // stride for A dim 2 (in floats)
const size_t nb03, // stride for A dim 3 (in floats)
const size_t nb2, // stride for X dim 2 (in floats)
const size_t nb3 // stride for X dim 3 (in floats)
) {
const int64_t batch_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (batch_idx >= total_batches) return;
// Decompose batch_idx into i02, i03
const int64_t i02 = batch_idx % ne02;
const int64_t i03 = batch_idx / ne02;
A_ptrs[batch_idx] = A + i02 * nb02 + i03 * nb03;
X_ptrs[batch_idx] = X + i02 * nb2 + i03 * nb3;
}
// Latency-optimized kernel for n=64, k=64 (single-token generation)
static __global__ void solve_tri_f32_64x64_latency(
const float * __restrict__ A,
const float * __restrict__ B,
float * __restrict__ X,
const uint3 ne02,
const size_t nb02,
const size_t nb03,
const size_t nb12,
const size_t nb13,
const size_t nb2,
const size_t nb3)
{
const int batch_idx = blockIdx.x;
const int lane = threadIdx.x;
const int warp_id = threadIdx.y;
const uint2 i02_i03 = fast_div_modulo(batch_idx, ne02);
const int64_t i02 = i02_i03.y;
const int64_t i03 = i02_i03.x;
const float * const A_batch = (const float *) (A + i02 * nb02 + i03 * nb03);
const float * const B_batch = (const float *) (B + i02 * nb12 + i03 * nb13);
float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3);
// Shared memory: A is 64x64, X is 64x65 (padded for bank conflicts)
__shared__ float sA[64 * 64];
__shared__ float sX[64 * 65];
__shared__ float sDiagInv[64]; // Precomputed 1/diagonal
const int tid = lane + warp_id * WARP_SIZE;
// Cooperative load of A matrix (4096 elements / 512 threads = 8 per thread)
#pragma unroll 8
for (int i = tid; i < 64 * 64; i += 512) {
sA[i] = A_batch[i];
}
// Cooperative load of B matrix into sX with padding
#pragma unroll 8
for (int i = tid; i < 64 * 64; i += 512) {
const int row = i / 64;
const int col = i % 64;
sX[row * 65 + col] = B_batch[i];
}
__syncthreads();
// Precompute diagonal inverses (first 2 warps handle this)
if (warp_id == 0) {
if (lane < 32) {
sDiagInv[lane] = 1.0f / sA[lane * 64 + lane];
}
}
if (warp_id == 1) {
if (lane < 32) {
sDiagInv[32 + lane] = 1.0f / sA[(32 + lane) * 64 + (32 + lane)];
}
}
__syncthreads();
// Each warp handles 4 columns: cols = warp_id*4 to warp_id*4+3
const int col_base = warp_id * 4;
#pragma unroll 1
for (int row = 0; row < 64; ++row) {
float sum0 = 0.0f, sum1 = 0.0f, sum2 = 0.0f, sum3 = 0.0f;
if (row > 0) {
for (int j = lane; j < row; j += WARP_SIZE) {
const float a_val = sA[row * 64 + j];
sum0 += a_val * sX[j * 65 + col_base + 0];
sum1 += a_val * sX[j * 65 + col_base + 1];
sum2 += a_val * sX[j * 65 + col_base + 2];
sum3 += a_val * sX[j * 65 + col_base + 3];
}
}
sum0 = warp_reduce_sum(sum0);
sum1 = warp_reduce_sum(sum1);
sum2 = warp_reduce_sum(sum2);
sum3 = warp_reduce_sum(sum3);
if (lane == 0) {
const float inv_diag = sDiagInv[row];
sX[row * 65 + col_base + 0] = (sX[row * 65 + col_base + 0] - sum0) * inv_diag;
sX[row * 65 + col_base + 1] = (sX[row * 65 + col_base + 1] - sum1) * inv_diag;
sX[row * 65 + col_base + 2] = (sX[row * 65 + col_base + 2] - sum2) * inv_diag;
sX[row * 65 + col_base + 3] = (sX[row * 65 + col_base + 3] - sum3) * inv_diag;
}
__syncthreads();
}
// Cooperative write results back
#pragma unroll 8
for (int i = tid; i < 64 * 64; i += 512) {
const int row = i / 64;
const int col = i % 64;
X_batch[i] = sX[row * 65 + col];
}
}
static __global__ void solve_tri_f32_64x64_opt(const float * __restrict__ A,
const float * __restrict__ B,
float * __restrict__ X,
const uint3 ne02,
const size_t nb02,
const size_t nb03,
const size_t nb12,
const size_t nb13,
const size_t nb2,
const size_t nb3) {
const int batch_idx = blockIdx.x;
const int lane = threadIdx.x;
const int warp_id = threadIdx.y;
const uint2 i02_i03 = fast_div_modulo(batch_idx, ne02);
const int64_t i02 = i02_i03.y;
const int64_t i03 = i02_i03.x;
const float * const A_batch = (const float *) (A + i02 * nb02 + i03 * nb03);
const float * const B_batch = (const float *) (B + i02 * nb12 + i03 * nb13);
float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3);
// Shared memory: A is 64x64, sXt is 64x65 (padded)
__shared__ float sA[64 * 64];
__shared__ float sXt[64 * 65];
const int tid = lane + warp_id * WARP_SIZE;
// Cooperative load of A matrix (4096 elements / 1024 threads = 4 per thread)
#pragma unroll 4
for (int i = tid; i < 64 * 64; i += 1024) {
sA[i] = A_batch[i];
}
// Cooperative load of B matrix transposed into sXt
// sXt[col * 65 + row] = B[row * 64 + col]
#pragma unroll 4
for (int i = tid; i < 64 * 64; i += 1024) {
const int row = i / 64;
const int col = i % 64;
sXt[col * 65 + row] = B_batch[row * 64 + col];
}
__syncthreads();
// Each warp handles 2 columns: col0 = warp_id*2, col1 = warp_id*2 + 1
const int col0 = warp_id * 2;
const int col1 = warp_id * 2 + 1;
// Forward substitution with all columns processed in parallel
// Each row depends on previous rows, but different columns are independent
#pragma unroll 1
for (int row = 0; row < 64; ++row) {
// Each lane computes partial sum for indices it handles
float sum0 = 0.0f;
float sum1 = 0.0f;
// Sum over j < row
// For row <= 32: each lane handles at most 1 element
// For row > 32: each lane handles at most 2 elements
if (lane < row) {
const float a_val = sA[row * 64 + lane];
sum0 = a_val * sXt[col0 * 65 + lane];
sum1 = a_val * sXt[col1 * 65 + lane];
}
if (row > WARP_SIZE) {
const int j2 = lane + WARP_SIZE;
if (j2 < row) {
const float a_val2 = sA[row * 64 + j2];
sum0 += a_val2 * sXt[col0 * 65 + j2];
sum1 += a_val2 * sXt[col1 * 65 + j2];
}
}
// Warp-level reduction
sum0 = warp_reduce_sum(sum0);
sum1 = warp_reduce_sum(sum1);
// Lane 0 computes and stores the result
if (lane == 0) {
const float a_diag = sA[row * 64 + row];
const float inv_diag = 1.0f / a_diag;
sXt[col0 * 65 + row] = (sXt[col0 * 65 + row] - sum0) * inv_diag;
sXt[col1 * 65 + row] = (sXt[col1 * 65 + row] - sum1) * inv_diag;
}
// Sync within warp to ensure writes are visible before next row reads
__syncwarp();
}
__syncthreads();
// Cooperative write of results back (transpose sXt to X)
#pragma unroll 4
for (int i = tid; i < 64 * 64; i += 1024) {
const int row = i / 64;
const int col = i % 64;
X_batch[row * 64 + col] = sXt[col * 65 + row];
}
}
static __global__ void solve_tri_f32_128x128_opt(const float * __restrict__ A,
const float * __restrict__ B,
float * __restrict__ X,
const uint3 ne02,
const size_t nb02,
const size_t nb03,
const size_t nb12,
const size_t nb13,
const size_t nb2,
const size_t nb3,
const int n,
const int k) {
const int batch_idx = blockIdx.x;
const int lane = threadIdx.x;
const int warp_id = threadIdx.y;
const uint2 i02_i03 = fast_div_modulo(batch_idx, ne02);
const int64_t i02 = i02_i03.y;
const int64_t i03 = i02_i03.x;
const float * const A_batch = (const float *) (A + i02 * nb02 + i03 * nb03);
const float * const B_batch = (const float *) (B + i02 * nb12 + i03 * nb13);
float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3);
// Shared memory with padding to avoid bank conflicts
// Layout: sA[128][128] + sXt[128][129]
extern __shared__ char smem_raw[];
float * sA = (float *)smem_raw; // 128×128 (zero-initialized for unused parts)
float * sXt = sA + 128 * 128; // 128×129 (padded)
const int tid = lane + warp_id * WARP_SIZE;
// Zero-initialize shared memory first (important for variable n, k)
#pragma unroll 16
for (int i = tid; i < 128 * 128; i += 1024) {
sA[i] = 0.0f;
}
#pragma unroll 16
for (int i = tid; i < 128 * 129; i += 1024) {
sXt[i] = 0.0f;
}
__syncthreads();
// Cooperative load of A matrix (n×n elements)
for (int i = tid; i < n * n; i += 1024) {
const int row = i / n;
const int col = i % n;
sA[row * 128 + col] = A_batch[row * n + col];
}
// Cooperative load of B matrix transposed into sXt
// sXt[col * 129 + row] = B[row * k + col]
for (int i = tid; i < n * k; i += 1024) {
const int row = i / k;
const int col = i % k;
sXt[col * 129 + row] = B_batch[row * k + col];
}
__syncthreads();
// Each warp handles columns: col_base to col_base+3
// But only process if col < k
const int col_base = warp_id * 4;
// Forward substitution with all columns processed in parallel
for (int row = 0; row < n; ++row) {
float sum0 = 0.0f, sum1 = 0.0f, sum2 = 0.0f, sum3 = 0.0f;
// Sum over j < row - each lane handles multiple elements
for (int j = lane; j < row; j += WARP_SIZE) {
const float a_val = sA[row * 128 + j];
if (col_base + 0 < k) sum0 += a_val * sXt[(col_base + 0) * 129 + j];
if (col_base + 1 < k) sum1 += a_val * sXt[(col_base + 1) * 129 + j];
if (col_base + 2 < k) sum2 += a_val * sXt[(col_base + 2) * 129 + j];
if (col_base + 3 < k) sum3 += a_val * sXt[(col_base + 3) * 129 + j];
}
// Warp-level reduction
sum0 = warp_reduce_sum(sum0);
sum1 = warp_reduce_sum(sum1);
sum2 = warp_reduce_sum(sum2);
sum3 = warp_reduce_sum(sum3);
// Lane 0 computes and stores the result
if (lane == 0) {
const float inv_diag = 1.0f / sA[row * 128 + row];
if (col_base + 0 < k) {
sXt[(col_base + 0) * 129 + row] = (sXt[(col_base + 0) * 129 + row] - sum0) * inv_diag;
}
if (col_base + 1 < k) {
sXt[(col_base + 1) * 129 + row] = (sXt[(col_base + 1) * 129 + row] - sum1) * inv_diag;
}
if (col_base + 2 < k) {
sXt[(col_base + 2) * 129 + row] = (sXt[(col_base + 2) * 129 + row] - sum2) * inv_diag;
}
if (col_base + 3 < k) {
sXt[(col_base + 3) * 129 + row] = (sXt[(col_base + 3) * 129 + row] - sum3) * inv_diag;
}
}
__syncwarp();
}
__syncthreads();
// Cooperative write of results back (transpose sXt to X)
for (int i = tid; i < n * k; i += 1024) {
const int row = i / k;
const int col = i % k;
X_batch[row * k + col] = sXt[col * 129 + row];
}
}
static __global__ void solve_tri_f32_256x256_tiled(const float * __restrict__ A,
const float * __restrict__ B,
float * __restrict__ X,
const uint3 ne02,
const size_t nb02,
const size_t nb03,
const size_t nb12,
const size_t nb13,
const size_t nb2,
const size_t nb3,
const int n,
const int k) {
const int batch_idx = blockIdx.x;
const int lane = threadIdx.x;
const int warp_id = threadIdx.y;
const uint2 i02_i03 = fast_div_modulo(batch_idx, ne02);
const int64_t i02 = i02_i03.y;
const int64_t i03 = i02_i03.x;
const float * const A_batch = (const float *) (A + i02 * nb02 + i03 * nb03);
const float * const B_batch = (const float *) (B + i02 * nb12 + i03 * nb13);
float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3);
// Tiled approach using 64×64 tiles to fit in shared memory
constexpr int TILE_SIZE = 64;
extern __shared__ char smem_raw[];
float * sA_tile = (float *)smem_raw; // 64×64 = 16KB
float * sXt_tile = sA_tile + TILE_SIZE * TILE_SIZE; // 64×65 = 16.25KB (padded)
float * sA_off = sXt_tile + TILE_SIZE * (TILE_SIZE+1); // 64×64 = 16KB (for off-diagonal blocks)
const int tid = lane + warp_id * WARP_SIZE;
// Initialize X = B (we'll solve in-place conceptually, using global memory)
for (int i = tid; i < n * k; i += 1024) {
X_batch[i] = B_batch[i];
}
__syncthreads();
// Process tile-by-tile along the diagonal
for (int tile_row = 0; tile_row < n; tile_row += TILE_SIZE) {
const int tile_n = min(TILE_SIZE, n - tile_row); // Actual rows in this tile
// Zero-init and load diagonal tile of A
for (int i = tid; i < TILE_SIZE * TILE_SIZE; i += 1024) {
sA_tile[i] = 0.0f;
}
__syncthreads();
for (int i = tid; i < tile_n * tile_n; i += 1024) {
int local_row = i / tile_n;
int local_col = i % tile_n;
sA_tile[local_row * TILE_SIZE + local_col] = A_batch[(tile_row + local_row) * n + tile_row + local_col];
}
__syncthreads();
// For each column tile of X
for (int tile_col = 0; tile_col < k; tile_col += TILE_SIZE) {
const int tile_k = min(TILE_SIZE, k - tile_col); // Actual columns in this tile
// Zero-init and load X tile transposed
for (int i = tid; i < TILE_SIZE * (TILE_SIZE+1); i += 1024) {
sXt_tile[i] = 0.0f;
}
__syncthreads();
for (int i = tid; i < tile_n * tile_k; i += 1024) {
int local_row = i / tile_k;
int local_col = i % tile_k;
sXt_tile[local_col * (TILE_SIZE+1) + local_row] =
X_batch[(tile_row + local_row) * k + tile_col + local_col];
}
__syncthreads();
// Apply updates from previous tile rows
for (int prev_tile = 0; prev_tile < tile_row; prev_tile += TILE_SIZE) {
const int prev_n = min(TILE_SIZE, n - prev_tile);
// Zero-init and load off-diagonal block
for (int i = tid; i < TILE_SIZE * TILE_SIZE; i += 1024) {
sA_off[i] = 0.0f;
}
__syncthreads();
for (int i = tid; i < tile_n * prev_n; i += 1024) {
int local_row = i / prev_n;
int local_col = i % prev_n;
sA_off[local_row * TILE_SIZE + local_col] = A_batch[(tile_row + local_row) * n + prev_tile + local_col];
}
__syncthreads();
// Update: X_tile -= A_off @ X_prev
int col0 = warp_id * 2;
int col1 = warp_id * 2 + 1;
for (int row = 0; row < tile_n; row++) {
float sum0 = 0.0f, sum1 = 0.0f;
for (int j = lane; j < prev_n; j += WARP_SIZE) {
float a_val = sA_off[row * TILE_SIZE + j];
if (col0 < tile_k) {
float x_prev0 = X_batch[(prev_tile + j) * k + tile_col + col0];
sum0 += a_val * x_prev0;
}
if (col1 < tile_k) {
float x_prev1 = X_batch[(prev_tile + j) * k + tile_col + col1];
sum1 += a_val * x_prev1;
}
}
sum0 = warp_reduce_sum(sum0);
sum1 = warp_reduce_sum(sum1);
if (lane == 0) {
if (col0 < tile_k) {
sXt_tile[col0 * (TILE_SIZE+1) + row] -= sum0;
}
if (col1 < tile_k) {
sXt_tile[col1 * (TILE_SIZE+1) + row] -= sum1;
}
}
__syncwarp();
}
__syncthreads();
}
// Solve the diagonal tile
int col0 = warp_id * 2;
int col1 = warp_id * 2 + 1;
for (int row = 0; row < tile_n; ++row) {
float sum0 = 0.0f, sum1 = 0.0f;
if (lane < row) {
float a_val = sA_tile[row * TILE_SIZE + lane];
if (col0 < tile_k) sum0 = a_val * sXt_tile[col0 * (TILE_SIZE+1) + lane];
if (col1 < tile_k) sum1 = a_val * sXt_tile[col1 * (TILE_SIZE+1) + lane];
}
if (row > WARP_SIZE) {
int j2 = lane + WARP_SIZE;
if (j2 < row) {
float a_val2 = sA_tile[row * TILE_SIZE + j2];
if (col0 < tile_k) sum0 += a_val2 * sXt_tile[col0 * (TILE_SIZE+1) + j2];
if (col1 < tile_k) sum1 += a_val2 * sXt_tile[col1 * (TILE_SIZE+1) + j2];
}
}
sum0 = warp_reduce_sum(sum0);
sum1 = warp_reduce_sum(sum1);
if (lane == 0) {
float inv_diag = 1.0f / sA_tile[row * TILE_SIZE + row];
if (col0 < tile_k) {
sXt_tile[col0 * (TILE_SIZE+1) + row] =
(sXt_tile[col0 * (TILE_SIZE+1) + row] - sum0) * inv_diag;
}
if (col1 < tile_k) {
sXt_tile[col1 * (TILE_SIZE+1) + row] =
(sXt_tile[col1 * (TILE_SIZE+1) + row] - sum1) * inv_diag;
}
}
__syncwarp();
}
__syncthreads();
// Write solved tile back to global memory
for (int i = tid; i < tile_n * tile_k; i += 1024) {
int local_row = i / tile_k;
int local_col = i % tile_k;
X_batch[(tile_row + local_row) * k + tile_col + local_col] =
sXt_tile[local_col * (TILE_SIZE+1) + local_row];
}
__syncthreads();
}
}
}
// When ncols_template == 0 the bounds for the loops in this function are not
// known and can't be unrolled. As we want to keep pragma unroll for all other
// cases we supress the clang transformation warning here.
#ifdef __clang__
# pragma clang diagnostic push
# pragma clang diagnostic ignored "-Wpass-failed"
#endif // __clang__
// Template parameters: n_template/k_template are the matrix dimensions when known at compile time (0 = runtime)
// threads_y_template is the number of threads in y dimension (max 32 to stay within 1024 thread limit)
template <int n_template, int k_template, int threads_y_template>
static __global__ void solve_tri_f32_fast(const float * __restrict__ A,
const float * __restrict__ B,
float * __restrict__ X,
const uint3 ne02,
const size_t nb02,
const size_t nb03,
const size_t nb12,
const size_t nb13,
const size_t nb2,
const size_t nb3,
const int n_arg,
const int k_arg) {
const int n = n_template == 0 ? n_arg : n_template;
const int k = k_template == 0 ? k_arg : k_template;
const int threads_y = threads_y_template == 0 ? blockDim.y : threads_y_template;
const int batch_idx = blockIdx.x;
const int lane = threadIdx.x;
const uint2 i02_i03 = fast_div_modulo(batch_idx, ne02);
const int64_t i02 = i02_i03.y;
const int64_t i03 = i02_i03.x;
const float * const A_batch = (const float *) (A + i02 * nb02 + i03 * nb03);
const float * const B_batch = (const float *) (B + i02 * nb12 + i03 * nb13);
float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3);
__shared__ float sA[MAX_N_FAST * MAX_N_FAST];
__shared__ float sXt[MAX_N_FAST * (MAX_K_FAST + 1)];
const int offset = threadIdx.x + threadIdx.y * blockDim.x;
const int block_threads = blockDim.x * blockDim.y;
// Load A matrix into shared memory
#pragma unroll
for (int i = 0; i < n * n; i += block_threads) {
int i0 = i + offset;
if (i0 < n * n) {
sA[i0] = A_batch[i0];
}
}
const int rows_per_warp = (n + WARP_SIZE - 1) / WARP_SIZE;
const int cols_per_thread = (k + threads_y - 1) / threads_y;
// Load B matrix into shared memory (transposed as sXt)
// Each thread handles multiple columns when k > threads_y
for (int c = 0; c < cols_per_thread; c++) {
const int col_idx = threadIdx.y + c * threads_y;
if (col_idx < k) {
#pragma unroll
for (int i = 0; i < rows_per_warp; i++) {
const int i0 = lane + i * WARP_SIZE;
if (i0 < n) {
sXt[col_idx * n + i0] = B_batch[i0 * k + col_idx];
}
}
}
}
__syncthreads();
// Solve for each column this thread handles
for (int c = 0; c < cols_per_thread; c++) {
const int col_idx = threadIdx.y + c * threads_y;
if (col_idx >= k) {
continue;
}
#pragma unroll
for (int row = 0; row < n; ++row) {
float sum = 0.0f;
{
int j = lane;
if (j < row) {
sum += sA[row * n + j] * sXt[col_idx * n + j];
}
}
if (row >= WARP_SIZE) {
int j = WARP_SIZE + lane;
if (j < row) {
sum += sA[row * n + j] * sXt[col_idx * n + j];
}
}
sum = warp_reduce_sum(sum);
if (lane == 0) {
const float b_val = sXt[col_idx * n + row];
const float a_diag = sA[row * n + row];
// no safeguards for division by zero because that indicates corrupt
// data anyway
sXt[col_idx * n + row] = (b_val - sum) / a_diag;
}
}
// Sync between columns to ensure writes are visible
if (c + 1 < cols_per_thread) {
__syncwarp();
}
}
__syncthreads();
// Write results back
for (int c = 0; c < cols_per_thread; c++) {
const int col_idx = threadIdx.y + c * threads_y;
if (col_idx < k) {
#pragma unroll
for (int i = 0; i < rows_per_warp; i++) {
const int i0 = lane + i * WARP_SIZE;
if (i0 < n) {
X_batch[i0 * k + col_idx] = sXt[col_idx * n + i0];
}
}
}
}
}
#ifdef __clang__
# pragma clang diagnostic pop
#endif // __clang__
// cuBLAS batched TRSM fallback for larger matrices or as robust path
// Solves A * X = B where A is lower triangular
// This function modifies X in-place (X should be initialized with B)
static void solve_tri_f32_cublas(
ggml_backend_cuda_context & ctx,
const float * A,
float * X, // Input: B, Output: solution X (in-place)
int n,
int k,
int64_t ne02,
int64_t ne03,
size_t nb02,
size_t nb03,
size_t nb2,
size_t nb3,
cudaStream_t stream
) {
const int64_t total_batches = ne02 * ne03;
// Allocate pointer arrays on device
ggml_cuda_pool_alloc<const float *> A_ptrs(ctx.pool(), total_batches);
ggml_cuda_pool_alloc<float *> X_ptrs(ctx.pool(), total_batches);
// Set up pointer arrays on device (CUDA graph compatible)
{
const int block_size = 256;
const int grid_size = (total_batches + block_size - 1) / block_size;
setup_trsm_batch_pointers<<<grid_size, block_size, 0, stream>>>(
A, X,
A_ptrs.get(), X_ptrs.get(),
ne02, total_batches,
nb02, nb03, nb2, nb3
);
CUDA_CHECK(cudaGetLastError());
}
// Get cuBLAS handle and set stream
cublasHandle_t handle = ctx.cublas_handle();
cublasSetStream(handle, stream);
// Save current math mode and set to default for accuracy
// (TF32 can cause numerical issues with triangular solves)
cublasMath_t prev_math_mode;
cublasGetMathMode(handle, &prev_math_mode);
cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH);
const float alpha = 1.0f;
cublasStatus_t status = cublasStrsmBatched(
handle,
CUBLAS_SIDE_RIGHT, // A is on the right: X * A = B
CUBLAS_FILL_MODE_UPPER, // A^T is upper (since A is lower in row-major)
CUBLAS_OP_N, // No additional transpose
CUBLAS_DIAG_NON_UNIT, // Diagonal is not assumed to be 1
k, // m: rows of X^T (columns of X)
n, // n: columns of X^T (rows of X) = size of A
&alpha,
(const float **)A_ptrs.get(), n, // lda = n (leading dimension)
(float **)X_ptrs.get(), k, // ldb = k (leading dimension of X^T)
total_batches
);
// Restore previous math mode
cublasSetMathMode(handle, prev_math_mode);
if (status != CUBLAS_STATUS_SUCCESS) {
fprintf(stderr, "cuBLAS batched TRSM failed: %d\n", (int) status);
}
}
static void solve_tri_f32_cuda(const float * A,
const float * B,
float * X,
int n,
int k,
int64_t ne02,
int64_t ne03,
size_t nb02,
size_t nb03,
size_t nb12,
size_t nb13,
size_t nb2,
size_t nb3,
cudaStream_t stream) {
const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02);
dim3 grid(ne02 * ne03);
// Handle large matrices first (256×256 and 65-128 range)
// Route sizes 65-256 to the tiled kernel
if (n > 64 || k > 64) {
// Use the tiled kernel which works for any size up to 256
// and only requires ~48KB shared memory (within standard limits)
dim3 threads_256(WARP_SIZE, 32); // 1024 threads
// Shared memory: 64×64 + 64×65 + 64×64 = 16KB + 16.25KB + 16KB = ~48KB
const size_t smem_size = (64 * 64 + 64 * 65 + 64 * 64) * sizeof(float);
// Configure extended shared memory for this kernel
static bool smem_configured_tiled = false;
if (!smem_configured_tiled) {
cudaFuncSetAttribute(solve_tri_f32_256x256_tiled,
cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
smem_configured_tiled = true;
}
solve_tri_f32_256x256_tiled<<<grid, threads_256, smem_size, stream>>>(
A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
return;
}
// Limit threads_y to 32 to ensure we don't exceed 1024 threads per block (32 * 32 = 1024)
const int threads_y = k <= 32 ? k : 32;
dim3 threads(WARP_SIZE, threads_y);
if (n == 64) {
switch (k) {
case 64:
{
// Use optimized kernel for n=64, k=64 case (common in Qwen3 Next DeltaNet)
// Block config: 32x32 = 1024 threads (32 warps)
dim3 threads_64x64(WARP_SIZE, 32);
solve_tri_f32_64x64_opt
<<<grid, threads_64x64, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3);
}
break;
case 48:
// k=48 needs 2 columns per thread (threads_y=32, some threads handle 1, some 2)
solve_tri_f32_fast<64, 48, 32>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 40:
// k=40 needs 2 columns per thread (threads_y=32, some threads handle 1, some 2)
solve_tri_f32_fast<64, 40, 32>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 32:
solve_tri_f32_fast<64, 32, 32>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 16:
solve_tri_f32_fast<64, 16, 16>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 14:
solve_tri_f32_fast<64, 14, 14>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 12:
solve_tri_f32_fast<64, 12, 12>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 10:
solve_tri_f32_fast<64, 10, 10>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 8:
solve_tri_f32_fast<64, 8, 8>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 6:
solve_tri_f32_fast<64, 6, 6>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 4:
solve_tri_f32_fast<64, 4, 4>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 2:
solve_tri_f32_fast<64, 2, 2>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 1:
solve_tri_f32_fast<64, 1, 1>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
default:
solve_tri_f32_fast<0, 0, 0>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
}
} else { // run general case
solve_tri_f32_fast<0, 0, 0>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
}
}
void ggml_cuda_op_solve_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0]; // A (triangular n x n matrix)
const ggml_tensor * src1 = dst->src[1]; // B (right hand side of n x k equation columns)
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src1));
GGML_ASSERT(ggml_is_contiguous(dst));
GGML_ASSERT(src0->ne[0] == src0->ne[1]);
GGML_ASSERT(src0->ne[1] == src1->ne[1]);
GGML_ASSERT(src0->ne[2] == src1->ne[2]);
GGML_ASSERT(src0->ne[3] == src1->ne[3]);
const int n = src0->ne[0];
const int k = src1->ne[0];
const int64_t ne02 = src0->ne[2];
const int64_t ne03 = src0->ne[3];
if (n <= MAX_N_FAST && k <= MAX_K_FAST) {
solve_tri_f32_cuda(
(const float *) src0->data,
(const float *) src1->data,
(float *) dst->data,
n, k,
ne02, ne03,
src0->nb[2] / sizeof(float),
src0->nb[3] / sizeof(float),
src1->nb[2] / sizeof(float),
src1->nb[3] / sizeof(float),
dst->nb[2] / sizeof(float),
dst->nb[3] / sizeof(float),
ctx.stream());
return;
}
if (dst->data != src1->data) {
const int64_t total_batches = ne02 * ne03;
const size_t X_size = (size_t) n * (size_t) k * (size_t) total_batches * sizeof(float);
CUDA_CHECK(cudaMemcpyAsync(dst->data, src1->data, X_size, cudaMemcpyDeviceToDevice, ctx.stream()));
}
solve_tri_f32_cublas(
ctx,
(const float *) src0->data,
(float *) dst->data,
n, k,
ne02, ne03,
src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float),
dst->nb[2] / sizeof(float), dst->nb[3] / sizeof(float),
ctx.stream());
}

View File

@ -0,0 +1,3 @@
#include "common.cuh"
void ggml_cuda_op_solve_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@ -0,0 +1,608 @@
#include "ssm-conv.cuh"
#define CUDA_SSM_CONV_BLOCK_SIZE 256
template <int split_n_t>
static __global__ void ssm_conv_single_seq_f32(
const float * src0,
const float * src1,
const float * src2,
float * dst_x,
int nc,
int nr,
int n_t,
int src0_s0,
int src0_s1,
int src1_s1) {
const int row = blockIdx.x * blockDim.x + threadIdx.x;
if (row >= nr) {
return;
}
const int t0 = blockIdx.y * split_n_t;
if (t0 >= n_t) {
return;
}
const float * state_row = src0 + (size_t) row * src0_s1;
const float * c_row = src2 + (size_t) row * nc;
#pragma unroll
for (int it = 0; it < split_n_t; ++it) {
const int t = t0 + it;
if (t >= n_t) {
break;
}
float sumf = 0.0f;
for (int j = 0; j < nc; ++j) {
const int idx = t + j;
const float x = idx < nc - 1
? state_row[(size_t) idx * src0_s0]
: src1[row + (size_t) (idx - (nc - 1)) * src1_s1];
sumf += x * c_row[j];
}
dst_x[row + (size_t) t * nr] = sumf;
}
}
template <int split_n_t>
static __global__ void ssm_conv_single_seq_f32_nc4(
const float * src0,
const float * src1,
const float * src2,
float * dst_x,
int nr,
int n_t,
int src0_s0,
int src0_s1,
int src1_s1) {
const int row = blockIdx.x * blockDim.x + threadIdx.x;
if (row >= nr) {
return;
}
const int t0 = blockIdx.y * split_n_t;
if (t0 >= n_t) {
return;
}
const float * state_row = src0 + (size_t) row * src0_s1;
const float * c_row = src2 + (size_t) row * 4;
const float c0 = c_row[0];
const float c1 = c_row[1];
const float c2 = c_row[2];
const float c3 = c_row[3];
#pragma unroll
for (int it = 0; it < split_n_t; ++it) {
const int t = t0 + it;
if (t >= n_t) {
break;
}
const int i0 = t;
const int i1 = t + 1;
const int i2 = t + 2;
const int i3 = t + 3;
const float x0 = i0 < 3 ? state_row[(size_t) i0 * src0_s0] : src1[row + (size_t) (i0 - 3) * src1_s1];
const float x1 = i1 < 3 ? state_row[(size_t) i1 * src0_s0] : src1[row + (size_t) (i1 - 3) * src1_s1];
const float x2 = i2 < 3 ? state_row[(size_t) i2 * src0_s0] : src1[row + (size_t) (i2 - 3) * src1_s1];
const float x3 = i3 < 3 ? state_row[(size_t) i3 * src0_s0] : src1[row + (size_t) (i3 - 3) * src1_s1];
dst_x[row + (size_t) t * nr] = x0 * c0 + x1 * c1 + x2 * c2 + x3 * c3;
}
}
static __global__ void ssm_conv_single_seq_final_state_f32(
const float * src0,
const float * src1,
float * dst_state,
int nc,
int nr,
int n_t,
int src0_s0,
int src0_s1,
int src1_s1) {
const int row = blockIdx.x * blockDim.x + threadIdx.x;
if (row >= nr) {
return;
}
const float * state_row = src0 + (size_t) row * src0_s1;
float * dst_row = dst_state + (size_t) row * nc;
for (int j = 0; j < nc; ++j) {
const int idx = n_t - 1 + j;
dst_row[j] = idx < nc - 1
? state_row[(size_t) idx * src0_s0]
: src1[row + (size_t) (idx - (nc - 1)) * src1_s1];
}
}
static __global__ void ssm_conv_init_states_f32_nc4(
const float * src0,
float * state,
int nr,
int n_kv) {
const int row = blockIdx.x * blockDim.x + threadIdx.x;
const int seq = blockIdx.y;
if (row >= nr || seq >= n_kv) {
return;
}
const float * src_row = src0 + (size_t) seq * nr * 3 + (size_t) row * 3;
float * state_row = state + (size_t) seq * nr * 4 + (size_t) row * 4;
state_row[1] = src_row[0];
state_row[2] = src_row[1];
state_row[3] = src_row[2];
}
static __global__ void ssm_conv_init_states_f32(
const float * src0,
float * state,
int nc,
int nr,
int n_kv) {
const int row = blockIdx.x * blockDim.x + threadIdx.x;
const int seq = blockIdx.y;
if (row >= nr || seq >= n_kv) {
return;
}
const float * src_row = src0 + (size_t) seq * nr * (nc - 1) + (size_t) row * (nc - 1);
float * state_row = state + (size_t) seq * nr * nc + (size_t) row * nc;
for (int i0 = 0; i0 < nc - 1; ++i0) {
state_row[1 + i0] = src_row[i0];
}
}
static __global__ void ssm_conv_validate_unique_seq_map(
const int32_t * src3,
int32_t * seq_ids,
int32_t * seq_seen,
int32_t * fast_path_ok,
int n_t,
int n_kv,
int src3_nb1) {
const int t = blockIdx.x * blockDim.x + threadIdx.x;
if (t >= n_t) {
return;
}
const int32_t * sq = src3 + (size_t) t * src3_nb1;
const int32_t seq0 = sq[0];
if (seq0 < 0 || seq0 >= n_kv) {
atomicExch(fast_path_ok, 0);
return;
}
// Fast path supports one sequence per token (no copy-to-multiple-sequences routing).
if (n_kv > 1) {
const int32_t seq1 = sq[1];
if (seq1 >= 0 && seq1 < n_kv) {
atomicExch(fast_path_ok, 0);
return;
}
}
seq_ids[t] = seq0;
if (atomicAdd(seq_seen + seq0, 1) != 0) {
// Sequence is updated by multiple tokens in the same batch => recurrent dependency across t.
atomicExch(fast_path_ok, 0);
}
}
static __global__ void ssm_conv_multi_seq_unique_f32_kernel(
const float * src0,
const float * src1,
const float * src2,
const int32_t * seq_ids,
const int32_t * fast_path_ok,
float * dst_x,
float * dst_state,
int nc,
int nr,
int n_t,
int src1_nb1) {
if (fast_path_ok != nullptr && fast_path_ok[0] == 0) {
return;
}
const int row = blockIdx.x * blockDim.x + threadIdx.x;
const int t = blockIdx.y;
if (row >= nr || t >= n_t) {
return;
}
const int seq = seq_ids[t];
const float * src_state_row = src0 + (size_t) seq * nr * (nc - 1) + (size_t) row * (nc - 1);
float * state_row = dst_state + (size_t) seq * nr * nc + (size_t) row * nc;
const float * c_row = src2 + (size_t) row * nc;
float sumf = 0.0f;
for (int i0 = 0; i0 < nc - 1; ++i0) {
const float v = src_state_row[i0];
state_row[i0] = v;
sumf += v * c_row[i0];
}
const float x = src1[row + (size_t) t * src1_nb1];
state_row[nc - 1] = x;
sumf += x * c_row[nc - 1];
dst_x[row + (size_t) t * nr] = sumf;
}
static __global__ void ssm_conv_multi_seq_unique_f32_kernel_nc4(
const float * src0,
const float * src1,
const float * src2,
const int32_t * seq_ids,
const int32_t * fast_path_ok,
float * dst_x,
float * dst_state,
int nr,
int n_t,
int src1_nb1) {
if (fast_path_ok != nullptr && fast_path_ok[0] == 0) {
return;
}
const int row = blockIdx.x * blockDim.x + threadIdx.x;
const int t = blockIdx.y;
if (row >= nr || t >= n_t) {
return;
}
const int seq = seq_ids[t];
const float * src_state_row = src0 + (size_t) seq * nr * 3 + (size_t) row * 3;
float * state_row = dst_state + (size_t) seq * nr * 4 + (size_t) row * 4;
const float * c_row = src2 + (size_t) row * 4;
const float s0 = src_state_row[0];
const float s1 = src_state_row[1];
const float s2 = src_state_row[2];
const float x = src1[row + (size_t) t * src1_nb1];
state_row[0] = s0;
state_row[1] = s1;
state_row[2] = s2;
state_row[3] = x;
dst_x[row + (size_t) t * nr] = s0 * c_row[0] + s1 * c_row[1] + s2 * c_row[2] + x * c_row[3];
}
static __global__ void ssm_conv_f32_kernel(
const float * src0,
const float * src1,
const float * src2,
const int32_t * src3,
const int32_t * fast_path_ok,
float * dst_x,
float * dst_state,
int nc,
int nr,
int n_t,
int n_kv,
int src1_nb1,
int src3_nb1) {
if (fast_path_ok != nullptr && fast_path_ok[0] != 0) {
return;
}
const int row = blockIdx.x * blockDim.x + threadIdx.x;
if (row >= nr) {
return;
}
const float * c_row = src2 + (size_t) row * nc;
for (int t = 0; t < n_t; ++t) {
const int32_t * sq = src3 + (size_t) t * src3_nb1;
const int seq0 = sq[0];
if (seq0 < 0 || seq0 >= n_kv) {
continue;
}
float * state_row = dst_state + (size_t) seq0 * nr * nc + (size_t) row * nc;
const float * src_state_row;
if (t == 0) {
src_state_row = src0 + (size_t) seq0 * nr * (nc - 1) + (size_t) row * (nc - 1);
} else {
src_state_row = state_row + 1;
}
for (int i0 = 0; i0 < nc - 1; ++i0) {
state_row[i0] = src_state_row[i0];
}
state_row[nc - 1] = src1[row + (size_t) t * src1_nb1];
for (int i3 = 1; i3 < n_kv; ++i3) {
const int seq = sq[i3];
if (seq < 0 || seq >= n_kv) {
break;
}
float * state_row_copy = dst_state + (size_t) seq * nr * nc + (size_t) row * nc;
for (int i0 = 0; i0 < nc; ++i0) {
state_row_copy[i0] = state_row[i0];
}
}
float sumf = 0.0f;
for (int i0 = 0; i0 < nc; ++i0) {
sumf += state_row[i0] * c_row[i0];
}
dst_x[row + (size_t) t * nr] = sumf;
}
}
template <bool has_multi_seq>
static __global__ void ssm_conv_f32_kernel_nc4(
const float * src0,
const float * src1,
const float * src2,
const int32_t * src3,
const int32_t * fast_path_ok,
float * dst_x,
float * dst_state,
int nr,
int n_t,
int n_kv,
int src1_nb1,
int src3_nb1) {
if (fast_path_ok != nullptr && fast_path_ok[0] != 0) {
return;
}
const int row = blockIdx.x * blockDim.x + threadIdx.x;
if (row >= nr) {
return;
}
const float * c_row = src2 + (size_t) row * 4;
const float c0 = c_row[0];
const float c1 = c_row[1];
const float c2 = c_row[2];
const float c3 = c_row[3];
for (int t = 0; t < n_t; ++t) {
const int32_t * sq = src3 + (size_t) t * src3_nb1;
const int seq0 = sq[0];
if (seq0 < 0 || seq0 >= n_kv) {
continue;
}
float * state_row = dst_state + (size_t) seq0 * nr * 4 + (size_t) row * 4;
const float * src_state_row;
if (t == 0) {
src_state_row = src0 + (size_t) seq0 * nr * 3 + (size_t) row * 3;
} else {
src_state_row = state_row + 1;
}
const float s0 = src_state_row[0];
const float s1 = src_state_row[1];
const float s2 = src_state_row[2];
const float x = src1[row + (size_t) t * src1_nb1];
state_row[0] = s0;
state_row[1] = s1;
state_row[2] = s2;
state_row[3] = x;
if constexpr (has_multi_seq) {
for (int i3 = 1; i3 < n_kv; ++i3) {
const int seq = sq[i3];
if (seq < 0 || seq >= n_kv) {
break;
}
float * state_row_copy = dst_state + (size_t) seq * nr * 4 + (size_t) row * 4;
state_row_copy[0] = s0;
state_row_copy[1] = s1;
state_row_copy[2] = s2;
state_row_copy[3] = x;
}
}
dst_x[row + (size_t) t * nr] = s0 * c0 + s1 * c1 + s2 * c2 + x * c3;
}
}
void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0]; // conv_state: [d_conv - 1, d_inner, n_kv]
const ggml_tensor * src1 = dst->src[1]; // x: [d_inner, n_tokens]
const ggml_tensor * src2 = dst->src[2]; // conv1d.weight: [d_conv, d_inner]
const ggml_tensor * src3 = dst->src[3]; // state_seq: [n_kv, n_tokens]
const int nc = src2->ne[0];
const int nr = src0->ne[1];
const int n_t = src1->ne[1];
const int n_kv = src0->ne[2];
GGML_ASSERT((int64_t) nr * n_t + (int64_t) nc * nr * n_kv == ggml_nelements(dst));
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(src2->type == GGML_TYPE_F32);
GGML_ASSERT(src3->type == GGML_TYPE_I32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT(src1->nb[0] == sizeof(float));
GGML_ASSERT(src2->nb[0] == sizeof(float));
GGML_ASSERT(src3->nb[0] == sizeof(int32_t));
GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float));
GGML_ASSERT(src2->nb[1] == src2->ne[0] * sizeof(float));
GGML_ASSERT(src2->nb[2] == src2->ne[1] * src2->ne[0] * sizeof(float));
GGML_ASSERT(src2->ne[0] == src0->ne[0] + 1);
GGML_ASSERT(src2->ne[1] == src0->ne[1]);
GGML_ASSERT(src1->ne[0] == src0->ne[1]);
GGML_ASSERT(src3->ne[0] == src0->ne[2]);
GGML_ASSERT(src3->ne[1] == src1->ne[1]);
float * dst_data = (float *) dst->data;
float * dst_x = dst_data;
float * dst_state = dst_data + (size_t) nr * n_t;
const dim3 block_dims(CUDA_SSM_CONV_BLOCK_SIZE, 1, 1);
const dim3 row_grid((nr + CUDA_SSM_CONV_BLOCK_SIZE - 1) / CUDA_SSM_CONV_BLOCK_SIZE, 1, 1);
ggml_cuda_pool_alloc<int32_t> fast_path_ok_d(ctx.pool());
const int32_t * multi_seq_fast_path_ok = nullptr;
// Fast path for single-sequence recurrent updates (Qwen3Next prompt/decode path).
// In this case, outputs are independent given the initial conv state, so we parallelize over token blocks.
if (n_kv == 1 && src3->ne[0] == 1) {
GGML_ASSERT(n_t > 0);
const int src0_s0 = src0->nb[0] / sizeof(float);
const int src0_s1 = src0->nb[1] / sizeof(float);
const int src1_s1 = src1->nb[1] / sizeof(float);
constexpr int split_n_t = 32;
const dim3 token_grid(row_grid.x, (n_t + split_n_t - 1) / split_n_t, 1);
if (nc == 4) {
ssm_conv_single_seq_f32_nc4<split_n_t><<<token_grid, block_dims, 0, ctx.stream()>>>(
(const float *) src0->data,
(const float *) src1->data,
(const float *) src2->data,
dst_x,
nr, n_t,
src0_s0, src0_s1, src1_s1);
} else {
ssm_conv_single_seq_f32<split_n_t><<<token_grid, block_dims, 0, ctx.stream()>>>(
(const float *) src0->data,
(const float *) src1->data,
(const float *) src2->data,
dst_x,
nc, nr, n_t,
src0_s0, src0_s1, src1_s1);
}
ssm_conv_single_seq_final_state_f32<<<row_grid, block_dims, 0, ctx.stream()>>>(
(const float *) src0->data,
(const float *) src1->data,
dst_state,
nc, nr, n_t,
src0_s0, src0_s1, src1_s1);
return;
}
if (n_kv > 1) {
const dim3 init_grid(row_grid.x, n_kv, 1);
if (nc == 4) {
ssm_conv_init_states_f32_nc4<<<init_grid, block_dims, 0, ctx.stream()>>>(
(const float *) src0->data,
dst_state,
nr, n_kv);
} else {
ssm_conv_init_states_f32<<<init_grid, block_dims, 0, ctx.stream()>>>(
(const float *) src0->data,
dst_state,
nc, nr, n_kv);
}
// Fast path for multi-sequence decode-like batches:
// one token per unique sequence, no copy-to-multiple-sequences routing.
ggml_cuda_pool_alloc<int32_t> seq_ids(ctx.pool(), n_t);
ggml_cuda_pool_alloc<int32_t> seq_seen(ctx.pool(), n_kv);
int32_t fast_path_ok = 1;
fast_path_ok_d.alloc(1);
CUDA_CHECK(cudaMemsetAsync(seq_seen.get(), 0, n_kv * sizeof(int32_t), ctx.stream()));
CUDA_CHECK(cudaMemcpyAsync(fast_path_ok_d.get(), &fast_path_ok, sizeof(int32_t), cudaMemcpyHostToDevice, ctx.stream()));
constexpr int seq_map_block_size = 256;
const dim3 seq_map_grid((n_t + seq_map_block_size - 1) / seq_map_block_size, 1, 1);
ssm_conv_validate_unique_seq_map<<<seq_map_grid, seq_map_block_size, 0, ctx.stream()>>>(
(const int32_t *) src3->data,
seq_ids.get(),
seq_seen.get(),
fast_path_ok_d.get(),
n_t,
n_kv,
src3->nb[1] / sizeof(int32_t));
CUDA_CHECK(cudaGetLastError());
multi_seq_fast_path_ok = fast_path_ok_d.get();
const dim3 token_grid(row_grid.x, n_t, 1);
if (nc == 4) {
ssm_conv_multi_seq_unique_f32_kernel_nc4<<<token_grid, block_dims, 0, ctx.stream()>>>(
(const float *) src0->data,
(const float *) src1->data,
(const float *) src2->data,
seq_ids.get(),
multi_seq_fast_path_ok,
dst_x,
dst_state,
nr, n_t,
src1->nb[1] / sizeof(float));
} else {
ssm_conv_multi_seq_unique_f32_kernel<<<token_grid, block_dims, 0, ctx.stream()>>>(
(const float *) src0->data,
(const float *) src1->data,
(const float *) src2->data,
seq_ids.get(),
multi_seq_fast_path_ok,
dst_x,
dst_state,
nc, nr, n_t,
src1->nb[1] / sizeof(float));
}
}
if (nc == 4) {
if (n_kv > 1) {
ssm_conv_f32_kernel_nc4<true><<<row_grid, block_dims, 0, ctx.stream()>>>(
(const float *) src0->data,
(const float *) src1->data,
(const float *) src2->data,
(const int32_t *) src3->data,
multi_seq_fast_path_ok,
dst_x,
dst_state,
nr, n_t, n_kv,
src1->nb[1] / sizeof(float),
src3->nb[1] / sizeof(int32_t));
} else {
ssm_conv_f32_kernel_nc4<false><<<row_grid, block_dims, 0, ctx.stream()>>>(
(const float *) src0->data,
(const float *) src1->data,
(const float *) src2->data,
(const int32_t *) src3->data,
nullptr,
dst_x,
dst_state,
nr, n_t, n_kv,
src1->nb[1] / sizeof(float),
src3->nb[1] / sizeof(int32_t));
}
} else {
ssm_conv_f32_kernel<<<row_grid, block_dims, 0, ctx.stream()>>>(
(const float *) src0->data,
(const float *) src1->data,
(const float *) src2->data,
(const int32_t *) src3->data,
multi_seq_fast_path_ok,
dst_x,
dst_state,
nc, nr, n_t, n_kv,
src1->nb[1] / sizeof(float),
src3->nb[1] / sizeof(int32_t));
}
}

View File

@ -0,0 +1,3 @@
#include "common.cuh"
void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@ -16,6 +16,25 @@ static __global__ void k_sum_rows_f32(const float * x, float * dst, const int nc
}
}
static __global__ void k_sum_rows_nc_f32(const char * x, char * y, const int ncols,
size_t nb00, size_t nb01, size_t nb02, size_t nb03, size_t nb1, size_t nb2, size_t nb3) {
const char * src = x + nb03*blockIdx.z + nb02*blockIdx.y + nb01*blockIdx.x;
float * dst = (float *)(y + nb3*blockIdx.z + nb2*blockIdx.y + nb1*blockIdx.x);
const int col = threadIdx.x;
float sum = 0.0f;
for (int i = col; i < ncols; i += blockDim.x) {
sum += *(const float *)(src + i*nb00);
}
sum = warp_reduce_sum(sum);
if (col == 0) {
dst[0] = sum;
}
}
static __global__ void k_sum_rows_div_f32(const float * __restrict__ x, float * __restrict__ dst, const int ncols, float s, float b) {
const int row = blockIdx.x;
const int col = threadIdx.x;
@ -43,6 +62,12 @@ void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int
const dim3 block_nums(nrows, 1, 1);
k_sum_rows_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
}
static void sum_rows_f32_cuda_nc(const char * x, char * dst, int ne0, int ne1, int ne2, int ne3,
size_t nb00, size_t nb01, size_t nb02, size_t nb03, size_t nb1, size_t nb2, size_t nb3, cudaStream_t stream) {
const dim3 block_dims(WARP_SIZE, 1, 1);
const dim3 block_nums(ne1, ne2, ne3);
k_sum_rows_nc_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ne0, nb00, nb01, nb02, nb03, nb1, nb2, nb3);
}
static void sum_rows_div_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, float s, float b, cudaStream_t stream) {
const dim3 block_dims(WARP_SIZE, 1, 1);
@ -52,19 +77,30 @@ static void sum_rows_div_f32_cuda(const float * x, float * dst, const int ncols,
void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(src0));
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
const int64_t ncols = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);
sum_rows_f32_cuda(src0_d, dst_d, ncols, nrows, stream);
}
void ggml_cuda_op_sum_rows_nc(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src = dst->src[0]->src[0];
GGML_ASSERT(src->op == GGML_OP_TRANSPOSE);
GGML_ASSERT(dst->type == GGML_TYPE_F32 && src->type == GGML_TYPE_F32);
cudaStream_t stream = ctx.stream();
sum_rows_f32_cuda_nc((const char *)src->data, (char *)dst->data, src->ne[0], src->ne[1], src->ne[2], src->ne[3],
src->nb[0], src->nb[1], src->nb[2], src->nb[3], dst->nb[1], dst->nb[2], dst->nb[3], stream);
}
void ggml_cuda_op_sum_rows_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {

View File

@ -5,3 +5,5 @@ void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream);
void ggml_cuda_op_sum_rows_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_sum_rows_nc(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

120
ggml/src/ggml-cuda/tri.cu Normal file
View File

@ -0,0 +1,120 @@
#include "tri.cuh"
#include "convert.cuh"
template<typename T, bool prefix_keep, int add_to_split>
static __global__ void tri_kernel(
const T * src, T * dst,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3) {
const int64_t i3 = blockIdx.z;
const int64_t i2 = blockIdx.y;
const int64_t i1 = blockIdx.x;
const int64_t split_point = i1 + add_to_split;
(void) nb00;
(void) nb0;
if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
return;
}
const T * src_row = src + i1 * nb01 + i2 * nb02 + i3 * nb03;
T * dst_row = dst + i1 * nb1 + i2 * nb2 + i3 * nb3;
if constexpr (prefix_keep) {
for (int64_t i0 = threadIdx.x; i0 < split_point; i0 += blockDim.x) {
dst_row[i0] = src_row[i0];
}
for (int64_t i0 = threadIdx.x + split_point; i0 < ne00; i0 += blockDim.x) {
dst_row[i0] = ggml_cuda_cast<T, float>(0.0f);
}
} else {
for (int64_t i0 = threadIdx.x; i0 < split_point; i0 += blockDim.x) {
dst_row[i0] = ggml_cuda_cast<T, float>(0.0f);
}
for (int64_t i0 = threadIdx.x + split_point; i0 < ne00; i0 += blockDim.x) {
dst_row[i0] = src_row[i0];
}
}
}
template<typename T>
static void tri_cuda(
const T * src, T * dst,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3,
const ggml_tri_type ttype,
cudaStream_t stream) {
dim3 block_dims(CUDA_TRI_BLOCK_SIZE, 1, 1);
dim3 grid_dims(ne01, ne02, ne03);
const size_t type_size = sizeof(T);
const int add_to_split = (ttype == GGML_TRI_TYPE_LOWER_DIAG || ttype == GGML_TRI_TYPE_UPPER) ? 1 : 0;
const bool prefix_keep = (ttype == GGML_TRI_TYPE_LOWER || ttype == GGML_TRI_TYPE_LOWER_DIAG);
if (prefix_keep) {
if (add_to_split == 0) {
tri_kernel<T, true, 0><<<grid_dims, block_dims, 0, stream>>>(
src, dst,
ne00, ne01, ne02, ne03,
nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size,
nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size
);
} else {
tri_kernel<T, true, 1><<<grid_dims, block_dims, 0, stream>>>(
src, dst,
ne00, ne01, ne02, ne03,
nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size,
nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size
);
}
} else {
if (add_to_split == 0) {
tri_kernel<T, false, 0><<<grid_dims, block_dims, 0, stream>>>(
src, dst,
ne00, ne01, ne02, ne03,
nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size,
nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size
);
} else {
tri_kernel<T, false, 1><<<grid_dims, block_dims, 0, stream>>>(
src, dst,
ne00, ne01, ne02, ne03,
nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size,
nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size
);
}
}
}
void ggml_cuda_op_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tri_type ttype = static_cast<ggml_tri_type>(((const int32_t *) dst->op_params)[0]);
GGML_ASSERT(src0->type == dst->type);
switch (src0->type) {
case GGML_TYPE_F32:
tri_cuda(
(const float *) src0->data, (float *) dst->data,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
ttype, ctx.stream()
);
break;
case GGML_TYPE_F16:
tri_cuda(
(const half *) src0->data, (half *) dst->data,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
ttype, ctx.stream()
);
break;
default:
GGML_ABORT("fatal error");
}
}

View File

@ -0,0 +1,5 @@
#include "common.cuh"
#define CUDA_TRI_BLOCK_SIZE 256
void ggml_cuda_op_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@ -735,6 +735,10 @@ static __device__ __forceinline__ float op_exp(float x) {
return expf(x);
}
static __device__ __forceinline__ float op_softplus(float x) {
return (x > 20.0f) ? x : logf(1.0f + expf(x));
}
static __device__ __forceinline__ float op_sin(float x) {
return sinf(x);
}
@ -831,6 +835,10 @@ void ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_exp>(ctx, dst);
}
void ggml_cuda_op_softplus(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_softplus>(ctx, dst);
}
// === gated ops
template <float (*op)(float), typename T>
@ -942,4 +950,3 @@ void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_elu>(ctx, dst);
}

View File

@ -53,6 +53,8 @@ void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst
void ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_softplus(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

File diff suppressed because it is too large Load Diff

View File

@ -47,10 +47,10 @@
#define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
#define LLAMA_SESSION_VERSION 8
#define LLAMA_SESSION_VERSION 9
#define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
#define LLAMA_STATE_SEQ_VERSION 2
#define LLAMA_STATE_SEQ_VERSION 3
#ifdef __cplusplus
extern "C" {

View File

@ -27,6 +27,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_QWEN2VL, "qwen2vl" },
{ LLM_ARCH_QWEN3, "qwen3" },
{ LLM_ARCH_QWEN3MOE, "qwen3moe" },
{ LLM_ARCH_QWEN3NEXT, "qwen3next" },
{ LLM_ARCH_QWEN3VL, "qwen3vl" },
{ LLM_ARCH_QWEN3VLMOE, "qwen3vlmoe" },
{ LLM_ARCH_PHI2, "phi2" },
@ -186,6 +187,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_SSM_INNER_SIZE, "%s.ssm.inner_size" },
{ LLM_KV_SSM_STATE_SIZE, "%s.ssm.state_size" },
{ LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" },
{ LLM_KV_SSM_GROUP_COUNT, "%s.ssm.group_count" },
{ LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" },
{ LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" },
@ -242,4 +244,3 @@ const char * llama_model_arch_name(llm_arch arch) {
}
return it->second;
}

View File

@ -26,6 +26,7 @@ enum llm_arch {
LLM_ARCH_QWEN2VL,
LLM_ARCH_QWEN3,
LLM_ARCH_QWEN3MOE,
LLM_ARCH_QWEN3NEXT,
LLM_ARCH_QWEN3VL,
LLM_ARCH_QWEN3VLMOE,
LLM_ARCH_PHI2,
@ -180,6 +181,7 @@ enum llm_kv {
LLM_KV_SSM_CONV_KERNEL,
LLM_KV_SSM_STATE_SIZE,
LLM_KV_SSM_TIME_STEP_RANK,
LLM_KV_SSM_GROUP_COUNT,
LLM_KV_TOKENIZER_MODEL,
LLM_KV_TOKENIZER_PRE,
@ -278,8 +280,11 @@ enum llm_tensor {
LLM_TENSOR_SSM_X,
LLM_TENSOR_SSM_DT,
LLM_TENSOR_SSM_A,
LLM_TENSOR_SSM_A_NOSCAN,
LLM_TENSOR_SSM_D,
LLM_TENSOR_SSM_NORM,
LLM_TENSOR_SSM_OUT,
LLM_TENSOR_SSM_BETA_ALPHA,
LLM_TENSOR_ATTN_Q_A,
LLM_TENSOR_ATTN_Q_B,
LLM_TENSOR_ATTN_KV_A_MQA,

View File

@ -6,6 +6,28 @@
#include "ggml.h"
#include <unordered_set>
#include <algorithm>
static inline uint32_t llama_kv_qnext_state_slots(const llama_kv_cache & kv_self) {
uint32_t n_slots = 0;
for (const ggml_tensor * t : kv_self.s_l) {
if (t == nullptr) {
continue;
}
const uint32_t layer_slots = (uint32_t) t->ne[1];
if (n_slots == 0) {
n_slots = layer_slots;
} else {
GGML_ASSERT(n_slots == layer_slots);
}
}
return n_slots;
}
llm_build_context::llm_build_context(
llama_context & lctx,
const llama_batch & batch,
@ -84,6 +106,7 @@ void llm_build_context::init() {
lctx.inp_s_copy = nullptr;
lctx.inp_s_mask = nullptr;
lctx.inp_s_seq = nullptr;
lctx.inp_s_seq_qnext = nullptr;
lctx.inp_pos_bucket = nullptr;
lctx.inp_embd_enc = nullptr;
lctx.inp_KQ_mask_cross = nullptr;
@ -118,6 +141,12 @@ ggml_cgraph * llm_build_context::build_k_shift() {
ggml_set_input(lctx.inp_K_shift);
for (int il = 0; il < n_layer; ++il) {
if (model.arch == LLM_ARCH_QWEN3NEXT && hparams.is_recurrent(il)) {
continue;
}
if (kv_self.k_l[il] == nullptr) {
continue;
}
const int64_t n_head_kv = hparams.n_head_kv(il);
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
struct ggml_tensor * rope_factors = build_rope_factors(il);
@ -161,21 +190,34 @@ ggml_cgraph * llm_build_context::build_k_shift() {
ggml_cgraph * llm_build_context::build_s_copy() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
GGML_ASSERT(kv_self.recurrent);
const uint32_t qnext_state_slots = llama_kv_qnext_state_slots(kv_self);
const bool has_qnext_state = qnext_state_slots > 0;
GGML_ASSERT(kv_self.recurrent || has_qnext_state);
struct ggml_tensor * state_copy = build_inp_s_copy();
for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size);
struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size);
if (kv_self.recurrent) {
struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size);
struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size);
conv_states = ggml_get_rows(ctx0, conv_states, state_copy);
ssm_states = ggml_get_rows(ctx0, ssm_states, state_copy);
conv_states = ggml_get_rows(ctx0, conv_states, state_copy);
ssm_states = ggml_get_rows(ctx0, ssm_states, state_copy);
// TODO: name the intermediate tensors with cb()
// TODO: name the intermediate tensors with cb()
ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_states, kv_self.k_l[il]));
ggml_build_forward_expand(gf, ggml_cpy(ctx0, ssm_states, kv_self.v_l[il]));
ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_states, kv_self.k_l[il]));
ggml_build_forward_expand(gf, ggml_cpy(ctx0, ssm_states, kv_self.v_l[il]));
}
if (kv_self.s_l.size() > (size_t) il && kv_self.s_l[il] != nullptr) {
struct ggml_tensor * qnext_states_all = ggml_reshape_2d(ctx0, kv_self.s_l[il], hparams.n_embd_v_s(), kv_self.s_l[il]->ne[1]);
GGML_ASSERT((uint32_t) qnext_states_all->ne[1] == qnext_state_slots);
struct ggml_tensor * qnext_state_copy = ggml_view_1d(ctx0, state_copy, qnext_state_slots, 0);
struct ggml_tensor * qnext_states = ggml_get_rows(ctx0, qnext_states_all, qnext_state_copy);
ggml_build_forward_expand(gf, ggml_cpy(ctx0, qnext_states, kv_self.s_l[il]));
}
}
return gf;
@ -198,6 +240,12 @@ ggml_cgraph * llm_build_context::build_defrag(const std::vector<uint32_t> & ids)
}
for (int il = 0; il < n_layer; ++il) {
if (model.arch == LLM_ARCH_QWEN3NEXT && hparams.is_recurrent(il)) {
continue;
}
if (kv_self.k_l[il] == nullptr) {
continue;
}
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
@ -214,7 +262,7 @@ ggml_cgraph * llm_build_context::build_defrag(const std::vector<uint32_t> & ids)
ggml_tensor * view_v_src = nullptr;
ggml_tensor * view_v_dst = nullptr;
if (kv_self.v_l.size() > il) {
if (kv_self.v_l.size() > il && kv_self.v_l[il] != nullptr) {
// Note: with MLA the V cache may not be present.
if (flash_attn) {
// NOTE: the V cache is not transposed when using flash attention
@ -509,12 +557,12 @@ void llm_build_context::llm_build_kv_store(
struct ggml_tensor * v_cache_view = nullptr;
if (cparams.flash_attn) {
if (!kv.v_trans) {
v_cache_view = ggml_view_1d(ctx, kv.v_l[il], n_tokens*n_embd_v_gqa,
(kv_head)*ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa));
lctx.cache_copies[2*il+1].step = ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa);
} else {
// note: the V cache is transposed when not using flash attention
// note: the V cache is transposed for legacy non-FA layouts
v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_v_gqa,
( n_ctx)*ggml_element_size(kv.v_l[il]),
(kv_head)*ggml_element_size(kv.v_l[il]));
@ -1454,12 +1502,21 @@ static ggml_tensor * llm_build_kqv(
} else {
// split cached v into n_head heads
struct ggml_tensor * v =
ggml_view_3d(ctx, kv.v_l[il],
struct ggml_tensor * v;
if (kv.v_trans) {
v = ggml_view_3d(ctx, kv.v_l[il],
n_kv, n_embd_head_v, n_head_kv,
ggml_element_size(kv.v_l[il])*n_ctx,
ggml_element_size(kv.v_l[il])*n_ctx*n_embd_head_v,
0);
} else {
v = ggml_view_3d(ctx, kv.v_l[il],
n_embd_head_v, n_kv, n_head_kv,
ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa),
ggml_row_size(kv.v_l[il]->type, n_embd_head_v),
0);
v = ggml_cont(ctx, ggml_transpose(ctx, v));
}
cb(v, "v", il);
auto kq_size = k->ne[1]*q->ne[1]*q->ne[2]*sizeof(float)/(1024*1024);
@ -4248,6 +4305,822 @@ ggml_cgraph * llm_build_context::build_qwen3moe() {
return gf;
}
ggml_cgraph * llm_build_context::build_qwen3next() {
static constexpr int QWEN3NEXT_CHUNK_SIZE = 64;
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
GGML_ASSERT(batch.n_tokens > 0);
const bool has_explicit_seq_info = batch.n_seq_id != nullptr && batch.seq_id != nullptr;
std::vector<llama_seq_id> token_seq_ids(batch.n_tokens, 0);
for (int i = 0; i < batch.n_tokens; ++i) {
if (has_explicit_seq_info) {
GGML_ASSERT(batch.n_seq_id[i] > 0 && "qwen3next expects each token to belong to at least one sequence");
GGML_ASSERT(batch.n_seq_id[i] == 1 && "qwen3next does not support multi-sequence tokens yet");
token_seq_ids[i] = batch.seq_id[i][0];
} else {
token_seq_ids[i] = 0;
}
}
const llama_seq_id seq_id = token_seq_ids[0];
const bool all_same_seq = std::all_of(token_seq_ids.begin(), token_seq_ids.end(), [&](llama_seq_id s) {
return s == seq_id;
});
bool has_unique_seq_ids = true;
if (!all_same_seq) {
std::unordered_set<llama_seq_id> seen;
seen.reserve(token_seq_ids.size());
for (llama_seq_id s : token_seq_ids) {
if (!seen.insert(s).second) {
has_unique_seq_ids = false;
break;
}
}
}
GGML_ASSERT(hparams.ssm_n_group > 0);
GGML_ASSERT(hparams.ssm_dt_rank > 0);
GGML_ASSERT(hparams.ssm_d_conv > 0);
GGML_ASSERT(hparams.ssm_d_inner % hparams.ssm_dt_rank == 0);
const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
const int64_t head_k_dim = hparams.ssm_d_state;
const int64_t num_k_heads = hparams.ssm_n_group;
const int64_t num_v_heads = hparams.ssm_dt_rank;
const int64_t head_v_dim = hparams.ssm_d_inner / num_v_heads;
const int64_t key_dim = head_k_dim * num_k_heads;
const int64_t value_dim = head_v_dim * num_v_heads;
const int64_t conv_dim = key_dim * 2 + value_dim;
const int64_t conv_state_dim = (hparams.ssm_d_conv - 1) * conv_dim;
const int64_t ssm_state_dim = head_v_dim * head_v_dim * num_v_heads;
const int64_t state_dim = conv_state_dim + ssm_state_dim;
const uint32_t qnext_state_slots = llama_kv_qnext_state_slots(kv_self);
GGML_ASSERT(qnext_state_slots > 0);
GGML_ASSERT(hparams.n_embd_v_s() == (uint32_t) state_dim);
// Reserve-graph builds may not carry explicit sequence IDs, in which case
// the fallback sequence slot is 0.
const uint32_t state_seq_id = (uint32_t) seq_id;
for (llama_seq_id s : token_seq_ids) {
GGML_ASSERT(s >= 0);
GGML_ASSERT((uint32_t) s < qnext_state_slots);
}
const bool reset_state = batch.pos != nullptr && batch.pos[0] == 0;
auto get_slice_2d = [&](ggml_tensor * t, int64_t c) -> ggml_tensor * {
return ggml_view_4d(ctx0, t, t->ne[0], t->ne[1], 1, t->ne[3],
t->nb[1], t->nb[2], t->nb[3], t->nb[2] * c);
};
auto build_delta_net_chunking = [&](ggml_tensor * q, ggml_tensor * k, ggml_tensor * v,
ggml_tensor * g, ggml_tensor * beta, ggml_tensor * state,
ggml_tensor * causal_mask, ggml_tensor * identity,
ggml_tensor * diag_mask, int il) -> std::pair<ggml_tensor *, ggml_tensor *> {
const int64_t S_k = q->ne[0];
const int64_t H_k = q->ne[1];
const int64_t n_tokens = q->ne[2];
const int64_t n_seqs = q->ne[3];
const int64_t S_v = v->ne[0];
const int64_t H_v = v->ne[1];
GGML_ASSERT(n_seqs == 1);
GGML_ASSERT(v->ne[2] == n_tokens);
GGML_ASSERT(k->ne[2] == n_tokens);
GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v && state->ne[2] == H_v && state->ne[3] == n_seqs);
GGML_ASSERT(H_k == H_v);
const float eps_norm = hparams.f_norm_rms_eps;
q = ggml_l2_norm(ctx0, q, eps_norm);
k = ggml_l2_norm(ctx0, k, eps_norm);
const float scale = 1.0f / sqrtf(S_v);
q = ggml_scale(ctx0, q, scale);
beta = ggml_sigmoid(ctx0, beta);
cb(q, "q_in", il);
cb(k, "k_in", il);
cb(v, "v_in", il);
cb(beta, "beta_in", il);
cb(g, "g_in", il);
q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs);
k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs);
v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_v, n_seqs);
beta = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3));
cb(q, "q_perm", il);
cb(k, "k_perm", il);
cb(v, "v_perm", il);
cb(beta, "beta_perm", il);
cb(g, "g_perm", il);
cb(state,"state_in", il);
const int64_t chunk_size = QWEN3NEXT_CHUNK_SIZE;
const int64_t pad = (chunk_size - n_tokens % chunk_size) % chunk_size;
const int64_t n_chunks = (n_tokens + pad) / chunk_size;
q = ggml_pad(ctx0, q, 0, pad, 0, 0);
k = ggml_pad(ctx0, k, 0, pad, 0, 0);
v = ggml_pad(ctx0, v, 0, pad, 0, 0);
g = ggml_pad(ctx0, g, pad, 0, 0, 0);
beta = ggml_pad(ctx0, beta, 0, pad, 0, 0);
cb(q, "q_pad", il);
cb(k, "k_pad", il);
cb(v, "v_pad", il);
cb(beta, "beta_pad", il);
cb(g, "g_pad", il);
ggml_tensor * v_beta = ggml_mul(ctx0, v, beta);
ggml_tensor * k_beta = ggml_mul(ctx0, ggml_repeat_4d(ctx0, beta, k->ne[0], beta->ne[1], beta->ne[2], beta->ne[3]), k);
cb(v_beta, "v_beta", il);
cb(k_beta, "k_beta", il);
q = ggml_reshape_4d(ctx0, q, S_k, chunk_size, n_chunks, H_k * n_seqs);
k = ggml_reshape_4d(ctx0, k, S_k, chunk_size, n_chunks, H_k * n_seqs);
k_beta = ggml_reshape_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, H_v * n_seqs);
v = ggml_reshape_4d(ctx0, v, S_v, chunk_size, n_chunks, H_v * n_seqs);
v_beta = ggml_reshape_4d(ctx0, v_beta, S_v, chunk_size, n_chunks, H_v * n_seqs);
g = ggml_reshape_4d(ctx0, g, chunk_size, 1, n_chunks, H_v * n_seqs);
beta = ggml_reshape_4d(ctx0, beta, 1, chunk_size, n_chunks, H_v * n_seqs);
ggml_tensor * g_cumsum = ggml_cumsum(ctx0, g);
cb(g_cumsum, "g_cumsum", il);
ggml_tensor * gcs_i =
ggml_repeat_4d(ctx0, g_cumsum, chunk_size, chunk_size, n_chunks, H_v * n_seqs);
ggml_tensor * gcs_j = ggml_reshape_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_v * n_seqs);
ggml_tensor * gcs_j_broadcast =
ggml_repeat_4d(ctx0, gcs_j, chunk_size, chunk_size, n_chunks, H_v * n_seqs);
ggml_tensor * decay_mask = ggml_sub(ctx0, gcs_j_broadcast, gcs_i);
cb(decay_mask, "decay_mask", il);
decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
decay_mask = ggml_exp(ctx0, decay_mask);
decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
ggml_tensor * kmulkbeta = ggml_mul_mat(ctx0, k, k_beta);
cb(kmulkbeta, "kk_beta", il);
ggml_tensor * k_decay = ggml_mul(ctx0, kmulkbeta, decay_mask);
ggml_tensor * attn = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, causal_mask));
cb(attn, "attn_pre_solve", il);
ggml_tensor * attn_lower = ggml_mul(ctx0, attn, causal_mask);
ggml_tensor * identity_repeat =
ggml_repeat_4d(ctx0, identity, attn_lower->ne[0], attn_lower->ne[1], attn_lower->ne[2], attn_lower->ne[3]);
ggml_tensor * lhs = ggml_neg(ctx0, ggml_sub(ctx0, attn_lower, identity_repeat));
ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false);
attn = ggml_mul(ctx0, lin_solve, causal_mask);
attn = ggml_add(ctx0, attn, identity);
cb(attn, "attn_solved", il);
v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), attn);
cb(v, "v_beta", il);
ggml_tensor * g_cumsum_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum));
cb(g_cumsum_t, "g_cumsum_t", il);
ggml_tensor * gexp = ggml_exp(ctx0, g_cumsum_t);
cb(gexp, "gexp", il);
ggml_tensor * kbeta_gexp = ggml_mul(ctx0, k_beta, gexp);
cb(kbeta_gexp, "kbeta_gexp", il);
auto attn_kbeta = ggml_mul_mat(ctx0, attn, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gexp)));
cb(attn_kbeta, "attn_kbeta", il);
ggml_tensor * k_cumdecay = ggml_cont(ctx0, ggml_transpose(ctx0, attn_kbeta));
cb(k_cumdecay, "k_cumdecay", il);
ggml_tensor * attn_kq = ggml_mul_mat(ctx0, k, q);
cb(attn_kq, "attn_kq_pre", il);
attn_kq = ggml_mul(ctx0, decay_mask, attn_kq);
attn_kq = ggml_mul(ctx0, attn_kq, diag_mask);
cb(attn_kq, "attn_kq", il);
ggml_tensor * g_last = ggml_view_4d(ctx0, g_cumsum, 1, 1, g_cumsum->ne[2], g_cumsum->ne[3],
g_cumsum->nb[1], g_cumsum->nb[2], g_cumsum->nb[3],
(g_cumsum->ne[0] - 1) * ggml_element_size(g_cumsum));
g_last = ggml_cont(ctx0, g_last);
cb(g_last, "g_last", il);
ggml_tensor * g_last_exp = ggml_exp(ctx0, g_last);
cb(g_last_exp, "g_last_exp", il);
ggml_tensor * g_last_repeat =
ggml_repeat_4d(ctx0, g_last, chunk_size, 1, n_chunks, H_v * n_seqs);
ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum, g_last_repeat));
cb(g_diff, "g_diff", il);
ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff);
ggml_tensor * g_diff_exp_t = ggml_reshape_4d(ctx0, g_diff_exp, 1, chunk_size, n_chunks, g_diff_exp->ne[3]);
ggml_tensor * key_gdiff = ggml_mul(ctx0, ggml_repeat_4d(ctx0, g_diff_exp_t, k->ne[0], g_diff_exp_t->ne[1], g_diff_exp_t->ne[2], g_diff_exp_t->ne[3]), k);
cb(key_gdiff, "key_gdiff", il);
ggml_tensor * key_gdiff_t = ggml_cont(ctx0, ggml_transpose(ctx0, key_gdiff));
cb(key_gdiff_t, "key_gdiff_t", il);
cb(state, "new_state", il);
ggml_tensor * core_attn_out = nullptr;
for (int64_t chunk = 0; chunk < n_chunks; chunk++) {
ggml_tensor * q_chunk = get_slice_2d(q, chunk);
ggml_tensor * v_chunk = get_slice_2d(v, chunk);
ggml_tensor * gexp_chunk = get_slice_2d(gexp, chunk);
ggml_tensor * k_cumdecay_chunk = get_slice_2d(k_cumdecay, chunk);
ggml_tensor * attn_chunk = get_slice_2d(attn_kq, chunk);
cb(attn_chunk, "attn_chunk", il);
ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs);
//printf("v_prime_chunk: %ld x %ld x %ld x %ld, %s x %ld x %ld x %ld x %ld, %s\n", state_t->ne[0], state_t->ne[1], state_t->ne[2], state_t->ne[3], ggml_type_name(state_t->type),
// k_cumdecay_chunk->ne[0], k_cumdecay_chunk->ne[1], k_cumdecay_chunk->ne[2], k_cumdecay_chunk->ne[3], ggml_type_name(k_cumdecay_chunk->type));
ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay_chunk);
cb(v_prime, "v_prime_chunk", il);
ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, v_chunk, v_prime), v_prime);
ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new));
cb(v_new, "v_new_chunk", il);
ggml_tensor * q_g_exp = ggml_mul(ctx0, ggml_repeat_4d(ctx0, gexp_chunk, q_chunk->ne[0], gexp_chunk->ne[1], gexp_chunk->ne[2], gexp_chunk->ne[3]), q_chunk);
ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_g_exp);
cb(attn_inter, "attn_inter_chunk", il);
//printf("v_attn_chunk: %ld x %ld x %ld x %ld, %s x %ld x %ld x %ld x %ld, %s\n", v_new_t->ne[0], v_new_t->ne[1], v_new_t->ne[2], v_new_t->ne[3], ggml_type_name(v_new_t->type),
// attn_chunk->ne[0], attn_chunk->ne[1], attn_chunk->ne[2], attn_chunk->ne[3], ggml_type_name(attn_chunk->type));
ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, attn_chunk);
cb(v_attn, "v_attn_chunk", il);
ggml_tensor * core_attn_out_chunk = ggml_add(ctx0, attn_inter, v_attn);
cb(core_attn_out_chunk, "core_attn_out_chunk", il);
core_attn_out = core_attn_out == nullptr
? core_attn_out_chunk
: ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 2);
ggml_tensor * k_gdiff_t = get_slice_2d(key_gdiff_t, chunk);
//printf("kgdmulvnew: %ld x %ld x %ld x %ld, %s x %ld x %ld x %ld x %ld, %s\n", v_new_t->ne[0], v_new_t->ne[1], v_new_t->ne[2], v_new_t->ne[3], ggml_type_name(v_new_t->type),
// k_gdiff_t->ne[0], k_gdiff_t->ne[1], k_gdiff_t->ne[2], k_gdiff_t->ne[3], ggml_type_name(k_gdiff_t->type));
ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, k_gdiff_t);
cb(kgdmulvnew, "kgdmulvnew", il);
ggml_tensor * gexp_last_chunk = ggml_cont(ctx0, get_slice_2d(g_last_exp, chunk));
state = ggml_add(ctx0,
ggml_mul(ctx0, state, ggml_reshape_4d(ctx0, gexp_last_chunk, gexp_last_chunk->ne[0], gexp_last_chunk->ne[1], H_v, n_seqs)),
ggml_reshape_4d(ctx0, kgdmulvnew, kgdmulvnew->ne[0], kgdmulvnew->ne[1], H_v, n_seqs));
}
ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out,
S_v, n_tokens, H_v, n_seqs,
ggml_row_size(core_attn_out->type, S_v),
ggml_row_size(core_attn_out->type, S_v * QWEN3NEXT_CHUNK_SIZE * n_chunks),
ggml_row_size(core_attn_out->type, S_v * QWEN3NEXT_CHUNK_SIZE * n_chunks * H_v), 0);
cb(output_tokens, "output_tokens", il);
output_tokens = ggml_permute(ctx0, output_tokens, 0, 2, 1, 3);
output_tokens = ggml_cont(ctx0, output_tokens);
return {output_tokens, state};
};
auto build_delta_net_autoregressive = [&](ggml_tensor * q, ggml_tensor * k, ggml_tensor * v,
ggml_tensor * g, ggml_tensor * beta, ggml_tensor * state,
int il) -> std::pair<ggml_tensor *, ggml_tensor *> {
const int64_t H_k = q->ne[1];
const int64_t n_tokens = q->ne[2];
const int64_t n_seqs = q->ne[3];
const int64_t S_v = v->ne[0];
const int64_t H_v = v->ne[1];
GGML_ASSERT(n_tokens == 1);
GGML_ASSERT(n_seqs == 1);
GGML_ASSERT(H_k == H_v);
GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v && state->ne[2] == H_v && state->ne[3] == n_seqs);
const float eps_norm = hparams.f_norm_rms_eps;
q = ggml_l2_norm(ctx0, q, eps_norm);
k = ggml_l2_norm(ctx0, k, eps_norm);
const float scale = 1.0f / sqrtf(S_v);
q = ggml_scale(ctx0, q, scale);
beta = ggml_sigmoid(ctx0, beta);
cb(q, "q_in", il);
cb(k, "k_in", il);
cb(v, "v_in", il);
cb(beta, "beta_in", il);
cb(g, "g_in", il);
ggml_tensor * g_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, g), 1, 1, H_k, n_seqs);
ggml_tensor * beta_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, beta), 1, 1, H_k, n_seqs);
g_t = ggml_exp(ctx0, g_t);
state = ggml_mul(ctx0, state, g_t);
ggml_tensor * k_t_unsqueezed = ggml_reshape_4d(ctx0, k, 1, S_v, H_v, n_seqs);
ggml_tensor * kv_mem = ggml_mul(ctx0, state, k_t_unsqueezed);
kv_mem = ggml_cont(ctx0, ggml_transpose(ctx0, kv_mem));
cb(kv_mem, "kv_mem_t_cont", il);
kv_mem = ggml_transpose(ctx0, ggml_sum_rows(ctx0, kv_mem));
ggml_tensor * v_t = ggml_reshape_4d(ctx0, v, S_v, 1, H_v, n_seqs);
ggml_tensor * v_diff = ggml_sub(ctx0, v_t, kv_mem);
cb(v_diff, "v_diff", il);
ggml_tensor * delta = ggml_mul(ctx0, v_diff, beta_t);
ggml_tensor * k_t_delta = ggml_mul(ctx0, ggml_repeat_4d(ctx0, k_t_unsqueezed, S_v, S_v, H_v, n_seqs), delta);
state = ggml_add(ctx0, state, k_t_delta);
ggml_tensor * q_t_unsqueezed = ggml_reshape_4d(ctx0, q, 1, S_v, H_v, n_seqs);
ggml_tensor * state_q = ggml_mul(ctx0, state, q_t_unsqueezed);
state_q = ggml_cont(ctx0, ggml_transpose(ctx0, state_q));
cb(state_q, "state_q_t_cont", il);
ggml_tensor * core_attn_out = ggml_transpose(ctx0, ggml_sum_rows(ctx0, state_q));
cb(core_attn_out, "output_tokens", il);
cb(state, "new_state", il);
return {core_attn_out, state};
};
auto build_qkvz = [&](ggml_tensor * input, int il) -> std::pair<ggml_tensor *, ggml_tensor *> {
const int64_t n_tok = input->ne[1];
if (model.layers[il].wqkv) {
ggml_tensor * qkv_mixed = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, input);
cb(qkv_mixed, "qkv_mixed", il);
qkv_mixed = ggml_reshape_3d(ctx0, qkv_mixed, qkv_mixed->ne[0], n_tok, 1);
cb(qkv_mixed, "linear_attn_qkv_mixed", il);
ggml_tensor * z = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv_gate, input);
cb(z, "z", il);
return { qkv_mixed, z };
}
ggml_tensor * mixed_qkvz = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_in, input);
cb(mixed_qkvz, "linear_attn_mixed_qkvz", il);
const int64_t qkvz_new_dim = 2 * head_k_dim + 2 * head_v_dim * (num_v_heads / num_k_heads);
ggml_tensor * mixed_qkvz_reshaped = ggml_reshape_4d(ctx0, mixed_qkvz, qkvz_new_dim, num_k_heads, n_tok, 1);
int64_t split_sizes_qkvz[4] = {
head_k_dim,
head_k_dim,
head_v_dim * num_v_heads / num_k_heads,
head_v_dim * num_v_heads / num_k_heads
};
ggml_tensor * query = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[0], num_k_heads, n_tok, 1,
mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], 0);
cb(query, "q", il);
ggml_tensor * key = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[1], num_k_heads, n_tok, 1,
mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
split_sizes_qkvz[0] * ggml_element_size(mixed_qkvz_reshaped));
cb(key, "k", il);
ggml_tensor * value = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[2], num_k_heads, n_tok, 1,
mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
(split_sizes_qkvz[0] + split_sizes_qkvz[1]) * ggml_element_size(mixed_qkvz_reshaped));
cb(value, "v", il);
ggml_tensor * z = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[3], num_k_heads, n_tok, 1,
mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
(split_sizes_qkvz[0] + split_sizes_qkvz[1] + split_sizes_qkvz[2]) * ggml_element_size(mixed_qkvz_reshaped));
z = ggml_cont(ctx0, z);
cb(z, "z", il);
ggml_tensor * query_flat = ggml_cont_3d(ctx0, query, head_k_dim * num_k_heads, n_tok, 1);
cb(query_flat, "query_flat", il);
ggml_tensor * key_flat = ggml_cont_3d(ctx0, key, head_k_dim * num_k_heads, n_tok, 1);
cb(key_flat, "key_flat", il);
ggml_tensor * value_flat = ggml_cont_3d(ctx0, value, head_v_dim * num_v_heads, n_tok, 1);
cb(value_flat, "value_flat", il);
ggml_tensor * qkv_mixed = ggml_concat(ctx0, query_flat, key_flat, 0);
qkv_mixed = ggml_concat(ctx0, qkv_mixed, value_flat, 0);
cb(qkv_mixed, "qkv_mixed", il);
return { qkv_mixed, z };
};
auto build_layer_attn = [&](ggml_tensor * cur, ggml_tensor * inp_pos, ggml_tensor * KQ_mask, int il) -> ggml_tensor * {
ggml_tensor * Qcur_full = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
cb(Qcur_full, "Qcur_full", il);
Qcur_full = ggml_reshape_4d(ctx0, Qcur_full, n_embd_head * 2, n_head, n_tokens, 1);
ggml_tensor * Qcur = ggml_view_4d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, 1,
Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], 0);
ggml_tensor * gate = ggml_view_4d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, 1,
Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], n_embd_head * ggml_element_size(Qcur_full));
cb(Qcur, "Qcur", il);
cb(gate, "gate", il);
Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
cb(Qcur, "Qcur_reshaped", il);
Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, cb, il);
cb(Qcur, "Qcur_normed", il);
ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
cb(Kcur, "Kcur", il);
ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
cb(Vcur, "Vcur", il);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, cb, il);
cb(Kcur, "Kcur_normed", il);
gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens);
cb(gate, "gate_reshaped", il);
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
ggml_tensor * attn = llm_build_kv(ctx0, lctx, kv_self, gf,
nullptr, nullptr,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv,
hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale,
cb, il);
cb(attn, "attn_pregate", il);
ggml_tensor * gate_sigmoid = ggml_sigmoid(ctx0, gate);
cb(gate_sigmoid, "gate_sigmoid", il);
attn = ggml_mul(ctx0, attn, gate_sigmoid);
cb(attn, "attn_gated", il);
attn = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, attn);
cb(attn, "attn_output", il);
return attn;
};
auto build_layer_ffn = [&](ggml_tensor * cur, int il) -> ggml_tensor * {
const bool has_moe = model.layers[il].ffn_gate_inp != nullptr;
const bool has_dense = model.layers[il].ffn_gate != nullptr && model.layers[il].ffn_up != nullptr && model.layers[il].ffn_down != nullptr;
if (has_moe) {
ggml_tensor * moe_out =
llm_build_moe_ffn(ctx0, lctx, cur,
model.layers[il].ffn_gate_inp,
model.layers[il].ffn_up_exps,
model.layers[il].ffn_gate_exps,
model.layers[il].ffn_down_exps,
nullptr,
n_expert, n_expert_used, LLM_FFN_SILU,
true, false, 0.0f, LLM_EXPERT_GATING_FUNC_SOFTMAX,
cb, il, gf, false);
cb(moe_out, "ffn_moe_out", il);
const bool has_shexp = model.layers[il].ffn_up_shexp != nullptr &&
model.layers[il].ffn_gate_shexp != nullptr &&
model.layers[il].ffn_down_shexp != nullptr &&
model.layers[il].ffn_gate_inp_shexp != nullptr;
if (has_shexp) {
ggml_tensor * ffn_shexp =
llm_build_ffn(ctx0, lctx, nullptr, cur,
model.layers[il].ffn_up_shexp, NULL, NULL,
model.layers[il].ffn_gate_shexp, NULL, NULL,
model.layers[il].ffn_down_shexp, NULL, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
cb(ffn_shexp, "ffn_shexp", il);
ggml_tensor * shared_gate = llm_build_lora_mm(lctx, ctx0, model.layers[il].ffn_gate_inp_shexp, cur);
cb(shared_gate, "shared_expert_gate", il);
if (shared_gate->ne[1] == 1) {
ffn_shexp = ggml_fused_mul_unary(ctx0, shared_gate, ffn_shexp, GGML_UNARY_OP_SIGMOID);
} else {
shared_gate = ggml_sigmoid(ctx0, shared_gate);
cb(shared_gate, "shared_expert_gate_sigmoid", il);
ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate);
}
cb(ffn_shexp, "ffn_shexp_gated", il);
cur = ggml_add(ctx0, moe_out, ffn_shexp);
} else {
cur = moe_out;
}
cb(cur, "ffn_out", il);
return cur;
}
GGML_ASSERT(has_dense);
cur = llm_build_ffn(ctx0, lctx, nullptr, cur,
model.layers[il].ffn_up, NULL, NULL,
model.layers[il].ffn_gate, NULL, NULL,
model.layers[il].ffn_down, NULL, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
cb(cur, "ffn_out", il);
return cur;
};
auto build_layer_attn_linear_core = [&](ggml_tensor * cur, ggml_tensor * causal_mask, ggml_tensor * identity,
ggml_tensor * diag_mask, ggml_tensor * inp_s_seq_qnext,
uint32_t state_seq_id_local, bool reset_state_local, int il) -> ggml_tensor * {
const int64_t n_tok = cur->ne[1];
auto qkvz = build_qkvz(cur, il);
ggml_tensor * qkv_mixed = qkvz.first;
ggml_tensor * z = qkvz.second;
ggml_tensor * mixed_ba = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_beta_alpha, cur);
cb(mixed_ba, "linear_attn_mixed_ba", il);
int64_t ba_new_dim = 2 * num_v_heads / num_k_heads;
ggml_tensor * mixed_ba_reshaped = ggml_reshape_4d(ctx0, mixed_ba, ba_new_dim, num_k_heads, n_tok, 1);
int64_t split_sizes_ba[2] = {
num_v_heads / num_k_heads,
num_v_heads / num_k_heads
};
ggml_tensor * b = ggml_view_4d(ctx0, mixed_ba_reshaped, split_sizes_ba[0], num_k_heads, n_tok, 1,
mixed_ba_reshaped->nb[1], mixed_ba_reshaped->nb[2], mixed_ba_reshaped->nb[3], 0);
cb(b, "b", il);
ggml_tensor * a = ggml_view_4d(ctx0, mixed_ba_reshaped, split_sizes_ba[1], num_k_heads, n_tok, 1,
mixed_ba_reshaped->nb[1], mixed_ba_reshaped->nb[2], mixed_ba_reshaped->nb[3],
split_sizes_ba[0] * ggml_element_size(mixed_ba_reshaped));
cb(a, "a", il);
ggml_tensor * beta = ggml_cont_4d(ctx0, b, num_v_heads, 1, n_tok, 1);
ggml_tensor * alpha = ggml_cont_3d(ctx0, a, num_v_heads, n_tok, 1);
ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, model.layers[il].ssm_dt);
ggml_tensor * alpha_softplus = ggml_softplus(ctx0, alpha_biased);
cb(alpha_softplus, "a_softplus", il);
ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a);
cb(gate, "gate", il);
size_t state_row_size = 0;
ggml_tensor * state_all = nullptr;
GGML_ASSERT((size_t) il < kv_self.s_l.size() && kv_self.s_l[il] != nullptr);
ggml_tensor * state_storage = kv_self.s_l[il];
GGML_ASSERT(state_storage->type == GGML_TYPE_F32);
GGML_ASSERT(state_storage->ne[0] >= state_dim);
GGML_ASSERT((uint32_t) state_storage->ne[1] == qnext_state_slots);
state_row_size = state_storage->nb[1];
GGML_ASSERT(ggml_nbytes(state_storage) >= state_row_size * qnext_state_slots);
state_all = ggml_view_2d(ctx0, state_storage, state_dim, qnext_state_slots, state_row_size, 0);
ggml_tensor * state_dst = ggml_view_2d(ctx0, state_all, state_dim, 1, state_row_size, state_seq_id_local * state_row_size);
ggml_tensor * state_f32 = state_dst;
if (state_f32->type != GGML_TYPE_F32) {
state_f32 = ggml_cast(ctx0, state_f32, GGML_TYPE_F32);
}
if (reset_state_local) {
state_f32 = ggml_scale(ctx0, state_f32, 0.0f);
}
ggml_tensor * conv_state_flat = ggml_view_2d(ctx0, state_f32, conv_state_dim, 1, state_f32->nb[1], 0);
ggml_tensor * ssm_state_flat = ggml_view_2d(ctx0, state_f32, ssm_state_dim, 1, state_f32->nb[1],
conv_state_dim * ggml_element_size(state_f32));
ggml_tensor * conv_states = ggml_reshape_3d(ctx0, conv_state_flat, hparams.ssm_d_conv - 1, conv_dim, 1);
ggml_tensor * state = ggml_reshape_4d(ctx0, ssm_state_flat, head_v_dim, head_v_dim, num_v_heads, 1);
cb(conv_states, "conv_states", il);
cb(state, "state_predelta", il);
ggml_tensor * conv_output_raw = ggml_ssm_conv(ctx0, conv_states, qkv_mixed, model.layers[il].ssm_conv1d, inp_s_seq_qnext);
cb(conv_output_raw, "conv_output_raw", il);
ggml_tensor * conv_output = ggml_view_2d(ctx0, conv_output_raw, conv_dim, n_tok, conv_dim * ggml_element_size(conv_output_raw), 0);
ggml_tensor * conv_output_silu = ggml_silu(ctx0, conv_output);
cb(conv_output_silu, "conv_output_silu", il);
ggml_tensor * q_conv = ggml_view_2d(ctx0, conv_output_silu, key_dim, n_tok, conv_output_silu->nb[1], 0);
ggml_tensor * k_conv = ggml_view_2d(ctx0, conv_output_silu, key_dim, n_tok, conv_output_silu->nb[1],
key_dim * ggml_element_size(conv_output_silu));
ggml_tensor * v_conv = ggml_view_4d(ctx0, conv_output_silu, head_v_dim, num_v_heads, n_tok, 1,
ggml_row_size(conv_output_silu->type, head_v_dim),
conv_output_silu->nb[1],
conv_output_silu->nb[1] * n_tok,
2 * key_dim * ggml_element_size(conv_output_silu));
q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_tok, 1);
k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_tok, 1);
v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_tok, 1);
cb(q_conv, "q_conv_cont", il);
cb(k_conv, "k_conv_cont", il);
cb(v_conv, "v_conv_cont", il);
if (num_k_heads != num_v_heads) {
GGML_ASSERT(num_v_heads % num_k_heads == 0);
const int64_t repeat_factor = num_v_heads / num_k_heads;
ggml_tensor * q_reshaped = ggml_reshape_3d(ctx0, q_conv, head_k_dim, 1, num_k_heads * n_tok);
ggml_tensor * k_reshaped = ggml_reshape_3d(ctx0, k_conv, head_k_dim, 1, num_k_heads * n_tok);
ggml_tensor * q_repeated = ggml_repeat_4d(ctx0, q_reshaped, head_k_dim, repeat_factor, num_k_heads * n_tok, 1);
ggml_tensor * k_repeated = ggml_repeat_4d(ctx0, k_reshaped, head_k_dim, repeat_factor, num_k_heads * n_tok, 1);
q_conv = ggml_reshape_4d(ctx0, q_repeated, head_k_dim, num_k_heads * repeat_factor, n_tok, 1);
k_conv = ggml_reshape_4d(ctx0, k_repeated, head_k_dim, num_k_heads * repeat_factor, n_tok, 1);
}
cb(q_conv, "q_conv_predelta", il);
cb(k_conv, "k_conv_predelta", il);
cb(v_conv, "v_conv_predelta", il);
std::pair<ggml_tensor *, ggml_tensor *> attn_out;
GGML_ASSERT(causal_mask != nullptr);
GGML_ASSERT(identity != nullptr);
GGML_ASSERT(diag_mask != nullptr);
attn_out = n_tok == 1
? build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il)
: build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il);
ggml_tensor * output = attn_out.first;
ggml_tensor * new_state = attn_out.second;
cb(output, "attn_output", il);
cb(new_state, "new_state", il);
ggml_tensor * new_conv_states = ggml_view_2d(ctx0, conv_output_raw, hparams.ssm_d_conv - 1, conv_dim,
hparams.ssm_d_conv * ggml_element_size(conv_output_raw),
(1 + conv_dim * n_tok) * ggml_element_size(conv_output_raw));
ggml_tensor * new_conv_flat = ggml_reshape_2d(ctx0, ggml_cont(ctx0, new_conv_states), conv_state_dim, 1);
ggml_tensor * new_ssm_flat = ggml_reshape_2d(ctx0, new_state, ssm_state_dim, 1);
ggml_tensor * new_state_flat = ggml_concat(ctx0, new_conv_flat, new_ssm_flat, 0);
ggml_tensor * state_update = new_state_flat;
if (state_dst->type != GGML_TYPE_F32) {
state_update = ggml_cast(ctx0, state_update, state_dst->type);
}
ggml_build_forward_expand(gf, ggml_cpy(ctx0, state_update, state_dst));
ggml_tensor * attn_out_2d = ggml_reshape_2d(ctx0, output, head_v_dim, num_v_heads * n_tok);
ggml_tensor * z_2d = ggml_reshape_2d(ctx0, z, head_v_dim, num_v_heads * n_tok);
ggml_tensor * attn_out_norm = llm_build_norm(ctx0, attn_out_2d, hparams, model.layers[il].ssm_norm, nullptr, LLM_NORM_RMS, cb, il);
ggml_tensor * gated_silu = ggml_silu(ctx0, z_2d);
attn_out_norm = ggml_mul(ctx0, attn_out_norm, gated_silu);
ggml_tensor * final_output = ggml_reshape_2d(ctx0, attn_out_norm, value_dim, n_tok);
cb(final_output, "final_output", il);
ggml_tensor * out = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_out, final_output);
cb(out, "linear_attn_out", il);
return ggml_reshape_2d(ctx0, out, n_embd, n_tok);
};
auto build_layer_attn_linear = [&](ggml_tensor * cur, ggml_tensor * causal_mask, ggml_tensor * identity,
ggml_tensor * diag_mask, int il) -> ggml_tensor * {
GGML_ASSERT(lctx.inp_s_seq_qnext != nullptr);
if (all_same_seq) {
return build_layer_attn_linear_core(cur, causal_mask, identity, diag_mask, lctx.inp_s_seq_qnext, state_seq_id, reset_state, il);
}
GGML_ASSERT(has_unique_seq_ids && "qwen3next mixed-sequence batches require unique sequence IDs per token");
ggml_tensor * out = nullptr;
for (int64_t i = 0; i < n_tokens; ++i) {
ggml_tensor * cur_i = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (size_t) i * cur->nb[1]);
ggml_tensor * inp_s_seq_qnext_i = ggml_view_2d(ctx0, lctx.inp_s_seq_qnext, 1, 1, lctx.inp_s_seq_qnext->nb[1], (size_t) i * lctx.inp_s_seq_qnext->nb[1]);
const bool reset_state_i = batch.pos != nullptr && batch.pos[i] == 0;
const uint32_t state_seq_id_i = (uint32_t) token_seq_ids[i];
ggml_tensor * out_i = build_layer_attn_linear_core(cur_i, causal_mask, identity, diag_mask, inp_s_seq_qnext_i, state_seq_id_i, reset_state_i, il);
out = out == nullptr ? out_i : ggml_concat(ctx0, out, out_i, 1);
}
return out;
};
ggml_tensor * inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
ggml_tensor * inp_pos = build_inp_pos();
ggml_tensor * inp_out_ids = n_tokens > 1 ? build_inp_out_ids() : nullptr;
ggml_tensor * KQ_mask = build_inp_KQ_mask();
lctx.inp_s_seq_qnext = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, 1, n_tokens);
cb(lctx.inp_s_seq_qnext, "inp_s_seq_qnext", -1);
ggml_set_input(lctx.inp_s_seq_qnext);
ggml_tensor * causal_mask = nullptr;
ggml_tensor * identity = nullptr;
ggml_tensor * diag_mask = nullptr;
causal_mask = ggml_tri(ctx0,
ggml_fill_inplace(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, QWEN3NEXT_CHUNK_SIZE, QWEN3NEXT_CHUNK_SIZE), 1.0f),
GGML_TRI_TYPE_LOWER);
identity = ggml_diag(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, QWEN3NEXT_CHUNK_SIZE), 1.0f));
diag_mask = ggml_add(ctx0, causal_mask, identity);
ggml_build_forward_expand(gf, causal_mask);
ggml_build_forward_expand(gf, identity);
ggml_build_forward_expand(gf, diag_mask);
ggml_tensor * cur = nullptr;
for (int il = 0; il < n_layer; ++il) {
ggml_tensor * inpSA = inpL;
GGML_ASSERT(model.layers[il].attn_norm != nullptr);
GGML_ASSERT(model.layers[il].attn_post_norm != nullptr);
const bool has_moe = model.layers[il].ffn_gate_inp != nullptr;
const bool has_dense = model.layers[il].ffn_gate != nullptr &&
model.layers[il].ffn_up != nullptr &&
model.layers[il].ffn_down != nullptr;
GGML_ASSERT(has_moe || has_dense);
if (has_moe) {
GGML_ASSERT(model.layers[il].ffn_up_exps != nullptr);
GGML_ASSERT(model.layers[il].ffn_gate_exps != nullptr);
GGML_ASSERT(model.layers[il].ffn_down_exps != nullptr);
}
cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, cb, il);
cb(cur, "attn_norm", il);
if (hparams.is_recurrent(il)) {
GGML_ASSERT(model.layers[il].ssm_conv1d != nullptr);
GGML_ASSERT(model.layers[il].ssm_dt != nullptr);
GGML_ASSERT(model.layers[il].ssm_a != nullptr);
GGML_ASSERT(model.layers[il].ssm_beta_alpha != nullptr);
GGML_ASSERT(model.layers[il].ssm_norm != nullptr);
GGML_ASSERT(model.layers[il].ssm_out != nullptr);
GGML_ASSERT(model.layers[il].wqkv != nullptr || model.layers[il].ssm_in != nullptr);
GGML_ASSERT(model.layers[il].wqkv_gate != nullptr || model.layers[il].ssm_in != nullptr);
cur = build_layer_attn_linear(cur, causal_mask, identity, diag_mask, il);
} else {
GGML_ASSERT(model.layers[il].wq != nullptr);
GGML_ASSERT(model.layers[il].wk != nullptr);
GGML_ASSERT(model.layers[il].wv != nullptr);
GGML_ASSERT(model.layers[il].wo != nullptr);
GGML_ASSERT(model.layers[il].attn_q_norm != nullptr);
GGML_ASSERT(model.layers[il].attn_k_norm != nullptr);
cur = build_layer_attn(cur, inp_pos, KQ_mask, il);
}
if (il == n_layer - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
cur = ggml_add(ctx0, cur, inpSA);
cb(cur, "attn_residual", il);
ggml_tensor * ffn_residual = cur;
ggml_tensor * attn_post_norm = llm_build_norm(ctx0, cur, hparams, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, cb, il);
cb(attn_post_norm, "attn_post_norm", il);
cur = build_layer_ffn(attn_post_norm, il);
cb(cur, "ffn_out", il);
cur = ggml_add(ctx0, cur, ffn_residual);
cur = lctx.cvec.apply_to(ctx0, cur, il);
cb(cur, "l_out", il);
inpL = cur;
}
cur = llm_build_norm(ctx0, inpL, hparams, model.output_norm, nullptr, LLM_NORM_RMS, cb, -1);
cb(cur, "result_norm", -1);
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
cb(cur, "result_output", -1);
ggml_build_forward_expand(gf, cur);
return gf;
}
ggml_cgraph * llm_build_context::build_qwen3vl() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
@ -9273,6 +10146,10 @@ ggml_cgraph * llm_build_context::llama_build_graph(
{
result = llm.build_qwen3moe();
} break;
case LLM_ARCH_QWEN3NEXT:
{
result = llm.build_qwen3next();
} break;
case LLM_ARCH_QWEN3VL:
{
result = llm.build_qwen3vl();

View File

@ -204,6 +204,8 @@ struct llm_build_context {
ggml_cgraph * build_qwen3vlmoe();
ggml_cgraph * build_qwen3next();
ggml_cgraph * build_phi2();
ggml_cgraph * build_phi3();

View File

@ -56,6 +56,7 @@ struct llama_kv_cache {
std::vector<struct ggml_tensor *> k_l; // per layer
std::vector<struct ggml_tensor *> v_l;
std::vector<struct ggml_tensor *> s_l; // per layer recurrent state storage (Qwen3Next)
std::vector<llama_split_tensor> split_k_l;
std::vector<llama_split_tensor> split_v_l;
@ -202,6 +203,7 @@ struct llama_context {
struct ggml_tensor * inp_s_copy; // I32 [kv_size]
struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]
struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch]
struct ggml_tensor * inp_s_seq_qnext; // I32 [1, n_batch]
struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch]
struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc]
struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]

View File

@ -5,7 +5,7 @@
#include <map>
#define LLAMA_MAX_EXPERTS 384 // Kimi-K2
#define LLAMA_MAX_EXPERTS 512 // Qwen3 Next
static const std::map<llama_rope_scaling_type, const char *> LLAMA_ROPE_SCALING_TYPES = {
{ LLAMA_ROPE_SCALING_TYPE_NONE, "none" },
@ -83,6 +83,7 @@ void llm_load_hparams(
std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0);
std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0);
std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0);
std::fill(hparams.recurrent_layer_arr.begin(), hparams.recurrent_layer_arr.end(), false);
ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer);
ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer);
@ -453,6 +454,28 @@ void llm_load_hparams(
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
case LLM_ARCH_QWEN3NEXT:
{
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false);
ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv);
ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner);
ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state);
ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group);
// Upstream convention: every 4th layer is full attention, others are recurrent.
for (uint32_t i = 0; i < hparams.n_layer; ++i) {
hparams.recurrent_layer_arr[i] = ((i + 1) % 4 != 0);
}
switch (hparams.n_layer) {
case 48: model.type = e_model::MODEL_80B_A3B; break;
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
case LLM_ARCH_QWEN3VLMOE:
{
ml.get_key(LLM_KV_NUM_DEEPSTACK_LAYERS, hparams.n_deepstack_layers, false);

View File

@ -89,6 +89,10 @@ struct llama_hparams {
uint32_t ssm_d_inner = 0;
uint32_t ssm_d_state = 0;
uint32_t ssm_dt_rank = 0;
uint32_t ssm_n_group = 0;
// for hybrid state-space models (e.g. qwen3next)
std::array<bool, LLAMA_MAX_LAYERS> recurrent_layer_arr;
float f_clamp_kqv = 0.0f;
float f_max_alibi_bias = 0.0f;
@ -169,6 +173,8 @@ struct llama_hparams {
if (this->ssm_d_inner != other.ssm_d_inner) return true;
if (this->ssm_d_state != other.ssm_d_state) return true;
if (this->ssm_dt_rank != other.ssm_dt_rank) return true;
if (this->ssm_n_group != other.ssm_n_group) return true;
if (this->recurrent_layer_arr != other.recurrent_layer_arr) return true;
if (this->dec_start_token_id != other.dec_start_token_id) return true;
@ -246,6 +252,10 @@ struct llama_hparams {
}
uint32_t n_embd_k_s() const { // dimension of the rolling state embeddings
if (ssm_n_group > 0) {
// qwen3next keeps all recurrent state in the V-cache tail
return 0;
}
// corresponds to Mamba's conv_states size
// TODO: maybe support other convolution strides than 1
// NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
@ -253,10 +263,26 @@ struct llama_hparams {
}
uint32_t n_embd_v_s() const { // dimension of the recurrent state embeddings
if (ssm_n_group > 0) {
// qwen3next recurrent state packs:
// 1) conv state: (d_conv - 1) * (2 * key_dim + value_dim)
// 2) delta-net state: head_v_dim * head_v_dim * num_v_heads
const uint32_t key_dim = ssm_d_state * ssm_n_group;
const uint32_t value_dim = ssm_d_inner;
const uint32_t conv_dim = 2 * key_dim + value_dim;
const uint32_t conv_state_dim = (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * conv_dim;
const uint32_t head_v_dim = ssm_dt_rank > 0 ? ssm_d_inner / ssm_dt_rank : 0;
const uint32_t ssm_state_dim = head_v_dim * head_v_dim * ssm_dt_rank;
return conv_state_dim + ssm_state_dim;
}
// corresponds to Mamba's ssm_states size
return ssm_d_state * ssm_d_inner;
}
bool is_recurrent(uint32_t il) const {
return il < n_layer ? recurrent_layer_arr[il] : false;
}
static bool is_float_close(float a, float b, float abs_tol) {
// Check for non-negative tolerance
if (abs_tol < 0.0) {

View File

@ -73,6 +73,8 @@ struct create_tensors_helper : public create_tensors_helper_interface {
bool create_qwen3_moe_tensors(const LLM_TN & tn);
bool create_qwen3next_tensors(const LLM_TN & tn);
bool create_phi2_tensors(const LLM_TN & tn);
bool create_phi3_tensors(const LLM_TN & tn);
@ -1291,6 +1293,99 @@ bool create_tensors_helper::create_qwen3_moe_tensors(const LLM_TN & tn) {
return use_mmap_buffer;
}
bool create_tensors_helper::create_qwen3next_tensors(const LLM_TN & tn) {
LOADING_PRELUDE
model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
// output
{
model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
model.output = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
if (model.output == NULL) {
model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
}
}
const bool has_moe_hparams = n_expert > 0 && n_expert_used > 0;
const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : (has_moe_hparams ? n_ff / n_expert_used : n_ff);
const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff_exp;
const int64_t head_k_dim = hparams.ssm_d_state;
const int64_t num_k_heads = hparams.ssm_n_group;
const int64_t num_v_heads = hparams.ssm_dt_rank;
const int64_t head_v_dim = hparams.ssm_d_inner / num_v_heads;
const int64_t key_dim = head_k_dim * num_k_heads;
const int64_t value_dim = head_v_dim * num_v_heads;
const int64_t conv_dim = key_dim * 2 + value_dim;
const int64_t qkvz_dim = key_dim * 2 + value_dim * 2;
const int64_t ba_dim = num_v_heads * 2;
for (int i = 0; i < n_layer; ++i) {
ggml_context * ctx_layer = ctx_for_layer(i);
ggml_context * ctx_split = ctx_for_layer_split(i);
auto & layer = model.layers[i];
layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
layer.attn_post_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd});
if (!hparams.is_recurrent(i)) {
// Full-attention layer
layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head * 2});
layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa});
layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa});
layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd});
layer.attn_q_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k});
layer.attn_k_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k});
} else {
// Recurrent linear-attention layer
layer.ssm_in = create_tensor(ctx_split, tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, qkvz_dim},
llama_model_loader::TENSOR_NOT_REQUIRED);
layer.wqkv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, key_dim * 2 + value_dim},
llama_model_loader::TENSOR_NOT_REQUIRED);
layer.wqkv_gate = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_GATE, "weight", i), {n_embd, value_dim},
llama_model_loader::TENSOR_NOT_REQUIRED);
layer.ssm_conv1d = create_tensor(ctx_split, tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {hparams.ssm_d_conv, conv_dim});
layer.ssm_dt = create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_DT, "bias", i), {hparams.ssm_dt_rank});
layer.ssm_a = create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_A_NOSCAN, i), {hparams.ssm_dt_rank});
layer.ssm_beta_alpha = create_tensor(ctx_split, tn(LLM_TENSOR_SSM_BETA_ALPHA, "weight", i), {n_embd, ba_dim});
layer.ssm_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_NORM, "weight", i), {head_v_dim});
layer.ssm_out = create_tensor(ctx_split, tn(LLM_TENSOR_SSM_OUT, "weight", i), {value_dim, n_embd});
}
auto ffn_ctx = model.split_mode == LLAMA_SPLIT_MODE_GRAPH ? ctx_split : ctx_layer;
// Dense FFN path (optional, e.g. mlp_only_layers)
layer.ffn_gate = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.ffn_up = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.ffn_down = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
// MoE path (optional per-layer)
layer.ffn_gate_inp = nullptr;
if (n_expert > 0) {
layer.ffn_gate_inp = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED);
}
if (layer.ffn_gate_inp != nullptr) {
if (n_expert_used == 0) {
throw std::runtime_error("n_expert_used must be > 0 when QWEN3NEXT MoE tensors are present");
}
use_mmap_buffer &= !create_std_ffn_exps(n_embd, tn, i, llama_model_loader::TENSOR_NOT_REQUIRED, n_ff_exp);
}
// Shared expert path (optional per-layer)
layer.ffn_gate_inp_shexp = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
if (layer.ffn_gate_inp_shexp != nullptr) {
layer.ffn_gate_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.ffn_up_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.ffn_down_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
}
}
return use_mmap_buffer;
}
bool create_tensors_helper::create_mimo2_tensors(const LLM_TN & tn) {
LOADING_PRELUDE
@ -3221,6 +3316,8 @@ bool create_tensors_helper::create_tensors() {
case LLM_ARCH_QWEN3MOE:
case LLM_ARCH_QWEN3VLMOE:
use_mmap_buffer = create_qwen3_moe_tensors(tn); break;
case LLM_ARCH_QWEN3NEXT:
use_mmap_buffer = create_qwen3next_tensors(tn); break;
case LLM_ARCH_PHI2:
use_mmap_buffer = create_phi2_tensors(tn); break;
case LLM_ARCH_PHI3:

View File

@ -429,6 +429,39 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
},
},
{
LLM_ARCH_QWEN3NEXT,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
{ LLM_TENSOR_ATTN_GATE, "blk.%d.attn_gate" },
{ LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
{ LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
{ LLM_TENSOR_SSM_A_NOSCAN, "blk.%d.ssm_a" },
{ LLM_TENSOR_SSM_BETA_ALPHA, "blk.%d.ssm_ba" },
{ LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
{ LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" },
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" },
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
},
},
{
LLM_ARCH_QWEN3VL,
{
@ -1648,6 +1681,7 @@ const char * llama_model_type_name(e_model type) {
case MODEL_16B_A1B: return "16B.A1B";
case MODEL_21B_A3B: return "21B.A3B";
case MODEL_30B_A3B: return "30B.A3B";
case MODEL_80B_A3B: return "80B.A3B";
case MODEL_80B_A13B: return "80B.A13B";
case MODEL_100B_A6B: return "100B.A6B";
case MODEL_106B_A12B: return "106B.A12B";

View File

@ -107,6 +107,7 @@ enum e_model {
MODEL_16B_A1B,
MODEL_21B_A3B, // Ernie MoE small
MODEL_30B_A3B,
MODEL_80B_A3B, // Qwen3-Next
MODEL_80B_A13B,
MODEL_100B_A6B,
MODEL_106B_A12B,
@ -289,6 +290,8 @@ struct llama_layer {
struct ggml_tensor * ssm_x = nullptr;
struct ggml_tensor * ssm_dt = nullptr;
struct ggml_tensor * ssm_out = nullptr;
struct ggml_tensor * ssm_norm = nullptr;
struct ggml_tensor * ssm_beta_alpha = nullptr;
// mamba
struct ggml_tensor * ssm_conv1d = nullptr;

View File

@ -568,9 +568,15 @@ bool llama_context::can_reuse_graph(const llama_batch & u_batch) {
bool llama_context::update_cache_copies() {
int n_layer = model.hparams.n_layer - model.hparams.nextn_predict_layers; //cache_copies.size()/2;
auto layer_has_attention_kv = [&](int il) {
return !(model.arch == LLM_ARCH_QWEN3NEXT && model.hparams.is_recurrent(il));
};
if ((int)kv_self.k_l.size() != n_layer) return false;
if (!(kv_self.v_l.empty() || (int)kv_self.v_l.size() == n_layer)) return false;
for (int il = 0; il < n_layer; ++il) {
if (!layer_has_attention_kv(il) || kv_self.k_l[il] == nullptr) {
continue;
}
auto kl = (ggml_split_tensor_t *)kv_self.k_l[il]->extra;
if (kl) {
GGML_ASSERT(model.split_mode == LLAMA_SPLIT_MODE_GRAPH || model.split_mode == LLAMA_SPLIT_MODE_ATTN);
@ -597,6 +603,9 @@ bool llama_context::update_cache_copies() {
}
} else {
for (int il = 0; il < n_layer; ++il) {
if (!layer_has_attention_kv(il) || kv_self.k_l[il] == nullptr) {
continue;
}
auto& c = cache_copies[2*il+0];
if (!c.cpy || c.cpy->op != GGML_OP_CPY || c.cpy->view_src != kv_self.k_l[il]) return false;
c.cpy->view_offs = kv_self.head*c.step;
@ -605,6 +614,9 @@ bool llama_context::update_cache_copies() {
}
if (kv_self.v_l.empty()) return true;
for (int il = 0; il < n_layer; ++il) {
if (!layer_has_attention_kv(il) || kv_self.v_l[il] == nullptr) {
continue;
}
auto& c = cache_copies[2*il+1];
if (!c.cpy || c.cpy->op != GGML_OP_CPY || c.cpy->view_src != kv_self.v_l[il]) return false;
c.cpy->view_offs = kv_self.head*c.step;
@ -640,6 +652,58 @@ llama_context::~llama_context() {
// kv cache helpers
//
static inline bool llama_qwen3next_is_recurrent_layer(
const llama_model & model,
const llama_hparams & hparams,
uint32_t il) {
return model.arch == LLM_ARCH_QWEN3NEXT && hparams.is_recurrent(il);
}
static inline uint32_t llama_kv_v_row_embd(
const llama_model & model,
const llama_hparams & hparams,
uint32_t il) {
// qwen3next recurrent state is stored in a dedicated V-cache tail (per sequence),
// so per-token V rows include only attention values.
if (model.arch == LLM_ARCH_QWEN3NEXT) {
return hparams.n_embd_v_gqa(il);
}
return hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
}
static inline uint32_t llama_qwen3next_state_slots(const llama_cparams & cparams, uint32_t kv_size) {
return std::min<uint32_t>(std::max<uint32_t>(1, cparams.n_seq_max), kv_size);
}
static inline uint32_t llama_kv_qnext_state_slots(const llama_kv_cache & cache) {
uint32_t n_slots = 0;
for (const ggml_tensor * t : cache.s_l) {
if (t == nullptr) {
continue;
}
const uint32_t layer_slots = (uint32_t) t->ne[1];
if (n_slots == 0) {
n_slots = layer_slots;
} else {
GGML_ASSERT(n_slots == layer_slots);
}
}
return n_slots;
}
static inline bool llama_kv_has_qnext_state_storage(const llama_kv_cache & cache) {
return llama_kv_qnext_state_slots(cache) > 0;
}
static inline bool llama_kv_qnext_seq_id_in_range(const llama_kv_cache & cache, llama_seq_id seq_id) {
const uint32_t n_slots = llama_kv_qnext_state_slots(cache);
return n_slots > 0 && seq_id >= 0 && (uint32_t) seq_id < n_slots;
}
static bool llama_kv_cache_init(
struct llama_kv_cache & cache,
const llama_context * ctx,
@ -658,7 +722,9 @@ static bool llama_kv_cache_init(
// TODO: find a nicer way to add other recurrent model architectures
cache.recurrent = model.arch == LLM_ARCH_MAMBA;
cache.v_trans = !cache.recurrent && !cparams.flash_attn;
// qwen3next uses hybrid recurrent+attention cache semantics. Keep V rows in
// standard layout to match the mainline hybrid path when flash attention is off.
cache.v_trans = !cache.recurrent && !cparams.flash_attn && model.arch != LLM_ARCH_QWEN3NEXT;
cache.head = 0;
cache.size = kv_size;
@ -670,7 +736,7 @@ static bool llama_kv_cache_init(
cache.cells.clear();
cache.cells.resize(kv_size);
if (cache.recurrent) {
if (cache.recurrent || model.arch == LLM_ARCH_QWEN3NEXT) {
// init state copy sources
for (uint32_t i = 0; i < cache.size; ++i) {
cache.cells[i].src = i;
@ -750,18 +816,27 @@ static bool llama_kv_cache_init(
needs_v_cache = cparams.mla_attn == 1 && !cparams.flash_attn;
}
if (needs_v_cache) cache.v_l.reserve(n_layer);
cache.s_l.reserve(n_layer);
std::vector<size_t> mem_split(model.splits.size(), 0);
const uint32_t qnext_state_slots = llama_qwen3next_state_slots(cparams, kv_size);
if (model.arch == LLM_ARCH_QWEN3NEXT && qnext_state_slots < std::max<uint32_t>(1, cparams.n_seq_max)) {
LLAMA_LOG_WARN("%s: reducing qwen3next state slots from %u to %u to fit KV cache size\n",
__func__, std::max<uint32_t>(1, cparams.n_seq_max), qnext_state_slots);
}
int n_mla = 0;
for (int i = 0; i < (int) n_layer; i++) {
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
const bool qnext_recurrent = llama_qwen3next_is_recurrent_layer(model, hparams, i);
const uint32_t n_embd_v_row = llama_kv_v_row_embd(model, hparams, i);
const uint32_t n_head_kv = hparams.n_head_kv(i);
const uint32_t n_embd_head_k= hparams.n_embd_head_k;
struct ggml_context * ctx = split_cache ? ctx_map.at(model.buft_layer[i].buft_matrix) : offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front();
ggml_tensor * k;
ggml_tensor * v;
ggml_tensor * k = nullptr;
ggml_tensor * v = nullptr;
ggml_tensor * s = nullptr;
if (is_mla_attn && cparams.mla_attn) {
// DeepSeek MLA
const uint32_t n_embd_head_qk_rope = hparams.n_rot;
@ -792,56 +867,70 @@ static bool llama_kv_cache_init(
ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front();
split_cache_i = false;
}
int n_embd_head_v = hparams.n_embd_head_v;
k = ggml_new_tensor_2d(ctx, type_k, n_embd_head_k, n_head_kv*kv_size);
v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
auto k_name = std::string{"cache_k_l"} + std::to_string(i);
auto v_name = std::string{"cache_v_l"} + std::to_string(i);
ggml_set_name(k, k_name.c_str());
ggml_set_name(v, v_name.c_str());
//ggml_format_name(k, "cache_k_l%d", i);
//ggml_format_name(v, "cache_v_l%d", i);
if (qnext_recurrent) {
s = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hparams.n_embd_v_s(), qnext_state_slots);
split_cache_i = false;
} else {
int n_embd_head_v = hparams.n_embd_head_v;
k = ggml_new_tensor_2d(ctx, type_k, n_embd_head_k, n_head_kv*kv_size);
int64_t v_ne = int64_t(n_embd_v_row)*kv_size;
v = ggml_new_tensor_1d(ctx, type_v, v_ne);
auto k_name = std::string{"cache_k_l"} + std::to_string(i);
auto v_name = std::string{"cache_v_l"} + std::to_string(i);
ggml_set_name(k, k_name.c_str());
ggml_set_name(v, v_name.c_str());
//ggml_format_name(k, "cache_k_l%d", i);
//ggml_format_name(v, "cache_v_l%d", i);
if (split_cache_i) {
bool use_V_for_K = model.layers[i].attn_k_norm && model.layers[i].attn_k_norm->ne[0] == K->ne[1] ? true : false;
auto extra_K = (const ggml_split_tensor_t *)K->extra;
auto extra_V = (const ggml_split_tensor_t *)V->extra;
auto & split_k_l = cache.split_k_l.emplace_back();
auto & split_v_l = cache.split_v_l.emplace_back();
split_k_l.tensor_splits.resize(extra_K->n_device, nullptr);
split_v_l.tensor_splits.resize(extra_V->n_device, nullptr);
for (int is = 0; is < extra_K->n_device; ++is) {
auto split = use_V_for_K ? extra_V->splits[is] : extra_K->splits[is];
if (!split) continue;
int nhead_kv = use_V_for_K ? split->ne[1] / n_embd_head_v : split->ne[1]/n_embd_head_k;
if (use_V_for_K) {
LLAMA_LOG_DEBUG("K_cache(%d, %d): using %d instead of %ld heads\n",
i, is, nhead_kv, extra_K->splits[is]->ne[1]/n_embd_head_k);
}
split_k_l.tensor_splits[is] = ggml_new_tensor_2d(ctx, type_k, n_embd_head_k, nhead_kv * kv_size);
auto split_name = k_name + '.' + std::to_string(is);
ggml_set_name(split_k_l.tensor_splits[is], split_name.c_str());
mem_split[is] += ggml_nbytes(split_k_l.tensor_splits[is]);
}
split_k_l.ggml.n_device = extra_K->n_device;
split_k_l.ggml.split_dim = 0;
split_k_l.ggml.splits = split_k_l.tensor_splits.data();
for (int is = 0; is < extra_V->n_device; ++is) {
auto split = extra_V->splits[is];
if (!split) continue;
split_v_l.tensor_splits[is] = ggml_new_tensor_1d(ctx, type_v, split->ne[1] * kv_size);
auto split_name = v_name + '.' + std::to_string(is);
ggml_set_name(split_v_l.tensor_splits[is], split_name.c_str());
mem_split[is] += ggml_nbytes(split_v_l.tensor_splits[is]);
}
split_v_l.ggml.n_device = extra_V->n_device;
split_v_l.ggml.split_dim = 0;
split_v_l.ggml.splits = split_v_l.tensor_splits.data();
k->extra = (void *)&split_k_l.ggml;
v->extra = (void *)&split_v_l.ggml;
}
}
if (s) {
auto s_name = std::string{"cache_s_l"} + std::to_string(i);
ggml_set_name(s, s_name.c_str());
}
cache.k_l.push_back(k);
cache.v_l.push_back(v);
if (split_cache_i) {
bool use_V_for_K = model.layers[i].attn_k_norm && model.layers[i].attn_k_norm->ne[0] == K->ne[1] ? true : false;
auto extra_K = (const ggml_split_tensor_t *)K->extra;
auto extra_V = (const ggml_split_tensor_t *)V->extra;
auto & split_k_l = cache.split_k_l.emplace_back();
auto & split_v_l = cache.split_v_l.emplace_back();
split_k_l.tensor_splits.resize(extra_K->n_device, nullptr);
split_v_l.tensor_splits.resize(extra_V->n_device, nullptr);
for (int is = 0; is < extra_K->n_device; ++is) {
auto split = use_V_for_K ? extra_V->splits[is] : extra_K->splits[is];
if (!split) continue;
int nhead_kv = use_V_for_K ? split->ne[1] / n_embd_head_v : split->ne[1]/n_embd_head_k;
if (use_V_for_K) {
LLAMA_LOG_DEBUG("K_cache(%d, %d): using %d instead of %ld heads\n",
i, is, nhead_kv, extra_K->splits[is]->ne[1]/n_embd_head_k);
}
split_k_l.tensor_splits[is] = ggml_new_tensor_2d(ctx, type_k, n_embd_head_k, nhead_kv * kv_size);
auto split_name = k_name + '.' + std::to_string(is);
ggml_set_name(split_k_l.tensor_splits[is], split_name.c_str());
mem_split[is] += ggml_nbytes(split_k_l.tensor_splits[is]);
}
split_k_l.ggml.n_device = extra_K->n_device;
split_k_l.ggml.split_dim = 0;
split_k_l.ggml.splits = split_k_l.tensor_splits.data();
for (int is = 0; is < extra_V->n_device; ++is) {
auto split = extra_V->splits[is];
if (!split) continue;
split_v_l.tensor_splits[is] = ggml_new_tensor_1d(ctx, type_v, split->ne[1] * kv_size);
auto split_name = v_name + '.' + std::to_string(is);
ggml_set_name(split_v_l.tensor_splits[is], split_name.c_str());
mem_split[is] += ggml_nbytes(split_v_l.tensor_splits[is]);
}
split_v_l.ggml.n_device = extra_V->n_device;
split_v_l.ggml.split_dim = 0;
split_v_l.ggml.splits = split_v_l.tensor_splits.data();
k->extra = (void *)&split_k_l.ggml;
v->extra = (void *)&split_v_l.ggml;
}
}
cache.s_l.push_back(s);
}
if (is_mla_attn && cparams.mla_attn && n_mla < n_layer && n_mla > 0) {
LLAMA_LOG_ERROR("%s: unexpected situation with %d out of %d layers having MLA enabled\n", __func__, n_mla, int(n_layer));
@ -1017,6 +1106,7 @@ static uint32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) {
static void llama_kv_cache_clear(struct llama_kv_cache & cache) {
for (int32_t i = 0; i < (int32_t) cache.size; ++i) {
cache.cells[i].pos = -1;
cache.cells[i].src = i;
cache.cells[i].seq_id.clear();
}
cache.head = 0;
@ -1056,6 +1146,8 @@ static bool llama_kv_cache_seq_rm(
}
}
const bool has_qnext_state = llama_kv_has_qnext_state_storage(cache);
for (uint32_t i = 0; i < cache.size; ++i) {
if (cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
if (seq_id < 0) {
@ -1070,6 +1162,9 @@ static bool llama_kv_cache_seq_rm(
if (cache.cells[i].pos >= 0) cache.used--;
cache.cells[i].pos = -1;
if (has_qnext_state) {
cache.cells[i].src = i;
}
if (new_head == cache.size) new_head = i;
}
}
@ -1111,6 +1206,21 @@ static void llama_kv_cache_seq_cp(
}
return;
}
const bool has_qnext_state = llama_kv_has_qnext_state_storage(cache);
if (has_qnext_state &&
llama_kv_qnext_seq_id_in_range(cache, seq_id_dst) &&
llama_kv_qnext_seq_id_in_range(cache, seq_id_src) &&
(uint32_t) seq_id_dst < cache.size &&
(uint32_t) seq_id_src < cache.size) {
seq_id_src = cache.cells[seq_id_src].src;
GGML_ASSERT((uint32_t) seq_id_src < cache.size);
cache.cells[seq_id_dst].src = seq_id_src;
cache.cells[seq_id_dst].pos = cache.cells[seq_id_src].pos;
cache.do_copy = true;
}
// otherwise, this is the KV cache of a Transformer-like model
cache.head = 0;
@ -1124,11 +1234,15 @@ static void llama_kv_cache_seq_cp(
static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id seq_id) {
uint32_t new_head = cache.size;
const bool has_qnext_state = llama_kv_has_qnext_state_storage(cache);
for (uint32_t i = 0; i < cache.size; ++i) {
if (!cache.cells[i].has_seq_id(seq_id)) {
if (cache.cells[i].pos >= 0) cache.used--;
cache.cells[i].pos = -1;
if (has_qnext_state) {
cache.cells[i].src = i;
}
cache.cells[i].seq_id.clear();
if (new_head == cache.size) new_head = i;
} else {
@ -2764,6 +2878,18 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
}
}
if (lctx.inp_s_seq_qnext) {
const int64_t n_tokens = batch.n_tokens;
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_seq_qnext->buffer));
int32_t * data = (int32_t *) lctx.inp_s_seq_qnext->data;
for (int64_t j = 0; j < n_tokens; ++j) {
// qwen3next linear-attention path uses a single local recurrent state slot.
data[j] = 0;
}
}
if (lctx.inp_pos_bucket) {
const int64_t n_tokens = batch.n_tokens;
@ -3012,11 +3138,51 @@ static int llama_decode_internal(
}
}
for (uint32_t cur_token = 0; cur_token < n_tokens_all; cur_token += n_ubatch) {
bool warned_qnext_mixed_repeat = false;
for (uint32_t cur_token = 0; cur_token < n_tokens_all; ) {
#if IK_PRINT_TIMING
auto tim1 = ggml_time_us();
#endif
const uint32_t n_tokens = std::min(n_ubatch, n_tokens_all - cur_token);
uint32_t n_tokens = std::min(n_ubatch, n_tokens_all - cur_token);
if (model.arch == LLM_ARCH_QWEN3NEXT &&
n_tokens > 1 &&
batch_all.n_seq_id != nullptr &&
batch_all.seq_id != nullptr) {
bool can_check = true;
bool any_diff = false;
bool has_dup = false;
llama_seq_id first_seq_id = 0;
std::unordered_set<llama_seq_id> seen_seq_ids;
seen_seq_ids.reserve(n_tokens);
for (uint32_t i = 0; i < n_tokens; ++i) {
const uint32_t idx = cur_token + i;
if (batch_all.n_seq_id[idx] <= 0 || batch_all.seq_id[idx] == nullptr) {
can_check = false;
break;
}
const llama_seq_id seq_id_i = batch_all.seq_id[idx][0];
if (i == 0) {
first_seq_id = seq_id_i;
} else if (seq_id_i != first_seq_id) {
any_diff = true;
}
if (!seen_seq_ids.insert(seq_id_i).second) {
has_dup = true;
}
}
if (can_check && any_diff && has_dup) {
n_tokens = 1;
if (!warned_qnext_mixed_repeat) {
LLAMA_LOG_WARN("%s: qwen3next mixed-sequence batch contains repeated seq_id values; falling back to single-token chunking\n", __func__);
warned_qnext_mixed_repeat = true;
}
}
}
llama_batch u_batch = {
/* .n_tokens = */ (int32_t) n_tokens,
/* .token = */ batch_all.token ? batch_all.token + cur_token : nullptr,
@ -3293,6 +3459,7 @@ static int llama_decode_internal(
#endif
}
n_outputs_prev += lctx.n_outputs;
cur_token += n_tokens;
}
// set to total number of outputs in the batch, for use in llama_get_logits_ith
@ -3766,7 +3933,7 @@ static int32_t llama_kv_cache_update_internal(struct llama_context & lctx) {
}
}
if (lctx.kv_self.recurrent && lctx.kv_self.do_copy) {
if ((lctx.kv_self.recurrent || llama_kv_has_qnext_state_storage(lctx.kv_self)) && lctx.kv_self.do_copy) {
{
lctx.reset_scheduler();
@ -4787,11 +4954,15 @@ struct llama_context * llama_init_from_model(
size_t memory_size_v = 0;
for (auto & k : ctx->kv_self.k_l) {
memory_size_k += ggml_nbytes(k);
if (k) {
memory_size_k += ggml_nbytes(k);
}
}
for (auto & v : ctx->kv_self.v_l) {
memory_size_v += ggml_nbytes(v);
if (v) {
memory_size_v += ggml_nbytes(v);
}
}
if (memory_size_k + memory_size_v > 0) {
@ -4918,7 +5089,7 @@ struct llama_context * llama_init_from_model(
}
if (params.only_active_experts) {
LLAMA_LOG_INFO("XXXXXXXXXXXXXXXXXXXXX Setting only active experts offload\n");
LLAMA_LOG_INFO("%s: enabling only_active_experts scheduling\n", __func__);
ggml_backend_sched_set_only_active_experts(ctx->sched, true);
}
if (model->split_mode == LLAMA_SPLIT_MODE_GRAPH && (!model->has_tensor_overrides() || cparams.split_mode_graph_scheduling)) {
@ -5031,6 +5202,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
case LLM_ARCH_QWEN2MOE:
case LLM_ARCH_QWEN3:
case LLM_ARCH_QWEN3MOE:
case LLM_ARCH_QWEN3NEXT:
case LLM_ARCH_PHI2:
case LLM_ARCH_PHI3:
case LLM_ARCH_GEMMA:
@ -5586,7 +5758,7 @@ struct llama_data_write {
}
}
void write_kv_cache_data(const struct llama_context * ctx, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) {
void write_kv_cache_data(const struct llama_context * ctx, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) {
const struct llama_kv_cache & kv_self = ctx->kv_self;
const struct llama_hparams & hparams = ctx->model.hparams;
@ -5599,23 +5771,30 @@ struct llama_data_write {
write(&v_state, sizeof(v_state));
write(&n_layer, sizeof(n_layer));
std::vector<uint8_t> tmp_buf;
// Iterate and write all the keys first, each row is a cell
// Get whole range at a time
for (uint32_t il = 0; il < n_layer; ++il) {
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
const uint32_t n_embd_head_qk_rope = hparams.n_rot;
const uint32_t kv_lora_rank = hparams.n_lora_kv;
const bool has_k_cache = kv_self.k_l[il] != nullptr;
// Write key type
const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type;
const int32_t k_type_i = has_k_cache ? (int32_t) kv_self.k_l[il]->type : -1;
write(&k_type_i, sizeof(k_type_i));
// Write row size of key
const uint64_t k_size_row = (ctx->cparams.mla_attn == 0) ? ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa) : ggml_row_size(kv_self.k_l[il]->type, kv_lora_rank + n_embd_head_qk_rope);
const uint64_t k_size_row = has_k_cache
? ((ctx->cparams.mla_attn == 0)
? ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa)
: ggml_row_size(kv_self.k_l[il]->type, kv_lora_rank + n_embd_head_qk_rope))
: 0;
write(&k_size_row, sizeof(k_size_row));
if (!has_k_cache) {
continue;
}
// Read each range of cells of k_size length each into tmp_buf and write out
for (const auto & range : cell_ranges) {
const size_t range_size = range.second - range.first;
@ -5626,16 +5805,21 @@ struct llama_data_write {
if (v_state == 0) {
for (uint32_t il = 0; il < n_layer; ++il) {
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
const uint32_t n_embd_v_gqa = llama_kv_v_row_embd(ctx->model, hparams, il);
const bool has_v_cache = kv_self.v_l[il] != nullptr;
// Write value type
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
const int32_t v_type_i = has_v_cache ? (int32_t) kv_self.v_l[il]->type : -1;
write(&v_type_i, sizeof(v_type_i));
// Write row size of value
const uint64_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
const uint64_t v_size_row = has_v_cache ? ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa) : 0;
write(&v_size_row, sizeof(v_size_row));
if (!has_v_cache) {
continue;
}
// Read each range of cells of v_size length each into tmp_buf and write out
for (const auto & range : cell_ranges) {
const size_t range_size = range.second - range.first;
@ -5648,18 +5832,24 @@ struct llama_data_write {
// When v is transposed, we also need the element size and get the element ranges from each row
const uint32_t kv_size = kv_self.size;
for (uint32_t il = 0; il < n_layer; ++il) {
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
const uint32_t n_embd_v_gqa = llama_kv_v_row_embd(ctx->model, hparams, il);
const bool has_v_cache = kv_self.v_l[il] != nullptr;
// Write value type
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
const int32_t v_type_i = has_v_cache ? (int32_t) kv_self.v_l[il]->type : -1;
write(&v_type_i, sizeof(v_type_i));
// Write element size
const uint32_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
const uint32_t v_size_el = has_v_cache ? ggml_type_size(kv_self.v_l[il]->type) : 0;
write(&v_size_el, sizeof(v_size_el));
// Write GQA embedding size
write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
const uint32_t n_embd_v_gqa_write = has_v_cache ? n_embd_v_gqa : 0;
write(&n_embd_v_gqa_write, sizeof(n_embd_v_gqa_write));
if (!has_v_cache) {
continue;
}
// For each row, we get the element values of each cell
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
@ -5673,6 +5863,42 @@ struct llama_data_write {
}
}
}
const uint32_t qnext_state = llama_kv_has_qnext_state_storage(kv_self) ? 1 : 0;
write(&qnext_state, sizeof(qnext_state));
if (qnext_state != 0) {
for (uint32_t il = 0; il < n_layer; ++il) {
const bool has_s_cache = il < kv_self.s_l.size() && kv_self.s_l[il] != nullptr;
const int32_t s_type_i = has_s_cache ? (int32_t) kv_self.s_l[il]->type : -1;
write(&s_type_i, sizeof(s_type_i));
const uint64_t s_size_row = has_s_cache ? ggml_row_size(kv_self.s_l[il]->type, kv_self.s_l[il]->ne[0]) : 0;
write(&s_size_row, sizeof(s_size_row));
uint32_t s_rows = 0;
size_t s_offset = 0;
if (has_s_cache) {
const uint32_t n_slots = (uint32_t) kv_self.s_l[il]->ne[1];
if (seq_id == -1) {
s_rows = n_slots;
} else if (llama_kv_qnext_seq_id_in_range(kv_self, seq_id) && (uint32_t) seq_id < kv_self.size) {
llama_seq_id src_seq_id = kv_self.cells[seq_id].src;
if (llama_kv_qnext_seq_id_in_range(kv_self, src_seq_id)) {
s_rows = 1;
s_offset = (size_t) src_seq_id * s_size_row;
}
}
}
write(&s_rows, sizeof(s_rows));
if (has_s_cache && s_rows > 0) {
write_tensor_data(kv_self.s_l[il], s_offset, s_rows * s_size_row, il);
}
}
}
}
void write_kv_cache(const struct llama_context * ctx, llama_seq_id seq_id = -1) {
@ -5711,7 +5937,7 @@ struct llama_data_write {
write(&cell_count, sizeof(cell_count));
write_kv_cache_meta(kv_self, cell_ranges, seq_id);
write_kv_cache_data(ctx, cell_ranges);
write_kv_cache_data(ctx, cell_ranges, seq_id);
}
};
@ -5922,7 +6148,7 @@ struct llama_data_read {
GGML_ASSERT(sum_split_row_size == row_size);
}
bool read_kv_cache_data(struct llama_context * ctx, uint32_t cell_count) {
bool read_kv_cache_data(struct llama_context * ctx, uint32_t cell_count, llama_seq_id seq_id = -1) {
const struct llama_hparams & hparams = ctx->model.hparams;
struct llama_kv_cache & kv_self = ctx->kv_self;
@ -5954,20 +6180,35 @@ struct llama_data_read {
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
const uint32_t n_embd_head_qk_rope = hparams.n_rot;
const uint32_t kv_lora_rank = hparams.n_lora_kv;
const bool has_k_cache = kv_self.k_l[il] != nullptr;
// Read type of key
int32_t k_type_i_ref;
read_to(&k_type_i_ref, sizeof(k_type_i_ref));
const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type;
if (k_type_i != k_type_i_ref) {
LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
return false;
if (!has_k_cache) {
if (k_type_i_ref != -1) {
LLAMA_LOG_ERROR("%s: missing key cache for layer %d\n", __func__, il);
return false;
}
} else {
const int32_t k_type_i = (int32_t) kv_self.k_l[il]->type;
if (k_type_i != k_type_i_ref) {
LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
return false;
}
}
// Read row size of key
uint64_t k_size_row_ref;
read_to(&k_size_row_ref, sizeof(k_size_row_ref));
if (!has_k_cache) {
if (k_size_row_ref != 0) {
LLAMA_LOG_ERROR("%s: expected empty key row size for layer %d\n", __func__, il);
return false;
}
continue;
}
const uint64_t k_size_row = (ctx->cparams.mla_attn == 0) ? ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa) : ggml_row_size(kv_self.k_l[il]->type, kv_lora_rank + n_embd_head_qk_rope);
if (k_size_row != k_size_row_ref) {
LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
@ -5986,20 +6227,35 @@ struct llama_data_read {
if (v_state == 0) {
for (uint32_t il = 0; il < n_layer; ++il) {
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
const uint32_t n_embd_v_gqa = llama_kv_v_row_embd(ctx->model, hparams, il);
const bool has_v_cache = kv_self.v_l[il] != nullptr;
// Read type of value
int32_t v_type_i_ref;
read_to(&v_type_i_ref, sizeof(v_type_i_ref));
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
if (v_type_i != v_type_i_ref) {
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
return false;
if (!has_v_cache) {
if (v_type_i_ref != -1) {
LLAMA_LOG_ERROR("%s: missing value cache for layer %d\n", __func__, il);
return false;
}
} else {
const int32_t v_type_i = (int32_t) kv_self.v_l[il]->type;
if (v_type_i != v_type_i_ref) {
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
return false;
}
}
// Read row size of value
uint64_t v_size_row_ref;
read_to(&v_size_row_ref, sizeof(v_size_row_ref));
if (!has_v_cache) {
if (v_size_row_ref != 0) {
LLAMA_LOG_ERROR("%s: expected empty value row size for layer %d\n", __func__, il);
return false;
}
continue;
}
const size_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
if (v_size_row != v_size_row_ref) {
LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
@ -6019,35 +6275,58 @@ struct llama_data_read {
else if (v_state == 1) {
// For each layer, read the values for each cell (transposed)
for (uint32_t il = 0; il < n_layer; ++il) {
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
const uint32_t n_embd_v_gqa = llama_kv_v_row_embd(ctx->model, hparams, il);
const bool has_v_cache = kv_self.v_l[il] != nullptr;
// Read type of value
int32_t v_type_i_ref;
read_to(&v_type_i_ref, sizeof(v_type_i_ref));
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
if (v_type_i != v_type_i_ref) {
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
return false;
if (!has_v_cache) {
if (v_type_i_ref != -1) {
LLAMA_LOG_ERROR("%s: missing transposed value cache for layer %d\n", __func__, il);
return false;
}
} else {
const int32_t v_type_i = (int32_t) kv_self.v_l[il]->type;
if (v_type_i != v_type_i_ref) {
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
return false;
}
}
// Read element size of value
uint32_t v_size_el_ref;
read_to(&v_size_el_ref, sizeof(v_size_el_ref));
const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
if (v_size_el != v_size_el_ref) {
LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
return false;
if (!has_v_cache) {
if (v_size_el_ref != 0) {
LLAMA_LOG_ERROR("%s: expected empty transposed value element size for layer %d\n", __func__, il);
return false;
}
} else {
const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
if (v_size_el != v_size_el_ref) {
LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
return false;
}
}
// Read GQA embedding size
uint32_t n_embd_v_gqa_ref;
read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
if (!has_v_cache) {
if (n_embd_v_gqa_ref != 0) {
LLAMA_LOG_ERROR("%s: expected empty transposed value rows for layer %d\n", __func__, il);
return false;
}
continue;
}
if (n_embd_v_gqa != n_embd_v_gqa_ref) {
LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il);
return false;
}
if (cell_count) {
const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
if (kv_self.v_l[il]->extra) {
throw std::runtime_error("Transposed V cache is not sypported with split mode 'graph'");
}
@ -6059,6 +6338,76 @@ struct llama_data_read {
}
}
}
uint32_t qnext_state_ref = 0;
read_to(&qnext_state_ref, sizeof(qnext_state_ref));
const bool has_qnext_state = llama_kv_has_qnext_state_storage(kv_self);
if ((qnext_state_ref != 0) != has_qnext_state) {
LLAMA_LOG_ERROR("%s: incompatible qwen3next state cache presence\n", __func__);
return false;
}
if (qnext_state_ref != 0) {
for (uint32_t il = 0; il < n_layer; ++il) {
const bool has_s_cache = il < kv_self.s_l.size() && kv_self.s_l[il] != nullptr;
int32_t s_type_i_ref;
read_to(&s_type_i_ref, sizeof(s_type_i_ref));
if (!has_s_cache) {
if (s_type_i_ref != -1) {
LLAMA_LOG_ERROR("%s: missing qwen3next state cache for layer %d\n", __func__, il);
return false;
}
} else {
const int32_t s_type_i = (int32_t) kv_self.s_l[il]->type;
if (s_type_i != s_type_i_ref) {
LLAMA_LOG_ERROR("%s: mismatched qwen3next state type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il);
return false;
}
}
uint64_t s_size_row_ref;
read_to(&s_size_row_ref, sizeof(s_size_row_ref));
const uint64_t s_size_row = has_s_cache ? ggml_row_size(kv_self.s_l[il]->type, kv_self.s_l[il]->ne[0]) : 0;
if (s_size_row != s_size_row_ref) {
LLAMA_LOG_ERROR("%s: mismatched qwen3next state row size (%zu != %zu, layer %d)\n",
__func__, (size_t) s_size_row, (size_t) s_size_row_ref, il);
return false;
}
uint32_t s_rows_ref;
read_to(&s_rows_ref, sizeof(s_rows_ref));
uint32_t s_rows = 0;
uint32_t s_dst_row = 0;
if (has_s_cache) {
const uint32_t n_slots = (uint32_t) kv_self.s_l[il]->ne[1];
if (seq_id == -1) {
s_rows = n_slots;
} else if (llama_kv_qnext_seq_id_in_range(kv_self, seq_id)) {
s_rows = 1;
s_dst_row = (uint32_t) seq_id;
}
}
if (s_rows_ref != s_rows) {
LLAMA_LOG_ERROR("%s: mismatched qwen3next state row count (%u != %u, layer %d)\n", __func__, s_rows, s_rows_ref, il);
return false;
}
if (s_rows > 0) {
const size_t s_data_size = s_rows * s_size_row;
const size_t s_dst_offset = (size_t) s_dst_row * s_size_row;
if (kv_self.s_l[il]->extra) {
read_kv_cache_data_split(ctx, kv_self.s_l[il], read(s_data_size), s_dst_row, s_size_row, s_rows, il);
} else {
ggml_backend_tensor_set(kv_self.s_l[il], read(s_data_size), s_dst_offset, s_data_size);
}
}
}
}
return true;
}
@ -6066,7 +6415,7 @@ struct llama_data_read {
uint32_t cell_count;
read_to(&cell_count, sizeof(cell_count));
bool res = read_kv_cache_meta(ctx, cell_count, seq_id) && read_kv_cache_data(ctx, cell_count);
bool res = read_kv_cache_meta(ctx, cell_count, seq_id) && read_kv_cache_data(ctx, cell_count, seq_id);
if (!res) {
if (seq_id == -1) {