From 32eddaf2ea8dd5d499dee9655592a89b91bfde9d Mon Sep 17 00:00:00 2001 From: o7si <32285332+o7si@users.noreply.github.com> Date: Fri, 19 Jun 2026 00:59:18 +0800 Subject: [PATCH 01/86] cmake : fix ui build with read-only source (#24752) --- scripts/ui-assets.cmake | 27 +++++++++++++++++++++++---- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/scripts/ui-assets.cmake b/scripts/ui-assets.cmake index 78a0f4c844..349fa9bf81 100644 --- a/scripts/ui-assets.cmake +++ b/scripts/ui-assets.cmake @@ -20,6 +20,7 @@ set(LLAMA_UI_GZIP "" CACHE STRING "Apply gzip compress to assets to save ban set(DIST_DIR "${UI_BINARY_DIR}/dist") set(SRC_DIST_DIR "${UI_SOURCE_DIR}/dist") +set(WORK_DIR "${UI_BINARY_DIR}/ui-src") set(STAMP_FILE "${UI_BINARY_DIR}/.ui-stamp") set(UI_CPP "${UI_BINARY_DIR}/ui.cpp") set(UI_H "${UI_BINARY_DIR}/ui.h") @@ -64,6 +65,22 @@ function(npm_build_should_skip out_var) set(${out_var} TRUE PARENT_SCOPE) endfunction() +function(stage_sources) + if(EXISTS "${WORK_DIR}") + file(GLOB staged RELATIVE "${WORK_DIR}" "${WORK_DIR}/*") + list(REMOVE_ITEM staged "node_modules") + foreach(entry ${staged}) + file(REMOVE_RECURSE "${WORK_DIR}/${entry}") + endforeach() + endif() + + file(COPY "${UI_SOURCE_DIR}/" + DESTINATION "${WORK_DIR}" + NO_SOURCE_PERMISSIONS + PATTERN "node_modules" EXCLUDE + ) +endfunction() + function(npm_build out_var) set(${out_var} FALSE PARENT_SCOPE) @@ -89,14 +106,16 @@ function(npm_build out_var) return() endif() + stage_sources() + # npm writes node_modules/.package-lock.json on every successful install, # so a package-lock.json newer than this marker means node_modules is stale - set(NPM_MARKER "${UI_SOURCE_DIR}/node_modules/.package-lock.json") + set(NPM_MARKER "${WORK_DIR}/node_modules/.package-lock.json") set(need_install FALSE) if(NOT EXISTS "${NPM_MARKER}") set(need_install TRUE) else() - file(TIMESTAMP "${UI_SOURCE_DIR}/package-lock.json" lock_ts) + file(TIMESTAMP "${WORK_DIR}/package-lock.json" lock_ts) file(TIMESTAMP "${NPM_MARKER}" marker_ts) if(lock_ts STRGREATER marker_ts) set(need_install TRUE) @@ -107,7 +126,7 @@ function(npm_build out_var) message(STATUS "UI: running npm install") execute_process( COMMAND ${NPM_EXECUTABLE} install - WORKING_DIRECTORY "${UI_SOURCE_DIR}" + WORKING_DIRECTORY "${WORK_DIR}" RESULT_VARIABLE rc ERROR_VARIABLE err ) @@ -124,7 +143,7 @@ function(npm_build out_var) execute_process( COMMAND ${CMAKE_COMMAND} -E env "LLAMA_UI_OUT_DIR=${DIST_DIR}" "LLAMA_UI_VERSION=${HF_VERSION}" "LLAMA_BUILD_NUMBER=${LLAMA_BUILD_NUMBER}" ${NPM_EXECUTABLE} run build - WORKING_DIRECTORY "${UI_SOURCE_DIR}" + WORKING_DIRECTORY "${WORK_DIR}" RESULT_VARIABLE rc ERROR_VARIABLE err ) From a6b3260a4268d3c9931e0c47859fe4ff8f108bd9 Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Thu, 18 Jun 2026 21:55:04 +0200 Subject: [PATCH 02/86] mtmd: add batching for mtmd-cli, add video tests (#24778) --- tools/mtmd/mtmd-cli.cpp | 117 ++++++++++++++++++++++++++++++++++------ tools/mtmd/tests.sh | 52 ++++++++++++++---- 2 files changed, 145 insertions(+), 24 deletions(-) diff --git a/tools/mtmd/mtmd-cli.cpp b/tools/mtmd/mtmd-cli.cpp index a3cad7cd06..0ad000ef01 100644 --- a/tools/mtmd/mtmd-cli.cpp +++ b/tools/mtmd/mtmd-cli.cpp @@ -32,9 +32,9 @@ static volatile bool g_is_generating = false; static volatile bool g_is_interrupted = false; /** - * Please note that this is NOT a production-ready stuff. + * Please note that this is NOT a production-ready binary. * It is a playground for trying multimodal support in llama.cpp. - * For contributors: please keep this code simple and easy to understand. + * For contributors: please keep this code simple and easy to understand. Do not add unnecessary complexity. The goal is to have a simple CLI for testing multimodal support. */ static void show_additional_info(int /*argc*/, char ** argv) { @@ -65,6 +65,14 @@ static void sigint_handler(int signo) { } #endif +// this is only used by tests.sh to capture the response ; it's not meant to be used in production +static void inject_test_response_marker() { + const char * env = std::getenv("MTMD_TEST_RESPONSE_MARKER"); + if (env) { + LOG("%s\n", env); + } +} + struct mtmd_cli_context { mtmd::context_ptr ctx_vision; common_init_result_ptr llama_init; @@ -79,6 +87,8 @@ struct mtmd_cli_context { mtmd::bitmaps bitmaps; std::vector videos; + mtmd::batch_ptr mbatch; + // chat template common_chat_templates_ptr tmpls; std::vector chat_history; @@ -233,6 +243,8 @@ static std::string chat_add_and_format(mtmd_cli_context & ctx, common_chat_msg & } static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg) { + inject_test_response_marker(); + bool add_bos = ctx.chat_history.empty(); auto formatted_chat = chat_add_and_format(ctx, msg); LOG_DBG("formatted_chat.prompt: %s\n", formatted_chat.c_str()); @@ -259,20 +271,95 @@ static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg) { ctx.bitmaps.entries.clear(); ctx.videos.clear(); - llama_pos new_n_past; - if (mtmd_helper_eval_chunks(ctx.ctx_vision.get(), - ctx.lctx, // lctx - chunks.ptr.get(), // chunks - ctx.n_past, // n_past - 0, // seq_id - ctx.n_batch, // n_batch - true, // logits_last - &new_n_past)) { - LOG_ERR("Unable to eval prompt\n"); - return 1; - } + // batch encode all media chunks, then decode each + size_t n_chunks = mtmd_input_chunks_size(chunks.ptr.get()); + for (size_t i = 0; i < n_chunks; i++) { + auto chunk = mtmd_input_chunks_get(chunks.ptr.get(), i); + auto chunk_type = mtmd_input_chunk_get_type(chunk); - ctx.n_past = new_n_past; + if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) { + // decode text chunk + llama_pos new_n_past = ctx.n_past; + res = mtmd_helper_eval_chunk_single(ctx.ctx_vision.get(), + ctx.lctx, + chunk, + ctx.n_past, + 0, // seq_id + ctx.n_batch, + i == n_chunks - 1, // logits_last + &new_n_past); + if (res != 0) { + LOG_ERR("Unable to eval text chunk %zu\n", i); + return 1; + } + ctx.n_past = new_n_past; + } else { + // media chunk: try to get embd from existing batch, or create a new batch + float * embd = nullptr; + if (ctx.mbatch) { + embd = mtmd_batch_get_output_embd(ctx.mbatch.get(), chunk); + + if (embd) { + LOG_DBG("found embd for media chunk %zu in existing batch\n", i); + } else { + LOG_DBG("media chunk %zu not found in existing batch, creating new batch\n", i); + } + } + + if (!embd) { + // create and encode a new batch with as many media chunks as possible + ctx.mbatch.reset(mtmd_batch_init(ctx.ctx_vision.get())); + res = mtmd_batch_add_chunk(ctx.mbatch.get(), chunk); + GGML_ASSERT(res == 0); // first chunk must always succeed + + int n_added = 1; + // add as many subsequent media chunks as possible + for (size_t j = i + 1; j < n_chunks; j++) { + auto next_chunk = mtmd_input_chunks_get(chunks.ptr.get(), j); + auto next_type = mtmd_input_chunk_get_type(next_chunk); + if (next_type == MTMD_INPUT_CHUNK_TYPE_TEXT) { + break; // text chunk splits the batch + } + res = mtmd_batch_add_chunk(ctx.mbatch.get(), next_chunk); + if (res != 0) { + break; // batch full or incompatible + } + n_added++; + } + + int64_t time_start = ggml_time_ms(); + LOG_INF("encoding mtmd batch, n_chunks = %d (done = %zu, total = %zu)\n", n_added, i, n_chunks); + res = mtmd_batch_encode(ctx.mbatch.get()); + if (res != 0) { + LOG_ERR("Failed to encode mtmd batch, res = %d\n", res); + return 1; + } + LOG_INF("mtmd batch encoding done in %d ms\n", (int)(ggml_time_ms() - time_start)); + + embd = mtmd_batch_get_output_embd(ctx.mbatch.get(), chunk); + } + + GGML_ASSERT(embd != nullptr); + + llama_pos new_n_past = ctx.n_past; + res = mtmd_helper_decode_image_chunk(ctx.ctx_vision.get(), + ctx.lctx, + chunk, + embd, + ctx.n_past, + 0, // seq_id + ctx.n_batch, + &new_n_past, + nullptr, // callback + nullptr // user_data + ); + if (res != 0) { + LOG_ERR("Unable to decode media chunk %zu\n", i); + return 1; + } + ctx.n_past = new_n_past; + } + } LOG("\n"); diff --git a/tools/mtmd/tests.sh b/tools/mtmd/tests.sh index 5da48d61bf..6fe26478ab 100755 --- a/tools/mtmd/tests.sh +++ b/tools/mtmd/tests.sh @@ -13,6 +13,8 @@ mkdir -p $SCRIPT_DIR/output PROJ_ROOT="$SCRIPT_DIR/../.." cd $PROJ_ROOT +export MTMD_TEST_RESPONSE_MARKER="" + # Check if the first argument is "big", then run test with big models # This is useful if we're running the script on a larger machine, so we can test the big models RUN_BIG_TESTS=false @@ -28,6 +30,15 @@ if [ "${1:-}" = "huge" ]; then echo "Include BIG and HUGE models..." fi +USE_VIDEO=false +if [ "${1:-}" = "video" ]; then + USE_VIDEO=true + echo "Using video as input..." + # behavior of USE_VIDEO: + # do NOT check if the output contains "new york", only verify if the exit code is 0 + # when printing the result, print the OK/FAIL line then print the generated text +fi + # Check if the second argument is "flash", then enable flash attention # This is useful to test if flash attention off works correctly FLASH_ATTN="on" @@ -50,13 +61,20 @@ add_test_vision() { if [ $# -gt 0 ]; then extra_args=$(printf " %q" "$@") fi + if [ "$USE_VIDEO" = true ]; then + arr_file+=("test-3.mp4") + else + arr_file+=("test-1.jpeg") + fi arr_prefix+=("[vision]") arr_hf+=("$hf") arr_extra_args+=("$extra_args") - arr_file+=("test-1.jpeg") } add_test_audio() { + if [ "$USE_VIDEO" = true ]; then + return 0 + fi local hf=$1 shift local extra_args="" @@ -166,19 +184,35 @@ for i in "${!arr_hf[@]}"; do cmd+=" -p \"what is the publisher name of the newspaper?\"" fi - output=$(eval "$cmd" 2>&1 | tee /dev/tty) + exit_code=0 + output=$(eval "$cmd" 2>&1 | tee /dev/tty) || exit_code=$? echo "$output" > $SCRIPT_DIR/output/$bin-$(echo "$hf" | tr '/' '-').log - # either contains "new york" or both "men" and "walk" - if echo "$output" | grep -iq "new york" \ - || (echo "$output" | grep -iq "men" && echo "$output" | grep -iq "walk") - then - result="$prefix \033[32mOK\033[0m: $hf" + if [ "$USE_VIDEO" = true ]; then + # for video, only check exit code; do not grep for "new york" + if [ $exit_code -eq 0 ]; then + result="$prefix \033[32mOK\033[0m: $hf" + else + result="$prefix \033[31mFAIL\033[0m: $hf" + fi + # append generated text (after the response marker) + generated_text=$(echo "$output" | sed "1,/${MTMD_TEST_RESPONSE_MARKER}/d" | tail -10) + if [ -n "$generated_text" ]; then + result+="\n$generated_text" + fi + echo -e "$result" else - result="$prefix \033[31mFAIL\033[0m: $hf" + # either contains "new york" or both "men" and "walk" + if echo "$output" | grep -iq "new york" \ + || (echo "$output" | grep -iq "men" && echo "$output" | grep -iq "walk") + then + result="$prefix \033[32mOK\033[0m: $hf" + else + result="$prefix \033[31mFAIL\033[0m: $hf" + fi + echo -e "$result" fi - echo -e "$result" arr_res+=("$result") echo "" From 40f3aafc45990c1646397c67d2bd3e8eff3f543e Mon Sep 17 00:00:00 2001 From: Reguna Date: Fri, 19 Jun 2026 04:01:24 +0800 Subject: [PATCH 03/86] server: add "X-Accel-Buffering": "no" header to streaming endpoints (#24774) * server: add "X-Accel-Buffering": "no" header to streaming endpoints This header tells Nginx (as a reverse proxy) to NOT buffer responses. (only affects streaming endpoints) Without it, Nginx will break streaming with certain applications (notably the Pi coding harness). --- tools/server/server-http.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tools/server/server-http.cpp b/tools/server/server-http.cpp index 5defee1f5e..4f2abab00c 100644 --- a/tools/server/server-http.cpp +++ b/tools/server/server-http.cpp @@ -492,6 +492,8 @@ using server_http_req_ptr = std::unique_ptr; static void process_handler_response(server_http_req_ptr && request, server_http_res_ptr & response, httplib::Response & res) { if (response->is_stream()) { res.status = response->status; + // Tell Nginx to not buffer any streamed response + response->headers["X-Accel-Buffering"] = "no"; set_headers(res, response->headers); const std::string content_type = response->content_type; // convert to shared_ptr as both chunked_content_provider() and on_complete() need to use it From 3a3edc9ac65cca79584ca497be41d70c75a58ba8 Mon Sep 17 00:00:00 2001 From: Pascal Date: Thu, 18 Jun 2026 22:23:01 +0200 Subject: [PATCH 04/86] Ggml/cuda col2im 1d (#24417) * cuda: add GGML_OP_COL2IM_1D, follow-up to the CPU op * cuda: col2im_1d use fast_div_modulo for the index decomposition * cuda: col2im_1d tighten supports_op, type match and contiguous dst --- ggml/src/ggml-cuda/col2im-1d.cu | 81 ++++++++++++++++++++++++++++++++ ggml/src/ggml-cuda/col2im-1d.cuh | 3 ++ ggml/src/ggml-cuda/ggml-cuda.cu | 12 +++++ 3 files changed, 96 insertions(+) create mode 100644 ggml/src/ggml-cuda/col2im-1d.cu create mode 100644 ggml/src/ggml-cuda/col2im-1d.cuh diff --git a/ggml/src/ggml-cuda/col2im-1d.cu b/ggml/src/ggml-cuda/col2im-1d.cu new file mode 100644 index 0000000000..fecd4c6a95 --- /dev/null +++ b/ggml/src/ggml-cuda/col2im-1d.cu @@ -0,0 +1,81 @@ +#include "col2im-1d.cuh" +#include "convert.cuh" + +// col2im_1d: scatter-add GEMM columns to 1D signal (gather approach) +// columns: [K*OC, T_in] -> output: [T_out, OC] +// Supports F32, F16, BF16 data with F32 accumulator. + +template +static __global__ void col2im_1d_kernel( + const T * __restrict__ col, + T * __restrict__ dst, + const int T_in, const uint3 T_out_fd, + const int OC, const int K, const int K_OC, + const int s0, const int p0, const int total) { + + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx >= total) return; + + // dst layout: [T_out, OC], ne[0]=T_out fastest + const uint2 qr = fast_div_modulo((uint32_t)idx, T_out_fd); // qr.x = idx / T_out, qr.y = idx % T_out + const int oc = (int)qr.x; + const int t_out = (int)qr.y; + const int t_abs = t_out + p0; // absolute position in uncropped signal + + // Gather: find all (t_in, k) where t_in*s + k == t_abs, 0 <= k < K + int t_in_min = (t_abs - K + s0) / s0; // ceil((t_abs - K + 1) / s) + if (t_in_min < 0) t_in_min = 0; + int t_in_max = t_abs / s0; + if (t_in_max >= T_in) t_in_max = T_in - 1; + + float sum = 0.0f; + for (int t_in = t_in_min; t_in <= t_in_max; t_in++) { + const int k = t_abs - t_in * s0; + // col layout: [K*OC, T_in], column index = oc * K + k + sum += ggml_cuda_cast(col[(oc * K + k) + t_in * K_OC]); + } + + dst[idx] = ggml_cuda_cast(sum); +} + +void ggml_cuda_op_col2im_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(ggml_is_contiguous(src0)); + + const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t OC = ((const int32_t *)(dst->op_params))[1]; + const int32_t p0 = ((const int32_t *)(dst->op_params))[2]; + + const int K_OC = (int) src0->ne[0]; + const int T_in = (int) src0->ne[1]; + const int K = K_OC / OC; + const int T_out = (int) dst->ne[0]; + + const uint3 T_out_fd = init_fastdiv_values((uint32_t)T_out); + + const int total = T_out * OC; + const int block_size = 256; + const int num_blocks = (total + block_size - 1) / block_size; + + switch (src0->type) { + case GGML_TYPE_F32: { + col2im_1d_kernel<<>>( + (const float *)src0->data, (float *)dst->data, + T_in, T_out_fd, OC, K, K_OC, s0, p0, total); + } break; + case GGML_TYPE_F16: { + col2im_1d_kernel<<>>( + (const half *)src0->data, (half *)dst->data, + T_in, T_out_fd, OC, K, K_OC, s0, p0, total); + } break; + case GGML_TYPE_BF16: { + col2im_1d_kernel<<>>( + (const nv_bfloat16 *)src0->data, (nv_bfloat16 *)dst->data, + T_in, T_out_fd, OC, K, K_OC, s0, p0, total); + } break; + default: + GGML_ABORT("col2im_1d: unsupported type"); + } +} diff --git a/ggml/src/ggml-cuda/col2im-1d.cuh b/ggml/src/ggml-cuda/col2im-1d.cuh new file mode 100644 index 0000000000..efc3313c4d --- /dev/null +++ b/ggml/src/ggml-cuda/col2im-1d.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_op_col2im_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 34cdbc81c2..3d4b5f6056 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -11,6 +11,7 @@ #include "ggml-cuda/argsort.cuh" #include "ggml-cuda/binbcast.cuh" #include "ggml-cuda/clamp.cuh" +#include "ggml-cuda/col2im-1d.cuh" #include "ggml-cuda/concat.cuh" #include "ggml-cuda/conv-transpose-1d.cuh" #include "ggml-cuda/conv2d.cuh" @@ -3051,6 +3052,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_CONV_TRANSPOSE_1D: ggml_cuda_op_conv_transpose_1d(ctx,dst); break; + case GGML_OP_COL2IM_1D: + ggml_cuda_op_col2im_1d(ctx, dst); + break; case GGML_OP_POOL_2D: ggml_cuda_op_pool2d(ctx, dst); break; @@ -5316,6 +5320,14 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g } return false; } break; + case GGML_OP_COL2IM_1D: + { + ggml_type src0_type = op->src[0]->type; + return (src0_type == GGML_TYPE_F32 || src0_type == GGML_TYPE_F16 || src0_type == GGML_TYPE_BF16) && + op->type == src0_type && + ggml_is_contiguous(op->src[0]) && + ggml_is_contiguous(op); + } break; case GGML_OP_SILU_BACK: return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; break; From db52540f730de39efcf7172d4ab1f79bb50556e2 Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Fri, 19 Jun 2026 01:16:16 +0200 Subject: [PATCH 05/86] mtmd: add batching support for internvl (#24775) --- tools/mtmd/clip.cpp | 2 +- tools/mtmd/models/internvl.cpp | 18 +++++++++++------- tools/mtmd/models/models.h | 1 + 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index dc62232957..17079815d4 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -534,7 +534,7 @@ ggml_tensor * clip_graph::build_vit( ggml_tensor * clip_graph::build_inp() { ggml_tensor * inp_raw = build_inp_raw(); ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1); - inp = ggml_reshape_2d(ctx0, inp, n_patches, n_embd); + inp = ggml_reshape_3d(ctx0, inp, n_patches, n_embd, n_batch); inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp)); if (model.patch_bias) { inp = ggml_add(ctx0, inp, model.patch_bias); diff --git a/tools/mtmd/models/internvl.cpp b/tools/mtmd/models/internvl.cpp index 9aded3b97c..65d7d5a6b7 100644 --- a/tools/mtmd/models/internvl.cpp +++ b/tools/mtmd/models/internvl.cpp @@ -8,7 +8,9 @@ ggml_cgraph * clip_graph_internvl::build() { ggml_tensor * inp = build_inp(); // add CLS token - inp = ggml_concat(ctx0, inp, model.class_embedding, 1); + ggml_tensor * cls_repeated = ggml_repeat_4d(ctx0, model.class_embedding, + model.class_embedding->ne[0], 1, n_batch, 1); + inp = ggml_concat(ctx0, inp, cls_repeated, 1); // The larger models use a different ViT, which uses RMS norm instead of layer norm // ref: https://github.com/ggml-org/llama.cpp/pull/13443#issuecomment-2869786188 @@ -24,14 +26,15 @@ ggml_cgraph * clip_graph_internvl::build() { nullptr); // remove CLS token - cur = ggml_view_2d(ctx0, cur, - n_embd, n_patches, - ggml_row_size(cur->type, n_embd), 0); + cur = ggml_view_3d(ctx0, cur, + n_embd, n_patches, n_batch, + cur->nb[1], cur->nb[2], 0); + cur = ggml_cont(ctx0, cur); // pixel shuffle { const int scale_factor = model.hparams.n_merge; - const int bsz = 1; // batch size, always 1 for now since we don't support batching + const int bsz = n_batch; const int height = n_patches_y; const int width = n_patches_x; GGML_ASSERT(scale_factor > 0); @@ -44,9 +47,10 @@ ggml_cgraph * clip_graph_internvl::build() { bsz); cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); // flatten to 2D - cur = ggml_cont_2d(ctx0, cur, + cur = ggml_cont_3d(ctx0, cur, n_embd * scale_factor * scale_factor, - cur->ne[1] * cur->ne[2]); + cur->ne[1] * cur->ne[2], + cur->ne[3]); } // projector (always using GELU activation) diff --git a/tools/mtmd/models/models.h b/tools/mtmd/models/models.h index 3a15f76829..12d5e69493 100644 --- a/tools/mtmd/models/models.h +++ b/tools/mtmd/models/models.h @@ -80,6 +80,7 @@ struct clip_graph_minicpmv4_6 : clip_graph { struct clip_graph_internvl : clip_graph { clip_graph_internvl(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} ggml_cgraph * build() override; + bool support_batch() const override { return true; } }; struct clip_graph_nemotron_v2_vl : clip_graph { From 8141e730f1598780c19b153e0e212ed70a672c53 Mon Sep 17 00:00:00 2001 From: shalinib-ibm Date: Fri, 19 Jun 2026 11:25:38 +0530 Subject: [PATCH 06/86] ggml-cpu: support K tails in power10 Q8/Q4 MMA matmul (#24753) * ggml-cpu: support K tails in Power10 MMA Q8/Q4 matmul This patch removes the requirement that K be divisible by kc in the tinyBlas_Q0_PPC tiled matmul path. Process the final K panel using its actual depth and pass the reduced panel size through packing and kernel execution. This allows more workloads to use the MMA kernel and reduces fallback to mnpack. * Apply suggestion from @taronaeo Co-authored-by: Aaron Teo --------- Co-authored-by: Aaron Teo --- ggml/src/ggml-cpu/llamafile/sgemm.cpp | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-cpu/llamafile/sgemm.cpp b/ggml/src/ggml-cpu/llamafile/sgemm.cpp index e13828e3be..0b8323e60c 100644 --- a/ggml/src/ggml-cpu/llamafile/sgemm.cpp +++ b/ggml/src/ggml-cpu/llamafile/sgemm.cpp @@ -2345,7 +2345,7 @@ class tinyBLAS_Q0_PPC { else if (n_aligned % 16 == 0) nc = 16; else nc = 8; } - bool can_use_tiled = n_aligned > 0 && (m % mc == 0) && (k % kc == 0); + bool can_use_tiled = n_aligned > 0 && (m % mc == 0); if (can_use_tiled) { matmul_tiled(m, n_aligned, mc, nc, kc); if (n > n_aligned) { @@ -3063,13 +3063,14 @@ class tinyBLAS_Q0_PPC { int64_t ii = (job / xtiles) * mc; int64_t jj = (job % xtiles) * nc; for (int64_t kk = 0; kk < k; kk += kc) { + int64_t k_cur = MIN(kc, k - kk); if constexpr(is_Ablock_q4) { - packNormal_q4_fp16(A + ii * lda + kk, lda, mc, kc, (uint8_t *)A_pack); + packNormal_q4_fp16(A + ii * lda + kk, lda, mc, k_cur, (uint8_t *)A_pack); } else { - packNormal_q8_fp16(A + ii * lda + kk, lda, mc, kc, (uint8_t *)A_pack); + packNormal_q8_fp16(A + ii * lda + kk, lda, mc, k_cur, (uint8_t *)A_pack); } - packNormal_q8_fp16(B + jj * ldb + kk, ldb, nc, kc, (uint8_t *)B_pack); - KERNEL_Q0(ii, jj, mc, nc, kc, kk, A_pack, B_pack); + packNormal_q8_fp16(B + jj * ldb + kk, ldb, nc, k_cur, (uint8_t *)B_pack); + KERNEL_Q0(ii, jj, mc, nc, k_cur, kk, A_pack, B_pack); } } } From 80452d65b9b1d44b496ed729f1fb0b6c4c39d7bf Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 19 Jun 2026 09:22:34 +0300 Subject: [PATCH 07/86] server : consolidate slot selection into get_available_slot (#24755) Absorb get_slot_by_id logic into get_available_slot so slot selection is handled by a single function call. When a specific slot id is requested, the LCP similarity check still runs to enable proper prompt cache updates. Assisted-by: pi:llama.cpp/Qwen3.6-27B --- tools/server/server-context.cpp | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index aebca306a8..ded622cfd6 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -1395,11 +1395,23 @@ private: bool update_cache = false; + // if a specific slot is requested, use it (still goes through cache update logic below) + if (task.id_slot != -1) { + ret = get_slot_by_id(task.id_slot); + if (ret) { + SLT_INF(*ret, "selected slot by id (%d)\n", task.id_slot); + } + } + // find the slot that has at least n% prompt similarity - if (ret == nullptr && slot_prompt_similarity != 0.0f) { + if (slot_prompt_similarity != 0.0f) { float sim_best = 0; for (server_slot & slot : slots) { + if (task.id_slot != -1 && slot.id != task.id_slot) { + continue; + } + // skip the slot if it is not available if (slot.is_processing()) { continue; @@ -1426,8 +1438,10 @@ private: if (ret != nullptr) { const float f_keep = (sim_best*task.tokens.size()) / ret->prompt.tokens.size(); - SLT_INF(*ret, "selected slot by LCP similarity, sim_best = %.3f (> %.3f thold), f_keep = %.3f\n", - sim_best, slot_prompt_similarity, f_keep); + if (task.id_slot == -1) { + SLT_INF(*ret, "selected slot by LCP similarity, sim_best = %.3f (> %.3f thold), f_keep = %.3f\n", + sim_best, slot_prompt_similarity, f_keep); + } // if we are about to lose a large portion of the existing context - save it in the prompt cache if (f_keep < 0.5f) { @@ -2180,10 +2194,9 @@ private: } } - const int id_slot = task.id_slot; const int id_task = task.id; - server_slot * slot = id_slot != -1 ? get_slot_by_id(id_slot) : get_available_slot(task); + server_slot * slot = get_available_slot(task); // // slot scheduling logic From 5bd21b8555edf203aee78aacedf2d2744b8702a3 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 19 Jun 2026 09:34:00 +0300 Subject: [PATCH 08/86] pi : remove docs from system prompt (#24791) --- .pi/gg/SYSTEM.md | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/.pi/gg/SYSTEM.md b/.pi/gg/SYSTEM.md index 197173faed..17ce71cc1b 100644 --- a/.pi/gg/SYSTEM.md +++ b/.pi/gg/SYSTEM.md @@ -25,13 +25,3 @@ Commits: - Do not explicitly set the git author in commits - rely on the default git config - Always use `--no-gpg-sign` when committing - Never `git push` without explicit confirmation from the user - -Resources (read on demand): -- [CONTRIBUTING.md](CONTRIBUTING.md) -- [Build documentation](docs/build.md) -- [Server usage documentation](tools/server/README.md) -- [Server development documentation](tools/server/README-dev.md) -- [PEG parser](docs/development/parsing.md) -- [Auto parser](docs/autoparser.md) -- [Jinja engine](common/jinja/README.md) -- [PR template](.github/pull_request_template.md) From 1868af13ac1c2c6dd5b0dd1a4ac43f6ccc80fd80 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 19 Jun 2026 10:14:26 +0300 Subject: [PATCH 09/86] ggml : bump version to 0.15.2 (ggml/1548) --- ggml/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 0507e0c5aa..04069784f1 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -5,7 +5,7 @@ project("ggml" C CXX ASM) ### GGML Version set(GGML_VERSION_MAJOR 0) set(GGML_VERSION_MINOR 15) -set(GGML_VERSION_PATCH 1) +set(GGML_VERSION_PATCH 2) set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}") list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/") From 5fd2dc2c41c342a75c26f9756ca6b1814ed05fb4 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 19 Jun 2026 10:18:14 +0300 Subject: [PATCH 10/86] sync : ggml --- scripts/sync-ggml.last | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 87d353ef45..499be5a585 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -3af5f5760e19a96427f5f7a93b79cbdf3d4b265b +707321c4cf6d21cb4bc831aa8b687dbf01a521ce From 159d093a43e87b977e9749b18a980c15b1a66a90 Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Fri, 19 Jun 2026 10:53:44 +0200 Subject: [PATCH 11/86] server: fix non-bound n_discard value (ctx shifting) (#24786) * server: fix non-bound n_discard value * Update tools/server/server-context.cpp Co-authored-by: Georgi Gerganov --------- Co-authored-by: Georgi Gerganov --- tools/server/server-context.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index ded622cfd6..a23b0405ce 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -2565,7 +2565,10 @@ private: n_keep = std::min(slot.n_ctx - 4, n_keep); const int n_left = slot.prompt.n_tokens() - n_keep; - const int n_discard = slot.task->params.n_discard ? slot.task->params.n_discard : (n_left / 2); + int n_discard = slot.task->params.n_discard ? slot.task->params.n_discard : (n_left / 2); + + // ref: https://github.com/ggml-org/llama.cpp/pull/24786 + n_discard = std::clamp(n_discard, 0, std::max(0, n_left - 1)); SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard); From b14e3fb90ca8c760f4254ddc9aa7845ebbdb2edf Mon Sep 17 00:00:00 2001 From: Ruixiang Wang Date: Fri, 19 Jun 2026 12:08:50 +0200 Subject: [PATCH 12/86] spec: support eagle3 for qwen3.5 & 3.6 (#24593) * spec: support qwen3.5 & 3.6 eagle3 draft * eagle3: Add deferred boundary checkpoints restore support for hybrid models * apply suggestions Co-authored-by: Georgi Gerganov * spec: adapt to API change * spec: fix naming * cont : add TODO --------- Co-authored-by: Georgi Gerganov --- common/common.cpp | 4 +- common/common.h | 6 ++- common/speculative.cpp | 72 +++++++++++++++++++++++++++++++++ common/speculative.h | 4 ++ src/models/qwen35.cpp | 2 + src/models/qwen35moe.cpp | 2 + tools/server/server-context.cpp | 4 ++ 7 files changed, 92 insertions(+), 2 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index b01772e1cb..f3f114f682 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -2034,7 +2034,7 @@ bool common_prompt_batch_decode( } size_t common_prompt_checkpoint::size() const { - return data_tgt.size() + data_dft.size(); + return data_tgt.size() + data_dft.size() + data_spec.size(); } bool common_prompt_checkpoint::empty() const { @@ -2049,6 +2049,7 @@ void common_prompt_checkpoint::clear() { data_tgt.clear(); data_dft.clear(); + data_spec.clear(); } void common_prompt_checkpoint::update_pos( @@ -2138,4 +2139,5 @@ void common_prompt_checkpoint::clear_tgt() { void common_prompt_checkpoint::clear_dft() { data_dft.clear(); + data_spec.clear(); } diff --git a/common/common.h b/common/common.h index 040b9cf233..535a4ed335 100644 --- a/common/common.h +++ b/common/common.h @@ -363,7 +363,7 @@ struct common_params_speculative { uint32_t need_n_rs_seq() const { bool needs_rs_seq = std::any_of(types.begin(), types.end(), [&](auto t) { - return t == COMMON_SPECULATIVE_TYPE_DRAFT_MTP; + return t == COMMON_SPECULATIVE_TYPE_DRAFT_MTP || t == COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3; }); return needs_rs_seq ? draft.n_max : 0u; @@ -1065,6 +1065,10 @@ struct common_prompt_checkpoint { std::vector data_tgt; std::vector data_dft; + // (optional) speculative-decoding implementation state stashed with the checkpoint + // (e.g. eagle3's deferred-boundary g_embd row) + std::vector data_spec; + size_t size() const; bool empty() const; diff --git a/common/speculative.cpp b/common/speculative.cpp index 6f387f2cfc..9c20585dc3 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -161,6 +161,10 @@ struct common_speculative_impl { virtual void accept(llama_seq_id seq_id, uint16_t n_accepted, bool is_other) = 0; + // (optional) serialize/restore per-seq internal state (e.g. eagle3's deferred boundary). + virtual bool get_state(llama_seq_id /*seq_id*/, std::vector & /*data*/) const { return false; } + virtual void set_state(llama_seq_id /*seq_id*/, const std::vector & /*data*/) {} + // true if this implementation requires the target context to extract post-norm embeddings virtual bool need_embd() const = 0; @@ -841,6 +845,49 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl { (size_t) n_embd_dec * sizeof(float)); } + // we only need to stash the deferred boundary's g_embd row for recurrent/hybrid targets: + // their single-position checkpoints drop it on restore + bool need_boundary_stash() const { + const llama_model * model_tgt = llama_get_model(params.ctx_tgt); + return llama_model_is_recurrent(model_tgt) || llama_model_is_hybrid(model_tgt); + } + + bool get_state(llama_seq_id seq_id, std::vector & data) const override { + if (!need_boundary_stash()) { + return false; + } + if (seq_id < 0 || seq_id >= (llama_seq_id) n_seq || pending_pos_last[seq_id] < 0) { + return false; + } + + const llama_pos pos = pending_pos_last[seq_id]; + const std::vector & g = pending_g_last[seq_id]; + + data.resize(sizeof(llama_pos) + g.size() * sizeof(float)); + std::memcpy(data.data(), &pos, sizeof(llama_pos)); + std::memcpy(data.data() + sizeof(llama_pos), g.data(), g.size() * sizeof(float)); + return true; + } + + void set_state(llama_seq_id seq_id, const std::vector & data) override { + if (!need_boundary_stash()) { + return; + } + if (seq_id < 0 || seq_id >= (llama_seq_id) n_seq) { + return; + } + if (data.size() != sizeof(llama_pos) + (size_t) n_embd_dec * sizeof(float)) { + return; + } + + llama_pos pos = -1; + std::memcpy(&pos, data.data(), sizeof(llama_pos)); + + pending_pos_last[seq_id] = pos; + pending_g_last[seq_id].resize(n_embd_dec); + std::memcpy(pending_g_last[seq_id].data(), data.data() + sizeof(llama_pos), (size_t) n_embd_dec * sizeof(float)); + } + bool need_embd() const override { return false; } @@ -2118,6 +2165,31 @@ void common_speculative_accept(common_speculative * spec, llama_seq_id seq_id, u } } +// TODO: support the case of more than one speculative implementations having a state +bool common_speculative_get_state(common_speculative * spec, llama_seq_id seq_id, std::vector & data) { + if (spec == nullptr) { + return false; + } + + for (auto & impl : spec->impls) { + if (impl->get_state(seq_id, data)) { + return true; + } + } + + return false; +} + +void common_speculative_set_state(common_speculative * spec, llama_seq_id seq_id, const std::vector & data) { + if (spec == nullptr) { + return; + } + + for (auto & impl : spec->impls) { + impl->set_state(seq_id, data); + } +} + void common_speculative_print_stats(const common_speculative * spec) { if (spec == nullptr) { return; diff --git a/common/speculative.h b/common/speculative.h index bf76ad709e..c58fac3cc6 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -68,6 +68,10 @@ void common_speculative_draft(common_speculative * spec); // informs the speculative context that n_accepted tokens were accepted by the target model void common_speculative_accept(common_speculative * spec, llama_seq_id, uint16_t n_accepted); +// (optional) get/set internal state +bool common_speculative_get_state(common_speculative * spec, llama_seq_id seq_id, std::vector & data); +void common_speculative_set_state(common_speculative * spec, llama_seq_id seq_id, const std::vector & data); + // print statistics about the speculative decoding void common_speculative_print_stats(const common_speculative * spec); diff --git a/src/models/qwen35.cpp b/src/models/qwen35.cpp index 6783d98ec2..d8ffe43ae7 100644 --- a/src/models/qwen35.cpp +++ b/src/models/qwen35.cpp @@ -156,6 +156,8 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para // MTP/NextN layers are loaded as extra decoder blocks but not executed in the main pass. for (int il = 0; il < n_layer; ++il) { + res->t_layer_inp[il] = inpL; + ggml_tensor * inpSA = inpL; cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); diff --git a/src/models/qwen35moe.cpp b/src/models/qwen35moe.cpp index eb5e9a406a..7b0876cbb0 100644 --- a/src/models/qwen35moe.cpp +++ b/src/models/qwen35moe.cpp @@ -179,6 +179,8 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p // MTP/NextN layers are loaded as extra decoder blocks but not executed in the main pass. for (int il = 0; il < n_layer; ++il) { + res->t_layer_inp[il] = inpL; + ggml_tensor * inpSA = inpL; cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index a23b0405ce..c887beb0ac 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -2172,6 +2172,8 @@ private: cur.update_tgt(ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); cur.update_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + // stash the draft's speculative state with the checkpoint + common_speculative_get_state(spec.get(), slot.id, cur.data_spec); SLT_INF(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", @@ -2998,6 +3000,8 @@ private: // restore the context checkpoint it->load_tgt(ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); it->load_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + // restore the draft's speculative state + common_speculative_set_state(spec.get(), slot.id, it->data_spec); pos_next = std::min(pos_next, std::max(it->pos_min + 1, it->pos_max)); n_past = std::min(slot.prompt.tokens.size_up_to_pos(pos_next), (size_t) it->n_tokens); From e2e7a9b2d04e0404131e551243d4bd399ccfe606 Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Fri, 19 Jun 2026 12:18:36 +0200 Subject: [PATCH 13/86] mtmd: several bug fixes (#24784) * mtmd: several bug fixes * fix build * fix gemma4ua * add sanity check in get_u32() * fix build (2) * area() avoid overflow --- tools/mtmd/clip.cpp | 30 +++++++- tools/mtmd/clip.h | 3 + tools/mtmd/mtmd-audio.cpp | 139 +++++++++++++++++++++----------------- tools/mtmd/mtmd-audio.h | 14 ++-- tools/mtmd/mtmd.cpp | 5 +- 5 files changed, 119 insertions(+), 72 deletions(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 17079815d4..10840a851f 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -1675,6 +1675,9 @@ struct clip_model_loader { // note: some models having hparams.image_size == 0, which means the image size is dynamic throw std::runtime_error(string_format("%s: image_size (%d) cannot be negative\n", __func__, hparams.image_size)); } + if (hparams.image_size > 65536) { + throw std::runtime_error(string_format("%s: image_size (%d) is too large (max 65536)\n", __func__, hparams.image_size)); + } if (hparams.patch_size <= 0) { throw std::runtime_error(string_format("%s: patch_size (%d) must be greater than 0\n", __func__, hparams.patch_size)); } @@ -1723,6 +1726,19 @@ struct clip_model_loader { LOG_INF("%s: audio_n_fft: %d\n", __func__, hparams.audio_n_fft); LOG_INF("%s: audio_window_len: %d\n", __func__, hparams.audio_window_len); LOG_INF("%s: audio_hop_len: %d\n", __func__, hparams.audio_hop_len); + + // GEMMA4UA is encoder-free: it uses n_mel_bins as a raw-waveform frame size (640) and has no FFT/filterbank, so the mel-range and FFT + // checks below do not apply to it. + const bool fft_based = model.proj_type != PROJECTOR_TYPE_GEMMA4UA; + + // Validate audio hparams loaded from GGUF metadata + if (hparams.n_mel_bins <= 0 || (fft_based && hparams.n_mel_bins > 256)) { + throw std::runtime_error(string_format("%s: n_mel_bins (%d) must be in range [1, 256]\n", __func__, hparams.n_mel_bins)); + } + if (fft_based && (hparams.audio_sample_rate <= 0 || hparams.audio_n_fft <= 0 || hparams.audio_hop_len <= 0 || hparams.audio_window_len <= 0)) { + throw std::runtime_error(string_format("%s: audio hparams invalid: sample_rate=%d n_fft=%d window_len=%d hop_len=%d\n", + __func__, hparams.audio_sample_rate, hparams.audio_n_fft, hparams.audio_window_len, hparams.audio_hop_len)); + } } LOG_INF("\n"); LOG_INF("%s: model size: %.2f MiB\n", __func__, model_size / 1024.0 / 1024.0); @@ -2831,6 +2847,12 @@ struct clip_model_loader { img.set_size({sz, sz}, false, false); LOG_INF("%s: warmup with image size = %d x %d\n", __func__, sz, sz); } else { + // GEMMA4UA uses n_mel_bins as a raw-waveform frame size (640), not a mel-bin count, + // so the [1, 256] bound only applies to FFT-based models. + const bool fft_based = ctx_clip.model.proj_type != PROJECTOR_TYPE_GEMMA4UA; + if (hparams.n_mel_bins <= 0 || (fft_based && hparams.n_mel_bins > 256)) { + throw std::runtime_error(string_format("%s: invalid n_mel_bins (%d), must be in [1, 256]\n", __func__, hparams.n_mel_bins)); + } img.set_size({hparams.warmup_audio_size, hparams.n_mel_bins}, false, false); LOG_INF("%s: warmup with audio size = %d\n", __func__, hparams.warmup_audio_size); } @@ -2994,7 +3016,13 @@ struct clip_model_loader { } return; } - output = gguf_get_val_u32(ctx_gguf.get(), i); + const uint32_t val = gguf_get_val_u32(ctx_gguf.get(), i); + // sanity check + if (val > (uint32_t) INT32_MAX) { + throw std::runtime_error(string_format("%s: value %u for key '%s' exceeds INT32_MAX\n", + __func__, val, key.c_str())); + } + output = (int) val; } void get_f32(const std::string & key, float & output, bool required = true) const { diff --git a/tools/mtmd/clip.h b/tools/mtmd/clip.h index f66f4bc3bb..e0f1d298c8 100644 --- a/tools/mtmd/clip.h +++ b/tools/mtmd/clip.h @@ -24,6 +24,9 @@ struct clip_image_size { return !(*this == other); } int area() const { + // avoid overflow when computing area + GGML_ASSERT(width >= 0 && width <= 46000); + GGML_ASSERT(height >= 0 && height <= 46000); return width * height; } }; diff --git a/tools/mtmd/mtmd-audio.cpp b/tools/mtmd/mtmd-audio.cpp index 13f211fd90..b72fd067a5 100644 --- a/tools/mtmd/mtmd-audio.cpp +++ b/tools/mtmd/mtmd-audio.cpp @@ -32,8 +32,8 @@ void mtmd_audio_cache::fill_hann_window(uint32_t length, bool periodic) { } } -void mtmd_audio_cache::fill_mel_filterbank_matrix(int n_mel, - int n_fft, +void mtmd_audio_cache::fill_mel_filterbank_matrix(int64_t n_mel, + int64_t n_fft, int sample_rate, float fmin, float fmax, @@ -86,11 +86,16 @@ void mtmd_audio_cache::fill_mel_filterbank_matrix(int n_mel, hz_pts[i] = mel_to_hz(mel_pts[i]); } - const int n_fft_bins = n_fft / 2 + 1; + const int64_t n_fft_bins = n_fft / 2 + 1; + + // Validate allocation size + if ((size_t)n_mel * (size_t)n_fft_bins > SIZE_MAX) { + GGML_ASSERT(false && "mel filterbank allocation too large"); + } // filterbank - std::vector out(n_mel * n_fft_bins, 0); - for (int m = 0; m < n_mel; ++m) { + std::vector out((size_t)n_mel * (size_t)n_fft_bins, 0); + for (int64_t m = 0; m < n_mel; ++m) { const double f_left = hz_pts[m]; const double f_center = hz_pts[m + 1]; const double f_right = hz_pts[m + 2]; @@ -266,8 +271,8 @@ static void ifft(const mtmd_audio_cache & cache, float * in, int N, float * out) } struct filter_params { - int32_t n_mel; - int32_t n_fft_bins; + int64_t n_mel; + int64_t n_fft_bins; int32_t hann_window_size; int32_t hop_length; int32_t sample_rate; @@ -293,8 +298,8 @@ static void log_mel_spectrogram_worker_thread(int ith, std::vector fft_in(frame_size * 2, 0.0); std::vector fft_out(frame_size * 2 * 2 * 2); - int n_fft_bins = params.n_fft_bins; - int i = ith; + int64_t n_fft_bins = params.n_fft_bins; + int64_t i = ith; const auto & filters = cache.filters; @@ -302,17 +307,18 @@ static void log_mel_spectrogram_worker_thread(int ith, GGML_ASSERT(n_fft_bins == 1 + (frame_size / 2)); GGML_ASSERT(cache.sin_vals.size() == cache.cos_vals.size()); // calculate FFT only when fft_in are not all zero - for (; i < std::min(n_samples / frame_step + 1, out.n_len); i += n_threads) { - const int offset = i * frame_step; + for (; i < std::min((int64_t)(n_samples / frame_step + 1), out.n_len); i += n_threads) { + const int64_t offset = i * frame_step; // apply Hann window (~10% faster) - for (int j = 0; j < std::min(frame_size, n_samples - offset); j++) { + const int valid_len = std::min(frame_size, std::max(0, n_samples - (int)offset)); + for (int j = 0; j < valid_len; j++) { fft_in[j] = hann[j] * samples[offset + j]; } // fill the rest with zeros - if (n_samples - offset < frame_size) { - std::fill(fft_in.begin() + (n_samples - offset), fft_in.end(), 0.0); + if (valid_len < frame_size) { + std::fill(fft_in.begin() + valid_len, fft_in.end(), 0.0); } // FFT @@ -325,7 +331,7 @@ static void log_mel_spectrogram_worker_thread(int ith, } // mel spectrogram - for (int j = 0; j < out.n_mel; j++) { + for (int64_t j = 0; j < out.n_mel; j++) { double sum = 0.0; // unroll loop (suggested by GH user @lunixbochs) int k = 0; @@ -339,21 +345,21 @@ static void log_mel_spectrogram_worker_thread(int ith, } // handle n_fft remainder for (; k < n_fft_bins; k++) { - sum += fft_out[k] * filters.data[j * n_fft_bins + k]; + sum += fft_out[k] * filters.data[(size_t)j * n_fft_bins + k]; } sum = std::max(sum, (double)params.mel_floor); sum = params.use_natural_log ? log(sum) : log10(sum); - out.data[j * out.n_len + i] = sum; + out.data[(size_t)j * out.n_len + i] = sum; } } // Otherwise fft_out are all zero double sum = params.use_natural_log ? log(1e-10) : log10(1e-10); for (; i < out.n_len; i += n_threads) { - for (int j = 0; j < out.n_mel; j++) { - out.data[j * out.n_len + i] = sum; + for (int64_t j = 0; j < out.n_mel; j++) { + out.data[(size_t)j * out.n_len + i] = sum; } } } @@ -437,16 +443,21 @@ static bool log_mel_spectrogram( GGML_ASSERT(params.hop_length > 0); out.n_mel = params.n_mel; out.n_len = (n_samples - frame_size) / frame_step + 1; - // TODO: handle these checks better - if (out.n_mel > 0 && (unsigned long)out.n_len > SIZE_MAX / out.n_mel) { - LOG_ERR("%s: size overflow\n", __func__); + // Validate dimensions before allocation to prevent integer overflow + if (out.n_mel <= 0 || out.n_len <= 0) { + LOG_ERR("%s: invalid mel dimensions n_mel=%lld n_len=%lld\n", __func__, (long long)out.n_mel, (long long)out.n_len); + return false; + } + const size_t total_size = (size_t)out.n_mel * (size_t)out.n_len; + if (total_size > SIZE_MAX / sizeof(float)) { + LOG_ERR("%s: size overflow: n_mel=%lld n_len=%lld\n", __func__, (long long)out.n_mel, (long long)out.n_len); return false; } if (n_samples < frame_size) { LOG_ERR("%s: not enough samples after padding\n", __func__); return false; } - out.data.resize(out.n_mel * out.n_len); + out.data.resize(total_size); { std::vector workers(n_threads - 1); @@ -464,38 +475,39 @@ static bool log_mel_spectrogram( } } - const int effective_n_len = n_samples_in / frame_step; + const int64_t effective_n_len = n_samples_in / frame_step; if (params.norm_per_feature) { GGML_ASSERT(effective_n_len > 1); - for (int i = 0; i < out.n_mel; i++) { + for (int64_t i = 0; i < out.n_mel; i++) { double mean = 0; - for (int j = 0; j < effective_n_len; ++j) { - mean += out.data[i * out.n_len + j]; + for (int64_t j = 0; j < effective_n_len; ++j) { + mean += out.data[(size_t)i * out.n_len + j]; } mean /= effective_n_len; double var = 0.0; - for (int j = 0; j < effective_n_len; ++j) { - const double value = out.data[i * out.n_len + j] - mean; + for (int64_t j = 0; j < effective_n_len; ++j) { + const double value = out.data[(size_t)i * out.n_len + j] - mean; var += value * value; } var /= effective_n_len - 1; // unbiased const double mstd = std::sqrt(var + 1e-5); - for (int j = 0; j < effective_n_len; ++j) { - auto &value = out.data[i * out.n_len + j]; + for (int64_t j = 0; j < effective_n_len; ++j) { + auto &value = out.data[(size_t)i * out.n_len + j]; value = (value - mean) / mstd; } // pad the rest with zeros - for (int j = effective_n_len; j < out.n_len; ++j) { - out.data[i * out.n_len + j] = 0.0; + for (int64_t j = effective_n_len; j < out.n_len; ++j) { + out.data[(size_t)i * out.n_len + j] = 0.0; } } } else if (!params.no_padding) { // Whisper-style clamping and normalization (NOT used by Gemma4) double mmax = -1e20; - for (int i = 0; i < out.n_mel*out.n_len; i++) { + const size_t mel_size = (size_t)out.n_mel * (size_t)out.n_len; + for (size_t i = 0; i < mel_size; i++) { if (out.data[i] > mmax) { mmax = out.data[i]; } @@ -503,7 +515,7 @@ static bool log_mel_spectrogram( mmax -= 8.0; - for (int i = 0; i < out.n_mel*out.n_len; i++) { + for (size_t i = 0; i < mel_size; i++) { if (out.data[i] < mmax) { out.data[i] = mmax; } @@ -582,13 +594,13 @@ bool mtmd_audio_preprocessor_whisper::preprocess(const float * s // because the cgraph in clip.cpp only accepts 3000 frames each, we need to split the mel // we always expect the mel to have 3000 silent frames at the end if (DEBUG) { - printf("output: n_mel = %d, n_len = %d\n", out_full.n_mel, out_full.n_len); + printf("output: n_mel = %d, n_len = %d\n", (int) out_full.n_mel, (int) out_full.n_len); } const size_t frames_per_chunk = 3000; GGML_ASSERT((size_t) out_full.n_len > frames_per_chunk); for (size_t off = 0; off < (size_t) out_full.n_len; off += frames_per_chunk) { - int n_len = std::min(frames_per_chunk, (size_t) out_full.n_len - off); - if ((size_t) n_len < frames_per_chunk) { + int64_t n_len = std::min((int64_t)frames_per_chunk, out_full.n_len - (int64_t)off); + if (n_len < (int64_t)frames_per_chunk) { break; // last incomplete chunk will always be a padded chunk, safe to ignore } @@ -596,10 +608,10 @@ bool mtmd_audio_preprocessor_whisper::preprocess(const float * s out_chunk.n_len = n_len; out_chunk.n_mel = out_full.n_mel; out_chunk.n_len_org = out_full.n_mel; // unused - out_chunk.data.reserve(out_chunk.n_mel * out_chunk.n_len); + out_chunk.data.reserve((size_t)out_chunk.n_mel * (size_t)out_chunk.n_len); - for (int i = 0; i < out_full.n_mel; i++) { - auto src = out_full.data.begin() + i * out_full.n_len + off; + for (int64_t i = 0; i < out_full.n_mel; i++) { + auto src = out_full.data.begin() + (size_t)i * out_full.n_len + off; out_chunk.data.insert(out_chunk.data.end(), src, src + frames_per_chunk); } @@ -681,8 +693,8 @@ bool mtmd_audio_preprocessor_qwen3a::preprocess(const float * sa // The effective frame count: center-padded STFT gives ~n_samples/hop_length frames. // We take min(mel_full.n_len, n_samples/hop + 1) to avoid including excess frames. - const int n_eff = std::min(mel_full.n_len, - (int)(n_samples / hparams.audio_hop_len) + 1); + const int64_t n_eff = std::min(mel_full.n_len, + (int64_t)(n_samples / hparams.audio_hop_len) + 1); // Split into inference windows matching n_window_infer=800 from model config. // Each window is padded to the next multiple of chunk_size for the cgraph. @@ -690,18 +702,18 @@ bool mtmd_audio_preprocessor_qwen3a::preprocess(const float * sa const int chunk_size = 100; // conv sub-chunk size (n_window * 2, n_window=50) const int window_size = 800; // mel frames per forward pass (n_window_infer=800) - for (int off = 0; off < n_eff; off += window_size) { - const int win_eff = std::min(window_size, n_eff - off); - const int n_chunks = (win_eff + chunk_size - 1) / chunk_size; - const int n_padded = n_chunks * chunk_size; + for (int64_t off = 0; off < n_eff; off += window_size) { + const int64_t win_eff = std::min((int64_t)window_size, n_eff - off); + const int64_t n_chunks = (win_eff + chunk_size - 1) / chunk_size; + const int64_t n_padded = n_chunks * chunk_size; mtmd_audio_mel out; out.n_mel = mel_full.n_mel; out.n_len = n_padded; out.n_len_org = win_eff; - out.data.assign(out.n_mel * out.n_len, 0.0f); - for (int m = 0; m < out.n_mel; m++) { - const int copy_len = std::min(win_eff, mel_full.n_len - off); + out.data.assign((size_t)out.n_mel * (size_t)out.n_len, 0.0f); + for (int64_t m = 0; m < out.n_mel; m++) { + const int64_t copy_len = std::min((int64_t)win_eff, mel_full.n_len - off); if (copy_len > 0) { std::copy(mel_full.data.begin() + (size_t)m * mel_full.n_len + off, mel_full.data.begin() + (size_t)m * mel_full.n_len + off + copy_len, @@ -823,37 +835,38 @@ bool mtmd_audio_preprocessor_granite_speech::preprocess(const float * } double mmax = -1e20; - for (int i = 0; i < mel.n_mel * mel.n_len; i++) { + const size_t mel_size = (size_t)mel.n_mel * (size_t)mel.n_len; + for (size_t i = 0; i < mel_size; i++) { if (mel.data[i] > mmax) { mmax = mel.data[i]; } } mmax -= 8.0; - for (int i = 0; i < mel.n_mel * mel.n_len; i++) { + for (size_t i = 0; i < mel_size; i++) { if (mel.data[i] < mmax) { mel.data[i] = mmax; } mel.data[i] = (mel.data[i] + 4.0) / 4.0; } - int n_frames = mel.n_len; + int64_t n_frames = mel.n_len; if (n_frames % 2 == 1) { n_frames--; } - const int n_mel = mel.n_mel; - const int n_stacked = n_frames / 2; + const int64_t n_mel = mel.n_mel; + const int64_t n_stacked = n_frames / 2; mtmd_audio_mel stacked; stacked.n_mel = 2 * n_mel; stacked.n_len = n_stacked; - stacked.n_len_org = (int)n_samples; - stacked.data.resize(2 * n_mel * n_stacked); + stacked.n_len_org = (int64_t)n_samples; + stacked.data.resize((size_t)2 * (size_t)n_mel * (size_t)n_stacked); - for (int t = 0; t < n_stacked; t++) { - for (int m = 0; m < n_mel; m++) { - stacked.data[m * n_stacked + t] = mel.data[m * mel.n_len + 2 * t]; - stacked.data[(m + n_mel) * n_stacked + t] = mel.data[m * mel.n_len + 2 * t + 1]; + for (int64_t t = 0; t < n_stacked; t++) { + for (int64_t m = 0; m < n_mel; m++) { + stacked.data[(size_t)m * n_stacked + t] = mel.data[(size_t)m * mel.n_len + 2 * t]; + stacked.data[(size_t)(m + n_mel) * n_stacked + t] = mel.data[(size_t)m * mel.n_len + 2 * t + 1]; } } @@ -921,8 +934,8 @@ bool mtmd_audio_preprocessor_gemma4a::preprocess(const float * s const int hop = hparams.audio_hop_len; const int n_with_left = (int)chunk_len + pad_left; // PyTorch: unfold(size=frame_length+1, step=hop) on semicausal-padded waveform - const int pt_frames = (n_with_left - (hparams.audio_window_len + 1)) / hop + 1; - const int n_padded_needed = (pt_frames - 1) * hop + fft_size; + const int64_t pt_frames = (n_with_left - (hparams.audio_window_len + 1)) / hop + 1; + const int64_t n_padded_needed = (pt_frames - 1) * hop + fft_size; const int total_pad = std::max((int)(n_padded_needed - (int)chunk_len), pad_left); std::vector padded_samples(total_pad + chunk_len, 0.0f); std::copy(chunk_ptr, chunk_ptr + chunk_len, padded_samples.data() + pad_left); diff --git a/tools/mtmd/mtmd-audio.h b/tools/mtmd/mtmd-audio.h index 9656e3940f..ad96bd847c 100644 --- a/tools/mtmd/mtmd-audio.h +++ b/tools/mtmd/mtmd-audio.h @@ -10,16 +10,16 @@ #define MTMD_INTERNAL_HEADER struct mtmd_audio_mel { - int n_len; - int n_len_org; - int n_mel; + int64_t n_len; + int64_t n_len_org; + int64_t n_mel; std::vector data; }; struct mtmd_audio_mel_filters { - int32_t n_mel; - int32_t n_fft; + int64_t n_mel; + int64_t n_fft; std::vector data; }; @@ -39,8 +39,8 @@ struct mtmd_audio_cache { // Build mel filterbank matrix [n_mel × n_fft_bins] at runtime. // n_fft_bins must be (N_fft / 2 + 1). Example: if N_fft=512 -> n_fft_bins=257. - void fill_mel_filterbank_matrix(int n_mel, - int n_fft, + void fill_mel_filterbank_matrix(int64_t n_mel, + int64_t n_fft, int sample_rate, // e.g. 16000 float fmin = 0.0f, // e.g. 0.0 float fmax = -1.0f, // e.g. sr/2; pass -1 for auto diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index abba2ebf2c..cbaac1d377 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -1295,9 +1295,12 @@ struct mtmd_tokenizer { for (auto & mel_spec : mel_spec_chunks) { const bool is_placeholder = mel_spec.data.empty(); + // Validate dimensions fit in clip_image_size (int) + GGML_ASSERT(mel_spec.n_len <= INT32_MAX && mel_spec.n_len >= 0); + GGML_ASSERT(mel_spec.n_mel <= INT32_MAX && mel_spec.n_mel >= 0); clip_image_f32 mel_f32; mel_f32.set_size( - {mel_spec.n_len, mel_spec.n_mel}, + {(int)mel_spec.n_len, (int)mel_spec.n_mel}, is_placeholder, /* is_audio */ true); mel_f32.cpy_buf(mel_spec.data); From 38724ab5937f6440993d4e9814563385569d3b5a Mon Sep 17 00:00:00 2001 From: Aldehir Rojas Date: Fri, 19 Jun 2026 08:32:31 -0500 Subject: [PATCH 14/86] docker : build the UI (#24794) * docker : build the UI * cont : use existing APP_VERSION --- .devops/cann.Dockerfile | 16 ++++++++++++++++ .devops/cpu.Dockerfile | 16 ++++++++++++++++ .devops/cuda.Dockerfile | 16 ++++++++++++++++ .devops/intel.Dockerfile | 16 ++++++++++++++++ .devops/musa.Dockerfile | 16 ++++++++++++++++ .devops/openvino.Dockerfile | 16 ++++++++++++++++ .devops/rocm.Dockerfile | 16 ++++++++++++++++ .devops/s390x.Dockerfile | 16 ++++++++++++++++ .devops/vulkan.Dockerfile | 16 ++++++++++++++++ .devops/zendnn.Dockerfile | 16 ++++++++++++++++ .dockerignore | 3 +++ 11 files changed, 163 insertions(+) diff --git a/.devops/cann.Dockerfile b/.devops/cann.Dockerfile index 9df86d0489..dc95e3f38d 100644 --- a/.devops/cann.Dockerfile +++ b/.devops/cann.Dockerfile @@ -13,6 +13,20 @@ ARG APP_REVISION=N/A # BUILD STAGE # Compile all binary files and libraries # ============================================================================== +ARG NODE_VERSION=24 + +FROM docker.io/node:$NODE_VERSION AS web + +ARG APP_VERSION + +WORKDIR /app/tools/ui + +COPY tools/ui/package.json tools/ui/package-lock.json ./ +RUN npm ci + +COPY tools/ui/ ./ +RUN LLAMA_BUILD_NUMBER="$APP_VERSION" npm run build + FROM ${CANN_BASE_IMAGE} AS build # -- Install build dependencies -- @@ -26,6 +40,8 @@ WORKDIR /app # -- Copy project files -- COPY . . +COPY --from=web /app/tools/ui/dist tools/ui/dist + # -- Set CANN environment variables (required for compilation) -- # Using ENV instead of `source` allows environment variables to persist across the entire image layer ENV ASCEND_TOOLKIT_HOME=/usr/local/Ascend/ascend-toolkit/latest diff --git a/.devops/cpu.Dockerfile b/.devops/cpu.Dockerfile index 9dbfdd11df..caf727bcdb 100644 --- a/.devops/cpu.Dockerfile +++ b/.devops/cpu.Dockerfile @@ -3,6 +3,20 @@ ARG BUILD_DATE=N/A ARG APP_VERSION=N/A ARG APP_REVISION=N/A +ARG NODE_VERSION=24 + +FROM docker.io/node:$NODE_VERSION AS web + +ARG APP_VERSION + +WORKDIR /app/tools/ui + +COPY tools/ui/package.json tools/ui/package-lock.json ./ +RUN npm ci + +COPY tools/ui/ ./ +RUN LLAMA_BUILD_NUMBER="$APP_VERSION" npm run build + FROM docker.io/ubuntu:$UBUNTU_VERSION AS build ARG TARGETARCH @@ -16,6 +30,8 @@ WORKDIR /app COPY . . +COPY --from=web /app/tools/ui/dist tools/ui/dist + RUN if [ "$TARGETARCH" = "amd64" ] || [ "$TARGETARCH" = "arm64" ]; then \ cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DGGML_NATIVE=OFF -DLLAMA_BUILD_TESTS=OFF -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON; \ else \ diff --git a/.devops/cuda.Dockerfile b/.devops/cuda.Dockerfile index 276f82c34c..b16b9a8f1a 100644 --- a/.devops/cuda.Dockerfile +++ b/.devops/cuda.Dockerfile @@ -11,6 +11,20 @@ ARG BUILD_DATE=N/A ARG APP_VERSION=N/A ARG APP_REVISION=N/A +ARG NODE_VERSION=24 + +FROM docker.io/node:$NODE_VERSION AS web + +ARG APP_VERSION + +WORKDIR /app/tools/ui + +COPY tools/ui/package.json tools/ui/package-lock.json ./ +RUN npm ci + +COPY tools/ui/ ./ +RUN LLAMA_BUILD_NUMBER="$APP_VERSION" npm run build + FROM ${BASE_CUDA_DEV_CONTAINER} AS build ARG GCC_VERSION @@ -26,6 +40,8 @@ WORKDIR /app COPY . . +COPY --from=web /app/tools/ui/dist tools/ui/dist + RUN if [ "${CUDA_DOCKER_ARCH}" != "default" ]; then \ export CMAKE_ARGS="-DCMAKE_CUDA_ARCHITECTURES=${CUDA_DOCKER_ARCH}"; \ fi && \ diff --git a/.devops/intel.Dockerfile b/.devops/intel.Dockerfile index 4d0c0a8fd8..3c059eb301 100644 --- a/.devops/intel.Dockerfile +++ b/.devops/intel.Dockerfile @@ -5,6 +5,20 @@ ARG APP_REVISION=N/A ## Build Image +ARG NODE_VERSION=24 + +FROM docker.io/node:$NODE_VERSION AS web + +ARG APP_VERSION + +WORKDIR /app/tools/ui + +COPY tools/ui/package.json tools/ui/package-lock.json ./ +RUN npm ci + +COPY tools/ui/ ./ +RUN LLAMA_BUILD_NUMBER="$APP_VERSION" npm run build + FROM docker.io/intel/deep-learning-essentials:$ONEAPI_VERSION AS build ARG GGML_SYCL_F16=ON @@ -22,6 +36,8 @@ WORKDIR /app COPY . . +COPY --from=web /app/tools/ui/dist tools/ui/dist + RUN if [ "${GGML_SYCL_F16}" = "ON" ]; then \ echo "GGML_SYCL_F16 is set" \ && export OPT_SYCL_F16="-DGGML_SYCL_F16=ON" \ diff --git a/.devops/musa.Dockerfile b/.devops/musa.Dockerfile index c98c44c951..0c23cc5547 100644 --- a/.devops/musa.Dockerfile +++ b/.devops/musa.Dockerfile @@ -10,6 +10,20 @@ ARG BUILD_DATE=N/A ARG APP_VERSION=N/A ARG APP_REVISION=N/A +ARG NODE_VERSION=24 + +FROM docker.io/node:$NODE_VERSION AS web + +ARG APP_VERSION + +WORKDIR /app/tools/ui + +COPY tools/ui/package.json tools/ui/package-lock.json ./ +RUN npm ci + +COPY tools/ui/ ./ +RUN LLAMA_BUILD_NUMBER="$APP_VERSION" npm run build + FROM ${BASE_MUSA_DEV_CONTAINER} AS build # MUSA architecture to build for (defaults to all supported archs) @@ -29,6 +43,8 @@ WORKDIR /app COPY . . +COPY --from=web /app/tools/ui/dist tools/ui/dist + RUN if [ "${MUSA_DOCKER_ARCH}" != "default" ]; then \ export CMAKE_ARGS="-DMUSA_ARCHITECTURES=${MUSA_DOCKER_ARCH}"; \ fi && \ diff --git a/.devops/openvino.Dockerfile b/.devops/openvino.Dockerfile index 9e96244ced..fec72b1c7d 100644 --- a/.devops/openvino.Dockerfile +++ b/.devops/openvino.Dockerfile @@ -22,6 +22,20 @@ ARG BUILD_DATE=N/A ARG APP_VERSION=N/A ARG APP_REVISION=N/A +ARG NODE_VERSION=24 + +FROM docker.io/node:$NODE_VERSION AS web + +ARG APP_VERSION + +WORKDIR /app/tools/ui + +COPY tools/ui/package.json tools/ui/package-lock.json ./ +RUN npm ci + +COPY tools/ui/ ./ +RUN LLAMA_BUILD_NUMBER="$APP_VERSION" npm run build + ## Build Image FROM docker.io/ubuntu:${UBUNTU_VERSION} AS build @@ -69,6 +83,8 @@ WORKDIR /app COPY . . +COPY --from=web /app/tools/ui/dist tools/ui/dist + # Build Stage RUN bash -c "source ${OpenVINO_DIR}/setupvars.sh && \ cmake -B build/ReleaseOV -G Ninja \ diff --git a/.devops/rocm.Dockerfile b/.devops/rocm.Dockerfile index 2ab10f4117..7fad0c22e5 100644 --- a/.devops/rocm.Dockerfile +++ b/.devops/rocm.Dockerfile @@ -11,6 +11,20 @@ ARG BUILD_DATE=N/A ARG APP_VERSION=N/A ARG APP_REVISION=N/A +ARG NODE_VERSION=24 + +FROM docker.io/node:$NODE_VERSION AS web + +ARG APP_VERSION + +WORKDIR /app/tools/ui + +COPY tools/ui/package.json tools/ui/package-lock.json ./ +RUN npm ci + +COPY tools/ui/ ./ +RUN LLAMA_BUILD_NUMBER="$APP_VERSION" npm run build + ### Build image FROM ${BASE_ROCM_DEV_CONTAINER} AS build @@ -38,6 +52,8 @@ WORKDIR /app COPY . . +COPY --from=web /app/tools/ui/dist tools/ui/dist + RUN HIPCXX="$(hipconfig -l)/clang" HIP_PATH="$(hipconfig -R)" \ cmake -S . -B build \ -DGGML_HIP=ON \ diff --git a/.devops/s390x.Dockerfile b/.devops/s390x.Dockerfile index d88dd2d92d..149d79a615 100644 --- a/.devops/s390x.Dockerfile +++ b/.devops/s390x.Dockerfile @@ -4,6 +4,20 @@ ARG BUILD_DATE=N/A ARG APP_VERSION=N/A ARG APP_REVISION=N/A +ARG NODE_VERSION=24 + +FROM docker.io/node:$NODE_VERSION AS web + +ARG APP_VERSION + +WORKDIR /app/tools/ui + +COPY tools/ui/package.json tools/ui/package-lock.json ./ +RUN npm ci + +COPY tools/ui/ ./ +RUN LLAMA_BUILD_NUMBER="$APP_VERSION" npm run build + ### Build Llama.cpp stage FROM docker.io/gcc:${GCC_VERSION} AS build @@ -20,6 +34,8 @@ RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \ WORKDIR /app COPY . . +COPY --from=web /app/tools/ui/dist tools/ui/dist + RUN --mount=type=cache,target=/root/.ccache \ --mount=type=cache,target=/app/build \ cmake -S . -B build -G Ninja \ diff --git a/.devops/vulkan.Dockerfile b/.devops/vulkan.Dockerfile index 05df94ec44..26c1902b14 100644 --- a/.devops/vulkan.Dockerfile +++ b/.devops/vulkan.Dockerfile @@ -3,6 +3,20 @@ ARG BUILD_DATE=N/A ARG APP_VERSION=N/A ARG APP_REVISION=N/A +ARG NODE_VERSION=24 + +FROM docker.io/node:$NODE_VERSION AS web + +ARG APP_VERSION + +WORKDIR /app/tools/ui + +COPY tools/ui/package.json tools/ui/package-lock.json ./ +RUN npm ci + +COPY tools/ui/ ./ +RUN LLAMA_BUILD_NUMBER="$APP_VERSION" npm run build + FROM docker.io/ubuntu:$UBUNTU_VERSION AS build # Install build tools @@ -17,6 +31,8 @@ WORKDIR /app COPY . . +COPY --from=web /app/tools/ui/dist tools/ui/dist + RUN cmake -B build -DGGML_NATIVE=OFF -DGGML_VULKAN=ON -DLLAMA_BUILD_TESTS=OFF -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON && \ cmake --build build --config Release -j$(nproc) diff --git a/.devops/zendnn.Dockerfile b/.devops/zendnn.Dockerfile index 9f811ab278..80daf56710 100644 --- a/.devops/zendnn.Dockerfile +++ b/.devops/zendnn.Dockerfile @@ -3,6 +3,20 @@ ARG BUILD_DATE=N/A ARG APP_VERSION=N/A ARG APP_REVISION=N/A +ARG NODE_VERSION=24 + +FROM docker.io/node:$NODE_VERSION AS web + +ARG APP_VERSION + +WORKDIR /app/tools/ui + +COPY tools/ui/package.json tools/ui/package-lock.json ./ +RUN npm ci + +COPY tools/ui/ ./ +RUN LLAMA_BUILD_NUMBER="$APP_VERSION" npm run build + FROM docker.io/ubuntu:$UBUNTU_VERSION AS build RUN apt-get update && \ @@ -14,6 +28,8 @@ WORKDIR /app COPY . . +COPY --from=web /app/tools/ui/dist tools/ui/dist + RUN cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DGGML_NATIVE=OFF -DLLAMA_BUILD_TESTS=OFF -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DGGML_ZENDNN=ON && \ cmake --build build -j $(nproc) diff --git a/.dockerignore b/.dockerignore index 064b7c7be8..a223b7e898 100644 --- a/.dockerignore +++ b/.dockerignore @@ -10,6 +10,9 @@ build*/ +tools/ui/node_modules/ +tools/ui/dist/ + models/* /llama-cli From 8c2d6f6475f2586e079fa6677dec91088de85604 Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Fri, 19 Jun 2026 16:06:13 +0200 Subject: [PATCH 15/86] server: add --agent arg, remove redundant webui naming compat (#24801) * server: add --agent arg, remove redundant webui naming compat * corrent env * fix the test * llama-gen-docs * nits: wordings --- common/arg.cpp | 70 +++++++++------------------------ common/common.h | 6 --- tools/cli/README.md | 3 +- tools/server/README.md | 18 ++++----- tools/server/server-context.cpp | 15 +++---- tools/server/server-models.cpp | 15 ++++--- tools/server/server-models.h | 5 +-- tools/server/server.cpp | 3 +- 8 files changed, 43 insertions(+), 92 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index bd4b113d16..52425f25e4 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2830,62 +2830,26 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.api_prefix = value; } ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_API_PREFIX")); - // Deprecated: use --ui-config instead (kept for backward compat) add_opt(common_arg( - {"--webui-config"}, "JSON", - "[DEPRECATED: use --ui-config] JSON that provides default WebUI settings (overrides WebUI defaults)", - [](common_params & params, const std::string & value) { - params.ui_config_json = value; - params.webui_config_json = value; - } - ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_WEBUI_CONFIG")); - - add_opt(common_arg( - {"--ui-config"}, "JSON", + {"--ui-config", "--webui-config"}, "JSON", "JSON that provides default UI settings (overrides UI defaults)", [](common_params & params, const std::string & value) { params.ui_config_json = value; - params.webui_config_json = value; } ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_UI_CONFIG")); - - // Deprecated: use --ui-config-file instead (kept for backward compat) add_opt(common_arg( - {"--webui-config-file"}, "PATH", - "[DEPRECATED: use --ui-config-file] JSON file that provides default WebUI settings (overrides WebUI defaults)", - [](common_params & params, const std::string & value) { - params.ui_config_json = read_file(value); - params.webui_config_json = params.ui_config_json; - } - ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_WEBUI_CONFIG_FILE")); - - add_opt(common_arg( - {"--ui-config-file"}, "PATH", + {"--ui-config-file", "--webui-config-file"}, "PATH", "JSON file that provides default UI settings (overrides UI defaults)", [](common_params & params, const std::string & value) { params.ui_config_json = read_file(value); - params.webui_config_json = params.ui_config_json; } ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_UI_CONFIG_FILE")); - - // Deprecated: use --ui-mcp-proxy instead (kept for backward compat) add_opt(common_arg( - {"--webui-mcp-proxy"}, - {"--no-webui-mcp-proxy"}, - "[DEPRECATED: use --ui-mcp-proxy/--no-ui-mcp-proxy] experimental: whether to enable MCP CORS proxy", - [](common_params & params, bool value) { - params.ui_mcp_proxy = value; - params.webui_mcp_proxy = value; - } - ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_WEBUI_MCP_PROXY")); - - add_opt(common_arg( - {"--ui-mcp-proxy"}, - {"--no-ui-mcp-proxy"}, + {"--ui-mcp-proxy", "--webui-mcp-proxy"}, + {"--no-ui-mcp-proxy", "--no-webui-mcp-proxy"}, "experimental: whether to enable MCP CORS proxy - do not enable in untrusted environments (default: disabled)", [](common_params & params, bool value) { params.ui_mcp_proxy = value; - params.webui_mcp_proxy = value; } ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_UI_MCP_PROXY")); add_opt(common_arg( @@ -2897,24 +2861,26 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.server_tools = parse_csv_row(value); } ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_TOOLS")); - // Deprecated: use --ui/--no-ui instead (kept for backward compat) - add_opt(common_arg( - {"--webui"}, - {"--no-webui"}, - "[DEPRECATED: use --ui/--no-ui] whether to enable the Web UI", + add_opt(common_arg( + {"-ag", "--agent"}, + {"-no-ag", "--no-agent"}, + "whether to enable CORS proxy and all built-in tools - do not enable in untrusted environments (default: disabled)", [](common_params & params, bool value) { - params.ui = value; - params.webui = value; + if (value) { + params.server_tools = {"all"}; + params.ui_mcp_proxy = true; + } else { + params.server_tools.clear(); + params.ui_mcp_proxy = false; + } } - ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_WEBUI")); - + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_AGENT")); add_opt(common_arg( - {"--ui"}, - {"--no-ui"}, + {"--ui", "--webui"}, + {"--no-ui", "--no-webui"}, string_format("whether to enable the Web UI (default: %s)", params.ui ? "enabled" : "disabled"), [](common_params & params, bool value) { params.ui = value; - params.webui = value; } ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_UI")); add_opt(common_arg( diff --git a/common/common.h b/common/common.h index 535a4ed335..44c605189c 100644 --- a/common/common.h +++ b/common/common.h @@ -624,12 +624,6 @@ struct common_params { // UI configs bool ui = true; - - // Deprecated: use ui, ui_mcp_proxy, ui_config_json instead - bool webui = ui; - bool webui_mcp_proxy = false; - std::string webui_config_json; - bool ui_mcp_proxy = false; std::string ui_config_json; diff --git a/tools/cli/README.md b/tools/cli/README.md index b11aa45ce9..f93ae914ce 100644 --- a/tools/cli/README.md +++ b/tools/cli/README.md @@ -161,7 +161,7 @@ | `-mmu, --mmproj-url URL` | URL to a multimodal projector file. see tools/mtmd/README.md
(env: LLAMA_ARG_MMPROJ_URL) | | `--mmproj-auto, --no-mmproj, --no-mmproj-auto` | whether to use multimodal projector file (if available), useful when using -hf (default: enabled)
(env: LLAMA_ARG_MMPROJ_AUTO) | | `--mmproj-offload, --no-mmproj-offload` | whether to enable GPU offloading for multimodal projector (default: enabled)
(env: LLAMA_ARG_MMPROJ_OFFLOAD) | -| `--image, --audio FILE` | path to an image or audio file. use with multimodal models, use comma-separated values for multiple files | +| `--image, --audio, --video FILE` | path to an image, audio, or video file. use with multimodal models, use comma-separated values for multiple files | | `--image-min-tokens N` | minimum number of tokens each image can take, only used by vision models with dynamic resolution (default: read from model)
(env: LLAMA_ARG_IMAGE_MIN_TOKENS) | | `--image-max-tokens N` | maximum number of tokens each image can take, only used by vision models with dynamic resolution (default: read from model)
(env: LLAMA_ARG_IMAGE_MAX_TOKENS) | | `--chat-template-kwargs STRING` | sets additional params for the json template parser, must be a valid json object string, e.g. '{"key1":"value1","key2":"value2"}'
(env: LLAMA_ARG_CHAT_TEMPLATE_KWARGS) | @@ -174,6 +174,7 @@ | `--chat-template-file JINJA_TEMPLATE_FILE` | set custom jinja chat template file (default: template taken from model's metadata)
if suffix/prefix are specified, template will be disabled
only commonly used templates are accepted (unless --jinja is set before this flag):
list of built-in templates:
bailing, bailing-think, bailing2, chatglm3, chatglm4, chatml, command-r, deepseek, deepseek-ocr, deepseek2, deepseek3, exaone-moe, exaone3, exaone4, falcon3, gemma, gigachat, glmedge, gpt-oss, granite, granite-4.0, granite-4.1, grok-2, hunyuan-dense, hunyuan-moe, hunyuan-vl, kimi-k2, llama2, llama2-sys, llama2-sys-bos, llama2-sys-strip, llama3, llama4, megrez, minicpm, mistral-v1, mistral-v3, mistral-v3-tekken, mistral-v7, mistral-v7-tekken, monarch, openchat, orion, pangu-embedded, phi3, phi4, rwkv-world, seed_oss, smolvlm, solar-open, vicuna, vicuna-orca, yandex, zephyr
(env: LLAMA_ARG_CHAT_TEMPLATE_FILE) | | `--skip-chat-parsing, --no-skip-chat-parsing` | force a pure content parser, even if a Jinja template is specified; model will output everything in the content section, including any reasoning and/or tool calls (default: disabled)
(env: LLAMA_ARG_SKIP_CHAT_PARSING) | | `--simple-io` | use basic IO for better compatibility in subprocesses and limited consoles | +| `--log-prompts-dir PATH` | Log prompts to directory (only used for debugging, default: disabled) | | `--spec-draft-hf, -hfd, -hfrd, --hf-repo-draft /[:quant]` | Same as --hf-repo, but for the draft model (default: unused)
(env: LLAMA_ARG_SPEC_DRAFT_HF_REPO) | | `--spec-draft-threads, -td, --threads-draft N` | number of threads to use during generation (default: same as --threads) | | `--spec-draft-threads-batch, -tbd, --threads-batch-draft N` | number of threads to use during batch and prompt processing (default: same as --threads-draft) | diff --git a/tools/server/README.md b/tools/server/README.md index 88a507e2c5..1f74ba52ae 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -175,13 +175,12 @@ For the full list of features, please refer to [server's changelog](https://gith | `-np, --parallel N` | number of server slots (default: -1, -1 = auto)
(env: LLAMA_ARG_N_PARALLEL) | | `-cb, --cont-batching, -nocb, --no-cont-batching` | whether to enable continuous batching (a.k.a dynamic batching) (default: enabled)
(env: LLAMA_ARG_CONT_BATCHING) | | `-mm, --mmproj FILE` | path to a multimodal projector file. see tools/mtmd/README.md
note: if -hf is used, this argument can be omitted
(env: LLAMA_ARG_MMPROJ) | -| `-tk, --talker-model FILE` | path to the qwen3-omni talker gguf, enables the /v1/audio/speech endpoint
(env: LLAMA_ARG_TALKER_MODEL) | -| `-c2w, --code2wav-model FILE` | path to the qwen3-omni code2wav gguf, the talker code detokenizer
(env: LLAMA_ARG_CODE2WAV_MODEL) | | `-mmu, --mmproj-url URL` | URL to a multimodal projector file. see tools/mtmd/README.md
(env: LLAMA_ARG_MMPROJ_URL) | | `--mmproj-auto, --no-mmproj, --no-mmproj-auto` | whether to use multimodal projector file (if available), useful when using -hf (default: enabled)
(env: LLAMA_ARG_MMPROJ_AUTO) | | `--mmproj-offload, --no-mmproj-offload` | whether to enable GPU offloading for multimodal projector (default: enabled)
(env: LLAMA_ARG_MMPROJ_OFFLOAD) | | `--image-min-tokens N` | minimum number of tokens each image can take, only used by vision models with dynamic resolution (default: read from model)
(env: LLAMA_ARG_IMAGE_MIN_TOKENS) | | `--image-max-tokens N` | maximum number of tokens each image can take, only used by vision models with dynamic resolution (default: read from model)
(env: LLAMA_ARG_IMAGE_MAX_TOKENS) | +| `--mtmd-batch-max-tokens N` | maximum number of image tokens per batch when encoding images (default: 1024)
(env: LLAMA_ARG_MTMD_BATCH_MAX_TOKENS) | | `-a, --alias STRING` | set model name aliases, comma-separated (to be used by API)
(env: LLAMA_ARG_ALIAS) | | `--tags STRING` | set model tags, comma-separated (informational, not used for routing)
(env: LLAMA_ARG_TAGS) | | `--embd-normalize N` | normalisation for embeddings (default: 2) (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm) | @@ -190,15 +189,12 @@ For the full list of features, please refer to [server's changelog](https://gith | `--reuse-port` | allow multiple sockets to bind to the same port (default: disabled)
(env: LLAMA_ARG_REUSE_PORT) | | `--path PATH` | path to serve static files from (default: )
(env: LLAMA_ARG_STATIC_PATH) | | `--api-prefix PREFIX` | prefix path the server serves from, without the trailing slash (default: )
(env: LLAMA_ARG_API_PREFIX) | -| `--webui-config JSON` | [DEPRECATED: use --ui-config] JSON that provides default WebUI settings (overrides WebUI defaults)
(env: LLAMA_ARG_WEBUI_CONFIG) | -| `--ui-config JSON` | JSON that provides default UI settings (overrides UI defaults)
(env: LLAMA_ARG_UI_CONFIG) | -| `--webui-config-file PATH` | [DEPRECATED: use --ui-config-file] JSON file that provides default WebUI settings (overrides WebUI defaults)
(env: LLAMA_ARG_WEBUI_CONFIG_FILE) | -| `--ui-config-file PATH` | JSON file that provides default UI settings (overrides UI defaults)
(env: LLAMA_ARG_UI_CONFIG_FILE) | -| `--webui-mcp-proxy, --no-webui-mcp-proxy` | [DEPRECATED: use --ui-mcp-proxy/--no-ui-mcp-proxy] experimental: whether to enable MCP CORS proxy
(env: LLAMA_ARG_WEBUI_MCP_PROXY) | -| `--ui-mcp-proxy, --no-ui-mcp-proxy` | experimental: whether to enable MCP CORS proxy - do not enable in untrusted environments (default: disabled)
(env: LLAMA_ARG_UI_MCP_PROXY) | +| `--ui-config, --webui-config JSON` | JSON that provides default UI settings (overrides UI defaults)
(env: LLAMA_ARG_UI_CONFIG) | +| `--ui-config-file, --webui-config-file PATH` | JSON file that provides default UI settings (overrides UI defaults)
(env: LLAMA_ARG_UI_CONFIG_FILE) | +| `--ui-mcp-proxy, --webui-mcp-proxy, --no-ui-mcp-proxy, --no-webui-mcp-proxy` | experimental: whether to enable MCP CORS proxy - do not enable in untrusted environments (default: disabled)
(env: LLAMA_ARG_UI_MCP_PROXY) | | `--tools TOOL1,TOOL2,...` | experimental: whether to enable built-in tools for AI agents - do not enable in untrusted environments (default: no tools)
specify "all" to enable all tools
available tools: read_file, file_glob_search, grep_search, exec_shell_command, write_file, edit_file, apply_diff, get_datetime
(env: LLAMA_ARG_TOOLS) | -| `--webui, --no-webui` | [DEPRECATED: use --ui/--no-ui] whether to enable the Web UI
(env: LLAMA_ARG_WEBUI) | -| `--ui, --no-ui` | whether to enable the Web UI (default: enabled)
(env: LLAMA_ARG_UI) | +| `-ag, --agent, -no-ag, --no-agent` | whether to enable CORS proxy and all built-in tools - do not enable in untrusted environments (default: disabled)
(env: LLAMA_ARG_AGENT) | +| `--ui, --webui, --no-ui, --no-webui` | whether to enable the Web UI (default: enabled)
(env: LLAMA_ARG_UI) | | `--embedding, --embeddings` | restrict to only support embedding use case; use only with dedicated embedding models (default: disabled)
(env: LLAMA_ARG_EMBEDDINGS) | | `--rerank, --reranking` | enable reranking endpoint on server (default: disabled)
(env: LLAMA_ARG_RERANKING) | | `--api-key KEY` | API key to use for authentication, multiple keys can be provided as a comma-separated list (default: none)
(env: LLAMA_API_KEY) | @@ -207,6 +203,7 @@ For the full list of features, please refer to [server's changelog](https://gith | `--ssl-cert-file FNAME` | path to file a PEM-encoded SSL certificate
(env: LLAMA_ARG_SSL_CERT_FILE) | | `--chat-template-kwargs STRING` | sets additional params for the json template parser, must be a valid json object string, e.g. '{"key1":"value1","key2":"value2"}'
(env: LLAMA_ARG_CHAT_TEMPLATE_KWARGS) | | `-to, --timeout N` | server read/write timeout in seconds (default: 3600)
(env: LLAMA_ARG_TIMEOUT) | +| `--sse-ping-interval N` | server SSE ping interval in seconds (-1 = disabled, default: 30)
(env: LLAMA_ARG_SSE_PING_INTERVAL) | | `--threads-http N` | number of threads used to process HTTP requests (default: -1)
(env: LLAMA_ARG_THREADS_HTTP) | | `--cache-prompt, --no-cache-prompt` | whether to enable prompt caching (default: enabled)
(env: LLAMA_ARG_CACHE_PROMPT) | | `--cache-reuse N` | min chunk size to attempt reusing from the cache via KV shifting, requires prompt caching to be enabled (default: 0)
[(card)](https://ggml.ai/f0.png)
(env: LLAMA_ARG_CACHE_REUSE) | @@ -231,6 +228,7 @@ For the full list of features, please refer to [server's changelog](https://gith | `-sps, --slot-prompt-similarity SIMILARITY` | how much the prompt of a request must match the prompt of a slot in order to use that slot (default: 0.10, 0.0 = disabled) | | `--lora-init-without-apply` | load LoRA adapters without applying them (apply later via POST /lora-adapters) (default: disabled) | | `--sleep-idle-seconds SECONDS` | number of seconds of idleness after which the server will sleep (default: -1; -1 = disabled) | +| `--log-prompts-dir PATH` | Log prompts to directory (only used for debugging, default: disabled) | | `--spec-draft-hf, -hfd, -hfrd, --hf-repo-draft /[:quant]` | Same as --hf-repo, but for the draft model (default: unused)
(env: LLAMA_ARG_SPEC_DRAFT_HF_REPO) | | `--spec-draft-threads, -td, --threads-draft N` | number of threads to use during generation (default: same as --threads) | | `--spec-draft-threads-batch, -tbd, --threads-batch-draft N` | number of threads to use during batch and prompt processing (default: same as --threads-draft) | diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index c887beb0ac..00ab31340b 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -1302,11 +1302,8 @@ private: } } - // populate UI settings (from either new ui_config_json or deprecated webui_config_json) { - const std::string & cfg = !params_base.ui_config_json.empty() - ? params_base.ui_config_json - : params_base.webui_config_json; + const std::string & cfg = params_base.ui_config_json; if (!cfg.empty()) { try { json json_settings = json::parse(cfg); @@ -4304,18 +4301,18 @@ void server_routes::init_routes() { { "endpoint_props", params.endpoint_props }, { "endpoint_metrics", params.endpoint_metrics }, // New keys - { "ui", params.ui }, - { "ui_settings", meta->json_ui_settings }, + { "ui", params.ui }, + { "ui_settings", meta->json_ui_settings }, // Deprecated: use ui/ui_settings instead (kept for backward compat) - { "webui", params.webui }, - { "webui_settings", meta->json_webui_settings }, + { "webui", params.ui }, + { "webui_settings", meta->json_ui_settings }, { "chat_template", tmpl_default }, { "chat_template_caps", meta->chat_template_caps }, { "bos_token", meta->bos_token_str }, { "eos_token", meta->eos_token_str }, { "build_info", meta->build_info }, { "is_sleeping", queue_tasks.is_sleeping() }, - { "cors_proxy_enabled", params.ui_mcp_proxy || params.webui_mcp_proxy }, + { "cors_proxy_enabled", params.ui_mcp_proxy }, }; if (params.use_jinja) { if (!tmpl_tools.empty()) { diff --git a/tools/server/server-models.cpp b/tools/server/server-models.cpp index 7aaad69261..23c1f16689 100644 --- a/tools/server/server-models.cpp +++ b/tools/server/server-models.cpp @@ -1462,9 +1462,9 @@ void server_models_routes::init_routes() { auto res = std::make_unique(); res_ok(res, { // TODO: add support for this on web UI - {"role", "router"}, - {"max_instances", params.models_max}, - {"models_autoload", params.models_autoload}, + {"role", "router"}, + {"max_instances", params.models_max}, + {"models_autoload", params.models_autoload}, // this is a dummy response to make sure the UI doesn't break {"model_alias", "llama-server"}, {"model_path", "none"}, @@ -1473,11 +1473,10 @@ void server_models_routes::init_routes() { {"n_ctx", 0}, }}, // New key - {"ui_settings", ui_settings}, - // Deprecated: use ui_settings instead (kept for backward compat) - {"webui_settings", webui_settings}, - {"build_info", std::string(llama_build_info())}, - {"cors_proxy_enabled", params.ui_mcp_proxy || params.webui_mcp_proxy}, + {"ui_settings", ui_settings}, + {"webui_settings", webui_settings}, + {"build_info", std::string(llama_build_info())}, + {"cors_proxy_enabled", params.ui_mcp_proxy}, }); return res; } diff --git a/tools/server/server-models.h b/tools/server/server-models.h index 319c4352e2..aeb0e874de 100644 --- a/tools/server/server-models.h +++ b/tools/server/server-models.h @@ -212,10 +212,7 @@ struct server_models_routes { server_models models; server_models_routes(const common_params & params, int argc, char ** argv) : params(params), models(params, argc, argv) { - // Support both new ui_config_json and deprecated webui_config_json - const std::string & cfg = !this->params.ui_config_json.empty() - ? this->params.ui_config_json - : this->params.webui_config_json; + const std::string & cfg = this->params.ui_config_json; if (!cfg.empty()) { try { json json_settings = json::parse(cfg); diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 78ab0318cf..2a67bfcfed 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -227,8 +227,7 @@ int llama_server(int argc, char ** argv) { ctx_http.register_gcp_compat(); // CORS proxy (EXPERIMENTAL, only used by the Web UI for MCP) - // Supports both new ui_mcp_proxy and deprecated webui_mcp_proxy fields - if (params.ui_mcp_proxy || params.webui_mcp_proxy) { + if (params.ui_mcp_proxy) { SRV_WRN("%s", "-----------------\n"); SRV_WRN("%s", "CORS proxy is enabled, do not expose server to untrusted environments\n"); SRV_WRN("%s", "This feature is EXPERIMENTAL and may be removed or changed in future versions\n"); From 0d2d9ccbf6aae92de310712297fd52becc134092 Mon Sep 17 00:00:00 2001 From: "Alessandro de Oliveira Faria (A.K.A.CABELO)" Date: Fri, 19 Jun 2026 11:16:35 -0300 Subject: [PATCH 16/86] vendor : update cpp-httplib to 0.48.0 (#24787) --- scripts/sync_vendor.py | 2 +- vendor/cpp-httplib/httplib.cpp | 241 +++++++++++++-------------------- vendor/cpp-httplib/httplib.h | 81 ++++++++--- 3 files changed, 160 insertions(+), 164 deletions(-) diff --git a/scripts/sync_vendor.py b/scripts/sync_vendor.py index 402d7bbad3..f913b0c7dc 100755 --- a/scripts/sync_vendor.py +++ b/scripts/sync_vendor.py @@ -5,7 +5,7 @@ import os import sys import subprocess -HTTPLIB_VERSION = "refs/tags/v0.47.0" +HTTPLIB_VERSION = "refs/tags/v0.48.0" vendor = { "https://github.com/nlohmann/json/releases/latest/download/json.hpp": "vendor/nlohmann/json.hpp", diff --git a/vendor/cpp-httplib/httplib.cpp b/vendor/cpp-httplib/httplib.cpp index 370c0e798d..1ac4fa4ba4 100644 --- a/vendor/cpp-httplib/httplib.cpp +++ b/vendor/cpp-httplib/httplib.cpp @@ -5809,11 +5809,9 @@ std::string decode_query_component(const std::string &component, for (size_t i = 0; i < component.size(); i++) { if (component[i] == '%' && i + 2 < component.size()) { - std::string hex = component.substr(i + 1, 2); - char *end; - unsigned long value = std::strtoul(hex.c_str(), &end, 16); - if (end == hex.c_str() + 2) { - result += static_cast(value); + auto val = 0; + if (detail::from_hex_to_i(component, i + 1, 2, val)) { + result += static_cast(val); i += 2; } else { result += component[i]; @@ -12551,6 +12549,21 @@ bool parse_ipv4(const std::string &str, unsigned char *out) { return *p == '\0'; } +// Parse an IP literal (IPv4 or IPv6) into raw network-order bytes. +// `out` must have room for at least 16 bytes. Returns the address length +// (4 for IPv4, 16 for IPv6) on success, or 0 if the string is not an IP +// literal. Used to match a host against iPAddress SANs the same way the +// OpenSSL backend does via X509_check_ip. +size_t parse_ip_address(const std::string &str, unsigned char *out) { + if (is_ipv4_address(str)) { return parse_ipv4(str, out) ? 4 : 0; } + struct in6_addr addr6 = {}; + if (inet_pton(AF_INET6, str.c_str(), &addr6) == 1) { + memcpy(out, &addr6, 16); + return 16; + } + return 0; +} + #ifdef _WIN32 // Enumerate Windows system certificates and call callback with DER data template @@ -12852,6 +12865,30 @@ int openssl_verify_callback(int preverify_ok, X509_STORE_CTX *ctx) { return callback(verify_ctx) ? 1 : 0; } +// X509_STORE_get0_objects is deprecated since OpenSSL 4.0 because it is not +// thread-safe; X509_STORE_get1_objects (OpenSSL 3.3+) returns a snapshot +// that must be released with release_store_objects +#if !defined(OPENSSL_IS_BORINGSSL) && !defined(LIBRESSL_VERSION_NUMBER) && \ + OPENSSL_VERSION_NUMBER >= 0x30300000L +#define CPPHTTPLIB_HAS_X509_STORE_GET1_OBJECTS +#endif + +STACK_OF(X509_OBJECT) * get_store_objects(X509_STORE *store) { +#ifdef CPPHTTPLIB_HAS_X509_STORE_GET1_OBJECTS + return X509_STORE_get1_objects(store); +#else + return X509_STORE_get0_objects(store); +#endif +} + +void release_store_objects(STACK_OF(X509_OBJECT) * objs) { +#ifdef CPPHTTPLIB_HAS_X509_STORE_GET1_OBJECTS + sk_X509_OBJECT_pop_free(objs, X509_OBJECT_free); +#else + (void)objs; // get0 variant returns an internal pointer; nothing to free +#endif +} + } // namespace impl ctx_t create_client_context() { @@ -13373,11 +13410,19 @@ std::string get_cert_subject_cn(cert_t cert) { auto subject_name = X509_get_subject_name(x509); if (!subject_name) return ""; - char buf[256]; - auto len = - X509_NAME_get_text_by_NID(subject_name, NID_commonName, buf, sizeof(buf)); - if (len < 0) return ""; - return std::string(buf, static_cast(len)); + // X509_NAME_get_text_by_NID is deprecated since OpenSSL 4.0 + auto idx = X509_NAME_get_index_by_NID(subject_name, NID_commonName, -1); + if (idx < 0) return ""; + + auto entry = X509_NAME_get_entry(subject_name, idx); + if (!entry) return ""; + + auto data = X509_NAME_ENTRY_get_data(entry); + if (!data) return ""; + + return std::string( + reinterpret_cast(ASN1_STRING_get0_data(data)), + static_cast(ASN1_STRING_length(data))); } std::string get_cert_issuer_name(cert_t cert) { @@ -13582,8 +13627,9 @@ size_t get_ca_certs(ctx_t ctx, std::vector &certs) { auto store = SSL_CTX_get_cert_store(ssl_ctx); if (!store) { return 0; } - auto objs = X509_STORE_get0_objects(store); + auto objs = impl::get_store_objects(store); if (!objs) { return 0; } + auto se = detail::scope_exit([&] { impl::release_store_objects(objs); }); auto count = sk_X509_OBJECT_num(objs); for (decltype(count) i = 0; i < count; i++) { @@ -13609,8 +13655,9 @@ std::vector get_ca_names(ctx_t ctx) { auto store = SSL_CTX_get_cert_store(ssl_ctx); if (!store) { return names; } - auto objs = X509_STORE_get0_objects(store); + auto objs = impl::get_store_objects(store); if (!objs) { return names; } + auto se = detail::scope_exit([&] { impl::release_store_objects(objs); }); auto count = sk_X509_OBJECT_num(objs); for (decltype(count) i = 0; i < count; i++) { @@ -13716,110 +13763,6 @@ std::string verify_error_string(long error_code) { } // namespace tls -bool SSLClient::verify_host(X509 *server_cert) const { - /* Quote from RFC2818 section 3.1 "Server Identity" - - If a subjectAltName extension of type dNSName is present, that MUST - be used as the identity. Otherwise, the (most specific) Common Name - field in the Subject field of the certificate MUST be used. Although - the use of the Common Name is existing practice, it is deprecated and - Certification Authorities are encouraged to use the dNSName instead. - - Matching is performed using the matching rules specified by - [RFC2459]. If more than one identity of a given type is present in - the certificate (e.g., more than one dNSName name, a match in any one - of the set is considered acceptable.) Names may contain the wildcard - character * which is considered to match any single domain name - component or component fragment. E.g., *.a.com matches foo.a.com but - not bar.foo.a.com. f*.com matches foo.com but not bar.com. - - In some cases, the URI is specified as an IP address rather than a - hostname. In this case, the iPAddress subjectAltName must be present - in the certificate and must exactly match the IP in the URI. - - */ - return verify_host_with_subject_alt_name(server_cert) || - verify_host_with_common_name(server_cert); -} - -bool -SSLClient::verify_host_with_subject_alt_name(X509 *server_cert) const { - auto ret = false; - - auto type = GEN_DNS; - - struct in6_addr addr6 = {}; - struct in_addr addr = {}; - size_t addr_len = 0; - -#ifndef __MINGW32__ - if (inet_pton(AF_INET6, host_.c_str(), &addr6)) { - type = GEN_IPADD; - addr_len = sizeof(struct in6_addr); - } else if (inet_pton(AF_INET, host_.c_str(), &addr)) { - type = GEN_IPADD; - addr_len = sizeof(struct in_addr); - } -#endif - - auto alt_names = static_cast( - X509_get_ext_d2i(server_cert, NID_subject_alt_name, nullptr, nullptr)); - - if (alt_names) { - auto dsn_matched = false; - auto ip_matched = false; - - auto count = sk_GENERAL_NAME_num(alt_names); - - for (decltype(count) i = 0; i < count && !dsn_matched; i++) { - auto val = sk_GENERAL_NAME_value(alt_names, i); - if (!val || val->type != type) { continue; } - - auto name = - reinterpret_cast(ASN1_STRING_get0_data(val->d.ia5)); - if (name == nullptr) { continue; } - - auto name_len = static_cast(ASN1_STRING_length(val->d.ia5)); - - switch (type) { - case GEN_DNS: - dsn_matched = - detail::match_hostname(std::string(name, name_len), host_); - break; - - case GEN_IPADD: - if (!memcmp(&addr6, name, addr_len) || !memcmp(&addr, name, addr_len)) { - ip_matched = true; - } - break; - } - } - - if (dsn_matched || ip_matched) { ret = true; } - } - - GENERAL_NAMES_free(const_cast( - reinterpret_cast(alt_names))); - return ret; -} - -bool SSLClient::verify_host_with_common_name(X509 *server_cert) const { - const auto subject_name = X509_get_subject_name(server_cert); - - if (subject_name != nullptr) { - char name[BUFSIZ]; - auto name_len = X509_NAME_get_text_by_NID(subject_name, NID_commonName, - name, sizeof(name)); - - if (name_len != -1) { - return detail::match_hostname( - std::string(name, static_cast(name_len)), host_); - } - } - - return false; -} - #endif // CPPHTTPLIB_OPENSSL_SUPPORT /* @@ -14622,10 +14565,10 @@ bool verify_hostname(cert_t cert, const char *hostname) { auto mcert = static_cast(cert); std::string host_str(hostname); - // Check if hostname is an IP address - bool is_ip = impl::is_ipv4_address(host_str); - unsigned char ip_bytes[4]; - if (is_ip) { impl::parse_ipv4(host_str, ip_bytes); } + // Check if hostname is an IP address (IPv4 or IPv6) + unsigned char ip_bytes[16]; + auto ip_len = impl::parse_ip_address(host_str, ip_bytes); + auto is_ip = ip_len > 0; // Check Subject Alternative Names (SAN) // In Mbed TLS 3.x, subject_alt_names contains raw values without ASN.1 tags @@ -14637,9 +14580,9 @@ bool verify_hostname(cert_t cert, const char *hostname) { size_t len = san->buf.len; if (is_ip) { - // Check if this SAN is an IPv4 address (4 bytes) - if (len == 4 && memcmp(p, ip_bytes, 4) == 0) { return true; } - // Check if this SAN is an IPv6 address (16 bytes) - skip for now + // For an IP host, only a matching iPAddress SAN of the same family + // (4 bytes for IPv4, 16 bytes for IPv6) may authenticate it. + if (len == ip_len && memcmp(p, ip_bytes, ip_len) == 0) { return true; } } else { // Check if this SAN is a DNS name (printable ASCII string) bool is_dns = len > 0; @@ -14654,21 +14597,25 @@ bool verify_hostname(cert_t cert, const char *hostname) { san = san->next; } - // Fallback: Check Common Name (CN) in subject - char cn[256]; - int ret = mbedtls_x509_dn_gets(cn, sizeof(cn), &mcert->subject); - if (ret > 0) { - std::string cn_str(cn); + // Fallback: Check Common Name (CN) in subject. Skipped for IP-literal hosts: + // an IP identity is only valid via an iPAddress SAN, never the CN (RFC 9110; + // the OpenSSL backend's X509_check_ip behaves the same way). + if (!is_ip) { + char cn[256]; + int ret = mbedtls_x509_dn_gets(cn, sizeof(cn), &mcert->subject); + if (ret > 0) { + std::string cn_str(cn); - // Look for "CN=" in the DN string - size_t cn_pos = cn_str.find("CN="); - if (cn_pos != std::string::npos) { - size_t start = cn_pos + 3; - size_t end = cn_str.find(',', start); - std::string cn_value = - cn_str.substr(start, end == std::string::npos ? end : end - start); + // Look for "CN=" in the DN string + size_t cn_pos = cn_str.find("CN="); + if (cn_pos != std::string::npos) { + size_t start = cn_pos + 3; + size_t end = cn_str.find(',', start); + std::string cn_value = + cn_str.substr(start, end == std::string::npos ? end : end - start); - if (detail::match_hostname(cn_value, host_str)) { return true; } + if (detail::match_hostname(cn_value, host_str)) { return true; } + } } } @@ -15774,10 +15721,10 @@ bool verify_hostname(cert_t cert, const char *hostname) { auto x509 = static_cast(cert); std::string host_str(hostname); - // Check if hostname is an IP address - bool is_ip = impl::is_ipv4_address(host_str); - unsigned char ip_bytes[4]; - if (is_ip) { impl::parse_ipv4(host_str, ip_bytes); } + // Check if hostname is an IP address (IPv4 or IPv6) + unsigned char ip_bytes[16]; + auto ip_len = impl::parse_ip_address(host_str, ip_bytes); + auto is_ip = ip_len > 0; // Check Subject Alternative Names auto *san_names = static_cast( @@ -15804,10 +15751,12 @@ bool verify_hostname(cert_t cert, const char *hostname) { } } } else if (is_ip && names->type == WOLFSSL_GEN_IPADD) { - // IP address + // IP address: only an iPAddress SAN of the same family (4 bytes for + // IPv4, 16 bytes for IPv6) may authenticate the host. unsigned char *ip_data = wolfSSL_ASN1_STRING_data(names->d.iPAddress); - int ip_len = wolfSSL_ASN1_STRING_length(names->d.iPAddress); - if (ip_data && ip_len == 4 && memcmp(ip_data, ip_bytes, 4) == 0) { + auto san_ip_len = wolfSSL_ASN1_STRING_length(names->d.iPAddress); + if (ip_data && san_ip_len == static_cast(ip_len) && + memcmp(ip_data, ip_bytes, ip_len) == 0) { wolfSSL_sk_free(san_names); return true; } @@ -15816,8 +15765,10 @@ bool verify_hostname(cert_t cert, const char *hostname) { wolfSSL_sk_free(san_names); } - // Fallback: Check Common Name (CN) in subject - WOLFSSL_X509_NAME *subject = wolfSSL_X509_get_subject_name(x509); + // Fallback: Check Common Name (CN) in subject. Skipped for IP-literal hosts: + // an IP identity is only valid via an iPAddress SAN, never the CN (RFC 9110; + // the OpenSSL backend's X509_check_ip behaves the same way). + auto subject = is_ip ? nullptr : wolfSSL_X509_get_subject_name(x509); if (subject) { char cn[256] = {}; int cn_len = wolfSSL_X509_NAME_get_text_by_NID(subject, NID_commonName, cn, diff --git a/vendor/cpp-httplib/httplib.h b/vendor/cpp-httplib/httplib.h index 94d93e88a5..bfdbfc1da7 100644 --- a/vendor/cpp-httplib/httplib.h +++ b/vendor/cpp-httplib/httplib.h @@ -8,8 +8,8 @@ #ifndef CPPHTTPLIB_HTTPLIB_H #define CPPHTTPLIB_HTTPLIB_H -#define CPPHTTPLIB_VERSION "0.47.0" -#define CPPHTTPLIB_VERSION_NUM "0x002f00" +#define CPPHTTPLIB_VERSION "0.48.0" +#define CPPHTTPLIB_VERSION_NUM "0x003000" #ifdef _WIN32 #if defined(_WIN32_WINNT) && _WIN32_WINNT < 0x0A00 @@ -686,18 +686,70 @@ inline from_chars_result from_chars(const char *first, const char *last, return {p, std::errc{}}; } -// from_chars for double (simple wrapper for strtod) +// from_chars for double (hand-written, locale-independent) +// +// The only double consumed by this library is the HTTP quality value, whose +// grammar is (RFC 9110 12.4.2): +// qvalue = ( "0" [ "." 0*3DIGIT ] ) / ( "1" [ "." 0*3("0") ] ) +// i.e. a non-negative decimal with no sign, exponent, "inf"/"nan", or wide +// magnitude. So this parser recognizes exactly 1*DIGIT [ "." *DIGIT ] with +// '.' always the decimal separator (std::strtod would instead read it from the +// global C locale, mis-parsing q-values once an embedder calls +// setlocale(LC_ALL, "") into a comma-decimal locale). The caller range-checks +// the result to [0, 1], so inputs outside that range need not be distinguished +// here. Allocation-free, single pass, and free of the overflow/rounding edge +// cases that exponent and wide-range handling would introduce. inline from_chars_result from_chars(const char *first, const char *last, double &value) { - std::string s(first, last); - char *endptr = nullptr; - errno = 0; - value = std::strtod(s.c_str(), &endptr); - if (endptr == s.c_str()) { return {first, std::errc::invalid_argument}; } - if (errno == ERANGE) { - return {first + (endptr - s.c_str()), std::errc::result_out_of_range}; + value = 0.0; + const char *p = first; + + // Each 1eN is exactly representable, so a single final division by the + // matching entry yields a correctly-rounded result. + static const double powers_of_ten[] = { + 1e0, 1e1, 1e2, 1e3, 1e4, 1e5, 1e6, 1e7, 1e8, 1e9, + 1e10, 1e11, 1e12, 1e13, 1e14, 1e15, 1e16, 1e17, 1e18}; + const int max_frac_digits = + static_cast(sizeof(powers_of_ten) / sizeof(powers_of_ten[0])) - 1; + + // Accumulate digits into a 64-bit integer and remember how many were + // fractional. Two independent caps keep this bounded and safe: + // * accumulation saturates before mantissa could overflow uint64_t, and + // * frac_digits is capped at max_frac_digits so it is always a valid index + // into powers_of_ten (without this an input like "0.000...0" would never + // grow mantissa, so the saturation cap alone would not bound it). + // Both caps only drop digits far beyond the precision a q-value needs; any + // value they would change is well outside [0, 1] and rejected by the caller. + uint64_t mantissa = 0; + int frac_digits = 0; + bool seen_digit = false; + + const uint64_t limit = ((std::numeric_limits::max)() - 9) / 10; + auto accumulate = [&](char c) { + if (mantissa <= limit) { + mantissa = mantissa * 10 + static_cast(c - '0'); + return true; + } + return false; + }; + + for (; p != last && '0' <= *p && *p <= '9'; ++p) { + seen_digit = true; + accumulate(*p); } - return {first + (endptr - s.c_str()), std::errc{}}; + + if (p != last && *p == '.') { + ++p; + for (; p != last && '0' <= *p && *p <= '9'; ++p) { + seen_digit = true; + if (frac_digits < max_frac_digits && accumulate(*p)) { ++frac_digits; } + } + } + + if (!seen_digit) { return {first, std::errc::invalid_argument}; } + + value = static_cast(mantissa) / powers_of_ten[frac_digits]; + return {p, std::errc{}}; } inline bool parse_port(const char *s, size_t len, int &port) { @@ -2826,13 +2878,6 @@ private: #endif friend class ClientImpl; - -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT -private: - bool verify_host(X509 *server_cert) const; - bool verify_host_with_subject_alt_name(X509 *server_cert) const; - bool verify_host_with_common_name(X509 *server_cert) const; -#endif }; #endif // CPPHTTPLIB_SSL_ENABLED From fabde3bf5136940eb03821aa2490e2360093965b Mon Sep 17 00:00:00 2001 From: Mikolaj Kucharski Date: Fri, 19 Jun 2026 15:33:54 +0000 Subject: [PATCH 17/86] arg: Add comment line support to --api-key-file (#23168) --- common/arg.cpp | 4 ++-- tools/server/README.md | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 52425f25e4..6fd366d33b 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2911,7 +2911,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_API_KEY")); add_opt(common_arg( {"--api-key-file"}, "FNAME", - "path to file containing API keys (default: none)", + "path to file containing API keys, one per line; lines starting with a hash are treated as comments (default: none)", [](common_params & params, const std::string & value) { std::ifstream key_file(value); if (!key_file) { @@ -2919,7 +2919,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } std::string key; while (std::getline(key_file, key)) { - if (!key.empty()) { + if (!key.empty() && key[0] != '#') { params.api_keys.push_back(key); } } diff --git a/tools/server/README.md b/tools/server/README.md index 1f74ba52ae..eb730e713a 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -198,7 +198,7 @@ For the full list of features, please refer to [server's changelog](https://gith | `--embedding, --embeddings` | restrict to only support embedding use case; use only with dedicated embedding models (default: disabled)
(env: LLAMA_ARG_EMBEDDINGS) | | `--rerank, --reranking` | enable reranking endpoint on server (default: disabled)
(env: LLAMA_ARG_RERANKING) | | `--api-key KEY` | API key to use for authentication, multiple keys can be provided as a comma-separated list (default: none)
(env: LLAMA_API_KEY) | -| `--api-key-file FNAME` | path to file containing API keys (default: none)
(env: LLAMA_ARG_API_KEY_FILE) | +| `--api-key-file FNAME` | path to file containing API keys, one per line; lines starting with a hash are treated as comments (default: none)
(env: LLAMA_ARG_API_KEY_FILE) | | `--ssl-key-file FNAME` | path to file a PEM-encoded SSL private key
(env: LLAMA_ARG_SSL_KEY_FILE) | | `--ssl-cert-file FNAME` | path to file a PEM-encoded SSL certificate
(env: LLAMA_ARG_SSL_CERT_FILE) | | `--chat-template-kwargs STRING` | sets additional params for the json template parser, must be a valid json object string, e.g. '{"key1":"value1","key2":"value2"}'
(env: LLAMA_ARG_CHAT_TEMPLATE_KWARGS) | From 175147e8f612671b7906e21fb1cdea62e4da0e21 Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Fri, 19 Jun 2026 22:12:46 +0200 Subject: [PATCH 18/86] server: remove all internal mentions about "webui" (#24817) --- tools/server/server-context.cpp | 9 +-------- tools/server/server-context.h | 3 +-- tools/server/server-models.cpp | 1 - tools/server/server-models.h | 2 -- tools/server/tests/unit/test_basic.py | 8 ++++---- tools/server/tests/unit/test_proxy.py | 6 +++--- tools/server/tests/utils.py | 12 ++++++------ 7 files changed, 15 insertions(+), 26 deletions(-) diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 00ab31340b..791188b1e7 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -825,8 +825,7 @@ private: server_metrics metrics; - json json_ui_settings = json::object(); // Primary: new name - json json_webui_settings = json::object(); // Deprecated: use json_ui_settings instead (kept for compat) + json json_ui_settings = json::object(); // Necessary similarity of prompt for slot selection float slot_prompt_similarity = 0.0f; @@ -1308,7 +1307,6 @@ private: try { json json_settings = json::parse(cfg); json_ui_settings = json_settings; - json_webui_settings = json_settings; // deprecated: keep in sync } catch (const std::exception & e) { SRV_ERR("%s: failed to parse UI config: %s\n", __func__, e.what()); return false; @@ -3687,7 +3685,6 @@ server_context_meta server_context::get_meta() const { /* has_inp_audio */ impl->chat_params.allow_audio, /* has_inp_video */ impl->chat_params.allow_video, /* json_ui_settings */ impl->json_ui_settings, - /* json_webui_settings */ impl->json_webui_settings, // Deprecated /* slot_n_ctx */ impl->get_slot_n_ctx(), /* pooling_type */ llama_pooling_type(impl->ctx_tgt), @@ -4300,12 +4297,8 @@ void server_routes::init_routes() { { "endpoint_slots", params.endpoint_slots }, { "endpoint_props", params.endpoint_props }, { "endpoint_metrics", params.endpoint_metrics }, - // New keys { "ui", params.ui }, { "ui_settings", meta->json_ui_settings }, - // Deprecated: use ui/ui_settings instead (kept for backward compat) - { "webui", params.ui }, - { "webui_settings", meta->json_ui_settings }, { "chat_template", tmpl_default }, { "chat_template_caps", meta->chat_template_caps }, { "bos_token", meta->bos_token_str }, diff --git a/tools/server/server-context.h b/tools/server/server-context.h index 0e84785af4..07afabb926 100644 --- a/tools/server/server-context.h +++ b/tools/server/server-context.h @@ -22,8 +22,7 @@ struct server_context_meta { bool has_inp_image; bool has_inp_audio; bool has_inp_video; - json json_ui_settings; // Primary: new name - json json_webui_settings; // Deprecated: use json_ui_settings instead (kept for backward compat) + json json_ui_settings; int slot_n_ctx; enum llama_pooling_type pooling_type; diff --git a/tools/server/server-models.cpp b/tools/server/server-models.cpp index 23c1f16689..1fffa6b6e5 100644 --- a/tools/server/server-models.cpp +++ b/tools/server/server-models.cpp @@ -1474,7 +1474,6 @@ void server_models_routes::init_routes() { }}, // New key {"ui_settings", ui_settings}, - {"webui_settings", webui_settings}, {"build_info", std::string(llama_build_info())}, {"cors_proxy_enabled", params.ui_mcp_proxy}, }); diff --git a/tools/server/server-models.h b/tools/server/server-models.h index aeb0e874de..98872b0461 100644 --- a/tools/server/server-models.h +++ b/tools/server/server-models.h @@ -207,7 +207,6 @@ public: struct server_models_routes { common_params params; json ui_settings = json::object(); // Primary: new name - json webui_settings = json::object(); // Deprecated: use ui_settings (kept for compat) std::atomic stopping = false; // for graceful disconnecting SSE clients during shutdown server_models models; server_models_routes(const common_params & params, int argc, char ** argv) @@ -217,7 +216,6 @@ struct server_models_routes { try { json json_settings = json::parse(cfg); ui_settings = json_settings; - webui_settings = json_settings; // Deprecated: keep in sync } catch (const std::exception & e) { LOG_ERR("%s: failed to parse UI config: %s\n", __func__, e.what()); throw; diff --git a/tools/server/tests/unit/test_basic.py b/tools/server/tests/unit/test_basic.py index d1b89cf1a9..285726abf4 100644 --- a/tools/server/tests/unit/test_basic.py +++ b/tools/server/tests/unit/test_basic.py @@ -79,9 +79,9 @@ def test_load_split_model(): assert match_regex("(little|girl)+", res.body["content"]) -def test_no_webui(): +def test_no_ui(): global server - # default: webui enabled + # default: UI enabled server.start() url = f"http://{server.server_host}:{server.server_port}" res = requests.get(url) @@ -89,8 +89,8 @@ def test_no_webui(): assert "" in res.text server.stop() - # with --no-webui - server.no_webui = True + # with --no-ui, the UI should be disabled + server.no_ui = True server.start() res = requests.get(url) assert res.status_code == 404 diff --git a/tools/server/tests/unit/test_proxy.py b/tools/server/tests/unit/test_proxy.py index b7c3326187..3b86d80473 100644 --- a/tools/server/tests/unit/test_proxy.py +++ b/tools/server/tests/unit/test_proxy.py @@ -12,7 +12,7 @@ def create_server(): def test_mcp_no_proxy(): global server - server.webui_mcp_proxy = False + server.ui_mcp_proxy = False server.start() res = server.make_request("GET", "/cors-proxy") @@ -21,7 +21,7 @@ def test_mcp_no_proxy(): def test_mcp_proxy(): global server - server.webui_mcp_proxy = True + server.ui_mcp_proxy = True server.start() url = f"http://{server.server_host}:{server.server_port}/cors-proxy?url=http://example.com" @@ -32,7 +32,7 @@ def test_mcp_proxy(): def test_mcp_proxy_custom_port(): global server - server.webui_mcp_proxy = True + server.ui_mcp_proxy = True server.start() # try getting the server's models API via the proxy diff --git a/tools/server/tests/utils.py b/tools/server/tests/utils.py index c50c9a0f5a..63a959449e 100644 --- a/tools/server/tests/utils.py +++ b/tools/server/tests/utils.py @@ -94,7 +94,7 @@ class ServerProcess: enable_ctx_shift: int | None = False spec_draft_n_min: int | None = None spec_draft_n_max: int | None = None - no_webui: bool | None = None + no_ui: bool | None = None jinja: bool | None = None reasoning_format: Literal['deepseek', 'none', 'nothink'] | None = None reasoning: Literal['on', 'off', 'auto'] | None = None @@ -107,7 +107,7 @@ class ServerProcess: cache_ram: int | None = None no_cache_idle_slots: bool = False log_path: str | None = None - webui_mcp_proxy: bool = False + ui_mcp_proxy: bool = False backend_sampling: bool = False gcp_compat: bool = False @@ -225,8 +225,8 @@ class ServerProcess: server_args.extend(["--spec-draft-n-max", self.spec_draft_n_max]) if self.spec_draft_n_min: server_args.extend(["--spec-draft-n-min", self.spec_draft_n_min]) - if self.no_webui: - server_args.append("--no-webui") + if self.no_ui: + server_args.append("--no-ui") if self.no_models_autoload: server_args.append("--no-models-autoload") if self.jinja: @@ -251,8 +251,8 @@ class ServerProcess: server_args.extend(["--cache-ram", self.cache_ram]) if self.no_cache_idle_slots: server_args.append("--no-cache-idle-slots") - if self.webui_mcp_proxy: - server_args.append("--webui-mcp-proxy") + if self.ui_mcp_proxy: + server_args.append("--ui-mcp-proxy") if self.backend_sampling: server_args.append("--backend_sampling") if self.gcp_compat: From e475fa2b5f9fb50c3d6fc3e7c6fdf1e004465b62 Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Fri, 19 Jun 2026 22:28:38 +0200 Subject: [PATCH 19/86] mtmd, arg: fix utf8 handling on windows (#24779) * mtmd, arg: fix utf8 handling on windows * also fix ggml_fopen * fix build fail * also fix CLI --- common/arg.cpp | 38 ++++++++++++++++++++++++++++++++++++++ common/common.cpp | 12 ++++++++++++ common/common.h | 3 +++ ggml/src/ggml.c | 17 +++++++---------- tools/cli/cli.cpp | 2 +- tools/mtmd/clip-impl.h | 24 ++++++++++++++++++++++++ tools/mtmd/clip.cpp | 2 +- tools/mtmd/mtmd-cli.cpp | 3 +++ tools/mtmd/mtmd-helper.cpp | 18 +++++++++++++++++- 9 files changed, 106 insertions(+), 13 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 6fd366d33b..8f4f7d0763 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -17,6 +17,7 @@ # define NOMINMAX #endif #include +#include #endif #define JSON_ASSERT GGML_ASSERT @@ -893,7 +894,44 @@ bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map buf; + std::vector ptrs; +}; + +static utf8_argv make_utf8_argv() { + utf8_argv out; + int wargc = 0; + LPWSTR* wargv = CommandLineToArgvW(GetCommandLineW(), &wargc); + if (!wargv) return out; + + out.buf.reserve(wargc); + for (int i = 0; i < wargc; ++i) { + int n = WideCharToMultiByte(CP_UTF8, WC_ERR_INVALID_CHARS, wargv[i], -1, nullptr, 0, nullptr, nullptr); + if (n <= 0) { out.buf.emplace_back(); continue; } + auto& s = out.buf.emplace_back(); + s.resize(static_cast(n - 1)); + (void)WideCharToMultiByte(CP_UTF8, 0, wargv[i], -1, s.data(), n, nullptr, nullptr); + } + LocalFree(wargv); + + out.ptrs.reserve(out.buf.size() + 1); + for (auto& s : out.buf) out.ptrs.push_back(s.data()); + out.ptrs.push_back(nullptr); + return out; +} +#endif + bool common_params_parse(int argc, char ** argv, common_params & params, llama_example ex, void(*print_usage)(int, char **)) { +#ifdef _WIN32 + auto utf8 = make_utf8_argv(); + if (!utf8.ptrs.empty()) { + argc = static_cast(utf8.buf.size()); + argv = utf8.ptrs.data(); + } +#endif + auto ctx_arg = common_params_parser_init(params, ex, print_usage); const common_params params_org = ctx_arg.params; // the example can modify the default params diff --git a/common/common.cpp b/common/common.cpp index f3f114f682..a14e7bbed9 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1074,6 +1074,18 @@ std::vector fs_list(const std::string & path, bool include_dir return files; } +std::ifstream fs_open_ifstream(const std::string & fname, std::ios_base::openmode mode) { +#ifdef _WIN32 + int wlen = MultiByteToWideChar(CP_UTF8, 0, fname.c_str(), -1, NULL, 0); + if (!wlen) { return std::ifstream(); } + std::vector wfname(wlen); + (void)MultiByteToWideChar(CP_UTF8, 0, fname.c_str(), -1, wfname.data(), wlen); + return std::ifstream(wfname.data(), mode); +#else + return std::ifstream(fname, mode); +#endif +} + // // TTY utils // diff --git a/common/common.h b/common/common.h index 44c605189c..254454dcb1 100644 --- a/common/common.h +++ b/common/common.h @@ -842,6 +842,9 @@ struct common_file_info { }; std::vector fs_list(const std::string & path, bool include_directories); +// fs open, also handle UTF8 on Windows +std::ifstream fs_open_ifstream(const std::string & fname, std::ios_base::openmode mode); + // // TTY utils // diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index b43016c87d..0f682fd185 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -600,18 +600,15 @@ FILE * ggml_fopen(const char * fname, const char * mode) { // convert fname (UTF-8) wchar_t * wfname = ggml_mbstowcs(fname); if (wfname) { - // convert mode (ANSI) - wchar_t * wmode = GGML_MALLOC((strlen(mode) + 1) * sizeof(wchar_t)); - wchar_t * wmode_p = wmode; - do { - *wmode_p++ = (wchar_t)*mode; - } while (*mode++); - - // open file - file = _wfopen(wfname, wmode); + // convert mode (UTF-8) + wchar_t * wmode = ggml_mbstowcs(mode); + if (wmode) { + // open file + file = _wfopen(wfname, wmode); + GGML_FREE(wmode); + } GGML_FREE(wfname); - GGML_FREE(wmode); } return file; diff --git a/tools/cli/cli.cpp b/tools/cli/cli.cpp index c03894b4b1..8b7b58693f 100644 --- a/tools/cli/cli.cpp +++ b/tools/cli/cli.cpp @@ -202,7 +202,7 @@ struct cli_context { // TODO: support remote files in the future (http, https, etc) std::string load_input_file(const std::string & fname, bool is_media) { - std::ifstream file(fname, std::ios::binary); + std::ifstream file = fs_open_ifstream(fname, std::ios::binary); if (!file) { return ""; } diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index f232b68e5a..e7b5301445 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -13,6 +13,14 @@ #include #include #include +#include + +#ifdef _WIN32 +#ifndef NOMINMAX +#define NOMINMAX +#endif +#include +#endif // Internal header for clip.cpp @@ -661,6 +669,22 @@ struct clip_image_f32_batch { // common utils // +#ifdef _WIN32 +static std::ifstream open_ifstream_binary(const std::string & fname) { + int wlen = MultiByteToWideChar(CP_UTF8, 0, fname.c_str(), -1, NULL, 0); + if (!wlen) { + throw std::runtime_error("failed to convert filename to UTF-16: " + fname); + } + std::vector wfname(wlen); + (void)MultiByteToWideChar(CP_UTF8, 0, fname.c_str(), -1, wfname.data(), wlen); + return std::ifstream(wfname.data(), std::ios::binary); +} +#else +static std::ifstream open_ifstream_binary(const std::string & fname) { + return std::ifstream(fname, std::ios::binary); +} +#endif + static std::string string_format(const char * fmt, ...) { va_list ap; va_list ap2; diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 10840a851f..c713703e01 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -1752,7 +1752,7 @@ struct clip_model_loader { std::map tensor_offset; std::vector tensors_to_load; - auto fin = std::ifstream(fname, std::ios::binary); + auto fin = open_ifstream_binary(fname); if (!fin) { throw std::runtime_error(string_format("%s: failed to open %s\n", __func__, fname.c_str())); } diff --git a/tools/mtmd/mtmd-cli.cpp b/tools/mtmd/mtmd-cli.cpp index 0ad000ef01..8704ea79d7 100644 --- a/tools/mtmd/mtmd-cli.cpp +++ b/tools/mtmd/mtmd-cli.cpp @@ -396,6 +396,9 @@ int main(int argc, char ** argv) { int n_predict = params.n_predict < 0 ? INT_MAX : params.n_predict; + console::init(params.simple_io, params.use_color); + atexit([]() { console::cleanup(); }); + // Ctrl+C handling { #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) diff --git a/tools/mtmd/mtmd-helper.cpp b/tools/mtmd/mtmd-helper.cpp index b5c4089232..3c73db4431 100644 --- a/tools/mtmd/mtmd-helper.cpp +++ b/tools/mtmd/mtmd-helper.cpp @@ -582,13 +582,29 @@ mtmd_helper_bitmap_wrapper mtmd_helper_bitmap_init_from_buf(mtmd_context * ctx, } mtmd_helper_bitmap_wrapper mtmd_helper_bitmap_init_from_file(mtmd_context * ctx, const char * fname, bool placeholder) { - std::vector buf; +#ifdef _WIN32 + int wlen = MultiByteToWideChar(CP_UTF8, 0, fname, -1, NULL, 0); + if (!wlen) { + LOG_ERR("Unable to convert filename to UTF-16: %s\n", fname); + return {nullptr, nullptr}; + } + std::vector wfname(wlen); + wlen = MultiByteToWideChar(CP_UTF8, 0, fname, -1, wfname.data(), wlen); + if (!wlen) { + LOG_ERR("Unable to convert filename to UTF-16: %s\n", fname); + return {nullptr, nullptr}; + } + FILE * f = _wfopen(wfname.data(), L"rb"); +#else FILE * f = fopen(fname, "rb"); +#endif if (!f) { LOG_ERR("Unable to open file %s: %s\n", fname, strerror(errno)); return {nullptr, nullptr}; } + std::vector buf; + fseek(f, 0, SEEK_END); long file_size = ftell(f); fseek(f, 0, SEEK_SET); From 4b48a53b6cc60e051f35f2acbd06264a909bb255 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrien=20Gallou=C3=ABt?= Date: Fri, 19 Jun 2026 23:26:54 +0200 Subject: [PATCH 20/86] server : optimize get_token_probabilities (#24796) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Use std::partial_sort to order only the requested top-n tokens instead of the full vocabulary logprobs sort: vocab=128000 n_top=0 iters=100 full sort: 8555.6 us/op partial sort: 704.3 us/op Signed-off-by: Adrien Gallouët --- tools/server/server-common.cpp | 36 +++++++++++++++++++++++---------- tools/server/server-common.h | 2 +- tools/server/server-context.cpp | 3 +-- 3 files changed, 27 insertions(+), 14 deletions(-) diff --git a/tools/server/server-common.cpp b/tools/server/server-common.cpp index 75729e62dd..3dc686bb46 100644 --- a/tools/server/server-common.cpp +++ b/tools/server/server-common.cpp @@ -12,6 +12,7 @@ #include #include #include +#include json format_error_response(const std::string & message, const enum error_type type) { std::string type_str; @@ -1238,7 +1239,7 @@ json format_response_rerank( // other utils // -std::vector get_token_probabilities(llama_context * ctx, int idx) { +std::vector get_token_probabilities(llama_context * ctx, int idx, size_t n_top) { std::vector cur; const auto * logits = llama_get_logits_ith(ctx, idx); @@ -1257,21 +1258,34 @@ std::vector get_token_probabilities(llama_context * ctx, int i } } - // sort tokens by logits - std::sort(cur.begin(), cur.end(), [](const llama_token_data & a, const llama_token_data & b) { - return a.logit > b.logit; - }); + // sort tokens by logits (partial: only the leading `n_top` need ordering) + if (n_top > cur.size()) { + n_top = cur.size(); + } + if (n_top > 0) { + std::partial_sort(cur.begin(), cur.begin() + n_top, cur.end(), + [](const llama_token_data & a, const llama_token_data & b) { + return a.logit > b.logit; + }); + } // apply softmax - float max_l = cur[0].logit; + float max_l = -std::numeric_limits::infinity(); + if (n_top > 0) { + max_l = cur[0].logit; // partial_sort guarantees the absolute maximum is at index 0 + } else { + for (const auto & t : cur) { + max_l = std::max(max_l, t.logit); + } + } float cum_sum = 0.0f; - for (size_t i = 0; i < cur.size(); ++i) { - float p = expf(cur[i].logit - max_l); - cur[i].p = p; + for (auto & t : cur) { + float p = expf(t.logit - max_l); + t.p = p; cum_sum += p; } - for (size_t i = 0; i < cur.size(); ++i) { - cur[i].p /= cum_sum; + for (auto & t : cur) { + t.p /= cum_sum; } return cur; diff --git a/tools/server/server-common.h b/tools/server/server-common.h index f286b3d156..efd31733b0 100644 --- a/tools/server/server-common.h +++ b/tools/server/server-common.h @@ -326,7 +326,7 @@ json format_response_rerank( // other utils // -std::vector get_token_probabilities(llama_context * ctx, int idx); +std::vector get_token_probabilities(llama_context * ctx, int idx, size_t n_top); std::string safe_json_to_str(const json & data); diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 791188b1e7..1f0e1bfd42 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -1824,8 +1824,7 @@ private: }); } } else { - // TODO: optimize this with min-p optimization - std::vector cur = get_token_probabilities(ctx_tgt, idx); + std::vector cur = get_token_probabilities(ctx_tgt, idx, n_probs_request); const size_t max_probs = cur.size(); const size_t n_probs = std::min(max_probs, n_probs_request); From 2b686a9120e2f1f9eabb65a597a4eef03eea9d87 Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Sat, 20 Jun 2026 01:02:26 +0200 Subject: [PATCH 21/86] server: refactor child --> router communication (#24821) * server: refactor child --> router communication * fix wakeup case * add docs * improve update_status() * nits --- common/arg.cpp | 2 - common/common.h | 11 ++- tools/server/README-dev.md | 11 +++ tools/server/server-context.cpp | 18 ++--- tools/server/server-context.h | 30 ++++++- tools/server/server-models.cpp | 133 ++++++++++++++++++-------------- tools/server/server-models.h | 30 +++++-- tools/server/server.cpp | 29 ++++--- 8 files changed, 173 insertions(+), 91 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 8f4f7d0763..a9b1a25b27 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -303,7 +303,6 @@ static handle_model_result common_params_handle_model(struct common_params_model if (!model.docker_repo.empty()) { model.path = common_docker_resolve_model(model.docker_repo); - model.name = model.docker_repo; } else if (!model.hf_repo.empty()) { // If -m was used with -hf, treat the model "path" as the hf_file to download if (model.hf_file.empty() && !model.path.empty()) { @@ -323,7 +322,6 @@ static handle_model_result common_params_handle_model(struct common_params_model throw std::runtime_error("failed to download model from Hugging Face"); } - model.name = model.hf_repo; model.path = download_result.model_path; if (!download_result.mmproj_path.empty()) { diff --git a/common/common.h b/common/common.h index 254454dcb1..f2f2202ec2 100644 --- a/common/common.h +++ b/common/common.h @@ -295,7 +295,16 @@ struct common_params_model { std::string hf_repo = ""; // HF repo // NOLINT std::string hf_file = ""; // HF file // NOLINT std::string docker_repo = ""; // Docker repo // NOLINT - std::string name = ""; // in format /[:] (tag is optional) // NOLINT + + std::string get_name() { + if (!hf_repo.empty()) { + return hf_repo; + } + if (!docker_repo.empty()) { + return docker_repo; + } + return path; + } }; // draft-model-based speculative decoding parameters diff --git a/tools/server/README-dev.md b/tools/server/README-dev.md index 4c41031239..2796d28350 100644 --- a/tools/server/README-dev.md +++ b/tools/server/README-dev.md @@ -180,6 +180,17 @@ That requires `JSON.stringify` when formatted to message content: } ``` +### Router mode: how child <--> router communicates + +Upon spawning a new child process using `subprocess`, both child and router listen to the stdout/stderr (combined) + +For the direction from child to router: +- Generic messages are logs, it will be forwarded to router's stdout +- Special state update messages are prefixed by `cmd_child_to_router:state:`, followed by a JSON. See `server_models::handle_child_state` for more + +For the direction from router to child: +- When server sends `cmd_router_to_child:exit`, the child should exit gracefully --> if after `DEFAULT_STOP_TIMEOUT` and the child is still running, force-kill it + ### Model management API (router mode) Model management API was added via PR [#23976](https://github.com/ggml-org/llama.cpp/pull/23976) diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 1f0e1bfd42..3de1335ec2 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -63,11 +63,6 @@ enum slot_state { SLOT_STATE_GENERATING, }; -enum server_state { - SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet - SERVER_STATE_READY, // Server is ready and model is loaded -}; - struct server_slot { int id; @@ -773,6 +768,8 @@ public: // note: chat_params must not be refreshed upon existing sleeping state server_chat_params chat_params; + server_state_callback_t callback_state = [](server_state, json) -> void {}; + server_context_impl() { mtmd_helper_log_set(common_log_default_callback, nullptr); } @@ -1244,8 +1241,8 @@ private: if (!params_base.model_alias.empty()) { // backward compat: use first alias as model name model_name = *params_base.model_alias.begin(); - } else if (!params_base.model.name.empty()) { - model_name = params_base.model.name; + } else if (!params_base.model.get_name().empty()) { + model_name = params_base.model.get_name(); } else { // fallback: derive model name from file name auto model_path = std::filesystem::path(params_base.model.path); @@ -3734,8 +3731,11 @@ struct server_res_generator : server_http_res { } }; -void server_context::on_sleeping_changed(std::function callback) { - impl->queue_tasks.on_sleeping_state(std::move(callback)); +void server_context::set_state_callback(server_state_callback_t callback) { + impl->callback_state = std::move(callback); + impl->queue_tasks.on_sleeping_state([this](bool sleeping) { + impl->callback_state(sleeping ? SERVER_STATE_SLEEPING : SERVER_STATE_READY, {}); + }); } // compute the number of tokens before the last user message in the prompt diff --git a/tools/server/server-context.h b/tools/server/server-context.h index 07afabb926..c7218a12ed 100644 --- a/tools/server/server-context.h +++ b/tools/server/server-context.h @@ -52,6 +52,31 @@ struct server_context_meta { uint64_t model_size; }; +enum server_state { + // SERVER_STATE_DOWNLOADING, + SERVER_STATE_LOADING, + SERVER_STATE_READY, + SERVER_STATE_SLEEPING, +}; + +static std::string server_state_to_str(server_state state) { + switch (state) { + case SERVER_STATE_LOADING: return "loading"; + case SERVER_STATE_READY: return "ready"; + case SERVER_STATE_SLEEPING: return "sleeping"; + default: GGML_ASSERT(false && "invalid server_state"); + } +} + +static server_state server_state_from_str(const std::string & str) { + if (str == "loading") return SERVER_STATE_LOADING; + if (str == "ready") return SERVER_STATE_READY; + if (str == "sleeping") return SERVER_STATE_SLEEPING; + GGML_ASSERT(false && "invalid server_state string"); +} + +using server_state_callback_t = std::function; + struct server_context { std::unique_ptr impl; @@ -79,9 +104,8 @@ struct server_context { // not thread-safe, should only be used from the main thread server_context_meta get_meta() const; - // register a callback to be called when sleeping state changes - // must be set before load_model() is called - void on_sleeping_changed(std::function callback); + // note: must be set before load_model() is called + void set_state_callback(server_state_callback_t callback); }; diff --git a/tools/server/server-models.cpp b/tools/server/server-models.cpp index 1fffa6b6e5..a569c8be3c 100644 --- a/tools/server/server-models.cpp +++ b/tools/server/server-models.cpp @@ -1,5 +1,6 @@ #include "server-common.h" #include "server-models.h" +#include "server-context.h" #include "build-info.h" #include "preset.h" @@ -44,9 +45,7 @@ extern char **environ; #define DEFAULT_STOP_TIMEOUT 10 // seconds #define CMD_ROUTER_TO_CHILD_EXIT "cmd_router_to_child:exit" -#define CMD_CHILD_TO_ROUTER_READY "cmd_child_to_router:ready" // also sent when waking up from sleep -#define CMD_CHILD_TO_ROUTER_SLEEP "cmd_child_to_router:sleep" -#define CMD_CHILD_TO_ROUTER_INFO "cmd_child_to_router:info:" // followed by json string +#define CMD_CHILD_TO_ROUTER_STATE "cmd_child_to_router:state:" // followed by json string // address for child process, this is needed because router may run on 0.0.0.0 // ref: https://github.com/ggml-org/llama.cpp/issues/17862 @@ -904,12 +903,8 @@ void server_models::load(const std::string & name) { while (fgets(buffer, vec_buf.size(), stdout_file) != nullptr) { LOG("[%5d] %s", port, buffer); std::string str(buffer); - if (string_starts_with(buffer, CMD_CHILD_TO_ROUTER_READY)) { - this->update_status(name, SERVER_MODEL_STATUS_LOADED, 0); - } else if (string_starts_with(buffer, CMD_CHILD_TO_ROUTER_INFO)) { - this->update_loaded_info(name, str); - } else if (string_starts_with(buffer, CMD_CHILD_TO_ROUTER_SLEEP)) { - this->update_status(name, SERVER_MODEL_STATUS_SLEEPING, 0); + if (string_starts_with(buffer, CMD_CHILD_TO_ROUTER_STATE)) { + this->handle_child_state(name, str); } } } else { @@ -976,7 +971,10 @@ void server_models::load(const std::string & name) { subprocess_destroy(&child_proc->get()); // update status and exit code - this->update_status(name, SERVER_MODEL_STATUS_UNLOADED, exit_code); + this->update_status(name, { + SERVER_MODEL_STATUS_UNLOADED, + exit_code + }); SRV_INF("instance name=%s exited with status %d\n", name.c_str(), exit_code); }); @@ -1016,7 +1014,8 @@ struct server_models_download_res : public common_download_callback { common_download_model(model, opts); is_ok = true; } catch (const std::exception & e) { - SRV_ERR("download failed for model name=%s: %s\n", model.name.c_str(), e.what()); + auto model_name = model.get_name(); + SRV_ERR("download failed for model name=%s: %s\n", model_name.c_str(), e.what()); is_ok = false; } return is_ok; @@ -1036,7 +1035,7 @@ struct server_models_download_res : public common_download_callback { }; void server_models::download(common_params_model && model, common_download_opts && opts) { - std::string name = model.name; + std::string name = model.get_name(); GGML_ASSERT(name == model.hf_repo); std::unique_lock lk(mutex); @@ -1064,9 +1063,10 @@ void server_models::download(common_params_model && model, common_download_opts inst.th = std::thread([this, dl = std::move(dl)]() { dl->opts.callback = dl.get(); bool ok = dl->run(); + auto model_name = dl->model.get_name(); SRV_INF("download finished for model name=%s with status=%s\n", - dl->model.name.c_str(), ok ? "success" : "failure"); - update_download_progress(dl->model.name, {}, true, ok); + model_name.c_str(), ok ? "success" : "failure"); + update_download_progress(model_name, {}, true, ok); // need_reload is set inside update_download_progress under the mutex; // the next load_models() call will clean up this instance }); @@ -1130,21 +1130,27 @@ void server_models::unload_all() { } } -void server_models::update_status(const std::string & name, server_model_status status, int exit_code) { +void server_models::update_status(const std::string & name, const update_status_args & args) { std::unique_lock lk(mutex); auto it = mapping.find(name); if (it != mapping.end()) { auto & meta = it->second.meta; - meta.status = status; - meta.exit_code = exit_code; + meta.status = args.status; + meta.exit_code = args.exit_code; + if (!args.loaded_info.is_null()) { + meta.loaded_info = args.loaded_info; + } } // broadcast status change to SSE { json data = { - {"status", server_model_status_to_string(status)}, + {"status", server_model_status_to_string(args.status)}, }; - if (status == SERVER_MODEL_STATUS_UNLOADED) { - data["exit_code"] = exit_code; + if (args.status == SERVER_MODEL_STATUS_UNLOADED) { + data["exit_code"] = args.exit_code; + } + if (!args.loaded_info.is_null()) { + data["info"] = args.loaded_info; } // note: notify_sse doesn't acquire the lock, so no deadlock here notify_sse("status_change", name, data); @@ -1152,29 +1158,6 @@ void server_models::update_status(const std::string & name, server_model_status cv.notify_all(); } -void server_models::update_loaded_info(const std::string & name, std::string & raw_info) { - if (!string_starts_with(raw_info, CMD_CHILD_TO_ROUTER_INFO)) { - SRV_WRN("invalid loaded info format from child for model name=%s: %s\n", name.c_str(), raw_info.c_str()); - return; - } - - json info; - try { - info = json::parse(raw_info.substr(strlen(CMD_CHILD_TO_ROUTER_INFO))); - } catch (const std::exception & e) { - SRV_WRN("failed to parse loaded info from child for model name=%s: %s\n", name.c_str(), e.what()); - return; - } - - std::unique_lock lk(mutex); - auto it = mapping.find(name); - if (it != mapping.end()) { - auto & meta = it->second.meta; - meta.loaded_info = info; - } - cv.notify_all(); -} - void server_models::update_download_progress(const std::string & name, const common_download_progress & progress, bool done, bool ok) { json curr; { @@ -1323,21 +1306,54 @@ server_http_res_ptr server_models::proxy_request(const server_http_req & req, co return proxy; } -bool server_models::is_child_server() { +void server_models::handle_child_state(const std::string & name, const std::string & raw_input) { + server_state state; + json payload; + + try { + json data = json::parse(raw_input.substr(strlen(CMD_CHILD_TO_ROUTER_STATE))); + state = server_state_from_str(json_value(data, "state", std::string())); + payload = json_value(data, "payload", json{}); + } catch (const std::exception & e) { + SRV_ERR("failed to parse child state update for name=%s: %s\n", name.c_str(), e.what()); + return; + } + + switch (state) { + case SERVER_STATE_LOADING: + { + // do nothing for now + // TODO: report loading progress for first load and wakeup from sleep + } break; + case SERVER_STATE_READY: + { + update_status(name, { + SERVER_MODEL_STATUS_LOADED, + 0, + // note: payload can be empty if this is a wakeup from sleep + payload.size() > 0 ? payload : nullptr + }); + } break; + case SERVER_STATE_SLEEPING: + { + update_status(name, { SERVER_MODEL_STATUS_SLEEPING }); + } break; + default: + // should never happen, but just in case + GGML_ASSERT(false && "unexpected state from child server"); + } +} + +// +// server_child +// + +bool server_child::is_child() { const char * router_port = std::getenv("LLAMA_SERVER_ROUTER_PORT"); return router_port != nullptr; } -std::thread server_models::setup_child_server(const std::function & shutdown_handler, const json & model_info) { - // send a notification to the router server that a model instance is ready - common_log_pause(common_log_main()); - fflush(stdout); - fprintf(stdout, "%s\n", CMD_CHILD_TO_ROUTER_READY); - fflush(stdout); - fprintf(stdout, "%s%s\n", CMD_CHILD_TO_ROUTER_INFO, safe_json_to_str(model_info).c_str()); - fflush(stdout); - common_log_resume(common_log_main()); - +std::thread server_child::setup(const std::function & shutdown_handler) { // setup thread for monitoring stdin return std::thread([shutdown_handler]() { // wait for EOF on stdin @@ -1363,10 +1379,14 @@ std::thread server_models::setup_child_server(const std::function & s }); } -void server_models::notify_router_sleeping_state(bool is_sleeping) { +void server_child::notify_to_router(const std::string & state, const json & payload) { + json data = { + {"state", state}, + {"payload", payload}, + }; common_log_pause(common_log_main()); fflush(stdout); - fprintf(stdout, "%s\n", is_sleeping ? CMD_CHILD_TO_ROUTER_SLEEP : CMD_CHILD_TO_ROUTER_READY); + fprintf(stdout, "%s%s\n", CMD_CHILD_TO_ROUTER_STATE, safe_json_to_str(data).c_str()); fflush(stdout); common_log_resume(common_log_main()); } @@ -1644,7 +1664,6 @@ void server_models_routes::init_routes() { common_params_model model; common_download_opts opts; - model.name = name; model.hf_repo = name; opts.bearer_token = params.hf_token; opts.download_mmproj = true; diff --git a/tools/server/server-models.h b/tools/server/server-models.h index 98872b0461..40a0e078c6 100644 --- a/tools/server/server-models.h +++ b/tools/server/server-models.h @@ -171,8 +171,12 @@ public: void download(common_params_model && model, common_download_opts && opts); // update the status of a model instance (thread-safe) - void update_status(const std::string & name, server_model_status status, int exit_code); - void update_loaded_info(const std::string & name, std::string & raw_info); + struct update_status_args { + server_model_status status; + int exit_code = 0; // only valid if status == UNLOADED + json loaded_info = nullptr; + }; + void update_status(const std::string & name, const update_status_args & args); void update_download_progress(const std::string & name, const common_download_progress & progress, bool done, bool ok = true); // remove a cache model from disk and update the list (thread-safe) @@ -193,15 +197,27 @@ public: // proxy an HTTP request to the model instance server_http_res_ptr proxy_request(const server_http_req & req, const std::string & method, const std::string & name, bool update_last_used); + // handle message sent from server_child::notify_to_router() + // raw input must starts with CMD_CHILD_TO_ROUTER_STATE, followed by a JSON string + // this function is not thread-safe, must be called from instance's monitoring thread + // payload per state: + // state = loading -> payload = {} (TODO: add progress info) + // state = ready -> payload = model_info (json), or {} if wakeup from sleeping + // state = sleeping -> payload = {} + void handle_child_state(const std::string & name, const std::string & raw_input); +}; + +struct server_child { // return true if the current process is a child server instance - static bool is_child_server(); + bool is_child(); - // notify the router server that a model instance is ready + // register the shutdown_handler to be called by the router // return the monitoring thread (to be joined by the caller) - static std::thread setup_child_server(const std::function & shutdown_handler, const json & model_info); + std::thread setup(const std::function & shutdown_handler); - // notify the router server that the sleeping state has changed - static void notify_router_sleeping_state(bool sleeping); + // notify router server for status changes (e.g. loading, downloading, sleeping, etc.) + // message will be handled by server_models::handle_child_state() on the router side + void notify_to_router(const std::string & state_name, const json & payload); }; struct server_models_routes { diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 2a67bfcfed..bf3680b9f0 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -90,8 +90,10 @@ int llama_server(int argc, char ** argv) { llama_numa_init(params.numa); // router server never loads a model and must not touch the GPU + const bool is_router_server = params.model.path.empty() + && params.model.hf_repo.empty(); + // skip device enumeration so the CUDA primary context stays uncreated - const bool is_router_server = params.model.path.empty(); common_params_print_info(params, !is_router_server); if (!is_router_server) { @@ -113,8 +115,9 @@ int llama_server(int argc, char ** argv) { } // for consistency between server router mode and single-model mode, we set the same model name as alias - if (params.model_alias.empty() && !params.model.name.empty()) { - params.model_alias.insert(params.model.name); + auto model_name = params.model.get_name(); + if (params.model_alias.empty() && !model_name.empty()) { + params.model_alias.insert(model_name); } // struct that contains llama context and inference @@ -255,6 +258,7 @@ int llama_server(int argc, char ** argv) { // Start the server // + server_child child; // only used in non-router mode std::function clean_up; if (is_router_server) { @@ -300,15 +304,16 @@ int llama_server(int argc, char ** argv) { return 1; } - // load the model - SRV_INF("%s", "loading model\n"); - - if (server_models::is_child_server()) { - ctx_server.on_sleeping_changed([&](bool sleeping) { - server_models::notify_router_sleeping_state(sleeping); + // setup communication child --> router if necessary + if (child.is_child()) { + ctx_server.set_state_callback([&](server_state state, json payload) { + child.notify_to_router(server_state_to_str(state), payload); }); } + // load the model + SRV_INF("%s", "loading model\n"); + if (!ctx_server.load_model(params)) { clean_up(); if (ctx_http.thread.joinable()) { @@ -365,9 +370,9 @@ int llama_server(int argc, char ** argv) { // optionally, notify router server that this instance is ready std::thread monitor_thread; - if (server_models::is_child_server()) { - json model_info = routes.get_model_info(); - monitor_thread = server_models::setup_child_server(shutdown_handler, model_info); + if (child.is_child()) { + monitor_thread = child.setup(shutdown_handler); + child.notify_to_router(server_state_to_str(SERVER_STATE_READY), routes.get_model_info()); } // this call blocks the main thread until queue_tasks.terminate() is called From f449e0553708b895adbd94a301431cef691f632d Mon Sep 17 00:00:00 2001 From: Masashi Yoshimura Date: Sat, 20 Jun 2026 08:12:32 +0900 Subject: [PATCH 22/86] ggml-webgpu: add adapter toggles for F16 on Vulkan + NVIDIA --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 42 +++++++++++----------------- 1 file changed, 16 insertions(+), 26 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 0b605fa86b..f71d1aee73 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -3788,7 +3788,7 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) { ctx->memset_pipeline = ggml_webgpu_create_pipeline(ctx->device, wgsl_memset, "memset", constants); } -static void create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { +static void ggml_backend_webgpu_request_adapter(wgpu::Instance & instance, wgpu::Adapter & adapter) { wgpu::RequestAdapterOptions options = {}; #ifndef __EMSCRIPTEN__ @@ -3800,17 +3800,20 @@ static void create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { options.nextInChain = &adapterTogglesDesc; #endif - ctx->webgpu_global_ctx->instance.WaitAny( - ctx->webgpu_global_ctx->instance.RequestAdapter( - &options, wgpu::CallbackMode::AllowSpontaneous, - [&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) { - if (status != wgpu::RequestAdapterStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message); - return; - } - ctx->webgpu_global_ctx->adapter = std::move(adapter); - }), - UINT64_MAX); + instance.WaitAny(instance.RequestAdapter( + &options, wgpu::CallbackMode::AllowSpontaneous, + [&adapter](wgpu::RequestAdapterStatus status, wgpu::Adapter _adapter, const char * message) { + if (status != wgpu::RequestAdapterStatus::Success) { + GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message); + return; + } + adapter = std::move(_adapter); + }), + UINT64_MAX); +} + +static void create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { + ggml_backend_webgpu_request_adapter(ctx->webgpu_global_ctx->instance, ctx->webgpu_global_ctx->adapter); GGML_ASSERT(ctx->webgpu_global_ctx->adapter != nullptr); ctx->webgpu_global_ctx->adapter.GetLimits(&ctx->webgpu_global_ctx->capabilities.limits); @@ -4543,20 +4546,7 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() { // Probe for adapter support wgpu::Adapter adapter; if (ctx->webgpu_global_ctx->instance != nullptr) { - wgpu::RequestAdapterOptions options = {}; - - // probe for adapter support - ctx->webgpu_global_ctx->instance.WaitAny( - ctx->webgpu_global_ctx->instance.RequestAdapter( - &options, wgpu::CallbackMode::AllowSpontaneous, - [&adapter](wgpu::RequestAdapterStatus status, wgpu::Adapter _adapter, const char * message) { - if (status != wgpu::RequestAdapterStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message); - return; - } - adapter = std::move(_adapter); - }), - UINT64_MAX); + ggml_backend_webgpu_request_adapter(ctx->webgpu_global_ctx->instance, adapter); } // WebGPU backend requires f16 support and, on native, implicit device synchronization. From f4043fec0103872bf4339f6fa18d8b17824d5b6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= <1629204+CISC@users.noreply.github.com> Date: Sat, 20 Jun 2026 12:42:36 +0200 Subject: [PATCH 23/86] convert : more consistent handling of rope_parameters (#24833) --- conversion/bailingmoe.py | 2 +- conversion/base.py | 8 +++++++- conversion/chatglm.py | 2 +- conversion/deci.py | 2 +- conversion/exaone.py | 6 +++--- conversion/gemma.py | 2 +- conversion/glm.py | 4 ++-- conversion/llama.py | 2 +- conversion/mimo.py | 2 +- conversion/minicpm.py | 16 ++++++---------- conversion/nemotron.py | 7 ++++--- conversion/phi.py | 20 +++++++++----------- conversion/qwen.py | 2 +- conversion/stablelm.py | 2 +- conversion/step3.py | 2 +- 15 files changed, 40 insertions(+), 39 deletions(-) diff --git a/conversion/bailingmoe.py b/conversion/bailingmoe.py index 319ff6dabe..2c6425cb64 100644 --- a/conversion/bailingmoe.py +++ b/conversion/bailingmoe.py @@ -126,7 +126,7 @@ class BailingMoeV2Model(TextModel): if (rope_dim := hparams.get("head_dim")) is None: rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"] - self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5))) + self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.rope_parameters.get("partial_rotary_factor", 0.5))) self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"]) self.gguf_writer.add_vocab_size(hparams["vocab_size"]) self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"]) diff --git a/conversion/base.py b/conversion/base.py index c872bcbb3c..08fd3747c4 100644 --- a/conversion/base.py +++ b/conversion/base.py @@ -1119,8 +1119,10 @@ class TextModel(ModelBase): rope_theta = self.find_hparam(["global_rope_theta", "rope_global_theta", "rope_theta_global", "rope_theta", "rotary_emb_base"], optional=True) local_rope_theta = self.find_hparam(["local_rope_theta", "rope_local_theta", "rope_theta_local", "swa_rope_theta", "rope_local_base_freq"], optional=True) + partial_rotary_factor = self.find_hparam(["partial_rotary_factor", "rope_pct", "rope_percent"], optional=True) + original_max_position_embeddings = self.find_hparam(["original_max_position_embeddings"], optional=True) - # Ensure "rope_theta" and "rope_type" is mirrored in rope_parameters + # Ensure global params are mirrored in rope_parameters if "full_attention" not in self.rope_parameters and "sliding_attention" not in self.rope_parameters: if local_rope_theta is not None: self.rope_parameters["sliding_attention"] = {"rope_theta": local_rope_theta} @@ -1128,6 +1130,10 @@ class TextModel(ModelBase): self.rope_parameters["rope_theta"] = rope_theta if "rope_type" not in self.rope_parameters and (rope_type := self.rope_parameters.get("type")) is not None: self.rope_parameters["rope_type"] = rope_type + if "partial_rotary_factor" not in self.rope_parameters and partial_rotary_factor is not None: + self.rope_parameters["partial_rotary_factor"] = partial_rotary_factor + if "original_max_position_embeddings" not in self.rope_parameters and original_max_position_embeddings is not None: + self.rope_parameters["original_max_position_embeddings"] = original_max_position_embeddings @classmethod def __init_subclass__(cls): diff --git a/conversion/chatglm.py b/conversion/chatglm.py index 7e323b8900..801913075d 100644 --- a/conversion/chatglm.py +++ b/conversion/chatglm.py @@ -148,7 +148,7 @@ class ChatGLMModel(TextModel): rope_dim = self.hparams["attention_dim"] else: rope_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"] - self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5))) + self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.rope_parameters.get("partial_rotary_factor", 0.5))) self.gguf_writer.add_add_bos_token(False) rope_freq = 10000 if "rope_ratio" in self.hparams: diff --git a/conversion/deci.py b/conversion/deci.py index 46d8568c5a..be446eefa6 100644 --- a/conversion/deci.py +++ b/conversion/deci.py @@ -161,7 +161,7 @@ class DeciModel(TextModel): factor = rope_params.get("factor", 8.0) low_freq_factor = rope_params.get("low_freq_factor", 1.0) high_freq_factor = rope_params.get("high_freq_factor", 4.0) - old_context_len = self.hparams.get("original_max_position_embeddings", 8192) + old_context_len = rope_params.get("original_max_position_embeddings", 8192) low_freq_wavelen = old_context_len / low_freq_factor high_freq_wavelen = old_context_len / high_freq_factor diff --git a/conversion/exaone.py b/conversion/exaone.py index b21f027842..bc4fb3f1b1 100644 --- a/conversion/exaone.py +++ b/conversion/exaone.py @@ -24,7 +24,7 @@ class ExaoneModel(TextModel): assert (hparams["activation_function"] == "silu") - rotary_factor = self.find_hparam(["partial_rotary_factor", "rope_pct"], optional=True) + rotary_factor = self.rope_parameters.get("partial_rotary_factor") rotary_factor = rotary_factor if rotary_factor is not None else 1.0 self.gguf_writer.add_rope_dimension_count(int(rotary_factor * (hparams["hidden_size"] // hparams["num_attention_heads"]))) @@ -39,7 +39,7 @@ class ExaoneModel(TextModel): factor = rope_params.get("factor", 8.0) low_freq_factor = rope_params.get("low_freq_factor", 1.0) high_freq_factor = rope_params.get("high_freq_factor", 4.0) - old_context_len = self.hparams.get("original_max_position_embeddings", 8192) + old_context_len = rope_params.get("original_max_position_embeddings", 8192) low_freq_wavelen = old_context_len / low_freq_factor high_freq_wavelen = old_context_len / high_freq_factor @@ -104,7 +104,7 @@ class Exaone4Model(TextModel): factor = rope_params.get("factor", 16.0) low_freq_factor = rope_params.get("low_freq_factor", 1.0) high_freq_factor = rope_params.get("high_freq_factor", 4.0) - old_context_len = self.hparams.get("original_max_position_embeddings", 8192) + old_context_len = rope_params.get("original_max_position_embeddings", 8192) low_freq_wavelen = old_context_len / low_freq_factor high_freq_wavelen = old_context_len / high_freq_factor diff --git a/conversion/gemma.py b/conversion/gemma.py index 5b4ca5c583..c552df732b 100644 --- a/conversion/gemma.py +++ b/conversion/gemma.py @@ -693,7 +693,7 @@ class Gemma4Model(Gemma3Model): self.gguf_writer.add_head_count_kv(value_arr) # handle n_rot differently for global vs swa layers - partial_rotary_factor_swa = self.hparams.get("partial_rotary_factor", 1.0) + partial_rotary_factor_swa = self.rope_parameters.get("partial_rotary_factor", 1.0) n_rot_full = int(head_dim_full) # "proportional" is used, see generate_extra_tensors n_rot_swa = int(head_dim_swa * partial_rotary_factor_swa) self.gguf_writer.add_rope_dimension_count(n_rot_full) diff --git a/conversion/glm.py b/conversion/glm.py index 641937720d..895cefc22b 100644 --- a/conversion/glm.py +++ b/conversion/glm.py @@ -124,7 +124,7 @@ class Glm4MoeModel(TextModel): self.hparams["hidden_size"] // self.hparams["num_attention_heads"] ) self.gguf_writer.add_rope_dimension_count( - int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5)) + int(rope_dim * self.rope_parameters.get("partial_rotary_factor", 0.5)) ) # MoE parameters - Use only routed expert count (shared experts handled separately) @@ -226,7 +226,7 @@ class GlmMoeDsaModel(DeepseekV2Model): super().set_gguf_parameters() rope_dim = self.hparams["qk_rope_head_dim"] - partial_rotary_factor = self.hparams.get("partial_rotary_factor", 1.0) + partial_rotary_factor = self.rope_parameters.get("partial_rotary_factor", 1.0) self.gguf_writer.add_rope_dimension_count(int(rope_dim * partial_rotary_factor)) # NextN/MTP prediction layers diff --git a/conversion/llama.py b/conversion/llama.py index b87bf92d46..a0d39472eb 100644 --- a/conversion/llama.py +++ b/conversion/llama.py @@ -289,7 +289,7 @@ class LlamaModel(TextModel): factor = rope_params.get("factor", 8.0) low_freq_factor = rope_params.get("low_freq_factor", 1.0) high_freq_factor = rope_params.get("high_freq_factor", 4.0) - old_context_len = self.hparams.get("original_max_position_embeddings", 8192) + old_context_len = rope_params.get("original_max_position_embeddings", 8192) low_freq_wavelen = old_context_len / low_freq_factor high_freq_wavelen = old_context_len / high_freq_factor diff --git a/conversion/mimo.py b/conversion/mimo.py index d4067aab4b..11ec286794 100644 --- a/conversion/mimo.py +++ b/conversion/mimo.py @@ -154,7 +154,7 @@ class MimoV2Model(TextModel): self.gguf_writer.add_expert_count(self.hparams["n_routed_experts"]) self.gguf_writer.add_expert_feed_forward_length(self.hparams["moe_intermediate_size"]) - rope_dim = int(self.hparams["head_dim"] * self.hparams["partial_rotary_factor"]) + rope_dim = int(self.hparams["head_dim"] * self.rope_parameters["partial_rotary_factor"]) self.gguf_writer.add_rope_dimension_count(rope_dim) self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("layernorm_epsilon", 1e-5)) diff --git a/conversion/minicpm.py b/conversion/minicpm.py index e9a4c4a74d..e31b26a008 100644 --- a/conversion/minicpm.py +++ b/conversion/minicpm.py @@ -32,11 +32,9 @@ class MiniCPMModel(TextModel): def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: rope_dims = self.hparams["hidden_size"] // self.hparams["num_attention_heads"] - rope_scaling = self.find_hparam(['rope_scaling'], True) - if rope_scaling is not None: - long_factors = rope_scaling.get('long_factor', None) - short_factors = rope_scaling.get('short_factor', None) - + long_factors = self.rope_parameters.get('long_factor') + short_factors = self.rope_parameters.get('short_factor') + if long_factors or short_factors: if long_factors is None or short_factors is None: raise KeyError('Missing the required key rope_scaling.long_factor or rope_scaling_short_factor') @@ -85,13 +83,11 @@ class MiniCPM3Model(TextModel): self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"]) def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: - rope_scaling = self.find_hparam(['rope_scaling'], True) - if rope_scaling is not None: + long_factors = self.rope_parameters.get('long_factor') + short_factors = self.rope_parameters.get('short_factor') + if long_factors or short_factors: rope_dims = self.hparams["qk_rope_head_dim"] - long_factors = rope_scaling.get('long_factor', None) - short_factors = rope_scaling.get('short_factor', None) - if long_factors is None or short_factors is None: raise KeyError('Missing the required key rope_scaling.long_factor or rope_scaling_short_factor') diff --git a/conversion/nemotron.py b/conversion/nemotron.py index dfeeb97858..e44688a788 100644 --- a/conversion/nemotron.py +++ b/conversion/nemotron.py @@ -125,17 +125,18 @@ class NemotronModel(TextModel): self.gguf_writer.add_layer_norm_eps(f_norm_eps) # * Partial RoPE - rot_pct = self.find_hparam(["partial_rotary_factor", "rope_pct", "rope_percent"]) + rot_pct = self.rope_parameters["partial_rotary_factor"] n_embd = self.find_hparam(["hidden_size", "n_embd"]) n_head = self.find_hparam(["num_attention_heads", "n_head"]) self.gguf_writer.add_rope_dimension_count(int(rot_pct * n_embd) // n_head) # * RopeScaling for Nemotron - if "rope_scaling" not in self.hparams or self.hparams["rope_scaling"] is None: + factor = self.hparams.get("factor") or self.rope_parameters.get("factor") + if factor is None: self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) else: self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) - self.gguf_writer.add_rope_scaling_factor(self.hparams["factor"]) + self.gguf_writer.add_rope_scaling_factor(factor) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: # * Adding +1 to LayerNorm's weights here to implement layernorm1p w/o changing anything on the GGML engine side diff --git a/conversion/phi.py b/conversion/phi.py index 5e0d72847a..df4bfe809a 100644 --- a/conversion/phi.py +++ b/conversion/phi.py @@ -18,7 +18,7 @@ class Phi2Model(TextModel): model_arch = gguf.MODEL_ARCH.PHI2 def set_gguf_parameters(self): - rot_pct = self.find_hparam(["partial_rotary_factor"]) + rot_pct = self.rope_parameters["partial_rotary_factor"] n_embd = self.find_hparam(["hidden_size", "n_embd"]) n_head = self.find_hparam(["num_attention_heads", "n_head"]) @@ -149,8 +149,8 @@ class Phi3MiniModel(TextModel): n_head_kv = self.find_hparam(["num_key_value_heads", "n_head_kv"]) rms_eps = self.find_hparam(["rms_norm_eps"]) max_pos_embds = self.find_hparam(["n_positions", "max_position_embeddings"]) - orig_max_pos_embds = self.find_hparam(["original_max_position_embeddings"]) - rot_pct = self.hparams.get("partial_rotary_factor", 1.0) + orig_max_pos_embds = self.rope_parameters["original_max_position_embeddings"] + rot_pct = self.rope_parameters.get("partial_rotary_factor", 1.0) rope_dims = int(rot_pct * n_embd) // n_head self.gguf_writer.add_context_length(max_pos_embds) @@ -174,18 +174,19 @@ class Phi3MiniModel(TextModel): n_embd = self.find_hparam(["hidden_size", "n_embd"]) n_head = self.find_hparam(["num_attention_heads", "n_head"]) max_pos_embds = self.find_hparam(["n_positions", "max_position_embeddings"]) - orig_max_pos_embds = self.find_hparam(["original_max_position_embeddings"]) - rot_pct = self.hparams.get("partial_rotary_factor", 1.0) + orig_max_pos_embds = self.rope_parameters["original_max_position_embeddings"] + rot_pct = self.rope_parameters.get("partial_rotary_factor", 1.0) rope_dims = int(rot_pct * n_embd) // n_head # write rope scaling for long context (128k) model - rope_scaling = self.find_hparam(['rope_scaling'], True) - if rope_scaling is None: + long_factors = self.rope_parameters.get('long_factor') + short_factors = self.rope_parameters.get('short_factor') + if not long_factors: return scale = max_pos_embds / orig_max_pos_embds - rope_scaling_type = rope_scaling.get('rope_type', rope_scaling.get('type', '')).lower() + rope_scaling_type = self.rope_parameters.get('rope_type', '').lower() if len(rope_scaling_type) == 0: raise KeyError('Missing the required key rope_scaling.type') @@ -198,9 +199,6 @@ class Phi3MiniModel(TextModel): self.gguf_writer.add_rope_scaling_attn_factors(attn_factor) - long_factors = rope_scaling.get('long_factor', None) - short_factors = rope_scaling.get('short_factor', None) - if long_factors is None or short_factors is None: raise KeyError('Missing the required key rope_scaling.long_factor or rope_scaling_short_factor') diff --git a/conversion/qwen.py b/conversion/qwen.py index 7eb135c832..6b85eb9aaf 100644 --- a/conversion/qwen.py +++ b/conversion/qwen.py @@ -280,7 +280,7 @@ class Qwen3NextModel(Qwen2MoeModel): self.gguf_writer.add_full_attention_interval(self.hparams.get("full_attention_interval", 4)) if (rope_dim := self.hparams.get("head_dim")) is None: rope_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"] - self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.25))) + self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.rope_parameters.get("partial_rotary_factor", 0.25))) @classmethod def filter_tensors(cls, item: tuple[str, Callable[[], Tensor]]) -> tuple[str, Callable[[], Tensor]] | None: diff --git a/conversion/stablelm.py b/conversion/stablelm.py index ba5e9aa6ca..6e16378a03 100644 --- a/conversion/stablelm.py +++ b/conversion/stablelm.py @@ -28,7 +28,7 @@ class StableLMModel(TextModel): self.gguf_writer.add_embedding_length(hparams["hidden_size"]) self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"]) - rotary_factor = self.find_hparam(["partial_rotary_factor", "rope_pct"]) + rotary_factor = self.rope_parameters["partial_rotary_factor"] self.gguf_writer.add_rope_dimension_count(int(rotary_factor * (hparams["hidden_size"] // hparams["num_attention_heads"]))) self.gguf_writer.add_head_count(hparams["num_attention_heads"]) self.gguf_writer.add_head_count_kv(hparams["num_key_value_heads"]) diff --git a/conversion/step3.py b/conversion/step3.py index 8c45b61c95..49bb5244a6 100644 --- a/conversion/step3.py +++ b/conversion/step3.py @@ -314,7 +314,7 @@ class Step35Model(TextModel): factor = float(rope_params.get("factor", 8.0)) low_freq_factor = float(rope_params.get("low_freq_factor", 1.0)) high_freq_factor = float(rope_params.get("high_freq_factor", 4.0)) - old_context_len = int(rope_params.get("original_max_position_embeddings", self.hparams.get("original_max_position_embeddings", 8192))) + old_context_len = int(rope_params.get("original_max_position_embeddings", 8192)) low_freq_wavelen = old_context_len / low_freq_factor high_freq_wavelen = old_context_len / high_freq_factor From 37a77fb0579be9d71e2c73da0553cfd42b7b103a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrien=20Gallou=C3=ABt?= Date: Sat, 20 Jun 2026 12:43:06 +0200 Subject: [PATCH 24/86] ggml : optimize AMX (#24806) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Flatten the partition over n_batch * M so every thread participates in the quantization | CPU | Model | Test | t/s OLD | t/s NEW | Speedup | |:--------------------------------|:------------------------------|:-------|----------:|----------:|----------:| | Intel(R) Xeon(R) Platinum 8488C | qwen35 0.8B IQ4_NL - 4.5 bpw | pp512 | 730.71 | 779.86 | 1.07 | | Intel(R) Xeon(R) Platinum 8488C | qwen35 0.8B IQ4_NL - 4.5 bpw | tg128 | 87.88 | 86.79 | 0.99 | | Intel(R) Xeon(R) Platinum 8488C | qwen35 0.8B IQ4_XS - 4.25 bpw | pp512 | 725.09 | 1023.31 | 1.41 | | Intel(R) Xeon(R) Platinum 8488C | qwen35 0.8B IQ4_XS - 4.25 bpw | tg128 | 83.64 | 83.62 | 1.00 | | Intel(R) Xeon(R) Platinum 8488C | qwen35 0.8B Q4_0 | pp512 | 820.51 | 924.05 | 1.13 | | Intel(R) Xeon(R) Platinum 8488C | qwen35 0.8B Q4_0 | tg128 | 90.59 | 92.46 | 1.02 | | Intel(R) Xeon(R) Platinum 8488C | qwen35 0.8B Q4_1 | pp512 | 776.88 | 872.79 | 1.12 | | Intel(R) Xeon(R) Platinum 8488C | qwen35 0.8B Q4_1 | tg128 | 89.39 | 90.94 | 1.02 | | Intel(R) Xeon(R) Platinum 8488C | qwen35 0.8B Q4_K_M | pp512 | 719.28 | 1009.27 | 1.40 | | Intel(R) Xeon(R) Platinum 8488C | qwen35 0.8B Q4_K_M | tg128 | 80.62 | 80.86 | 1.00 | | Intel(R) Xeon(R) Platinum 8488C | qwen35 0.8B Q4_K_S | pp512 | 732.29 | 1077.29 | 1.47 | | Intel(R) Xeon(R) Platinum 8488C | qwen35 0.8B Q4_K_S | tg128 | 86.42 | 83.53 | 0.97 | Signed-off-by: Adrien Gallouët --- ggml/src/ggml-cpu/amx/mmq.cpp | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/ggml/src/ggml-cpu/amx/mmq.cpp b/ggml/src/ggml-cpu/amx/mmq.cpp index d9383a04be..9f3a744b5d 100644 --- a/ggml/src/ggml-cpu/amx/mmq.cpp +++ b/ggml/src/ggml-cpu/amx/mmq.cpp @@ -2417,15 +2417,14 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te // Q4_K, Q5_K, Q6_K, IQ4_XS handles 8 TILE_K per blck_size GGML_ASSERT(TILE_K == blck_size || TILE_K * 8 == blck_size); - parallel_for_ggml(params, n_batch, [&](int begin, int end) { - for (int batch_idx = begin; batch_idx < end; ++batch_idx) { + parallel_for_ggml(params, n_batch * M, [&](int begin, int end) { + for (int idx = begin; idx < end; ++idx) { + int batch_idx = idx / M; + int m = idx % M; int64_t src1_offset = ggml_batch_offset(src1, batch_idx, ne2); const float * A_data = (const float *)((const char *)src1->data + src1_offset); char * wdata_batch = (char *)wdata + batch_idx * M * row_size_A; - - for (int m = 0; m < M; ++m) { - from_float(A_data + m * K, wdata_batch + m * row_size_A, K); - } + from_float(A_data + m * K, wdata_batch + m * row_size_A, K); } }); }); From 796f41bedca8a786ab3eb5584cd97b7730b303d8 Mon Sep 17 00:00:00 2001 From: davidrhodus Date: Sat, 20 Jun 2026 03:48:24 -0700 Subject: [PATCH 25/86] model : glm-dsa load DSA indexer tensors as optional (#24770) GLM-5.2 ships the DSA "lightning indexer" on only a subset of layers (the "full" layers; others omit it), but the GLM_DSA loader created the five indexer tensors on every layer as required, so loading any GLM-5.2 GGUF failed with e.g. `missing tensor 'blk.3.indexer.k_norm.weight'`. GLM_DSA's graph is llama_model_deepseek2::graph (plain MLA) and does not use the indexer tensors (indexer runtime not yet implemented), so they are loaded-but-unused. Marking them TENSOR_NOT_REQUIRED lets layers without an indexer load as nullptr and the model runs as full MLA attention. DeepSeek-V3.2 (uniform indexer on all layers) is unaffected. --- src/models/glm-dsa.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/models/glm-dsa.cpp b/src/models/glm-dsa.cpp index 11d91312de..32fe6def6f 100644 --- a/src/models/glm-dsa.cpp +++ b/src/models/glm-dsa.cpp @@ -101,11 +101,11 @@ void llama_model_glm_dsa::load_arch_tensors(llama_model_loader &) { layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags); // DSA indexer - layer.indexer_k_norm = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "weight", i), {hparams.indexer_head_size}, flags); - layer.indexer_k_norm_b = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "bias", i), {hparams.indexer_head_size}, flags); - layer.indexer_proj = create_tensor(tn(LLM_TENSOR_INDEXER_PROJ, "weight", i), {n_embd, hparams.indexer_n_head}, flags); - layer.indexer_attn_k = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_K, "weight", i), {n_embd, hparams.indexer_head_size}, flags); - layer.indexer_attn_q_b = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_Q_B, "weight", i), {q_lora_rank, hparams.indexer_n_head * hparams.indexer_head_size}, flags); + layer.indexer_k_norm = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "weight", i), {hparams.indexer_head_size}, flags | TENSOR_NOT_REQUIRED); + layer.indexer_k_norm_b = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "bias", i), {hparams.indexer_head_size}, flags | TENSOR_NOT_REQUIRED); + layer.indexer_proj = create_tensor(tn(LLM_TENSOR_INDEXER_PROJ, "weight", i), {n_embd, hparams.indexer_n_head}, flags | TENSOR_NOT_REQUIRED); + layer.indexer_attn_k = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_K, "weight", i), {n_embd, hparams.indexer_head_size}, flags | TENSOR_NOT_REQUIRED); + layer.indexer_attn_q_b = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_Q_B, "weight", i), {q_lora_rank, hparams.indexer_n_head * hparams.indexer_head_size}, flags | TENSOR_NOT_REQUIRED); if (i < (int) hparams.n_layer_dense_lead) { layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, flags); layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, flags); From 67e9fd3b74b7fab3a153161f1942cc2121aa90a3 Mon Sep 17 00:00:00 2001 From: Aldehir Rojas Date: Sat, 20 Jun 2026 05:54:42 -0500 Subject: [PATCH 26/86] docker : prebuild web UI for s390x build [no release] (#24829) --- .devops/s390x.Dockerfile | 16 ---------------- .dockerignore | 1 - .github/workflows/docker.yml | 18 ++++++++++++++++-- 3 files changed, 16 insertions(+), 19 deletions(-) diff --git a/.devops/s390x.Dockerfile b/.devops/s390x.Dockerfile index 149d79a615..d88dd2d92d 100644 --- a/.devops/s390x.Dockerfile +++ b/.devops/s390x.Dockerfile @@ -4,20 +4,6 @@ ARG BUILD_DATE=N/A ARG APP_VERSION=N/A ARG APP_REVISION=N/A -ARG NODE_VERSION=24 - -FROM docker.io/node:$NODE_VERSION AS web - -ARG APP_VERSION - -WORKDIR /app/tools/ui - -COPY tools/ui/package.json tools/ui/package-lock.json ./ -RUN npm ci - -COPY tools/ui/ ./ -RUN LLAMA_BUILD_NUMBER="$APP_VERSION" npm run build - ### Build Llama.cpp stage FROM docker.io/gcc:${GCC_VERSION} AS build @@ -34,8 +20,6 @@ RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \ WORKDIR /app COPY . . -COPY --from=web /app/tools/ui/dist tools/ui/dist - RUN --mount=type=cache,target=/root/.ccache \ --mount=type=cache,target=/app/build \ cmake -S . -B build -G Ninja \ diff --git a/.dockerignore b/.dockerignore index a223b7e898..0b81e83bf5 100644 --- a/.dockerignore +++ b/.dockerignore @@ -11,7 +11,6 @@ build*/ tools/ui/node_modules/ -tools/ui/dist/ models/* diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 8195a55ff2..afe4b7c664 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -58,6 +58,13 @@ jobs: git tag ${{ steps.srctag.outputs.name }} || exit 0 git push origin ${{ steps.srctag.outputs.name }} || exit 0 + build_ui: + name: Build UI + needs: create_tag + uses: ./.github/workflows/ui-build.yml + with: + hf_ui_version: ${{ needs.create_tag.outputs.source_tag }} + prepare_matrices: name: Prepare Docker matrices runs-on: ubuntu-24.04 @@ -79,7 +86,7 @@ jobs: [ { "tag": "cpu", "dockerfile": ".devops/cpu.Dockerfile", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04" }, { "tag": "cpu", "dockerfile": ".devops/cpu.Dockerfile", "platforms": "linux/arm64", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04-arm" }, - { "tag": "cpu", "dockerfile": ".devops/s390x.Dockerfile", "platforms": "linux/s390x", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04-s390x" }, + { "tag": "cpu", "dockerfile": ".devops/s390x.Dockerfile", "platforms": "linux/s390x", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04-s390x", "prebuilt_ui": true }, { "tag": "cuda cuda12", "dockerfile": ".devops/cuda.Dockerfile", "cuda_version": "12.8.1", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" }, { "tag": "cuda cuda12", "dockerfile": ".devops/cuda.Dockerfile", "cuda_version": "12.8.1", "platforms": "linux/arm64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04-arm" }, { "tag": "cuda13", "dockerfile": ".devops/cuda.Dockerfile", "cuda_version": "13.3.0", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" }, @@ -135,7 +142,7 @@ jobs: push_to_registry: name: Push Docker image to Docker Registry - needs: [prepare_matrices, create_tag] + needs: [prepare_matrices, create_tag, build_ui] runs-on: ${{ matrix.config.runs_on }} strategy: @@ -150,6 +157,13 @@ jobs: fetch-depth: 0 ref: ${{ needs.create_tag.outputs.source_tag }} + - name: Download prebuilt UI + if: ${{ matrix.config.prebuilt_ui == true }} + uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8 + with: + name: ui-build + path: tools/ui/dist + - name: Set up QEMU if: ${{ contains(matrix.config.platforms, 'linux/amd64') }} uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a # v4 From e27f3085973722407518ea4822fb3e0a2b41df9c Mon Sep 17 00:00:00 2001 From: Matti4 Date: Sat, 20 Jun 2026 15:34:47 +0200 Subject: [PATCH 27/86] server: avoid forwarding auth headers in CORS proxy (#24373) * server: avoid forwarding auth headers in CORS proxy * format * fix test * fix e2e test --------- Co-authored-by: Xuan Son Nguyen --- tools/server/server-cors-proxy.h | 22 ++++++- tools/server/tests/unit/test_security.py | 45 ++++++++++++++ tools/ui/src/lib/constants/mcp.ts | 3 + tools/ui/src/lib/services/mcp.service.ts | 23 +++++-- tools/ui/src/lib/utils/api-headers.ts | 15 ++++- tools/ui/src/lib/utils/cors-proxy.ts | 8 ++- tools/ui/tests/e2e/pwa.e2e.ts | 10 +-- tools/ui/tests/unit/mcp-service.test.ts | 64 +++++++++++++++++++- tools/ui/tests/unit/sanitize-headers.test.ts | 18 ++++++ 9 files changed, 187 insertions(+), 21 deletions(-) diff --git a/tools/server/server-cors-proxy.h b/tools/server/server-cors-proxy.h index 2af0c7e1c2..53a6909ed2 100644 --- a/tools/server/server-cors-proxy.h +++ b/tools/server/server-cors-proxy.h @@ -7,9 +7,18 @@ #include #include #include +#include +#include #include "server-http.h" +static std::string proxy_header_to_lower(std::string header) { + std::transform(header.begin(), header.end(), header.begin(), [](unsigned char c) { + return std::tolower(c); + }); + return header; +} + static server_http_res_ptr proxy_request(const server_http_req & req, std::string method) { std::string target_url = req.get_param("url"); common_http_url parsed_url = common_http_parse_url(target_url); @@ -33,11 +42,18 @@ static server_http_res_ptr proxy_request(const server_http_req & req, std::strin SRV_INF("proxying %s request to %s://%s:%i%s\n", method.c_str(), parsed_url.scheme.c_str(), parsed_url.host.c_str(), parsed_url.port, parsed_url.path.c_str()); std::map headers; + const std::string proxy_header_prefix = "x-llama-server-proxy-header-"; for (auto [key, value] : req.headers) { - auto new_key = key; - if (string_starts_with(new_key, "x-proxy-header-")) { - string_replace_all(new_key, "x-proxy-header-", ""); + const std::string lowered_key = proxy_header_to_lower(key); + if (!string_starts_with(lowered_key, proxy_header_prefix)) { + continue; } + + auto new_key = key.substr(proxy_header_prefix.size()); + if (new_key.empty()) { + continue; + } + headers[new_key] = value; } diff --git a/tools/server/tests/unit/test_security.py b/tools/server/tests/unit/test_security.py index 02d0b1afbc..a0c3e214ae 100644 --- a/tools/server/tests/unit/test_security.py +++ b/tools/server/tests/unit/test_security.py @@ -1,6 +1,8 @@ import pytest from openai import OpenAI from utils import * +import threading +from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer server = ServerPreset.tinyllama2() @@ -105,6 +107,49 @@ def test_cors_options(origin: str, cors_header: str, cors_header_value: str): assert res.headers[cors_header] == cors_header_value +def test_cors_proxy_only_forwards_explicit_proxy_headers(): + class CaptureHeadersHandler(BaseHTTPRequestHandler): + def do_GET(self): + self.server.captured_headers = dict(self.headers) + self.send_response(200) + self.end_headers() + self.wfile.write(b"ok") + + def log_message(self, format, *args): + pass + + target = ThreadingHTTPServer(("127.0.0.1", 0), CaptureHeadersHandler) + target.captured_headers = {} + target_thread = threading.Thread(target=target.serve_forever, daemon=True) + target_thread.start() + + try: + server = ServerPreset.tinyllama2() + server.api_key = TEST_API_KEY + server.ui_mcp_proxy = True + server.start() + + res = server.make_request("GET", f"/cors-proxy?url=http://127.0.0.1:{target.server_port}/capture", headers={ + "Authorization": f"Bearer {TEST_API_KEY}", + "Proxy-Authorization": "Basic secret", + "X-Api-Key": TEST_API_KEY, + "Cookie": "session=secret", + "x-llama-server-proxy-header-accept": "application/json", + "x-llama-server-proxy-header-authorization": "Bearer explicit", + }) + + assert res.status_code == 200 + captured = {key.lower(): value for key, value in target.captured_headers.items()} + assert captured["accept"] == "application/json" + assert captured["authorization"] == "Bearer explicit" + assert "proxy-authorization" not in captured + assert "x-api-key" not in captured + assert "cookie" not in captured + finally: + target.shutdown() + target.server_close() + + @pytest.mark.parametrize( "media_path, image_url, success", [ diff --git a/tools/ui/src/lib/constants/mcp.ts b/tools/ui/src/lib/constants/mcp.ts index 5b11f989e2..a7381df0bf 100644 --- a/tools/ui/src/lib/constants/mcp.ts +++ b/tools/ui/src/lib/constants/mcp.ts @@ -51,6 +51,9 @@ export const EXPECTED_THEMED_ICON_PAIR_COUNT = 2; /** CORS proxy URL query parameter name */ export const CORS_PROXY_URL_PARAM = 'url'; +/** Header prefix for headers that should be forwarded by the CORS proxy */ +export const CORS_PROXY_HEADER_PREFIX = 'x-llama-server-proxy-header-'; + /** Number of trailing characters to keep visible when partially redacting mcp-session-id */ export const MCP_SESSION_ID_VISIBLE_CHARS = 5; diff --git a/tools/ui/src/lib/services/mcp.service.ts b/tools/ui/src/lib/services/mcp.service.ts index 0aa58dc5d8..90de0d5d88 100644 --- a/tools/ui/src/lib/services/mcp.service.ts +++ b/tools/ui/src/lib/services/mcp.service.ts @@ -16,6 +16,7 @@ import { DEFAULT_MCP_CONFIG, DEFAULT_CLIENT_VERSION, DEFAULT_IMAGE_MIME_TYPE, + CORS_PROXY_HEADER_PREFIX, MCP_PARTIAL_REDACT_HEADERS, CORS_PROXY_ENDPOINT } from '$lib/constants'; @@ -133,6 +134,20 @@ export class MCPService { return details; } + private static addRequestHeaders( + requestHeaders: Headers, + headers: HeadersInit, + useProxy: boolean + ) { + for (const [key, value] of new Headers(headers).entries()) { + const proxiedKey = + useProxy && !key.toLowerCase().startsWith(CORS_PROXY_HEADER_PREFIX) + ? `${CORS_PROXY_HEADER_PREFIX}${key}` + : key; + requestHeaders.set(proxiedKey, value); + } + } + private static summarizeError(error: unknown): Record { if (error instanceof Error) { return { @@ -271,15 +286,11 @@ export class MCPService { const requestHeaders = new Headers(baseInit.headers); if (typeof Request !== 'undefined' && input instanceof Request) { - for (const [key, value] of input.headers.entries()) { - requestHeaders.set(key, value); - } + this.addRequestHeaders(requestHeaders, input.headers, useProxy); } if (init?.headers) { - for (const [key, value] of new Headers(init.headers).entries()) { - requestHeaders.set(key, value); - } + this.addRequestHeaders(requestHeaders, init.headers, useProxy); } const request = this.createDiagnosticRequestDetails( diff --git a/tools/ui/src/lib/utils/api-headers.ts b/tools/ui/src/lib/utils/api-headers.ts index c0a5309b99..a2b70d492a 100644 --- a/tools/ui/src/lib/utils/api-headers.ts +++ b/tools/ui/src/lib/utils/api-headers.ts @@ -1,5 +1,5 @@ import { config } from '$lib/stores/settings.svelte'; -import { REDACTED_HEADERS } from '$lib/constants'; +import { CORS_PROXY_HEADER_PREFIX, REDACTED_HEADERS } from '$lib/constants'; import { redactValue } from './redact'; /** @@ -52,11 +52,20 @@ export function sanitizeHeaders( for (const [key, value] of normalized.entries()) { const normalizedKey = key.toLowerCase(); - const partialChars = partialRedactHeaders?.get(normalizedKey); + const unproxiedKey = normalizedKey.startsWith(CORS_PROXY_HEADER_PREFIX) + ? normalizedKey.slice(CORS_PROXY_HEADER_PREFIX.length) + : normalizedKey; + const partialChars = + partialRedactHeaders?.get(normalizedKey) ?? partialRedactHeaders?.get(unproxiedKey); if (partialChars !== undefined) { sanitized[key] = redactValue(value, partialChars); - } else if (REDACTED_HEADERS.has(normalizedKey) || redactedHeaders.has(normalizedKey)) { + } else if ( + REDACTED_HEADERS.has(normalizedKey) || + REDACTED_HEADERS.has(unproxiedKey) || + redactedHeaders.has(normalizedKey) || + redactedHeaders.has(unproxiedKey) + ) { sanitized[key] = redactValue(value); } else { sanitized[key] = value; diff --git a/tools/ui/src/lib/utils/cors-proxy.ts b/tools/ui/src/lib/utils/cors-proxy.ts index 47caf27427..1694b7dbe6 100644 --- a/tools/ui/src/lib/utils/cors-proxy.ts +++ b/tools/ui/src/lib/utils/cors-proxy.ts @@ -3,7 +3,11 @@ */ import { base } from '$app/paths'; -import { CORS_PROXY_ENDPOINT, CORS_PROXY_URL_PARAM } from '$lib/constants'; +import { + CORS_PROXY_ENDPOINT, + CORS_PROXY_HEADER_PREFIX, + CORS_PROXY_URL_PARAM +} from '$lib/constants'; /** * Build a proxied URL that routes through llama-server's CORS proxy. @@ -28,7 +32,7 @@ export function buildProxiedHeaders(headers: Record): Record = {}; for (const [key, value] of Object.entries(headers)) { - proxiedHeaders[`x-proxy-header-${key}`] = value; + proxiedHeaders[`${CORS_PROXY_HEADER_PREFIX}${key}`] = value; } return proxiedHeaders; diff --git a/tools/ui/tests/e2e/pwa.e2e.ts b/tools/ui/tests/e2e/pwa.e2e.ts index be7642b191..e21672239b 100644 --- a/tools/ui/tests/e2e/pwa.e2e.ts +++ b/tools/ui/tests/e2e/pwa.e2e.ts @@ -39,8 +39,8 @@ test.describe('PWA Service Worker', () => { const swContent = await swResponse.text(); // Precache contains SvelteKit content-hashed bundle paths - expect(swContent).toMatch(/"_app\/immutable\/bundle\.[a-zA-Z0-9-]+\.js"/); - expect(swContent).toMatch(/"_app\/immutable\/assets\/bundle\.[a-zA-Z0-9-]+\.css"/); + expect(swContent).toMatch(/"_app\/immutable\/bundle\.[a-zA-Z0-9_-]+\.js"/); + expect(swContent).toMatch(/"_app\/immutable\/assets\/bundle\.[a-zA-Z0-9_-]+\.css"/); expect(swContent).toMatch(/"manifest\.webmanifest"/); expect(swContent).toMatch(/"_app\/version\.json"/); expect(swContent).toMatch(/NavigationRoute/); @@ -99,8 +99,8 @@ test.describe('PWA Service Worker', () => { const html = await response.text(); // SvelteKit outputs content-hashed bundle names in _app/immutable/ - expect(html).toMatch(/href="(\.\/|\/)_app\/immutable\/bundle\.[a-zA-Z0-9-]+\.js"/); - expect(html).toMatch(/href="(\.\/|\/)_app\/immutable\/assets\/bundle\.[a-zA-Z0-9-]+\.css"/); - expect(html).toMatch(/import\("(\.\/|\/)_app\/immutable\/bundle\.[a-zA-Z0-9-]+\.js"\)/); + expect(html).toMatch(/href="(\.\/|\/)_app\/immutable\/bundle\.[a-zA-Z0-9_-]+\.js"/); + expect(html).toMatch(/href="(\.\/|\/)_app\/immutable\/assets\/bundle\.[a-zA-Z0-9_-]+\.css"/); + expect(html).toMatch(/import\("(\.\/|\/)_app\/immutable\/bundle\.[a-zA-Z0-9_-]+\.js"\)/); }); }); diff --git a/tools/ui/tests/unit/mcp-service.test.ts b/tools/ui/tests/unit/mcp-service.test.ts index afd3bdd5cf..1f6fdda377 100644 --- a/tools/ui/tests/unit/mcp-service.test.ts +++ b/tools/ui/tests/unit/mcp-service.test.ts @@ -3,6 +3,7 @@ import { Client } from '@modelcontextprotocol/sdk/client'; import { MCPService } from '$lib/services/mcp.service'; import { MCPConnectionPhase, MCPTransportType } from '$lib/enums'; import type { MCPConnectionLog, MCPServerConfig } from '$lib/types'; +import { CORS_PROXY_HEADER_PREFIX } from '$lib/constants'; type DiagnosticFetchFactory = ( serverName: string, @@ -16,11 +17,12 @@ type DiagnosticFetchFactory = ( const createDiagnosticFetch = ( config: MCPServerConfig, onLog?: (log: MCPConnectionLog) => void, - baseInit: RequestInit = {} + baseInit: RequestInit = {}, + useProxy = false ) => ( MCPService as unknown as { createDiagnosticFetch: DiagnosticFetchFactory } - ).createDiagnosticFetch('test-server', config, baseInit, new URL(config.url), false, onLog); + ).createDiagnosticFetch('test-server', config, baseInit, new URL(config.url), useProxy, onLog); describe('MCPService', () => { afterEach(() => { @@ -94,6 +96,64 @@ describe('MCPService', () => { }); }); + it('wraps dynamic request headers when using the CORS proxy', async () => { + const logs: MCPConnectionLog[] = []; + const proxiedAuthToken = `${CORS_PROXY_HEADER_PREFIX}x-auth-token`; + const proxiedContentType = `${CORS_PROXY_HEADER_PREFIX}content-type`; + const proxiedSessionId = `${CORS_PROXY_HEADER_PREFIX}mcp-session-id`; + const response = new Response('{}', { + status: 200, + headers: { 'content-type': 'application/json' } + }); + const fetchMock = vi.fn().mockResolvedValue(response); + + vi.stubGlobal('fetch', fetchMock); + + const config: MCPServerConfig = { + url: 'https://example.com/mcp', + transport: MCPTransportType.STREAMABLE_HTTP, + useProxy: true + }; + + const controller = createDiagnosticFetch( + config, + (log) => logs.push(log), + { + headers: { + authorization: 'Bearer llama-server-key', + [proxiedAuthToken]: 'target-token' + } + }, + true + ); + + await controller.fetch('http://localhost:8080/cors-proxy?url=https%3A%2F%2Fexample.com%2Fmcp', { + method: 'POST', + headers: { + 'content-type': 'application/json', + 'mcp-session-id': 'session-request-12345' + }, + body: '{}' + }); + + const sentHeaders = fetchMock.mock.calls[0]?.[1]?.headers as Headers; + expect(sentHeaders.get('authorization')).toBe('Bearer llama-server-key'); + expect(sentHeaders.get(proxiedAuthToken)).toBe('target-token'); + expect(sentHeaders.get(proxiedContentType)).toBe('application/json'); + expect(sentHeaders.get(proxiedSessionId)).toBe('session-request-12345'); + expect(sentHeaders.has('content-type')).toBe(false); + expect(sentHeaders.has('mcp-session-id')).toBe(false); + expect(logs[0].details).toMatchObject({ + request: { + headers: { + authorization: '[redacted]', + [proxiedAuthToken]: '[redacted]', + [proxiedSessionId]: '....12345' + } + } + }); + }); + it('partially redacts mcp-session-id in diagnostic request and response logs', async () => { const logs: MCPConnectionLog[] = []; const response = new Response('{}', { diff --git a/tools/ui/tests/unit/sanitize-headers.test.ts b/tools/ui/tests/unit/sanitize-headers.test.ts index f5a682d863..8cc1fcdfc8 100644 --- a/tools/ui/tests/unit/sanitize-headers.test.ts +++ b/tools/ui/tests/unit/sanitize-headers.test.ts @@ -1,5 +1,6 @@ import { describe, expect, it } from 'vitest'; import { sanitizeHeaders } from '$lib/utils/api-headers'; +import { CORS_PROXY_HEADER_PREFIX } from '$lib/constants'; describe('sanitizeHeaders', () => { it('returns empty object for undefined input', () => { @@ -52,4 +53,21 @@ describe('sanitizeHeaders', () => { const result = sanitizeHeaders(headers, ['X-CUSTOM-TOKEN']); expect(result['x-custom-token']).toBe('[redacted]'); }); + + it('redacts proxied sensitive and custom target headers', () => { + const proxiedAuthorization = `${CORS_PROXY_HEADER_PREFIX}authorization`; + const proxiedSessionId = `${CORS_PROXY_HEADER_PREFIX}mcp-session-id`; + const proxiedVendorKey = `${CORS_PROXY_HEADER_PREFIX}x-vendor-key`; + const headers = new Headers({ + [proxiedAuthorization]: 'Bearer secret', + [proxiedSessionId]: 'session-12345', + [proxiedVendorKey]: 'vendor-secret' + }); + const partial = new Map([['mcp-session-id', 5]]); + const result = sanitizeHeaders(headers, ['x-vendor-key'], partial); + + expect(result[proxiedAuthorization]).toBe('[redacted]'); + expect(result[proxiedSessionId]).toBe('....12345'); + expect(result[proxiedVendorKey]).toBe('[redacted]'); + }); }); From 8452824611be321246f33339727f60a90c02c277 Mon Sep 17 00:00:00 2001 From: Muhammad Salem Date: Sat, 20 Jun 2026 18:08:59 +0300 Subject: [PATCH 28/86] release: add missing link for win opencl adreno arm64 (#24809) --- .github/workflows/release.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 9789215f29..c7b67e4925 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -1627,6 +1627,7 @@ jobs: **Windows:** - [Windows x64 (CPU)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cpu-x64.zip) - [Windows arm64 (CPU)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cpu-arm64.zip) + - [Windows arm64 (OpenCL Adreno)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-opencl-adreno-arm64.zip) - [Windows x64 (CUDA 12)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cuda-12.4-x64.zip) - [CUDA 12.4 DLLs](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/cudart-llama-bin-win-cuda-12.4-x64.zip) - [Windows x64 (CUDA 13)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cuda-13.3-x64.zip) - [CUDA 13.3 DLLs](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/cudart-llama-bin-win-cuda-13.3-x64.zip) - [Windows x64 (Vulkan)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-vulkan-x64.zip) From 75f460ac289e61eb3c2bb63c9487794a1ed514d1 Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Sat, 20 Jun 2026 19:45:27 +0200 Subject: [PATCH 29/86] arg: try fixing test-args-parser randomly fails (#24826) * arg: try fixing test-args-parser randomly fails * return ref * try triggering the workflow * exception wrapper * wip * test * test 2 * arg: guard win32 utf8 argv override make_utf8_argv rebuilds argv from GetCommandLineW to fix utf8 handling of non ascii arguments on windows. the override runs unconditionally inside common_params_parse, so it also clobbers a programmatic argv passed by a caller. test-arg-parser builds a synthetic argv but then sees the real process command line instead, the model argument is never parsed, and the assert that expects success aborts via fastfail (0xC0000409). this shows up as a random failure in the openvino windows workflow. only override argv when its length matches the caller argc, so the utf8 repair still applies to real binaries while a programmatic argv stays intact. --------- Co-authored-by: Pascal --- common/arg.cpp | 6 +++--- tests/test-arg-parser.cpp | 12 +++++++++++- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index a9b1a25b27..8f54b5c814 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -924,8 +924,8 @@ static utf8_argv make_utf8_argv() { bool common_params_parse(int argc, char ** argv, common_params & params, llama_example ex, void(*print_usage)(int, char **)) { #ifdef _WIN32 auto utf8 = make_utf8_argv(); - if (!utf8.ptrs.empty()) { - argc = static_cast(utf8.buf.size()); + // repair argv only when it matches the process command line + if (static_cast(utf8.buf.size()) == argc) { argv = utf8.ptrs.data(); } #endif @@ -2897,7 +2897,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.server_tools = parse_csv_row(value); } ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_TOOLS")); - add_opt(common_arg( + add_opt(common_arg( {"-ag", "--agent"}, {"-no-ag", "--no-agent"}, "whether to enable CORS proxy and all built-in tools - do not enable in untrusted environments (default: disabled)", diff --git a/tests/test-arg-parser.cpp b/tests/test-arg-parser.cpp index 0dd8422e73..e83ee85dd4 100644 --- a/tests/test-arg-parser.cpp +++ b/tests/test-arg-parser.cpp @@ -10,7 +10,7 @@ #undef NDEBUG #include -int main(void) { +static void test(void) { common_params params; printf("test-arg-parser: make sure there is no duplicated arguments in any examples\n\n"); @@ -210,3 +210,13 @@ int main(void) { printf("test-arg-parser: all tests OK\n\n"); } + +int main(void) { + try { + test(); + } catch (std::exception & e) { + fprintf(stderr, "test-arg-parser: exception: %s\n", e.what()); + return 1; + } + return 0; +} From 84de01a1f1c847292b8d90a9c0bff6619f2919be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrien=20Gallou=C3=ABt?= Date: Sat, 20 Jun 2026 20:07:01 +0200 Subject: [PATCH 30/86] llama : use LLM_KV for quantization_version & file_type (#24802) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Adrien Gallouët --- src/llama-quant.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index cf92ce4bb8..89b7fe8d43 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -932,8 +932,8 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: // copy the KV pairs from the input file gguf_set_kv (ctx_out.get(), ml.metadata); - gguf_set_val_u32(ctx_out.get(), "general.quantization_version", GGML_QNT_VERSION); // TODO: use LLM_KV - gguf_set_val_u32(ctx_out.get(), "general.file_type", ftype); // TODO: use LLM_KV + gguf_set_val_u32(ctx_out.get(), ml.llm_kv(LLM_KV_GENERAL_QUANTIZATION_VERSION).c_str(), GGML_QNT_VERSION); + gguf_set_val_u32(ctx_out.get(), ml.llm_kv(LLM_KV_GENERAL_FILE_TYPE).c_str(), ftype); // Remove split metadata gguf_remove_key(ctx_out.get(), ml.llm_kv(LLM_KV_SPLIT_NO).c_str()); From 4a8094317436a23c484c9803cc3ac348e236708f Mon Sep 17 00:00:00 2001 From: Guanhuai Zhang <67999475+BiReRa@users.noreply.github.com> Date: Sun, 21 Jun 2026 05:58:49 +0800 Subject: [PATCH 31/86] fix(hexagon): use padded stride for ssm-conv weights (#24470) --- ggml/src/ggml-hexagon/htp/ssm-conv.c | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/ggml/src/ggml-hexagon/htp/ssm-conv.c b/ggml/src/ggml-hexagon/htp/ssm-conv.c index d574da2e2b..a48bc9ed86 100644 --- a/ggml/src/ggml-hexagon/htp/ssm-conv.c +++ b/ggml/src/ggml-hexagon/htp/ssm-conv.c @@ -183,24 +183,25 @@ static inline void hvx_transpose_32x32_f32(HVX_Vector m[32]) { // transposed into VTCM. // // VTCM layouts (per thread): -// src1_T : {d_inner_per_thread, d_conv} — staged once per launch (small). -// src0_T : {d_inner_tile, ncs} — staged per d_inner-tile. +// src1_T : {d_inner_stride, d_conv} - staged once per launch (small). +// src0_T : {d_inner_tile, ncs} - staged per d_inner-tile. // // d_inner_tile is chosen so that per-thread VTCM stays under the budget. // Each thread iterates ceil(d_inner_per_thread d_inner_tile) tiles serially. #define HTP_SSM_CONV_VTCM_BUDGET (1u << 20) // 1 MiB per thread -// Scalar transpose: src1 {d_conv, d_inner} (DDR) -> {d_inner_per_thread, d_conv} (VTCM) +// Scalar transpose: src1 {d_conv, d_inner} (DDR) -> {d_inner_stride, d_conv} (VTCM) static inline void transpose_src1(const float * src1_data, uint32_t src1_stride_inner, uint32_t i1_off, uint32_t d_inner_per_thread, + uint32_t d_inner_stride, uint32_t d_conv, float * src1_T) { for (uint32_t i = 0; i < d_inner_per_thread; ++i) { const float * src_row = src1_data + (i1_off + i) * src1_stride_inner; for (uint32_t j = 0; j < d_conv; ++j) { - src1_T[j * d_inner_per_thread + i] = src_row[j]; + src1_T[j * d_inner_stride + i] = src_row[j]; } } } @@ -280,6 +281,7 @@ static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void } const uint32_t d_inner_per_thread = ir1 - ir0; + const uint32_t d_inner_stride = scctx->nrows_per_thread; const uint32_t d_inner_tile = scctx->d_inner_tile; const float * src0_data = (const float *) src0->data; @@ -290,8 +292,8 @@ static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void float * src0_T = (float *)(octx->src0_spad.data + ith * octx->src0_spad.size_per_thread); float * src1_T = (float *)(octx->src1_spad.data + ith * octx->src1_spad.size_per_thread); - // Stage src1 weights once into VTCM in {d_inner_per_thread, d_conv} layout. - transpose_src1(src1_data, src1_stride_inner, ir0, d_inner_per_thread, d_conv, src1_T); + // Stage src1 weights once into VTCM in {d_inner_stride, d_conv} layout. + transpose_src1(src1_data, src1_stride_inner, ir0, d_inner_per_thread, d_inner_stride, d_conv, src1_T); const uint32_t C_TILE = VLEN_FP32; @@ -314,7 +316,7 @@ static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void HVX_Vector acc = hvx_vec_splat_f32(0.0f); for (uint32_t j = 0; j < d_conv; ++j) { HVX_Vector x = *(const HVX_Vector *) (src0_T + (t + j) * d_inner_tile + cb); - HVX_Vector w = *(const HVX_Vector *) (src1_T + j * d_inner_per_thread + tile_off + cb); + HVX_Vector w = *(const HVX_Vector *) (src1_T + j * d_inner_stride + tile_off + cb); acc = Q6_Vqf32_vadd_Vqf32Vqf32(acc, Q6_Vqf32_vmpy_VsfVsf(x, w)); } HVX_Vector res = Q6_Vsf_equals_Vqf32(acc); @@ -362,8 +364,7 @@ int op_ssm_conv_f32(struct htp_ops_context * octx) { use_hvx = 1; } - scctx.nrows_per_thread = (d_inner + n_threads - 1) / n_threads; - scctx.nrows_per_thread += (scctx.nrows_per_thread & 1); + scctx.nrows_per_thread = hex_round_up((d_inner + n_threads - 1) / n_threads, VLEN_FP32); const uint32_t d_inner_per_thread = scctx.nrows_per_thread; const uint32_t ncs = src0->ne[0]; From c57607016a1ebdd08d269e3378eee5546fc3bf3a Mon Sep 17 00:00:00 2001 From: Aldehir Rojas Date: Sat, 20 Jun 2026 17:43:04 -0500 Subject: [PATCH 32/86] common/json-schema-to-grammar : align spacing rules with parsers (#24835) --- common/json-schema-to-grammar.cpp | 46 ++-- common/peg-parser.cpp | 2 +- examples/json_schema_to_grammar.py | 42 ++-- tests/test-chat.cpp | 4 +- tests/test-json-schema-to-grammar.cpp | 310 +++++++++++++------------- 5 files changed, 202 insertions(+), 202 deletions(-) diff --git a/common/json-schema-to-grammar.cpp b/common/json-schema-to-grammar.cpp index e2c4d6ce22..b18607cd65 100644 --- a/common/json-schema-to-grammar.cpp +++ b/common/json-schema-to-grammar.cpp @@ -233,27 +233,27 @@ struct BuiltinRule { }; static std::unordered_map PRIMITIVE_RULES = { - {"boolean", {"(\"true\" | \"false\") space", {}}}, + {"boolean", {"(\"true\" | \"false\")", {}}}, {"decimal-part", {"[0-9]{1,16}", {}}}, {"integral-part", {"[0] | [1-9] [0-9]{0,15}", {}}}, - {"number", {"(\"-\"? integral-part) (\".\" decimal-part)? ([eE] [-+]? integral-part)? space", {"integral-part", "decimal-part"}}}, - {"integer", {"(\"-\"? integral-part) space", {"integral-part"}}}, + {"number", {"(\"-\"? integral-part) (\".\" decimal-part)? ([eE] [-+]? integral-part)?", {"integral-part", "decimal-part"}}}, + {"integer", {"(\"-\"? integral-part)", {"integral-part"}}}, {"value", {"object | array | string | number | boolean | null", {"object", "array", "string", "number", "boolean", "null"}}}, - {"object", {"\"{\" space ( string \":\" space value (\",\" space string \":\" space value)* )? \"}\" space", {"string", "value"}}}, - {"array", {"\"[\" space ( value (\",\" space value)* )? \"]\" space", {"value"}}}, - {"uuid", {"\"\\\"\" [0-9a-fA-F]{8} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{12} \"\\\"\" space", {}}}, + {"object", {"\"{\" space ( string \":\" space value (\",\" space string \":\" space value)* )? space \"}\"", {"string", "value"}}}, + {"array", {"\"[\" space ( value (\",\" space value)* )? space \"]\"", {"value"}}}, + {"uuid", {"\"\\\"\" [0-9a-fA-F]{8} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{12} \"\\\"\"", {}}}, {"char", {"[^\"\\\\\\x7F\\x00-\\x1F] | [\\\\] ([\"\\\\bfnrt] | \"u\" [0-9a-fA-F]{4})", {}}}, - {"string", {"\"\\\"\" char* \"\\\"\" space", {"char"}}}, - {"null", {"\"null\" space", {}}}, + {"string", {"\"\\\"\" char* \"\\\"\"", {"char"}}}, + {"null", {"\"null\"", {}}}, }; static std::unordered_map STRING_FORMAT_RULES = { {"date", {"[0-9]{4} \"-\" ( \"0\" [1-9] | \"1\" [0-2] ) \"-\" ( \"0\" [1-9] | [1-2] [0-9] | \"3\" [0-1] )", {}}}, {"time", {"([01] [0-9] | \"2\" [0-3]) \":\" [0-5] [0-9] \":\" [0-5] [0-9] ( \".\" [0-9]{3} )? ( \"Z\" | ( \"+\" | \"-\" ) ( [01] [0-9] | \"2\" [0-3] ) \":\" [0-5] [0-9] )", {}}}, {"date-time", {"date \"T\" time", {"date", "time"}}}, - {"date-string", {"\"\\\"\" date \"\\\"\" space", {"date"}}}, - {"time-string", {"\"\\\"\" time \"\\\"\" space", {"time"}}}, - {"date-time-string", {"\"\\\"\" date-time \"\\\"\" space", {"date-time"}}} + {"date-string", {"\"\\\"\" date \"\\\"\"", {"date"}}}, + {"time-string", {"\"\\\"\" time \"\\\"\"", {"time"}}}, + {"date-time-string", {"\"\\\"\" date-time \"\\\"\"", {"date-time"}}} }; static bool is_reserved_name(const std::string & name) { @@ -551,16 +551,16 @@ private: } return join_seq(); }; - return _add_rule(name, "\"\\\"\" (" + to_rule(transform()) + ") \"\\\"\" space"); + return _add_rule(name, "\"\\\"\" (" + to_rule(transform()) + ") \"\\\"\""); } /* Returns a rule that matches a JSON string that is none of the provided strings not_strings({"a"}) - -> ["] ( [a] char+ | [^"a] char* )? ["] space + -> ["] ( [a] char+ | [^"a] char* )? ["] not_strings({"and", "also"}) - -> ["] ( [a] ([l] ([s] ([o] char+ | [^"o] char*) | [^"s] char*) | [n] ([d] char+ | [^"d] char*) | [^"ln] char*) | [^"a] char* )? ["] space + -> ["] ( [a] ([l] ([s] ([o] char+ | [^"o] char*) | [^"s] char*) | [n] ([d] char+ | [^"d] char*) | [^"ln] char*) | [^"a] char* )? ["] */ std::string _not_strings(const std::vector & strings) { @@ -619,7 +619,7 @@ private: if (!trie.is_end_of_string) { out << "?"; } - out << " [\"] space"; + out << " [\"]"; return out.str(); } @@ -725,7 +725,7 @@ private: rule += " )?"; } - rule += " \"}\" space"; + rule += " space \"}\""; return rule; } @@ -858,14 +858,14 @@ public: return _add_rule(rule_name, _generate_union_rule(name, schema_types)); } if (schema.contains("const")) { - return _add_rule(rule_name, _generate_constant_rule(schema["const"]) + " space"); + return _add_rule(rule_name, _generate_constant_rule(schema["const"])); } if (schema.contains("enum")) { std::vector enum_values; for (const auto & v : schema["enum"]) { enum_values.push_back(_generate_constant_rule(v)); } - return _add_rule(rule_name, "(" + string_join(enum_values, " | ") + ") space"); + return _add_rule(rule_name, "(" + string_join(enum_values, " | ") + ")"); } if ((schema_type.is_null() || schema_type == "object") && (schema.contains("properties") || @@ -933,7 +933,7 @@ public: } } if (!enum_intersection.empty()) { - return _add_rule(rule_name, "(" + string_join(enum_intersection, " | ") + ") space"); + return _add_rule(rule_name, "(" + string_join(enum_intersection, " | ") + ")"); } } return _add_rule(rule_name, _build_object_rule(properties, required, hybrid_name, json())); @@ -948,7 +948,7 @@ public: } rule += visit(items[i], name + (name.empty() ? "" : "-") + "tuple-" + std::to_string(i)); } - rule += " \"]\" space"; + rule += " space \"]\""; return _add_rule(rule_name, rule); } std::string item_rule_name = visit(items, name + (name.empty() ? "" : "-") + "item"); @@ -956,7 +956,7 @@ public: json max_items_json = schema.contains("maxItems") ? schema["maxItems"] : json(); int max_items = max_items_json.is_number_integer() ? max_items_json.get() : std::numeric_limits::max(); - return _add_rule(rule_name, "\"[\" space " + build_repetition(item_rule_name, min_items, max_items, "\",\" space") + " \"]\" space"); + return _add_rule(rule_name, "\"[\" space " + build_repetition(item_rule_name, min_items, max_items, "\",\" space") + " space \"]\""); } if ((schema_type.is_null() || schema_type == "string") && schema.contains("pattern")) { return _visit_pattern(schema["pattern"], rule_name); @@ -972,7 +972,7 @@ public: std::string char_rule = _add_primitive("char", PRIMITIVE_RULES.at("char")); int min_len = schema.contains("minLength") ? schema["minLength"].get() : 0; int max_len = schema.contains("maxLength") ? schema["maxLength"].get() : std::numeric_limits::max(); - return _add_rule(rule_name, "\"\\\"\" " + build_repetition(char_rule, min_len, max_len) + " \"\\\"\" space"); + return _add_rule(rule_name, "\"\\\"\" " + build_repetition(char_rule, min_len, max_len) + " \"\\\"\""); } if (schema_type == "integer" && (schema.contains("minimum") || schema.contains("exclusiveMinimum") || schema.contains("maximum") || schema.contains("exclusiveMaximum"))) { int64_t min_value = std::numeric_limits::min(); @@ -990,7 +990,7 @@ public: std::stringstream out; out << "("; build_min_max_int(min_value, max_value, out); - out << ") space"; + out << ")"; return _add_rule(rule_name, out.str()); } if (schema.empty() || schema_type == "object") { diff --git a/common/peg-parser.cpp b/common/peg-parser.cpp index d4b491a80e..ff0d24d43f 100644 --- a/common/peg-parser.cpp +++ b/common/peg-parser.cpp @@ -1342,7 +1342,7 @@ common_peg_parser common_peg_parser_builder::json_object() { common_peg_parser common_peg_parser_builder::json_array() { return rule("json-array", [this]() { auto ws = space(); - auto elements = sequence({json(), zero_or_more(sequence({literal(","), ws, json()}))}); + auto elements = sequence({json(), zero_or_more(sequence({ws, literal(","), ws, json()}))}); return sequence({ literal("["), ws, diff --git a/examples/json_schema_to_grammar.py b/examples/json_schema_to_grammar.py index 077fcfacac..83abd259da 100755 --- a/examples/json_schema_to_grammar.py +++ b/examples/json_schema_to_grammar.py @@ -198,18 +198,18 @@ class BuiltinRule: SPACE_RULE = '| " " | "\\n"{1,2} [ \\t]{0,20}' PRIMITIVE_RULES = { - 'boolean' : BuiltinRule('("true" | "false") space', []), + 'boolean' : BuiltinRule('("true" | "false")', []), 'decimal-part' : BuiltinRule('[0-9]{1,16}', []), 'integral-part': BuiltinRule('[0] | [1-9] [0-9]{0,15}', []), - 'number' : BuiltinRule('("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space', ['integral-part', 'decimal-part']), - 'integer' : BuiltinRule('("-"? integral-part) space', ['integral-part']), + 'number' : BuiltinRule('("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)?', ['integral-part', 'decimal-part']), + 'integer' : BuiltinRule('("-"? integral-part)', ['integral-part']), 'value' : BuiltinRule('object | array | string | number | boolean | null', ['object', 'array', 'string', 'number', 'boolean', 'null']), - 'object' : BuiltinRule('"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space', ['string', 'value']), - 'array' : BuiltinRule('"[" space ( value ("," space value)* )? "]" space', ['value']), - 'uuid' : BuiltinRule(r'"\"" [0-9a-fA-F]{8} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{12} "\"" space', []), + 'object' : BuiltinRule('"{" space ( string ":" space value ("," space string ":" space value)* )? space "}"', ['string', 'value']), + 'array' : BuiltinRule('"[" space ( value ("," space value)* )? space "]"', ['value']), + 'uuid' : BuiltinRule(r'"\"" [0-9a-fA-F]{8} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{12} "\""', []), 'char' : BuiltinRule(r'[^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})', []), - 'string' : BuiltinRule(r'"\"" char* "\"" space', ['char']), - 'null' : BuiltinRule('"null" space', []), + 'string' : BuiltinRule(r'"\"" char* "\""', ['char']), + 'null' : BuiltinRule('"null"', []), } # TODO: support "uri", "email" string formats @@ -217,9 +217,9 @@ STRING_FORMAT_RULES = { 'date' : BuiltinRule('[0-9]{4} "-" ( "0" [1-9] | "1" [0-2] ) "-" ( \"0\" [1-9] | [1-2] [0-9] | "3" [0-1] )', []), 'time' : BuiltinRule('([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9]{3} )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )', []), 'date-time' : BuiltinRule('date "T" time', ['date', 'time']), - 'date-string' : BuiltinRule('"\\"" date "\\"" space', ['date']), - 'time-string' : BuiltinRule('"\\"" time "\\"" space', ['time']), - 'date-time-string': BuiltinRule('"\\"" date-time "\\"" space', ['date-time']), + 'date-string' : BuiltinRule('"\\"" date "\\""', ['date']), + 'time-string' : BuiltinRule('"\\"" time "\\""', ['time']), + 'date-time-string': BuiltinRule('"\\"" date-time "\\""', ['date-time']), } DOTALL = '[\\U00000000-\\U0010FFFF]' @@ -319,7 +319,7 @@ class SchemaConverter: out.append(f'[^"{"".join(rejects)}] {char_rule}*') visit(trie) - out.append(f' ){"" if trie.is_end_of_string else "?"} ["] space') + out.append(f' ){"" if trie.is_end_of_string else "?"} ["]') return ''.join(out) def _add_rule(self, name, rule): @@ -549,7 +549,7 @@ class SchemaConverter: return self._add_rule( name, to_rule(transform()) if self._raw_pattern \ - else "\"\\\"\" (" + to_rule(transform()) + ") \"\\\"\" space") + else "\"\\\"\" (" + to_rule(transform()) + ") \"\\\"\"") def _resolve_ref(self, ref): @@ -580,10 +580,10 @@ class SchemaConverter: return self._add_rule(rule_name, self._generate_union_rule(name, [{**schema, 'type': t} for t in schema_type])) elif 'const' in schema: - return self._add_rule(rule_name, self._generate_constant_rule(schema['const']) + ' space') + return self._add_rule(rule_name, self._generate_constant_rule(schema['const'])) elif 'enum' in schema: - rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in schema['enum'])) + ') space' + rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in schema['enum'])) + ')' return self._add_rule(rule_name, rule) elif schema_type in (None, 'object') and \ @@ -624,7 +624,7 @@ class SchemaConverter: enum_intersection &= s if enum_intersection: - rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in sorted(enum_intersection))) + ') space' + rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in sorted(enum_intersection))) + ')' return self._add_rule(rule_name, rule) return self._add_rule(rule_name, self._build_object_rule(properties, required, hybrid_name, additional_properties=None)) @@ -638,12 +638,12 @@ class SchemaConverter: ' "," space '.join( self.visit(item, f'{name}{"-" if name else ""}tuple-{i}') for i, item in enumerate(items)) + - ' "]" space') + ' space "]"') else: item_rule_name = self.visit(items, f'{name}{"-" if name else ""}item') min_items = schema.get("minItems", 0) max_items = schema.get("maxItems") - return self._add_rule(rule_name, '"[" space ' + _build_repetition(item_rule_name, min_items, max_items, separator_rule='"," space') + ' "]" space') + return self._add_rule(rule_name, '"[" space ' + _build_repetition(item_rule_name, min_items, max_items, separator_rule='"," space') + ' space "]"') elif schema_type in (None, 'string') and 'pattern' in schema: return self._visit_pattern(schema['pattern'], rule_name) @@ -663,7 +663,7 @@ class SchemaConverter: min_len = schema.get('minLength', 0) max_len = schema.get('maxLength') - return self._add_rule(rule_name, r'"\"" ' + _build_repetition(char_rule, min_len, max_len) + r' "\"" space') + return self._add_rule(rule_name, r'"\"" ' + _build_repetition(char_rule, min_len, max_len) + r' "\""') elif schema_type in (None, 'integer') and \ ('minimum' in schema or 'exclusiveMinimum' in schema or 'maximum' in schema or 'exclusiveMaximum' in schema): @@ -680,7 +680,7 @@ class SchemaConverter: out = ["("] _generate_min_max_int(min_value, max_value, out) - out.append(") space") + out.append(")") return self._add_rule(rule_name, ''.join(out)) elif (schema_type == 'object') or (len(schema) == 0): @@ -765,7 +765,7 @@ class SchemaConverter: rule += ' )' rule += ' )?' - rule += ' "}" space' + rule += ' space "}"' return rule diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index 902a4c135a..30aa35e137 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -5022,14 +5022,14 @@ static void test_template_output_peg_parsers(bool detailed_debug) { tst.test("Hello, world!\nWhat's up?").tools({ special_function_tool }).expect(message_assist).expect_reconstruction().run(); tst.test( - "```json\n\"42\" \n```") + "```json\n\"42\"\n```") .reasoning_format(COMMON_REASONING_FORMAT_AUTO) .json_schema(const_schema) .expect_content(R"("42")") .run(); tst.test( - "\"42\" \n") + "\"42\"\n") .reasoning_format(COMMON_REASONING_FORMAT_AUTO) .json_schema(const_schema) .expect_content(R"("42")") diff --git a/tests/test-json-schema-to-grammar.cpp b/tests/test-json-schema-to-grammar.cpp index b4362852c3..f095274cd1 100755 --- a/tests/test-json-schema-to-grammar.cpp +++ b/tests/test-json-schema-to-grammar.cpp @@ -92,7 +92,7 @@ static void test_all(const std::string & lang, std::function Date: Sat, 20 Jun 2026 21:15:06 -0500 Subject: [PATCH 33/86] common/peg : refactor until gbnf grammar generation (#24839) * common/peg : refactor until gbnf grammar into an ac automaton * cont : add a test with multiple strings * cont : pad state with 0s so rules line up * cont : clean up comments * cont : use set everywhere * cont : inline state num string padding * cont : add a ref to PR * cont : fix regression in server-tools.cpp --- common/peg-parser.cpp | 194 +++++++++++++--------- common/peg-parser.h | 4 +- tests/peg-parser/test-gbnf-generation.cpp | 80 ++++++++- tools/server/server-tools.cpp | 1 + 4 files changed, 199 insertions(+), 80 deletions(-) diff --git a/common/peg-parser.cpp b/common/peg-parser.cpp index ff0d24d43f..506b902451 100644 --- a/common/peg-parser.cpp +++ b/common/peg-parser.cpp @@ -6,13 +6,14 @@ #include "unicode.h" #include +#include #include #include #include #include #include +#include #include -#include // Trick to catch missing branches template @@ -88,40 +89,7 @@ struct trie { return match_result{match_result::NO_MATCH}; } - struct prefix_and_next { - std::vector prefix; - std::vector next_chars; - }; - - std::vector collect_prefix_and_next() { - std::vector prefix; - std::vector result; - collect_prefix_and_next(0, prefix, result); - return result; - } - private: - void collect_prefix_and_next(size_t index, std::vector & prefix, std::vector & out) { - if (!nodes[index].is_word) { - if (!nodes[index].children.empty()) { - std::vector chars; - chars.reserve(nodes[index].children.size()); - for (const auto & p : nodes[index].children) { - chars.push_back(p.first); - } - out.emplace_back(prefix_and_next{prefix, chars}); - } - } - - for (const auto & p : nodes[index].children) { - uint32_t ch = p.first; - auto child = p.second; - prefix.push_back(ch); - collect_prefix_and_next(child, prefix, out); - prefix.pop_back(); - } - } - size_t create_node() { size_t index = nodes.size(); nodes.emplace_back(); @@ -153,6 +121,65 @@ struct trie { } }; +// Aho-Corasick automaton +struct aho_corasick { + trie t; + std::vector fail; // failure links + std::vector order; // states in BFS order + std::vector terminal; // match states (directly or via a suffix link) + std::set alphabet; // every character with a transition + + aho_corasick(const std::vector & strings) : t(strings) { + const auto & nodes = t.nodes; + const size_t n = nodes.size(); + + fail.assign(n, 0); + order.reserve(n); + + std::deque queue{ 0 }; + while (!queue.empty()) { + size_t u = queue.front(); + queue.pop_front(); + order.push_back(u); + for (const auto & [ch, v] : nodes[u].children) { + if (u != 0) { + size_t f = fail[u]; + while (f && nodes[f].children.find(ch) == nodes[f].children.end()) { + f = fail[f]; + } + auto it = nodes[f].children.find(ch); + fail[v] = (it != nodes[f].children.end() && it->second != v) ? it->second : 0; + } + queue.push_back(v); + } + } + + terminal.assign(n, false); + for (size_t u : order) { + terminal[u] = nodes[u].is_word || (u != 0 && terminal[fail[u]]); + } + + for (const auto & node : nodes) { + for (const auto & [ch, v] : node.children) { + alphabet.insert(ch); + } + } + } + + size_t num_states() const { return t.nodes.size(); } + bool is_terminal(size_t s) const { return terminal[s]; } + + // follow failure links until a transition on `ch` exists. + size_t next(size_t state, uint32_t ch) const { + const auto & nodes = t.nodes; + while (state && nodes[state].children.find(ch) == nodes[state].children.end()) { + state = fail[state]; + } + auto it = nodes[state].children.find(ch); + return it != nodes[state].children.end() ? it->second : 0; + } +}; + static std::pair parse_hex_escape(const std::string & str, size_t pos, int hex_count) { if (pos + hex_count > str.length()) { return {0, 0}; @@ -992,12 +1019,12 @@ void common_peg_arena::resolve_refs() { } std::string common_peg_arena::dump(common_peg_parser_id id) const { - std::unordered_set visited; + std::set visited; return dump_impl(id, visited); } std::string common_peg_arena::dump_impl(common_peg_parser_id id, - std::unordered_set & visited) const { + std::set & visited) const { // Check for cycles if (visited.count(id)) { return "[cycle]"; @@ -1502,61 +1529,74 @@ static std::string gbnf_escape_char_class(uint32_t c) { return std::string(buf); } -static std::string gbnf_excluding_pattern(const std::vector & strings) { - trie matcher(strings); - auto pieces = matcher.collect_prefix_and_next(); +// GBNF grammar matching strings that contain no string in `strings` as a +// substring. Emits the complement of an Aho-Corasick automaton DFA and returns +// the start state rule name. +// +// ref: https://github.com/ggml-org/llama.cpp/pull/24839 +static std::string gbnf_excluding_grammar(const common_grammar_builder & builder, + const std::string & prefix, + const std::vector & strings) { + aho_corasick ac(strings); - std::string pattern; - std::string trailing; // optional proper-prefix of a delimiter, allowed only at the very end - for (size_t i = 0; i < pieces.size(); ++i) { - if (i > 0) { - pattern += " | "; + auto state_name = [&](size_t s) -> std::string { + if (s == 0) { + return prefix; } + std::string num = std::to_string(s); + num = num.size() == 1 ? ("0" + num) : num; + return prefix + "-" + num; + }; - const auto & pre = pieces[i].prefix; - const auto & chars = pieces[i].next_chars; - - std::string cls; - cls.reserve(chars.size()); + auto char_class = [](const std::vector & chars, bool negate) { + std::string s = negate ? "[^" : "["; for (uint32_t ch : chars) { - cls += gbnf_escape_char_class(ch); + s += gbnf_escape_char_class(ch); + } + return s + "]"; + }; + + for (size_t q = 0; q < ac.num_states(); q++) { + if (ac.is_terminal(q)) { + continue; // match states are dropped } - if (!pre.empty()) { - std::string pre_literal = gbnf_format_literal(common_unicode_cpts_to_utf8(pre)); - pattern += pre_literal + " [^" + cls + "]"; - // Each interior alternative consumes a delimiter-prefix plus a disambiguating - // char, so the repetition alone cannot match a value that *ends* on a proper - // prefix of a delimiter (e.g. a trailing "\n" when the delimiter is - // "\n\n"). The runtime until() (greedy first-match) accepts such - // values, so without this the grammar would reject input the parser accepts. - // Allow the value to terminate on any proper prefix as an optional tail. - // This makes the grammar a slight superset of the runtime language (a value - // may end on the longest prefix, which greedy first-match would not itself - // produce); harmless for constrained generation, which only needs to admit - // every runtime-valid string. - if (!trailing.empty()) { - trailing += " | "; + std::map> buckets; + std::vector excluded; + for (uint32_t c : ac.alphabet) { + size_t d = ac.next(q, c); + if (ac.is_terminal(d)) { + excluded.push_back(c); // completes a forbidden string -> omit + } else if (d != 0) { + buckets[d].push_back(c); // specific non-root destination + excluded.push_back(c); } - trailing += pre_literal; - } else { - pattern += "[^" + cls + "]"; } + + std::string rhs = "|"; // every state is accepting + for (const auto & [d, chars] : buckets) { + rhs += " " + char_class(chars, false) + " " + state_name(d) + " |"; + } + rhs += " " + char_class(excluded, true) + " " + state_name(0); + + builder.add_rule(state_name(q), rhs); } - std::string result = "(" + pattern + ")*"; - if (!trailing.empty()) { - result += " (" + trailing + ")?"; + // An empty delimiter makes the start state terminal. Emit an entry rule + // that matches nothing so the returned reference stays valid. + if (ac.is_terminal(0)) { + builder.add_rule(prefix, "|"); } - return result; + + return state_name(0); } -static std::unordered_set collect_reachable_rules( +static std::set collect_reachable_rules( const common_peg_arena & arena, const common_peg_parser_id & rule ) { - std::unordered_set reachable; - std::unordered_set visited; + std::set reachable; + std::set visited; std::function visit = [&](common_peg_parser_id id) { const auto & parser = arena.get(id); @@ -1765,7 +1805,7 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo if (p.delimiters.empty()) { return ".*"; } - return gbnf_excluding_pattern(p.delimiters); + return gbnf_excluding_grammar(builder, "until-" + std::to_string(id), p.delimiters); } else if constexpr (std::is_same_v) { if (schema_delegates(p)) { return to_gbnf(p.child); @@ -1789,7 +1829,7 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo }; // Collect reachable rules - std::unordered_set reachable_rules; + std::set reachable_rules; if (lazy) { // Collect rules reachable from trigger rules diff --git a/common/peg-parser.h b/common/peg-parser.h index b6bb05214b..132173a64c 100644 --- a/common/peg-parser.h +++ b/common/peg-parser.h @@ -3,8 +3,8 @@ #include #include +#include #include -#include #include #include #include @@ -335,7 +335,7 @@ class common_peg_arena { friend class common_peg_parser_builder; private: - std::string dump_impl(common_peg_parser_id id, std::unordered_set & visited) const; + std::string dump_impl(common_peg_parser_id id, std::set & visited) const; common_peg_parser_id add_parser(common_peg_parser_variant parser); void add_rule(const std::string & name, common_peg_parser_id id); diff --git a/tests/peg-parser/test-gbnf-generation.cpp b/tests/peg-parser/test-gbnf-generation.cpp index 00111e6a19..45d692ca60 100644 --- a/tests/peg-parser/test-gbnf-generation.cpp +++ b/tests/peg-parser/test-gbnf-generation.cpp @@ -129,8 +129,86 @@ void test_gbnf_generation(testing &t) { }); assert_gbnf_equal(t, R"""( - root ::= ([^<] | "<" [^/] | "])* ("<" | "] until-0 + )""", gbnf); + }); + + t.test("until grammar overlapping delimiter", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.until("\n\n"); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_gbnf_equal(t, R"""( + root ::= until-0 + space ::= | " " | "\n"{1,2} [ \t]{0,20} + until-0 ::= | [\n] until-0-01 | [^\n] until-0 + until-0-01 ::= | [\n] until-0-01 | [<] until-0-02 | [^\n<] until-0 + until-0-02 ::= | [\n] until-0-01 | [/] until-0-03 | [^\n/] until-0 + until-0-03 ::= | [\n] until-0-01 | [p] until-0-04 | [^\np] until-0 + until-0-04 ::= | [\n] until-0-01 | [a] until-0-05 | [^\na] until-0 + until-0-05 ::= | [\n] until-0-01 | [r] until-0-06 | [^\nr] until-0 + until-0-06 ::= | [\n] until-0-01 | [a] until-0-07 | [^\na] until-0 + until-0-07 ::= | [\n] until-0-01 | [m] until-0-08 | [^\nm] until-0 + until-0-08 ::= | [\n] until-0-01 | [e] until-0-09 | [^\ne] until-0 + until-0-09 ::= | [\n] until-0-01 | [t] until-0-10 | [^\nt] until-0 + until-0-10 ::= | [\n] until-0-01 | [e] until-0-11 | [^\ne] until-0 + until-0-11 ::= | [\n] until-0-01 | [r] until-0-12 | [^\nr] until-0 + until-0-12 ::= | [\n] until-0-01 | [>] until-0-13 | [^\n>] until-0 + until-0-13 ::= | [^\n] until-0 + )""", gbnf); + }); + + // DeepSeek-V3.2 tag prefix. The DSML token (|DSML|) embeds U+FF5C, + // so the delimiter mixes ASCII and multi-byte codepoints. + t.test("until grammar unicode delimiter", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.until("<|DSML|"); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_gbnf_equal(t, R"""( + root ::= until-0 + space ::= | " " | "\n"{1,2} [ \t]{0,20} + until-0 ::= | [<] until-0-01 | [^<] until-0 + until-0-01 ::= | [<] until-0-01 | [\uFF5C] until-0-02 | [^<\uFF5C] until-0 + until-0-02 ::= | [<] until-0-01 | [D] until-0-03 | [^ #include #include +#include namespace fs = std::filesystem; From d789527482d925156d7c4adfecebf5fb8481e0ee Mon Sep 17 00:00:00 2001 From: YiChen Lv <63285796+forforever73@users.noreply.github.com> Date: Sun, 21 Jun 2026 16:33:18 +0800 Subject: [PATCH 34/86] spec : Support Step3.5/3.7 flash mtp3 (#24340) * add mtp_layer_offset + include nextn flags in graph reuse * add llama_set_mtp_layer_offset + llama_model_n_nextn_layer API * offset head select + require all MTP blocks * speculative multi-head process() * speculative multi-head draft() * gather outputs via inp_out_ids * cleanup * fix core * minor cleanup * merged draft_multi_head into draft() * mtp rename nextn * Apply suggestions from code review Co-authored-by: Aman Gupta * clean-up comments * fix for multi seq * apply suggestions && chain-heads comment * add a reference for chain_heads discussion --------- Co-authored-by: Aman Gupta --- common/speculative.cpp | 137 ++++++++++++++++++++++++++++++----------- include/llama.h | 17 ++--- src/llama-context.cpp | 8 +++ src/llama-context.h | 1 + src/llama-cparams.h | 2 + src/llama-ext.h | 5 ++ src/llama-graph.h | 11 +++- src/llama-model.cpp | 4 ++ src/models/step35.cpp | 55 ++++++++--------- 9 files changed, 167 insertions(+), 73 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index 9c20585dc3..3c38ae2b02 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -905,7 +905,13 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { int32_t n_embd = 0; - bool is_mem_shared = false; + // One MTP draft driver, three modes (set once in the ctor): + // is_mem_shared (gemma4): shares the target KV, runs all heads in one graph. + // chain_heads (step35): n_mtp_layers trained heads, one per draft step. + // neither (qwen35 / qwen35moe): a single trained MTP head. + int32_t n_mtp_layers = 1; + bool is_mem_shared = false; // gemma4 + bool chain_heads = false; // derived in the ctor: n_mtp_layers > 1 && !is_mem_shared // Per-sequence cross-batch carryover: pair (h_p, x_{p+1}) at MTP pos p+1. // The last h-row of one process() call needs the first token of the NEXT @@ -920,10 +926,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { std::vector> verify_h; std::vector verify_h_rows; - // Per-seq draft length from the last draft() call, used in accept() to - // roll back ctx_dft's recurrent state past the AR draft's redundant - // pre-advancement before process() mirrored the verify batch. - std::vector last_n_drafted; + std::vector i_last; + std::vector> chain_h; common_speculative_impl_draft_mtp(const common_params_speculative & params, uint32_t n_seq) : common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT_MTP, n_seq) @@ -936,6 +940,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { n_embd = llama_model_n_embd_out(llama_get_model(ctx_dft)); GGML_ASSERT(n_embd == llama_model_n_embd(llama_get_model(ctx_tgt)) && "MTP input row width must match the target h_nextn width"); + n_mtp_layers = std::max(1, (int) llama_model_n_layer_nextn(llama_get_model(ctx_dft))); LOG_INF("%s: adding speculative implementation 'draft-mtp'\n", __func__); LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%.2f, n_embd=%d, backend_sampling=%d\n", __func__, this->params.n_max, this->params.n_min, this->params.p_min, n_embd, (int) this->params.backend_sampling); @@ -982,16 +987,25 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { llama_set_embeddings_nextn(ctx_dft, true, /*masked*/ true); is_mem_shared = llama_get_ctx_other(ctx_dft) == ctx_tgt; + chain_heads = n_mtp_layers > 1 && !is_mem_shared; + + if (chain_heads) { + this->params.n_max = std::min(this->params.n_max, n_mtp_layers); + + chain_h.assign(n_seq, {}); + for (auto & c : chain_h) { + c.reserve((size_t) (this->params.n_max + 1) * n_embd); + } + } pending_h.assign(n_seq, std::vector(n_embd, 0.0f)); + i_last.assign(n_seq, -1); i_batch_beg.assign(n_seq, -1); i_batch_end.assign(n_seq, -1); verify_h.assign(n_seq, {}); verify_h_rows.assign(n_seq, 0); - - last_n_drafted.assign(n_seq, 0); } ~common_speculative_impl_draft_mtp() override { @@ -1097,9 +1111,34 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { set_h(i_batch_beg[seq_id], pending_h[seq_id].data()); } - const int32_t rc = llama_decode(ctx_dft, batch); - if (rc != 0) { - LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (pos=%d)\n", __func__, (int) rc, (int) batch_in.pos[0]); + auto * mem_dft = llama_get_memory(ctx_dft); + + bool ok = true; + for (int head = 0; head < n_mtp_layers; ++head) { + if (chain_heads) { + // ref: https://github.com/ggml-org/llama.cpp/pull/24340/changes#r3413498544 + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { + if (i_batch_beg[seq_id] < 0) { + continue; + } + llama_memory_seq_rm(mem_dft, seq_id, batch_in.pos[i_batch_beg[seq_id]], -1); + } + llama_set_nextn_layer_offset(ctx_dft, head); + } + + const int32_t rc = llama_decode(ctx_dft, batch); + if (rc != 0) { + LOG_ERR("%s: llama_decode(ctx_dft) head=%d failed rc=%d (pos=%d)\n", + __func__, head, (int) rc, (int) batch_in.pos[0]); + ok = false; + break; + } + } + + if (chain_heads) { + llama_set_nextn_layer_offset(ctx_dft, 0); // restore default for non-draft decodes + } + if (!ok) { return false; } } @@ -1134,7 +1173,6 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { int n_drafting = 0; std::vector drafting(n_seq); - const float * h_row = nullptr; const size_t row_bytes = (size_t) n_embd * sizeof(float); for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { @@ -1149,22 +1187,43 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { common_sampler_reset(smpls[seq_id].get()); common_batch_add(batch, dp.id_last, dp.n_past, { seq_id }, true); + std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd, pending_h[seq_id].data(), row_bytes); - h_row = pending_h[seq_id].data(); - std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes); - } + i_last[seq_id] = batch.n_tokens - 1; - int ret = llama_decode(ctx_dft, batch); - if (ret != 0) { - LOG_WRN("%s: llama_decode returned %d\n", __func__, ret); - return; + if (chain_heads) { + chain_h[seq_id].assign(pending_h[seq_id].begin(), pending_h[seq_id].end()); + } } int i = 0; while (n_drafting > 0) { - int i_batch = 0; + // each step decodes under a different head, i.e. a different decoder layer, and + // KV is per layer. process() filled this layer's KV only for positions < n_past + // (prompt + accepted prefix) — nothing in the draft region yet. so reset the + // draft region (the seq_rm lower bound is n_past, leaving the prompt KV intact) + // and select head i so it rebuilds its own layer's KV there; decoding just the + // latest token would leave its attention reading cells only another head wrote. + if (chain_heads) { + auto * mem_dft = llama_get_memory(ctx_dft); + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { + if (drafting[seq_id]) { + llama_memory_seq_rm(mem_dft, seq_id, dparams[seq_id].n_past, -1); + } + } + llama_set_nextn_layer_offset(ctx_dft, i); + } + int ret = llama_decode(ctx_dft, batch); + if (ret != 0) { + LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret); + break; + } + + // rebuild the batch for the next step: the growing-KV paths re-add only the + // new token (the KV already holds the prefix), while chained heads re-add the + // whole prefix at the next head. dropped sequences are simply not re-added. common_batch_clear(batch); for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { @@ -1174,9 +1233,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { auto * smpl = smpls[seq_id].get(); - common_sampler_sample(smpl, ctx_dft, i_batch, true); - h_row = llama_get_embeddings_nextn_ith(ctx_dft, i_batch); - ++i_batch; + common_sampler_sample(smpl, ctx_dft, i_last[seq_id], true); + const float * h_row = llama_get_embeddings_nextn_ith(ctx_dft, i_last[seq_id]); const auto * cur_p = common_sampler_get_candidates(smpl, true); @@ -1210,30 +1268,41 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { continue; } - if (is_mem_shared) { + if (chain_heads) { + // ref: https://github.com/ggml-org/llama.cpp/pull/24340#discussion_r3448031546 + chain_h[seq_id].insert(chain_h[seq_id].end(), h_row, h_row + n_embd); + + const int n_rows = (int) result.size() + 1; // id_last + tokens drafted so far + for (int t = 0; t < n_rows; ++t) { + const llama_token tok = (t == 0) ? dp.id_last : result[t - 1]; + common_batch_add(batch, tok, dp.n_past + t, { seq_id }, t == n_rows - 1); + std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd, + chain_h[seq_id].data() + (size_t) t * n_embd, row_bytes); + } + } else if (is_mem_shared) { // note: with shared memory (e.g. Gemma4 assistants) we use the same position for all draft tokens // ref: https://github.com/huggingface/transformers/blob/effde20942e3f82a1b97449f60b3a48c5ff96145/docs/source/en/model_doc/gemma4_assistant.md?plain=1#L36-L37 common_batch_add(batch, id, dp.n_past, { seq_id }, true); + std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd, h_row, row_bytes); } else { common_batch_add(batch, id, dp.n_past + i + 1, { seq_id }, true); + std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd, h_row, row_bytes); } - std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes); + + i_last[seq_id] = batch.n_tokens - 1; } if (batch.n_tokens == 0) { break; } - // evaluate the drafted tokens on the draft model - ret = llama_decode(ctx_dft, batch); - if (ret != 0) { - LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret); - break; - } - ++i; } + if (chain_heads) { + llama_set_nextn_layer_offset(ctx_dft, 0); // restore default for non-draft decodes + } + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { auto & dp = dparams[seq_id]; if (!dp.drafting) { @@ -1243,8 +1312,6 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { if (dp.result->size() < (size_t) params.n_min) { dp.result->clear(); } - - last_n_drafted[seq_id] = (uint16_t) dp.result->size(); } } @@ -1857,7 +1924,7 @@ common_speculative * common_speculative_init(common_params_speculative & params, bool has_draft_simple = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE)); bool has_draft_eagle3 = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3)) && params.draft.ctx_dft != nullptr; - bool has_mtp = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_MTP)) && params.draft.ctx_dft != nullptr; + bool has_draft_mtp = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_MTP)) && params.draft.ctx_dft != nullptr; @@ -1895,7 +1962,7 @@ common_speculative * common_speculative_init(common_params_speculative & params, if (has_draft_eagle3) { configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, params)); } - if (has_mtp) { + if (has_draft_mtp) { configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_MTP, params)); } } diff --git a/include/llama.h b/include/llama.h index 27e4806742..f723c9f60c 100644 --- a/include/llama.h +++ b/include/llama.h @@ -558,14 +558,15 @@ extern "C" { LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model); LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model); - LLAMA_API int32_t llama_model_n_ctx_train(const struct llama_model * model); - LLAMA_API int32_t llama_model_n_embd (const struct llama_model * model); - LLAMA_API int32_t llama_model_n_embd_inp (const struct llama_model * model); - LLAMA_API int32_t llama_model_n_embd_out (const struct llama_model * model); - LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model); - LLAMA_API int32_t llama_model_n_head (const struct llama_model * model); - LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model); - LLAMA_API int32_t llama_model_n_swa (const struct llama_model * model); + LLAMA_API int32_t llama_model_n_ctx_train (const struct llama_model * model); + LLAMA_API int32_t llama_model_n_embd (const struct llama_model * model); + LLAMA_API int32_t llama_model_n_embd_inp (const struct llama_model * model); + LLAMA_API int32_t llama_model_n_embd_out (const struct llama_model * model); + LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model); + LLAMA_API int32_t llama_model_n_layer_nextn(const struct llama_model * model); + LLAMA_API int32_t llama_model_n_head (const struct llama_model * model); + LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model); + LLAMA_API int32_t llama_model_n_swa (const struct llama_model * model); // Get the model's RoPE frequency scaling factor LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model); diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 529bc4a5e9..220240ea95 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1156,6 +1156,10 @@ void llama_context::set_embeddings_layer_inp(uint32_t lid, bool enable) { sched_need_reserve = true; } +void llama_context::set_nextn_layer_offset(int32_t offset) { + cparams.nextn_layer_offset = offset; +} + void llama_context::set_causal_attn(bool value) { LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); @@ -3699,6 +3703,10 @@ void llama_set_embeddings_layer_inp(llama_context * ctx, uint32_t lid, bool valu ctx->set_embeddings_layer_inp(lid, value); } +void llama_set_nextn_layer_offset(llama_context * ctx, int32_t offset) { + ctx->set_nextn_layer_offset(offset); +} + llama_memory_t llama_get_memory(const struct llama_context * ctx) { if (!ctx) { return nullptr; diff --git a/src/llama-context.h b/src/llama-context.h index 853052be2c..f8b7805871 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -115,6 +115,7 @@ struct llama_context { void set_embeddings (bool value); void set_embeddings_nextn(bool value, bool masked); void set_embeddings_layer_inp(uint32_t lid, bool enable); + void set_nextn_layer_offset(int32_t offset); void set_causal_attn(bool value); void set_warmup(bool value); diff --git a/src/llama-cparams.h b/src/llama-cparams.h index 2b109f909c..546ae1e2c1 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -18,6 +18,8 @@ struct llama_cparams { int32_t n_threads; // number of threads to use for generation int32_t n_threads_batch; // number of threads to use for batch processing + int32_t nextn_layer_offset = 0; + float rope_freq_base; float rope_freq_scale; diff --git a/src/llama-ext.h b/src/llama-ext.h index 8b5679b690..348bbae957 100644 --- a/src/llama-ext.h +++ b/src/llama-ext.h @@ -95,6 +95,11 @@ LLAMA_API llama_memory_breakdown llama_get_memory_breakdown(const struct llama_c // If masked == false, output the embeddings for all tokens in the batch regardless of batch.logits LLAMA_API void llama_set_embeddings_nextn(struct llama_context * ctx, bool value, bool masked); +// Select which appended NextN block the DECODER_MTP graph runs (offset past +// the trunk: il = n_layer() + offset). Used by the speculative NextN driver to +// chain multiple trained NextN heads. Default 0 (first head). +LLAMA_API void llama_set_nextn_layer_offset(struct llama_context * ctx, int32_t offset); + // mirrors: // LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); LLAMA_API float * llama_get_embeddings_nextn(struct llama_context * ctx); diff --git a/src/llama-graph.h b/src/llama-graph.h index 5e8a658350..a6e8c3985b 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -682,9 +682,16 @@ struct llm_graph_params { } } + // TODO: https://github.com/ggml-org/llama.cpp/pull/24340#discussion_r3448035248 + if (cparams.nextn_layer_offset != other.cparams.nextn_layer_offset) { + return false; + } + return - cparams.embeddings == other.cparams.embeddings && - cparams.causal_attn == other.cparams.causal_attn && + cparams.embeddings == other.cparams.embeddings && + cparams.embeddings_nextn == other.cparams.embeddings_nextn && + cparams.embeddings_nextn_masked == other.cparams.embeddings_nextn_masked && + cparams.causal_attn == other.cparams.causal_attn && arch == other.arch && gtype == other.gtype && cvec == other.cvec && diff --git a/src/llama-model.cpp b/src/llama-model.cpp index c528755339..d041a9ce3e 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2312,6 +2312,10 @@ int32_t llama_model_n_layer(const llama_model * model) { return model->hparams.n_layer(); } +int32_t llama_model_n_layer_nextn(const llama_model * model) { + return model->hparams.n_layer_nextn; +} + int32_t llama_model_n_head(const llama_model * model) { return model->hparams.n_head(); } diff --git a/src/models/step35.cpp b/src/models/step35.cpp index e2218c5870..9b7b18a367 100644 --- a/src/models/step35.cpp +++ b/src/models/step35.cpp @@ -112,7 +112,7 @@ void llama_model_step35::load_arch_tensors(llama_model_loader & ml) { layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, TENSOR_NOT_REQUIRED); }; - auto load_block_mtp = [&](int i, bool is_first_mtp) { + auto load_block_mtp = [&](int i) { auto & layer = layers[i]; const uint32_t n_head_l = hparams.n_head(i); @@ -121,15 +121,12 @@ void llama_model_step35::load_arch_tensors(llama_model_loader & ml) { // The MTP block is a full Step3p5 decoder layer (mtp_block) plus the // NextN-specific wiring (enorm/hnorm/eh_proj + optional shared head). - // `mtp_flags` becomes NOT_REQUIRED when the GGUF is trunk-only. - // - // Only the FIRST MTP block (i == n_main) is required for the - // single-block MTP runtime; trailing MTP blocks are always tolerated - // as missing so pruned GGUFs (block 0 only) load cleanly. Override - // mtp_flags to NOT_REQUIRED for those. - const int eff_mtp_flags = is_first_mtp ? mtp_flags : (mtp_flags | TENSOR_NOT_REQUIRED); + // Multi-block MTP: every declared MTP block is required (the draft chain + // runs all n_layer_nextn heads), so each block uses the captured + // `mtp_flags` directly — already NOT_REQUIRED for a trunk-only GGUF, + // which keeps that path correct. - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, eff_mtp_flags); + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, mtp_flags); layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED); layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED); @@ -140,12 +137,12 @@ void llama_model_step35::load_arch_tensors(llama_model_loader & ml) { layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | TENSOR_DUPLICATED); } - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_l, n_embd_k_gqa, n_embd_v_gqa, eff_mtp_flags); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v * n_head_l, n_embd}, eff_mtp_flags); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_l, n_embd_k_gqa, n_embd_v_gqa, mtp_flags); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v * n_head_l, n_embd}, mtp_flags); layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), {n_embd, n_head_l}, TENSOR_NOT_REQUIRED); - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, eff_mtp_flags); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, mtp_flags); // dense MLP (leading dense blocks) — present if the MTP block isn't MoE layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); @@ -165,9 +162,9 @@ void llama_model_step35::load_arch_tensors(llama_model_loader & ml) { layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, TENSOR_NOT_REQUIRED); // NextN-specific tensors that define the MTP block. - layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, eff_mtp_flags); - layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, eff_mtp_flags); - layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, eff_mtp_flags); + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, mtp_flags); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, mtp_flags); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, mtp_flags); layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED); @@ -176,13 +173,11 @@ void llama_model_step35::load_arch_tensors(llama_model_loader & ml) { for (int i = 0; i < n_layer; ++i) { load_block_trunk(i, trunk_flags); } - // Only the first MTP block (i == n_main) is required at runtime — the - // single-block-MTP graph in build_arch_graph always uses that one. - // Trailing MTP blocks are loaded if present (so an un-pruned GGUF with - // all MTP layers still works) but tolerated when absent via the pruning - // path. See scripts/prune_step35_extra_mtp.py for the pruner. + // All n_layer_nextn MTP blocks are required — the multi-block draft chain + // runs every head (head k at offset k). The GGUF declares the count via + // step35.nextn_predict_layers. for (int i = n_layer; i < n_layer_all; ++i) { - load_block_mtp(i, /*is_first_mtp=*/ i == n_layer); + load_block_mtp(i); } } @@ -372,13 +367,14 @@ llama_model_step35::graph_mtp::graph_mtp(const llama_model & model, const llm_gr : llm_graph_context(params) { GGML_ASSERT(hparams.n_layer_nextn > 0 && "STEP35 MTP requires n_layer_nextn > 0"); - // Single-block MTP only: always run the first trained MTP block (Qwen - // MTP / vLLM single-MTP-layer style). Multi-block round-robin proved to - // be a much deeper refactor than this PR justifies; the trailing MTP - // blocks are loaded with TENSOR_NOT_REQUIRED so pruned GGUFs (with just - // block 0) also work — see load_arch_tensors below and - // scripts/prune_step35_extra_mtp.py. - const int il = hparams.n_layer(); + // Multi-block MTP: the DECODER_MTP graph runs the MTP head selected by + // cparams.nextn_layer_offset (0 = first trained head). The speculative driver + // bumps the offset per draft step to chain heads 45->46->47. offset 0 keeps + // single-block behavior identical to before. + const int il = hparams.n_layer() + cparams.nextn_layer_offset; + GGML_ASSERT(cparams.nextn_layer_offset >= 0 && + cparams.nextn_layer_offset < (int) hparams.n_layer_nextn && + "nextn_layer_offset out of range [0, n_layer_nextn)"); const auto & layer = model.layers[il]; GGML_ASSERT(layer.nextn.eh_proj && "MTP block missing nextn.eh_proj"); @@ -536,6 +532,9 @@ llama_model_step35::graph_mtp::graph_mtp(const llama_model & model, const llm_gr cur = ggml_add(ctx0, cur, ffn_inp); cb(cur, "mtp_post_ffn", il); + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + // Pre-norm hidden state: used by the AR draft loop to seed the next MTP step. cb(cur, "h_nextn", -1); res->t_h_nextn = cur; From 8a118ee86c3b818ce7e1524e48fc7cc65f1dc69b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 21 Jun 2026 11:37:12 +0300 Subject: [PATCH 35/86] minor : clean-up whitespaces (#24862) [no ci] --- common/speculative.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index 3c38ae2b02..c922a3f592 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -991,7 +991,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { if (chain_heads) { this->params.n_max = std::min(this->params.n_max, n_mtp_layers); - + chain_h.assign(n_seq, {}); for (auto & c : chain_h) { c.reserve((size_t) (this->params.n_max + 1) * n_embd); From d6d899580dcf0e50a3d14453d3d082f6ed050450 Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Sun, 21 Jun 2026 11:58:14 +0200 Subject: [PATCH 36/86] server: real-time model load progress tracking via /models/sse (#24828) * server: real-time model load progress tracking via /models/sse * update docs * add mutex for notify_to_router * correct docs --- tools/server/README.md | 28 ++++++++++++++-- tools/server/server-context.cpp | 57 ++++++++++++++++++++++++++++++++- tools/server/server-models.cpp | 20 ++++++++++-- tools/server/server-models.h | 8 ++++- 4 files changed, 106 insertions(+), 7 deletions(-) diff --git a/tools/server/README.md b/tools/server/README.md index eb730e713a..5efdad0954 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -1859,9 +1859,33 @@ Example events: { "model": "...", - "event": "download_finished", + "event": "model_status", "data": { - "status": "loading" + "status": "loading", + "progress": { + "stage": "fit_params", + "value": 0.5 // from 0.0 to 1.0 ; note: not all stages have this "value" + } + } +} + +{ + "model": "...", + "event": "model_status", + "data": { + "status": "loaded", + "info": { + // note: only include info on first load + // waking up from sleep doesn't have this + } + } +} + +{ + "model": "...", + "event": "model_status", + "data": { + "status": "sleeping" } } diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 3de1335ec2..531b106e55 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -833,6 +833,8 @@ private: bool sleeping = false; + int64_t t_last_load_progress_ms = 0; + void destroy() { spec.reset(); ctx_dft.reset(); @@ -863,6 +865,30 @@ private: sleeping = new_state; } + static bool load_progress_callback(float progress, void * user_data) { + auto * ctx = static_cast(user_data); + GGML_ASSERT(ctx); + // always emit the first and final sample; throttle the rest to one per 200ms + { + auto & t_last = ctx->t_last_load_progress_ms; + const int64_t t_now = ggml_time_ms(); + const bool first = t_last == 0; + const bool done = progress >= 1.0f; + const bool throttled = !first && !done && (t_now - t_last) < 200; + if (throttled) { + return true; + } + t_last = t_now; + } + if (ctx->callback_state) { + ctx->callback_state(SERVER_STATE_LOADING, { + {"stage", "text_model"}, + {"value", progress}, + }); + } + return true; + } + // load the model and initialize llama_context // this may also be called to resume from sleeping state bool load_model(common_params & params) { @@ -916,6 +942,10 @@ private: // optionally reserve VRAM for the draft / MTP context before fitting the target model if (params_base.fit_params) { + if (callback_state) { + callback_state(SERVER_STATE_LOADING, {{"stage", "fit_params"}}); + } + const bool spec_mtp = std::find(params_base.speculative.types.begin(), params_base.speculative.types.end(), COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params_base.speculative.types.end(); @@ -991,6 +1021,13 @@ private: } } + // attach a progress callback + { + t_last_load_progress_ms = 0; + params_base.load_progress_callback = load_progress_callback; + params_base.load_progress_callback_user_data = this; + } + llama_init = common_init_from_params(params_base); model_tgt = llama_init->model(); @@ -1008,6 +1045,10 @@ private: add_bos_token = llama_vocab_get_add_bos(vocab); if (params_base.speculative.has_dft()) { + if (callback_state) { + callback_state(SERVER_STATE_LOADING, {{"stage", "spec_model"}}); + } + // TODO speculative: move to common/speculative.cpp? const auto & params_spec = params_base.speculative.draft; @@ -1079,6 +1120,10 @@ private: } if (has_mmproj) { + if (callback_state) { + callback_state(SERVER_STATE_LOADING, {{"stage", "mmproj_model"}}); + } + if (!is_resume) { mtmd_helper_log_set(common_log_default_callback, nullptr); } @@ -1259,6 +1304,10 @@ private: return init(); } + if (callback_state) { + callback_state(SERVER_STATE_READY, {}); + } + return true; } @@ -1335,6 +1384,9 @@ private: const bool enable_thinking = params_base.enable_reasoning != 0 && template_supports_thinking; SRV_INF("%s: chat template, thinking = %d\n", __func__, enable_thinking); + // IMPORTANT: chat_params is reused across sleeping / resuming states, + // never store llama_context/llama_model pointers in chat_params, + // as they may be invalidated after sleeping chat_params = { /* use_jinja */ params_base.use_jinja, /* prefill_assistant */ params_base.prefill_assistant, @@ -3734,7 +3786,10 @@ struct server_res_generator : server_http_res { void server_context::set_state_callback(server_state_callback_t callback) { impl->callback_state = std::move(callback); impl->queue_tasks.on_sleeping_state([this](bool sleeping) { - impl->callback_state(sleeping ? SERVER_STATE_SLEEPING : SERVER_STATE_READY, {}); + if (sleeping) { + impl->callback_state(SERVER_STATE_SLEEPING, {}); + } + // for sleeping == false, event is emitted by load_model() }); } diff --git a/tools/server/server-models.cpp b/tools/server/server-models.cpp index a569c8be3c..68eefdffac 100644 --- a/tools/server/server-models.cpp +++ b/tools/server/server-models.cpp @@ -442,6 +442,7 @@ void server_models::load_models() { /* last_used */ 0, /* args */ std::vector(), /* loaded_info */ {}, + /* progress */ {}, /* exit_code */ 0, /* stop_timeout */ DEFAULT_STOP_TIMEOUT, /* multimodal */ mtmd_caps{false, false}, @@ -608,6 +609,7 @@ void server_models::load_models() { /* last_used */ 0, /* args */ std::vector(), /* loaded_info */ {}, + /* progress */ {}, /* exit_code */ 0, /* stop_timeout */ DEFAULT_STOP_TIMEOUT, /* multimodal */ mtmd_caps{false, false}, @@ -1140,6 +1142,9 @@ void server_models::update_status(const std::string & name, const update_status_ if (!args.loaded_info.is_null()) { meta.loaded_info = args.loaded_info; } + if (!args.progress.is_null()) { + meta.progress = args.progress; + } } // broadcast status change to SSE { @@ -1152,6 +1157,9 @@ void server_models::update_status(const std::string & name, const update_status_ if (!args.loaded_info.is_null()) { data["info"] = args.loaded_info; } + if (!args.progress.is_null()) { + data["progress"] = args.progress; + } // note: notify_sse doesn't acquire the lock, so no deadlock here notify_sse("status_change", name, data); } @@ -1322,8 +1330,12 @@ void server_models::handle_child_state(const std::string & name, const std::stri switch (state) { case SERVER_STATE_LOADING: { - // do nothing for now - // TODO: report loading progress for first load and wakeup from sleep + update_status(name, { + SERVER_MODEL_STATUS_LOADING, + 0, + nullptr, // no loaded_info yet + payload, + }); } break; case SERVER_STATE_READY: { @@ -1331,7 +1343,8 @@ void server_models::handle_child_state(const std::string & name, const std::stri SERVER_MODEL_STATUS_LOADED, 0, // note: payload can be empty if this is a wakeup from sleep - payload.size() > 0 ? payload : nullptr + payload.size() > 0 ? payload : nullptr, + {}, // reset progress info }); } break; case SERVER_STATE_SLEEPING: @@ -1384,6 +1397,7 @@ void server_child::notify_to_router(const std::string & state, const json & payl {"state", state}, {"payload", payload}, }; + std::lock_guard lk(mtx_stdout); common_log_pause(common_log_main()); fflush(stdout); fprintf(stdout, "%s%s\n", CMD_CHILD_TO_ROUTER_STATE, safe_json_to_str(data).c_str()); diff --git a/tools/server/server-models.h b/tools/server/server-models.h index 40a0e078c6..17759b00a5 100644 --- a/tools/server/server-models.h +++ b/tools/server/server-models.h @@ -72,6 +72,7 @@ struct server_model_meta { int64_t last_used = 0; // for LRU unloading std::vector args; // args passed to the model instance, will be populated by render_args() json loaded_info; // info to be reflected via /v1/models endpoint ; if in DOWNLOADING state, it should contain download progress info + json progress; // reflect load or download progress info, if any int exit_code = 0; // exit code of the model instance process (only valid if status == FAILED) int stop_timeout = 0; // seconds to wait before force-killing the model instance during shutdown mtmd_caps multimodal; // multimodal capabilities @@ -170,12 +171,14 @@ public: // to stop the download, call unload() void download(common_params_model && model, common_download_opts && opts); - // update the status of a model instance (thread-safe) struct update_status_args { server_model_status status; int exit_code = 0; // only valid if status == UNLOADED json loaded_info = nullptr; + json progress = nullptr; }; + // update the status of a model instance (thread-safe) + // also send SSE notification to /models/sse endpoint void update_status(const std::string & name, const update_status_args & args); void update_download_progress(const std::string & name, const common_download_progress & progress, bool done, bool ok = true); @@ -208,6 +211,9 @@ public: }; struct server_child { + // serializes the notify_to_router writes + std::mutex mtx_stdout; + // return true if the current process is a child server instance bool is_child(); From bfa3219177c81bbf9f38939901656d60a745eb7e Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Sun, 21 Jun 2026 13:03:14 +0200 Subject: [PATCH 37/86] server: add "verbose" field to schema (#24864) --- tools/server/server-schema.cpp | 3 +++ .../server/tests/unit/test_chat_completion.py | 20 +++++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/tools/server/server-schema.cpp b/tools/server/server-schema.cpp index d5d747a654..ed4bda2412 100644 --- a/tools/server/server-schema.cpp +++ b/tools/server/server-schema.cpp @@ -14,6 +14,9 @@ std::vector> make_llama_cmpl_schema(const common_params & fields.emplace_back(f); }; + add((new field_bool("verbose", params.verbose)) + ->set_desc("Include __verbose field in the response with additional debug information")); + add((new field_bool("timings_per_token", params.timings_per_token)) ->set_desc("Include prompt processing and text generation speed information in each response")); diff --git a/tools/server/tests/unit/test_chat_completion.py b/tools/server/tests/unit/test_chat_completion.py index b00aac649d..0258b539ed 100644 --- a/tools/server/tests/unit/test_chat_completion.py +++ b/tools/server/tests/unit/test_chat_completion.py @@ -603,3 +603,23 @@ def test_chat_completions_token_count(): }) assert res.status_code == 200 assert res.body["input_tokens"] > 5 + + +def test_verbose_debug(): + global server + server.start() + for verbose in [True, False]: + res = server.make_request("POST", "/chat/completions", data={ + "max_tokens": 2, + "messages": [ + {"role": "system", "content": "Book"}, + {"role": "user", "content": "What is the best book"}, + ], + "verbose": verbose, + }) + assert res.status_code == 200 + if verbose: + assert "__verbose" in res.body + assert "Book" in res.body["__verbose"]["prompt"] + else: + assert "__verbose" not in res.body From 2f89acc2bc614dc121db065a74e503bf88668951 Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Sun, 21 Jun 2026 13:40:52 +0200 Subject: [PATCH 38/86] mtmd: add load progress callback (#24865) --- tools/mtmd/clip.cpp | 35 +++++++++++++++++++++++++++++++-- tools/mtmd/clip.h | 2 ++ tools/mtmd/mtmd.cpp | 8 ++++++++ tools/mtmd/mtmd.h | 8 ++++++++ tools/server/server-context.cpp | 33 ++++++++++++++++++------------- 5 files changed, 70 insertions(+), 16 deletions(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index c713703e01..fccc1e3487 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -1045,8 +1045,17 @@ struct clip_model_loader { bool has_vision = false; bool has_audio = false; + mtmd_progress_callback progress_callback = nullptr; + void * progress_callback_user_data = nullptr; + // TODO @ngxson : we should not pass clip_ctx here, it should be clip_model - clip_model_loader(const char * fname, bool skip_tensors = false) : fname(fname) { + clip_model_loader(const char * fname, + bool skip_tensors = false, + mtmd_progress_callback progress_cb = nullptr, + void * progress_user_data = nullptr) + : fname(fname), + progress_callback(progress_cb), + progress_callback_user_data(progress_user_data) { struct ggml_context * meta = nullptr; struct gguf_init_params params = { @@ -2790,10 +2799,22 @@ struct clip_model_loader { if (!ctx_clip.no_alloc) { std::vector read_buf; + // start loading event + if (progress_callback){ + progress_callback(0.0, progress_callback_user_data); + } + + // compute total tensor data size for progress reporting + size_t total_data_size = 0; + for (auto & t : tensors_to_load) { + total_data_size += ggml_nbytes(t); + } + // alloc memory and offload data ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(ctx_clip.backend); ctx_clip.buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(ctx_clip.ctx_data.get(), buft)); ggml_backend_buffer_set_usage(ctx_clip.buf.get(), GGML_BACKEND_BUFFER_USAGE_WEIGHTS); + size_t data_loaded = 0; for (auto & t : tensors_to_load) { ggml_tensor * cur = ggml_get_tensor(ctx_clip.ctx_data.get(), t->name); GGML_ASSERT(cur && "tensor not found in ctx_data"); @@ -2814,6 +2835,13 @@ struct clip_model_loader { fin.read(reinterpret_cast(read_buf.data()), num_bytes); ggml_backend_tensor_set(cur, read_buf.data(), 0, num_bytes); } + data_loaded += num_bytes; + if (progress_callback && total_data_size > 0) { + const float progress = (float)data_loaded / (float)total_data_size; + if (!progress_callback(progress, progress_callback_user_data)) { + throw std::runtime_error(string_format("%s: model loading cancelled by progress_callback\n", __func__)); + } + } } fin.close(); @@ -3105,7 +3133,10 @@ struct clip_init_result clip_init(const char * fname, struct clip_context_params clip_ctx * ctx_audio = nullptr; try { - clip_model_loader loader(fname); + clip_model_loader loader(fname, + /* skip_tensors */ false, + ctx_params.progress_callback, + ctx_params.progress_callback_user_data); bool skip_audio = false; if (loader.has_vision) { diff --git a/tools/mtmd/clip.h b/tools/mtmd/clip.h index e0f1d298c8..967093a812 100644 --- a/tools/mtmd/clip.h +++ b/tools/mtmd/clip.h @@ -54,6 +54,8 @@ struct clip_context_params { ggml_backend_sched_eval_callback cb_eval; void * cb_eval_user_data; bool no_alloc; + mtmd_progress_callback progress_callback; + void * progress_callback_user_data; }; struct clip_init_result { diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index cbaac1d377..564bafc621 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -251,6 +251,8 @@ mtmd_context_params mtmd_context_params_default() { /* cb_eval */ nullptr, /* cb_eval_user_data */ nullptr, /* batch_max_tokens */ 1024, + /* progress_callback */ nullptr, + /* progress_callback_user_data */ nullptr, }; return params; } @@ -345,6 +347,8 @@ struct mtmd_context { /* cb_eval */ ctx_params.cb_eval, /* cb_eval_user_data */ ctx_params.cb_eval_user_data, /* no_alloc */ no_alloc, + /* progress_callback */ ctx_params.progress_callback, + /* progress_callback_user_data */ ctx_params.progress_callback_user_data, }; auto res = clip_init(mmproj_fname, ctx_clip_params); @@ -2133,8 +2137,12 @@ std::map mtmd_get_memory_usage(const char * mmproj_f mtmd::context_ptr ctx; auto saved_log_callback = g_logger_state.log_callback; auto saved_log_user_data = g_logger_state.log_callback_user_data; + + ctx_params.progress_callback = nullptr; + try { mtmd_log_set(stub_log_callback, nullptr); // suppress logging + // TODO @ngxson : fix no_alloc here ctx.reset(new mtmd_context(mmproj_fname, nullptr, ctx_params)); mtmd_log_set(saved_log_callback, saved_log_user_data); // restore log callback std::map total_mem; diff --git a/tools/mtmd/mtmd.h b/tools/mtmd/mtmd.h index 2fd149e480..25d51ef58d 100644 --- a/tools/mtmd/mtmd.h +++ b/tools/mtmd/mtmd.h @@ -83,6 +83,8 @@ typedef struct mtmd_input_chunks mtmd_input_chunks; typedef struct mtmd_input_text mtmd_input_text; typedef struct mtmd_batch mtmd_batch; +typedef bool (*mtmd_progress_callback)(float progress, void * user_data); + struct mtmd_context_params { bool use_gpu; bool print_timings; @@ -104,6 +106,12 @@ struct mtmd_context_params { int32_t batch_max_tokens; // maximum number of output tokens in a batch // (note: this is not a hard-limit, the first image will always be added even if it exceeds this limit) // (default: 1024) + + // Called with a progress value between 0.0 and 1.0. Pass NULL to disable. + // If the provided progress_callback returns true, model loading continues. + // If it returns false, model loading is immediately aborted. + mtmd_progress_callback progress_callback; + void * progress_callback_user_data; }; MTMD_API const char * mtmd_default_marker(void); diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 531b106e55..7db4cb1986 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -833,8 +833,6 @@ private: bool sleeping = false; - int64_t t_last_load_progress_ms = 0; - void destroy() { spec.reset(); ctx_dft.reset(); @@ -865,12 +863,18 @@ private: sleeping = new_state; } + struct load_progress_data { + server_context_impl * ctx; + std::string stage; + int64_t t_last_load_progress_ms = 0; + load_progress_data(server_context_impl * ctx, const std::string & stage) : ctx(ctx), stage(stage) {} + }; static bool load_progress_callback(float progress, void * user_data) { - auto * ctx = static_cast(user_data); - GGML_ASSERT(ctx); + auto * d = static_cast(user_data); + GGML_ASSERT(d); // always emit the first and final sample; throttle the rest to one per 200ms { - auto & t_last = ctx->t_last_load_progress_ms; + auto & t_last = d->t_last_load_progress_ms; const int64_t t_now = ggml_time_ms(); const bool first = t_last == 0; const bool done = progress >= 1.0f; @@ -880,9 +884,9 @@ private: } t_last = t_now; } - if (ctx->callback_state) { - ctx->callback_state(SERVER_STATE_LOADING, { - {"stage", "text_model"}, + if (d->ctx->callback_state) { + d->ctx->callback_state(SERVER_STATE_LOADING, { + {"stage", d->stage}, {"value", progress}, }); } @@ -892,6 +896,9 @@ private: // load the model and initialize llama_context // this may also be called to resume from sleeping state bool load_model(common_params & params) { + load_progress_data load_progress_text(this, "text_model"); + load_progress_data load_progress_mmproj(this, "mmproj_model"); + bool is_resume = sleeping; SRV_INF("loading model '%s'\n", params.model.path.c_str()); @@ -912,6 +919,9 @@ private: mparams.image_max_tokens = params_base.image_max_tokens; mparams.batch_max_tokens = params_base.mtmd_batch_max_tokens; mparams.media_marker = get_media_marker(); + // progress callback + mparams.progress_callback = load_progress_callback; + mparams.progress_callback_user_data = &load_progress_mmproj; } // optionally get the memory usage of mmproj @@ -1023,9 +1033,8 @@ private: // attach a progress callback { - t_last_load_progress_ms = 0; params_base.load_progress_callback = load_progress_callback; - params_base.load_progress_callback_user_data = this; + params_base.load_progress_callback_user_data = &load_progress_text; } llama_init = common_init_from_params(params_base); @@ -1120,10 +1129,6 @@ private: } if (has_mmproj) { - if (callback_state) { - callback_state(SERVER_STATE_LOADING, {{"stage", "mmproj_model"}}); - } - if (!is_resume) { mtmd_helper_log_set(common_log_default_callback, nullptr); } From bf533823cd06e7fb21552265eee1bf2fd2752974 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= <1629204+CISC@users.noreply.github.com> Date: Sun, 21 Jun 2026 14:04:52 +0200 Subject: [PATCH 39/86] jinja : implement call statement (#24847) * implement call statement * undo unintended change * de-lambda * simplify * move caller context inside function handler --- common/jinja/runtime.cpp | 135 ++++++++++++++++++++++++++------------- common/jinja/runtime.h | 1 + tests/test-jinja.cpp | 26 ++++++++ 3 files changed, 116 insertions(+), 46 deletions(-) diff --git a/common/jinja/runtime.cpp b/common/jinja/runtime.cpp index 1fae7884e1..f98cb0876f 100644 --- a/common/jinja/runtime.cpp +++ b/common/jinja/runtime.cpp @@ -686,59 +686,62 @@ value set_statement::execute_impl(context & ctx) { return mk_val(); } +static inline void bind_parameters(const std::string & name, const statements & this_args, const func_args & args, context & ctx) { + const size_t expected_count = this_args.size(); + const size_t input_count = args.count(); + + JJ_DEBUG("Invoking '%s' with %zu input arguments (expected %zu)", name.c_str(), input_count, expected_count); + for (size_t i = 0; i < expected_count; ++i) { + if (i < input_count) { + if (is_stmt(this_args[i])) { + // normal parameter + std::string param_name = cast_stmt(this_args[i])->val; + value param_value = args.get_kwarg_or_pos(param_name, i); + JJ_DEBUG(" Binding parameter '%s' to argument of type %s", param_name.c_str(), param_value->type().c_str()); + ctx.set_val(param_name, param_value); + } else if (is_stmt(this_args[i])) { + // default argument used as normal parameter + auto kwarg = cast_stmt(this_args[i]); + if (!is_stmt(kwarg->key)) { + throw std::runtime_error("Keyword argument key must be an identifier in '" + name + "'"); + } + std::string param_name = cast_stmt(kwarg->key)->val; + value param_value = args.get_kwarg_or_pos(param_name, i); + JJ_DEBUG(" Binding parameter '%s' to argument of type %s", param_name.c_str(), param_value->type().c_str()); + ctx.set_val(param_name, param_value); + } else { + throw std::runtime_error("Invalid parameter type in '" + name + "'"); + } + } else { + auto & default_arg = this_args[i]; + if (is_stmt(default_arg)) { + auto kwarg = cast_stmt(default_arg); + if (!is_stmt(kwarg->key)) { + throw std::runtime_error("Keyword argument key must be an identifier in '" + name + "'"); + } + std::string param_name = cast_stmt(kwarg->key)->val; + JJ_DEBUG(" Binding parameter '%s' to default argument of type %s", param_name.c_str(), kwarg->val->type().c_str()); + ctx.set_val(param_name, kwarg->val->execute(args.ctx)); + } else { + throw std::runtime_error("Not enough arguments provided to '" + name + "'"); + } + //std::string param_name = cast_stmt(default_args[i])->val; + //JJ_DEBUG(" Binding parameter '%s' to default", param_name.c_str()); + //ctx.var[param_name] = default_args[i]->execute(ctx); + } + } +} + value macro_statement::execute_impl(context & ctx) { if (!is_stmt(this->name)) { throw std::runtime_error("Macro name must be an identifier"); } std::string name = cast_stmt(this->name)->val; - const func_handler func = [this, name, &ctx](const func_args & args) -> value { - size_t expected_count = this->args.size(); - size_t input_count = args.count(); + const func_handler func = [this, name](const func_args & args) -> value { + context macro_ctx(args.ctx); // new scope for macro execution - JJ_DEBUG("Invoking macro '%s' with %zu input arguments (expected %zu)", name.c_str(), input_count, expected_count); - context macro_ctx(ctx); // new scope for macro execution - - // bind parameters - for (size_t i = 0; i < expected_count; ++i) { - if (i < input_count) { - if (is_stmt(this->args[i])) { - // normal parameter - std::string param_name = cast_stmt(this->args[i])->val; - value param_value = args.get_kwarg_or_pos(param_name, i); - JJ_DEBUG(" Binding parameter '%s' to argument of type %s", param_name.c_str(), param_value->type().c_str()); - macro_ctx.set_val(param_name, param_value); - } else if (is_stmt(this->args[i])) { - // default argument used as normal parameter - auto kwarg = cast_stmt(this->args[i]); - if (!is_stmt(kwarg->key)) { - throw std::runtime_error("Keyword argument key must be an identifier in macro '" + name + "'"); - } - std::string param_name = cast_stmt(kwarg->key)->val; - value param_value = args.get_kwarg_or_pos(param_name, i); - JJ_DEBUG(" Binding parameter '%s' to argument of type %s", param_name.c_str(), param_value->type().c_str()); - macro_ctx.set_val(param_name, param_value); - } else { - throw std::runtime_error("Invalid parameter type in macro '" + name + "'"); - } - } else { - auto & default_arg = this->args[i]; - if (is_stmt(default_arg)) { - auto kwarg = cast_stmt(default_arg); - if (!is_stmt(kwarg->key)) { - throw std::runtime_error("Keyword argument key must be an identifier in macro '" + name + "'"); - } - std::string param_name = cast_stmt(kwarg->key)->val; - JJ_DEBUG(" Binding parameter '%s' to default argument of type %s", param_name.c_str(), kwarg->val->type().c_str()); - macro_ctx.set_val(param_name, kwarg->val->execute(ctx)); - } else { - throw std::runtime_error("Not enough arguments provided to macro '" + name + "'"); - } - //std::string param_name = cast_stmt(default_args[i])->val; - //JJ_DEBUG(" Binding parameter '%s' to default", param_name.c_str()); - //macro_ctx.var[param_name] = default_args[i]->execute(ctx); - } - } + bind_parameters(name, this->args, args, macro_ctx); // execute macro body JJ_DEBUG("Executing macro '%s' body with %zu statements", name.c_str(), this->body.size()); @@ -752,6 +755,46 @@ value macro_statement::execute_impl(context & ctx) { return mk_val(); } +value call_statement::execute_impl(context & ctx) { + auto call_expr = cast_stmt(this->call); + if (!call_expr) { + throw std::runtime_error("Call statement requires a valid call expression"); + } + + value callee_val = call_expr->callee->execute(ctx); + if (!is_val(callee_val)) { + throw std::runtime_error("Callee is not a function: got " + callee_val->type()); + } + auto * callee_func = cast_val(callee_val); + + context caller_ctx(ctx); // new scope for caller execution + + const func_handler func = [this, caller_ctx = std::move(caller_ctx)](const func_args & args) -> value { + context block_ctx(caller_ctx); // new scope for block execution + + bind_parameters("caller", this->caller_args, args, block_ctx); + + JJ_DEBUG("Executing call body with %zu statements", this->body.size()); + auto res = exec_statements(this->body, block_ctx); + JJ_DEBUG("Call body execution complete, result: %s", res->val_str.str().c_str()); + return res; + }; + + context call_ctx(ctx); + call_ctx.set_val("caller", mk_val("caller", func)); + + func_args args(call_ctx); + + for (const auto & arg_expr : call_expr->args) { + auto arg_val = arg_expr->execute(ctx); + JJ_DEBUG(" Argument type: %s", arg_val->type().c_str()); + args.push_back(arg_val); + } + + JJ_DEBUG("Calling macro '%s' with %zu arguments", callee_func->name.c_str(), args.count()); + return callee_func->invoke(args); +} + value member_expression::execute_impl(context & ctx) { value object = this->object->execute(ctx); diff --git a/common/jinja/runtime.h b/common/jinja/runtime.h index b6f4a6ab48..37b4c35cac 100644 --- a/common/jinja/runtime.h +++ b/common/jinja/runtime.h @@ -552,6 +552,7 @@ struct call_statement : public statement { for (const auto & arg : this->caller_args) chk_type(arg); } std::string type() const override { return "CallStatement"; } + value execute_impl(context & ctx) override; }; struct ternary_expression : public expression { diff --git a/tests/test-jinja.cpp b/tests/test-jinja.cpp index 8039956246..81bbcd55a4 100644 --- a/tests/test-jinja.cpp +++ b/tests/test-jinja.cpp @@ -995,6 +995,32 @@ static void test_macros(testing & t) { json::object(), "Hello, John Smith,Hi, Jane Doe" ); + + test_template(t, "macro with caller", + "\ +{%- macro nest_dict(o, i, ff='') %}\n\ + {{- caller(ff) }}\n\ + {%- for k, v in o|items %}\n\ + {{- i + k + ': ' }}\n\ + {%- if v is mapping %}\n\ + {{- '{' }}\n\ + {% call(f) nest_dict(v, i + ' ') %}\n\ + {{- 'fail' if ff is undefined }}\n\ + {%- endcall %}\n\ + {{- i + '}' }}\n\ + {% else %}\n\ + {{- v|string }}\n\ + {% endif %}\n\ + {%- endfor %}\n\ +{%- endmacro %}\n\ +{%- call(f) nest_dict({'root1': 1, 'root2': {'nest1': 1, 'nest2': {'nest3': 2}}}, ' ', 'Dict') %}\n\ + {{- 'fail' if ff is defined }}\n\ + {{- f + ' {' }}\n\ +{% endcall %}\n\ +{{- '}' }}", + json::object(), + "Dict {\n root1: 1\n root2: {\n nest1: 1\n nest2: {\n nest3: 2\n }\n }\n}" + ); } static void test_namespace(testing & t) { From 0d135df48ccee9a799fa9d9ea0ed494bd4fdd74f Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Sun, 21 Jun 2026 14:12:15 +0200 Subject: [PATCH 40/86] mtmd: fix mtmd_get_memory_usage (#24867) --- tools/mtmd/clip.cpp | 62 ++++++++++++++++++--------------- tools/mtmd/mtmd.cpp | 3 +- tools/server/server-context.cpp | 4 ++- 3 files changed, 37 insertions(+), 32 deletions(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index fccc1e3487..7dd7023c41 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -2796,7 +2796,7 @@ struct clip_model_loader { } // load data - if (!ctx_clip.no_alloc) { + { std::vector read_buf; // start loading event @@ -2814,38 +2814,42 @@ struct clip_model_loader { ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(ctx_clip.backend); ctx_clip.buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(ctx_clip.ctx_data.get(), buft)); ggml_backend_buffer_set_usage(ctx_clip.buf.get(), GGML_BACKEND_BUFFER_USAGE_WEIGHTS); - size_t data_loaded = 0; - for (auto & t : tensors_to_load) { - ggml_tensor * cur = ggml_get_tensor(ctx_clip.ctx_data.get(), t->name); - GGML_ASSERT(cur && "tensor not found in ctx_data"); - auto it_off = tensor_offset.find(t->name); - GGML_ASSERT(it_off != tensor_offset.end() && "no offset for tensor"); - const size_t offset = it_off->second; - fin.seekg(offset, std::ios::beg); - if (!fin) { - throw std::runtime_error(string_format("%s: failed to seek for tensor %s\n", __func__, t->name)); - } - size_t num_bytes = ggml_nbytes(cur); - if (ggml_backend_buft_is_host(buft)) { - // for the CPU and Metal backend, we can read directly into the tensor - fin.read(reinterpret_cast(cur->data), num_bytes); - } else { - // read into a temporary buffer first, then copy to device memory - read_buf.resize(num_bytes); - fin.read(reinterpret_cast(read_buf.data()), num_bytes); - ggml_backend_tensor_set(cur, read_buf.data(), 0, num_bytes); - } - data_loaded += num_bytes; - if (progress_callback && total_data_size > 0) { - const float progress = (float)data_loaded / (float)total_data_size; - if (!progress_callback(progress, progress_callback_user_data)) { - throw std::runtime_error(string_format("%s: model loading cancelled by progress_callback\n", __func__)); + // read the weight from file + if (!ctx_clip.no_alloc) { + size_t data_loaded = 0; + for (auto & t : tensors_to_load) { + ggml_tensor * cur = ggml_get_tensor(ctx_clip.ctx_data.get(), t->name); + GGML_ASSERT(cur && "tensor not found in ctx_data"); + auto it_off = tensor_offset.find(t->name); + GGML_ASSERT(it_off != tensor_offset.end() && "no offset for tensor"); + const size_t offset = it_off->second; + fin.seekg(offset, std::ios::beg); + if (!fin) { + throw std::runtime_error(string_format("%s: failed to seek for tensor %s\n", __func__, t->name)); + } + size_t num_bytes = ggml_nbytes(cur); + if (ggml_backend_buft_is_host(buft)) { + // for the CPU and Metal backend, we can read directly into the tensor + fin.read(reinterpret_cast(cur->data), num_bytes); + } else { + // read into a temporary buffer first, then copy to device memory + read_buf.resize(num_bytes); + fin.read(reinterpret_cast(read_buf.data()), num_bytes); + ggml_backend_tensor_set(cur, read_buf.data(), 0, num_bytes); + } + data_loaded += num_bytes; + if (progress_callback && total_data_size > 0) { + const float progress = (float)data_loaded / (float)total_data_size; + if (!progress_callback(progress, progress_callback_user_data)) { + throw std::runtime_error(string_format("%s: model loading cancelled by progress_callback\n", __func__)); + } } } + LOG_DBG("%s: loaded %zu tensors from %s\n", __func__, tensors_to_load.size(), fname.c_str()); + } else { + LOG_DBG("%s: no_alloc is set, skipping tensor data loading (%zu tensors)\n", __func__, tensors_to_load.size()); } fin.close(); - - LOG_DBG("%s: loaded %zu tensors from %s\n", __func__, tensors_to_load.size(), fname.c_str()); } } diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index 564bafc621..724538b585 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -2142,8 +2142,7 @@ std::map mtmd_get_memory_usage(const char * mmproj_f try { mtmd_log_set(stub_log_callback, nullptr); // suppress logging - // TODO @ngxson : fix no_alloc here - ctx.reset(new mtmd_context(mmproj_fname, nullptr, ctx_params)); + ctx.reset(new mtmd_context(mmproj_fname, nullptr, ctx_params, true)); mtmd_log_set(saved_log_callback, saved_log_user_data); // restore log callback std::map total_mem; auto merge = [&](const struct clip_ctx * c) { diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 7db4cb1986..aeb15096c8 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -926,13 +926,15 @@ private: // optionally get the memory usage of mmproj if (has_mmproj && params_base.fit_params) { + int64_t t_start = ggml_time_us(); auto mmproj_mem = mtmd_get_memory_usage(mmproj_path.c_str(), mparams); + int64_t t_elapsed = ggml_time_us() - t_start; if (!mmproj_mem.empty()) { size_t total = 0; for (auto & [dev, size] : mmproj_mem) { total += size; } - SRV_INF("[mtmd] estimated worst-case memory usage of mmproj is %.2f MiB\n", total / (1024.0 * 1024.0)); + SRV_INF("[mtmd] estimated worst-case memory usage of mmproj is %.2f MiB (took %.2f ms)\n", total / (1024.0 * 1024.0), t_elapsed / 1000.0); GGML_ASSERT(!params_base.fit_params_target.empty()); for (auto & [dev, size] : mmproj_mem) { for (size_t i = 0; i < ggml_backend_dev_count(); i++) { From bddfd2b1137cd6e51fbb939081caf50e9f496a66 Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Sun, 21 Jun 2026 14:16:11 +0200 Subject: [PATCH 41/86] server: refactor batch construction (#24843) * server: refactor batch construction * wip * wip 2 * wip 3 * wip 4 * add abort_all_slots * handle batch full more carefully * fix assert * rm debug log * small nits * (debug) add timings * debug: force llama_synchronize for accurate timings * address comments * disable DEBUG_TIMINGS --- tools/server/server-context.cpp | 934 ++++++++++++++++++++------------ 1 file changed, 583 insertions(+), 351 deletions(-) diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index aeb15096c8..91a8eb9452 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -63,6 +63,99 @@ enum slot_state { SLOT_STATE_GENERATING, }; +struct server_slot; // forward declaration + +struct server_batch { + llama_batch batch; + bool batch_rendered = false; + + struct token { + int32_t id_slot; + llama_token token; + llama_pos pos; + bool output; + }; + std::vector tokens; + int32_t n_tokens_alloc = 0; + + // track if given slot can be batched with slots already in the batch + server_slot * slot_batched = nullptr; + + float alora_scale = -1.0f; + size_t alora_disabled_id = 0; + + server_batch() { + batch.token = nullptr; // sentinel: uninitialized batch + } + + ~server_batch() { + llama_batch_free(batch); + } + + void init(int32_t n_tokens_alloc) { + this->n_tokens_alloc = n_tokens_alloc; + batch = llama_batch_init(n_tokens_alloc, 0, 1); + tokens.reserve(n_tokens_alloc); + } + + bool add(int32_t id_slot, llama_token token, llama_pos pos, bool output) { + GGML_ASSERT(batch.token != nullptr); + if ((int32_t)tokens.size() >= n_tokens_alloc) { + return false; + } + // LOG_INF("adding token to batch: slot=%d, token=%d, pos=%d, output=%d\n", id_slot, token, pos, output); + tokens.push_back({ id_slot, token, pos, output }); + return true; + } + + void clear() { + tokens.clear(); + common_batch_clear(batch); + slot_batched = nullptr; + alora_scale = -1.0f; + alora_disabled_id = 0; + batch_rendered = false; + } + + int32_t size() const { + return (int32_t)tokens.size(); + } + + void set_output(int32_t idx, bool output) { + GGML_ASSERT(idx >= 0 && idx < (int32_t)tokens.size()); + tokens[idx].output = output; + } + + void render() { + GGML_ASSERT(batch.token != nullptr); + common_batch_clear(batch); + for (int32_t i = 0; i < size(); i++) { + const auto & t = tokens[i]; + common_batch_add(batch, t.token, t.pos, { t.id_slot }, t.output); + } + batch_rendered = true; + } + + llama_batch get_view(int32_t off, int32_t n_tokens) const { + GGML_ASSERT(batch.token != nullptr); + GGML_ASSERT(batch_rendered); + GGML_ASSERT(off >= 0 && off < size()); + GGML_ASSERT(n_tokens > 0 && off + n_tokens <= size()); + + llama_batch view = { + n_tokens, + batch.token + off, + nullptr, + batch.pos + off, + batch.n_seq_id + off, + batch.seq_id + off, + batch.logits + off, + }; + + return view; + } +}; + struct server_slot { int id; @@ -185,6 +278,7 @@ struct server_slot { // stats size_t n_sent_text = 0; // number of sent text character + // TODO @ngxson : move all metrics to a sub-struct for clarity int64_t t_start_process_prompt; int64_t t_start_generation; int64_t t_print_last = 0; @@ -348,12 +442,14 @@ struct server_slot { return n_draft_max; } - void update_batch(llama_batch & batch) { + // add sampled token of this slot to the batch, optionally add the speculative draft tokens if any + void handle_last_sampled_token(server_batch & batch) { + bool add_ok = true; if (spec_draft.empty()) { // no speculative decoding - i_batch = batch.n_tokens; + i_batch = batch.size(); - common_batch_add(batch, sampled, prompt.tokens.pos_next(), { this->id }, true); + add_ok &= batch.add(id, sampled, prompt.tokens.pos_next(), true); SLT_DBG(*this, "slot decode token, id=%d, n_ctx = %d, n_tokens = %d, truncated = %d\n", sampled, n_ctx, prompt.n_tokens(), truncated); @@ -363,19 +459,21 @@ struct server_slot { GGML_ASSERT(spec_i_batch.empty()); - spec_i_batch.push_back(batch.n_tokens); + spec_i_batch.push_back(batch.size()); for (size_t i = 0; i < spec_draft.size(); i++) { - spec_i_batch.push_back(batch.n_tokens + i + 1); + spec_i_batch.push_back(batch.size() + i + 1); } auto pos0 = prompt.tokens.pos_next(); - common_batch_add(batch, sampled, pos0++, { this->id }, true); + add_ok &= batch.add(id, sampled, pos0++, true); for (auto token : spec_draft) { - common_batch_add(batch, token, pos0++, { this->id }, true); + add_ok &= batch.add(this->id, token, pos0++, true); } } + GGML_ASSERT(add_ok && "batch must be large enough to hold the sampled and draft tokens"); + prompt.tokens.push_back(sampled); prompt.tokens.insert(spec_draft); } @@ -793,7 +891,7 @@ private: llama_context * ctx_tgt = nullptr; - llama_batch batch {}; + server_batch batch; llama_model_ptr model_dft; llama_context_ptr ctx_dft; @@ -845,8 +943,6 @@ private: mtmd_free(mctx); mctx = nullptr; - - llama_batch_free(batch); } void handle_sleeping_state(bool new_state) { @@ -1266,7 +1362,7 @@ private: // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used) { const int32_t n_batch = llama_n_batch(ctx_tgt); - batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1); + batch.init(std::max(n_batch, params_base.n_parallel)); } if (params_base.cache_ram_mib != 0) { @@ -2556,7 +2652,83 @@ private: } } + void iterate(std::vector & slots, std::function callback) { + for (auto & slot : slots) { + try { + callback(slot); + } catch (const std::exception & e) { + SLT_ERR(slot, "got exception: %s\n", e.what()); + send_error(slot, std::string("got exception: ") + e.what(), ERROR_TYPE_SERVER); + slot.release(); + } + } + } + + void iterate(std::vector & slots, std::function callback) { + for (auto & slot : slots) { + try { + callback(*slot); + } catch (const std::exception & e) { + SLT_ERR(*slot, "got exception: %s\n", e.what()); + send_error(*slot, std::string("got exception: ") + e.what(), ERROR_TYPE_SERVER); + slot->release(); + } + } + } + + void abort_all_slots(const std::string & reason) { + for (auto & slot : slots) { + if (slot.is_processing()) { + send_error(slot, reason, ERROR_TYPE_SERVER); + slot.release(); + } + } + } + + // @ngxson : for debugging only + int64_t t_pre_decode = 0; + int64_t t_decode = 0; + int64_t t_post_decode = 0; + int64_t t_sampl = 0; + int64_t n_pre_decode = 0; + int64_t n_decode = 0; + int64_t n_post_decode = 0; + int64_t n_sampl = 0; +// #define DEBUG_TIMINGS +#ifdef DEBUG_TIMINGS + struct scoped_timer { + int64_t & t; + int64_t & n; + int64_t t_start; + scoped_timer(int64_t & t_, int64_t & n_) : t(t_), n(n_) { + t_start = ggml_time_us(); + } + ~scoped_timer() { + t += ggml_time_us() - t_start; + n++; + } + }; +#else + struct scoped_timer { + scoped_timer(int64_t &, int64_t &) {} + ~scoped_timer() {} + }; +#endif + void update_slots() { +#ifdef DEBUG_TIMINGS + static int64_t t_prev = 0; + int64_t t_start = ggml_time_us(); + if (t_start - t_prev > 5 * 1000 * 1000) { // every 5 seconds + t_prev = t_start; + SRV_INF("n_pre_decode = %" PRId64 "\n", n_pre_decode); + SRV_INF("avg t_pre_decode = %f ms\n", (double) t_pre_decode / n_pre_decode / 1000.0); + SRV_INF("avg t_decode = %f ms\n", (double) t_decode / n_decode / 1000.0); + SRV_INF("avg t_post_decode = %f ms\n", (double) t_post_decode / n_post_decode / 1000.0); + SRV_INF("avg t_sampl = %f ms\n", (double) t_sampl / n_sampl / 1000.0); + } +#endif + // check if all slots are idle { bool all_idle = true; @@ -2570,29 +2742,80 @@ private: if (all_idle) { SRV_INF("%s", "all slots are idle\n"); + return; // skip further processing - return; + } else { + SRV_DBG("%s", "posting NEXT_RESPONSE\n"); + + server_task task(SERVER_TASK_TYPE_NEXT_RESPONSE); + task.id = queue_tasks.get_new_id(); + queue_tasks.post(std::move(task)); } } - { - SRV_DBG("%s", "posting NEXT_RESPONSE\n"); - - server_task task(SERVER_TASK_TYPE_NEXT_RESPONSE); - task.id = queue_tasks.get_new_id(); - queue_tasks.post(std::move(task)); + try { + scoped_timer t(t_pre_decode, n_pre_decode); + pre_decode(); + batch.render(); + } catch (const std::exception & e) { + SRV_ERR("pre_decode() failed: %s\n", e.what()); + abort_all_slots("pre_decode() failed: " + std::string(e.what())); } + llama_batch batch_view; + int32_t off_next = 0; + int32_t n_batch = llama_n_batch(ctx_tgt); + for (int32_t off = 0; off < batch.size(); off = off_next) { + const int32_t n_tokens = std::min(n_batch, batch.size() - off); + try { + scoped_timer t(t_decode, n_decode); + // TODO @ngxson : maybe handle n_batch == 1 here instead of inside decode() + + batch_view = batch.get_view(off, n_tokens); + bool ok = decode(n_batch, off, batch_view); +#ifdef DEBUG_TIMINGS + llama_synchronize(ctx_tgt); +#endif + + if (ok) { + // move the head of the batch forward with the number of tokens we just processed + off_next = off + n_tokens; + + // on successful decode, restore the original batch size + n_batch = llama_n_batch(ctx_tgt); + } else { + // try again with the updated n_batch + continue; + } + } catch (const std::exception & e) { + SRV_ERR("decode() failed: %s\n", e.what()); + abort_all_slots("decode() failed: " + std::string(e.what())); + break; // stop any further processing + } + + try { + scoped_timer t(t_post_decode, n_post_decode); + post_decode(n_tokens, off, batch_view); + } catch (const std::exception & e) { + SRV_ERR("post_decode() failed: %s\n", e.what()); + abort_all_slots("post_decode() failed: " + std::string(e.what())); + break; // stop any further processing + } + + } + } + + void pre_decode() { // apply context-shift if needed // TODO: simplify and improve - for (server_slot & slot : slots) { + iterate(slots, [&](server_slot & slot) { if (slot.state == SLOT_STATE_GENERATING && slot.prompt.n_tokens() + 1 >= slot.n_ctx) { if (!params_base.ctx_shift) { // this check is redundant (for good) // we should never get here, because generation should already stopped in process_token() send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER); slot.release(); - continue; + return; } if (mctx) { @@ -2604,7 +2827,7 @@ private: if (slot.task->is_parent() || slot.task->is_child()) { send_error(slot, "context shift cannot be used for shared prompt", ERROR_TYPE_SERVER); slot.release(); - continue; + return; } // Shift context @@ -2650,28 +2873,28 @@ private: slot.truncated = true; } - } + }); // start populating the batch for this iteration - common_batch_clear(batch); + batch.clear(); // track if given slot can be batched with slots already in the batch - server_slot * slot_batched = nullptr; + auto & slot_batched = batch.slot_batched; std::vector generating; std::vector drafting; // determine which slots are generating and drafting - for (auto & slot : slots) { + iterate(slots, [&](server_slot & slot) { if (slot.state != SLOT_STATE_GENERATING) { - continue; + return; } // check if we can batch this slot with the previous one if (!slot_batched) { slot_batched = &slot; } else if (!slot_batched->can_batch_with(slot)) { - continue; + return; } generating.push_back(&slot); @@ -2719,7 +2942,7 @@ private: } } } - } + }); // generate the actual drafts (if any) { @@ -2727,9 +2950,7 @@ private: } // make checkpoints if needed - for (auto * slot_ptr : drafting) { - auto & slot = *slot_ptr; - + iterate(drafting, [&](server_slot & slot) { auto & draft = slot.spec_draft; auto & ckpt = slot.spec_ckpt; @@ -2772,38 +2993,42 @@ private: ckpt.update_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); } } - } + }); // update the batch with the sampled/drafted tokens - for (auto * slot_ptr : generating) { - auto & slot = *slot_ptr; - - slot.update_batch(batch); - } + iterate(generating, [&](server_slot & slot) { + slot.handle_last_sampled_token(batch); + }); // process in chunks of params.n_batch int32_t n_batch = llama_n_batch(ctx_tgt); int32_t n_ubatch = llama_n_ubatch(ctx_tgt); - float alora_scale = -1.0f; - size_t alora_disabled_id = 0; + auto & alora_scale = batch.alora_scale; + auto & alora_disabled_id = batch.alora_disabled_id; // next, batch any pending prompts without exceeding n_batch - if (params_base.cont_batching || batch.n_tokens == 0) { - for (auto & slot : slots) { + if (params_base.cont_batching || batch.size() == 0) { + bool add_ok = true; // false means the batch is full, skip remaining slots + + iterate(slots, [&](server_slot & slot) { + if (!add_ok || batch.size() >= n_batch) { + return; // batch is full, skip remaining slots + } + if (!slot.is_processing()) { - continue; + return; } // check if we can batch this slot with the previous one if (slot_batched && !slot_batched->can_batch_with(slot)) { - continue; + return; } // check if this is a child slot if (slot.state == SLOT_STATE_WAIT_OTHER) { SLT_DBG(slot, "%s", "waiting for parent slot to complete\n"); - continue; + return; } // this slot still has a prompt to be processed @@ -2811,7 +3036,7 @@ private: const auto & input_tokens = slot.task->tokens; // used to determine the number of tokens added to the batch for the current slot - const auto n_tokens_prev = batch.n_tokens; + const auto n_tokens_prev = batch.size(); // TODO: maybe move branch to outside of this loop in the future if (slot.state == SLOT_STATE_STARTED) { @@ -2847,14 +3072,14 @@ private: send_final_response(slot); slot.release(); - continue; + return; } // TODO: support memory-less logits computation if (slot.task->need_logits() && !llama_get_memory(ctx_tgt)) { send_error(slot, "the current context does not logits computation. skipping", ERROR_TYPE_SERVER); slot.release(); - continue; + return; } if (!slot.can_split()) { @@ -2866,7 +3091,7 @@ private: slot.task->n_tokens(), n_ubatch), ERROR_TYPE_SERVER); slot.release(); - continue; + return; } if (slot.task->n_tokens() > slot.n_ctx) { @@ -2877,7 +3102,7 @@ private: slot.task->n_tokens(), slot.n_ctx), ERROR_TYPE_EXCEED_CONTEXT_SIZE); slot.release(); - continue; + return; } } else { if (slot.task->n_tokens() >= slot.n_ctx) { @@ -2887,7 +3112,7 @@ private: slot.task->n_tokens(), slot.n_ctx), ERROR_TYPE_EXCEED_CONTEXT_SIZE); slot.release(); - continue; + return; } if (slot.task->params.cache_prompt) { @@ -3107,8 +3332,8 @@ private: if (!slot.can_split()) { // cannot fit the prompt in the current batch - will try next iter - if (batch.n_tokens + slot.task->n_tokens() > n_batch) { - continue; + if (batch.size() + slot.task->n_tokens() > n_batch) { + return; } } @@ -3192,7 +3417,7 @@ private: const bool n_before_user_known = n_before_user > 0; // add prompt tokens for processing in the current batch - while (slot.prompt.n_tokens() < slot.task->n_tokens() && batch.n_tokens < n_batch) { + while (slot.prompt.n_tokens() < slot.task->n_tokens() && batch.size() < n_batch) { // get next token to process llama_token cur_tok = input_tokens[slot.prompt.n_tokens()]; if (cur_tok == LLAMA_TOKEN_NULL) { @@ -3210,10 +3435,9 @@ private: // embedding requires all tokens in the batch to be output; // MTP also wants logits at every prompt position so the // streaming hook can mirror t_h_nextn into ctx_dft. - common_batch_add(batch, + add_ok &= batch.add(slot.id, cur_tok, slot.prompt.tokens.pos_next(), - { slot.id }, slot.need_embd()); slot.prompt.tokens.push_back(cur_tok); @@ -3249,7 +3473,7 @@ private: } // the number of tokens added to the batch for the current slot - const auto n_tokens_cur = batch.n_tokens - n_tokens_prev; + const auto n_tokens_cur = batch.size() - n_tokens_prev; const bool near_prompt_end = slot.task->n_tokens() < slot.prompt.n_tokens() + n_ubatch; @@ -3257,13 +3481,13 @@ private: if (slot.prompt.n_tokens() == slot.task->n_tokens()) { slot.state = SLOT_STATE_DONE_PROMPT; - GGML_ASSERT(batch.n_tokens > 0); + GGML_ASSERT(batch.size() > 0); // extract the logits only for the last token - batch.logits[batch.n_tokens - 1] = true; + batch.set_output(batch.size() - 1, true); slot.n_decoded = 0; - slot.i_batch = batch.n_tokens - 1; + slot.i_batch = batch.size() - 1; slot.init_sampler(); } else { @@ -3322,20 +3546,20 @@ private: if (!slot_batched) { slot_batched = &slot; } - - if (batch.n_tokens >= n_batch) { - break; - } - } + }); } + } - SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens); + // returns true = success ; false = retry with smaller batch size + // throw std::runtime_error on fatal error + bool decode(int32_t & n_batch, int32_t off, llama_batch & batch_view) { + SRV_DBG("n_batch (effective) = %d, off = %d\n", n_batch, off); - auto accept_special_token = [&](server_slot & slot, llama_token token) { - return params_base.special || - slot.task->params.sampling.preserved_tokens.find(token) != slot.task->params.sampling.preserved_tokens.end(); - }; + auto & slot_batched = batch.slot_batched; + auto & alora_scale = batch.alora_scale; + auto & alora_disabled_id = batch.alora_disabled_id; + // TODO @ngxson : alora handling is too messy, need to refactor it to be more clear and maintainable if (slot_batched) { // apply lora, only need to do it once per batch common_set_adapter_lora(ctx_tgt, slot_batched->lora); @@ -3350,340 +3574,348 @@ private: llama_set_embeddings(ctx_tgt, slot_batched->need_embd()); } - if (batch.n_tokens == 0) { + if (batch.size() == 0) { SRV_WRN("%s", "no tokens to decode\n"); if (++n_empty_consecutive > 3) { GGML_ABORT("fatal error - please provide logs and repro in %s\n", "https://github.com/ggml-org/llama.cpp/pull/20277"); } + + return true; // nothing to decode } else { n_empty_consecutive = 0; } - int32_t i_next = 0; + const int ret = llama_decode(ctx_tgt, batch_view); - // process the created batch of tokens - for (int32_t i = 0; i < batch.n_tokens; i = i_next) { - const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i); + metrics.on_decoded(slots); - llama_batch batch_view = { - n_tokens, - batch.token + i, - nullptr, - batch.pos + i, - batch.n_seq_id + i, - batch.seq_id + i, - batch.logits + i, - }; + if (ret != 0) { + { + std::string err; - const int ret = llama_decode(ctx_tgt, batch_view); - - metrics.on_decoded(slots); - - if (ret != 0) { - { - std::string err; - - if (n_batch == 1 && ret == 1) { - // TODO: try to terminate only the largest active slot/sequence and continue with the rest - // need to remove the tokens from the current batch too - err = "Context size has been exceeded."; - } - - if (ret == -1) { - err = "Invalid input batch."; - } - - if (ret < -1) { - // TODO: update slot state based on llama_memory_seq_pos_min() and llama_memory_seq_pos_max() - err = "Compute error."; - } - - // TODO: handle ret == 2 (abort) when we start aborting - - if (!err.empty()) { - SRV_ERR("%s i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret); - - for (auto & slot : slots) { - if (slot.is_processing()) { - send_error(slot, err); - slot.release(); - - // note: it's complicated to keep track of how much of the current batch has been - // processed before the error occurred, so we simply clear the entire context - slot.prompt_clear(false); - } - } - - break; - } + if (n_batch == 1 && ret == 1) { + // TODO: try to terminate only the largest active slot/sequence and continue with the rest + // need to remove the tokens from the current batch too + err = "Context size has been exceeded."; } - // retry with half the batch size to try to find a free slot in the KV cache - if (!try_clear_idle_slots()) { - n_batch /= 2; + if (ret == -1) { + err = "Invalid input batch."; } - SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret); + if (ret < -1) { + // TODO: update slot state based on llama_memory_seq_pos_min() and llama_memory_seq_pos_max() + err = "Compute error."; + } - continue; // continue loop of n_batch - } + // TODO: handle ret == 2 (abort) when we start aborting - // TODO: avoid restoring the draft context and re-evaluating the drafted tokens when not needed [TAG_SPEC_AVOID_DRAFT_REEVAL] - // for now, always re-evaluate for simplicity - // ref: https://github.com/ggml-org/llama.cpp/pull/22728#issuecomment-4400925384 - if (!common_speculative_process(spec.get(), batch_view)) { - SRV_ERR("%s", "failed to process speculative batch\n"); + if (!err.empty()) { + SRV_ERR("%s off = %d, n_batch = %d, ret = %d\n", err.c_str(), off, n_batch, ret); - // TODO: handle error - break; - } + for (auto & slot : slots) { + if (slot.is_processing()) { + send_error(slot, err); + slot.release(); - // move the head of the batch forward with the number of tokens we just processed - i_next = i + n_tokens; - - // on successful decode, restore the original batch size - n_batch = llama_n_batch(ctx_tgt); - - // handle `n_cmpl > 1` tasks - when the main prompt is processed, activate all child tasks too - for (auto & slot : slots) { - if (slot.state == SLOT_STATE_DONE_PROMPT && slot.task->is_parent()) { - std::vector children; - for (auto & other : slots) { - if (other.state == SLOT_STATE_WAIT_OTHER && slot.task->id == other.task->id_parent) { - children.push_back(&other); + // note: it's complicated to keep track of how much of the current batch has been + // processed before the error occurred, so we simply clear the entire context + slot.prompt_clear(false); } } - // all children slots should already launched by launch_slots_with_parent_task() - // copy state to the child slots - for (auto & child : children) { - SLT_INF(slot, " - copying state to child %d\n", child->id); - - GGML_ASSERT(child->state == SLOT_STATE_WAIT_OTHER); - - slot.copy_state_to(*child); - child->state = SLOT_STATE_DONE_PROMPT; - } + // stop, do not retry with smaller batch size + throw std::runtime_error(err); } } - for (auto & slot : slots) { - // optionally send prompt processing progress - if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_DONE_PROMPT) { - if (slot.task->params.stream && slot.task->params.return_progress) { - send_partial_response(slot, {}, true); + // retry with half the batch size to try to find a free slot in the KV cache + if (!try_clear_idle_slots()) { + n_batch /= 2; + } + + SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size, off = %d, n_batch = %d, ret = %d\n", off, n_batch, ret); + + return false; // retry with the updated n_batch + } + + // TODO: avoid restoring the draft context and re-evaluating the drafted tokens when not needed [TAG_SPEC_AVOID_DRAFT_REEVAL] + // for now, always re-evaluate for simplicity + // ref: https://github.com/ggml-org/llama.cpp/pull/22728#issuecomment-4400925384 + if (!common_speculative_process(spec.get(), batch_view)) { + SRV_ERR("%s", "failed to process speculative batch\n"); + + // TODO: handle error + throw std::runtime_error("failed to process speculative batch"); + } + + // handle `n_cmpl > 1` tasks - when the main prompt is processed, activate all child tasks too + for (auto & slot : slots) { + if (slot.state == SLOT_STATE_DONE_PROMPT && slot.task->is_parent()) { + std::vector children; + for (auto & other : slots) { + if (other.state == SLOT_STATE_WAIT_OTHER && slot.task->id == other.task->id_parent) { + children.push_back(&other); } } - if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) { - continue; // continue loop of slots + // all children slots should already launched by launch_slots_with_parent_task() + // copy state to the child slots + for (auto & child : children) { + SLT_INF(slot, " - copying state to child %d\n", child->id); + + GGML_ASSERT(child->state == SLOT_STATE_WAIT_OTHER); + + slot.copy_state_to(*child); + child->state = SLOT_STATE_DONE_PROMPT; + } + } + } + + return true; + } + + void post_decode(int32_t n_batch_tokens, int32_t off, llama_batch & batch_view) { + // for checking if a given batch index is inside batch_view + auto is_inside_view = [&](int32_t idx) { + return idx >= off && idx < off + n_batch_tokens; + }; + + // TODO @ngxson : it's tricky to make sub-batch compatible with common_sampler_sample_and_accept_n, + // so for now we will throw an error in this case: https://github.com/ggml-org/llama.cpp/issues/24840 + iterate(slots, [&](server_slot & slot) { + for (auto & i : slot.spec_i_batch) { + if (!is_inside_view(i)) { + throw std::runtime_error(string_format("speculative batch index %d is not inside the current sub-batch [%d, %d)", i, off, off + n_batch_tokens)); + } + } + }); + + auto accept_special_token = [&](server_slot & slot, llama_token token) { + return params_base.special || + slot.task->params.sampling.preserved_tokens.find(token) != slot.task->params.sampling.preserved_tokens.end(); + }; + + iterate(slots, [&](server_slot & slot) { + // optionally send prompt processing progress + if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_DONE_PROMPT) { + if (slot.task->params.stream && slot.task->params.return_progress) { + send_partial_response(slot, {}, true); + } + } + + if (!is_inside_view(slot.i_batch)) { + // the required token not in this sub-batch, skip + return; + } + + if (slot.state == SLOT_STATE_DONE_PROMPT) { + if (slot.task->type == SERVER_TASK_TYPE_EMBEDDING) { + // prompt evaluated for embedding + send_embedding(slot, batch_view); + slot.release(); + slot.i_batch = -1; + return; } - if (slot.state == SLOT_STATE_DONE_PROMPT) { - if (slot.task->type == SERVER_TASK_TYPE_EMBEDDING) { - // prompt evaluated for embedding - send_embedding(slot, batch_view); - slot.release(); - slot.i_batch = -1; - continue; // continue loop of slots - } - - if (slot.task->type == SERVER_TASK_TYPE_RERANK) { - send_rerank(slot, batch_view); - slot.release(); - slot.i_batch = -1; - continue; // continue loop of slots - } - - GGML_ASSERT(slot.task->need_sampling()); - - // prompt evaluated for next-token prediction - slot.state = SLOT_STATE_GENERATING; - - if (slot.can_speculate()) { - common_speculative_begin(spec.get(), slot.id, slot.prompt.tokens.get_text_tokens()); - } - } else if (slot.state != SLOT_STATE_GENERATING) { - continue; // continue loop of slots + if (slot.task->type == SERVER_TASK_TYPE_RERANK) { + send_rerank(slot, batch_view); + slot.release(); + slot.i_batch = -1; + return; } - if (slot.can_speculate() && !slot.spec_draft.empty()) { - continue; // sample using speculative decoding + GGML_ASSERT(slot.task->need_sampling()); + + // prompt evaluated for next-token prediction + slot.state = SLOT_STATE_GENERATING; + + if (slot.can_speculate()) { + common_speculative_begin(spec.get(), slot.id, slot.prompt.tokens.get_text_tokens()); + } + } else if (slot.state != SLOT_STATE_GENERATING) { + return; + } + + if (slot.can_speculate() && !slot.spec_draft.empty()) { + return; // sample using speculative decoding + } + + // shifted according to the current sub-batch + const int tok_idx = slot.i_batch - off; + + llama_token id; + { + scoped_timer timer(t_sampl, n_sampl); + id = common_sampler_sample(slot.smpl.get(), slot.ctx_tgt, tok_idx); + } + + slot.i_batch = -1; + + common_sampler_accept(slot.smpl.get(), id, true); + + // here we have synchronized the llama_context (due to the sampling above), so we can do time measurement + const int64_t t_now = ggml_time_us(); + + slot.n_decoded += 1; + + if (slot.n_decoded == 1) { + slot.t_start_generation = t_now; + slot.t_print_last = t_now; + slot.n_decoded_last = 0; + slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3; + metrics.on_prompt_eval(slot); + } + + slot.t_token_generation = std::max(1, t_now - slot.t_start_generation) / 1e3; + + completion_token_output result; + result.tok = id; + result.text_to_send = common_token_to_piece(slot.ctx_tgt, result.tok, accept_special_token(slot, result.tok)); + result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs + + if (slot.task->params.sampling.n_probs > 0) { + populate_token_probs(slot, result, slot.task->params.post_sampling_probs, params_base.special, tok_idx); + } + + if (!process_token(result, slot)) { + // release slot because of stop condition + slot.print_timings(); + send_final_response(slot); + metrics.on_prediction(slot); + slot.release(); + + return; + } + + slot.print_timings_tg(); + }); + + // speculative decoding - main model sample and accept + iterate(slots, [&](server_slot & slot) { + if (slot.state != SLOT_STATE_GENERATING || !slot.can_speculate() || slot.spec_draft.empty()) { + return; + } + + // save the original draft size + const size_t n_draft = slot.spec_draft.size(); + + GGML_ASSERT(n_draft > 0); + + // verify and try to accept the draft + { + // save the sampler sampler state in case we need to restore it + common_sampler_ptr smpl_save(common_sampler_clone(slot.smpl.get())); + + GGML_ASSERT(slot.spec_i_batch.size() == n_draft + 1); + auto accepted = common_sampler_sample_and_accept_n(slot.smpl.get(), slot.ctx_tgt, slot.spec_i_batch, slot.spec_draft); + slot.spec_i_batch.clear(); + + GGML_ASSERT(accepted.size() >= 1); + + const uint32_t n_rollback = slot.spec_draft.size() + 1 - accepted.size(); + + const bool use_ckpt_tgt = + ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL || + (ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_RS && n_rollback > llama_n_rs_seq(ctx_tgt)); + + // check for partial draft acceptance + if (n_rollback > 0) { + if (use_ckpt_tgt) { + if (trace > 0) { + SLT_INF(slot, "accepted %2zu/%2zu draft tokens (restore checkpoint)\n", accepted.size() - 1, slot.spec_draft.size()); + } + + // partial acceptance is not supported by the context -> truncate the draft and restore the state + slot.spec_draft = std::move(accepted); + + const auto & ckpt = slot.spec_ckpt; + + SLT_DBG(slot, "restoring speculative checkpoint (pos_min = %d, pos_max = %d, size = %zu)\n", ckpt.pos_min, ckpt.pos_max, ckpt.size()); + + { + ckpt.load_tgt(slot.ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + + common_context_seq_rm(slot.ctx_tgt, slot.id, ckpt.pos_max + 1, -1); + } + + if (slot.ctx_dft) { + ckpt.load_dft(slot.ctx_dft, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + + common_context_seq_rm(slot.ctx_dft, slot.id, ckpt.pos_max + 1, -1); + } + + slot.prompt.tokens.keep_first(ckpt.n_tokens); + slot.smpl = std::move(smpl_save); + + return; + } } - const int tok_idx = slot.i_batch - i; + if (trace > 0) { + SLT_INF(slot, "accepted %2zu/%2zu draft tokens\n", accepted.size() - 1, n_draft); + } - llama_token id = common_sampler_sample(slot.smpl.get(), slot.ctx_tgt, tok_idx); + common_speculative_accept(spec.get(), slot.id, accepted.size() - 1); - slot.i_batch = -1; + slot.spec_draft = std::move(accepted); + } - common_sampler_accept(slot.smpl.get(), id, true); + const int64_t t_now = ggml_time_us(); - // here we have synchronized the llama_context (due to the sampling above), so we can do time measurement - const int64_t t_now = ggml_time_us(); + const auto ids = std::move(slot.spec_draft); + + slot.t_token_generation = std::max(1, t_now - slot.t_start_generation) / 1e3; + + // update how many tokens out of those tested were accepted + slot.n_draft_accepted += ids.size() - 1; + slot.n_draft_verif_steps += 1; + + if (slot.n_accepted_per_pos.empty()) { + slot.n_accepted_per_pos.resize(common_speculative_n_max(¶ms_base.speculative), 0); + } + for (size_t i = 0; i < ids.size() - 1 && i < slot.n_accepted_per_pos.size(); ++i) { + slot.n_accepted_per_pos[i]++; + } + + // add accepted tokens to the prompt + slot.prompt.tokens.keep_first(slot.prompt.n_tokens() - n_draft); + slot.prompt.tokens.insert({ids.begin(), ids.end() - 1}); + + slot.sampled = ids.back(); // last accepted token + SLT_DBG(slot, "add accepted tokens: sampled=%d, ids.size=%zu, n_draft=%zu\n", slot.sampled, ids.size(), n_draft); + + common_context_seq_rm(slot.ctx_tgt, slot.id, slot.prompt.tokens.pos_next(), -1); + if (slot.ctx_dft) { + common_context_seq_rm(slot.ctx_dft, slot.id, slot.prompt.tokens.pos_next(), -1); + } + + for (size_t i = 0; i < ids.size(); ++i) { + completion_token_output result; + + result.tok = ids[i]; + result.text_to_send = common_token_to_piece(slot.ctx_tgt, result.tok, accept_special_token(slot, result.tok)); + result.prob = 1.0f; // set later + + // TODO: set result.probs slot.n_decoded += 1; - if (slot.n_decoded == 1) { - slot.t_start_generation = t_now; - slot.t_print_last = t_now; - slot.n_decoded_last = 0; - slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3; - metrics.on_prompt_eval(slot); - } - - slot.t_token_generation = std::max(1, t_now - slot.t_start_generation) / 1e3; - - completion_token_output result; - result.tok = id; - result.text_to_send = common_token_to_piece(slot.ctx_tgt, result.tok, accept_special_token(slot, result.tok)); - result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs - - if (slot.task->params.sampling.n_probs > 0) { - populate_token_probs(slot, result, slot.task->params.post_sampling_probs, params_base.special, tok_idx); - } - if (!process_token(result, slot)) { - // release slot because of stop condition slot.print_timings(); send_final_response(slot); metrics.on_prediction(slot); slot.release(); - continue; + return; } - - slot.print_timings_tg(); } - // speculative decoding - main model sample and accept - for (auto & slot : slots) { - if (slot.state != SLOT_STATE_GENERATING || !slot.can_speculate() || slot.spec_draft.empty()) { - continue; - } + slot.print_timings_tg(); - // save the original draft size - const size_t n_draft = slot.spec_draft.size(); - - GGML_ASSERT(n_draft > 0); - - // verify and try to accept the draft - { - // save the sampler sampler state in case we need to restore it - common_sampler_ptr smpl_save(common_sampler_clone(slot.smpl.get())); - - GGML_ASSERT(slot.spec_i_batch.size() == n_draft + 1); - auto accepted = common_sampler_sample_and_accept_n(slot.smpl.get(), slot.ctx_tgt, slot.spec_i_batch, slot.spec_draft); - slot.spec_i_batch.clear(); - - GGML_ASSERT(accepted.size() >= 1); - - const uint32_t n_rollback = slot.spec_draft.size() + 1 - accepted.size(); - - const bool use_ckpt_tgt = - ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL || - (ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_RS && n_rollback > llama_n_rs_seq(ctx_tgt)); - - // check for partial draft acceptance - if (n_rollback > 0) { - if (use_ckpt_tgt) { - if (trace > 0) { - SLT_INF(slot, "accepted %2zu/%2zu draft tokens (restore checkpoint)\n", accepted.size() - 1, slot.spec_draft.size()); - } - - // partial acceptance is not supported by the context -> truncate the draft and restore the state - slot.spec_draft = std::move(accepted); - - const auto & ckpt = slot.spec_ckpt; - - SLT_DBG(slot, "restoring speculative checkpoint (pos_min = %d, pos_max = %d, size = %zu)\n", ckpt.pos_min, ckpt.pos_max, ckpt.size()); - - { - ckpt.load_tgt(slot.ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - - common_context_seq_rm(slot.ctx_tgt, slot.id, ckpt.pos_max + 1, -1); - } - - if (slot.ctx_dft) { - ckpt.load_dft(slot.ctx_dft, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - - common_context_seq_rm(slot.ctx_dft, slot.id, ckpt.pos_max + 1, -1); - } - - slot.prompt.tokens.keep_first(ckpt.n_tokens); - slot.smpl = std::move(smpl_save); - - continue; - } - } - - if (trace > 0) { - SLT_INF(slot, "accepted %2zu/%2zu draft tokens\n", accepted.size() - 1, n_draft); - } - - common_speculative_accept(spec.get(), slot.id, accepted.size() - 1); - - slot.spec_draft = std::move(accepted); - } - - const int64_t t_now = ggml_time_us(); - - const auto ids = std::move(slot.spec_draft); - - slot.t_token_generation = std::max(1, t_now - slot.t_start_generation) / 1e3; - - // update how many tokens out of those tested were accepted - slot.n_draft_accepted += ids.size() - 1; - slot.n_draft_verif_steps += 1; - - if (slot.n_accepted_per_pos.empty()) { - slot.n_accepted_per_pos.resize(common_speculative_n_max(¶ms_base.speculative), 0); - } - for (size_t i = 0; i < ids.size() - 1 && i < slot.n_accepted_per_pos.size(); ++i) { - slot.n_accepted_per_pos[i]++; - } - - // add accepted tokens to the prompt - slot.prompt.tokens.keep_first(slot.prompt.n_tokens() - n_draft); - slot.prompt.tokens.insert({ids.begin(), ids.end() - 1}); - - slot.sampled = ids.back(); // last accepted token - SLT_DBG(slot, "add accepted tokens: sampled=%d, ids.size=%zu, n_draft=%zu\n", slot.sampled, ids.size(), n_draft); - - common_context_seq_rm(slot.ctx_tgt, slot.id, slot.prompt.tokens.pos_next(), -1); - if (slot.ctx_dft) { - common_context_seq_rm(slot.ctx_dft, slot.id, slot.prompt.tokens.pos_next(), -1); - } - - for (size_t i = 0; i < ids.size(); ++i) { - completion_token_output result; - - result.tok = ids[i]; - result.text_to_send = common_token_to_piece(slot.ctx_tgt, result.tok, accept_special_token(slot, result.tok)); - result.prob = 1.0f; // set later - - // TODO: set result.probs - - slot.n_decoded += 1; - - if (!process_token(result, slot)) { - slot.print_timings(); - send_final_response(slot); - metrics.on_prediction(slot); - slot.release(); - - break; - } - } - - slot.print_timings_tg(); - - SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int) ids.size() - 1, (int) n_draft, slot.prompt.n_tokens()); - } - } - - SRV_DBG("%s", "run slots completed\n"); + SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int) ids.size() - 1, (int) n_draft, slot.prompt.n_tokens()); + }); } int get_slot_n_ctx() { From 7c082bc417bbe53210a83df4ba5b49e18ce6193c Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Sun, 21 Jun 2026 17:36:52 +0200 Subject: [PATCH 42/86] server: fix report progress for loading spec models, add "stages" list (#24870) * server: fix report progress for loading spec models, add "stages" list * improve * nits * nits 2 --- tools/server/README.md | 8 +++- tools/server/server-context.cpp | 71 ++++++++++++++++++++------------- 2 files changed, 50 insertions(+), 29 deletions(-) diff --git a/tools/server/README.md b/tools/server/README.md index 5efdad0954..7fa3a4d728 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -1863,11 +1863,15 @@ Example events: "data": { "status": "loading", "progress": { - "stage": "fit_params", - "value": 0.5 // from 0.0 to 1.0 ; note: not all stages have this "value" + "stages": ["text_model", "spec_model", "mmproj_model"], + "current": "text_model", + "value": 0.5 } } } +// note for "loading" status: +// - subsequent events will follow the same order of "stages" list +// - mmap is may report incorrect progress on some platforms; if you need exact progress, use --no-mmap { "model": "...", diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 91a8eb9452..3f9391cacb 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -962,6 +962,7 @@ private: struct load_progress_data { server_context_impl * ctx; std::string stage; + std::vector stages; int64_t t_last_load_progress_ms = 0; load_progress_data(server_context_impl * ctx, const std::string & stage) : ctx(ctx), stage(stage) {} }; @@ -982,7 +983,8 @@ private: } if (d->ctx->callback_state) { d->ctx->callback_state(SERVER_STATE_LOADING, { - {"stage", d->stage}, + {"stages", d->stages}, + {"current", d->stage}, {"value", progress}, }); } @@ -992,18 +994,42 @@ private: // load the model and initialize llama_context // this may also be called to resume from sleeping state bool load_model(common_params & params) { - load_progress_data load_progress_text(this, "text_model"); + load_progress_data load_progress_text (this, "text_model"); load_progress_data load_progress_mmproj(this, "mmproj_model"); + load_progress_data load_progress_spec (this, "spec_model"); - bool is_resume = sleeping; - - SRV_INF("loading model '%s'\n", params.model.path.c_str()); + const bool is_resume = sleeping; params_base = params; params_base.n_outputs_max = server_n_outputs_max(params_base); + const bool has_mmproj = !params.mmproj.path.empty(); + const bool has_draft = params.speculative.has_dft(); + const bool spec_mtp = std::find(params_base.speculative.types.begin(), + params_base.speculative.types.end(), + COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params_base.speculative.types.end(); + const bool has_spec = has_draft || spec_mtp; + + if (callback_state) { + std::vector stages = {"text_model"}; + if (has_spec) { + stages.push_back("spec_model"); + } + if (has_mmproj) { + stages.push_back("mmproj_model"); + } + load_progress_text.stages = stages; + load_progress_mmproj.stages = stages; + load_progress_spec.stages = stages; + + // trigger 0% progress + load_progress_callback(0.0f, &load_progress_text); + } + + + SRV_INF("loading model '%s'\n", params.model.path.c_str()); + std::string & mmproj_path = params_base.mmproj.path; - bool has_mmproj = !mmproj_path.empty(); mtmd_context_params mparams = mtmd_context_params_default(); if (has_mmproj) { mparams.use_gpu = params_base.mmproj_use_gpu; @@ -1050,16 +1076,7 @@ private: // optionally reserve VRAM for the draft / MTP context before fitting the target model if (params_base.fit_params) { - if (callback_state) { - callback_state(SERVER_STATE_LOADING, {{"stage", "fit_params"}}); - } - - const bool spec_mtp = std::find(params_base.speculative.types.begin(), - params_base.speculative.types.end(), - COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params_base.speculative.types.end(); - const bool has_draft = params_base.speculative.has_dft(); - - if (has_draft || spec_mtp) { + if (has_spec) { common_params params_dft = params_base; bool measure_model_bytes = true; @@ -1151,11 +1168,7 @@ private: add_bos_token = llama_vocab_get_add_bos(vocab); - if (params_base.speculative.has_dft()) { - if (callback_state) { - callback_state(SERVER_STATE_LOADING, {{"stage", "spec_model"}}); - } - + if (has_draft) { // TODO speculative: move to common/speculative.cpp? const auto & params_spec = params_base.speculative.draft; @@ -1178,6 +1191,10 @@ private: auto mparams_dft = common_model_params_to_llama(params_dft); + // progress callback + mparams_dft.progress_callback = load_progress_callback; + mparams_dft.progress_callback_user_data = &load_progress_spec; + model_dft.reset(llama_model_load_from_file(params_dft.model.path.c_str(), mparams_dft)); if (model_dft == nullptr) { SRV_ERR("failed to load draft model, '%s'\n", params_dft.model.path.c_str()); @@ -1186,10 +1203,6 @@ private: auto cparams = common_context_params_to_llama(params_dft); - const bool spec_mtp = std::find(params_base.speculative.types.begin(), - params_base.speculative.types.end(), - COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params_base.speculative.types.end(); - if (spec_mtp) { cparams.ctx_type = LLAMA_CONTEXT_TYPE_MTP; } @@ -1203,8 +1216,10 @@ private: params_base.speculative.draft.ctx_tgt = ctx_tgt; params_base.speculative.draft.ctx_dft = ctx_dft.get(); - } else if (std::find(params_base.speculative.types.begin(), params_base.speculative.types.end(), - COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params_base.speculative.types.end()) { + } else if (spec_mtp) { + // no new model load, so we simply report 0.0 and 1.0 progress + load_progress_callback(0.0f, &load_progress_spec); + SRV_INF("creating MTP draft context against the target model '%s'\n", params_base.model.path.c_str()); @@ -1224,6 +1239,8 @@ private: params_base.speculative.draft.ctx_tgt = ctx_tgt; params_base.speculative.draft.ctx_dft = ctx_dft.get(); + + load_progress_callback(1.0f, &load_progress_spec); } if (has_mmproj) { From 52b3df0023659b142ce29f75c7a82cf437769c33 Mon Sep 17 00:00:00 2001 From: Aldehir Rojas Date: Sun, 21 Jun 2026 16:20:58 -0500 Subject: [PATCH 43/86] common/peg : implement ac parser for stricter grammar generation (#24869) * common/peg : implement ac parser * cont : extract functions * cont : tidy up * cont : remove a test * cont : move ac() def --- common/chat-auto-parser-generator.cpp | 9 +- common/peg-parser.cpp | 131 +++++++++++++++++----- common/peg-parser.h | 15 ++- tests/peg-parser/test-gbnf-generation.cpp | 69 ++++++++++++ 4 files changed, 190 insertions(+), 34 deletions(-) diff --git a/common/chat-auto-parser-generator.cpp b/common/chat-auto-parser-generator.cpp index 37ca55c8df..36aab7ecbe 100644 --- a/common/chat-auto-parser-generator.cpp +++ b/common/chat-auto-parser-generator.cpp @@ -395,10 +395,11 @@ common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_conte arguments.name_suffix) + arguments.value_prefix + (schema_info.resolves_to_string(param_schema) ? - p.tool_arg_string_value(until_suffix) : - p.tool_arg_json_value(p.schema( - p.json(), "tool-" + name + "-arg-" + param_name + "-schema", param_schema, false))) + - p.tool_arg_close(p.literal(arguments.value_suffix))); + p.ac(p.tool_arg_string_value(until_suffix) + + p.tool_arg_close(p.literal(arguments.value_suffix)), arguments.value_suffix) : + (p.tool_arg_json_value(p.schema( + p.json(), "tool-" + name + "-arg-" + param_name + "-schema", param_schema, false)) + + p.tool_arg_close(p.literal(arguments.value_suffix))))); auto named_arg = p.rule("tool-" + name + "-arg-" + param_name, arg); if (is_required) { diff --git a/common/peg-parser.cpp b/common/peg-parser.cpp index 506b902451..807e952d90 100644 --- a/common/peg-parser.cpp +++ b/common/peg-parser.cpp @@ -921,6 +921,10 @@ struct parser_executor { common_peg_parse_result operator()(const common_peg_gbnf_parser & p) { return arena.parse(p.child, ctx, start_pos); } + + common_peg_parse_result operator()(const common_peg_ac_parser & p) { + return arena.parse(p.child, ctx, start_pos); + } }; common_peg_parse_result common_peg_arena::parse(common_peg_parse_context & ctx, size_t start) const { @@ -989,7 +993,8 @@ void common_peg_arena::resolve_refs() { std::is_same_v || std::is_same_v || std::is_same_v || - std::is_same_v) { + std::is_same_v || + std::is_same_v) { p.child = resolve_ref(p.child); } else if constexpr (std::is_same_v) { p.child = resolve_ref(p.child); @@ -1070,6 +1075,8 @@ std::string common_peg_arena::dump_impl(common_peg_parser_id return "Atomic(" + dump_impl(p.child, visited) + ")"; } else if constexpr (std::is_same_v) { return "Gbnf(" + p.grammar + ", " + dump_impl(p.child, visited) + ")"; + } else if constexpr (std::is_same_v) { + return "Ac(" + string_join(p.delimiters, " | ") + ", " + dump_impl(p.child, visited) + ")"; } else if constexpr (std::is_same_v) { return "Any"; } else if constexpr (std::is_same_v) { @@ -1479,6 +1486,13 @@ common_peg_parser common_peg_parser_builder::json_member(const std::string & key }); } +common_peg_parser common_peg_parser_builder::ac(const common_peg_parser & p, const std::vector & delimiters) { + if (delimiters.empty()) { + throw std::runtime_error("ac parser requires at least one delimiter"); + } + return add(common_peg_ac_parser{p, delimiters}); +} + static std::string gbnf_escape_char_class(uint32_t c) { if (c == '-' || c == ']' || c == '[' || c == '\\') { return "\\" + std::string(1, (char) c); @@ -1529,14 +1543,22 @@ static std::string gbnf_escape_char_class(uint32_t c) { return std::string(buf); } -// GBNF grammar matching strings that contain no string in `strings` as a -// substring. Emits the complement of an Aho-Corasick automaton DFA and returns -// the start state rule name. -// -// ref: https://github.com/ggml-org/llama.cpp/pull/24839 -static std::string gbnf_excluding_grammar(const common_grammar_builder & builder, - const std::string & prefix, - const std::vector & strings) { +static std::string gbnf_char_class(const std::vector & chars, bool negate) { + std::string s = negate ? "[^" : "["; + for (uint32_t ch : chars) { + s += gbnf_escape_char_class(ch); + } + return s + "]"; +} + +static std::string gbnf_ac_grammar( + const common_grammar_builder & builder, + const std::string & prefix, + const std::vector & strings, + const std::function &, + const std::map> &, + const std::vector &, + const std::function &)> & build_rule) { aho_corasick ac(strings); auto state_name = [&](size_t s) -> std::string { @@ -1548,42 +1570,30 @@ static std::string gbnf_excluding_grammar(const common_grammar_builder & builder return prefix + "-" + num; }; - auto char_class = [](const std::vector & chars, bool negate) { - std::string s = negate ? "[^" : "["; - for (uint32_t ch : chars) { - s += gbnf_escape_char_class(ch); - } - return s + "]"; - }; - for (size_t q = 0; q < ac.num_states(); q++) { if (ac.is_terminal(q)) { - continue; // match states are dropped + continue; // match states } std::map> buckets; - std::vector excluded; + std::vector completing; // chars that complete a delimiter + std::vector specific; // chars with an explicit transition for (uint32_t c : ac.alphabet) { size_t d = ac.next(q, c); if (ac.is_terminal(d)) { - excluded.push_back(c); // completes a forbidden string -> omit + completing.push_back(c); + specific.push_back(c); } else if (d != 0) { buckets[d].push_back(c); // specific non-root destination - excluded.push_back(c); + specific.push_back(c); } } - std::string rhs = "|"; // every state is accepting - for (const auto & [d, chars] : buckets) { - rhs += " " + char_class(chars, false) + " " + state_name(d) + " |"; - } - rhs += " " + char_class(excluded, true) + " " + state_name(0); - - builder.add_rule(state_name(q), rhs); + builder.add_rule(state_name(q), build_rule(completing, buckets, specific, state_name)); } // An empty delimiter makes the start state terminal. Emit an entry rule - // that matches nothing so the returned reference stays valid. + // that matches the empty string so the returned reference stays valid. if (ac.is_terminal(0)) { builder.add_rule(prefix, "|"); } @@ -1591,6 +1601,54 @@ static std::string gbnf_excluding_grammar(const common_grammar_builder & builder return state_name(0); } +// GBNF grammar matching strings that contain no string in `strings` as a +// substring. Emits the complement of an Aho-Corasick automaton DFA and returns +// the start state rule name. +// +// ref: https://github.com/ggml-org/llama.cpp/pull/24839 +static std::string gbnf_excluding_grammar(const common_grammar_builder & builder, + const std::string & prefix, + const std::vector & strings) { + return gbnf_ac_grammar(builder, prefix, strings, + [](const std::vector & /*completing*/, + const std::map> & buckets, + const std::vector & specific, + const std::function & state_name) { + // every state is accepting and completing chars get no + // alternative, so a forbidden string can never be matched + std::string rhs = "|"; + for (const auto & [d, chars] : buckets) { + rhs += " " + gbnf_char_class(chars, false) + " " + state_name(d) + " |"; + } + rhs += " " + gbnf_char_class(specific, true) + " " + state_name(0); + return rhs; + }); +} + +// GBNF grammar matching everything up to and including the first occurrence of +// any string in `strings`. Emits the Aho-Corasick automaton DFA and returns +// the start state rule name. +static std::string gbnf_including_grammar(const common_grammar_builder & builder, + const std::string & prefix, + const std::vector & strings) { + return gbnf_ac_grammar(builder, prefix, strings, + [](const std::vector & completing, + const std::map> & buckets, + const std::vector & specific, + const std::function & state_name) { + std::vector alts; + if (!completing.empty()) { + alts.push_back(gbnf_char_class(completing, false)); // terminate on match + } + for (const auto & [d, chars] : buckets) { + alts.push_back(gbnf_char_class(chars, false) + " " + state_name(d)); + } + // every other character keeps scanning from the start state + alts.push_back(gbnf_char_class(specific, true) + " " + state_name(0)); + return string_join(alts, " | "); + }); +} + static std::set collect_reachable_rules( const common_peg_arena & arena, const common_peg_parser_id & rule @@ -1628,6 +1686,7 @@ static std::set collect_reachable_rules( std::is_same_v || std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) { visit(p.child); } else if constexpr (std::is_same_v) { @@ -1822,6 +1881,8 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo return to_gbnf(p.child); } else if constexpr (std::is_same_v) { return p.grammar; + } else if constexpr (std::is_same_v) { + return gbnf_including_grammar(builder, "ac-" + std::to_string(id), p.delimiters); } else { static_assert(is_always_false_v); } @@ -1958,6 +2019,8 @@ static nlohmann::json serialize_parser_variant(const common_peg_parser_variant & }; } else if constexpr (std::is_same_v) { return json{{"type", "gbnf"}, {"child", p.child}, {"grammar", p.grammar}}; + } else if constexpr (std::is_same_v) { + return json{{"type", "ac"}, {"child", p.child}, {"delimiters", p.delimiters}}; } }, variant); } @@ -2130,6 +2193,16 @@ static common_peg_parser_variant deserialize_parser_variant(const nlohmann::json }; } + if (type == "ac") { + if (!j.contains("child") || !j.contains("delimiters") || !j["delimiters"].is_array() || j["delimiters"].empty()) { + throw std::runtime_error("ac parser requires 'child' and a non-empty 'delimiters' array"); + } + return common_peg_ac_parser{ + j["child"].get(), + j["delimiters"].get>(), + }; + } + throw std::runtime_error("Unknown parser type: " + type); } diff --git a/common/peg-parser.h b/common/peg-parser.h index 132173a64c..c198499dd9 100644 --- a/common/peg-parser.h +++ b/common/peg-parser.h @@ -275,6 +275,11 @@ struct common_peg_gbnf_parser { std::string grammar; }; +struct common_peg_ac_parser { + common_peg_parser_id child; + std::vector delimiters; +}; + // Variant holding all parser types using common_peg_parser_variant = std::variant< common_peg_epsilon_parser, @@ -296,7 +301,8 @@ using common_peg_parser_variant = std::variant< common_peg_ref_parser, common_peg_atomic_parser, common_peg_tag_parser, - common_peg_gbnf_parser + common_peg_gbnf_parser, + common_peg_ac_parser >; class common_peg_arena { @@ -514,6 +520,13 @@ class common_peg_parser_builder { // the child's grammar. Parsing delegates entirely to the child. common_peg_parser gbnf(const common_peg_parser & p, const std::string & grammar) { return add(common_peg_gbnf_parser{p, grammar}); } + // Wraps a child parser but emits a GBNF grammar built from the Aho-Corasick + // automaton of `delimiters`, matching everything up to and including the + // first delimiter. Parsing delegates entirely to the child, which is + // responsible for consuming the delimiter (e.g. until(D) + literal(D)). + common_peg_parser ac(const common_peg_parser & p, const std::vector & delimiters); + common_peg_parser ac(const common_peg_parser & p, const std::string & delimiter) { return ac(p, std::vector{delimiter}); } + void set_root(const common_peg_parser & p); common_peg_arena build(); diff --git a/tests/peg-parser/test-gbnf-generation.cpp b/tests/peg-parser/test-gbnf-generation.cpp index 45d692ca60..60066a817b 100644 --- a/tests/peg-parser/test-gbnf-generation.cpp +++ b/tests/peg-parser/test-gbnf-generation.cpp @@ -212,6 +212,75 @@ void test_gbnf_generation(testing &t) { )""", gbnf); }); + t.test("ac grammar", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.ac(p.until("") + p.literal(""), ""); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_gbnf_equal(t, R"""( + ac-3 ::= [<] ac-3-01 | [^<] ac-3 + ac-3-01 ::= [<] ac-3-01 | [/] ac-3-02 | [^/<] ac-3 + ac-3-02 ::= [<] ac-3-01 | [t] ac-3-03 | [^] | [<] ac-3-01 | [^<>] ac-3 + root ::= ac-3 + space ::= | " " | "\n"{1,2} [ \t]{0,20} + )""", gbnf); + }); + + t.test("ac grammar terminates at first delimiter", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.ac(p.until("\n\n") + p.literal("\n\n"), "\n\n"); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_gbnf_equal(t, R"""( + ac-3 ::= [\n] ac-3-01 | [^\n] ac-3 + ac-3-01 ::= [\n] ac-3-01 | [<] ac-3-02 | [^\n<] ac-3 + ac-3-02 ::= [\n] ac-3-01 | [/] ac-3-03 | [^\n/] ac-3 + ac-3-03 ::= [\n] ac-3-01 | [p] ac-3-04 | [^\np] ac-3 + ac-3-04 ::= [\n] ac-3-01 | [a] ac-3-05 | [^\na] ac-3 + ac-3-05 ::= [\n] ac-3-01 | [r] ac-3-06 | [^\nr] ac-3 + ac-3-06 ::= [\n] ac-3-01 | [a] ac-3-07 | [^\na] ac-3 + ac-3-07 ::= [\n] ac-3-01 | [m] ac-3-08 | [^\nm] ac-3 + ac-3-08 ::= [\n] ac-3-01 | [e] ac-3-09 | [^\ne] ac-3 + ac-3-09 ::= [\n] ac-3-01 | [t] ac-3-10 | [^\nt] ac-3 + ac-3-10 ::= [\n] ac-3-01 | [e] ac-3-11 | [^\ne] ac-3 + ac-3-11 ::= [\n] ac-3-01 | [r] ac-3-12 | [^\nr] ac-3 + ac-3-12 ::= [\n] ac-3-01 | [>] ac-3-13 | [^\n>] ac-3 + ac-3-13 ::= [\n] | [^\n] ac-3 + root ::= ac-3 + space ::= | " " | "\n"{1,2} [ \t]{0,20} + )""", gbnf); + }); + + t.test("ac grammar multiple delimiters", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.ac(p.eps(), std::vector{"ab", "cd", "ef"}); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_gbnf_equal(t, R"""( + ac-1 ::= [a] ac-1-01 | [c] ac-1-03 | [e] ac-1-05 | [^ace] ac-1 + ac-1-01 ::= [b] | [a] ac-1-01 | [c] ac-1-03 | [e] ac-1-05 | [^abce] ac-1 + ac-1-03 ::= [d] | [a] ac-1-01 | [c] ac-1-03 | [e] ac-1-05 | [^acde] ac-1 + ac-1-05 ::= [f] | [a] ac-1-01 | [c] ac-1-03 | [e] ac-1-05 | [^acef] ac-1 + root ::= ac-1 + space ::= | " " | "\n"{1,2} [ \t]{0,20} + )""", gbnf); + }); + t.test("complex expressions with parentheses", [](testing &t) { auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.one_or_more(p.literal("a") | p.literal("b")); From 0ef6f06d553b160d8fc1fba38f5848c7940873a2 Mon Sep 17 00:00:00 2001 From: aafsmarak <92150196+aafsmarak@users.noreply.github.com> Date: Mon, 22 Jun 2026 09:18:31 +0530 Subject: [PATCH 44/86] docs/android.md: Add dependency `libandroid-spawn` for building in termux (#21812) Fixes https://github.com/ggml-org/llama.cpp/issues/18615 --- docs/android.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/android.md b/docs/android.md index 964ce8a1f0..e8d580a9ed 100644 --- a/docs/android.md +++ b/docs/android.md @@ -29,7 +29,7 @@ With Termux, you can install and run `llama.cpp` as if the environment were Linu ``` $ apt update && apt upgrade -y -$ apt install git cmake +$ apt install git cmake libandroid-spawn ``` Then, follow the [build instructions](https://github.com/ggml-org/llama.cpp/blob/master/docs/build.md), specifically for CMake. From d0f9d2e5ac5d4f51763755958b8f353fed01aaa2 Mon Sep 17 00:00:00 2001 From: Pascal Date: Mon, 22 Jun 2026 10:55:28 +0200 Subject: [PATCH 45/86] server: fix edit_file crash on append at end of file (line_start -1) (#24893) line_start -1 normalized to n+1, so append inserted at lines.begin() + n + 1, one past end() -> heap-buffer-overflow in vector::_M_range_insert. Normalize -1 to n (insert at end()), restrict -1 to append mode and reject it for replace/delete instead of silently clobbering the last line. Parenthesize the insert offset so empty-file append computes the position as int first, avoiding a transient begin() - 1 on a null vector data pointer. --- tools/server/server-tools.cpp | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/tools/server/server-tools.cpp b/tools/server/server-tools.cpp index 95662d4ecb..790ed85a06 100644 --- a/tools/server/server-tools.cpp +++ b/tools/server/server-tools.cpp @@ -569,9 +569,13 @@ struct server_tool_edit_file : server_tool { } int n = (int) lines.size(); if (e.line_start == -1) { - // -1 means end of file; line_end is ignored — normalize to point past last line - e.line_start = n + 1; - e.line_end = n + 1; + // -1 targets end of file -> valid for append only; line_end is ignored + if (e.mode != "append") { + return {{"error", "line_start -1 (end of file) is only valid for append mode"}}; + } + // append at end of file: insert position is the current line count + e.line_start = n; + e.line_end = n; } else { if (e.line_start < 1 || e.line_end < e.line_start) { return {{"error", string_format("invalid line range [%d, %d]", e.line_start, e.line_end)}}; @@ -612,8 +616,8 @@ struct server_tool_edit_file : server_tool { } else if (e.mode == "delete") { lines.erase(lines.begin() + idx_start, lines.begin() + idx_end + 1); } else { // append - // idx_end + 1 may equal lines.size() when line_start == -1 (end of file) - lines.insert(lines.begin() + idx_end + 1, new_lines.begin(), new_lines.end()); + // insert after idx_end; idx_end + 1 == lines.size() for end-of-file append + lines.insert(lines.begin() + (idx_end + 1), new_lines.begin(), new_lines.end()); } } From 37957e8531bcd2e5f98233d6ecc864f2b76e6b8b Mon Sep 17 00:00:00 2001 From: Tim Neumann Date: Mon, 22 Jun 2026 13:08:32 +0200 Subject: [PATCH 46/86] sampling : remove unconditional softmax+sort in top-n-sigma sampler (#22645) --- src/llama-sampler.cpp | 2 -- tests/test-sampling.cpp | 4 ++-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/llama-sampler.cpp b/src/llama-sampler.cpp index 9bbc5dbde2..2370e91a14 100644 --- a/src/llama-sampler.cpp +++ b/src/llama-sampler.cpp @@ -2813,8 +2813,6 @@ static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_t cur_p->data[i].logit = -INFINITY; } } - - llama_sampler_softmax_impl(cur_p, true); } static struct llama_sampler * llama_sampler_top_n_sigma_clone(const struct llama_sampler * smpl) { diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp index 7cd96c5cd3..2aecff90e7 100644 --- a/tests/test-sampling.cpp +++ b/tests/test-sampling.cpp @@ -360,9 +360,9 @@ int main(void) { test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 1}, {0.241818f, 0.241818f, 0.032727f, 0.241818f, 0.241818f}, 2.0f, 1.1f, 2, 5, {}); test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 3, 4, 0, 1}, {0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, 1.0f, 1.1f, 4, 7, {}); - test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.571429f, 0.428571f, 0.0f, 0.0f}, 1.00f); + test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.0f, 0.0f, 0.428571f, 0.571429f}, 1.00f); test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.1f, 0.2f, 0.3f, 0.4f}, 0.00f); // top_n_sigma == 0 now represents a no-op rather than greedy decoding as of PR#13345 - test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 3.00f); + test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.1f, 0.2f, 0.3f, 0.4f}, 3.00f); test_sampler_queue(10000, "k", 10000, 1.0f, 1.0f); test_sampler_queue(10000, "k", 1, 1.0f, 1.0f); From f8cc15f163e784c58fe13aee58ebc03055bb0c40 Mon Sep 17 00:00:00 2001 From: Neo Zhang Date: Mon, 22 Jun 2026 19:09:02 +0800 Subject: [PATCH 47/86] [SYCL] support bf16 on bin_bcast OP and unary OPs (#24838) * support bf16 on bin_bcast OP and unary OPs * support the older Intel compiler than 2026.0 --- ggml/src/ggml-sycl/binbcast.cpp | 5 + ggml/src/ggml-sycl/element_wise.cpp | 208 +++++++++++++++++++++------- 2 files changed, 160 insertions(+), 53 deletions(-) diff --git a/ggml/src/ggml-sycl/binbcast.cpp b/ggml/src/ggml-sycl/binbcast.cpp index ad2e6ca35e..306eeddc0c 100644 --- a/ggml/src/ggml-sycl/binbcast.cpp +++ b/ggml/src/ggml-sycl/binbcast.cpp @@ -293,6 +293,11 @@ inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_t (sycl::ext::oneapi::bfloat16 *) dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2, nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1), main_stream); + } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_BF16) { + op()((const sycl::ext::oneapi::bfloat16 *) src0->data, (const float *) src1->data, + (sycl::ext::oneapi::bfloat16 *) dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2, + ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2, nb3, ggml_is_contiguous(src0), + ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1), main_stream); #endif } else { fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__, ggml_type_name(dst->type), diff --git a/ggml/src/ggml-sycl/element_wise.cpp b/ggml/src/ggml-sycl/element_wise.cpp index aca68e58ee..0c82ceb969 100644 --- a/ggml/src/ggml-sycl/element_wise.cpp +++ b/ggml/src/ggml-sycl/element_wise.cpp @@ -43,14 +43,44 @@ static __dpct_inline__ T op_sgn(T x) { return x > static_cast(0.f) ? static_cast(1.f) : ((x < static_cast(0.f) ? static_cast(-1.f) : static_cast(0.f))); } + template static __dpct_inline__ T op_abs(T x) { - return sycl::fabs(x); + if constexpr (std::is_same_v) { + return sycl::ext::oneapi::experimental::fabs(x); // or experimental namespace if needed + } else { + return sycl::fabs(x); + } +} + +template +static __dpct_inline__ T op_expm1(T x) { + if constexpr (std::is_same_v) { + return static_cast( + sycl::expm1(static_cast(x)) + ); + } else { + return sycl::expm1(x); + } } template static __dpct_inline__ T op_elu(T x) { - return (x > static_cast(0.f)) ? x : sycl::expm1(x); + return (x > static_cast(0.f)) ? x : op_expm1(x); +} + +template +static __dpct_inline__ T op_tanh(T x) { + if constexpr (std::is_same_v) { + constexpr int ver = __INTEL_LLVM_COMPILER; +#if defined(__INTEL_LLVM_COMPILER) && (__INTEL_LLVM_COMPILER >= 20260000) + return sycl::ext::oneapi::experimental::tanh(x); +#else + return static_cast(sycl::tanh(static_cast(x))); +#endif + } else { + return sycl::tanh(x); + } } template @@ -59,74 +89,106 @@ static __dpct_inline__ T op_gelu(T x) { const T SQRT_2_OVER_PI = static_cast(0.79788456080286535587989211986876f); return static_cast(0.5f) * x * (static_cast(1.0f) + - sycl::tanh(SQRT_2_OVER_PI * x * (static_cast(1.0f) + GELU_COEF_A * x * x))); + op_tanh(SQRT_2_OVER_PI * x * (static_cast(1.0f) + GELU_COEF_A * x * x))); +} + +template +static __dpct_inline__ T op_exp(T x) { + if constexpr (std::is_same_v) { + return sycl::ext::oneapi::experimental::exp(x); + } else { + return sycl::exp(x); + } } template static __dpct_inline__ T op_silu(T x) { - return x / (static_cast(1.0f) + sycl::native::exp(-x)); + return x / (static_cast(1.0f) + op_exp(-x)); } template -static __dpct_inline__ T op_gelu_quick(T x) { - const T GELU_QUICK_COEF_LOCAL = static_cast(-1.702f); - return x * (static_cast(1.0f) / (static_cast(1.0f) + sycl::native::exp(GELU_QUICK_COEF_LOCAL * x))); +static __dpct_inline__ T op_erf(T x) { + if constexpr (std::is_same_v) { + return static_cast( + sycl::erf(static_cast(x)) + ); + } else { + return sycl::erf(x); + } } template static __dpct_inline__ T op_gelu_erf(T x) { const T SQRT_2_INV = static_cast(0.70710678118654752440084436210484f); - return static_cast(0.5f) * x * (static_cast(1.0f) + sycl::erf(x * SQRT_2_INV)); + return static_cast(0.5f) * x * (static_cast(1.0f) + op_erf(x * SQRT_2_INV)); } template -static __dpct_inline__ T op_tanh(T x) { - return sycl::tanh(x); +static __dpct_inline__ T op_gelu_quick(T x) { + const T GELU_QUICK_COEF_LOCAL = static_cast(-1.702f); + return x * (static_cast(1.0f) / (static_cast(1.0f) + op_exp(GELU_QUICK_COEF_LOCAL * x))); } template static __dpct_inline__ T op_relu(T x) { - return sycl::fmax(x, static_cast(0)); + if constexpr (std::is_same_v) { + return sycl::ext::oneapi::experimental::fmax(x, static_cast(0)); + } else { + return sycl::fmax(x, static_cast(0)); + } } template static __dpct_inline__ T op_sigmoid(T x) { - return static_cast(1.0f) / (static_cast(1.0f) + sycl::native::exp(-x)); + return static_cast(1.0f) / (static_cast(1.0f) + op_exp(-x)); } template static __dpct_inline__ T op_sqrt(T x) { - return sycl::sqrt(x); + if constexpr (std::is_same_v) { + return sycl::ext::oneapi::experimental::sqrt(x); + } else { + return sycl::sqrt(x); + } } template static __dpct_inline__ T op_sin(T x) { - return sycl::sin(x); + if constexpr (std::is_same_v) { + return sycl::ext::oneapi::experimental::sin(x); + } else { + return sycl::sin(x); + } } template static __dpct_inline__ T op_cos(T x) { - return sycl::cos(x); + if constexpr (std::is_same_v) { + return sycl::ext::oneapi::experimental::cos(x); + } else { + return sycl::cos(x); + } } template static __dpct_inline__ T op_hardsigmoid(T x) { - return sycl::fmin(static_cast(1.0f), sycl::fmax(static_cast(0.0f), (x + static_cast(3.0f)) / static_cast(6.0f))); + if constexpr (std::is_same_v) { + return sycl::ext::oneapi::experimental::fmin( + static_cast(1.0f), sycl::ext::oneapi::experimental::fmax( + static_cast(0.0f), (x + static_cast(3.0f)) / static_cast(6.0f))); + } else { + return sycl::fmin(static_cast(1.0f), + sycl::fmax(static_cast(0.0f), (x + static_cast(3.0f)) / static_cast(6.0f))); + } } template static __dpct_inline__ T op_hardswish(T x) { - return x * sycl::fmin(static_cast(1.0f), sycl::fmax(static_cast(0.0f), (x + static_cast(3.0f)) / static_cast(6.0f))); -} - -template -static __dpct_inline__ T op_exp(T x) { - return sycl::exp(x); -} - -template -static __dpct_inline__ T op_expm1(T x) { - return sycl::expm1(x); + if constexpr (std::is_same_v) { + return x * sycl::ext::oneapi::experimental::fmin(static_cast(1.0f), sycl::ext::oneapi::experimental::fmax(static_cast(0.0f), (x + static_cast(3.0f)) / static_cast(6.0f))); + } else { + return x * sycl::fmin(static_cast(1.0f), sycl::fmax(static_cast(0.0f), (x + static_cast(3.0f)) / static_cast(6.0f))); + } } template @@ -134,13 +196,17 @@ static __dpct_inline__ T op_log(T x) { if (x <= static_cast(0)) { return neg_infinity(); } - return sycl::log(x); + if constexpr (std::is_same_v) { + return sycl::ext::oneapi::experimental::log(x); + } else { + return sycl::log(x); + } } template static __dpct_inline__ T op_softplus(T x) { const float xf = (float) x; - const float ax = sycl::fabs(xf); + const float ax = op_abs(xf); const float m = sycl::fmax(xf, 0.0f); const float y = m + sycl::log1p(sycl::exp(-ax)); return (T) y; @@ -159,8 +225,14 @@ static __dpct_inline__ T op_step(T x) { template static __dpct_inline__ T op_leaky_relu(T x, float negative_slope) { T neg_slope_T = static_cast(negative_slope); - return sycl::fmax(x, static_cast(0)) + + if constexpr (std::is_same_v) { + return sycl::ext::oneapi::experimental::fmax(x, static_cast(0)) + + sycl::ext::oneapi::experimental::fmin(x, static_cast(0.0f)) * neg_slope_T; + + } else { + return sycl::fmax(x, static_cast(0)) + sycl::fmin(x, static_cast(0.0f)) * neg_slope_T; + } } template @@ -175,22 +247,40 @@ static __dpct_inline__ T op_clamp(T x, float min_val, float max_val) { template static __dpct_inline__ T op_floor(T x) { - return sycl::floor(x); + if constexpr (std::is_same_v) { + return sycl::ext::oneapi::experimental::floor(x); + } else { + return sycl::floor(x); + } } template static __dpct_inline__ T op_ceil(T x) { - return sycl::ceil(x); + if constexpr (std::is_same_v) { + return sycl::ext::oneapi::experimental::ceil(x); + } else { + return sycl::ceil(x); + } } template static __dpct_inline__ T op_round(T x) { - return sycl::round(x); + if constexpr (std::is_same_v) { + return static_cast( + sycl::round(static_cast(x)) + ); + } else { + return sycl::round(x); + } } template static __dpct_inline__ T op_trunc(T x) { - return sycl::trunc(x); + if constexpr (std::is_same_v) { + return sycl::ext::oneapi::experimental::trunc(x); + } else { + return sycl::trunc(x); + } } template @@ -339,7 +429,7 @@ static void acc_f32_sycl(const float *x, const float *y, float *dst, const int num_blocks = (n_elements + SYCL_ACC_BLOCK_SIZE - 1) / SYCL_ACC_BLOCK_SIZE; stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE)), - [=](sycl::nd_item<3> /*item_ct1*/) { + [=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, ne13, s1, s2, s3, offset); }); } @@ -354,8 +444,8 @@ static void arange_kernel(T * dst, const int k, T start, T step, template static inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) { - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16 || dst->src[0]->type == GGML_TYPE_BF16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_BF16); GGML_ASSERT(dst->src[0]->type == dst->type); dpct::queue_ptr main_stream = ctx.stream(); @@ -367,6 +457,14 @@ static inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx, kernel_invoker(data_pts.src, data_pts.dst, (int)ggml_nelements(dst->src[0]), main_stream, std::forward(args)...); break; } +#ifdef GGML_SYCL_HAS_BF16 + case GGML_TYPE_BF16: + { + auto data_pts = cast_data(dst); + kernel_invoker(data_pts.src, data_pts.dst, (int)ggml_nelements(dst->src[0]), main_stream, std::forward(args)...); + break; + } +#endif case GGML_TYPE_F32: { auto data_pts = cast_data(dst); @@ -480,7 +578,7 @@ static inline void ggml_sycl_op_unary( stream->parallel_for( sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256), sycl::range<1>(256)), - [=](sycl::nd_item<1> item_ct1) { + [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { unary_op_generic_kernel( src, dst_ptr, k_elements, ne0, ne1, ne2, ne3, @@ -508,7 +606,7 @@ static inline void ggml_sycl_op_arange(ggml_backend_sycl_context & ctx, ggml_ten stream->parallel_for( sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_ARANGE_BLOCK_SIZE), sycl::range<1>(SYCL_ARANGE_BLOCK_SIZE)), - [=](sycl::nd_item<1> item_ct1) { + [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { arange_kernel(dst_ptr, k, start, step, item_ct1); }); } @@ -602,7 +700,7 @@ static inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor stream->parallel_for( sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_EXP_BLOCK_SIZE), sycl::range<1>(SYCL_EXP_BLOCK_SIZE)), - [=](sycl::nd_item<1> item_ct1) { + [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { unary_op_log_kernel(src, dst_ptr, k_elements, item_ct1); }); }); @@ -640,7 +738,7 @@ static inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, ggml_tenso stream->parallel_for( sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SQRT_BLOCK_SIZE), sycl::range<1>(SYCL_SQRT_BLOCK_SIZE)), - [=](sycl::nd_item<1> item_ct1) { + [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { unary_op_sqrt_kernel(src, dst_ptr, k_elements, item_ct1); }); }); @@ -653,7 +751,7 @@ static inline void ggml_sycl_op_sin(ggml_backend_sycl_context & ctx, ggml_tensor stream->parallel_for( sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIN_BLOCK_SIZE), sycl::range<1>(SYCL_SIN_BLOCK_SIZE)), - [=](sycl::nd_item<1> item_ct1) { + [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { unary_op_sin_kernel(src, dst_ptr, k_elements, item_ct1); }); }); @@ -666,7 +764,7 @@ static inline void ggml_sycl_op_cos(ggml_backend_sycl_context & ctx, ggml_tensor stream->parallel_for( sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIN_BLOCK_SIZE), sycl::range<1>(SYCL_SIN_BLOCK_SIZE)), - [=](sycl::nd_item<1> item_ct1) { + [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { unary_op_cos_kernel(src, dst_ptr, k_elements, item_ct1); }); }); @@ -681,7 +779,7 @@ static inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, ggml stream->parallel_for( sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_RELU_BLOCK_SIZE), sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), - [=](sycl::nd_item<1> item_ct1) { + [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { unary_op_leaky_relu_kernel(src, dst_ptr, k_elements, slope, item_ct1); }); }, negative_slope); @@ -694,7 +792,7 @@ static inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, ggml_tensor stream->parallel_for( sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SQR_BLOCK_SIZE), sycl::range<1>(SYCL_SQR_BLOCK_SIZE)), - [=](sycl::nd_item<1> item_ct1) { + [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { unary_op_sqr_kernel(src, dst_ptr, k_elements, item_ct1); }); }); @@ -711,7 +809,7 @@ static inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tens stream->parallel_for( sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_CLAMP_BLOCK_SIZE), sycl::range<1>(SYCL_CLAMP_BLOCK_SIZE)), - [=](sycl::nd_item<1> item_ct1) { + [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { clamp(src, dst_ptr, min_arg, max_arg, k_elements, item_ct1); }); }, min_val, max_val); @@ -774,7 +872,8 @@ static inline void ggml_sycl_op_geglu(ggml_backend_sycl_context & ctx, ggml_tens [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) { const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE); main_stream->parallel_for( - sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { + sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), + sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { gated_op_fused_geglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1); }); }); @@ -785,7 +884,8 @@ static inline void ggml_sycl_op_reglu(ggml_backend_sycl_context & ctx, ggml_tens [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) { const uint32_t num_blocks = ceil_div((uint32_t)k, SYCL_RELU_BLOCK_SIZE); // Using RELU block size for reglu main_stream->parallel_for( - sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { + sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), + sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { gated_op_fused_reglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1); }); }); @@ -796,7 +896,8 @@ static inline void ggml_sycl_op_swiglu(ggml_backend_sycl_context & ctx, ggml_ten [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) { const uint32_t num_blocks = ceil_div((uint32_t)k, SYCL_SILU_BLOCK_SIZE); // Using SILU block size for swiglu main_stream->parallel_for( - sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { + sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), + sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { gated_op_fused_swiglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1); }); }); @@ -811,7 +912,6 @@ __dpct_inline__ float ggml_sycl_op_swiglu_oai_single(float x, float g, float alp return out_glu; } - template static void swiglu_oai_kernel(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, @@ -845,7 +945,7 @@ static void swiglu_oai_sycl(const T * x, const int64_t num_blocks = (k + SYCL_GLU_BLOCK_SIZE - 1) / SYCL_GLU_BLOCK_SIZE; stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_GLU_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_GLU_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { swiglu_oai_kernel(x, g, dst, k, n, o0, o1, alpha, limit, item_ct1); }); } @@ -899,7 +999,8 @@ static inline void ggml_sycl_op_geglu_erf(ggml_backend_sycl_context & ctx, ggml_ [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) { const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE); main_stream->parallel_for( - sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { + sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), + sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { gated_op_fused_geglu_erf(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1); }); }); @@ -910,7 +1011,8 @@ static inline void ggml_sycl_op_geglu_quick(ggml_backend_sycl_context & ctx, ggm [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) { const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE); main_stream->parallel_for( - sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { + sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), + sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { gated_op_fused_geglu_quick(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1); }); }); From 099b579acb9f7fd0eefcbb2198fd453b00c6e787 Mon Sep 17 00:00:00 2001 From: Pascal Date: Mon, 22 Jun 2026 15:55:30 +0200 Subject: [PATCH 48/86] ui: model status and load progress via /models/sse feed (#24878) * ui: model status and load progress via /models/sse feed * ui: centralize SSE wire-format delimiters into shared constants for the chat and /models/sse parsers * ui: type /models/sse event names as a ServerModelsSseEventType enum Address review from allozaur --- tools/ui/src/app.d.ts | 10 + .../ChatMessageAssistant.svelte | 15 +- .../app/models/ModelsSelectorOption.svelte | 19 +- tools/ui/src/lib/constants/api-endpoints.ts | 3 +- tools/ui/src/lib/constants/index.ts | 2 + tools/ui/src/lib/constants/model-loading.ts | 14 + tools/ui/src/lib/constants/sse.ts | 16 + tools/ui/src/lib/enums/index.ts | 2 +- tools/ui/src/lib/enums/server.enums.ts | 14 + tools/ui/src/lib/services/chat.service.ts | 16 +- tools/ui/src/lib/stores/models.svelte.ts | 280 +++++++++++++++--- tools/ui/src/lib/types/api.d.ts | 48 ++- tools/ui/src/lib/types/index.ts | 11 +- tools/ui/src/lib/types/models.d.ts | 13 +- tools/ui/src/lib/utils/index.ts | 3 + tools/ui/src/lib/utils/progress.ts | 43 +++ tools/ui/src/routes/+layout.svelte | 14 + 17 files changed, 466 insertions(+), 57 deletions(-) create mode 100644 tools/ui/src/lib/constants/model-loading.ts create mode 100644 tools/ui/src/lib/constants/sse.ts create mode 100644 tools/ui/src/lib/utils/progress.ts diff --git a/tools/ui/src/app.d.ts b/tools/ui/src/app.d.ts index a7583eec59..5264e5cc4d 100644 --- a/tools/ui/src/app.d.ts +++ b/tools/ui/src/app.d.ts @@ -19,6 +19,10 @@ import type { ApiErrorResponse, ApiLlamaCppServerProps, ApiModelDataEntry, + ApiModelLoadStage, + ApiModelsSseProgress, + ApiModelsSseData, + ApiModelsSseEvent, ApiModelListResponse, ApiProcessingState, ApiRouterModelMeta, @@ -52,6 +56,7 @@ import type { // Model types ModelModalities, ModelOption, + ModelLoadProgress, // Settings types SettingsChatServiceOptions, SettingsConfigValue, @@ -83,6 +88,10 @@ declare global { ApiErrorResponse, ApiLlamaCppServerProps, ApiModelDataEntry, + ApiModelLoadStage, + ApiModelsSseProgress, + ApiModelsSseData, + ApiModelsSseEvent, ApiModelListResponse, ApiProcessingState, ApiRouterModelMeta, @@ -120,6 +129,7 @@ declare global { // Model types ModelModalities, ModelOption, + ModelLoadProgress, // Settings types SettingsChatServiceOptions, SettingsConfigValue, diff --git a/tools/ui/src/lib/components/app/chat/ChatMessages/ChatMessage/ChatMessageAssistant/ChatMessageAssistant.svelte b/tools/ui/src/lib/components/app/chat/ChatMessages/ChatMessage/ChatMessageAssistant/ChatMessageAssistant.svelte index 4c74206f1b..2272eaedb3 100644 --- a/tools/ui/src/lib/components/app/chat/ChatMessages/ChatMessage/ChatMessageAssistant/ChatMessageAssistant.svelte +++ b/tools/ui/src/lib/components/app/chat/ChatMessages/ChatMessage/ChatMessageAssistant/ChatMessageAssistant.svelte @@ -10,7 +10,7 @@ import { getMessageEditContext } from '$lib/contexts'; import { useProcessingState } from '$lib/hooks/use-processing-state.svelte'; import { isLoading, isChatStreaming } from '$lib/stores/chat.svelte'; - import { copyToClipboard, deriveAgenticSections } from '$lib/utils'; + import { copyToClipboard, deriveAgenticSections, modelLoadProgressText } from '$lib/utils'; import { AgenticSectionType } from '$lib/enums'; import { REASONING_TAGS } from '$lib/constants/agentic'; import { tick } from 'svelte'; @@ -185,6 +185,13 @@ let hasNoContent = $derived(!message?.content?.trim()); let isActivelyProcessing = $derived(isCurrentlyLoading || isStreaming); + // during a router auto-load the message has no model yet, so target the selected one + let loadTargetModel = $derived(message.model ?? modelsStore.selectedModelName); + let modelLoadProgress = $derived( + isRouter && loadTargetModel ? modelsStore.getLoadProgress(loadTargetModel) : null + ); + let modelLoadingText = $derived(modelLoadProgressText(modelLoadProgress)); + let showProcessingInfoTop = $derived( message?.role === MessageRole.ASSISTANT && isActivelyProcessing && @@ -220,7 +227,8 @@
- {processingState.getPromptProgressText() ?? + {modelLoadingText ?? + processingState.getPromptProgressText() ?? processingState.getProcessingMessage() ?? 'Processing...'} @@ -252,7 +260,8 @@
- {processingState.getPromptProgressText() ?? + {modelLoadingText ?? + processingState.getPromptProgressText() ?? processingState.getProcessingMessage() ?? 'Processing...'} diff --git a/tools/ui/src/lib/components/app/models/ModelsSelectorOption.svelte b/tools/ui/src/lib/components/app/models/ModelsSelectorOption.svelte index fef1490f37..f2a024d31d 100644 --- a/tools/ui/src/lib/components/app/models/ModelsSelectorOption.svelte +++ b/tools/ui/src/lib/components/app/models/ModelsSelectorOption.svelte @@ -13,6 +13,7 @@ import type { ModelOption } from '$lib/types/models'; import { ServerModelStatus } from '$lib/enums'; import { modelsStore, routerModels } from '$lib/stores/models.svelte'; + import { modelLoadFraction, modelLoadProgressText } from '$lib/utils'; interface Props { option: ModelOption; @@ -50,11 +51,15 @@ (serverStatus === ServerModelStatus.LOADED || isSleeping) && !isOperationInProgress ); let isLoading = $derived(serverStatus === ServerModelStatus.LOADING || isOperationInProgress); + + let loadProgress = $derived(isLoading ? modelsStore.getLoadProgress(option.model) : null); + let loadPercent = $derived(Math.round(modelLoadFraction(loadProgress) * 100)); + let loadTitle = $derived(modelLoadProgressText(loadProgress));
onSelect(option.id)} onmouseenter={onMouseEnter} @@ -188,4 +194,15 @@
{/if}
+ + {#if isLoading} +
+
+
+ {/if}
diff --git a/tools/ui/src/lib/constants/api-endpoints.ts b/tools/ui/src/lib/constants/api-endpoints.ts index 9eb6c74e75..a410905057 100644 --- a/tools/ui/src/lib/constants/api-endpoints.ts +++ b/tools/ui/src/lib/constants/api-endpoints.ts @@ -1,7 +1,8 @@ export const API_MODELS = { LIST: '/v1/models', LOAD: '/models/load', - UNLOAD: '/models/unload' + UNLOAD: '/models/unload', + SSE: '/models/sse' }; // chat completion routes, the control route drives realtime inference (e.g. end reasoning) diff --git a/tools/ui/src/lib/constants/index.ts b/tools/ui/src/lib/constants/index.ts index c51d84cdc2..4993ab647a 100644 --- a/tools/ui/src/lib/constants/index.ts +++ b/tools/ui/src/lib/constants/index.ts @@ -37,6 +37,8 @@ export * from './mcp-form'; export * from './mcp-resource'; export * from './message-export'; export * from './model-id'; +export * from './model-loading'; +export * from './sse'; export * from './precision'; export * from './processing-info'; export * from './pwa'; diff --git a/tools/ui/src/lib/constants/model-loading.ts b/tools/ui/src/lib/constants/model-loading.ts new file mode 100644 index 0000000000..a55ba708b1 --- /dev/null +++ b/tools/ui/src/lib/constants/model-loading.ts @@ -0,0 +1,14 @@ +/** + * Labels shown while a model loads, keyed by the stage reported on /models/sse. + */ +export const MODEL_LOAD_STAGE_LABELS: Record = { + text_model: 'Loading weights', + spec_model: 'Loading draft', + mmproj_model: 'Loading projector' +}; + +/** + * Share of the bar reserved for each load phase after text_model. + * text_model fills the rest, so a plain model reaches 100% on its own. + */ +export const MODEL_LOAD_TAIL_SHARE = 0.1; diff --git a/tools/ui/src/lib/constants/sse.ts b/tools/ui/src/lib/constants/sse.ts new file mode 100644 index 0000000000..0eb4b6edee --- /dev/null +++ b/tools/ui/src/lib/constants/sse.ts @@ -0,0 +1,16 @@ +/** + * Server-sent events wire format, shared by the chat stream and the + * /models/sse status feed (text/event-stream). + */ + +// blank line between two events +export const SSE_RECORD_SEPARATOR = '\n\n'; + +// line break inside an event +export const SSE_LINE_SEPARATOR = '\n'; + +// data field prefix, the value follows after an optional space +export const SSE_DATA_PREFIX = 'data:'; + +// end-of-stream marker on the chat completion stream +export const SSE_DONE_MARKER = '[DONE]'; diff --git a/tools/ui/src/lib/enums/index.ts b/tools/ui/src/lib/enums/index.ts index 449e4f90a9..811744fd9a 100644 --- a/tools/ui/src/lib/enums/index.ts +++ b/tools/ui/src/lib/enums/index.ts @@ -54,7 +54,7 @@ export { export { ModelModality } from './model.enums'; -export { ServerRole, ServerModelStatus } from './server.enums'; +export { ServerRole, ServerModelStatus, ServerModelsSseEventType } from './server.enums'; export { ParameterSource, SyncableParameterType, SettingsFieldType } from './settings.enums'; diff --git a/tools/ui/src/lib/enums/server.enums.ts b/tools/ui/src/lib/enums/server.enums.ts index c9d599c52b..446af84be7 100644 --- a/tools/ui/src/lib/enums/server.enums.ts +++ b/tools/ui/src/lib/enums/server.enums.ts @@ -19,3 +19,17 @@ export enum ServerModelStatus { SLEEPING = 'sleeping', FAILED = 'failed' } + +/** + * /models/sse event type enum - discriminates the records broadcast on the + * model status feed in ROUTER mode. Matches the event names emitted by + * tools/server/server-models.cpp from the C++ server. + */ +export enum ServerModelsSseEventType { + STATUS_CHANGE = 'status_change', + MODEL_STATUS = 'model_status', + STATUS_UPDATE = 'status_update', + MODELS_RELOAD = 'models_reload', + MODEL_REMOVE = 'model_remove', + DOWNLOAD_PROGRESS = 'download_progress' +} diff --git a/tools/ui/src/lib/services/chat.service.ts b/tools/ui/src/lib/services/chat.service.ts index 70844f57ee..9001c9572f 100644 --- a/tools/ui/src/lib/services/chat.service.ts +++ b/tools/ui/src/lib/services/chat.service.ts @@ -10,7 +10,10 @@ import { SETTINGS_KEYS, API_CHAT, API_SLOTS, - CONTROL_ACTION + CONTROL_ACTION, + SSE_LINE_SEPARATOR, + SSE_DATA_PREFIX, + SSE_DONE_MARKER } from '$lib/constants'; import { AttachmentType, @@ -18,8 +21,7 @@ import { FileTypeAudio, MessageRole, MimeTypeAudio, - ReasoningFormat, - UrlProtocol + ReasoningFormat } from '$lib/enums'; import type { ApiChatMessageContentPart, @@ -642,15 +644,15 @@ export class ChatService { if (abortSignal?.aborted) break; chunk += decoder.decode(value, { stream: true }); - const lines = chunk.split('\n'); + const lines = chunk.split(SSE_LINE_SEPARATOR); chunk = lines.pop() || ''; for (const line of lines) { if (abortSignal?.aborted) break; - if (line.startsWith(UrlProtocol.DATA)) { - const data = line.slice(6); - if (data === '[DONE]') { + if (line.startsWith(SSE_DATA_PREFIX)) { + const data = line.slice(SSE_DATA_PREFIX.length).trim(); + if (data === SSE_DONE_MARKER) { streamFinished = true; continue; diff --git a/tools/ui/src/lib/stores/models.svelte.ts b/tools/ui/src/lib/stores/models.svelte.ts index 1990ba6049..2ce450d423 100644 --- a/tools/ui/src/lib/stores/models.svelte.ts +++ b/tools/ui/src/lib/stores/models.svelte.ts @@ -1,6 +1,7 @@ +import { base } from '$app/paths'; import { SvelteMap, SvelteSet } from 'svelte/reactivity'; import { toast } from 'svelte-sonner'; -import { ServerModelStatus, ModelModality } from '$lib/enums'; +import { ServerModelStatus, ServerModelsSseEventType, ModelModality } from '$lib/enums'; import { ModelsService } from '$lib/services/models.service'; import { PropsService } from '$lib/services/props.service'; import { serverStore, isRouterMode } from '$lib/stores/server.svelte'; @@ -8,11 +9,15 @@ import { detectThinkingSupport, detectThinkingSupportWithReason } from '$lib/utils/chat-template-thinking-detector'; -import { TTLCache } from '$lib/utils'; +import { TTLCache, getAuthHeaders } from '$lib/utils'; import { MODEL_PROPS_CACHE_TTL_MS, MODEL_PROPS_CACHE_MAX_ENTRIES, - FAVORITE_MODELS_LOCALSTORAGE_KEY + FAVORITE_MODELS_LOCALSTORAGE_KEY, + API_MODELS, + SSE_RECORD_SEPARATOR, + SSE_LINE_SEPARATOR, + SSE_DATA_PREFIX } from '$lib/constants'; import { conversationsStore } from '$lib/stores/conversations.svelte'; @@ -55,6 +60,15 @@ class ModelsStore { private modelUsage = $state>>(new Map()); private modelLoadingStates = new SvelteMap(); + // /models/sse feed state, the single source of truth for status and load progress + private statusAbort: AbortController | null = null; + private statusReaderActive = false; + private loadProgress = new SvelteMap(); + private statusWaiters = new Map< + string, + { target: ServerModelStatus; resolve: () => void; reject: (e: Error) => void } + >(); + favoriteModelIds = $state>(this.loadFavoritesFromStorage()); /** @@ -626,49 +640,218 @@ class ModelsStore { * */ - /** - * WORKAROUND: Polling for model status after load/unload operations. - * - * Currently, `/models/load` and `/models/unload` return success before - * the operation actually completes on the server. - * - * TODO: Remove polling once llama-server properly waits for the operation - * to complete before returning success. - */ - - private static readonly STATUS_POLL_INTERVAL = 500; + // reconnect delay after the feed drops or the server is not ready yet + private static readonly SSE_RECONNECT_MS = 1000; /** - * Poll for expected model status after load/unload operation. - * Keeps polling until the model reaches the expected status or fails. + * Open the /models/sse feed and keep it live with auto reconnect. + * Idempotent and router mode only. The feed drives status and progress, + * so it replaces any post-operation polling. */ - private async pollForModelStatus( - modelId: string, - expectedStatus: ServerModelStatus - ): Promise { - let attempt = 0; - while (true) { - await this.fetchRouterModels(); + subscribeStatus(): void { + if (this.statusReaderActive) return; + if (!isRouterMode()) return; - const currentStatus = this.getModelStatus(modelId); - if (currentStatus === expectedStatus) return; + this.statusReaderActive = true; + this.statusAbort = new AbortController(); + void this.runStatusReader(this.statusAbort.signal); + } - if (currentStatus === ServerModelStatus.FAILED) { - throw new Error( - `Model failed to ${expectedStatus === ServerModelStatus.LOADED ? 'load' : 'unload'}` - ); + /** + * Close the /models/sse feed and drop transient progress. + */ + unsubscribeStatus(): void { + this.statusReaderActive = false; + this.statusAbort?.abort(); + this.statusAbort = null; + this.loadProgress.clear(); + } + + /** + * Current load progress for a model, or null when not loading. + */ + getLoadProgress(modelId: string): ModelLoadProgress | null { + return this.loadProgress.get(modelId) ?? null; + } + + /** + * Read the feed and reconnect until unsubscribed. Splits the byte stream + * into SSE records on the blank line boundary. + */ + private async runStatusReader(signal: AbortSignal): Promise { + const decoder = new TextDecoder(); + + while (!signal.aborted) { + try { + const response = await fetch(`${base}${API_MODELS.SSE}`, { + headers: getAuthHeaders(), + signal + }); + + if (response.ok && response.body) { + const reader = response.body.getReader(); + let buffer = ''; + + while (!signal.aborted) { + const { value, done } = await reader.read(); + if (done) break; + + buffer += decoder.decode(value, { stream: true }); + + let boundary = buffer.indexOf(SSE_RECORD_SEPARATOR); + while (boundary !== -1) { + this.handleStatusRecord(buffer.slice(0, boundary)); + buffer = buffer.slice(boundary + SSE_RECORD_SEPARATOR.length); + boundary = buffer.indexOf(SSE_RECORD_SEPARATOR); + } + } + } + } catch { + // network drop or abort falls through to the reconnect delay } - if ( - expectedStatus === ServerModelStatus.LOADED && - currentStatus === ServerModelStatus.UNLOADED && - attempt > 2 - ) { - throw new Error('Model was unloaded unexpectedly during loading'); - } + if (signal.aborted) return; - attempt++; - await new Promise((resolve) => setTimeout(resolve, ModelsStore.STATUS_POLL_INTERVAL)); + await new Promise((resolve) => setTimeout(resolve, ModelsStore.SSE_RECONNECT_MS)); + } + } + + /** + * Parse one SSE record. The payload rides in the data lines as a JSON + * envelope that carries its own model, event and data fields. + */ + private handleStatusRecord(record: string): void { + const payload = record + .split(SSE_LINE_SEPARATOR) + .filter((line) => line.startsWith(SSE_DATA_PREFIX)) + .map((line) => line.slice(SSE_DATA_PREFIX.length).trim()) + .join(SSE_LINE_SEPARATOR); + + if (payload.length === 0) return; + + let envelope: ApiModelsSseEvent; + try { + envelope = JSON.parse(payload); + } catch { + return; + } + + this.applyStatusEvent(envelope); + } + + /** + * Route one feed record by event kind. Only the status_* events carry a + * status payload, models_reload triggers a list refresh, model_remove drops + * the row, download_* belong to the download surface, not here. + */ + private applyStatusEvent(event: ApiModelsSseEvent): void { + switch (event.event) { + case ServerModelsSseEventType.STATUS_CHANGE: + case ServerModelsSseEventType.MODEL_STATUS: + case ServerModelsSseEventType.STATUS_UPDATE: + this.applyModelStatus(event); + break; + case ServerModelsSseEventType.MODELS_RELOAD: + void this.fetchRouterModels(); + break; + case ServerModelsSseEventType.MODEL_REMOVE: + this.removeRouterModel(event.model); + break; + case ServerModelsSseEventType.DOWNLOAD_PROGRESS: + break; + } + } + + /** + * Apply a status envelope: update the model row, track or clear progress, + * settle any pending load or unload awaiter. + */ + private applyModelStatus(event: ApiModelsSseEvent): void { + const model = event.model; + const data = event.data; + if (!model || !data?.status) return; + + const status = data.status; + + this.setRouterModelStatus(model, status); + + if (status === ServerModelStatus.LOADING) { + if (data.progress) this.loadProgress.set(model, data.progress); + } else { + this.loadProgress.delete(model); + } + + if (status === ServerModelStatus.LOADED) { + void this.updateModelModalities(model); + } + + const failed = + status === ServerModelStatus.FAILED || + (status === ServerModelStatus.UNLOADED && (data.exit_code ?? 0) !== 0); + + if (failed) { + this.rejectStatus(model, new Error(`Model failed: ${this.toDisplayName(model)}`)); + return; + } + + this.settleStatus(model, status); + } + + /** + * Drop a model row reported gone by the feed and settle its awaiters. + */ + private removeRouterModel(modelId: string): void { + if (this.routerModels.findIndex((m) => m.id === modelId) === -1) return; + + this.routerModels = this.routerModels.filter((m) => m.id !== modelId); + this.loadProgress.delete(modelId); + this.rejectStatus(modelId, new Error(`Model removed: ${this.toDisplayName(modelId)}`)); + } + + /** + * Update one model row status in place, reassigning to trigger reactivity. + */ + private setRouterModelStatus(modelId: string, status: ServerModelStatus): void { + const idx = this.routerModels.findIndex((m) => m.id === modelId); + if (idx === -1) return; + + const current = this.routerModels[idx]; + if (current.status.value === status) return; + + const next = [...this.routerModels]; + next[idx] = { ...current, status: { ...current.status, value: status } }; + this.routerModels = next; + } + + /** + * Register an awaiter that resolves when the feed reports target status. + * One operation runs per model at a time, so one awaiter per model is kept. + */ + private waitForStatus(modelId: string, target: ServerModelStatus): Promise { + return new Promise((resolve, reject) => { + this.statusWaiters.set(modelId, { target, resolve, reject }); + }); + } + + /** + * Resolve and drop the awaiter when the model reaches its target status. + */ + private settleStatus(modelId: string, status: ServerModelStatus): void { + const waiter = this.statusWaiters.get(modelId); + if (waiter && waiter.target === status) { + this.statusWaiters.delete(modelId); + waiter.resolve(); + } + } + + /** + * Reject and drop the awaiter for a model. + */ + private rejectStatus(modelId: string, error: Error): void { + const waiter = this.statusWaiters.get(modelId); + if (waiter) { + this.statusWaiters.delete(modelId); + waiter.reject(error); } } @@ -679,12 +862,18 @@ class ModelsStore { this.modelLoadingStates.set(modelId, true); this.error = null; + // the feed drives completion, so it must be live before the request + this.subscribeStatus(); + + const reachedLoaded = this.waitForStatus(modelId, ServerModelStatus.LOADED); + reachedLoaded.catch(() => {}); + try { await ModelsService.load(modelId); - await this.pollForModelStatus(modelId, ServerModelStatus.LOADED); - await this.updateModelModalities(modelId); + await reachedLoaded; toast.success(`Model loaded: ${this.toDisplayName(modelId)}`); } catch (error) { + this.rejectStatus(modelId, error instanceof Error ? error : new Error('load failed')); this.error = error instanceof Error ? error.message : 'Failed to load model'; toast.error(`Failed to load model: ${this.toDisplayName(modelId)}`); throw error; @@ -700,11 +889,17 @@ class ModelsStore { this.modelLoadingStates.set(modelId, true); this.error = null; + this.subscribeStatus(); + + const reachedUnloaded = this.waitForStatus(modelId, ServerModelStatus.UNLOADED); + reachedUnloaded.catch(() => {}); + try { await ModelsService.unload(modelId); - await this.pollForModelStatus(modelId, ServerModelStatus.UNLOADED); + await reachedUnloaded; toast.info(`Model unloaded: ${this.toDisplayName(modelId)}`); } catch (error) { + this.rejectStatus(modelId, error instanceof Error ? error : new Error('unload failed')); this.error = error instanceof Error ? error.message : 'Failed to unload model'; toast.error(`Failed to unload model: ${this.toDisplayName(modelId)}`); throw error; @@ -783,6 +978,9 @@ class ModelsStore { } clear(): void { + this.unsubscribeStatus(); + this.statusWaiters.forEach((waiter) => waiter.reject(new Error('Models store cleared'))); + this.statusWaiters.clear(); this.models = []; this.routerModels = []; this.loading = false; diff --git a/tools/ui/src/lib/types/api.d.ts b/tools/ui/src/lib/types/api.d.ts index f620d67351..2a2524d002 100644 --- a/tools/ui/src/lib/types/api.d.ts +++ b/tools/ui/src/lib/types/api.d.ts @@ -1,4 +1,10 @@ -import type { ContentPartType, FileTypeAudio, ServerModelStatus, ServerRole } from '$lib/enums'; +import type { + ContentPartType, + FileTypeAudio, + ServerModelStatus, + ServerModelsSseEventType, + ServerRole +} from '$lib/enums'; import type { ChatMessagePromptProgress, ChatRole } from './chat'; export type AudioInputFormat = FileTypeAudio.WAV | FileTypeAudio.MP3; @@ -96,6 +102,46 @@ export interface ApiModelDataEntry { meta?: Record | null; } +/** + * Load stage reported by the /models/sse feed, in load order. + */ +export type ApiModelLoadStage = 'text_model' | 'spec_model' | 'mmproj_model'; + +/** + * Load progress snapshot: the full ordered stage plan, the active stage, + * and its fractional value (0.0 -> 1.0). + */ +export interface ApiModelsSseProgress { + stages: ApiModelLoadStage[]; + current: ApiModelLoadStage; + value: number; +} + +/** + * Status payload carried by a /models/sse envelope. + * exit_code appears on unload. + */ +export interface ApiModelsSseData { + status: ServerModelStatus; + progress?: ApiModelsSseProgress; + exit_code?: number; +} + +/** + * Event kind multiplexed on the /models/sse feed. + * Only the status_* events carry a status payload, models_reload signals a + * full list refresh, model_remove drops a row, download_* drive download UI. + */ +/** + * One /models/sse record. event discriminates the kind, model names the + * target instance, data carries the status payload when present. + */ +export interface ApiModelsSseEvent { + model: string; + event: ServerModelsSseEventType; + data: ApiModelsSseData; +} + export interface ApiModelDetails { name: string; model: string; diff --git a/tools/ui/src/lib/types/index.ts b/tools/ui/src/lib/types/index.ts index c5f9488981..9b0b118045 100644 --- a/tools/ui/src/lib/types/index.ts +++ b/tools/ui/src/lib/types/index.ts @@ -11,6 +11,10 @@ export type { ApiChatMessageData, ApiModelStatus, ApiModelDataEntry, + ApiModelLoadStage, + ApiModelsSseProgress, + ApiModelsSseData, + ApiModelsSseEvent, ApiModelDetails, ApiModelListResponse, ApiLlamaCppServerProps, @@ -70,7 +74,12 @@ export type { } from './database'; // Model types -export type { ModelModalities, ModelOption, ModalityCapabilities } from './models'; +export type { + ModelModalities, + ModelOption, + ModelLoadProgress, + ModalityCapabilities +} from './models'; // Settings types export type { diff --git a/tools/ui/src/lib/types/models.d.ts b/tools/ui/src/lib/types/models.d.ts index 51069599d7..b32c16f6f2 100644 --- a/tools/ui/src/lib/types/models.d.ts +++ b/tools/ui/src/lib/types/models.d.ts @@ -1,4 +1,4 @@ -import type { ApiModelDataEntry, ApiModelDetails } from '$lib/types/api'; +import type { ApiModelDataEntry, ApiModelDetails, ApiModelLoadStage } from '$lib/types/api'; export interface ModelModalities { vision: boolean; @@ -20,6 +20,17 @@ export interface ModelOption { tags?: string[]; } +/** + * Ephemeral UI-only load progress for one model instance. + * Lives only while a load runs, driven by the /models/sse feed. + * stage is absent until the feed reports its first stage. + */ +export interface ModelLoadProgress { + stages: ApiModelLoadStage[]; + current: ApiModelLoadStage; + value: number; +} + export interface ParsedModelId { raw: string; orgName: string | null; diff --git a/tools/ui/src/lib/utils/index.ts b/tools/ui/src/lib/utils/index.ts index 637db8812c..61b9932d3f 100644 --- a/tools/ui/src/lib/utils/index.ts +++ b/tools/ui/src/lib/utils/index.ts @@ -44,6 +44,9 @@ export { buildProxiedUrl, buildProxiedHeaders } from './cors-proxy'; // URL utilities export { extractRootDomain, sanitizeExternalUrl } from './url'; +// Progress helpers +export { modelLoadFraction, modelLoadProgressText } from './progress'; + // Conversation utilities export { createMessageCountMap, getMessageCount } from './conversation-utils'; diff --git a/tools/ui/src/lib/utils/progress.ts b/tools/ui/src/lib/utils/progress.ts new file mode 100644 index 0000000000..4d7e223882 --- /dev/null +++ b/tools/ui/src/lib/utils/progress.ts @@ -0,0 +1,43 @@ +/** + * Model load progress helpers for the /models/sse surfaces + * (selector row and chat message). + */ + +import { MODEL_LOAD_STAGE_LABELS, MODEL_LOAD_TAIL_SHARE } from '$lib/constants'; + +/** + * Human label for a model load stage. + */ +export function modelLoadStageLabel(stage: ApiModelLoadStage): string { + return MODEL_LOAD_STAGE_LABELS[stage]; +} + +/** + * Overall load fraction (0.0 -> 1.0) across the declared stage plan. + * text_model fills [0, 1 - tail], each later phase owns one tail slice. + */ +export function modelLoadFraction(progress: ModelLoadProgress | null): number { + if (!progress) return 0; + + const { stages, current, value } = progress; + const tailCount = Math.max(stages.length - 1, 0); + const textCeiling = 1 - tailCount * MODEL_LOAD_TAIL_SHARE; + const idx = stages.indexOf(current); + + if (idx <= 0) { + return value * textCeiling; + } + + return textCeiling + (idx - 1 + value) * MODEL_LOAD_TAIL_SHARE; +} + +/** + * Single line describing load progress: active stage label and overall percent. + * Returns null when there is no progress to show. + */ +export function modelLoadProgressText(progress: ModelLoadProgress | null): string | null { + if (!progress) return null; + + const label = modelLoadStageLabel(progress.current); + return `${label} ${Math.round(modelLoadFraction(progress) * 100)}%`; +} diff --git a/tools/ui/src/routes/+layout.svelte b/tools/ui/src/routes/+layout.svelte index fdba9a9d37..1269692a78 100644 --- a/tools/ui/src/routes/+layout.svelte +++ b/tools/ui/src/routes/+layout.svelte @@ -230,6 +230,20 @@ } }); + // Live model status and load progress via the /models/sse feed (router mode) + $effect(() => { + if (!browser) return; + if (!isRouterMode()) return; + + untrack(() => { + modelsStore.subscribeStatus(); + }); + + return () => { + modelsStore.unsubscribeStatus(); + }; + }); + // Background MCP server health checks on app load // Fetch enabled servers from settings and run health checks in background $effect(() => { From 6ee0f65793da4bca2301826f70383aef2da60345 Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Mon, 22 Jun 2026 16:42:47 +0200 Subject: [PATCH 49/86] server: refactor/generalize input file schema (#24299) * server: refactor/generalize input file schema * wire up input_video, accept raw base64 * nits * nits (2) * fix windows --- tools/server/README.md | 17 +++++++--- tools/server/server-common.cpp | 62 ++++++++++++++++++++-------------- 2 files changed, 49 insertions(+), 30 deletions(-) diff --git a/tools/server/README.md b/tools/server/README.md index 7fa3a4d728..e88bc5f28a 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -1230,8 +1230,6 @@ print(completion.choices[0].text) Given a ChatML-formatted json description in `messages`, it returns the predicted completion. Both synchronous and streaming mode are supported, so scripted and interactive applications work fine. While no strong claims of compatibility with OpenAI API spec is being made, in our experience it suffices to support many apps. Only models with a [supported chat template](https://github.com/ggml-org/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template) can be used optimally with this endpoint. By default, the ChatML template will be used. -If model supports multimodal, you can input the media file via `image_url` content part. We support both base64 and remote URL as input. See OAI documentation for more. - *Options:* See [OpenAI Chat Completions API documentation](https://platform.openai.com/docs/api-reference/chat). llama.cpp `/completion`-specific features such as `mirostat` are also supported. @@ -1250,9 +1248,18 @@ The `response_format` parameter supports both plain JSON output (e.g. `{"type": `parallel_tool_calls` : Whether to enable parallel/multiple tool calls (only supported on some models, verification is based on jinja template). -For multimodal input: -- Content type `image_url` and `input_audio` are the same as OAI schema -- Content type `input_video` is an extension from OAI schema. For now, it only accepts base64 input +For multimodal input (typed content, `messages[i].content[j]`): +- If `type == "image_url"`: + - `image_url.url` can be a remote URL, base64 (raw or URI-encoded via `data:image/...;base64`) or path to local file + - Accepts formats supported by `stb_image` (jpeg, png, tga, bmp, gif, ...) +- If `type == "input_audio"`: + - Either `input_audio.data` or `input_audio.url` can be specified, can be a remote URL, raw base64 or path to local file + - Accepts formats supported by `miniaudio` (mp3, wav, flac) + - `input_audio.format` will be ignored, the file format will be determined automatically +- If `type == "input_video"`: + - Either `input_video.data` or `input_video.url` can be specified, can be a remote URL, raw base64 or path to local file + - Accepts formats supported by `ffmpeg` +- Note: for local file, make sure to set `--media-path`. File path must be prefixed by `file://` *Examples:* diff --git a/tools/server/server-common.cpp b/tools/server/server-common.cpp index 3dc686bb46..e412b94c5c 100644 --- a/tools/server/server-common.cpp +++ b/tools/server/server-common.cpp @@ -817,12 +817,21 @@ json oaicompat_completion_params_parse(const json & body) { return llama_params; } -// media_path always end with '/', see arg.cpp +// url can be +// - http(s):// for remote files +// - file:// for local files (only allowed if media_path is set) +// - data: for base64 encoded data with uri scheme (e.g. data:image/png;base64,...) +// - raw base64 encoded data static void handle_media( std::vector & out_files, - json & media_obj, - const std::string & media_path) { - std::string url = json_value(media_obj, "url", std::string()); + const std::string & url, + const std::string & media_path, + bool accept_base64_uri) { + if (!media_path.empty()) { + // should already be enforced by arg.cpp, but checking just in case + GGML_ASSERT(media_path.back() == DIRECTORY_SEPARATOR); + } + if (string_starts_with(url, "http")) { // download remote image // TODO @ngxson : maybe make these params configurable @@ -858,20 +867,28 @@ static void handle_media( data.assign((std::istreambuf_iterator(file)), std::istreambuf_iterator()); out_files.push_back(data); - } else { + } else if (accept_base64_uri && string_starts_with(url, "data:")) { // try to decode base64 image std::vector parts = string_split(url, /*separator*/ ','); if (parts.size() != 2) { - throw std::runtime_error("Invalid url value"); + throw std::runtime_error("Invalid uri-encoded base64 value"); } else if (!string_starts_with(parts[0], "data:image/")) { - throw std::runtime_error("Invalid url format: " + parts[0]); + throw std::runtime_error("Invalid uri format: " + parts[0]); } else if (!string_ends_with(parts[0], "base64")) { - throw std::runtime_error("url must be base64 encoded"); + throw std::runtime_error("uri must be base64 encoded"); } else { auto base64_data = parts[1]; auto decoded_data = base64_decode(base64_data); out_files.push_back(decoded_data); } + + } else { + // try as raw base64 string + auto decoded_data = base64_decode(url); + if (decoded_data.empty()) { + throw std::runtime_error("Invalid base64 value"); + } + out_files.push_back(decoded_data); } } @@ -957,14 +974,15 @@ json oaicompat_chat_params_parse( } for (auto & p : content) { - std::string type = json_value(p, "type", std::string()); + std::string type = json_value(p, "type", std::string()); if (type == "image_url") { if (!opt.allow_image) { throw std::runtime_error("image input is not supported - hint: if this is unexpected, you may need to provide the mmproj"); } json image_url = json_value(p, "image_url", json::object()); - handle_media(out_files, image_url, opt.media_path); + std::string url = json_value(image_url, "url", std::string()); + handle_media(out_files, url, opt.media_path, true); p["type"] = "media_marker"; p["text"] = get_media_marker(); @@ -975,17 +993,11 @@ json oaicompat_chat_params_parse( throw std::runtime_error("audio input is not supported - hint: if this is unexpected, you may need to provide the mmproj"); } - json input_audio = json_value(p, "input_audio", json::object()); - std::string data = json_value(input_audio, "data", std::string()); - std::string format = json_value(input_audio, "format", std::string()); - // while we also support flac, we don't allow it here so we matches the OAI spec - if (format != "wav" && format != "mp3") { - throw std::invalid_argument("input_audio.format must be either 'wav' or 'mp3'"); - } - auto decoded_data = base64_decode(data); // expected to be base64 encoded - out_files.push_back(decoded_data); - - // TODO: add audio_url support by reusing handle_media() + // note: don't need to validate "format", it's redundant + json input_audio = json_value(p, "input_audio", json::object()); + std::string url = json_value(input_audio, "data", + json_value(input_audio, "url", std::string())); + handle_media(out_files, url, opt.media_path, false); p["type"] = "media_marker"; p["text"] = get_media_marker(); @@ -996,10 +1008,10 @@ json oaicompat_chat_params_parse( throw std::runtime_error("video input is not supported - hint: if this is unexpected, you may need to provide the mmproj"); } - json input_video = json_value(p, "input_video", json::object()); - std::string data = json_value(input_video, "data", std::string()); - auto decoded_data = base64_decode(data); // expected to be base64 encoded - out_files.push_back(decoded_data); + json input_video = json_value(p, "input_video", json::object()); + std::string url = json_value(input_video, "data", + json_value(input_video, "url", std::string())); + handle_media(out_files, url, opt.media_path, false); p["type"] = "media_marker"; p["text"] = get_media_marker(); From 721354fbdfb7743e2be2183d918a3cdb9276c70f Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Mon, 22 Jun 2026 18:24:04 +0200 Subject: [PATCH 50/86] server: (router) move model downloading to dedicated process (#24834) * server: real-time model load progress tracking via /models/sse * update docs * server: move model download to child process * rm unused * fix most problems * clean up * nit fixes * fix test case * do not detact() thread * shorter MODEL_DOWNLOAD_TIMEOUT in test * throttle --- common/arg.cpp | 14 +- common/arg.h | 6 +- tools/server/README-dev.md | 6 +- tools/server/server-context.cpp | 6 + tools/server/server-context.h | 4 +- tools/server/server-models.cpp | 360 ++++++++++++++++--------- tools/server/server-models.h | 20 +- tools/server/server.cpp | 13 +- tools/server/tests/unit/test_router.py | 35 ++- 9 files changed, 312 insertions(+), 152 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 8f54b5c814..5297d90753 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -396,7 +396,7 @@ static bool parse_bool_value(const std::string & value) { // CLI argument parsing functions // -bool common_params_handle_models(common_params & params, llama_example curr_ex) { +bool common_params_handle_models(common_params & params, llama_example curr_ex, common_download_callback * callback) { const bool spec_type_draft_mtp = std::find(params.speculative.types.begin(), params.speculative.types.end(), COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params.speculative.types.end(); @@ -408,6 +408,10 @@ bool common_params_handle_models(common_params & params, llama_example curr_ex) opts.download_mtp = spec_type_draft_mtp; opts.download_mmproj = !params.no_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty(); + if (callback) { + opts.callback = callback; + } + // sub-models (draft, mmproj, vocoder) are explicitly specified by the user, // so we should not auto-discover mtp/mmproj siblings for them common_download_opts sub_opts = opts; @@ -584,8 +588,11 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context throw std::invalid_argument("error: --prompt-cache-all not supported in interactive mode yet\n"); } - // export_graph_ops loads only metadata - const bool skip_model_download = ctx_arg.ex == LLAMA_EXAMPLE_EXPORT_GRAPH_OPS; + const bool skip_model_download = + // server will call common_params_handle_models() later, so we skip it here + ctx_arg.ex == LLAMA_EXAMPLE_SERVER || + // export_graph_ops loads only metadata + ctx_arg.ex == LLAMA_EXAMPLE_EXPORT_GRAPH_OPS; if (!skip_model_download) { // handle model and download @@ -594,7 +601,6 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context // model is required (except for server) // TODO @ngxson : maybe show a list of available models in CLI in this case if (params.model.path.empty() - && ctx_arg.ex != LLAMA_EXAMPLE_SERVER && !params.usage && !params.completion) { throw std::invalid_argument("error: --model is required\n"); diff --git a/common/arg.h b/common/arg.h index 0010f2a9ac..c061fc60f7 100644 --- a/common/arg.h +++ b/common/arg.h @@ -1,6 +1,7 @@ #pragma once #include "common.h" +#include "download.h" #include #include @@ -133,7 +134,10 @@ void common_params_add_preset_options(std::vector & args); // return true if the model is ready to use // throw an exception if there is an error that prevents the model from being used (e.g. network error, model not found, etc) // if params.skip_download is true, no downloads will be attempted. return false if the model is invalid or missing (e.g. ETag check failed) -bool common_params_handle_models(common_params & params, llama_example curr_ex); +bool common_params_handle_models( + common_params & params, + llama_example curr_ex, + common_download_callback * callback = nullptr); // initialize argument parser context - used by test-arg-parser and preset common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr); diff --git a/tools/server/README-dev.md b/tools/server/README-dev.md index 2796d28350..5959745e47 100644 --- a/tools/server/README-dev.md +++ b/tools/server/README-dev.md @@ -204,9 +204,9 @@ Instead of building everything from the ground up (like what most AI agents will The flow for downloading a new model: - POST request comes in --> `post_router_models` --> validation -- `server_models::download()` is called - - Sets up a new thread `inst.th` and runs the download inside -- If a stop request comes in, set `stop_download` to `true` +- A new `llama-server` subprocess will be spawned with special `SERVER_CHILD_MODE_DOWNLOAD` +- Child process runs the download and report status back to router via stdin/out +- If a stop request comes in, the router asks the child process to stop (same mechanism as running a model in child process) - Otherwise, upon completion, we call `load_models()` to refresh the list of models ### Notable Related PRs diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 3f9391cacb..0a25b414ed 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -931,6 +931,8 @@ private: bool sleeping = false; + int64_t t_last_load_progress_ms = 0; + void destroy() { spec.reset(); ctx_dft.reset(); @@ -1244,6 +1246,10 @@ private: } if (has_mmproj) { + if (callback_state) { + callback_state(SERVER_STATE_LOADING, {{"stage", "mmproj_model"}}); + } + if (!is_resume) { mtmd_helper_log_set(common_log_default_callback, nullptr); } diff --git a/tools/server/server-context.h b/tools/server/server-context.h index c7218a12ed..952f825f72 100644 --- a/tools/server/server-context.h +++ b/tools/server/server-context.h @@ -53,7 +53,7 @@ struct server_context_meta { }; enum server_state { - // SERVER_STATE_DOWNLOADING, + SERVER_STATE_DOWNLOADING, SERVER_STATE_LOADING, SERVER_STATE_READY, SERVER_STATE_SLEEPING, @@ -61,6 +61,7 @@ enum server_state { static std::string server_state_to_str(server_state state) { switch (state) { + case SERVER_STATE_DOWNLOADING: return "downloading"; case SERVER_STATE_LOADING: return "loading"; case SERVER_STATE_READY: return "ready"; case SERVER_STATE_SLEEPING: return "sleeping"; @@ -69,6 +70,7 @@ static std::string server_state_to_str(server_state state) { } static server_state server_state_from_str(const std::string & str) { + if (str == "downloading") return SERVER_STATE_DOWNLOADING; if (str == "loading") return SERVER_STATE_LOADING; if (str == "ready") return SERVER_STATE_READY; if (str == "sleeping") return SERVER_STATE_SLEEPING; diff --git a/tools/server/server-models.cpp b/tools/server/server-models.cpp index 68eefdffac..a87e4e423e 100644 --- a/tools/server/server-models.cpp +++ b/tools/server/server-models.cpp @@ -64,6 +64,17 @@ struct server_subproc { return sproc.has_value() && subprocess_alive(&sproc.value()); } + void request_exit() { + if (sproc.has_value()) { + FILE * stdin_file = subprocess_stdin(&sproc.value()); + if (stdin_file) { + fprintf(stdin_file, "%s\n", CMD_ROUTER_TO_CHILD_EXIT); + fflush(stdin_file); + } + } + stopped.store(true, std::memory_order_relaxed); + } + void terminate() { if (!sproc.has_value()) { return; @@ -323,7 +334,7 @@ void server_models::notify_sse(const std::string & event, const std::string & mo } void server_models::load_models() { - // Phase 1: load presets from all sources — pure I/O, no lock needed + // Phase 1: load presets from all sources - pure I/O, no lock needed // 1. cached models common_presets cached_models = ctx_preset.load_from_cache(); SRV_INF("Loaded %zu cached model presets\n", cached_models.size()); @@ -376,7 +387,7 @@ void server_models::load_models() { return source_map.count(name) ? source_map.at(name) : SERVER_MODEL_SOURCE_PRESET; }; - // Helpers that read `mapping` — must be called while holding the lock. + // Helpers that read `mapping` - must be called while holding the lock. std::unordered_set custom_names; for (const auto & [name, preset] : custom_presets) custom_names.insert(name); auto join_set = [](const std::set & s) { @@ -523,7 +534,7 @@ void server_models::load_models() { } } - // join outside the lock — monitoring thread calls update_status (needs lock) + // join outside the lock - monitoring thread calls update_status (needs lock) lk.unlock(); for (auto & th : threads_to_join) th.join(); lk.lock(); @@ -622,7 +633,7 @@ void server_models::load_models() { apply_stop_timeout(); - // clear reload flag before unlocking for autoload — load() blocks on !is_reloading, + // clear reload flag before unlocking for autoload - load() blocks on !is_reloading, // so clearing it here (while still locked) prevents a deadlock in the autoload calls below is_reloading = false; cv.notify_all(); @@ -815,17 +826,23 @@ void server_models::unload_lru() { } void server_models::load(const std::string & name) { - if (!has_model(name)) { - throw std::runtime_error("model name=" + name + " is not found"); + load(name, load_options{}); +} + +void server_models::load(const std::string & name, const load_options & opts) { + if (!opts.custom_meta.has_value()) { + if (!has_model(name)) { + throw std::runtime_error("model name=" + name + " is not found"); + } + unload_lru(); } - unload_lru(); std::unique_lock lk(mutex); // edge case: block until any in-progress reload has finished so we always load // against the freshest preset and a consistent mapping state cv.wait(lk, [this]() { return !is_reloading; }); - auto meta = mapping[name].meta; + auto meta = opts.custom_meta.has_value() ? *opts.custom_meta : mapping[name].meta; if (meta.status != SERVER_MODEL_STATUS_UNLOADED) { SRV_INF("model %s is not ready\n", name.c_str()); return; @@ -869,6 +886,12 @@ void server_models::load(const std::string & name) { std::vector child_env = base_env; // copy child_env.push_back("LLAMA_SERVER_ROUTER_PORT=" + std::to_string(base_params.port)); + if (opts.mode == SERVER_CHILD_MODE_DOWNLOAD) { + inst.meta.status = SERVER_MODEL_STATUS_DOWNLOADING; + child_env.push_back("LLAMA_SERVER_CHILD_MODE=download"); + child_env.push_back("LLAMA_ARG_HF_REPO=" + name); + } + SRV_INF("%s", "spawning server instance with args:\n"); for (const auto & arg : child_args) { SRV_INF(" %s\n", arg.c_str()); @@ -886,13 +909,17 @@ void server_models::load(const std::string & name) { if (result != 0) { throw std::runtime_error("failed to spawn server instance"); } - - inst.stdin_file = subprocess_stdin(&inst.subproc->get()); } // start a thread to manage the child process // captured variables are guaranteed to be destroyed only after the thread is joined - inst.th = std::thread([this, name, child_proc = inst.subproc, port = inst.meta.port, stop_timeout = inst.meta.stop_timeout]() { + inst.th = std::thread([ + this, name, + child_proc = inst.subproc, + port = inst.meta.port, + stop_timeout = inst.meta.stop_timeout, + child_mode = opts.mode + ]() { FILE * stdin_file = subprocess_stdin(&child_proc->get()); FILE * stdout_file = subprocess_stdout(&child_proc->get()); // combined stdout/stderr @@ -925,7 +952,7 @@ void server_models::load(const std::string & name) { return is_stopping() || child_proc->stopped.load(std::memory_order_acquire); }); } - // child crashed or finished on its own — skip graceful shutdown sequence + // child crashed or finished on its own, skip graceful shutdown sequence if (child_proc->stopped.load(std::memory_order_acquire)) { return; } @@ -973,10 +1000,14 @@ void server_models::load(const std::string & name) { subprocess_destroy(&child_proc->get()); // update status and exit code - this->update_status(name, { - SERVER_MODEL_STATUS_UNLOADED, - exit_code - }); + if (child_mode == SERVER_CHILD_MODE_DOWNLOAD) { + // instance will be cleaned up on next load_models() call + } else { + this->update_status(name, { + SERVER_MODEL_STATUS_UNLOADED, + exit_code + }); + } SRV_INF("instance name=%s exited with status %d\n", name.c_str(), exit_code); }); @@ -984,7 +1015,7 @@ void server_models::load(const std::string & name) { { auto & old_instance = mapping[name]; // old process should have exited already, but just in case, we clean it up here - if (old_instance.subproc->is_alive()) { + if (old_instance.subproc && old_instance.subproc->is_alive()) { SRV_WRN("old process for model name=%s is still alive, this is unexpected\n", name.c_str()); old_instance.subproc->terminate(); // force kill } @@ -1001,92 +1032,13 @@ void server_models::load(const std::string & name) { cv.notify_all(); } -// callback for model downloading functionality -struct server_models_download_res : public common_download_callback { - common_params_model model; - common_download_opts opts; - - std::function should_stop; - std::function on_progress; - - bool is_ok = false; - - bool run() { - try { - common_download_model(model, opts); - is_ok = true; - } catch (const std::exception & e) { - auto model_name = model.get_name(); - SRV_ERR("download failed for model name=%s: %s\n", model_name.c_str(), e.what()); - is_ok = false; - } - return is_ok; - } - void on_start(const common_download_progress & p) override { - on_progress(p); - } - void on_update(const common_download_progress & p) override { - on_progress(p); - } - void on_done(const common_download_progress &, bool ok) override { - is_ok = ok; - } - bool is_cancelled() const override { - return should_stop(); - } -}; - -void server_models::download(common_params_model && model, common_download_opts && opts) { - std::string name = model.get_name(); - GGML_ASSERT(name == model.hf_repo); - - std::unique_lock lk(mutex); - if (mapping.find(name) != mapping.end()) { - throw std::runtime_error("model name=" + name + " already exists"); - } - - instance_t inst; - inst.meta.name = name; - inst.meta.status = SERVER_MODEL_STATUS_DOWNLOADING; - inst.subproc = std::make_shared(); - - auto dl = std::make_unique(); - dl->model = model; // copy - dl->opts = opts; // copy - - dl->should_stop = [sp = inst.subproc]() { - return sp->stopped.load(std::memory_order_relaxed); - }; - - dl->on_progress = [this, name](const common_download_progress & p) { - update_download_progress(name, p, false); - }; - - inst.th = std::thread([this, dl = std::move(dl)]() { - dl->opts.callback = dl.get(); - bool ok = dl->run(); - auto model_name = dl->model.get_name(); - SRV_INF("download finished for model name=%s with status=%s\n", - model_name.c_str(), ok ? "success" : "failure"); - update_download_progress(model_name, {}, true, ok); - // need_reload is set inside update_download_progress under the mutex; - // the next load_models() call will clean up this instance - }); - - mapping[name] = std::move(inst); - notify_sse("status_update", name, { - {"status", server_model_status_to_string(SERVER_MODEL_STATUS_DOWNLOADING)}, - }); - cv.notify_all(); -} - void server_models::unload(const std::string & name) { std::unique_lock lk(mutex); auto it = mapping.find(name); if (it != mapping.end()) { if (it->second.meta.status == SERVER_MODEL_STATUS_DOWNLOADING) { SRV_INF("cancelling download for model name=%s\n", name.c_str()); - it->second.subproc->stopped.store(true, std::memory_order_relaxed); + it->second.subproc->request_exit(); // for convenience, we wait the status change here wait(lk, name, [](const server_model_meta & new_meta) { return new_meta.status != SERVER_MODEL_STATUS_DOWNLOADING; @@ -1198,37 +1150,65 @@ void server_models::update_download_progress(const std::string & name, const com } bool server_models::remove(const std::string & name) { - auto meta = get_meta(name); + // do everything under one lock acquisition; avoid get_meta() / + // unload() because they can trigger load_models() which erases + // transient DOWNLOADING / DOWNLOADED entries as a side-effect + std::unique_lock lk(mutex); - if (!meta.has_value()) { + auto it = mapping.find(name); + if (it == mapping.end()) { throw std::runtime_error("model name=" + name + " is not found"); } - if (meta->source != SERVER_MODEL_SOURCE_CACHE) { + if (it->second.meta.source != SERVER_MODEL_SOURCE_CACHE) { throw std::runtime_error("model name=" + name + " is not removable (not from cache)"); } - unload(name); // cancel download or stop running instance - { - std::unique_lock lk(mutex); - // a cancelled download lands on DOWNLOADED; a stopped instance lands on UNLOADED - wait(lk, name, [](const server_model_meta & new_meta) { - return new_meta.status == SERVER_MODEL_STATUS_UNLOADED - || new_meta.status == SERVER_MODEL_STATUS_DOWNLOADED; - }); - // join before erasing - after status reaches UNLOADED/DOWNLOADED the thread no - // longer acquires this mutex, so joining while holding it is safe - if (mapping[name].th.joinable()) { - mapping[name].th.join(); + if (it->second.meta.status == SERVER_MODEL_STATUS_DOWNLOADING) { + // cancel in-flight download + SRV_INF("cancelling download for model name=%s\n", name.c_str()); + it->second.subproc->request_exit(); + } else if (it->second.meta.is_running()) { + // stop running instance + SRV_INF("stopping model instance name=%s\n", name.c_str()); + stopping_models.insert(name); + if (it->second.meta.status == SERVER_MODEL_STATUS_LOADING) { + it->second.subproc->terminate(); } - // remove the model from disk (hold lock to prevent concurrent load) - bool ok = common_download_remove(name); - if (ok) { - mapping.erase(name); - } - SRV_INF("removing model name=%s from cache (%s)\n", name.c_str(), ok ? "succeeded" : "failed"); - notify_sse("model_remove", name, {}); - return ok; + cv_stop.notify_all(); } + + // wait until the monitoring thread finishes + wait(lk, name, [](const server_model_meta & meta) { + return meta.status == SERVER_MODEL_STATUS_UNLOADED + || meta.status == SERVER_MODEL_STATUS_DOWNLOADED; + }); + + // re-find after wait - load_models() may have erased the entry during the wait + it = mapping.find(name); + if (it == mapping.end()) { + // load_models() already joined the thread and erased the entry; + // we just need to clean up the cached files on disk + lk.unlock(); + bool ok = common_download_remove(name); + SRV_INF("removing model name=%s from cache (%s)\n", name.c_str(), ok ? "succeeded" : "partial"); + notify_sse("model_remove", name, {}); + return true; + } + + // join before erasing - thread no longer acquires this mutex + if (it->second.th.joinable()) { + it->second.th.join(); + } + + // remove from disk (best-effort: cancelled downloads may have no cached files) + bool ok = common_download_remove(name); + mapping.erase(name); + if (!ok) { + SRV_WRN("removing model name=%s from disk returned false (no cached files?)\n", name.c_str()); + } + SRV_INF("removing model name=%s from cache (%s)\n", name.c_str(), ok ? "succeeded" : "partial"); + notify_sse("model_remove", name, {}); + return true; } void server_models::wait(const std::string & name, std::function predicate) { @@ -1243,7 +1223,9 @@ void server_models::wait(std::unique_lock & lk, const std::string & return predicate(it->second.meta); } - return false; + // model was removed from mapping by another code path (e.g. load_models()). + // nothing left to wait for - tell the caller to proceed. + return true; }); } @@ -1328,6 +1310,31 @@ void server_models::handle_child_state(const std::string & name, const std::stri } switch (state) { + case SERVER_STATE_DOWNLOADING: + { + std::string result = json_value(payload, "result", std::string()); + std::string url = json_value(payload, "url", std::string()); + auto request_exit = [&]() { + std::lock_guard lk(mutex); + auto it = mapping.find(name); + if (it != mapping.end()) { + return it->second.subproc->request_exit(); + } + }; + if (result == "download_finished") { + update_download_progress(name, {}, true, true); + request_exit(); + } else if (result == "download_failed") { + update_download_progress(name, {}, true, false); + request_exit(); + } else if (!url.empty()) { + common_download_progress p; + p.url = url; + p.downloaded = json_value(payload, "downloaded", (size_t)0); + p.total = json_value(payload, "total", (size_t)0); + update_download_progress(name, p, false); + } + } break; case SERVER_STATE_LOADING: { update_status(name, { @@ -1366,6 +1373,90 @@ bool server_child::is_child() { return router_port != nullptr; } +server_child_mode server_child::get_mode() { + const char * mode = std::getenv("LLAMA_SERVER_CHILD_MODE"); + std::string mode_str(mode ? mode : ""); + if (mode_str == "download") { + return SERVER_CHILD_MODE_DOWNLOAD; + } else { + return SERVER_CHILD_MODE_NORMAL; + } +} + +struct server_download_state : public common_download_callback { + server_child * self; + std::function should_stop; + std::atomic last_progress_time{0}; // multiple files downloading in different threads + bool is_ok = false; + + server_download_state(server_child * s) : self(s) {} + + bool run(common_params & params) { + try { + common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, this); + is_ok = true; + } catch (const std::exception & e) { + auto model_name = params.model.get_name(); + SRV_ERR("download failed for model name=%s: %s\n", model_name.c_str(), e.what()); + is_ok = false; + } + return is_ok; + } + void on_progress(const common_download_progress & p) { + json data = { + {"url", p.url}, + {"downloaded", p.downloaded}, + {"total", p.total}, + }; + self->notify_to_router(server_state_to_str(SERVER_STATE_DOWNLOADING), data); + } + void on_start(const common_download_progress & p) override { + on_progress(p); + } + void on_update(const common_download_progress & p) override { + int64_t now = ggml_time_ms(); + // throttle progress updates to avoid flooding logs + if (now - last_progress_time.load(std::memory_order_relaxed) >= 100) { + on_progress(p); + last_progress_time.store(now, std::memory_order_relaxed); + } + } + void on_done(const common_download_progress & p, bool) override { + on_progress(p); + } + bool is_cancelled() const override { + return should_stop ? should_stop() : false; + } +}; + +int server_child::run_download(common_params & params) { + auto cancelled = std::make_shared>(false); + + // monitor stdin for cancellation command from the router + std::thread signal_thread = setup([cancelled](int) { + cancelled->store(true, std::memory_order_relaxed); + }); + + server_download_state dl(this); + dl.should_stop = [cancelled]() { + return cancelled->load(std::memory_order_relaxed); + }; + + bool ok = dl.run(params); + + notify_to_router(server_state_to_str(SERVER_STATE_DOWNLOADING), { + {"result", ok ? "download_finished" : "download_failed"}, + }); + + // router should send CMD_ROUTER_TO_CHILD_EXIT after receiving the result + if (signal_thread.joinable()) { + signal_thread.join(); + } + + SRV_INF("download completed %s\n", ok ? "successfully" : "with errors"); + return 0; +} + std::thread server_child::setup(const std::function & shutdown_handler) { // setup thread for monitoring stdin return std::thread([shutdown_handler]() { @@ -1639,7 +1730,7 @@ void server_models_routes::init_routes() { res_err(res, format_error_response("model is not found", ERROR_TYPE_INVALID_REQUEST)); return res; } - if (!model->is_running()) { + if (!model->is_running() && model->status != SERVER_MODEL_STATUS_DOWNLOADING) { res_err(res, format_error_response("model is not running", ERROR_TYPE_INVALID_REQUEST)); return res; } @@ -1680,8 +1771,9 @@ void server_models_routes::init_routes() { model.hf_repo = name; opts.bearer_token = params.hf_token; - opts.download_mmproj = true; - opts.download_mtp = true; + // note: we only check main model, no need sidecar here + opts.download_mmproj = false; + opts.download_mtp = false; // first, only check if the model is valid and can be downloaded opts.skip_download = true; @@ -1702,10 +1794,21 @@ void server_models_routes::init_routes() { throw std::invalid_argument("model validation failed, unable to download"); } + // reject if model already exists + if (models.has_model(name)) { + throw std::invalid_argument("model '" + name + "' already exists"); + } + // then, proceed with the actual download - opts.skip_download = false; SRV_INF("starting download for model '%s'\n", name.c_str()); - models.download(std::move(model), std::move(opts)); + { + server_models::load_options load_opts; + load_opts.mode = SERVER_CHILD_MODE_DOWNLOAD; + load_opts.custom_meta = server_model_meta{}; + load_opts.custom_meta->source = SERVER_MODEL_SOURCE_CACHE; + load_opts.custom_meta->name = name; + models.load(name, load_opts); + } res_ok(res, {{"success", true}}); return res; @@ -1719,10 +1822,7 @@ void server_models_routes::init_routes() { throw std::invalid_argument("model must be a non-empty string"); } - bool ok = models.remove(name); - if (!ok) { - throw std::runtime_error("failed to remove model '" + name + "'"); - } + models.remove(name); // throws on error res_ok(res, {{"success", true}}); return res; diff --git a/tools/server/server-models.h b/tools/server/server-models.h index 17759b00a5..9ed4aeead0 100644 --- a/tools/server/server-models.h +++ b/tools/server/server-models.h @@ -40,6 +40,11 @@ enum server_model_source { SERVER_MODEL_SOURCE_CACHE, }; +enum server_child_mode { + SERVER_CHILD_MODE_NORMAL, // load the model and run normally + SERVER_CHILD_MODE_DOWNLOAD, // download the model and exit +}; + static std::string server_model_status_to_string(server_model_status status) { switch (status) { case SERVER_MODEL_STATUS_DOWNLOADING: return "downloading"; @@ -105,7 +110,6 @@ private: std::shared_ptr subproc; // shared between main thread and monitoring thread std::thread th; server_model_meta meta; - FILE * stdin_file = nullptr; }; std::mutex mutex; @@ -161,16 +165,19 @@ public: // return a copy of all model metadata (thread-safe) std::vector get_all_meta(); + struct load_options { + server_child_mode mode = SERVER_CHILD_MODE_NORMAL; + // used for spawning a downloading child process + std::optional custom_meta = std::nullopt; + }; + // load and unload model instances // these functions are thread-safe void load(const std::string & name); + void load(const std::string & name, const load_options & opts); void unload(const std::string & name); void unload_all(); - // download a new model, progress is reported via SSE - // to stop the download, call unload() - void download(common_params_model && model, common_download_opts && opts); - struct update_status_args { server_model_status status; int exit_code = 0; // only valid if status == UNLOADED @@ -213,9 +220,12 @@ public: struct server_child { // serializes the notify_to_router writes std::mutex mtx_stdout; + std::atomic is_finished_downloading = false; // set by run_download // return true if the current process is a child server instance bool is_child(); + server_child_mode get_mode(); + int run_download(common_params & params); // register the shutdown_handler to be called by the router // return the monitoring thread (to be joined by the caller) diff --git a/tools/server/server.cpp b/tools/server/server.cpp index bf3680b9f0..dd4b1c507c 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -134,6 +134,7 @@ int llama_server(int argc, char ** argv) { // // register API routes + server_child child; // only used in non-router mode server_routes routes(params, ctx_server); server_tools tools; @@ -254,11 +255,21 @@ int llama_server(int argc, char ** argv) { ctx_http.post("/tools", ex_wrapper(tools.handle_post)); } + // + // Handle downloading model + // + + if (child.is_child() && child.get_mode() == SERVER_CHILD_MODE_DOWNLOAD) { + return child.run_download(params); + } else if (!is_router_server) { + // single-model mode (NOT spawned by router) + common_params_handle_models(params, LLAMA_EXAMPLE_SERVER); + } + // // Start the server // - server_child child; // only used in non-router mode std::function clean_up; if (is_router_server) { diff --git a/tools/server/tests/unit/test_router.py b/tools/server/tests/unit/test_router.py index 11c77ca7aa..41e95f4c5f 100644 --- a/tools/server/tests/unit/test_router.py +++ b/tools/server/tests/unit/test_router.py @@ -257,14 +257,25 @@ def test_router_reload_models(): MODEL_DOWNLOAD_ID = "ggml-org/test-model-router-download:F16" -MODEL_DOWNLOAD_TIMEOUT = 300 +MODEL_DOWNLOAD_TIMEOUT = 30 -def _listen_sse(server: ServerProcess, collected: list, stop: threading.Event): - """Collect /models/sse events into `collected` until `stop` is set.""" +def _listen_sse( + server: ServerProcess, collected: list, stop: threading.Event, ready: threading.Event | None = None +): + """Collect /models/sse events into `collected` until `stop` is set. + + When `ready` is provided, it is set once the streaming response is open, + i.e. the server has accepted the connection and registered us as a + subscriber. Callers that trigger one-shot events (e.g. download_finished) + must wait on `ready` before acting, otherwise the event can be broadcast + before this client is subscribed and be lost. + """ url = f"http://{server.server_host}:{server.server_port}/models/sse" try: with requests.get(url, stream=True, timeout=MODEL_DOWNLOAD_TIMEOUT) as resp: + if ready is not None: + ready.set() for line_bytes in resp.iter_lines(): if stop.is_set(): break @@ -294,11 +305,17 @@ def test_router_download_model(): sse_events: list = [] stop = threading.Event() + sse_ready = threading.Event() sse_thread = threading.Thread( - target=_listen_sse, args=(server, sse_events, stop), daemon=True + target=_listen_sse, args=(server, sse_events, stop, sse_ready), daemon=True ) sse_thread.start() + # wait for the SSE client to be subscribed before triggering the download, + # otherwise the one-shot download_finished event can be broadcast before + # this client is registered and be lost + assert sse_ready.wait(10), "SSE client failed to connect" + # Trigger the download res = server.make_request("POST", "/models", data={"model": MODEL_DOWNLOAD_ID}) assert res.status_code == 200 @@ -328,13 +345,17 @@ def test_router_delete_model(): # Ensure the model exists (download it if needed) if MODEL_DOWNLOAD_ID not in _get_model_ids(is_reload=False): - res = server.make_request("POST", "/models", data={"model": MODEL_DOWNLOAD_ID}) - assert res.status_code == 200 sse_events: list = [] stop = threading.Event() + sse_ready = threading.Event() threading.Thread( - target=_listen_sse, args=(server, sse_events, stop), daemon=True + target=_listen_sse, args=(server, sse_events, stop, sse_ready), daemon=True ).start() + # subscribe before triggering the download so the one-shot + # download_finished event is not lost (see test_router_download_model) + assert sse_ready.wait(10), "SSE client failed to connect" + res = server.make_request("POST", "/models", data={"model": MODEL_DOWNLOAD_ID}) + assert res.status_code == 200 finished = _wait_for_sse_event( sse_events, "download_finished", MODEL_DOWNLOAD_ID, MODEL_DOWNLOAD_TIMEOUT ) From 9c0ac887f35eb4f3cdef57df422de47715b6f0f8 Mon Sep 17 00:00:00 2001 From: Mahdiou Diallo <104755555+mahdiou@users.noreply.github.com> Date: Mon, 22 Jun 2026 21:00:21 +0200 Subject: [PATCH 51/86] ui: Prioritize favorite models in model selection (#24766) Updated model selection prioritization to include favorite models. --- tools/ui/src/lib/stores/models.svelte.ts | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tools/ui/src/lib/stores/models.svelte.ts b/tools/ui/src/lib/stores/models.svelte.ts index 2ce450d423..87fb172d07 100644 --- a/tools/ui/src/lib/stores/models.svelte.ts +++ b/tools/ui/src/lib/stores/models.svelte.ts @@ -545,7 +545,8 @@ class ModelsStore { * 1. Model from active conversation's last assistant response (if loaded) * 2. Model from active conversation's last assistant response (if not loaded) * 3. First loaded model (not from active conversation) - * 4. First available model + * 4. A favorite model + * 5. First available model */ async ensureFirstModelSelected(): Promise { if (this.selectedModelName) return; @@ -574,6 +575,13 @@ class ModelsStore { return; } + // Try loading a favorite model + const favorite = this.favoriteModelIds.values().next()?.value + if (favorite) { + await this.selectModelById(favorite); + return; + } + // Fall back to the first available model await this.selectModelById(availableModels[0].id); } From dec5ca5577d6042b4e870fadf4087c5b9b8d3a70 Mon Sep 17 00:00:00 2001 From: Matt Thompson <111157855+boondocklabs@users.noreply.github.com> Date: Mon, 22 Jun 2026 14:03:12 -0700 Subject: [PATCH 52/86] server : Add id to tool call responses api (#24882) --- tools/server/server-task.cpp | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index 9ba039c8b8..a9ebac013f 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -591,10 +591,11 @@ json server_task_result_cmpl_final::to_json_oaicompat_resp() { for (const common_chat_tool_call & tool_call : oaicompat_msg.tool_calls) { output.push_back(json { + {"id", "fc_" + tool_call.id}, {"type", "function_call"}, {"status", "completed"}, {"arguments", tool_call.arguments}, - {"call_id", "fc_" + tool_call.id}, + {"call_id", "call_" + tool_call.id}, {"name", tool_call.name}, }); } @@ -690,10 +691,11 @@ json server_task_result_cmpl_final::to_json_oaicompat_resp_stream() { for (const common_chat_tool_call & tool_call : oaicompat_msg.tool_calls) { const json output_item = { + {"id", "fc_" + tool_call.id}, {"type", "function_call"}, {"status", "completed"}, {"arguments", tool_call.arguments}, - {"call_id", "fc_" + tool_call.id}, + {"call_id", "call_" + tool_call.id}, {"name", tool_call.name} }; server_sent_events.push_back(json { @@ -1277,8 +1279,9 @@ json server_task_result_cmpl_partial::to_json_oaicompat_resp() { {"data", json { {"type", "response.output_item.added"}, {"item", json { + {"id", "fc_" + diff.tool_call_delta.id}, {"arguments", ""}, - {"call_id", "fc_" + diff.tool_call_delta.id}, + {"call_id", "call_" + diff.tool_call_delta.id}, {"name", diff.tool_call_delta.name}, {"type", "function_call"}, {"status", "in_progress"}, From 23ee8797e11b1f9502f547a9250bfa0f5c36a63f Mon Sep 17 00:00:00 2001 From: Shawn Gu Date: Mon, 22 Jun 2026 22:25:21 -0700 Subject: [PATCH 53/86] opencl: q8_0 gemv precision improvement (#24923) --- ggml/src/ggml-opencl/kernels/gemv_noshuffle_q8_0_f32.cl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q8_0_f32.cl b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q8_0_f32.cl index 9703b693e5..f5c6fb3e84 100644 --- a/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q8_0_f32.cl +++ b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q8_0_f32.cl @@ -174,7 +174,7 @@ __kernel void kernel_gemv_noshuffle_q8_0_f32( regA.s6 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x; regA.s7 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x; - dequantizeBlockAccum_ns_sgbroadcast_1(totalSum, regA, regS, regB); + dequantizeBlockAccum_ns_sgbroadcast_1(totalSum, regA, convert_float(regS), regB); } // reduction in local memory, assumes #wave=4 From 73618f27a801c0b8614ceaf3547d3c2a99baae14 Mon Sep 17 00:00:00 2001 From: Aldehir Rojas Date: Tue, 23 Jun 2026 00:27:28 -0500 Subject: [PATCH 54/86] server: improve user message detection and create checkpoints at every user message (#24176) * server : improve message span logic * cont : cast size_t to int32_t in comparisons * server : create checkpoints before every user msg * chat : remove \n in gemma4 delimiters * chat : merge msg delimiter structs into one * cont : reword comment * cont : initialize tokens in delimiter * cont : add server_tokens::get_raw_tokens() for mtmd * cont : move message finding to server_tokens and skip mtmd tokens * cont : update cohere2moe parser * cont : increase min-step to 8192 and always produce a chkpt for last user message --- common/chat.cpp | 156 +++++++++++++++++++++----------- common/chat.h | 71 +++++++++++++-- common/common.h | 2 +- tests/test-chat.cpp | 123 ++++++++++++++++++++----- tools/server/server-common.cpp | 18 ++-- tools/server/server-common.h | 3 + tools/server/server-context.cpp | 110 ++++------------------ tools/server/server-task.h | 6 +- 8 files changed, 302 insertions(+), 187 deletions(-) diff --git a/common/chat.cpp b/common/chat.cpp index ded8440e66..cee6ad650a 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -90,41 +90,93 @@ std::string common_chat_msg::render_content(const std::string & delimiter) const return text; } -std::vector common_chat_split_by_role(const std::string & prompt, const std::vector & delims) { - if (delims.empty() || prompt.empty()) { - return {}; +common_chat_role common_chat_role_from_string(const std::string & role) { + if (role == "system") { return COMMON_CHAT_ROLE_SYSTEM; } + if (role == "assistant") { return COMMON_CHAT_ROLE_ASSISTANT; } + if (role == "user") { return COMMON_CHAT_ROLE_USER; } + if (role == "tool") { return COMMON_CHAT_ROLE_TOOL; } + return COMMON_CHAT_ROLE_UNKNOWN; +} + +const char * common_chat_role_to_string(common_chat_role role) { + switch (role) { + case COMMON_CHAT_ROLE_SYSTEM: return "system"; + case COMMON_CHAT_ROLE_ASSISTANT: return "assistant"; + case COMMON_CHAT_ROLE_USER: return "user"; + case COMMON_CHAT_ROLE_TOOL: return "tool"; + case COMMON_CHAT_ROLE_UNKNOWN: return ""; + } + return ""; +} + +json common_chat_msg_delimiters::to_json() const { + json result = json::array(); + for (const auto & d : delimiters) { + result.push_back({ + { "role", common_chat_role_to_string(d.role) }, + { "delimiter", d.delimiter }, + }); + } + return result; +} + +common_chat_msg_delimiters common_chat_msg_delimiters_parse(const json & delimiters) { + common_chat_msg_delimiters result; + + if (!delimiters.is_array()) { + return result; } - auto parser = build_peg_parser([&](common_peg_parser_builder & p) { - std::vector all_delims; - std::vector tagged_messages; - - all_delims.reserve(delims.size()); - tagged_messages.reserve(delims.size()); - for (const auto & d : delims) { - all_delims.push_back(d.delimiter); + result.delimiters.reserve(delimiters.size()); + for (const auto & d : delimiters) { + if (!d.is_object()) { + continue; } - - auto any_delim = p.until_one_of(all_delims); - for (const auto & d : delims) { - tagged_messages.push_back(p.tag(d.role, p.literal(d.delimiter) + any_delim)); - } - - return any_delim + p.zero_or_more(p.choice(tagged_messages)) + p.end(); - }); - - common_peg_parse_context ctx(prompt); - const auto result = parser.parse(ctx); - if (!result.success()) { - return {}; + result.delimiters.push_back({ + common_chat_role_from_string(d.value("role", std::string())), + d.value("delimiter", std::string()), + }); } - std::vector spans; - ctx.ast.visit(result, [&](const common_peg_ast_node & node) { - if (!node.tag.empty()) { - spans.push_back({ node.tag, node.start, node.end - node.start }); + return result; +} + +void common_chat_msg_delimiters::tokenize(const llama_vocab * vocab) { + for (auto & d : delimiters) { + d.tokens = common_tokenize(vocab, d.delimiter, false, true); + } +} + +common_chat_msg_spans common_chat_msg_delimiters::split(const llama_tokens & tokens, const std::map & skips) const { + std::vector> matches; + + auto skip = skips.begin(); + for (size_t i = 0; i < tokens.size();) { + if (skip != skips.end() && i == skip->first) { + i += skip->second; + ++skip; + continue; } - }); + for (const auto & d : delimiters) { + if (i + d.tokens.size() > tokens.size()) { + continue; + } + if (std::equal(d.tokens.begin(), d.tokens.end(), tokens.begin() + i)) { + matches.emplace_back(d.role, i); + break; + } + } + i++; + } + + matches.emplace_back(COMMON_CHAT_ROLE_UNKNOWN, tokens.size()); + + common_chat_msg_spans spans; + for (size_t i = 0; i + 1 < matches.size(); i++) { + const auto & curr = matches[i]; + const auto & next = matches[i + 1]; + spans.add(curr.first, curr.second, next.second - curr.second); + } return spans; } @@ -1081,13 +1133,13 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp data.prompt = prompt; data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs, /* messages_override= */ adjusted_messages); - data.message_spans = common_chat_split_by_role(prompt, { - { "assistant", "<|start|>assistant" }, - { "user", "<|start|>user" }, - { "system", "<|start|>developer" }, - { "system", "<|start|>system" }, - { "tool", "<|start|>functions" }, - }); + data.message_delimiters = { + { COMMON_CHAT_ROLE_ASSISTANT, "<|start|>assistant" }, + { COMMON_CHAT_ROLE_USER, "<|start|>user" }, + { COMMON_CHAT_ROLE_SYSTEM, "<|start|>developer" }, + { COMMON_CHAT_ROLE_SYSTEM, "<|start|>system" }, + { COMMON_CHAT_ROLE_TOOL, "<|start|>functions" }, + }; data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; data.supports_thinking = true; @@ -1228,10 +1280,10 @@ static common_chat_params common_chat_params_init_gemma4(const common_chat_templ data.prompt += data.generation_prompt; } - data.message_spans = common_chat_split_by_role(data.prompt, { - { "user", "<|turn>user\n" }, - { "assistant", "<|turn>model\n" }, - }); + data.message_delimiters = { + { COMMON_CHAT_ROLE_USER, "<|turn>user" }, + { COMMON_CHAT_ROLE_ASSISTANT, "<|turn>model" }, + }; data.format = COMMON_CHAT_FORMAT_PEG_GEMMA4; data.supports_thinking = true; @@ -2030,15 +2082,15 @@ static common_chat_params common_chat_params_init_cohere2moe(const common_chat_t RESULT_START, RESULT_END, }; - // Split the rendered prompt into per-role message spans. Tool results are rendered with the + // Declare per-role message delimiters. Tool results are rendered with the // system token followed by <|START_TOOL_RESULT|>, so the "tool" delimiter must be listed before // the plain "system" one (it is a strict superset, and the role split tries delimiters in order). - data.message_spans = common_chat_split_by_role(data.prompt, { - { "assistant", GEN_PREFIX }, - { "user", TURN_START + USER }, - { "tool", TURN_START + SYSTEM + RESULT_START }, - { "system", TURN_START + SYSTEM }, - }); + data.message_delimiters = { + { COMMON_CHAT_ROLE_ASSISTANT, GEN_PREFIX }, + { COMMON_CHAT_ROLE_USER, TURN_START + USER }, + { COMMON_CHAT_ROLE_TOOL, TURN_START + SYSTEM + RESULT_START }, + { COMMON_CHAT_ROLE_SYSTEM, TURN_START + SYSTEM }, + }; auto has_tools = inputs.tools.is_array() && !inputs.tools.empty(); auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE; @@ -2526,17 +2578,15 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_ autoparser.analyze_template(tmpl); auto auto_params = autoparser::peg_generator::generate_parser(tmpl, params, autoparser); - std::vector delimiters; + common_chat_msg_delimiters delimiters; if (!autoparser.assistant_start.empty()) { - delimiters.push_back({ "assistant", autoparser.assistant_start }); + delimiters.add(COMMON_CHAT_ROLE_ASSISTANT, autoparser.assistant_start); } if (!autoparser.user_start.empty()) { - delimiters.push_back({ "user", autoparser.user_start }); + delimiters.add(COMMON_CHAT_ROLE_USER, autoparser.user_start); } - if (!delimiters.empty()) { - auto_params.message_spans = common_chat_split_by_role(auto_params.prompt, delimiters); - } + auto_params.message_delimiters = std::move(delimiters); auto_params.supports_thinking = autoparser.reasoning.mode != autoparser::reasoning_mode::NONE; if (auto_params.supports_thinking) { diff --git a/common/chat.h b/common/chat.h index 5659cd42a0..7898f1623f 100644 --- a/common/chat.h +++ b/common/chat.h @@ -143,15 +143,75 @@ struct common_chat_msg_diff { } }; +enum common_chat_role { + COMMON_CHAT_ROLE_UNKNOWN, + COMMON_CHAT_ROLE_SYSTEM, + COMMON_CHAT_ROLE_ASSISTANT, + COMMON_CHAT_ROLE_USER, + COMMON_CHAT_ROLE_TOOL +}; + +common_chat_role common_chat_role_from_string(const std::string & role); +const char * common_chat_role_to_string(common_chat_role role); + struct common_chat_msg_span { - std::string role; + common_chat_role role = COMMON_CHAT_ROLE_UNKNOWN; std::size_t pos = 0; std::size_t len = 0; + + bool valid() const { + return role != COMMON_CHAT_ROLE_UNKNOWN; + } +}; + +struct common_chat_msg_spans { + std::vector spans; + + void add(common_chat_role role, size_t pos, size_t len) { + spans.push_back({ role, pos, len }); + } + + bool is_user_start(int32_t pos) const { + for (auto it = spans.begin(); it != spans.end(); ++it) { + if (it->role == COMMON_CHAT_ROLE_USER && pos == (int32_t) it->pos) { + return true; + } + } + return false; + } + + int32_t last_user_message_pos() const { + for (auto it = spans.rbegin(); it != spans.rend(); ++it) { + if (it->role == COMMON_CHAT_ROLE_USER) { + return (int32_t) it->pos; + } + } + return -1; + } }; struct common_chat_msg_delimiter { - std::string role; - std::string delimiter; + common_chat_role role = COMMON_CHAT_ROLE_UNKNOWN; + std::string delimiter; + llama_tokens tokens = {}; +}; + +struct common_chat_msg_delimiters { + std::vector delimiters; + + common_chat_msg_delimiters() = default; + common_chat_msg_delimiters(std::initializer_list delims) : delimiters(delims) {} + + void add(common_chat_role role, const std::string & delimiter) { + delimiters.push_back({ role, delimiter }); + } + + void tokenize(const llama_vocab * vocab); + + // split tokens into message spans. skips maps a start index to a length of a region to jump over without matching + common_chat_msg_spans split(const llama_tokens & tokens, const std::map & skips = {}) const; + + nlohmann::ordered_json to_json() const; }; struct common_chat_tool { @@ -219,7 +279,7 @@ struct common_chat_params { std::vector preserved_tokens; std::vector additional_stops; std::string parser; - std::vector message_spans; + common_chat_msg_delimiters message_delimiters; }; // per-message parsing syntax @@ -325,5 +385,4 @@ struct common_chat_prompt_preset { common_chat_prompt_preset common_chat_get_asr_prompt(const common_chat_templates * chat_templates); -std::vector common_chat_split_by_role(const std::string & prompt, const std::vector & delims); - +common_chat_msg_delimiters common_chat_msg_delimiters_parse(const nlohmann::ordered_json & delimiters); diff --git a/common/common.h b/common/common.h index f2f2202ec2..75a6036a0f 100644 --- a/common/common.h +++ b/common/common.h @@ -609,7 +609,7 @@ struct common_params { bool cache_prompt = true; // whether to enable prompt caching bool cache_idle_slots = true; // save and clear idle slots upon starting a new task int32_t n_ctx_checkpoints = 32; // max number of context checkpoints per slot - int32_t checkpoint_min_step = 256; // minimum spacing between context checkpoints + int32_t checkpoint_min_step = 8192; // minimum spacing between context checkpoints int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc. std::string hostname = "127.0.0.1"; diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index 30aa35e137..c38aed8cfe 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -1562,37 +1562,112 @@ static void test_msgs_oaicompat_json_conversion() { } } -static void test_split_by_role() { +static void test_msg_token_delimiters_split() { LOG_DBG("%s\n", __func__); + // Delimiters that share a leading token, distinguished by the second token, + // to exercise the per-position token matching. + const common_chat_msg_delimiters delims = { + { { COMMON_CHAT_ROLE_USER, "", { 10, 11 } }, + { COMMON_CHAT_ROLE_ASSISTANT, "", { 10, 12 } } } + }; + // Empty inputs - assert_equals(0, common_chat_split_by_role("", {}).size()); - assert_equals(0, common_chat_split_by_role("hello", {}).size()); - assert_equals(0, common_chat_split_by_role("", { { "user", "<|user|>" } }).size()); + assert_equals(0, common_chat_msg_delimiters{}.split({}).spans.size()); + assert_equals(0, common_chat_msg_delimiters{}.split({ 10, 11 }).spans.size()); + assert_equals(0, delims.split({}).spans.size()); - // Multi-role conversation, no leading/trailing content + // No delimiters match -> no spans + assert_equals(0, delims.split({ 100, 101, 102 }).spans.size()); + + // Multi-role conversation: HiHelloBye { - const std::string prompt = "<|user|>Hi<|assistant|>Hello<|user|>Bye"; - const auto splits = common_chat_split_by_role(prompt, { - { "user", "<|user|>" }, - { "assistant", "<|assistant|>" }, - }); - assert_equals(3, splits.size()); + const llama_tokens tokens = { + 10, 11, // + 100, 101, // Hi + 10, 12, // + 200, 201, 202, // Hello + 10, 11, // + 300, 301, // Bye + }; - assert_equals("user", splits[0].role); - assert_equals(0, splits[0].pos); - assert_equals(10, splits[0].len); - assert_equals("<|user|>Hi", prompt.substr(splits[0].pos, splits[0].len)); + const auto result = delims.split(tokens); + const auto & spans = result.spans; + assert_equals(3, spans.size()); - assert_equals("assistant", splits[1].role); - assert_equals(10, splits[1].pos); - assert_equals(18, splits[1].len); - assert_equals("<|assistant|>Hello", prompt.substr(splits[1].pos, splits[1].len)); + assert_equals(COMMON_CHAT_ROLE_USER, spans[0].role); + assert_equals(0, spans[0].pos); + assert_equals(4, spans[0].len); - assert_equals("user", splits[2].role); - assert_equals(28, splits[2].pos); - assert_equals(11, splits[2].len); - assert_equals("<|user|>Bye", prompt.substr(splits[2].pos, splits[2].len)); + assert_equals(COMMON_CHAT_ROLE_ASSISTANT, spans[1].role); + assert_equals(4, spans[1].pos); + assert_equals(5, spans[1].len); + + assert_equals(COMMON_CHAT_ROLE_USER, spans[2].role); + assert_equals(9, spans[2].pos); + assert_equals(4, spans[2].len); + + // is_user_start() is true at the token position where a user span begins + assert_equals(true, result.is_user_start(0)); + assert_equals(false, result.is_user_start(4)); // assistant span + assert_equals(true, result.is_user_start(9)); + } + + // Content before the first delimiter is not captured as a span + { + const llama_tokens tokens = { + 500, 501, // leading content (dropped) + 10, 11, // + 100, // Hi + }; + + const auto spans = delims.split(tokens).spans; + assert_equals(1, spans.size()); + assert_equals(COMMON_CHAT_ROLE_USER, spans[0].role); + assert_equals(2, spans[0].pos); + assert_equals(3, spans[0].len); + } + + // Skipped regions (media chunks) are jumped over but still count as span content + { + const llama_tokens tokens = { + 10, 11, // + LLAMA_TOKEN_NULL, // media chunk (3 tokens) + LLAMA_TOKEN_NULL, + LLAMA_TOKEN_NULL, + 100, // Hi + 10, 12, // + }; + + const std::map skips = { { 2, 3 } }; + + const auto spans = delims.split(tokens, skips).spans; + assert_equals(2, spans.size()); + + assert_equals(COMMON_CHAT_ROLE_USER, spans[0].role); + assert_equals(0, spans[0].pos); + assert_equals(6, spans[0].len); + + assert_equals(COMMON_CHAT_ROLE_ASSISTANT, spans[1].role); + assert_equals(6, spans[1].pos); + assert_equals(2, spans[1].len); + } + + // A delimiter sequence inside a skipped region is not matched + { + const llama_tokens tokens = { + 10, 11, // + 10, 12, // skipped region that happens to contain delimiter tokens + 100, // Hi + }; + + const std::map skips = { { 2, 2 } }; + + const auto spans = delims.split(tokens, skips).spans; + assert_equals(1, spans.size()); + assert_equals(COMMON_CHAT_ROLE_USER, spans[0].role); + assert_equals(0, spans[0].pos); + assert_equals(5, spans[0].len); } } @@ -5857,7 +5932,7 @@ int main(int argc, char ** argv) { { test_msg_diffs_compute(); test_msgs_oaicompat_json_conversion(); - test_split_by_role(); + test_msg_token_delimiters_split(); test_tools_oaicompat_json_conversion(); test_convert_responses_to_chatcmpl(); test_developer_role_to_system_workaround(); diff --git a/tools/server/server-common.cpp b/tools/server/server-common.cpp index e412b94c5c..ac291d359a 100644 --- a/tools/server/server-common.cpp +++ b/tools/server/server-common.cpp @@ -518,6 +518,14 @@ size_t server_tokens::get_common_prefix(const server_tokens & b) const { return max_idx; // all tokens are equal } +common_chat_msg_spans server_tokens::find_message_spans(const common_chat_msg_delimiters & delims) const { + std::map skips; + for (const auto & it : map_idx_to_media) { + skips[it.first] = mtmd_input_chunk_get_n_tokens(it.second.get()); + } + return delims.split(tokens, skips); +} + bool server_tokens::validate(const struct llama_context * ctx) const { const llama_model * model = llama_get_model(ctx); const llama_vocab * vocab = llama_model_get_vocab(model); @@ -1104,15 +1112,7 @@ json oaicompat_chat_params_parse( llama_params["chat_parser"] = chat_params.parser; } - llama_params["message_spans"] = json::array(); - - for (const auto & span : chat_params.message_spans) { - llama_params["message_spans"].push_back({ - { "role", span.role }, - { "pos", span.pos }, - { "len", span.len }, - }); - } + llama_params["message_delimiters"] = chat_params.message_delimiters.to_json(); // Reasoning budget: pass parameters through to sampling layer { diff --git a/tools/server/server-common.h b/tools/server/server-common.h index efd31733b0..c0eaec6b02 100644 --- a/tools/server/server-common.h +++ b/tools/server/server-common.h @@ -218,6 +218,9 @@ public: size_t get_common_prefix(const server_tokens & b) const; + // split the tokens into message spans, skipping over media chunks + common_chat_msg_spans find_message_spans(const common_chat_msg_delimiters & delims) const; + // make sure all text tokens are within the vocab range bool validate(const struct llama_context * ctx) const; diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 0a25b414ed..ca91449d26 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -3436,8 +3436,8 @@ private: has_mtmd = true; } - const int32_t n_before_user = slot.task->params.n_before_user; - const bool n_before_user_known = n_before_user > 0; + const auto & spans = slot.task->params.message_spans; + const auto last_user_pos = spans.last_user_message_pos(); // add prompt tokens for processing in the current batch while (slot.prompt.n_tokens() < slot.task->n_tokens() && batch.size() < n_batch) { @@ -3466,10 +3466,8 @@ private: slot.n_prompt_tokens_processed++; - // stop the prompt batch exactly before the latest user input, so a checkpoint - // can be created after the previous messages - if (n_before_user_known && - slot.prompt.n_tokens() == n_before_user) { + // stop the prompt batch exactly before a user message + if (spans.is_user_start(slot.prompt.n_tokens())) { break; } @@ -3498,8 +3496,13 @@ private: // the number of tokens added to the batch for the current slot const auto n_tokens_cur = batch.size() - n_tokens_prev; + const auto n_tokens_start = slot.prompt.n_tokens() - n_tokens_cur; + const bool near_prompt_end = slot.task->n_tokens() < slot.prompt.n_tokens() + n_ubatch; + const bool is_user_start = spans.is_user_start(n_tokens_start); + const bool is_last_user_message = n_tokens_start == last_user_pos; + // entire prompt has been processed if (slot.prompt.n_tokens() == slot.task->n_tokens()) { slot.state = SLOT_STATE_DONE_PROMPT; @@ -3514,8 +3517,9 @@ private: slot.init_sampler(); } else { - // skip ordinary mid-prompt checkpoints - if (!n_before_user_known && !near_prompt_end) { + // skip ordinary mid-prompt checkpoints, unless the batch starts a user + // message or we are near the end of the prompt + if (!is_user_start && !near_prompt_end) { do_checkpoint = false; } } @@ -3523,29 +3527,6 @@ private: const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), slot.id); const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), slot.id); - // checkpoints are created before the current batch is decoded, so - // their token position is the batch start rather than the prompt end - const int32_t n_tokens_start = slot.prompt.n_tokens() - n_tokens_cur; - - { - const bool is_on_user = - n_before_user_known && - n_tokens_start == n_before_user; - - const bool is_after_user = - n_before_user_known && - n_tokens_start > n_before_user; - - const bool is_allowed = - !n_before_user_known || - is_on_user || - (is_after_user && near_prompt_end); - - if (do_checkpoint && !is_allowed) { - do_checkpoint = false; - } - } - // nothing to checkpoint yet // TODO: is this check needed? if (do_checkpoint && pos_min < 0) { @@ -3555,8 +3536,8 @@ private: // do not checkpoint after mtmd chunks do_checkpoint = do_checkpoint && !has_mtmd; - // no need to create checkpoints that are too close together - do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || n_tokens_start > slot.prompt.checkpoints.back().n_tokens + params_base.checkpoint_min_step); + // no need to create checkpoints that are too close together, unless it's the last user message + do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || is_last_user_message || n_tokens_start > slot.prompt.checkpoints.back().n_tokens + params_base.checkpoint_min_step); SLT_DBG(slot, "main/do_checkpoint = %s, pos_min = %d, pos_max = %d\n", do_checkpoint ? "yes" : "no", pos_min, pos_max); // note: we create the checkpoint before calling llama_decode(), so the current batch is not @@ -4055,54 +4036,6 @@ void server_context::set_state_callback(server_state_callback_t callback) { }); } -// compute the number of tokens before the last user message in the prompt -static int32_t prompt_get_n_before_user( - const json & message_spans, - const std::string & prompt, - const std::vector & files, - const llama_vocab * vocab, - mtmd_context * mctx) { - int32_t result = -1; - int32_t byte_pos = -1; - - for (const auto & span : message_spans) { - const std::string role = json_value(span, "role", std::string()); - - if (role == "user") { - byte_pos = json_value(span, "pos", -1); - } - } - - if (byte_pos >= 0) { - GGML_ASSERT((size_t) byte_pos <= prompt.size()); - - const std::string prefix = prompt.substr(0, (size_t) byte_pos); - - const std::string marker = get_media_marker(); - size_t n_prefix_media = 0; - for (size_t pos = 0; (pos = prefix.find(marker, pos)) != std::string::npos; pos += marker.size()) { - n_prefix_media++; - } - - GGML_ASSERT(n_prefix_media <= files.size()); - - if (mctx != nullptr && n_prefix_media > 0) { - // TODO: this makes a copy - avoid it - std::vector prefix_files(files.begin(), files.begin() + n_prefix_media); - - result = (int32_t) process_mtmd_prompt(mctx, prefix, prefix_files).size(); - } else { - result = (int32_t) tokenize_input_prompts(vocab, nullptr, prefix, true, true)[0].size(); - } - - SRV_TRC("message_spans: last user message: byte_pos=%d, media=%zu, n_before_user=%d\n", - byte_pos, n_prefix_media, result); - } - - return result; -} - - // // server_routes // @@ -4150,6 +4083,10 @@ std::unique_ptr server_routes::handle_completions_impl( // tasks.reserve(inputs.size()); // TODO: this is inaccurate due to child tasks + // message delimiters for checkpointing + auto delimiters = common_chat_msg_delimiters_parse(json_value(data, "message_delimiters", json::array())); + delimiters.tokenize(ctx_server.vocab); + for (size_t i = 0; i < inputs.size(); i++) { server_task task = server_task(type); @@ -4163,16 +4100,7 @@ std::unique_ptr server_routes::handle_completions_impl( meta->logit_bias_eog, data); - const auto message_spans = json_value(data, "message_spans", json::array()); - if (prompt.is_string() && message_spans.is_array()) { - task.params.n_before_user = - prompt_get_n_before_user( - message_spans, - prompt.get(), - files, - ctx_server.vocab, - ctx_server.mctx); - } + task.params.message_spans = task.tokens.find_message_spans(delimiters); task.id_slot = json_value(data, "id_slot", -1); diff --git a/tools/server/server-task.h b/tools/server/server-task.h index 299c279d7d..293bdf053a 100644 --- a/tools/server/server-task.h +++ b/tools/server/server-task.h @@ -62,9 +62,6 @@ struct task_params { int32_t n_cache_reuse = 0; // min chunk size to attempt reusing from the cache via KV shifting (0 = disabled) - // number of prompt tokens before the latest user message - int32_t n_before_user = -1; - int64_t t_max_prompt_ms = -1; // TODO: implement int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit @@ -92,6 +89,9 @@ struct task_params { // per-request parameters for chat parsing common_chat_parser_params chat_parser_params; + // message spans for checkpointing + common_chat_msg_spans message_spans; + // Embeddings int32_t embd_normalize = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm) From 035cd8f9a6dda9cd0224b07d3df2dc95f7a3a31e Mon Sep 17 00:00:00 2001 From: Masashi Yoshimura Date: Tue, 23 Jun 2026 15:19:34 +0900 Subject: [PATCH 55/86] codeowners: add yomaytk to ggml-webgpu (#24930) --- CODEOWNERS | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CODEOWNERS b/CODEOWNERS index 4b9d901771..46fd518b7e 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -10,7 +10,7 @@ # ggml-org/ggml-rpc : rgerganov # ggml-org/ggml-sycl : arthw # ggml-org/ggml-vulkan : 0cc4m, jeffbolznv -# ggml-org/ggml-webgpu : reeselevine +# ggml-org/ggml-webgpu : reeselevine, yomaytk # ggml-org/ggml-zdnn : taronaeo # ggml-org/llama-common : ggerganov, aldehir, angt, danbev, ngxson, pwilkin # ggml-org/llama-mtmd : ngxson From 7c908502ea0868e6ae913f79ba84ba844a5b386a Mon Sep 17 00:00:00 2001 From: Masashi Yoshimura Date: Tue, 23 Jun 2026 17:13:55 +0900 Subject: [PATCH 56/86] ggml-webgpu: improve MTP inference by using mat-vec path for small batches (#24811) * ggml-webgpu: improve small batches decoding * Add barrier to the NUM_COLS loop in mul-mat-vec --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 13 +- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 20 +- .../wgsl-shaders/mul_mat_id_vec.wgsl | 4 +- .../ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl | 88 +- .../wgsl-shaders/mul_mat_vec_acc.tmpl | 991 ++++++++++-------- .../wgsl-shaders/mul_mat_vec_q_acc.tmpl | 132 ++- .../ggml-webgpu/wgsl-shaders/quantize_q8.wgsl | 23 +- tests/test-backend-ops.cpp | 2 + 8 files changed, 682 insertions(+), 591 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 6f877f15ce..c00a2e9ee9 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -905,11 +905,12 @@ struct ggml_webgpu_mul_mat_vec_pipeline_key { ggml_type src0_type; ggml_type src1_type; int vectorized; + uint32_t num_cols; bool use_mmvq; bool operator==(const ggml_webgpu_mul_mat_vec_pipeline_key & other) const { return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized && - use_mmvq == other.use_mmvq; + num_cols == other.num_cols && use_mmvq == other.use_mmvq; } }; @@ -919,6 +920,7 @@ struct ggml_webgpu_mul_mat_vec_pipeline_key_hash { ggml_webgpu_hash_combine(seed, key.src0_type); ggml_webgpu_hash_combine(seed, key.src1_type); ggml_webgpu_hash_combine(seed, key.vectorized); + ggml_webgpu_hash_combine(seed, key.num_cols); ggml_webgpu_hash_combine(seed, key.use_mmvq); return seed; } @@ -993,11 +995,12 @@ struct ggml_webgpu_mul_mat_id_pipeline_key { ggml_type src0_type; ggml_type src1_type; uint32_t n_experts; + uint32_t num_cols; int vectorized; bool operator==(const ggml_webgpu_mul_mat_id_pipeline_key & other) const { return src0_type == other.src0_type && src1_type == other.src1_type && n_experts == other.n_experts && - vectorized == other.vectorized; + num_cols == other.num_cols && vectorized == other.vectorized; } }; @@ -1007,6 +1010,7 @@ struct ggml_webgpu_mul_mat_id_pipeline_key_hash { ggml_webgpu_hash_combine(seed, key.src0_type); ggml_webgpu_hash_combine(seed, key.src1_type); ggml_webgpu_hash_combine(seed, key.n_experts); + ggml_webgpu_hash_combine(seed, key.num_cols); ggml_webgpu_hash_combine(seed, key.vectorized); return seed; } @@ -1107,7 +1111,7 @@ inline bool ggml_webgpu_can_use_mmvq(const ggml_tensor * src0, const ggml_tensor * src1, bool supports_dot_product, const std::string & vendor) { - if (src1->ne[1] == 1) { + if (src1->ne[1] <= 4) { bool supports_dp4a = vendor == "amd" || vendor == "intel" || vendor == "nvidia"; if (supports_dp4a && supports_dot_product) { switch (src1->type) { @@ -1889,6 +1893,7 @@ class ggml_webgpu_shader_lib { (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? 1 : 0; + key.num_cols = context.dst->ne[1]; key.use_mmvq = ggml_webgpu_can_use_mmvq(context.src0, context.src1, context.supports_dot_product, context.vendor); @@ -2004,6 +2009,7 @@ class ggml_webgpu_shader_lib { if (key.vectorized) { variant += "_vectorized"; } + defines.push_back(std::string("NUM_COLS=") + std::to_string(key.num_cols)); auto processed = preprocessor.preprocess(shader_src, defines); auto decisions = std::make_shared(); @@ -2421,6 +2427,7 @@ class ggml_webgpu_shader_lib { if (key.vectorized) { variant += "_vectorized"; } + defines.push_back(std::string("NUM_COLS=1")); defines.push_back(std::string("N_EXPERTS=") + std::to_string(key.n_experts)); diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index f71d1aee73..e8eafd185a 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1418,15 +1418,17 @@ static void ggml_webgpu_quantize_q8_dispatch(webgpu_context & const size_t dst_offset = ggml_webgpu_tensor_offset(dst); const size_t q8_src1_align_offset = ROUNDUP_POW2( dst_offset + ggml_nbytes(dst), ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); - const size_t q8_src1_binding_size = - ROUNDUP_POW2(src1->ne[3] * src1->ne[2] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32)), - WEBGPU_STORAGE_BUF_BINDING_MULT); + const size_t q8_src1_binding_size = ROUNDUP_POW2( + src1->ne[3] * src1->ne[2] * src1->ne[1] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32)), + WEBGPU_STORAGE_BUF_BINDING_MULT); std::vector q8_params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), (uint32_t) src1->ne[0], + (uint32_t) src1->ne[1], (uint32_t) src1->ne[2], (uint32_t) src1->ne[3], }; @@ -1442,7 +1444,7 @@ static void ggml_webgpu_quantize_q8_dispatch(webgpu_context & uint32_t q8_wg_x = 1; uint32_t q8_wg_y = 1; const uint32_t wg_per_vec = (src0->ne[0] / 4 + (q8_wg_size - 1)) / q8_wg_size; - const uint32_t q8_total_wg = src1->ne[2] * src1->ne[3] * wg_per_vec; + const uint32_t q8_total_wg = src1->ne[1] * src1->ne[2] * src1->ne[3] * wg_per_vec; const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; compute_2d_workgroups(q8_total_wg, max_wg_per_dim, q8_wg_x, q8_wg_y); @@ -1456,7 +1458,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src1, ggml_tensor * dst) { // Determine if this is a mat-vec operation - bool is_vec = (dst->ne[1] == 1); + bool use_mat_vec = (dst->ne[1] <= 4); // use MMVQ path for mat-vec bool use_mmvq = ggml_webgpu_can_use_mmvq(src0, src1, ctx->global_ctx->capabilities.supports_dot_product, @@ -1482,7 +1484,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, webgpu_pipeline pipeline; std::vector dispatches; - if (is_vec) { + if (use_mat_vec) { if (use_mmvq) { ggml_webgpu_quantize_q8_dispatch(ctx, src0, src1, dst, dispatches); } @@ -1529,7 +1531,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, uint32_t wg_y = 1; const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; - if (is_vec) { + if (use_mat_vec) { auto * decisions = static_cast(pipeline.context.get()); uint32_t batches = dst->ne[2] * dst->ne[3]; @@ -3691,8 +3693,8 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer ggml_webgpu_can_use_mmvq(src0, src1, ctx->webgpu_global_ctx->capabilities.supports_dot_product, ctx->webgpu_global_ctx->vendor); if (use_mmvq) { - const size_t q8_src1_size = - src1->ne[3] * src1->ne[2] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32)); + const size_t q8_src1_size = src1->ne[3] * src1->ne[2] * src1->ne[1] * + (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32)); res = ROUNDUP_POW2(res + q8_src1_size + ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment, WEBGPU_STORAGE_BUF_BINDING_MULT); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl index 6ff9bcf2df..78ae955e6b 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl @@ -103,7 +103,7 @@ fn main( #ifdef USE_SUBGROUP_REDUCTION for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let subgroup_total = subgroupAdd(acc[row]); + let subgroup_total = subgroupAdd(acc[0][row]); if (subgroup_invocation_id == 0u) { partial_sums[partial_index(row, subgroup_id)] = subgroup_total; } @@ -126,7 +126,7 @@ fn main( #ifdef USE_WORKGROUP_REDUCTION for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - partial_sums[partial_index(row, thread_id)] = acc[row]; + partial_sums[partial_index(row, thread_id)] = acc[0][row]; } workgroupBarrier(); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl index f0a7fbd059..ebdf09513e 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl @@ -91,61 +91,67 @@ fn main( let dst_idx_base = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row_base; #ifdef MMVQ - let src1q_idx_base = (src13_idx * params.bs02 * params.broadcast2 + src12_idx) * (params.k / 32u); + let src1q_idx_base = (src13_idx * params.bs02 * params.broadcast2 + src12_idx) * params.n * (params.k / 32u); let acc = accumulate_vec_q_dot(thread_id, row_base, src0_batch_offset, src1q_idx_base); #else let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; let acc = accumulate_vec_dot(thread_id, row_base, src0_batch_offset, src1_idx_base); #endif + for (var col = 0u;col < NUM_COLS;col += 1) { + #ifdef USE_SUBGROUP_REDUCTION - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let subgroup_total = subgroupAdd(acc[row]); - if (subgroup_invocation_id == 0u) { - partial_sums[partial_index(row, subgroup_id)] = subgroup_total; - } - } + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let subgroup_total = subgroupAdd(acc[col][row]); + if (subgroup_invocation_id == 0u) { + partial_sums[partial_index(row, subgroup_id)] = subgroup_total; + } + } - workgroupBarrier(); + workgroupBarrier(); - for (var row = subgroup_id; (row < OUTPUTS_PER_WG) && (row_base + row < params.m); row += num_subgroups) { - let output_row = row_base + row; - var row_acc = 0.0f; - for (var k = subgroup_invocation_id; k < num_subgroups; k += subgroup_size) { - row_acc += partial_sums[partial_index(row, k)]; - } - let row_total = subgroupAdd(row_acc); - if (subgroup_invocation_id == 0) { - dst[dst_idx_base + row] = row_total; - } - } + for (var row = subgroup_id; (row < OUTPUTS_PER_WG) && (row_base + row < params.m); row += num_subgroups) { + let output_row = row_base + row; + var row_acc = 0.0f; + for (var k = subgroup_invocation_id; k < num_subgroups; k += subgroup_size) { + row_acc += partial_sums[partial_index(row, k)]; + } + let row_total = subgroupAdd(row_acc); + if (subgroup_invocation_id == 0) { + dst[dst_idx_base + col * params.m + row] = row_total; + } + } #endif #ifdef USE_WORKGROUP_REDUCTION - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - partial_sums[partial_index(row, thread_id)] = acc[row]; - } + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + partial_sums[partial_index(row, thread_id)] = acc[col][row]; + } + + workgroupBarrier(); + + var stride = WG_SIZE / 2u; + + while (stride > 0) { + if (thread_id < stride) { + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + partial_sums[partial_index(row, thread_id)] += partial_sums[partial_index(row, thread_id + stride)]; + } + } + + workgroupBarrier(); + stride = stride / 2; + } + + if (thread_id < OUTPUTS_PER_WG) { + let output_row = row_base + thread_id; + if (output_row < params.m) { + dst[dst_idx_base + col * params.m + thread_id] = partial_sums[partial_index(thread_id, 0)]; + } + } +#endif workgroupBarrier(); - var stride = WG_SIZE / 2u; - - while (stride > 0) { - if (thread_id < stride) { - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - partial_sums[partial_index(row, thread_id)] += partial_sums[partial_index(row, thread_id + stride)]; - } - } - - workgroupBarrier(); - stride = stride / 2; } - - if (thread_id < OUTPUTS_PER_WG) { - let output_row = row_base + thread_id; - if (output_row < params.m) { - dst[dst_idx_base + thread_id] = partial_sums[partial_index(thread_id, 0)]; - } - } -#endif } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl index 08753b9d64..b0703fe906 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl @@ -32,8 +32,8 @@ fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 { #endif #ifdef MUL_ACC_FLOAT -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let k_vec = params.k / VEC_SIZE; let src1_idx_base_vec = src1_idx_base / VEC_SIZE; @@ -41,12 +41,18 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src // Each thread walks K, loads from the vector, and updates // a small block of output rows held in registers. for (var k = thread_id; k < k_vec; k += WG_SIZE) { - let x = src1[src1_idx_base_vec + k]; + var x_vals: array; + for (var col = 0u;col < NUM_COLS;col += 1) { + x_vals[col] = src1[src1_idx_base_vec + col * (params.stride_11 / VEC_SIZE) + k]; + } for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { let src0_idx = (src0_batch_offset + output_row * params.stride_01) / VEC_SIZE + k; - acc[row] += inner_dot(src0[src0_idx], x); + let w = src0[src0_idx]; + for (var col = 0u;col < NUM_COLS;col += 1) { + acc[col][row] += inner_dot(w, x_vals[col]); + } } } } @@ -60,30 +66,33 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE_BYTES 18 #define THREADS_PER_BLOCK 16 #define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let num_blocks = params.k / BLOCK_SIZE; let thread_within_block = thread_id % THREADS_PER_BLOCK; for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD; i++) { - x_block[i] = f32(src1[x_base + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < ELEMS_PER_THREAD; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; let d = f32(load_f16_at_src0(block_byte_base)); let q_byte = load_u32_at_src0(block_byte_base + 2u + thread_within_block) & 0xFFu; - var row_sum = 0.0; - for (var bit = 0u; bit < 8u; bit++) { - let w = select(-d, d, ((q_byte >> bit) & 1u) != 0u); - row_sum += w * x_block[bit]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var bit = 0u; bit < 8u; bit++) { + let w = select(-d, d, ((q_byte >> bit) & 1u) != 0u); + row_sum += w * x_block[col][bit]; + } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -97,35 +106,37 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE_BYTES 18 #define THREADS_PER_BLOCK 4 #define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let num_blocks = params.k / BLOCK_SIZE; let thread_within_block = thread_id % 4; for (var block = thread_id/THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE/THREADS_PER_BLOCK) { let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4] = f32(src1[x_base + i + 16]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + x_block[col][i + 4] = f32(src1[x_base + col * params.stride_11 + i + 16]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; let d = f32(load_f16_at_src0(block_byte_base)); - var row_sum = 0.0; - let q_packed = load_u32_at_src0(block_byte_base + 2u + 4u * thread_within_block); - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_byte = get_byte(q_packed, byte_idx); - let q_lo = (f32(q_byte & 0xFu) - 8.0) * d; - let q_hi = (f32((q_byte >> 4u) & 0xFu) - 8.0) * d; - row_sum += q_lo * x_block[byte_idx]; - row_sum += q_hi * x_block[byte_idx + 4u]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let q_lo = (f32(q_byte & 0xFu) - 8.0) * d; + let q_hi = (f32((q_byte >> 4u) & 0xFu) - 8.0) * d; + row_sum += q_lo * x_block[col][byte_idx]; + row_sum += q_hi * x_block[col][byte_idx + 4u]; + } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -139,36 +150,38 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE_BYTES 20 #define THREADS_PER_BLOCK 4 #define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let num_blocks = params.k / BLOCK_SIZE; let thread_within_block = thread_id % THREADS_PER_BLOCK; for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4] = f32(src1[x_base + i + 16]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + x_block[col][i + 4] = f32(src1[x_base + col * params.stride_11 + i + 16]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; let d = f32(load_f16_at_src0(block_byte_base)); let m = f32(load_f16_at_src0(block_byte_base + 2u)); - var row_sum = 0.0; - let q_packed = load_u32_at_src0(block_byte_base + 4u + 4u * thread_within_block); - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_byte = get_byte(q_packed, byte_idx); - let q_lo = f32(q_byte & 0xFu) * d + m; - let q_hi = f32((q_byte >> 4u) & 0xFu) * d + m; - row_sum += q_lo * x_block[byte_idx]; - row_sum += q_hi * x_block[byte_idx + 4u]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let q_lo = f32(q_byte & 0xFu) * d + m; + let q_hi = f32((q_byte >> 4u) & 0xFu) * d + m; + row_sum += q_lo * x_block[col][byte_idx]; + row_sum += q_hi * x_block[col][byte_idx + 4u]; + } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -182,19 +195,20 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE_BYTES 22 #define THREADS_PER_BLOCK 4 #define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let num_blocks = params.k / BLOCK_SIZE; let thread_within_block = thread_id % THREADS_PER_BLOCK; for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4] = f32(src1[x_base + i + 16]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + x_block[col][i + 4] = f32(src1[x_base + col * params.stride_11 + i + 16]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -203,18 +217,19 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let qh_packed = load_u32_at_src0(block_byte_base + 2u); let q_packed = load_u32_at_src0(block_byte_base + 6u + 4u * thread_within_block); let qh_shift = thread_within_block * 4u; - var row_sum = 0.0; - - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_byte = get_byte(q_packed, byte_idx); - let qh_lo = ((qh_packed >> (qh_shift + byte_idx)) << 4u) & 0x10u; - let qh_hi = (qh_packed >> (qh_shift + byte_idx + 12u)) & 0x10u; - let q_lo = (f32((q_byte & 0xFu) | qh_lo) - 16.0) * d; - let q_hi = (f32(((q_byte >> 4u) & 0xFu) | qh_hi) - 16.0) * d; - row_sum += q_lo * x_block[byte_idx]; - row_sum += q_hi * x_block[byte_idx + 4u]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let qh_lo = ((qh_packed >> (qh_shift + byte_idx)) << 4u) & 0x10u; + let qh_hi = (qh_packed >> (qh_shift + byte_idx + 12u)) & 0x10u; + let q_lo = (f32((q_byte & 0xFu) | qh_lo) - 16.0) * d; + let q_hi = (f32(((q_byte >> 4u) & 0xFu) | qh_hi) - 16.0) * d; + row_sum += q_lo * x_block[col][byte_idx]; + row_sum += q_hi * x_block[col][byte_idx + 4u]; + } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -228,19 +243,20 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE_BYTES 24 #define THREADS_PER_BLOCK 4 #define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let num_blocks = params.k / BLOCK_SIZE; let thread_within_block = thread_id % THREADS_PER_BLOCK; for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4] = f32(src1[x_base + i + 16]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + x_block[col][i + 4] = f32(src1[x_base + col * params.stride_11 + i + 16]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -250,18 +266,19 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let qh_packed = load_u32_at_src0(block_byte_base + 4u); let q_packed = load_u32_at_src0(block_byte_base + 8u + 4u * thread_within_block); let qh_shift = thread_within_block * 4u; - var row_sum = 0.0; - - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_byte = get_byte(q_packed, byte_idx); - let qh_lo = ((qh_packed >> (qh_shift + byte_idx)) << 4u) & 0x10u; - let qh_hi = (qh_packed >> (qh_shift + byte_idx + 12u)) & 0x10u; - let q_lo = f32((q_byte & 0xFu) | qh_lo) * d + m; - let q_hi = f32(((q_byte >> 4u) & 0xFu) | qh_hi) * d + m; - row_sum += q_lo * x_block[byte_idx]; - row_sum += q_hi * x_block[byte_idx + 4u]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let qh_lo = ((qh_packed >> (qh_shift + byte_idx)) << 4u) & 0x10u; + let qh_hi = (qh_packed >> (qh_shift + byte_idx + 12u)) & 0x10u; + let q_lo = f32((q_byte & 0xFu) | qh_lo) * d + m; + let q_hi = f32(((q_byte >> 4u) & 0xFu) | qh_hi) * d + m; + row_sum += q_lo * x_block[col][byte_idx]; + row_sum += q_hi * x_block[col][byte_idx + 4u]; + } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -275,33 +292,38 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE_BYTES 34 #define THREADS_PER_BLOCK 4 #define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let num_blocks = params.k / BLOCK_SIZE; let thread_within_block = thread_id % THREADS_PER_BLOCK; for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD; i++) { - x_block[i] = f32(src1[x_base + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < ELEMS_PER_THREAD; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; let d = f32(load_f16_at_src0(block_byte_base)); - var row_sum = 0.0; - + var q_packed: array; for (var packed_idx = 0u; packed_idx < ELEMS_PER_THREAD / 4u; packed_idx++) { - let q_packed = load_u32_at_src0(block_byte_base + 2u + 4u * (thread_within_block * 2u + packed_idx)); - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_val = f32(get_byte_i32(q_packed, byte_idx)) * d; - row_sum += q_val * x_block[packed_idx * 4u + byte_idx]; - } + q_packed[packed_idx] = load_u32_at_src0(block_byte_base + 2u + 4u * (thread_within_block * 2u + packed_idx)); + } + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var packed_idx = 0u; packed_idx < ELEMS_PER_THREAD / 4u; packed_idx++) { + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_val = f32(get_byte_i32(q_packed[packed_idx], byte_idx)) * d; + row_sum += q_val * x_block[col][packed_idx * 4u + byte_idx]; + } + } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -315,34 +337,39 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE_BYTES 36 #define THREADS_PER_BLOCK 4 #define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let num_blocks = params.k / BLOCK_SIZE; let thread_within_block = thread_id % THREADS_PER_BLOCK; for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD; i++) { - x_block[i] = f32(src1[x_base + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < ELEMS_PER_THREAD; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; let d = f32(load_f16_at_src0(block_byte_base)); let m = f32(load_f16_at_src0(block_byte_base + 2u)); - var row_sum = 0.0; - + var q_packed: array; for (var packed_idx = 0u; packed_idx < ELEMS_PER_THREAD / 4u; packed_idx++) { - let q_packed = load_u32_at_src0(block_byte_base + 4u + 4u * (thread_within_block * 2u + packed_idx)); - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_val = f32(get_byte_i32(q_packed, byte_idx)) * d + m; - row_sum += q_val * x_block[packed_idx * 4u + byte_idx]; - } + q_packed[packed_idx] = load_u32_at_src0(block_byte_base + 4u + 4u * (thread_within_block * 2u + packed_idx)); + } + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var packed_idx = 0u; packed_idx < ELEMS_PER_THREAD / 4u; packed_idx++) { + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_val = f32(get_byte_i32(q_packed[packed_idx], byte_idx)) * d + m; + row_sum += q_val * x_block[col][packed_idx * 4u + byte_idx]; + } + } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -355,8 +382,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 84 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -379,14 +406,15 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 4u; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4u] = f32(src1[x_base + 32u + i]); - x_block[i + 8u] = f32(src1[x_base + 64u + i]); - x_block[i + 12u] = f32(src1[x_base + 96u + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < 4u; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + x_block[col][i + 4u] = f32(src1[x_base + col * params.stride_11 + 32u + i]); + x_block[col][i + 8u] = f32(src1[x_base + col * params.stride_11 + 64u + i]); + x_block[col][i + 12u] = f32(src1[x_base + col * params.stride_11 + 96u + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -404,30 +432,32 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let qs0 = q_u32 & 0xFFFFu; let qs1 = q_u32 >> 16u; - var sumy = vec4(0.0, 0.0, 0.0, 0.0); - var acc1 = vec4(0.0, 0.0, 0.0, 0.0); - var acc2 = vec4(0.0, 0.0, 0.0, 0.0); + for (var col = 0u;col < NUM_COLS;col += 1) { + var sumy = vec4(0.0, 0.0, 0.0, 0.0); + var acc1 = vec4(0.0, 0.0, 0.0, 0.0); + var acc2 = vec4(0.0, 0.0, 0.0, 0.0); - sumy[0] = x_block[0] + x_block[1] + x_block[2] + x_block[3]; - sumy[1] = x_block[4] + x_block[5] + x_block[6] + x_block[7]; - sumy[2] = x_block[8] + x_block[9] + x_block[10] + x_block[11]; - sumy[3] = x_block[12] + x_block[13] + x_block[14] + x_block[15]; + sumy[0] = x_block[col][0] + x_block[col][1] + x_block[col][2] + x_block[col][3]; + sumy[1] = x_block[col][4] + x_block[col][5] + x_block[col][6] + x_block[col][7]; + sumy[2] = x_block[col][8] + x_block[col][9] + x_block[col][10] + x_block[col][11]; + sumy[3] = x_block[col][12] + x_block[col][13] + x_block[col][14] + x_block[col][15]; - acc1[0] = x_block[0] * f32(qs0 & 0x0003u) + x_block[2] * f32(qs1 & 0x0003u); - acc2[0] = x_block[1] * f32(qs0 & 0x0300u) + x_block[3] * f32(qs1 & 0x0300u); - acc1[1] = x_block[4] * f32(qs0 & 0x000Cu) + x_block[6] * f32(qs1 & 0x000Cu); - acc2[1] = x_block[5] * f32(qs0 & 0x0C00u) + x_block[7] * f32(qs1 & 0x0C00u); - acc1[2] = x_block[8] * f32(qs0 & 0x0030u) + x_block[10] * f32(qs1 & 0x0030u); - acc2[2] = x_block[9] * f32(qs0 & 0x3000u) + x_block[11] * f32(qs1 & 0x3000u); - acc1[3] = x_block[12] * f32(qs0 & 0x00C0u) + x_block[14] * f32(qs1 & 0x00C0u); - acc2[3] = x_block[13] * f32(qs0 & 0xC000u) + x_block[15] * f32(qs1 & 0xC000u); + acc1[0] = x_block[col][0] * f32(qs0 & 0x0003u) + x_block[col][2] * f32(qs1 & 0x0003u); + acc2[0] = x_block[col][1] * f32(qs0 & 0x0300u) + x_block[col][3] * f32(qs1 & 0x0300u); + acc1[1] = x_block[col][4] * f32(qs0 & 0x000Cu) + x_block[col][6] * f32(qs1 & 0x000Cu); + acc2[1] = x_block[col][5] * f32(qs0 & 0x0C00u) + x_block[col][7] * f32(qs1 & 0x0C00u); + acc1[2] = x_block[col][8] * f32(qs0 & 0x0030u) + x_block[col][10] * f32(qs1 & 0x0030u); + acc2[2] = x_block[col][9] * f32(qs0 & 0x3000u) + x_block[col][11] * f32(qs1 & 0x3000u); + acc1[3] = x_block[col][12] * f32(qs0 & 0x00C0u) + x_block[col][14] * f32(qs1 & 0x00C0u); + acc2[3] = x_block[col][13] * f32(qs0 & 0xC000u) + x_block[col][15] * f32(qs1 & 0xC000u); - acc[row] += dall * ((acc1[0] + (1.0/256.0) * acc2[0]) * f32(sc0 & 0xFu) + - (acc1[1] + (1.0/256.0) * acc2[1]) * f32(sc2 & 0xFu) / 4.0 + - (acc1[2] + (1.0/256.0) * acc2[2]) * f32(sc4 & 0xFu) / 16.0 + - (acc1[3] + (1.0/256.0) * acc2[3]) * f32(sc6 & 0xFu) / 64.0) - - dmin * (sumy[0] * f32(sc0 & 0xF0u) + sumy[1] * f32(sc2 & 0xF0u) + - sumy[2] * f32(sc4 & 0xF0u) + sumy[3] * f32(sc6 & 0xF0u)); + acc[col][row] += dall * ((acc1[0] + (1.0/256.0) * acc2[0]) * f32(sc0 & 0xFu) + + (acc1[1] + (1.0/256.0) * acc2[1]) * f32(sc2 & 0xFu) / 4.0 + + (acc1[2] + (1.0/256.0) * acc2[2]) * f32(sc4 & 0xFu) / 16.0 + + (acc1[3] + (1.0/256.0) * acc2[3]) * f32(sc6 & 0xFu) / 64.0) + - dmin * (sumy[0] * f32(sc0 & 0xF0u) + sumy[1] * f32(sc2 & 0xF0u) + + sumy[2] * f32(sc4 & 0xF0u) + sumy[3] * f32(sc6 & 0xF0u)); + } } } } @@ -440,8 +470,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 110 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -485,12 +515,13 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 8u; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 8u] = f32(src1[x_base + 32u + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < 8u; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + x_block[col][i + 8u] = f32(src1[x_base + col * params.stride_11 + 32u + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -516,28 +547,30 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let h_u32_0 = load_u32_at_src0(block_byte_base + h_byte + 0u); let h_u32_1 = load_u32_at_src0(block_byte_base + h_byte + 4u); - var s1 = 0.0; var s2 = 0.0; var s3 = 0.0; - var s4 = 0.0; var s5 = 0.0; var s6 = 0.0; + for (var col = 0u;col < NUM_COLS;col += 1) { + var s1 = 0.0; var s2 = 0.0; var s3 = 0.0; + var s4 = 0.0; var s5 = 0.0; var s6 = 0.0; - for (var l = 0u; l < 8u; l += 2u) { - let q_u32 = select(q_u32_0, q_u32_1, l >= 4u); - let qs = select(q_u32 & 0xFFFFu, q_u32 >> 16u, (l & 2u) != 0u); - let h_u32 = select(h_u32_0, h_u32_1, l >= 4u); - let hv = select(h_u32 & 0xFFFFu, h_u32 >> 16u, (l & 2u) != 0u); + for (var l = 0u; l < 8u; l += 2u) { + let q_u32 = select(q_u32_0, q_u32_1, l >= 4u); + let qs = select(q_u32 & 0xFFFFu, q_u32 >> 16u, (l & 2u) != 0u); + let h_u32 = select(h_u32_0, h_u32_1, l >= 4u); + let hv = select(h_u32 & 0xFFFFu, h_u32 >> 16u, (l & 2u) != 0u); - s1 += x_block[l + 0u] * f32(qs & qm0); - s2 += x_block[l + 1u] * f32(qs & qm1); - s3 += select(0.0, x_block[l + 0u], (hv & hm0) == 0u) + - select(0.0, x_block[l + 1u], (hv & hm1) == 0u); - s4 += x_block[l + 8u] * f32(qs & qm2); - s5 += x_block[l + 9u] * f32(qs & qm3); - s6 += select(0.0, x_block[l + 8u], (hv & hm2) == 0u) + - select(0.0, x_block[l + 9u], (hv & hm3) == 0u); + s1 += x_block[col][l + 0u] * f32(qs & qm0); + s2 += x_block[col][l + 1u] * f32(qs & qm1); + s3 += select(0.0, x_block[col][l + 0u], (hv & hm0) == 0u) + + select(0.0, x_block[col][l + 1u], (hv & hm1) == 0u); + s4 += x_block[col][l + 8u] * f32(qs & qm2); + s5 += x_block[col][l + 9u] * f32(qs & qm3); + s6 += select(0.0, x_block[col][l + 8u], (hv & hm2) == 0u) + + select(0.0, x_block[col][l + 9u], (hv & hm3) == 0u); + } + + let d1 = d * (s1 + (1.0/256.0) * s2 - s3 * v1); + let d2 = d * (s4 + (1.0/256.0) * s5 - s6 * v2); + acc[col][row] += (d1 * scale0 + 0.25 * d2 * scale1) / f32(1u << shift); } - - let d1 = d * (s1 + (1.0/256.0) * s2 - s3 * v1); - let d2 = d * (s4 + (1.0/256.0) * s5 - s6 * v2); - acc[row] += (d1 * scale0 + 0.25 * d2 * scale1) / f32(1u << shift); } } } @@ -550,8 +583,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 144 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -573,12 +606,15 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 4u; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4u] = f32(src1[x_base + 32u + i]); - x_block[i + 8u] = f32(src1[x_base + 128u + i]); - x_block[i + 12u] = f32(src1[x_base + 160u + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + let col_base = x_base + col * params.stride_11; + for (var i = 0u; i < 4u; i++) { + x_block[col][i] = f32(src1[col_base + i]); + x_block[col][i + 4u] = f32(src1[col_base + 32u + i]); + x_block[col][i + 8u] = f32(src1[col_base + 128u + i]); + x_block[col][i + 12u] = f32(src1[col_base + 160u + i]); + } } for (var row = 0u; row < OUTPUTS_PER_WG; row++) { @@ -613,23 +649,25 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let q1_u32 = load_u32_at_src0_aligned(block_byte_base + 16u + q_offset); let q2_u32 = load_u32_at_src0_aligned(block_byte_base + 80u + q_offset); - var dot = vec4(0.0, 0.0, 0.0, 0.0); - var sumx = vec4(0.0, 0.0, 0.0, 0.0); - for (var i = 0u; i < 4u; i++) { - let q1b = byte_of(q1_u32, i); - let q2b = byte_of(q2_u32, i); - dot[0] += x_block[i] * f32(q1b & 0x0Fu); - dot[1] += x_block[i + 4u] * f32(q1b >> 4u); - dot[2] += x_block[i + 8u] * f32(q2b & 0x0Fu); - dot[3] += x_block[i + 12u] * f32(q2b >> 4u); - sumx[0] += x_block[i]; - sumx[1] += x_block[i + 4u]; - sumx[2] += x_block[i + 8u]; - sumx[3] += x_block[i + 12u]; - } + for (var col = 0u;col < NUM_COLS;col += 1) { + var dot = vec4(0.0, 0.0, 0.0, 0.0); + var sumx = vec4(0.0, 0.0, 0.0, 0.0); + for (var i = 0u; i < 4u; i++) { + let q1b = byte_of(q1_u32, i); + let q2b = byte_of(q2_u32, i); + dot[0] += x_block[col][i] * f32(q1b & 0x0Fu); + dot[1] += x_block[col][i + 4u] * f32(q1b >> 4u); + dot[2] += x_block[col][i + 8u] * f32(q2b & 0x0Fu); + dot[3] += x_block[col][i + 12u] * f32(q2b >> 4u); + sumx[0] += x_block[col][i]; + sumx[1] += x_block[col][i + 4u]; + sumx[2] += x_block[col][i + 8u]; + sumx[3] += x_block[col][i + 12u]; + } - acc[row] += d * (dot[0] * scale0 + dot[1] * scale1 + dot[2] * scale2 + dot[3] * scale3) - - dmin * (sumx[0] * min0 + sumx[1] * min1 + sumx[2] * min2 + sumx[3] * min3); + acc[col][row] += d * (dot[0] * scale0 + dot[1] * scale1 + dot[2] * scale2 + dot[3] * scale3) + - dmin * (sumx[0] * min0 + sumx[1] * min1 + sumx[2] * min2 + sumx[3] * min3); + } } } } @@ -642,8 +680,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 176 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -671,14 +709,16 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 4u; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4u] = f32(src1[x_base + 32u + i]); - x_block[i + 8u] = f32(src1[x_base + 128u + i]); - x_block[i + 12u] = f32(src1[x_base + 160u + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + let col_base = x_base + col * params.stride_11; + for (var i = 0u; i < 4u; i++) { + x_block[col][i] = f32(src1[col_base + i]); + x_block[col][i + 4u] = f32(src1[col_base + 32u + i]); + x_block[col][i + 8u] = f32(src1[col_base + 128u + i]); + x_block[col][i + 12u] = f32(src1[col_base + 160u + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -712,37 +752,39 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let q2_u32 = load_u32_at_src0_aligned(block_byte_base + q_offset + 64u); let qh_u32 = load_u32_at_src0_aligned(block_byte_base + qh_offset); - var vals = vec4(0.0, 0.0, 0.0, 0.0); - var sumy = vec4(0.0, 0.0, 0.0, 0.0); - for (var i = 0u; i < 4u; i++) { - let q1b = byte_of(q1_u32, i); - let q2b = byte_of(q2_u32, i); - let qhb = byte_of(qh_u32, i); + for (var col = 0u;col < NUM_COLS;col += 1) { + var vals = vec4(0.0, 0.0, 0.0, 0.0); + var sumy = vec4(0.0, 0.0, 0.0, 0.0); + for (var i = 0u; i < 4u; i++) { + let q1b = byte_of(q1_u32, i); + let q2b = byte_of(q2_u32, i); + let qhb = byte_of(qh_u32, i); - let yl0 = x_block[i]; - let yl8 = x_block[i + 4u]; - let yh0 = x_block[i + 8u]; - let yh8 = x_block[i + 12u]; + let yl0 = x_block[col][i]; + let yl8 = x_block[col][i + 4u]; + let yh0 = x_block[col][i + 8u]; + let yh8 = x_block[col][i + 12u]; - sumy[0] += yl0; - sumy[1] += yl8; - sumy[2] += yh0; - sumy[3] += yh8; + sumy[0] += yl0; + sumy[1] += yl8; + sumy[2] += yh0; + sumy[3] += yh8; - let q0 = f32((q1b & 0x0Fu) | select(0u, 0x10u, (qhb & hm1) != 0u)); - let q1 = f32((q1b >> 4u) | select(0u, 0x10u, (qhb & hm2) != 0u)); - let q2 = f32((q2b & 0x0Fu) | select(0u, 0x10u, (qhb & hm3) != 0u)); - let q3 = f32((q2b >> 4u) | select(0u, 0x10u, (qhb & hm4) != 0u)); + let q0 = f32((q1b & 0x0Fu) | select(0u, 0x10u, (qhb & hm1) != 0u)); + let q1 = f32((q1b >> 4u) | select(0u, 0x10u, (qhb & hm2) != 0u)); + let q2 = f32((q2b & 0x0Fu) | select(0u, 0x10u, (qhb & hm3) != 0u)); + let q3 = f32((q2b >> 4u) | select(0u, 0x10u, (qhb & hm4) != 0u)); - vals[0] += yl0 * q0; - vals[1] += yl8 * q1; - vals[2] += yh0 * q2; - vals[3] += yh8 * q3; + vals[0] += yl0 * q0; + vals[1] += yl8 * q1; + vals[2] += yh0 * q2; + vals[3] += yh8 * q3; + } + + acc[col][row] += d * (f0 * vals[0] + f1 * vals[1] + f4 * vals[2] + f5 * vals[3]) + - dmin * (sumy[0] * m0 + sumy[1] * m1 + + sumy[2] * m4 + sumy[3] * m5); } - - acc[row] += d * (f0 * vals[0] + f1 * vals[1] + f4 * vals[2] + f5 * vals[3]) - - dmin * (sumy[0] * m0 + sumy[1] * m1 + - sumy[2] * m4 + sumy[3] * m5); } } } @@ -755,8 +797,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 210 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -777,14 +819,16 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var l = 0u; l < 4u; l++) { - x_block[l] = f32(src1[x_base + l]); - x_block[l + 4u] = f32(src1[x_base + 32u + l]); - x_block[l + 8u] = f32(src1[x_base + 64u + l]); - x_block[l + 12u] = f32(src1[x_base + 96u + l]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + let col_base = x_base + col * params.stride_11; + for (var l = 0u; l < 4u; l++) { + x_block[col][l] = f32(src1[col_base + l]); + x_block[col][l + 4u] = f32(src1[col_base + 32u + l]); + x_block[col][l + 8u] = f32(src1[col_base + 64u + l]); + x_block[col][l + 12u] = f32(src1[col_base + 96u + l]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -802,26 +846,28 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let sc4 = sbyte_of(sc_u32_1, sc_byte_pos); let sc6 = sbyte_of(sc_u32_1, sc_byte_pos + 2u); - var sums = vec4(0.0, 0.0, 0.0, 0.0); + for (var col = 0u;col < NUM_COLS;col += 1) { + var sums = vec4(0.0, 0.0, 0.0, 0.0); - for (var l = 0u; l < 4u; l++) { - let q1b = byte_of(ql1_u32, l); - let q2b = byte_of(ql2_u32, l); - let qhb = byte_of(qh_u32, l); + for (var l = 0u; l < 4u; l++) { + let q1b = byte_of(ql1_u32, l); + let q2b = byte_of(ql2_u32, l); + let qhb = byte_of(qh_u32, l); - let dq0 = f32(i32((q1b & 0x0Fu) | ((qhb & 0x03u) << 4u)) - 32); - let dq1 = f32(i32((q2b & 0x0Fu) | ((qhb & 0x0Cu) << 2u)) - 32); - let dq2 = f32(i32((q1b >> 4u) | (qhb & 0x30u)) - 32); - let dq3 = f32(i32((q2b >> 4u) | ((qhb & 0xC0u) >> 2u)) - 32); + let dq0 = f32(i32((q1b & 0x0Fu) | ((qhb & 0x03u) << 4u)) - 32); + let dq1 = f32(i32((q2b & 0x0Fu) | ((qhb & 0x0Cu) << 2u)) - 32); + let dq2 = f32(i32((q1b >> 4u) | (qhb & 0x30u)) - 32); + let dq3 = f32(i32((q2b >> 4u) | ((qhb & 0xC0u) >> 2u)) - 32); - sums[0] += x_block[l] * dq0; - sums[1] += x_block[l + 4u] * dq1; - sums[2] += x_block[l + 8u] * dq2; - sums[3] += x_block[l + 12u] * dq3; + sums[0] += x_block[col][l] * dq0; + sums[1] += x_block[col][l + 4u] * dq1; + sums[2] += x_block[col][l + 8u] * dq2; + sums[3] += x_block[col][l + 12u] * dq3; + } + + acc[col][row] += d * (sums[0] * f32(sc0) + sums[1] * f32(sc2) + + sums[2] * f32(sc4) + sums[3] * f32(sc6)); } - - acc[row] += d * (sums[0] * f32(sc0) + sums[1] * f32(sc2) + - sums[2] * f32(sc4) + sums[3] * f32(sc6)); } } } @@ -834,8 +880,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 50 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -850,11 +896,12 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 16u; i++) { - x_block[i] = f32(src1[x_base + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < 16u; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -866,20 +913,22 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000u) != 0u); let qs_w = load_u32_at_src0(block_byte_base + 2u + sub_blk * 4u); - var row_sum = 0.0; - for (var ll = 0u; ll < 2u; ll++) { - let l = slot0 + ll; - let qs_byte = get_byte(qs_w, l); - let ig = (qs_byte | (((qh >> (3u * l)) & 7u) << 8u)) * 8u; - let gw = iq1_grid[ig / 16u]; - let bit_base = (ig % 16u) * 2u; - for (var j = 0u; j < 8u; j++) { - let g = (gw >> (bit_base + j * 2u)) & 3u; - let gs = select(f32(g), f32(g) - 4.0, (g & 2u) != 0u); - row_sum += dl * (gs + delta) * x_block[ll * 8u + j]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let qs_byte = get_byte(qs_w, l); + let ig = (qs_byte | (((qh >> (3u * l)) & 7u) << 8u)) * 8u; + let gw = iq1_grid[ig / 16u]; + let bit_base = (ig % 16u) * 2u; + for (var j = 0u; j < 8u; j++) { + let g = (gw >> (bit_base + j * 2u)) & 3u; + let gs = select(f32(g), f32(g) - 4.0, (g & 2u) != 0u); + row_sum += dl * (gs + delta) * x_block[col][ll * 8u + j]; + } } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -892,8 +941,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 56 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -908,11 +957,12 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 16u; i++) { - x_block[i] = f32(src1[x_base + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < 16u; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -936,26 +986,28 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let qh_lo = qh & 0xFFu; let qh_hi = (qh >> 8u) & 0xFFu; - var row_sum = 0.0; - for (var ll = 0u; ll < 2u; ll++) { - let l = slot0 + ll; - let bit_off = 6u * (sub_blk % 2u) + 3u * (l / 2u); - let sub_scale = (sc_u16 >> bit_off) & 0x7u; - let dl = d * f32(2u * sub_scale + 1u); - let qh_byte = select(qh_lo, qh_hi, l >= 2u); - let ll2 = l % 2u; - let grid_idx = get_byte(qs_w, l) | (((qh_byte >> (4u * ll2)) & 7u) << 8u); - let delta = select(IQ1_DELTA, -IQ1_DELTA, ((qh_byte >> (3u + 4u * ll2)) & 1u) != 0u); - let ig = grid_idx * 8u; - let gw = iq1_grid[ig / 16u]; - let bit_base = (ig % 16u) * 2u; - for (var j = 0u; j < 8u; j++) { - let g = (gw >> (bit_base + j * 2u)) & 3u; - let gs = select(f32(g), f32(g) - 4.0, (g & 2u) != 0u); - row_sum += dl * (gs + delta) * x_block[ll * 8u + j]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let bit_off = 6u * (sub_blk % 2u) + 3u * (l / 2u); + let sub_scale = (sc_u16 >> bit_off) & 0x7u; + let dl = d * f32(2u * sub_scale + 1u); + let qh_byte = select(qh_lo, qh_hi, l >= 2u); + let ll2 = l % 2u; + let grid_idx = get_byte(qs_w, l) | (((qh_byte >> (4u * ll2)) & 7u) << 8u); + let delta = select(IQ1_DELTA, -IQ1_DELTA, ((qh_byte >> (3u + 4u * ll2)) & 1u) != 0u); + let ig = grid_idx * 8u; + let gw = iq1_grid[ig / 16u]; + let bit_base = (ig % 16u) * 2u; + for (var j = 0u; j < 8u; j++) { + let g = (gw >> (bit_base + j * 2u)) & 3u; + let gs = select(f32(g), f32(g) - 4.0, (g & 2u) != 0u); + row_sum += dl * (gs + delta) * x_block[col][ll * 8u + j]; + } } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -968,8 +1020,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 66 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -984,11 +1036,12 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 16u; i++) { - x_block[i] = f32(src1[x_base + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < 16u; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -999,22 +1052,24 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let ls = aux_hi >> 28u; let db = d * (0.5 + f32(ls)) * 0.25; - var row_sum = 0.0; - for (var ll = 0u; ll < 2u; ll++) { - let l = slot0 + ll; - let grid_idx = (aux_lo >> (8u * l)) & 0xFFu; - let signs_idx = (aux_hi >> (7u * l)) & 0x7Fu; - let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu; - let gw_lo = iq2xxs_grid[grid_idx * 2u]; - let gw_hi = iq2xxs_grid[grid_idx * 2u + 1u]; - for (var j = 0u; j < 8u; j++) { - let gw = select(gw_hi, gw_lo, j < 4u); - let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu); - let s = select(1.0, -1.0, ((signs >> j) & 1u) != 0u); - row_sum += db * b * s * x_block[ll * 8u + j]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let grid_idx = (aux_lo >> (8u * l)) & 0xFFu; + let signs_idx = (aux_hi >> (7u * l)) & 0x7Fu; + let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu; + let gw_lo = iq2xxs_grid[grid_idx * 2u]; + let gw_hi = iq2xxs_grid[grid_idx * 2u + 1u]; + for (var j = 0u; j < 8u; j++) { + let gw = select(gw_hi, gw_lo, j < 4u); + let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu); + let s = select(1.0, -1.0, ((signs >> j) & 1u) != 0u); + row_sum += db * b * s * x_block[col][ll * 8u + j]; + } } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -1027,8 +1082,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 74 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -1043,11 +1098,12 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 16u; i++) { - x_block[i] = f32(src1[x_base + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < 16u; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -1058,27 +1114,29 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let scales_word = load_u32_at_src0(block_byte_base + 66u + (sub_blk / 4u) * 4u); let scales_byte = get_byte(scales_word, sub_blk % 4u); - var row_sum = 0.0; - for (var ll = 0u; ll < 2u; ll++) { - let l = slot0 + ll; - let qs_word = select(qs_hi, qs_lo, l < 2u); - let half2 = (l % 2u) * 16u; - let qs_val = (qs_word >> half2) & 0xFFFFu; - let grid_idx = qs_val & 0x1FFu; - let signs_idx = (qs_val >> 9u) & 0x7Fu; - let sub_scale = (scales_byte >> (4u * (l / 2u))) & 0xFu; - let db = d * (0.5 + f32(sub_scale)) * 0.25; - let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu; - let gw_lo = iq2xs_grid[grid_idx * 2u]; - let gw_hi = iq2xs_grid[grid_idx * 2u + 1u]; - for (var j = 0u; j < 8u; j++) { - let gw = select(gw_hi, gw_lo, j < 4u); - let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu); - let s = select(1.0, -1.0, ((signs >> j) & 1u) != 0u); - row_sum += db * b * s * x_block[ll * 8u + j]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let qs_word = select(qs_hi, qs_lo, l < 2u); + let half2 = (l % 2u) * 16u; + let qs_val = (qs_word >> half2) & 0xFFFFu; + let grid_idx = qs_val & 0x1FFu; + let signs_idx = (qs_val >> 9u) & 0x7Fu; + let sub_scale = (scales_byte >> (4u * (l / 2u))) & 0xFu; + let db = d * (0.5 + f32(sub_scale)) * 0.25; + let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu; + let gw_lo = iq2xs_grid[grid_idx * 2u]; + let gw_hi = iq2xs_grid[grid_idx * 2u + 1u]; + for (var j = 0u; j < 8u; j++) { + let gw = select(gw_hi, gw_lo, j < 4u); + let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu); + let s = select(1.0, -1.0, ((signs >> j) & 1u) != 0u); + row_sum += db * b * s * x_block[col][ll * 8u + j]; + } } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -1091,8 +1149,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 82 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -1107,11 +1165,12 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 16u; i++) { - x_block[i] = f32(src1[x_base + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < 16u; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -1124,24 +1183,26 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let sc_word = load_u32_at_src0(block_byte_base + 74u + (sub_blk / 4u) * 4u); let scales_byte = get_byte(sc_word, sub_blk % 4u); - var row_sum = 0.0; - for (var ll = 0u; ll < 2u; ll++) { - let l = slot0 + ll; - let qs_byte = get_byte(qs_w, l); - let sign_byte = get_byte(sg_w, l); - let grid_idx = qs_byte | (((qh_byte >> (2u * l)) & 3u) << 8u); - let sub_scale = (scales_byte >> (4u * (l / 2u))) & 0xFu; - let db = d * (0.5 + f32(sub_scale)) * 0.25; - let gw_lo = iq2s_grid[grid_idx * 2u]; - let gw_hi = iq2s_grid[grid_idx * 2u + 1u]; - for (var j = 0u; j < 8u; j++) { - let gw = select(gw_hi, gw_lo, j < 4u); - let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu); - let s = select(1.0, -1.0, ((sign_byte >> j) & 1u) != 0u); - row_sum += db * b * s * x_block[ll * 8u + j]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let qs_byte = get_byte(qs_w, l); + let sign_byte = get_byte(sg_w, l); + let grid_idx = qs_byte | (((qh_byte >> (2u * l)) & 3u) << 8u); + let sub_scale = (scales_byte >> (4u * (l / 2u))) & 0xFu; + let db = d * (0.5 + f32(sub_scale)) * 0.25; + let gw_lo = iq2s_grid[grid_idx * 2u]; + let gw_hi = iq2s_grid[grid_idx * 2u + 1u]; + for (var j = 0u; j < 8u; j++) { + let gw = select(gw_hi, gw_lo, j < 4u); + let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu); + let s = select(1.0, -1.0, ((sign_byte >> j) & 1u) != 0u); + row_sum += db * b * s * x_block[col][ll * 8u + j]; + } } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -1154,8 +1215,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 98 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -1170,11 +1231,12 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 16u; i++) { - x_block[i] = f32(src1[x_base + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < 16u; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -1186,27 +1248,29 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let ls = aux >> 28u; let db = d * (0.5 + f32(ls)) * 0.5; - var row_sum = 0.0; - for (var ll = 0u; ll < 2u; ll++) { - let l = slot0 + ll; - let qs_word = select(qs_hi, qs_lo, l < 2u); - let byte_pos = (l % 2u) * 2u; - let grid_idx_0 = (qs_word >> (byte_pos * 8u)) & 0xFFu; - let grid_idx_1 = (qs_word >> ((byte_pos + 1u) * 8u)) & 0xFFu; - let signs_idx = (aux >> (7u * l)) & 0x7Fu; - let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu; - let grid1 = iq3xxs_grid[grid_idx_0]; - let grid2 = iq3xxs_grid[grid_idx_1]; - for (var j = 0u; j < 4u; j++) { - let b1 = f32((grid1 >> (j * 8u)) & 0xFFu); - let b2 = f32((grid2 >> (j * 8u)) & 0xFFu); - let s1 = select(1.0, -1.0, ((signs >> j) & 1u) != 0u); - let s2 = select(1.0, -1.0, ((signs >> (j + 4u)) & 1u) != 0u); - row_sum += db * b1 * s1 * x_block[ll * 8u + j]; - row_sum += db * b2 * s2 * x_block[ll * 8u + j + 4u]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let qs_word = select(qs_hi, qs_lo, l < 2u); + let byte_pos = (l % 2u) * 2u; + let grid_idx_0 = (qs_word >> (byte_pos * 8u)) & 0xFFu; + let grid_idx_1 = (qs_word >> ((byte_pos + 1u) * 8u)) & 0xFFu; + let signs_idx = (aux >> (7u * l)) & 0x7Fu; + let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu; + let grid1 = iq3xxs_grid[grid_idx_0]; + let grid2 = iq3xxs_grid[grid_idx_1]; + for (var j = 0u; j < 4u; j++) { + let b1 = f32((grid1 >> (j * 8u)) & 0xFFu); + let b2 = f32((grid2 >> (j * 8u)) & 0xFFu); + let s1 = select(1.0, -1.0, ((signs >> j) & 1u) != 0u); + let s2 = select(1.0, -1.0, ((signs >> (j + 4u)) & 1u) != 0u); + row_sum += db * b1 * s1 * x_block[col][ll * 8u + j]; + row_sum += db * b2 * s2 * x_block[col][ll * 8u + j + 4u]; + } } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -1219,8 +1283,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 110 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -1235,11 +1299,12 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 16u; i++) { - x_block[i] = f32(src1[x_base + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < 16u; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -1255,28 +1320,30 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let sub_scale = (scales_byte >> (4u * (sub_blk % 2u))) & 0xFu; let db = d * (1.0 + 2.0 * f32(sub_scale)); - var row_sum = 0.0; - for (var ll = 0u; ll < 2u; ll++) { - let l = slot0 + ll; - let qs_word = select(qs_hi, qs_lo, l < 2u); - let byte_pos = (l % 2u) * 2u; - let qs0 = (qs_word >> (byte_pos * 8u)) & 0xFFu; - let qs1 = (qs_word >> ((byte_pos + 1u) * 8u)) & 0xFFu; - let grid_idx_1 = qs0 | (((qh_byte >> (2u * l)) & 1u) << 8u); - let grid_idx_2 = qs1 | (((qh_byte >> (2u * l + 1u)) & 1u) << 8u); - let sign_byte = get_byte(sg_w, l); - let grid1 = iq3s_grid[grid_idx_1]; - let grid2 = iq3s_grid[grid_idx_2]; - for (var j = 0u; j < 4u; j++) { - let b1 = f32((grid1 >> (j * 8u)) & 0xFFu); - let b2 = f32((grid2 >> (j * 8u)) & 0xFFu); - let s1 = select(1.0, -1.0, ((sign_byte >> j) & 1u) != 0u); - let s2 = select(1.0, -1.0, ((sign_byte >> (j + 4u)) & 1u) != 0u); - row_sum += db * b1 * s1 * x_block[ll * 8u + j]; - row_sum += db * b2 * s2 * x_block[ll * 8u + j + 4u]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let qs_word = select(qs_hi, qs_lo, l < 2u); + let byte_pos = (l % 2u) * 2u; + let qs0 = (qs_word >> (byte_pos * 8u)) & 0xFFu; + let qs1 = (qs_word >> ((byte_pos + 1u) * 8u)) & 0xFFu; + let grid_idx_1 = qs0 | (((qh_byte >> (2u * l)) & 1u) << 8u); + let grid_idx_2 = qs1 | (((qh_byte >> (2u * l + 1u)) & 1u) << 8u); + let sign_byte = get_byte(sg_w, l); + let grid1 = iq3s_grid[grid_idx_1]; + let grid2 = iq3s_grid[grid_idx_2]; + for (var j = 0u; j < 4u; j++) { + let b1 = f32((grid1 >> (j * 8u)) & 0xFFu); + let b2 = f32((grid2 >> (j * 8u)) & 0xFFu); + let s1 = select(1.0, -1.0, ((sign_byte >> j) & 1u) != 0u); + let s2 = select(1.0, -1.0, ((sign_byte >> (j + 4u)) & 1u) != 0u); + row_sum += db * b1 * s1 * x_block[col][ll * 8u + j]; + row_sum += db * b2 * s2 * x_block[col][ll * 8u + j + 4u]; + } } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -1290,35 +1357,37 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE_BYTES 18 #define THREADS_PER_BLOCK 4 #define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let num_blocks = params.k / BLOCK_SIZE; let thread_within_block = thread_id % THREADS_PER_BLOCK; for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4u; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD / 2u; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4u] = f32(src1[x_base + i + 16u]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < ELEMS_PER_THREAD / 2u; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + x_block[col][i + 4u] = f32(src1[x_base + col * params.stride_11 + i + 16u]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; let d = f32(load_f16_at_src0(block_byte_base)); - var row_sum = 0.0; - let q_packed = load_u32_at_src0(block_byte_base + 2u + 4u * thread_within_block); - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_byte = get_byte(q_packed, byte_idx); - let q_lo = f32(kvalues_iq4nl[q_byte & 0xFu]) * d; - let q_hi = f32(kvalues_iq4nl[(q_byte >> 4u) & 0xFu]) * d; - row_sum += q_lo * x_block[byte_idx]; - row_sum += q_hi * x_block[byte_idx + 4u]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let q_lo = f32(kvalues_iq4nl[q_byte & 0xFu]) * d; + let q_hi = f32(kvalues_iq4nl[(q_byte >> 4u) & 0xFu]) * d; + row_sum += q_lo * x_block[col][byte_idx]; + row_sum += q_hi * x_block[col][byte_idx + 4u]; + } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -1331,8 +1400,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 136 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -1346,11 +1415,12 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 16u; i++) { - x_block[i] = f32(src1[x_base + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < 16u; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -1370,17 +1440,19 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let q_w2 = load_u32_at_src0(block_byte_base + qs_byte_off + 8u); let q_w3 = load_u32_at_src0(block_byte_base + qs_byte_off + 12u); - var row_sum = 0.0; - for (var i = 0u; i < 16u; i++) { - let q_word = select( - select(q_w0, q_w1, i >= 4u), - select(q_w2, q_w3, i >= 12u), - i >= 8u); - let q_byte = get_byte(q_word, i % 4u); - let nib = select(q_byte & 0xFu, (q_byte >> 4u) & 0xFu, half == 1u); - row_sum += f32(kvalues_iq4nl[nib]) * dl * x_block[i]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var i = 0u; i < 16u; i++) { + let q_word = select( + select(q_w0, q_w1, i >= 4u), + select(q_w2, q_w3, i >= 12u), + i >= 8u); + let q_byte = get_byte(q_word, i % 4u); + let nib = select(q_byte & 0xFu, (q_byte >> 4u) & 0xFu, half == 1u); + row_sum += f32(kvalues_iq4nl[nib]) * dl * x_block[col][i]; + } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -1394,35 +1466,38 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE_BYTES 17 #define THREADS_PER_BLOCK 4 #define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let num_blocks = params.k / BLOCK_SIZE; let thread_within_block = thread_id % 4; for (var block = thread_id/THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE/THREADS_PER_BLOCK) { let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4] = f32(src1[x_base + i + 16]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + x_block[col][i + 4] = f32(src1[x_base + col * params.stride_11 + i + 16]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; let eu8 = get_byte(load_u32_at_src0(block_byte_base), 0); let e = ldexp(1.0, i32(eu8) - 128); - var row_sum = 0.0; let q_packed = load_u32_at_src0(block_byte_base + 1u + 4u * thread_within_block); - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_byte = get_byte(q_packed, byte_idx); - let q_lo = f32(kvalues_mxfp4[q_byte & 0xFu]) * e; - let q_hi = f32(kvalues_mxfp4[(q_byte >> 4u) & 0xFu]) * e; - row_sum += q_lo * x_block[byte_idx]; - row_sum += q_hi * x_block[byte_idx + 4u]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let q_lo = f32(kvalues_mxfp4[q_byte & 0xFu]) * e; + let q_hi = f32(kvalues_mxfp4[(q_byte >> 4u) & 0xFu]) * e; + row_sum += q_lo * x_block[col][byte_idx]; + row_sum += q_hi * x_block[col][byte_idx + 4u]; + } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl index 3ef2f77ebe..6ccaf61a6a 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl @@ -51,10 +51,7 @@ fn repack_b_dm(block: u32) -> B_DS_TYPE { fn get_dm(block_byte_base: u32) -> f32 { return f32(load_f16_at_src0(block_byte_base)); } -fn mul_q8_1(row_sum: i32, da: f32, b_ds: B_DS_TYPE) -> f32 { - return f32(row_sum) * (da * b_ds.x) - 8.0 * da * b_ds.y / THREADS_PER_BLOCK; -} -#endif +#endif // MUL_ACC_Q4_0 #ifdef MUL_ACC_Q4_1 #define BLOCK_SIZE_BYTES 20 @@ -85,10 +82,7 @@ fn get_dm(block_byte_base: u32) -> vec2 { f32(load_f16_at_src0(block_byte_base + 2u)) ); } -fn mul_q8_1(row_sum: i32, dma: vec2, b_ds: B_DS_TYPE) -> f32 { - return f32(row_sum) * (dma.x * b_ds.x) + dma.y * b_ds.y / THREADS_PER_BLOCK; -} -#endif +#endif // MUL_ACC_Q4_1 #ifdef MUL_ACC_Q8_0 #define BLOCK_SIZE_BYTES 34 @@ -111,46 +105,48 @@ fn repack_b_dm(block: u32) -> B_DS_TYPE { fn get_dm(block_byte_base: u32) -> f32 { return f32(load_f16_at_src0(block_byte_base)); } -fn mul_q8_1(row_sum: i32, da: f32, b_ds: B_DS_TYPE) -> f32 { - return f32(row_sum) * (da * b_ds); -} -#endif +#endif // MUL_ACC_Q8_0 -#ifdef LEGACY_QUANTS -fn mmvq_dot_product(a_byte_base: u32, b_inner_id: u32, b_repacked: vec2, b_ds: B_DS_TYPE) -> f32 { - var row_sum = 0; - let a_repacked = repack_a(a_byte_base, b_inner_id); - - row_sum += dot4I8Packed(a_repacked[0], b_repacked[0]); - row_sum += dot4I8Packed(a_repacked[1], b_repacked[1]); - - return mul_q8_1(row_sum, get_dm(a_byte_base), b_ds); -} - -fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array { - var acc: array; +#if defined(LEGACY_QUANTS) +fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let num_blocks = params.k / BLOCK_SIZE; for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { - let b_inner_id = thread_id % THREADS_PER_BLOCK; - let b_block_idx = src1q_idx_base + block; - - let b_repacked = repack_b_qs(b_block_idx, b_inner_id); - let b_ds = repack_b_dm(b_block_idx); - + let inner_id = thread_id % THREADS_PER_BLOCK; for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - acc[row] += mmvq_dot_product(block_byte_base, b_inner_id, b_repacked, b_ds); + let a_repacked = repack_a(block_byte_base, inner_id); + let da = get_dm(block_byte_base); + for (var col = 0u;col < NUM_COLS;col += 1) { + let src1q_idx = src1q_idx_base + col * (params.k / Q8_BLOCK_SIZE) + block; + let b_repacked = repack_b_qs(src1q_idx, inner_id); + let b_ds = repack_b_dm(src1q_idx); + + let row_sum = dot4I8Packed(a_repacked[0], b_repacked[0]) + dot4I8Packed(a_repacked[1], b_repacked[1]); + +#if defined(MUL_ACC_Q4_0) + acc[col][row] += f32(row_sum) * (da * b_ds.x) - 8.0 * da * b_ds.y / THREADS_PER_BLOCK; +#endif // MUL_ACC_Q4_0 + +#if defined(MUL_ACC_Q4_1) + acc[col][row] += f32(row_sum) * (da.x * b_ds.x) + da.y * b_ds.y / THREADS_PER_BLOCK; +#endif // MUL_ACC_Q4_1 + +#if defined(MUL_ACC_Q8_0) + acc[col][row] += f32(row_sum) * (da * b_ds); +#endif // MUL_ACC_Q8_0 + } } } } return acc; } -#endif +#endif // LEGACY_QUANTS #ifdef MUL_ACC_Q2_K #define BLOCK_SIZE_BYTES 84 @@ -191,22 +187,7 @@ fn get_scale_min(block_byte_base: u32, tid: u32) -> vec2 { let scale = byte_of(load_u32_at_src0_aligned(scale_byte), scale_byte & 3u); return vec2(f32(scale & 0xFu), f32(scale >> 4u)); } -fn mmvq_dot_product(a_byte_base: u32, tid: u32, b_repacked: vec4, b_ds: B_DS_TYPE) -> f32 { - let a_repacked = repack_a(a_byte_base, tid); - let dm = get_dm(a_byte_base); - let scale_min = get_scale_min(a_byte_base, tid); - - let scale_q = i32(scale_min.x); - let scale_m_i8x4 = u32(scale_min.y) * 0x01010101u; - - let row_sum_d = (dot4I8Packed(b_repacked[0], a_repacked[0]) + dot4I8Packed(b_repacked[1], a_repacked[1]) - + dot4I8Packed(b_repacked[2], a_repacked[2]) + dot4I8Packed(b_repacked[3], a_repacked[3])) * scale_q; - let row_sum_m = dot4I8Packed(b_repacked[0], scale_m_i8x4) + dot4I8Packed(b_repacked[1], scale_m_i8x4) - + dot4I8Packed(b_repacked[2], scale_m_i8x4) + dot4I8Packed(b_repacked[3], scale_m_i8x4); - - return b_ds * (dm.x * f32(row_sum_d) - dm.y * f32(row_sum_m)); -} -#endif +#endif // MUL_ACC_Q2_K #ifdef MUL_ACC_Q4_K #define BLOCK_SIZE_BYTES 144 @@ -265,39 +246,52 @@ fn get_scale_min(block_byte_base: u32, tid: u32) -> vec2 { return vec2(scale, min_val); } -fn mmvq_dot_product(a_byte_base: u32, tid: u32, b_repacked: vec4, b_ds: B_DS_TYPE) -> f32 { - let a_repacked = repack_a(a_byte_base, tid); - let dm = get_dm(a_byte_base); - let scale_min = get_scale_min(a_byte_base, tid); - - let row_sum = dot4I8Packed(a_repacked[0], b_repacked[0]) + dot4I8Packed(a_repacked[1], b_repacked[1]) - + dot4I8Packed(a_repacked[2], b_repacked[2]) + dot4I8Packed(a_repacked[3], b_repacked[3]); - - // Each thread covers half of the Q8_1 block, so add only b_ds.y/2. - return b_ds.x * dm.x * scale_min.x * f32(row_sum) - dm.y * scale_min.y * (b_ds.y / (Q8_BLOCK_SIZE / ELEMS_PER_THREAD)); -} -#endif +#endif // MUL_ACC_Q4_K #ifdef K_QUANTS -fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; for (var block = thread_id / THREADS_PER_BLOCK; block < params.k / BLOCK_SIZE; block += WG_SIZE / THREADS_PER_BLOCK) { - let src1q_idx = src1q_idx_base + (block * BLOCK_SIZE + ELEMS_PER_THREAD * tid) / Q8_BLOCK_SIZE; - let b_repacked = repack_b_qs(src1q_idx, tid); - let b_ds = repack_b_dm(src1q_idx); - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - acc[row] += mmvq_dot_product(block_byte_base, tid, b_repacked, b_ds); + let a_repacked = repack_a(block_byte_base, tid); + let dm = get_dm(block_byte_base); + let scale_min = get_scale_min(block_byte_base, tid); + for (var col = 0u;col < NUM_COLS;col += 1) { + let src1q_idx = src1q_idx_base + col * (params.k / Q8_BLOCK_SIZE) + (block * BLOCK_SIZE + ELEMS_PER_THREAD * tid) / Q8_BLOCK_SIZE; + let b_repacked = repack_b_qs(src1q_idx, tid); + let b_ds = repack_b_dm(src1q_idx); + +#if defined(MUL_ACC_Q2_K) + let scale_q = i32(scale_min.x); + let scale_m_i8x4 = u32(scale_min.y) * 0x01010101u; + + let row_sum_d = (dot4I8Packed(b_repacked[0], a_repacked[0]) + dot4I8Packed(b_repacked[1], a_repacked[1]) + + dot4I8Packed(b_repacked[2], a_repacked[2]) + dot4I8Packed(b_repacked[3], a_repacked[3])) * scale_q; + let row_sum_m = dot4I8Packed(b_repacked[0], scale_m_i8x4) + dot4I8Packed(b_repacked[1], scale_m_i8x4) + + dot4I8Packed(b_repacked[2], scale_m_i8x4) + dot4I8Packed(b_repacked[3], scale_m_i8x4); + + acc[col][row] += b_ds * (dm.x * f32(row_sum_d) - dm.y * f32(row_sum_m)); +#endif // MUL_ACC_Q2_K + +#if defined(MUL_ACC_Q4_K) + let row_sum = dot4I8Packed(a_repacked[0], b_repacked[0]) + dot4I8Packed(a_repacked[1], b_repacked[1]) + + dot4I8Packed(a_repacked[2], b_repacked[2]) + dot4I8Packed(a_repacked[3], b_repacked[3]); + + // Each thread covers half of the Q8_1 block, so add only b_ds.y/2. + acc[col][row] += b_ds.x * dm.x * scale_min.x * f32(row_sum) - dm.y * scale_min.y * (b_ds.y / (Q8_BLOCK_SIZE / ELEMS_PER_THREAD)); +#endif // MUL_ACC_Q4_K + + } } } } return acc; } -#endif +#endif // K_QUANTS diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl index b3f1fa04b8..847b27ffad 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl @@ -9,9 +9,11 @@ requires packed_4x8_integer_dot_product; struct Params { offset_src1: u32, + stride_11: u32, stride_12: u32, stride_13: u32, ne0: u32, + ne1: u32, ne2: u32, ne3: u32, }; @@ -57,25 +59,28 @@ fn main( @builtin(num_workgroups) num_wg: vec3 ) { let thread_id = local_id.x; - let num_vec4 = params.ne0 / 4u; + let ne0_vec4 = params.ne0 / 4u; - let wg_per_vec = (num_vec4 + (WG_SIZE - 1u)) / WG_SIZE; - let total_batches = wg_per_vec * params.ne2 * params.ne3; + let wg_per_vec = (ne0_vec4 + (WG_SIZE - 1u)) / WG_SIZE; + let total_batches = wg_per_vec * params.ne1 * params.ne2 * params.ne3; let wg_linear = wg_id.y * num_wg.x + wg_id.x; if (wg_linear >= total_batches) { return; } - let src13_idx = wg_linear / (params.ne2 * wg_per_vec); - let src12_idx = (wg_linear - src13_idx * (params.ne2 * wg_per_vec)) / wg_per_vec; - let src11_wg_idx = wg_linear % wg_per_vec; - let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; + let vec_idx = wg_linear / wg_per_vec; + let src13_idx = vec_idx / (params.ne2 * params.ne1); + let vec_ne12_num = vec_idx % (params.ne2 * params.ne1); + let src12_idx = vec_ne12_num / params.ne1; + let src11_idx = vec_ne12_num % params.ne1; + let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12 + src11_idx * params.stride_11; let src1_idx_vec4_base = src1_idx_base / 4u; let blocks_per_row = params.ne0 / 32u; let blocks_per_wg = (WG_SIZE * 4u) / 32u; - let src1q_idx_base = (src13_idx * params.ne2 + src12_idx) * blocks_per_row; + let src1q_idx_base = ((src13_idx * params.ne2 + src12_idx) * params.ne1 + src11_idx) * blocks_per_row; + let src11_wg_idx = wg_linear % wg_per_vec; let src1q_idx = src1q_idx_base + src11_wg_idx * blocks_per_wg + thread_id / 8u; let qs_idx = thread_id % 8u; @@ -85,7 +90,7 @@ fn main( var thread_amax = 0.0; let src11_vec4_idx = src11_wg_idx * WG_SIZE + thread_id; - let is_valid = src11_vec4_idx < num_vec4; + let is_valid = src11_vec4_idx < ne0_vec4; #ifdef USE_SUBGROUP_REDUCTION diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 15ae38927c..127c4634c0 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -8433,6 +8433,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {3, 2}, {2, 1})); test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {3, 2}, {1, 2})); test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {3, 2}, {2, 2})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 4, k, {3, 2}, {2, 2})); test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {1, 1}, {1, 1})); test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {1, 1}, {2, 1})); @@ -8449,6 +8450,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {2, 3}, {1, 1}, {0, 1, 3, 2})); test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {2, 3}, {1, 1}, {0, 3, 2, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 4, k, {2, 3}, {1, 1}, {0, 3, 2, 1})); test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, k, {2, 3}, {1, 1}, {0, 2, 1, 3})); test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, k, {2, 3}, {1, 1}, {0, 1, 3, 2})); test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, k, {2, 3}, {1, 1}, {0, 3, 2, 1})); From a3900a669419e38f4e5f13a1a773857a38fd4cb1 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 23 Jun 2026 04:03:31 -0600 Subject: [PATCH 57/86] model: Granite Speech Plus (#24818) * feat: Add conversion support for Granite Speech Plus Branch: GraniteSpeechPlus AI-usage: full (Bob, OpenCode + Qwen3.6-35b) Signed-off-by: Gabe Goodhart * feat: Extend granite_speech to support plus multi-layer concatenation Branch: GraniteSpeechPlus AI-usage: draft (Bob, OpenCode + Qwen3.6-35b) Signed-off-by: Gabe Goodhart * fix(conversion): Fix plural naming for feature_layers for audio Branch: GraniteSpeechPlus AI-usage: none Signed-off-by: Gabe Goodhart * fix(mtmd): Align feature_layer usage and naming everywhere Branch: GraniteSpeechPlus AI-usage: none Signed-off-by: Gabe Goodhart * style: Use fstring for log Signed-off-by: Gabe Goodhart Co-authored-by: Xuan-Son Nguyen --------- Co-authored-by: Xuan-Son Nguyen --- conversion/__init__.py | 2 ++ conversion/granite.py | 28 +++++++++++++++++++++ gguf-py/gguf/constants.py | 1 + gguf-py/gguf/gguf_writer.py | 3 +++ tools/mtmd/clip-impl.h | 2 +- tools/mtmd/clip-model.h | 6 ++--- tools/mtmd/clip.cpp | 19 +++++++------- tools/mtmd/models/granite-speech.cpp | 36 ++++++++++++++++++++++++++- tools/mtmd/models/granite4-vision.cpp | 4 +-- tools/mtmd/models/llava.cpp | 6 ++--- 10 files changed, 87 insertions(+), 20 deletions(-) diff --git a/conversion/__init__.py b/conversion/__init__.py index 00192cf33a..c6af6f7318 100644 --- a/conversion/__init__.py +++ b/conversion/__init__.py @@ -96,6 +96,7 @@ TEXT_MODEL_MAP: dict[str, str] = { "GraniteMoeHybridForCausalLM": "granite", "GraniteMoeSharedForCausalLM": "granite", "GraniteSpeechForConditionalGeneration": "granite", + "GraniteSpeechPlusForConditionalGeneration": "granite", "Grok1ForCausalLM": "grok", "GrokForCausalLM": "grok", "GroveMoeForCausalLM": "grovemoe", @@ -261,6 +262,7 @@ MMPROJ_MODEL_MAP: dict[str, str] = { "GlmasrModel": "ultravox", "Granite4VisionForConditionalGeneration": "granite", "GraniteSpeechForConditionalGeneration": "granite", + "GraniteSpeechPlusForConditionalGeneration": "granite", "HunYuanVLForConditionalGeneration": "hunyuan", "Idefics3ForConditionalGeneration": "smolvlm", "InternVisionModel": "internvl", diff --git a/conversion/granite.py b/conversion/granite.py index 53441fe570..8367ed225d 100644 --- a/conversion/granite.py +++ b/conversion/granite.py @@ -348,6 +348,34 @@ class GraniteSpeechMmprojModel(MmprojModel): yield from super().modify_tensors(data_torch, name, bid) +@ModelBase.register("GraniteSpeechPlusForConditionalGeneration") +class GraniteSpeechPlusMmprojModel(GraniteSpeechMmprojModel): + """Conversion for GraniteSpeechPlus - extends GraniteSpeech with feature layer concatenation""" + has_vision_encoder = False + has_audio_encoder = True + + def set_gguf_parameters(self): + assert self.hparams_audio is not None + super().set_gguf_parameters() + + # Add feature_layer if present in encoder config + if feature_layers := self.hparams_audio.get("cat_hidden_layers"): + self.gguf_writer.add_audio_feature_layers(feature_layers) + logger.info(f"gguf: audio feature_layers = {feature_layers}") + + # Validate projector dimension matches concatenated encoder output + hidden_dim = self.hparams_audio["hidden_dim"] + expected_dim = hidden_dim * (len(feature_layers) + 1) + projector_dim = self.global_config["projector_config"]["encoder_hidden_size"] + + if projector_dim != expected_dim: + raise ValueError( + f"Projector encoder_hidden_size ({projector_dim}) does not match " + f"expected concatenated dimension ({expected_dim}). " + f"Expected: hidden_dim ({hidden_dim}) * (len(feature_layers) + 1) = {expected_dim}" + ) + + @ModelBase.register("Granite4VisionForConditionalGeneration") class Granite4VisionMmprojModel(MmprojModel): has_vision_encoder = True diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 463963f2ac..1bda9452dd 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -359,6 +359,7 @@ class Keys: CHUNK_SIZE = "clip.audio.chunk_size" CONV_KERNEL_SIZE = "clip.audio.conv_kernel_size" MAX_POS_EMB = "clip.audio.max_pos_emb" + FEATURE_LAYERS = "clip.audio.feature_layer" # Granite Speech Plus class Attention: HEAD_COUNT = "clip.audio.attention.head_count" diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index f707f29dc5..a06ec88b32 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -1310,6 +1310,9 @@ class GGUFWriter: def add_audio_max_pos_emb(self, value: int) -> None: self.add_uint32(Keys.ClipAudio.MAX_POS_EMB, value) + def add_audio_feature_layers(self, layers: Sequence[int]) -> None: + self.add_array(Keys.ClipAudio.FEATURE_LAYERS, layers) + def add_audio_projector_window_size(self, value: int) -> None: self.add_uint32(Keys.ClipAudio.Projector.WINDOW_SIZE, value) diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index e7b5301445..5b413681f0 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -42,6 +42,7 @@ #define KEY_N_HEAD "clip.%s.attention.head_count" #define KEY_N_HEAD_KV "clip.%s.attention.head_count_kv" #define KEY_LAYER_NORM_EPS "clip.%s.attention.layer_norm_epsilon" +#define KEY_FEATURE_LAYERS "clip.%s.feature_layer" // vision-specific #define KEY_VISION_PROJ_TYPE "clip.vision.projector_type" // for models with mixed modalities @@ -54,7 +55,6 @@ #define KEY_PATCH_SIZE "clip.vision.patch_size" #define KEY_IMAGE_MEAN "clip.vision.image_mean" #define KEY_IMAGE_STD "clip.vision.image_std" -#define KEY_FEATURE_LAYER "clip.vision.feature_layer" #define KEY_PROJ_SCALE_FACTOR "clip.vision.projector.scale_factor" #define KEY_PROJ_SAMPLE_QUERY_SIDE "clip.vision.projector.query_side" #define KEY_PROJ_SAMPLE_WINDOW_SIDE "clip.vision.projector.window_side" diff --git a/tools/mtmd/clip-model.h b/tools/mtmd/clip-model.h index 48796b6306..f86702eba4 100644 --- a/tools/mtmd/clip-model.h +++ b/tools/mtmd/clip-model.h @@ -91,7 +91,7 @@ struct clip_hparams { float eps = 1e-6; float rope_theta = 0.0; - std::vector vision_feature_layer; + std::vector feature_layers; int32_t attn_window_size = 0; int32_t n_wa_pattern = 0; std::unordered_set wa_layer_indexes; // explicit layer indexes that use full attention (for irregular patterns like YoutuVL) @@ -165,8 +165,8 @@ struct clip_hparams { return false; } - bool is_vision_feature_layer(int32_t layer) const { - return std::find(vision_feature_layer.begin(), vision_feature_layer.end(), layer) != vision_feature_layer.end(); + bool is_feature_layer(int32_t layer) const { + return std::find(feature_layers.begin(), feature_layers.end(), layer) != feature_layers.end(); } }; diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 7dd7023c41..7bd486030f 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -1264,12 +1264,10 @@ struct clip_model_loader { } } - // Load the vision feature layer indices if they are explicitly provided; - // if multiple vision feature layers are present, the values will be concatenated - // to form the final visual features. + // Load the vision/audio feature layer indices if they are explicitly provided // NOTE: gguf conversions should standardize the values of the vision feature layer to // be non-negative, since we use -1 to mark values as unset here. - get_arr_int(KEY_FEATURE_LAYER, hparams.vision_feature_layer, false); + get_arr_int(string_format(KEY_FEATURE_LAYERS, prefix), hparams.feature_layers, false); // model-specific params switch (model.proj_type) { @@ -1651,6 +1649,7 @@ struct clip_model_loader { get_u32(KEY_A_PROJ_WINDOW_SIZE, hparams.audio_proj_window_size); get_u32(KEY_A_PROJ_DOWNSAMPLE_RATE, hparams.audio_proj_downsample_rate); get_u32(KEY_A_PROJ_HEAD_COUNT, hparams.audio_proj_head_count); + // NOTE: feature layers loaded above in common path } break; case PROJECTOR_TYPE_JANUS_PRO: { @@ -1663,11 +1662,11 @@ struct clip_model_loader { hparams.image_resize_algo = RESIZE_ALGO_BICUBIC_PILLOW; hparams.image_resize_pad = PAD_CEIL; - get_arr_int(KEY_FEATURE_LAYER, hparams.vision_feature_layer); + // NOTE: feature_layers loaded in common path as optional get_arr_int(KEY_PROJ_SPATIAL_OFFSETS, hparams.proj_spatial_offsets); - if (hparams.vision_feature_layer.size() != hparams.proj_spatial_offsets.size()) { - throw std::runtime_error(string_format("%s: vision_feature_layer.size() %d != proj_spatial_offsets.size() %d", - hparams.vision_feature_layer.size(), hparams.proj_spatial_offsets.size())); + if (hparams.feature_layers.size() != hparams.proj_spatial_offsets.size()) { + throw std::runtime_error(string_format("%s: feature_layers.size() %d != proj_spatial_offsets.size() %d", + hparams.feature_layers.size(), hparams.proj_spatial_offsets.size())); } get_u32(KEY_PROJ_SAMPLE_QUERY_SIDE, hparams.downsample_query_side); @@ -2740,7 +2739,7 @@ struct clip_model_loader { model.image_newline = get_tensor(TN_IMAGE_NEWLINE); // Load separate layerwise and spatial projector tensors - const auto projector_count = hparams.vision_feature_layer.size(); + const auto projector_count = hparams.feature_layers.size(); model.qf_proj_blocks.resize(projector_count); for (size_t bid = 0; bid < projector_count; ++bid) { auto & b = model.qf_proj_blocks[bid]; @@ -4388,7 +4387,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, int n_threads, const clip_image_f32 // Stage 1b only uses block 0's permutations; future stages // will upload all blocks. - for (size_t bid = 0; bid < hparams.vision_feature_layer.size(); ++bid) { + for (size_t bid = 0; bid < hparams.feature_layers.size(); ++bid) { const std::string prefix = "g4v_blk" + std::to_string(bid) + "_"; upload(prefix + "win_idx", make_win_idx(image_side, window_side)); upload(prefix + "qwin_idx", make_win_idx(new_side, query_side)); diff --git a/tools/mtmd/models/granite-speech.cpp b/tools/mtmd/models/granite-speech.cpp index 0bd4d75ac5..a158a59ce9 100644 --- a/tools/mtmd/models/granite-speech.cpp +++ b/tools/mtmd/models/granite-speech.cpp @@ -1,5 +1,7 @@ #include "models.h" +#include + ggml_cgraph * clip_graph_granite_speech::build() { const int n_frames = img.nx(); const int context_size = hparams.audio_chunk_size; @@ -11,6 +13,10 @@ ggml_cgraph * clip_graph_granite_speech::build() { const int padded_len = num_blocks * context_size; const int remainder = n_frames % context_size; + // Calculate projector input dimension based on feature layers + const int proj_input_dim = n_embd * (hparams.feature_layers.size() + 1); + const bool use_feature_concat = !hparams.feature_layers.empty(); + ggml_tensor * attn_dists = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, context_size * context_size); ggml_set_name(attn_dists, "attn_dists"); ggml_set_input(attn_dists); @@ -31,6 +37,15 @@ ggml_cgraph * clip_graph_granite_speech::build() { cur = ggml_add(ctx0, cur, model.inp_proj_b); cb(cur, "inp_linear", -1); + // Capture layer 0 if requested (after input_linear) + ggml_tensor * concat_result = nullptr; + if (use_feature_concat) { + if (std::find(hparams.feature_layers.begin(), hparams.feature_layers.end(), 0) != hparams.feature_layers.end()) { + concat_result = cur; + cb(concat_result, "feature_layer_0", -1); + } + } + for (int il = 0; il < n_layer; il++) { const auto & layer = model.layers[il]; auto * residual = cur; @@ -168,6 +183,18 @@ ggml_cgraph * clip_graph_granite_speech::build() { NORM_TYPE_NORMAL, eps, il); cb(cur, "layer_out", il); + // Capture intermediate layer (il + 1) if requested + if (use_feature_concat) { + if (hparams.is_feature_layer(il + 1)) { + if (concat_result == nullptr) { + concat_result = cur; + } else { + concat_result = ggml_concat(ctx0, concat_result, cur, 0); + } + cb(concat_result, string_format("feature_layer_%d", il + 1).c_str(), il); + } + } + // CTC branch if (il + 1 == ctc_layer) { auto * mid = build_mm(model.ctc_out_w, cur); @@ -180,6 +207,13 @@ ggml_cgraph * clip_graph_granite_speech::build() { } } + // Append final output to concatenated features if using feature concatenation + if (use_feature_concat && concat_result != nullptr) { + concat_result = ggml_concat(ctx0, concat_result, cur, 0); + cb(concat_result, "concat_final", -1); + cur = concat_result; + } + cb(cur, "encoder_out", -1); // QFormer projector @@ -197,7 +231,7 @@ ggml_cgraph * clip_graph_granite_speech::build() { cur = ggml_pad(ctx0, cur, 0, padded_proj - n_frames, 0, 0); } - ggml_tensor * enc_windows = ggml_reshape_3d(ctx0, cur, n_embd, window_size, nblocks_proj); + ggml_tensor * enc_windows = ggml_reshape_3d(ctx0, cur, proj_input_dim, window_size, nblocks_proj); ggml_tensor * queries = build_norm(model.qf_proj_blocks[0].qf_proj_query, model.qf_proj_blocks[0].qf_proj_norm_w, model.qf_proj_blocks[0].qf_proj_norm_b, diff --git a/tools/mtmd/models/granite4-vision.cpp b/tools/mtmd/models/granite4-vision.cpp index 9adb6f0fdb..1b252543c0 100644 --- a/tools/mtmd/models/granite4-vision.cpp +++ b/tools/mtmd/models/granite4-vision.cpp @@ -304,14 +304,14 @@ ggml_cgraph * clip_graph_granite4_vision::build() { } // --- Stage 1b/1c: WindowQFormer blocks --- - const int projector_count = hparams.vision_feature_layer.size(); + const int projector_count = hparams.feature_layers.size(); const float qformer_eps = 1e-12f; ggml_tensor * mmproj = nullptr; for (int bid = 0; bid < projector_count; ++bid) { const auto & blk = model.qf_proj_blocks[bid]; - int vlayer = hparams.vision_feature_layer[bid]; + int vlayer = hparams.feature_layers[bid]; GGML_ASSERT(vlayer >= 0 && vlayer < n_layer); ggml_tensor * h = layer_outs[vlayer]; diff --git a/tools/mtmd/models/llava.cpp b/tools/mtmd/models/llava.cpp index 5aa3d2f0fa..47efe68bd8 100644 --- a/tools/mtmd/models/llava.cpp +++ b/tools/mtmd/models/llava.cpp @@ -21,7 +21,7 @@ ggml_cgraph * clip_graph_llava::build() { // If we set explicit vision feature layers, only go up to the deepest one // NOTE: only used by granite-vision models for now - for (const auto & feature_layer : hparams.vision_feature_layer) { + for (const auto & feature_layer : hparams.feature_layers) { if (feature_layer > deepest_feature_layer) { deepest_feature_layer = feature_layer; } @@ -59,7 +59,7 @@ ggml_cgraph * clip_graph_llava::build() { // If this is an embedding feature layer, save the output. // NOTE: 0 index here refers to the input to the encoder. - if (hparams.is_vision_feature_layer(il)) { + if (hparams.is_feature_layer(il)) { embedding_stack.push_back(cur); } @@ -134,7 +134,7 @@ ggml_cgraph * clip_graph_llava::build() { // process vision feature layers (used by granite) { // final layer is a vision feature layer - if (hparams.is_vision_feature_layer(max_feature_layer)) { + if (hparams.is_feature_layer(max_feature_layer)) { embedding_stack.push_back(inpL); } From c926ad09857517978575d6a74d225b463f7417a0 Mon Sep 17 00:00:00 2001 From: Wyatt Caldwell <218154709+Detensable@users.noreply.github.com> Date: Tue, 23 Jun 2026 03:55:46 -0700 Subject: [PATCH 58/86] vulkan: link ggml-cpu when GGML_VULKAN_CHECK_RESULTS / RUN_TESTS are enabled (#24444) The result-checking and test debug paths in ggml-vulkan.cpp call ggml_graph_compute_with_ctx() to compute a CPU reference graph, but that symbol is defined in ggml-cpu, which ggml-vulkan does not link. Enabling -DGGML_VULKAN_CHECK_RESULTS=ON (or -DGGML_VULKAN_RUN_TESTS=ON) therefore fails to link with an unresolved external (e.g. LNK2019 on MSVC, undefined reference on GCC/Clang). This regressed after ggml-cpu was split into its own library. Link ggml-cpu under those two options so the debug builds link again. Signed-off-by: Wyatt Caldwell <218154709+Detensable@users.noreply.github.com> --- ggml/src/ggml-vulkan/CMakeLists.txt | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/ggml/src/ggml-vulkan/CMakeLists.txt b/ggml/src/ggml-vulkan/CMakeLists.txt index 2d9e85794a..5aeb6e97b1 100644 --- a/ggml/src/ggml-vulkan/CMakeLists.txt +++ b/ggml/src/ggml-vulkan/CMakeLists.txt @@ -108,6 +108,9 @@ if (Vulkan_FOUND) if (GGML_VULKAN_CHECK_RESULTS) add_compile_definitions(GGML_VULKAN_CHECK_RESULTS) + # the result-checking path computes a CPU reference graph via + # ggml_graph_compute_with_ctx(), which is defined in ggml-cpu + target_link_libraries(ggml-vulkan PRIVATE ggml-cpu) endif() if (GGML_VULKAN_DEBUG) @@ -129,6 +132,8 @@ if (Vulkan_FOUND) if (GGML_VULKAN_RUN_TESTS) add_compile_definitions(GGML_VULKAN_RUN_TESTS) + # the test path also calls ggml_graph_compute_with_ctx() (ggml-cpu) + target_link_libraries(ggml-vulkan PRIVATE ggml-cpu) endif() # Set up toolchain for host compilation whether cross-compiling or not From 75ad0b23ed6dc98ce384953e1f9bc494c3de92ce Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Tue, 23 Jun 2026 13:28:34 +0200 Subject: [PATCH 59/86] server: fix remote preset handling, add test (#24938) * server: add test for remote preset * fix remote preset handling * fix * fix test --- common/arg.cpp | 11 +++++++---- common/arg.h | 7 ++++++- common/download.cpp | 4 +++- common/download.h | 1 + tools/server/server-models.cpp | 6 ++++-- tools/server/server.cpp | 13 ++++++++++++- tools/server/tests/unit/test_router.py | 19 +++++++++++++++++++ 7 files changed, 52 insertions(+), 9 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 5297d90753..276dbec8ba 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -301,6 +301,8 @@ static handle_model_result common_params_handle_model(struct common_params_model const common_download_opts & opts) { handle_model_result result; + // TODO @ngxson : refactor this into a new common_model_download_context + if (!model.docker_repo.empty()) { model.path = common_docker_resolve_model(model.docker_repo); } else if (!model.hf_repo.empty()) { @@ -396,7 +398,7 @@ static bool parse_bool_value(const std::string & value) { // CLI argument parsing functions // -bool common_params_handle_models(common_params & params, llama_example curr_ex, common_download_callback * callback) { +bool common_params_handle_models(common_params & params, llama_example curr_ex, const common_params_handle_models_params & handle_params) { const bool spec_type_draft_mtp = std::find(params.speculative.types.begin(), params.speculative.types.end(), COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params.speculative.types.end(); @@ -407,9 +409,10 @@ bool common_params_handle_models(common_params & params, llama_example curr_ex, opts.skip_download = params.skip_download; opts.download_mtp = spec_type_draft_mtp; opts.download_mmproj = !params.no_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty(); + opts.preset_only = handle_params.preset_only; - if (callback) { - opts.callback = callback; + if (handle_params.callback) { + opts.callback = handle_params.callback; } // sub-models (draft, mmproj, vocoder) are explicitly specified by the user, @@ -596,7 +599,7 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context if (!skip_model_download) { // handle model and download - common_params_handle_models(params, ctx_arg.ex); + common_params_handle_models(params, ctx_arg.ex, {}); // model is required (except for server) // TODO @ngxson : maybe show a list of available models in CLI in this case diff --git a/common/arg.h b/common/arg.h index c061fc60f7..fdfc04bc7a 100644 --- a/common/arg.h +++ b/common/arg.h @@ -130,6 +130,11 @@ bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map & args); +struct common_params_handle_models_params { + common_download_callback * callback = nullptr; + bool preset_only = false; // if true, only check & download remote preset (for router mode) +}; + // populate model paths (main model, mmproj, etc) from -hf if necessary // return true if the model is ready to use // throw an exception if there is an error that prevents the model from being used (e.g. network error, model not found, etc) @@ -137,7 +142,7 @@ void common_params_add_preset_options(std::vector & args); bool common_params_handle_models( common_params & params, llama_example curr_ex, - common_download_callback * callback = nullptr); + const common_params_handle_models_params & handle_params); // initialize argument parser context - used by test-arg-parser and preset common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr); diff --git a/common/download.cpp b/common/download.cpp index f320462753..5b55c76a11 100644 --- a/common/download.cpp +++ b/common/download.cpp @@ -799,6 +799,7 @@ common_download_model_result common_download_model(const common_params_model & bool download_mmproj = opts.download_mmproj; bool download_mtp = opts.download_mtp; + bool preset_only = opts.preset_only; bool is_hf = !model.hf_repo.empty(); if (is_hf) { @@ -806,7 +807,8 @@ common_download_model_result common_download_model(const common_params_model & if (!hf.preset.path.empty()) { // if preset.ini exists, only download that file alone tasks.push_back({hf.preset.url, hf.preset.local_path}); - } else { + } else if (!preset_only) { + // only add other files if we're NOT in preset-only mode (normal run, non-router) for (const auto & f : hf.model_files) { tasks.push_back({f.url, f.local_path}); } diff --git a/common/download.h b/common/download.h index 8dbf07836f..755e34ea8c 100644 --- a/common/download.h +++ b/common/download.h @@ -55,6 +55,7 @@ struct common_download_opts { bool skip_download = false; // if true, only validation is performed, common_skip_download_exception may be thrown if the file is missing or invalid bool download_mmproj = false; bool download_mtp = false; + bool preset_only = false; // if true, only check & download remote preset (for router mode) common_download_callback * callback = nullptr; }; diff --git a/tools/server/server-models.cpp b/tools/server/server-models.cpp index a87e4e423e..a4df3ef108 100644 --- a/tools/server/server-models.cpp +++ b/tools/server/server-models.cpp @@ -224,7 +224,7 @@ void server_model_meta::update_caps() { }); params.offline = true; // params.skip_download = true; // TODO: ideally, we should validate the model here, but it takes too much time - common_params_handle_models(params, LLAMA_EXAMPLE_SERVER); + common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, {}); if (params.mmproj.path.empty()) { multimodal = { false, false }; } else { @@ -1393,7 +1393,9 @@ struct server_download_state : public common_download_callback { bool run(common_params & params) { try { - common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, this); + common_params_handle_models_params p; + p.callback = this; + common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, p); is_ok = true; } catch (const std::exception & e) { auto model_name = params.model.get_name(); diff --git a/tools/server/server.cpp b/tools/server/server.cpp index dd4b1c507c..4165c1015e 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -89,6 +89,17 @@ int llama_server(int argc, char ** argv) { llama_backend_init(); llama_numa_init(params.numa); + // note: router mode also accepts -hf remote-preset, so we need to check that first + if (!params.model.hf_repo.empty()) { + try { + common_params_handle_models_params handle_params; + handle_params.preset_only = true; + common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, handle_params); + } catch (const std::exception & e) { + // ignored for now + } + } + // router server never loads a model and must not touch the GPU const bool is_router_server = params.model.path.empty() && params.model.hf_repo.empty(); @@ -263,7 +274,7 @@ int llama_server(int argc, char ** argv) { return child.run_download(params); } else if (!is_router_server) { // single-model mode (NOT spawned by router) - common_params_handle_models(params, LLAMA_EXAMPLE_SERVER); + common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, {}); } // diff --git a/tools/server/tests/unit/test_router.py b/tools/server/tests/unit/test_router.py index 41e95f4c5f..94165e520e 100644 --- a/tools/server/tests/unit/test_router.py +++ b/tools/server/tests/unit/test_router.py @@ -256,6 +256,25 @@ def test_router_reload_models(): os.remove(preset_path) +def test_router_remote_preset(): + global server + server.model_hf_repo = "ggml-org/test-preset-ci" + server.model_hf_file = None + server.offline = False + server.start() + + # Should see preset models in GET /models + res = server.make_request("GET", "/models") + assert res.status_code == 200 + ids = {item["id"] for item in res.body.get("data", [])} + assert "tinygemma3-preset" in ids + assert "stories260K-test" in ids + + # Should be able to load a preset model + model_id = "tinygemma3-preset" + _load_model_and_wait(model_id) + + MODEL_DOWNLOAD_ID = "ggml-org/test-model-router-download:F16" MODEL_DOWNLOAD_TIMEOUT = 30 From 0eb874d37445fb25cc268ad0b2f2cb07ce561b66 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Tue, 23 Jun 2026 07:26:17 -0500 Subject: [PATCH 60/86] vulkan: make mul_mm ALIGNED a spec constant (#24689) This trims down some of the shader variant explosion and reduces binary size. --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 76 +++++---- .../ggml-vulkan/vulkan-shaders/mul_mm.comp | 54 ++++--- .../vulkan-shaders/mul_mm_cm2.comp | 11 +- .../vulkan-shaders/mul_mm_funcs.glsl | 151 ++++++++++-------- .../vulkan-shaders/vulkan-shaders-gen.cpp | 17 +- 5 files changed, 172 insertions(+), 137 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 9a36b45de8..b3c269783e 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -4074,19 +4074,35 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { } #endif + auto const &ggml_vk_mul_mm_spec = [](std::vector spec, bool aligned) { + spec.push_back(aligned ? 1u : 0u); + return spec; + }; + const int mul_mat_id_param_count = 5; #if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) if (device->coopmat2) { + auto const &ggml_vk_mul_mm_cm2_spec = [](std::vector spec, bool aligned, bool mul_mat_id) { + if (mul_mat_id && spec.size() > 5) { + spec.insert(spec.begin() + 5, aligned ? 1u : 0u); + } else { + spec.push_back(aligned ? 1u : 0u); + } + if (mul_mat_id && spec.size() == 6) { + spec.push_back(32); + } + return spec; + }; // Create 6 variants, {s,m,l}x{unaligned,aligned} #define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, true); \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, true); \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, true); \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, true); \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, true); \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(l_ ## WARPTILE, false, PARAMCOUNT == mul_mat_id_param_count), 1, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(m_ ## WARPTILE, false, PARAMCOUNT == mul_mat_id_param_count), 1, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(s_ ## WARPTILE, false, PARAMCOUNT == mul_mat_id_param_count), 1, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(l_ ## WARPTILE, true, PARAMCOUNT == mul_mat_id_param_count), l_align, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(m_ ## WARPTILE, true, PARAMCOUNT == mul_mat_id_param_count), m_align, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(s_ ## WARPTILE, true, PARAMCOUNT == mul_mat_id_param_count), s_align, true); \ // Create 2 variants, {f16,f32} accumulator #define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ @@ -4161,17 +4177,17 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { // Create 6 variants, {s,m,l}x{unaligned,aligned} #define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ if (device->mul_mat ## ID ## _l[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, false), 1, false, true); \ if (device->mul_mat ## ID ## _m[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, false), 1, false, true); \ if (device->mul_mat ## ID ## _s[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, false), 1, false, true); \ if (device->mul_mat ## ID ## _l[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, true), l_align, false, true); \ if (device->mul_mat ## ID ## _m[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, true), m_align, false, true); \ if (device->mul_mat ## ID ## _s[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, true), s_align, false, true); \ // Create 2 variants, {f16,f32} accumulator #define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ @@ -4284,32 +4300,32 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { // Selects dot2 SPIR-V variant at runtime when device->dot2_f16 is true #define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \ if (device->mul_mat ## ID ## _l[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _m[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _s[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _l[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _len : NAMELC ## _aligned ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _data : NAMELC ## _aligned ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, true), l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _m[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _len : NAMELC ## _aligned ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _data : NAMELC ## _aligned ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, true), m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _s[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _len : NAMELC ## _aligned ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _data : NAMELC ## _aligned ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, true), s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ // bf16 scalar path promotes to f32, no dot2 variant #define CREATE_MM_NODOT2(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \ if (device->mul_mat ## ID ## _l[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _m[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _s[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _l[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, true), l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _m[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, true), m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _s[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, true), s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ #define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \ if (device->mul_mat ## ID ## _l_int[TYPE]) { \ @@ -4474,17 +4490,17 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { // Create 6 variants, {s,m,l}x{unaligned,aligned} #define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \ if (device->mul_mat ## ID ## _l[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _m[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _s[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _l[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, true), l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _m[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, true), m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _s[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, true), s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ #define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ if (device->mul_mat ## ID ## _l_int[TYPE]) \ diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index f39410d74f..57c0410e45 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -38,17 +38,7 @@ #define LOAD_VEC_B 1 #endif -// Load 2 values at once without affecting index calculations through LOAD_VEC -#if (defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)) && !defined(ALIGNED) -#define LOAD_VEC_BATCH_A 2 -#else -#define LOAD_VEC_BATCH_A 1 -#endif -#if !defined(ALIGNED) -#define LOAD_VEC_BATCH_B 2 -#else -#define LOAD_VEC_BATCH_B 1 -#endif +layout (constant_id = 11) const uint ALIGNED = 0; #if !defined(TO_FLOAT_TYPE) #define TO_FLOAT_TYPE FLOAT_TYPE @@ -57,6 +47,13 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +#if defined(DATA_A_F32) +layout (binding = 0) readonly buffer A_SCALAR {float data_a_scalar[];}; +#elif defined(DATA_A_F16) +layout (binding = 0) readonly buffer A_SCALAR {float16_t data_a_scalar[];}; +#elif defined(DATA_A_BF16) +layout (binding = 0) readonly buffer A_SCALAR {uint16_t data_a_scalar[];}; +#endif #if defined(A_TYPE_PACKED16) layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];}; #endif @@ -65,6 +62,7 @@ layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32 #endif layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; +layout (binding = 1) readonly buffer B_SCALAR {B_TYPE_SCALAR data_b_scalar[];}; layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; #ifdef MUL_MAT_ID @@ -194,13 +192,23 @@ void main() { const uint warp_r = warp_i % (BM / WM); const uint warp_c = warp_i / (BM / WM); - const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A / LOAD_VEC_BATCH_A); - const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A / LOAD_VEC_BATCH_A); - const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B / LOAD_VEC_BATCH_B); - const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B / LOAD_VEC_BATCH_B); +#if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16) + const uint LOAD_VEC_A_EFF = (ALIGNED != 0) ? LOAD_VEC_A : 1; + const uint LOAD_VEC_BATCH_A = (ALIGNED != 0) ? 1 : 2; +#else + const uint LOAD_VEC_A_EFF = LOAD_VEC_A; + const uint LOAD_VEC_BATCH_A = 1; +#endif + const uint LOAD_VEC_B_EFF = (ALIGNED != 0) ? LOAD_VEC_B : 1; + const uint LOAD_VEC_BATCH_B = (ALIGNED != 0) ? 1 : 2; - const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A * LOAD_VEC_BATCH_A / BK; - const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B * LOAD_VEC_BATCH_B / BK; + const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A_EFF / LOAD_VEC_BATCH_A); + const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A_EFF / LOAD_VEC_BATCH_A); + const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B_EFF / LOAD_VEC_BATCH_B); + const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B_EFF / LOAD_VEC_BATCH_B); + + const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A_EFF * LOAD_VEC_BATCH_A / BK; + const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B_EFF * LOAD_VEC_BATCH_B / BK; #ifdef MUL_MAT_ID #ifdef MUL_MAT_ID_USE_SUBGROUPS @@ -239,15 +247,15 @@ void main() { uint pos_a = #ifdef MUL_MAT_ID - expert_idx * (p.batch_stride_a / LOAD_VEC_A) + + expert_idx * (p.batch_stride_a / LOAD_VEC_A_EFF) + #else - batch_idx_a * (p.batch_stride_a / LOAD_VEC_A) + + batch_idx_a * (p.batch_stride_a / LOAD_VEC_A_EFF) + #endif - (ir * BM * p.stride_a + start_k) / LOAD_VEC_A; + (ir * BM * p.stride_a + start_k) / LOAD_VEC_A_EFF; #ifdef MUL_MAT_ID uint pos_b = 0; #else - uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B; + uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B_EFF; #endif #ifdef COOPMAT @@ -287,8 +295,8 @@ void main() { barrier(); - pos_a += BK / LOAD_VEC_A; - pos_b += BK / LOAD_VEC_B; + pos_a += BK / LOAD_VEC_A_EFF; + pos_b += BK / LOAD_VEC_B_EFF; #ifdef COOPMAT [[unroll]] for (uint i = 0; i < BK; i += TK) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp index 2656fe1c3e..a2e15f6f5c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp @@ -36,6 +36,7 @@ layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working wit layout (constant_id = 4) const bool enable_smaller_matrices = false; const uint BNover2 = enable_smaller_matrices ? (BN / 2) : BN; const uint BNover4 = enable_smaller_matrices ? (BN / 4) : BN; +layout (constant_id = 5) const uint ALIGNED = 0; layout (push_constant) uniform parameter { @@ -111,7 +112,7 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufB { }; uint _ne1; -layout (constant_id = 5) const uint subgroup_size = 32; +layout (constant_id = 6) const uint subgroup_size = 32; shared uvec4 ballots_sh[BLOCK_SIZE / subgroup_size]; B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const in uint coordInBlock[2]) @@ -297,12 +298,12 @@ void main() { // Hint to the compiler that values are aligned (want 16B alignment). // Quants are always block-aligned, no alignment needed. -#if ALIGNED + if (ALIGNED != 0) { #if QUANT_K == 1 - stride_a &= ~7; -#endif - stride_b &= ~7; + stride_a &= ~7; #endif + stride_b &= ~7; + } // Create layouts for both clamped and unclamped accesses tensorLayoutNV<2> tensorLayoutA = createTensorLayoutNV(2); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl index 7359516898..56a8a0f187 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl @@ -1,50 +1,57 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uint idx_m, const uint block, const uint end_k) { #if defined(DATA_A_F32) || defined(DATA_A_F16) #if LOAD_VEC_A == 8 - const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; - const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; - FLOAT_TYPEV8 aa = FLOAT_TYPEV8(data_a[idx]); - buf_a[buf_idx ] = aa[0].xy; - buf_a[buf_idx + 1] = aa[0].zw; - buf_a[buf_idx + 2] = aa[1].xy; - buf_a[buf_idx + 3] = aa[1].zw; + if (ALIGNED != 0) { + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + FLOAT_TYPEV8 aa = FLOAT_TYPEV8(data_a[idx]); + buf_a[buf_idx ] = aa[0].xy; + buf_a[buf_idx + 1] = aa[0].zw; + buf_a[buf_idx + 2] = aa[1].xy; + buf_a[buf_idx + 3] = aa[1].zw; + return; + } #elif LOAD_VEC_A == 4 - const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; - const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; - FLOAT_TYPEV4 aa = FLOAT_TYPEV4(data_a[idx]); - buf_a[buf_idx ] = aa.xy; - buf_a[buf_idx + 1] = aa.zw; -#else // LOAD_VEC_BATCH_A == 2 + if (ALIGNED != 0) { + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + FLOAT_TYPEV4 aa = FLOAT_TYPEV4(data_a[idx]); + buf_a[buf_idx ] = aa.xy; + buf_a[buf_idx + 1] = aa.zw; + return; + } +#endif const uint idx = pos_a + col * p.stride_a + row * 2; const uint buf_idx = col * SHMEM_STRIDE + row; if (idx_m < p.M && block + row * 2 + 1 < end_k) { - buf_a[buf_idx] = FLOAT_TYPEV2(data_a[idx], - data_a[idx + 1]); + buf_a[buf_idx] = FLOAT_TYPEV2(data_a_scalar[idx], + data_a_scalar[idx + 1]); } else if (idx_m < p.M && block + row * 2 < end_k) { - buf_a[buf_idx] = FLOAT_TYPEV2(data_a[idx], 0.0f); + buf_a[buf_idx] = FLOAT_TYPEV2(data_a_scalar[idx], 0.0f); } else { buf_a[buf_idx] = FLOAT_TYPEV2(0.0f); } -#endif #elif defined(DATA_A_BF16) #if LOAD_VEC_A == 4 - const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; - const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; - FLOAT_TYPEV4 aa = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_a[idx])); - buf_a[buf_idx ] = aa.xy; - buf_a[buf_idx + 1] = aa.zw; -#else // LOAD_VEC_BATCH_A == 2 + if (ALIGNED != 0) { + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + FLOAT_TYPEV4 aa = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_a[idx])); + buf_a[buf_idx ] = aa.xy; + buf_a[buf_idx + 1] = aa.zw; + return; + } +#endif const uint idx = pos_a + col * p.stride_a + row * 2; const uint buf_idx = col * SHMEM_STRIDE + row; if (idx_m < p.M && block + row * 2 + 1 < end_k) { - buf_a[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_a[idx]), - TO_FLOAT_TYPE(data_a[idx + 1])); + buf_a[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_a_scalar[idx]), + TO_FLOAT_TYPE(data_a_scalar[idx + 1])); } else if (idx_m < p.M && block + row * 2 < end_k) { - buf_a[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_a[idx]), 0.0f); + buf_a[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_a_scalar[idx]), 0.0f); } else { buf_a[buf_idx] = FLOAT_TYPEV2(0.0f); } -#endif #elif defined(DATA_A_Q4_0) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4; @@ -526,75 +533,85 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin #if !defined(MUL_MAT_ID) void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uint idx_n, const uint block, const uint end_k) { #if LOAD_VEC_B == 8 - // Not supported for b_type bf16 because bf16mat2x4 does not exist - const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row; - const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2; - FLOAT_TYPEV8 bb = FLOAT_TYPEV8(data_b[idx]); - buf_b[buf_idx + 0] = bb[0].xy; - buf_b[buf_idx + 1] = bb[0].zw; - buf_b[buf_idx + 2] = bb[1].xy; - buf_b[buf_idx + 3] = bb[1].zw; + if (ALIGNED != 0) { + // Not supported for b_type bf16 because bf16mat2x4 does not exist + const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2; + FLOAT_TYPEV8 bb = FLOAT_TYPEV8(data_b[idx]); + buf_b[buf_idx + 0] = bb[0].xy; + buf_b[buf_idx + 1] = bb[0].zw; + buf_b[buf_idx + 2] = bb[1].xy; + buf_b[buf_idx + 3] = bb[1].zw; + return; + } #elif LOAD_VEC_B == 4 - const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row; - const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2; + if (ALIGNED != 0) { + const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2; #if defined(DATA_B_BF16) - FLOAT_TYPEV4 bb = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_b[idx])); + FLOAT_TYPEV4 bb = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_b[idx])); #else - FLOAT_TYPEV4 bb = FLOAT_TYPEV4(data_b[idx]); + FLOAT_TYPEV4 bb = FLOAT_TYPEV4(data_b[idx]); +#endif + buf_b[buf_idx + 0] = bb.xy; + buf_b[buf_idx + 1] = bb.zw; + return; + } #endif - buf_b[buf_idx + 0] = bb.xy; - buf_b[buf_idx + 1] = bb.zw; -#else // LOAD_VEC_BATCH_B == 2 const uint idx = pos_b + col * p.stride_b + row * 2; const uint buf_idx = col * SHMEM_STRIDE + row; if (idx_n < p.N && block + row * 2 + 1 < end_k) { - buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]), - TO_FLOAT_TYPE(data_b[idx + 1])); + buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b_scalar[idx]), + TO_FLOAT_TYPE(data_b_scalar[idx + 1])); } else if (idx_n < p.N && block + row * 2 < end_k) { - buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]), 0.0f); + buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b_scalar[idx]), 0.0f); } else { buf_b[buf_idx] = FLOAT_TYPEV2(0.0f); } -#endif } #else void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uint ic, const uint _ne1, const uint block, const uint end_k) { #if LOAD_VEC_B == 8 - // Not supported for b_type bf16 because bf16mat2x4 does not exist - const u16vec2 row_idx = row_ids[col]; - const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row; - const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2; - FLOAT_TYPEV8 bb = FLOAT_TYPEV8(data_b[idx]); - buf_b[buf_idx + 0] = bb[0].xy; - buf_b[buf_idx + 1] = bb[0].zw; - buf_b[buf_idx + 2] = bb[1].xy; - buf_b[buf_idx + 3] = bb[1].zw; + if (ALIGNED != 0) { + // Not supported for b_type bf16 because bf16mat2x4 does not exist + const u16vec2 row_idx = row_ids[col]; + const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2; + FLOAT_TYPEV8 bb = FLOAT_TYPEV8(data_b[idx]); + buf_b[buf_idx + 0] = bb[0].xy; + buf_b[buf_idx + 1] = bb[0].zw; + buf_b[buf_idx + 2] = bb[1].xy; + buf_b[buf_idx + 3] = bb[1].zw; + return; + } #elif LOAD_VEC_B == 4 - const u16vec2 row_idx = row_ids[col]; - const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row; - const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2; + if (ALIGNED != 0) { + const u16vec2 row_idx = row_ids[col]; + const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2; #if defined(DATA_B_BF16) - FLOAT_TYPEV4 bb = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_b[idx])); + FLOAT_TYPEV4 bb = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_b[idx])); #else - FLOAT_TYPEV4 bb = FLOAT_TYPEV4(data_b[idx]); + FLOAT_TYPEV4 bb = FLOAT_TYPEV4(data_b[idx]); +#endif + buf_b[buf_idx + 0] = bb.xy; + buf_b[buf_idx + 1] = bb.zw; + return; + } #endif - buf_b[buf_idx + 0] = bb.xy; - buf_b[buf_idx + 1] = bb.zw; -#else // LOAD_VEC_BATCH_B == 2 const uint row_i = ic * BN + col; const uint buf_idx = col * SHMEM_STRIDE + row; if (row_i < _ne1 && block + row * 2 + 1 < end_k) { const u16vec2 row_idx = row_ids[col]; const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2; - buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]), - TO_FLOAT_TYPE(data_b[idx + 1])); + buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b_scalar[idx]), + TO_FLOAT_TYPE(data_b_scalar[idx + 1])); } else if (row_i < _ne1 && block + row * 2 < end_k) { const u16vec2 row_idx = row_ids[col]; const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2; - buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]), 0.0f); + buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b_scalar[idx]), 0.0f); } else { buf_b[buf_idx] = FLOAT_TYPEV2(0.0f); } -#endif } #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index ca6b444314..f07583b6ab 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -539,11 +539,9 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c }; // Shaders with f16 B_TYPE - string_to_spv(shader_name + "_f32_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_f32_f16" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f32_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPE_SCALAR", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_f16" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPE_SCALAR", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); // bf16 { @@ -565,8 +563,7 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c #endif { if (!dot2) { - string_to_spv(shader_name + "_bf16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"B_TYPEV4", "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"B_TYPEV4", "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_bf16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"B_TYPE_SCALAR", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"B_TYPEV4", "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}}), fp16, coopmat, coopmat2, f16acc); } } } @@ -583,8 +580,6 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c } std::string data_a_key = "DATA_A_" + to_uppercase(tname); - // For unaligned, load one at a time for f32/f16, or two at a time for quants - std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? "1" : load_vec_quant; // For aligned matmul loads std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? load_vec : load_vec_quant; @@ -597,13 +592,11 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c // don't generate f32 variants for coopmat2 if (!coopmat2) { - string_to_spv(shader_name + "_" + tname + "_f32" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_" + tname + "_f32" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f32" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"B_TYPE_SCALAR", "float"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); } if (tname != "f16" && tname != "f32") { - string_to_spv(shader_name + "_" + tname + "_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_" + tname + "_f16" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPE_SCALAR", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); } #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) From c5606364b2a608c6501f25b80c3e256a355a0ce8 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Tue, 23 Jun 2026 08:39:20 -0500 Subject: [PATCH 61/86] vulkan: support CONV_3D (#24612) * vulkan: support CONV_3D This is a pretty direct port of conv2d_mm.comp to CONV_3D, done by codex and cleaned up by me. * disable slower perf tests --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 244 +++++++++- .../ggml-vulkan/vulkan-shaders/conv3d_mm.comp | 431 ++++++++++++++++++ .../vulkan-shaders/vulkan-shaders-gen.cpp | 25 + tests/test-backend-ops.cpp | 28 ++ 4 files changed, 725 insertions(+), 3 deletions(-) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/conv3d_mm.comp diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index b3c269783e..508d569f20 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -493,6 +493,20 @@ struct vk_conv2d_pipeline_state { } }; +struct vk_conv3d_pipeline_state { + vk_conv3d_pipeline_state(uint32_t s0, uint32_t s1, uint32_t s2, uint32_t p0, uint32_t p1, uint32_t p2, + uint32_t d0, uint32_t d1, uint32_t d2, uint32_t KW, uint32_t KH, uint32_t KD, uint32_t aligned) + : s0(s0), s1(s1), s2(s2), p0(p0), p1(p1), p2(p2), d0(d0), d1(d1), d2(d2), KW(KW), KH(KH), KD(KD), aligned(aligned) {} + + uint32_t s0, s1, s2, p0, p1, p2, d0, d1, d2, KW, KH, KD; + uint32_t aligned; + + bool operator<(const vk_conv3d_pipeline_state &b) const { + return std::tie(s0, s1, s2, p0, p1, p2, d0, d1, d2, KW, KH, KD, aligned) < + std::tie(b.s0, b.s1, b.s2, b.p0, b.p1, b.p2, b.d0, b.d1, b.d2, b.KW, b.KH, b.KD, b.aligned); + } +}; + struct vk_solve_tri_pipeline_state { vk_solve_tri_pipeline_state(uint32_t N, uint32_t K) : N(N), K(K) {} @@ -924,6 +938,8 @@ struct vk_device_struct { std::map pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT]; std::map pipeline_conv_transpose_2d_f32[CONV_SHAPE_COUNT]; std::map pipeline_conv_transpose_2d_f16_f32[CONV_SHAPE_COUNT]; + std::map pipeline_conv3d_f32[CONV_SHAPE_COUNT]; + std::map pipeline_conv3d_f16_f32[CONV_SHAPE_COUNT]; vk_pipeline pipeline_conv2d_dw_whcn_f32, pipeline_conv2d_dw_whcn_f16_f32; vk_pipeline pipeline_conv2d_dw_cwhn_f32, pipeline_conv2d_dw_cwhn_f16_f32; @@ -1669,6 +1685,41 @@ template <> void init_pushconst_fastdiv(vk_op_conv2d_push_constants &p) { init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL); } +struct vk_op_conv3d_push_constants { + uint32_t OC; + uint32_t IC; + uint32_t N; + + uint32_t IW; + uint32_t IH; + uint32_t ID; + uint32_t OW; + uint32_t OH; + uint32_t OD; + + uint32_t nb01; + uint32_t nb02; + uint32_t nb03; + + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; + + uint32_t nb1; + uint32_t nb2; + uint32_t nb3; + + uint32_t OWmp; uint32_t OWL; + uint32_t OWOHmp; uint32_t OWOHL; + uint32_t OWOHODmp; uint32_t OWOHODL; +}; + +template <> void init_pushconst_fastdiv(vk_op_conv3d_push_constants &p) { + init_fastdiv_values(p.OW, p.OWmp, p.OWL); + init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL); + init_fastdiv_values(p.OW*p.OH*p.OD, p.OWOHODmp, p.OWOHODL); +} + struct vk_op_conv2d_dw_push_constants { uint32_t ne; uint32_t batches; @@ -5330,7 +5381,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { ggml_vk_create_pipeline(device, device->pipeline_opt_step_sgd_f32, "opt_step_sgd_f32", opt_step_sgd_f32_len, opt_step_sgd_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); - // conv2d, conv_transpose_2d + // conv2d, conv_transpose_2d, conv3d for (uint32_t s = 0; s < CONV_SHAPE_COUNT; ++s) { // smaller WG for the small-tile fallback gives more concurrent WGs per SM uint32_t conv2d_WG_SIZE = (s == CONV_SHAPE_64x32) ? 128 : 256; @@ -5393,8 +5444,8 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { return (conv2d_BS.K * (conv2d_BS.CRS + pad) + conv2d_BS.CRS * (conv2d_BS.NPQ + pad) + csh_elems) * elem_size; }; - // coopmat1 needs to store the output through shared memory, so check up front - // whether it'll fit and disable it before applying coopmat1 parameters. + // 2D, transpose-2D, and 3D conv use the same KxCRS @ CRSxNPQ shmem + // layout. cm1 needs Csh for output, so check before applying cm1 params. if (conv2d_use_cm1 && device->properties.limits.maxComputeSharedMemorySize < shmem_req(conv2d_cm1_shmem_pad, true, true)) { conv2d_use_cm1 = false; } @@ -5486,6 +5537,53 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { } #undef CREATE_CONV #undef CREATE_CONVS + + std::vector conv3d_spec_constants = { conv2d_WG_SIZE, conv2d_BS.K, conv2d_BS.CRS, conv2d_BS.NPQ, conv2d_TS_K, conv2d_SHMEM_PAD }; +#define CREATE_CONV3D(type_suffix, spv_suffix) \ + for (auto &c : device->pipeline_conv3d##type_suffix[s]) { \ + const vk_conv3d_pipeline_state &state = c.first; \ + std::vector spec_constants_cpy = conv3d_spec_constants; \ + spec_constants_cpy.push_back(state.s0); \ + spec_constants_cpy.push_back(state.s1); \ + spec_constants_cpy.push_back(state.s2); \ + spec_constants_cpy.push_back(state.p0); \ + spec_constants_cpy.push_back(state.p1); \ + spec_constants_cpy.push_back(state.p2); \ + spec_constants_cpy.push_back(state.d0); \ + spec_constants_cpy.push_back(state.d1); \ + spec_constants_cpy.push_back(state.d2); \ + spec_constants_cpy.push_back(state.KW); \ + spec_constants_cpy.push_back(state.KH); \ + spec_constants_cpy.push_back(state.KD); \ + spec_constants_cpy.push_back(state.aligned); \ + spec_constants_cpy.push_back(conv2d_csh_store); \ + spec_constants_cpy.push_back(conv2d_WM); \ + spec_constants_cpy.push_back(conv2d_WN); \ + ggml_vk_create_pipeline( \ + device, c.second, "conv3d" #type_suffix, \ + conv3d##type_suffix##spv_suffix##_len, conv3d##type_suffix##spv_suffix##_data, "main", 3, \ + sizeof(vk_op_conv3d_push_constants), wg_denoms, spec_constants_cpy, 1, true, conv2d_required_subgroup_size != 0, conv2d_required_subgroup_size); \ + } +#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + if (device->coopmat2) { + CREATE_CONV3D(_f32, _cm2) + CREATE_CONV3D(_f16_f32, _cm2) + } else +#endif +#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + if (conv2d_use_cm1) { + CREATE_CONV3D(_f32, _cm1) + CREATE_CONV3D(_f16_f32, _cm1) + } else +#endif + if (conv2d_UNROLL) { + CREATE_CONV3D(_f32, _unroll) + CREATE_CONV3D(_f16_f32, _unroll) + } else { + CREATE_CONV3D(_f32, ) + CREATE_CONV3D(_f16_f32, ) + } +#undef CREATE_CONV3D } ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); @@ -10901,6 +10999,61 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const } } return nullptr; + case GGML_OP_CONV_3D: + if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + const uint32_t OC = (uint32_t)ggml_get_op_params_i32(dst, 11); + const uint32_t IC = (uint32_t)ggml_get_op_params_i32(dst, 9); + const uint32_t N = (uint32_t)ggml_get_op_params_i32(dst, 10); + const uint32_t NPQ = N * dst->ne[2] * dst->ne[1] * dst->ne[0]; + const vk_conv_shapes shape = ggml_vk_conv_select_shape(ctx, OC, NPQ); + + const uint32_t KW = (uint32_t)src0->ne[0]; + const uint32_t KH = (uint32_t)src0->ne[1]; + const uint32_t KD = (uint32_t)src0->ne[2]; + const uint32_t s0 = (uint32_t)ggml_get_op_params_i32(dst, 0); + const uint32_t s1 = (uint32_t)ggml_get_op_params_i32(dst, 1); + const uint32_t s2 = (uint32_t)ggml_get_op_params_i32(dst, 2); + const uint32_t p0 = (uint32_t)ggml_get_op_params_i32(dst, 3); + const uint32_t p1 = (uint32_t)ggml_get_op_params_i32(dst, 4); + const uint32_t p2 = (uint32_t)ggml_get_op_params_i32(dst, 5); + const uint32_t d0 = (uint32_t)ggml_get_op_params_i32(dst, 6); + const uint32_t d1 = (uint32_t)ggml_get_op_params_i32(dst, 7); + const uint32_t d2 = (uint32_t)ggml_get_op_params_i32(dst, 8); + + const uint32_t CRS = IC * KW * KH * KD; + const uint32_t BS_K = vk_conv_block_sizes[shape].K; + const uint32_t BS_CRS = vk_conv_block_sizes[shape].CRS; + const uint32_t BS_NPQ = vk_conv_block_sizes[shape].NPQ; + const uint32_t aligned = ((OC % BS_K == 0) && + (CRS % BS_CRS == 0) && + (NPQ % BS_NPQ == 0)) ? 1u : 0u; + + vk_conv3d_pipeline_state conv3d_pipeline_state(s0, s1, s2, p0, p1, p2, d0, d1, d2, KW, KH, KD, aligned); + + std::map *pipelines = nullptr; + if (src0->type == GGML_TYPE_F32) { + pipelines = &ctx->device->pipeline_conv3d_f32[shape]; + } else if (src0->type == GGML_TYPE_F16) { + pipelines = &ctx->device->pipeline_conv3d_f16_f32[shape]; + } else { + return nullptr; + } + + vk_pipeline pipeline = nullptr; + + { + std::lock_guard guard(ctx->device->compile_mutex); + auto it = pipelines->find(conv3d_pipeline_state); + if (it != pipelines->end()) { + pipeline = it->second; + } else { + (*pipelines)[conv3d_pipeline_state] = pipeline = std::make_shared(); + } + } + + return pipeline; + } + return nullptr; case GGML_OP_ADD1: if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { return ctx->device->pipeline_add1_f16_f16; @@ -11236,6 +11389,21 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co GGML_ABORT("invalid push constant type for CONV_2D"); } break; + case GGML_OP_CONV_3D: + if constexpr (std::is_same_v) { + const uint32_t NPQ = pc.N * pc.OD * pc.OH * pc.OW; + const vk_conv_shapes shape = ggml_vk_conv_select_shape(ctx, pc.OC, NPQ); + const uint32_t NPQ_blocks = CEIL_DIV(NPQ, vk_conv_block_sizes[shape].NPQ); + + elements = { pc.OC, NPQ_blocks, 1 }; + if (elements[1] > 512) { + elements[2] = CEIL_DIV(elements[1], 512); + elements[1] = 512; + } + } else { + GGML_ABORT("invalid push constant type for CONV_3D"); + } + break; case GGML_OP_ADD: case GGML_OP_SUB: case GGML_OP_DIV: @@ -13134,6 +13302,51 @@ static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context & subctx, ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, dst->op, std::move(p)); } +static void ggml_vk_conv_3d(ggml_backend_vk_context * ctx, vk_context & subctx, const ggml_tensor * src0, + const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + GGML_TENSOR_BINARY_OP_LOCALS + GGML_ASSERT(nb00 == sizeof(float) || nb00 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb10 == sizeof(float)); + GGML_ASSERT(nb0 == sizeof(float)); + + vk_op_conv3d_push_constants p{}; + p.IC = static_cast(ggml_get_op_params_i32(dst, 9)); + p.N = static_cast(ggml_get_op_params_i32(dst, 10)); + p.OC = static_cast(ggml_get_op_params_i32(dst, 11)); + GGML_ASSERT(src0->ne[3] == (int64_t)p.IC * p.OC); + GGML_ASSERT(src1->ne[3] == (int64_t)p.IC * p.N); + GGML_ASSERT(dst->ne[3] == (int64_t)p.OC * p.N); + + p.IW = static_cast(ne10); + p.IH = static_cast(ne11); + p.ID = static_cast(ne12); + p.OW = static_cast(ne0); + p.OH = static_cast(ne1); + p.OD = static_cast(ne2); + + // the shader clamps src addresses to p.IC * p.N * p.IW * p.IH * p.ID - 1 in uint32, so the + // total input element count must fit in a uint32. + GGML_ASSERT((uint64_t)p.IC * p.N * p.IW * p.IH * p.ID <= 0xFFFFFFFFull); + + p.nb01 = static_cast(nb01 / nb00); + p.nb02 = static_cast(nb02 / nb00); + p.nb03 = static_cast(nb03 / nb00); + + p.nb11 = static_cast(nb11 / nb10); + p.nb12 = static_cast(nb12 / nb10); + p.nb13 = static_cast(nb13 / nb10); + + p.nb1 = static_cast(nb1 / nb0); + p.nb2 = static_cast(nb2 / nb0); + p.nb3 = static_cast(nb3 / nb0); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_CONV_3D, std::move(p)); +} + static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { vk_op_conv2d_dw_push_constants p{}; p.ne = ggml_nelements(dst); @@ -14531,6 +14744,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_CONV_TRANSPOSE_2D: ggml_vk_conv_2d(ctx, compute_ctx, src0, src1, node); + break; + case GGML_OP_CONV_3D: + ggml_vk_conv_3d(ctx, compute_ctx, src0, src1, node); + break; case GGML_OP_CONV_2D_DW: ggml_vk_conv_2d_dw(ctx, compute_ctx, src0, src1, node); @@ -17301,6 +17518,13 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm ggml_is_contiguous(op->src[1]) && ggml_is_contiguous(op)); } + case GGML_OP_CONV_3D: + return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && + op->src[1]->type == GGML_TYPE_F32 && + op->type == GGML_TYPE_F32 && + ggml_is_contiguous(op->src[0]) && + ggml_is_contiguous(op->src[1]) && + ggml_is_contiguous(op); default: return false; } @@ -18144,6 +18368,20 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * const int32_t d0 = tensor->op_params[4]; const int32_t d1 = tensor->op_params[5]; tensor_clone = ggml_conv_2d(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1); + } else if (tensor->op == GGML_OP_CONV_3D) { + const int32_t s0 = tensor->op_params[0]; + const int32_t s1 = tensor->op_params[1]; + const int32_t s2 = tensor->op_params[2]; + const int32_t p0 = tensor->op_params[3]; + const int32_t p1 = tensor->op_params[4]; + const int32_t p2 = tensor->op_params[5]; + const int32_t d0 = tensor->op_params[6]; + const int32_t d1 = tensor->op_params[7]; + const int32_t d2 = tensor->op_params[8]; + const int32_t IC = tensor->op_params[9]; + const int32_t N = tensor->op_params[10]; + const int32_t OC = tensor->op_params[11]; + tensor_clone = ggml_conv_3d_direct(ggml_ctx, src_clone[0], src_clone[1], s0, s1, s2, p0, p1, p2, d0, d1, d2, IC, N, OC); } else if (tensor->op == GGML_OP_CONV_2D_DW) { const int32_t s0 = tensor->op_params[0]; const int32_t s1 = tensor->op_params[1]; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/conv3d_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/conv3d_mm.comp new file mode 100644 index 0000000000..a9712eb3ac --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/conv3d_mm.comp @@ -0,0 +1,431 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#ifdef COOPMAT2 +#extension GL_NV_cooperative_matrix2 : enable +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#extension GL_KHR_memory_scope_semantics : enable +#endif + +#ifdef COOPMAT +#extension GL_KHR_cooperative_matrix : enable +#extension GL_KHR_shader_subgroup_basic : enable +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#extension GL_KHR_memory_scope_semantics : enable +#endif + +#include "types.glsl" + +// shape notation: [dim(N), ..., dim(0)] -- stride(dim(j)) >= stride(dim(i)) if i > j +layout(binding = 0) readonly buffer A { + A_TYPE knl_data[]; +}; // src0 - kernel: [KW, KH, KD, IC*OC] + +layout(binding = 1) readonly buffer B { + B_TYPE src_data[]; +}; // src1 - input: [IW, IH, ID, IC*N] -- channel_first format + +layout(binding = 2) writeonly buffer D { + D_TYPE dst_data[]; +}; // dst - result: [OW, OH, OD, OC*N] + +layout(push_constant) uniform parameter { + // I/O channels, batch size + uint32_t OC; + uint32_t IC; + uint32_t N; + + // Tensor spatial sizes: input, output + uint32_t IW; + uint32_t IH; + uint32_t ID; + uint32_t OW; + uint32_t OH; + uint32_t OD; + + // Strides in elements + uint32_t nb01; + uint32_t nb02; + uint32_t nb03; + + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; + + uint32_t nb1; + uint32_t nb2; + uint32_t nb3; + + // fastdiv helper values + uint32_t OWmp; uint32_t OWL; + uint32_t OWOHmp; uint32_t OWOHL; + uint32_t OWOHODmp; uint32_t OWOHODL; +} + +p; + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; +// Blocktile sizes +layout(constant_id = 1) const uint BS_K = 128; +layout(constant_id = 2) const uint BS_CRS = 16; +layout(constant_id = 3) const uint BS_NPQ = 128; +// Thread-tile sizes +layout(constant_id = 4) const uint TS_K = 8; +layout(constant_id = 5) const uint SHMEM_PAD = 4; +// Stride, padding, dilation +layout(constant_id = 6) const uint s0 = 1; +layout(constant_id = 7) const uint s1 = 1; +layout(constant_id = 8) const uint s2 = 1; +layout(constant_id = 9) const uint p0 = 0; +layout(constant_id = 10) const uint p1 = 0; +layout(constant_id = 11) const uint p2 = 0; +layout(constant_id = 12) const uint d0 = 1; +layout(constant_id = 13) const uint d1 = 1; +layout(constant_id = 14) const uint d2 = 1; +// Kernel spatial sizes +layout(constant_id = 15) const uint KW = 1; +layout(constant_id = 16) const uint KH = 1; +layout(constant_id = 17) const uint KD = 1; +// when set, skip bounds checks and address clamps (K/CRS/NPQ are tile-aligned) +layout(constant_id = 18) const uint aligned = 0; +// stage cm2 result through shmem (Csh) for coalesced stores. cm1 always does this. +layout(constant_id = 19) const uint csh_store = 0; + +#ifdef COOPMAT +// cm1 subgroup tile: each subgroup computes a WM x WN region as a grid of +// TM x TN x TK fragments. Requires WM%TM == WN%TN == BS_K%WM == BS_NPQ%WN == +// BS_CRS%TK == 0, and WG_SIZE == (BS_K/WM) * (BS_NPQ/WN) * subgroup_size. +layout(constant_id = 20) const uint WM = 32; +layout(constant_id = 21) const uint WN = 32; +const uint TM = 16; +const uint TN = 16; +const uint TK = 16; +const uint cms_per_row = WM / TM; +const uint cms_per_col = WN / TN; +const uint warps_M = BS_K / WM; +const uint warps_N = BS_NPQ / WN; +#endif + +// without padding, ID_idx/IH_idx/IW_idx are in bounds by construction +const bool dhw_in_bounds = (p0 == 0) && (p1 == 0) && (p2 == 0); + +uint32_t tid = gl_LocalInvocationID.x; +const uint32_t WG_SIZE = gl_WorkGroupSize.x; + +uint splitWork(uint work_size, uint block_size) { + return (block_size + work_size - 1) / block_size; +} + +uint32_t K = p.OC; +uint32_t CRS = p.IC * KD * KH * KW; +uint32_t NPQ = p.N * p.OD * p.OH * p.OW; + +// Number of blocktiles per input +uint32_t NB_CRS = splitWork(CRS, BS_CRS); + +#if defined(COOPMAT2) || defined(COOPMAT) +#define SHMEM_TYPE float16_t +#else +#define SHMEM_TYPE float +#endif + +const uint32_t Ash_stride = BS_CRS + SHMEM_PAD; +const uint32_t Bsh_stride = BS_NPQ + SHMEM_PAD; + +const uint32_t Ash_len = BS_K * Ash_stride; +const uint32_t Bsh_len = BS_CRS * Bsh_stride; + +shared SHMEM_TYPE Ash[Ash_len]; // K x CRS +shared SHMEM_TYPE Bsh[Bsh_len]; // CRS x NPQ + +#if defined(COOPMAT2) || defined(COOPMAT) +// stage matC through shmem so global stores are row-major (NPQ-contiguous) +const uint32_t Csh_stride = BS_NPQ; +#ifdef COOPMAT +const uint32_t Csh_len = BS_K * Csh_stride; +#else +const uint32_t Csh_len = csh_store != 0 ? BS_K * Csh_stride : 1; +#endif +shared SHMEM_TYPE Csh[Csh_len]; // K x NPQ +#endif + +// Threadtile sizes +const uint32_t TS_NPQ = BS_K * BS_NPQ / WG_SIZE / TS_K; + +// Number of threadtiles per blocktile +const uint32_t NT_NPQ = BS_NPQ / TS_NPQ; + +/* +Compute +KxCRS @ CRSxNPQ = K x NPQ +K=OC +C=IC +D,R,S=KD,KH,KW +Z,P,Q=OD,OH,OW +*/ + +uint32_t B_idx_K = gl_WorkGroupID.x; +uint32_t B_idx_NPQ = gl_WorkGroupID.y + gl_WorkGroupID.z * 512; + +uint32_t T_y = tid / NT_NPQ; +uint32_t T_x = tid % NT_NPQ; + +uint32_t Ar = tid / BS_CRS; +uint32_t Ac = tid % BS_CRS; +const uint32_t ArpWg = WG_SIZE / BS_CRS; + +uint32_t Br = tid / BS_NPQ; +uint32_t Bc = tid % BS_NPQ; +const uint32_t BrpWg = WG_SIZE / BS_NPQ; + +// see init_fastdiv_values in ggml-vulkan.cpp +uint fastdiv(uint n, uint mp, uint L) { + uint msbs, lsbs; + // msbs = mulhi(n, mp) + umulExtended(n, mp, msbs, lsbs); + return (msbs + n) >> L; +} + +void split_crs(uint32_t crs_idx, out uint32_t ic, out uint32_t kd, out uint32_t kh, out uint32_t kw) { + const uint32_t KHKW = KH * KW; + const uint32_t KDKHKW = KD * KHKW; + ic = crs_idx / KDKHKW; + uint32_t rem = crs_idx - ic * KDKHKW; + kd = rem / KHKW; + rem = rem - kd * KHKW; + kh = rem / KW; + kw = rem - kh * KW; +} + +void split_npq(uint32_t npq_idx, out uint32_t n, out uint32_t od, out uint32_t oh, out uint32_t ow) { + const uint32_t OWOH = p.OW * p.OH; + n = fastdiv(npq_idx, p.OWOHODmp, p.OWOHODL); + uint32_t rem = npq_idx - n * p.OD * OWOH; + od = fastdiv(rem, p.OWOHmp, p.OWOHL); + rem = rem - od * OWOH; + oh = fastdiv(rem, p.OWmp, p.OWL); + ow = rem - oh * p.OW; +} + +#ifdef COOPMAT2 +#define ACC_TYPE float16_t + +ACC_TYPE perElemOpStore(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem) +{ + uint32_t K_idx = B_idx_K * BS_K + r; + uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + c; + uint32_t N_idx; + uint32_t OD_idx; + uint32_t OH_idx; + uint32_t OW_idx; + split_npq(NPQ_idx, N_idx, OD_idx, OH_idx, OW_idx); + uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + OD_idx * p.nb2 + (N_idx * p.OC + K_idx) * p.nb3; + if (aligned != 0 || (K_idx < K && NPQ_idx < NPQ)) { + dst_data[dst_idx] = D_TYPE(elem); + } + return elem; +} +#endif + +void main() { + if (B_idx_NPQ * BS_NPQ >= NPQ) { + return; + } + +#ifdef COOPMAT2 + coopmat matC; + matC = coopmat(0.0); +#elif defined(COOPMAT) + coopmat sums[cms_per_row * cms_per_col]; + [[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) { + sums[i] = coopmat(0.0); + } + const uint warp_r = gl_SubgroupID / warps_N; + const uint warp_c = gl_SubgroupID % warps_N; +#else + float regC[TS_K][TS_NPQ]; + for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) { + for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) { + regC[T_ly][T_lx] = 0.0; + } + } +#endif + /* Advance block in CRS dim */ + [[dont_unroll]] for (uint32_t B_idx_CRS = 0; B_idx_CRS < NB_CRS; B_idx_CRS++) { + uint32_t CRS_idx_a = B_idx_CRS * BS_CRS + Ac; + uint32_t IC_idx_a; + uint32_t KD_idx_a; + uint32_t KH_idx_a; + uint32_t KW_idx_a; + split_crs(CRS_idx_a, IC_idx_a, KD_idx_a, KH_idx_a, KW_idx_a); + + /* Load kernel to A_block: (BS_K x BS_CRS)*/ + UNROLL for (uint32_t r_offset = 0; r_offset < BS_K; r_offset += ArpWg) { + uint32_t B_ly = r_offset + Ar; + uint32_t B_lx = Ac; + uint32_t K_idx = B_idx_K * BS_K + B_ly; /* Global K_idx (row index of A)*/ + uint32_t knl_idx = KW_idx_a + KH_idx_a * p.nb01 + KD_idx_a * p.nb02 + (K_idx * p.IC + IC_idx_a) * p.nb03; + if (aligned == 0) { + knl_idx = min(knl_idx, K * CRS - 1); + } + float val = knl_data[knl_idx]; + if (aligned == 0 && (K_idx >= K || CRS_idx_a >= CRS)) { + val = 0.0; + } + Ash[B_ly * Ash_stride + B_lx] = SHMEM_TYPE(val); + } + /* Load input to B_block: (BS_CRS x BS_NPQ) */ + UNROLL for (uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg) { + uint32_t B_ly = r_offset + Br; /* Row index of B block */ + uint32_t B_lx = Bc; + uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + B_lx; /* Global NPQ index (column index of B) */ + uint32_t N_idx; + uint32_t OD_idx; + uint32_t OH_idx; + uint32_t OW_idx; + split_npq(NPQ_idx, N_idx, OD_idx, OH_idx, OW_idx); + + uint32_t CRS_idx_b = B_idx_CRS * BS_CRS + B_ly; + uint32_t IC_idx_b; + uint32_t KD_idx_b; + uint32_t KH_idx_b; + uint32_t KW_idx_b; + split_crs(CRS_idx_b, IC_idx_b, KD_idx_b, KH_idx_b, KW_idx_b); + + uint32_t ID_idx = OD_idx * s2 + KD_idx_b * d2 - p2; + uint32_t IH_idx = OH_idx * s1 + KH_idx_b * d1 - p1; + uint32_t IW_idx = OW_idx * s0 + KW_idx_b * d0 - p0; + + uint32_t src_idx = IW_idx + IH_idx * p.nb11 + ID_idx * p.nb12 + (N_idx * p.IC + IC_idx_b) * p.nb13; + // skip clamp when address can't go OOB + if (aligned == 0 || !dhw_in_bounds) { + src_idx = min(src_idx, p.IC * p.N * p.IW * p.IH * p.ID - 1); + } + float val = src_data[src_idx]; + bool oob = false; + if (aligned == 0 && (CRS_idx_b >= CRS || NPQ_idx >= NPQ)) { + oob = true; + } + // also catches lower-bound underflow (idx wraps to 0x80000000+) + if (!dhw_in_bounds && (ID_idx >= p.ID || IH_idx >= p.IH || IW_idx >= p.IW)) { + oob = true; + } + if (oob) { + val = 0.0; + } + Bsh[B_ly * Bsh_stride + B_lx] = SHMEM_TYPE(val); + } + barrier(); +#ifdef COOPMAT2 + coopmat matA; + coopmat matB; + + coopMatLoad(matA, Ash, 0, Ash_stride, gl_CooperativeMatrixLayoutRowMajor); + coopMatLoad(matB, Bsh, 0, Bsh_stride, gl_CooperativeMatrixLayoutRowMajor); + matC = coopMatMulAdd(matA, matB, matC); +#elif defined(COOPMAT) + // each subgroup multiplies its grid of fragments per TK-sized CRS chunk + [[unroll]] for (uint k_step = 0; k_step < BS_CRS / TK; k_step++) { + coopmat cache_a[cms_per_row]; + [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { + const uint a_off = (warp_r * WM + cm_row * TM) * Ash_stride + k_step * TK; + coopMatLoad(cache_a[cm_row], Ash, a_off, Ash_stride, gl_CooperativeMatrixLayoutRowMajor); + } + [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { + coopmat cache_b; + const uint b_off = k_step * TK * Bsh_stride + warp_c * WN + cm_col * TN; + coopMatLoad(cache_b, Bsh, b_off, Bsh_stride, gl_CooperativeMatrixLayoutRowMajor); + [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { + sums[cm_col * cms_per_row + cm_row] = coopMatMulAdd(cache_a[cm_row], cache_b, sums[cm_col * cms_per_row + cm_row]); + } + } + } +#else + if (T_y * TS_K < K) { + UNROLL for (uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++) { + float regA[TS_K]; + float regB[TS_NPQ]; + for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) { + regA[T_ly] = Ash[(T_y * TS_K + T_ly) * Ash_stride + CRS_lidx]; + } + for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) { + regB[T_lx] = Bsh[CRS_lidx * Bsh_stride + T_x * TS_NPQ + T_lx]; + } + for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) { + for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) { + regC[T_ly][T_lx] = fma(regA[T_ly], regB[T_lx], regC[T_ly][T_lx]); + } + } + } + } +#endif + barrier(); + } + /* Save C* */ +#if defined(COOPMAT2) || defined(COOPMAT) + // stage matC into Csh, then write to dst with coalesced NPQ-contiguous stores +#ifdef COOPMAT + const bool use_staged_store = true; +#else + const bool use_staged_store = (csh_store != 0); +#endif + if (use_staged_store) { +#ifdef COOPMAT + // cm1: each subgroup stores its fragment grid into its Csh slot + [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { + [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { + const uint csh_off = (warp_r * WM + cm_row * TM) * Csh_stride + warp_c * WN + cm_col * TN; + coopMatStore(sums[cm_col * cms_per_row + cm_row], Csh, csh_off, Csh_stride, gl_CooperativeMatrixLayoutRowMajor); + } + } +#else + coopMatStore(matC, Csh, 0, Csh_stride, gl_CooperativeMatrixLayoutRowMajor); +#endif + barrier(); + + // cooperative shmem->global: WG threads spread across BS_NPQ (the + // contiguous direction of dst), each iter covers store_rows_per_iter K-rows + const uint32_t store_rows_per_iter = WG_SIZE / BS_NPQ; + const uint32_t store_iters = BS_K / store_rows_per_iter; + const uint32_t k_thread_offset = tid / BS_NPQ; + const uint32_t npq_thread = tid % BS_NPQ; + [[unroll]] for (uint32_t i = 0; i < store_iters; i++) { + uint32_t k_local = i * store_rows_per_iter + k_thread_offset; + uint32_t K_idx = B_idx_K * BS_K + k_local; + uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + npq_thread; + uint32_t N_idx; + uint32_t OD_idx; + uint32_t OH_idx; + uint32_t OW_idx; + split_npq(NPQ_idx, N_idx, OD_idx, OH_idx, OW_idx); + uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + OD_idx * p.nb2 + (N_idx * p.OC + K_idx) * p.nb3; + if (aligned != 0 || (K_idx < K && NPQ_idx < NPQ)) { + dst_data[dst_idx] = D_TYPE(Csh[k_local * Csh_stride + npq_thread]); + } + } + } +#ifdef COOPMAT2 + else { + coopMatPerElementNV(matC, matC, perElemOpStore); + } +#endif +#else + if (T_y * TS_K < K) { + for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) { + for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) { + uint32_t K_idx = B_idx_K * BS_K + T_y * TS_K + T_ly; + uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx; + uint32_t N_idx; + uint32_t OD_idx; + uint32_t OH_idx; + uint32_t OW_idx; + split_npq(NPQ_idx, N_idx, OD_idx, OH_idx, OW_idx); + uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + OD_idx * p.nb2 + (N_idx * p.OC + K_idx) * p.nb3; + if (aligned != 0 || (K_idx < K && NPQ_idx < NPQ)) { + dst_data[dst_idx] = D_TYPE(regC[T_ly][T_lx]); + } + } + } + } +#endif +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index f07583b6ab..2f5661f548 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -1053,6 +1053,31 @@ void process_shaders() { } } + for (auto unroll : {false, true}) { + for (auto a_f16 : {false, true}) { + std::map defines = { + {"A_TYPE", a_f16 ? "float16_t" : "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, + {"UNROLL", unroll ? "[[unroll]]" : ""}, + }; + std::string name = std::string("conv3d") + (a_f16 ? "_f16" : "") + "_f32"; + string_to_spv(name + (unroll ? "_unroll" : ""), "conv3d_mm.comp", defines); +#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + if (unroll) { + auto cm2_defines = defines; + cm2_defines["COOPMAT2"] = "1"; + string_to_spv(name, "conv3d_mm.comp", cm2_defines, true, false, true); + } +#endif +#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + if (unroll) { + auto cm1_defines = defines; + cm1_defines["COOPMAT"] = "1"; + string_to_spv(name, "conv3d_mm.comp", cm1_defines, true, true, false); + } +#endif + } + } + string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}})); string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}})); string_to_spv("conv2d_dw_whcn_f16_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}})); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 127c4634c0..719ae51cc2 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -9272,6 +9272,34 @@ static std::vector> make_test_cases_perf() { } } + struct conv3d_perf_case { + int N, IC, ID, IH, IW, OC, KD, KH, KW, s0, s1, s2, p0, p1, p2, d0, d1, d2; + }; + + const std::vector conv3d_cases = { + {1, 320, 8, 38, 26, 1280, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1}, + {1, 1280, 8, 38, 26, 1280, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1}, + {1, 320, 8, 76, 52, 1280, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1}, + {1, 1280, 8, 76, 52, 1280, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1}, + {1, 320, 8, 152, 104, 1280, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1}, +#if 0 + // too slow on some devices + {1, 1280, 8, 152, 104, 1280, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1}, + {1, 320, 4, 304, 208, 1280, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1}, + {1, 640, 4, 304, 208, 1280, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1}, +#endif + }; + + for (ggml_type kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) { + for (const conv3d_perf_case & c : conv3d_cases) { + test_cases.emplace_back(new test_conv_3d( + c.N, c.IC, c.ID, c.IH, c.IW, + c.OC, c.KD, c.KH, c.KW, + c.s0, c.s1, c.s2, c.p0, c.p1, c.p2, c.d0, c.d1, c.d2, + kernel_type)); + } + } + test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 1, 1, 1})); test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 512, 1, 1})); From 92e854ab836254bb7f2eb49babd5613474bdb700 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Tue, 23 Jun 2026 08:39:37 -0500 Subject: [PATCH 62/86] vulkan: Support GET_ROWS_BACK (#24883) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 32 +++++++++++++++++++ .../vulkan-shaders/get_rows_back.comp | 25 +++++++++++++++ .../vulkan-shaders/vulkan-shaders-gen.cpp | 1 + 3 files changed, 58 insertions(+) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/get_rows_back.comp diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 508d569f20..d2827ad71f 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -791,6 +791,7 @@ struct vk_device_struct { vk_pipeline pipeline_mul_mat_vec_nc_f16_f32; vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT]; vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT]; + vk_pipeline pipeline_get_rows_back_f32; vk_pipeline pipeline_acc_f32; vk_pipeline pipeline_set_f32; @@ -4946,6 +4947,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4], "get_rows_mxfp4_f32", get_rows_mxfp4_f32_len, get_rows_mxfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_NVFP4], "get_rows_nvfp4_f32", get_rows_nvfp4_f32_len, get_rows_nvfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_back_f32, "get_rows_back_f32", get_rows_back_f32_len, get_rows_back_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {256, 1, 1}, {}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, sizeof(vk_op_flash_attn_split_k_reduce_push_constants), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true); @@ -10408,6 +10410,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_get_rows_f32[src0->type]; } return nullptr; + case GGML_OP_GET_ROWS_BACK: + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_get_rows_back_f32; + } + return nullptr; case GGML_OP_ACC: if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_acc_f32; @@ -11304,6 +11311,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]); elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]); break; + case GGML_OP_GET_ROWS_BACK: + elements = { (uint32_t)dst->ne[0], (uint32_t)dst->ne[1], 1 }; + elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]); + break; case GGML_OP_ARGSORT: GGML_ASSERT(0); break; @@ -11564,6 +11575,21 @@ static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context& subctx, }); } +static void ggml_vk_get_rows_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_GET_ROWS_BACK, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2], (uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, 0, + }); +} + static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { const uint32_t src0_type_size = ggml_type_size(src0->type); const uint32_t src1_type_size = ggml_type_size(src1->type); @@ -14476,6 +14502,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_GET_ROWS: ggml_vk_get_rows(ctx, compute_ctx, src0, src1, node); + break; + case GGML_OP_GET_ROWS_BACK: + ggml_vk_get_rows_back(ctx, compute_ctx, src0, src1, node); + break; case GGML_OP_ADD: if (ctx->num_additional_fused_ops) { @@ -17197,6 +17227,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm return false; } } + case GGML_OP_GET_ROWS_BACK: + return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_SET_ROWS: { switch (op->type) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_back.comp b/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_back.comp new file mode 100644 index 0000000000..7e3d8a2819 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_back.comp @@ -0,0 +1,25 @@ +#version 450 + +#include "types.glsl" +#include "generic_binary_head.glsl" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint col = gl_GlobalInvocationID.x; + + if (col >= p.ne20) { + return; + } + + for (uint row = gl_GlobalInvocationID.y; row < p.ne21; row += gl_WorkGroupSize.y * gl_NumWorkGroups.y) { + float sum = 0.0f; + for (uint i = 0; i < p.ne10; ++i) { + if (data_b[get_boffset() + i*p.nb10] == int(row)) { + sum += data_a[get_aoffset() + i*p.nb01 + col*p.nb00]; + } + } + + data_d[get_doffset() + row*p.nb21 + col*p.nb20] = sum; + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 2f5661f548..502602f799 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -843,6 +843,7 @@ void process_shaders() { string_to_spv("repeat_i32", "repeat.comp", {{"A_TYPE", "int32_t"}, {"D_TYPE", "int32_t"}}); string_to_spv("repeat_back_f32", "repeat_back.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("get_rows_back_f32", "get_rows_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}}); string_to_spv("repeat_i16", "repeat.comp", {{"A_TYPE", "int16_t"}, {"D_TYPE", "int16_t"}}); From 72a9269172290829c503c351e2b68ca2f7af2bc7 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Tue, 23 Jun 2026 09:48:24 -0500 Subject: [PATCH 63/86] vulkan: support all backend tests for SQR/SQRT/SIN/COS/CLAMP/LEAKY_RELU/NORM (#24582) * vulkan: make SQR/SQRT/SIN/COS/CLAMP/LEAKY_RELU use unary.comp * vulkan: make NORM support noncontig * add noncontiguous row test cases for norm/l2_norm, handle this in the CPU backend and l2_norm.comp * fix supports_op for cuda and webgpu --- ggml/src/ggml-cpu/ops.cpp | 73 ++++++++++++------ ggml/src/ggml-cuda/ggml-cuda.cu | 2 +- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 76 +++++++++++-------- .../src/ggml-vulkan/vulkan-shaders/clamp.comp | 17 ----- ggml/src/ggml-vulkan/vulkan-shaders/cos.comp | 17 ----- .../ggml-vulkan/vulkan-shaders/l2_norm.comp | 11 +-- .../vulkan-shaders/leaky_relu.comp | 22 ------ ggml/src/ggml-vulkan/vulkan-shaders/norm.comp | 20 ++--- ggml/src/ggml-vulkan/vulkan-shaders/sin.comp | 17 ----- ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp | 17 ----- .../ggml-vulkan/vulkan-shaders/square.comp | 17 ----- .../src/ggml-vulkan/vulkan-shaders/unary.comp | 24 ++++++ .../vulkan-shaders/vulkan-shaders-gen.cpp | 23 +++--- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 2 +- tests/test-backend-ops.cpp | 34 +++++++-- 15 files changed, 171 insertions(+), 201 deletions(-) delete mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp delete mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/cos.comp delete mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp delete mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/sin.comp delete mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp delete mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/square.comp diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 74611dce7f..6724686b8a 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -3688,8 +3688,6 @@ static void ggml_compute_forward_norm_f32( GGML_ASSERT(ggml_are_same_shape(src0, dst)); - GGML_ASSERT(src0->nb[0] == sizeof(float)); - const int ith = params->ith; const int nth = params->nth; @@ -3703,25 +3701,49 @@ static void ggml_compute_forward_norm_f32( for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { for (int64_t i01 = ith; i01 < ne01; i01 += nth) { - const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + const char * x = (const char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03; + char * y = (char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3; - float sum = 0.0; - ggml_vec_sum_f32(ne00, &sum, x); - float mean = sum/ne00; + if (nb00 == sizeof(float) && nb0 == sizeof(float)) { + const float * xf = (const float *) x; - float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); - float variance = 0; + float sum = 0.0; + ggml_vec_sum_f32(ne00, &sum, xf); + float mean = sum/ne00; + + float * yf = (float *) y; + float variance = 0; #ifdef GGML_USE_ACCELERATE - mean = -mean; - vDSP_vsadd(x, 1, &mean, y, 1, ne00); - vDSP_measqv(y, 1, &variance, ne00); + mean = -mean; + vDSP_vsadd(xf, 1, &mean, yf, 1, ne00); + vDSP_measqv(yf, 1, &variance, ne00); #else - variance = ggml_vec_cvar_f32(ne00, y, x, mean); + variance = ggml_vec_cvar_f32(ne00, yf, xf, mean); #endif //GGML_USE_ACCELERATE - const float scale = 1.0f/sqrtf(variance + eps); - ggml_vec_scale_f32(ne00, y, scale); + const float scale = 1.0f/sqrtf(variance + eps); + ggml_vec_scale_f32(ne00, yf, scale); + } else { + float sum = 0.0; + for (int64_t i00 = 0; i00 < ne00; i00++) { + sum += *(const float *) (x + i00*nb00); + } + const float mean = sum/ne00; + + float variance = 0.0f; + for (int64_t i00 = 0; i00 < ne00; i00++) { + const float v = *(const float *) (x + i00*nb00) - mean; + *(float *) (y + i00*nb0) = v; + variance += v * v; + } + variance /= ne00; + + const float scale = 1.0f/sqrtf(variance + eps); + for (int64_t i00 = 0; i00 < ne00; i00++) { + *(float *) (y + i00*nb0) *= scale; + } + } } } } @@ -4142,8 +4164,6 @@ static void ggml_compute_forward_l2_norm_f32( GGML_ASSERT(ggml_are_same_shape(src0, dst)); - GGML_ASSERT(src0->nb[0] == sizeof(float)); - const int ith = params->ith; const int nth = params->nth; @@ -4158,20 +4178,27 @@ static void ggml_compute_forward_l2_norm_f32( for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { for (int64_t i01 = ith; i01 < ne01; i01 += nth) { - const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + const char * x = (const char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03; ggml_float sum = 0.0; for (int64_t i00 = 0; i00 < ne00; i00++) { - sum += (ggml_float)(x[i00] * x[i00]); + const float xi = *(const float *) (x + i00*nb00); + sum += (ggml_float)(xi * xi); } - float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); - - memcpy(y, x, ne00 * sizeof(float)); - const float scale = 1.0f/fmaxf(sqrtf(sum), eps); - ggml_vec_scale_f32(ne00, y, scale); + char * y = (char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3; + + if (nb00 == sizeof(float) && nb0 == sizeof(float)) { + memcpy(y, x, ne00 * sizeof(float)); + ggml_vec_scale_f32(ne00, (float *) y, scale); + } else { + for (int64_t i00 = 0; i00 < ne00; i00++) { + const float xi = *(const float *) (x + i00*nb00); + *(float *) (y + i00*nb0) = xi * scale; + } + } } } } diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 3d4b5f6056..cca70592f8 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -5334,7 +5334,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_NORM: case GGML_OP_RMS_NORM: case GGML_OP_L2_NORM: - return true; + return ggml_is_contiguous_rows(op->src[0]); case GGML_OP_RMS_NORM_BACK: return ggml_is_contiguous(op->src[0]); break; diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index d2827ad71f..f4a578b893 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -816,14 +816,10 @@ struct vk_device_struct { vk_pipeline pipeline_concat_i8, pipeline_concat_i16, pipeline_concat_i32, pipeline_concat_i64; vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bicubic_f32, pipeline_upscale_bilinear_antialias_f32; vk_pipeline pipeline_scale_f32; - vk_pipeline pipeline_sqr_f32; - vk_pipeline pipeline_sqrt_f32; - vk_pipeline pipeline_sin_f32; - vk_pipeline pipeline_cos_f32; vk_pipeline pipeline_log[2]; vk_pipeline pipeline_tri[2]; vk_pipeline pipeline_diag[2]; - vk_pipeline pipeline_clamp_f32; + vk_pipeline pipeline_clamp[2]; vk_pipeline pipeline_pad_f32; vk_pipeline pipeline_roll_f32; vk_pipeline pipeline_repeat_i32, pipeline_repeat_back_f32; @@ -855,6 +851,10 @@ struct vk_device_struct { vk_pipeline pipeline_gelu_quick[2]; vk_pipeline pipeline_silu[2]; vk_pipeline pipeline_relu[2]; + vk_pipeline pipeline_sqr[2]; + vk_pipeline pipeline_sqrt[2]; + vk_pipeline pipeline_sin[2]; + vk_pipeline pipeline_cos[2]; vk_pipeline pipeline_xielu[2]; vk_pipeline pipeline_neg[2]; vk_pipeline pipeline_tanh[2]; @@ -886,7 +886,7 @@ struct vk_device_struct { vk_pipeline pipeline_geglu_erf[2]; vk_pipeline pipeline_geglu_quick[2]; - vk_pipeline pipeline_leaky_relu_f32; + vk_pipeline pipeline_leaky_relu[2]; vk_pipeline pipeline_silu_back_f32; vk_pipeline pipeline_diag_mask_inf_f32; vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16; @@ -4972,7 +4972,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { } ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_nc_push_constants), {1, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true); @@ -5092,11 +5092,6 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_sqr_f32, "sqr_f32", sqr_f32_len, sqr_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_sqrt_f32, "sqrt_f32", sqrt_f32_len, sqrt_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_sin_f32, "sin_f32", sin_f32_len, sin_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cos_f32, "cos_f32", cos_f32_len, cos_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_log[0], "log_f32", log_f32_len, log_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16", log_f16_len, log_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); @@ -5106,8 +5101,6 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { ggml_vk_create_pipeline(device, device->pipeline_diag[0], "diag_f32", diag_f32_len, diag_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_diag[1], "diag_f16", diag_f16_len, diag_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_pad_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_roll_f32, "roll_f32", roll_f32_len, roll_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); @@ -5127,6 +5120,12 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { CREATE_UNARY(gelu_quick) CREATE_UNARY(silu) CREATE_UNARY(relu) + CREATE_UNARY(sqr) + CREATE_UNARY(sqrt) + CREATE_UNARY(sin) + CREATE_UNARY(cos) + CREATE_UNARY(clamp) + CREATE_UNARY(leaky_relu) CREATE_UNARY(xielu) CREATE_UNARY(neg) CREATE_UNARY(tanh) @@ -5166,7 +5165,6 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { CREATE_GLU(geglu_quick) #undef CREATE_GLU - ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true); @@ -10521,23 +10519,27 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const } return nullptr; case GGML_OP_SQR: - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_sqr_f32; + if (src0->type == dst->type && + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) { + return ctx->device->pipeline_sqr[dst->type == GGML_TYPE_F16]; } return nullptr; case GGML_OP_SQRT: - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_sqrt_f32; + if (src0->type == dst->type && + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) { + return ctx->device->pipeline_sqrt[dst->type == GGML_TYPE_F16]; } return nullptr; case GGML_OP_SIN: - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_sin_f32; + if (src0->type == dst->type && + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) { + return ctx->device->pipeline_sin[dst->type == GGML_TYPE_F16]; } return nullptr; case GGML_OP_COS: - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_cos_f32; + if (src0->type == dst->type && + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) { + return ctx->device->pipeline_cos[dst->type == GGML_TYPE_F16]; } return nullptr; case GGML_OP_LOG: @@ -10559,8 +10561,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const } return nullptr; case GGML_OP_CLAMP: - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_clamp_f32; + if (src0->type == dst->type && + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) { + return ctx->device->pipeline_clamp[dst->type == GGML_TYPE_F16]; } return nullptr; case GGML_OP_PAD: @@ -10928,8 +10931,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const } return nullptr; case GGML_OP_LEAKY_RELU: - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_leaky_relu_f32; + if (src0->type == dst->type && + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) { + return ctx->device->pipeline_leaky_relu[dst->type == GGML_TYPE_F16]; } return nullptr; case GGML_OP_CONV_2D: @@ -11431,6 +11435,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co case GGML_OP_TRI: case GGML_OP_DIAG: case GGML_OP_CLAMP: + case GGML_OP_LEAKY_RELU: case GGML_OP_PAD: case GGML_OP_ROLL: case GGML_OP_REPEAT: @@ -12297,8 +12302,10 @@ static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx, static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { float * op_params = (float *)dst->op_params; + vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst); + p.param1 = op_params[0]; - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f, 0.0f, 0.0f }); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_NORM, std::move(p)); } static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { @@ -13399,7 +13406,10 @@ static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { const float * op_params = (const float *)dst->op_params; - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f, 0.0f, 0.0f }); + vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst); + p.param1 = op_params[0]; + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, std::move(p)); } #ifdef GGML_VULKAN_RUN_TESTS @@ -17325,12 +17335,11 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_TRANSPOSE: case GGML_OP_RMS_NORM: return true; - case GGML_OP_NORM: case GGML_OP_GROUP_NORM: return ggml_is_contiguous(op->src[0]); + case GGML_OP_NORM: case GGML_OP_L2_NORM: - return ggml_is_contiguous_rows(op->src[0]) && - op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; + return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; case GGML_OP_ADD: case GGML_OP_SUB: case GGML_OP_MUL: @@ -17349,8 +17358,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_SIN: case GGML_OP_COS: case GGML_OP_CLAMP: - return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_LEAKY_RELU: + return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && + op->type == op->src[0]->type; case GGML_OP_OPT_STEP_ADAMW: case GGML_OP_OPT_STEP_SGD: return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp b/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp deleted file mode 100644 index 653431895e..0000000000 --- a/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +++ /dev/null @@ -1,17 +0,0 @@ -#version 450 - -#include "types.glsl" -#include "generic_unary_head.glsl" - -layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; - -void main() { - const uint idx = get_idx(); - - if (idx >= p.ne) { - return; - } - - const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]); - data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val < p.param1 ? p.param1 : (val > p.param2 ? p.param2 : val)); -} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp b/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp deleted file mode 100644 index db6865db98..0000000000 --- a/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +++ /dev/null @@ -1,17 +0,0 @@ -#version 450 - -#include "types.glsl" -#include "generic_unary_head.glsl" - -layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; - -void main() { - const uint idx = get_idx(); - - if (idx >= p.ne) { - return; - } - - const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]); - data_d[get_doffset() + dst_idx(idx)] = D_TYPE(cos(val)); -} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp index f9af46744d..9039ed1ded 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp @@ -14,16 +14,13 @@ void main() { const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; const uint tid = gl_LocalInvocationID.x; - const uint i3 = row / (p.ne11 * p.ne12); - const uint i3_offset = i3 * p.ne12 * p.ne11; - const uint i2 = (row - i3_offset) / p.ne11; - const uint i2_offset = i2 * p.ne11; - const uint i1 = row - i3_offset - i2_offset; + const uint a_base = get_aoffset() + src0_idx(row * p.ne00); + const uint d_base = get_doffset() + dst_idx(row * p.ne10); sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp [[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) { - const FLOAT_TYPE xi = FLOAT_TYPE(data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0]); + const FLOAT_TYPE xi = FLOAT_TYPE(data_a[a_base + i0*p.nb00]); sum[tid] += xi * xi; } @@ -39,6 +36,6 @@ void main() { const FLOAT_TYPE scale = 1.0f / max(sqrt(sum[0]), FLOAT_TYPE(p.param1)); [[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) { - data_d[i3*p.nb13 + i2*p.nb12 + i1*p.nb11 + i0] = D_TYPE(scale * FLOAT_TYPE(data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0])); + data_d[d_base + i0*p.nb10] = D_TYPE(scale * FLOAT_TYPE(data_a[a_base + i0*p.nb00])); } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp b/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp deleted file mode 100644 index b281e855cb..0000000000 --- a/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +++ /dev/null @@ -1,22 +0,0 @@ -#version 450 - -#include "generic_head.glsl" -#include "types.glsl" - -#extension GL_EXT_control_flow_attributes : enable - -layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; - -layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; -layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; - -void main() { - const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; - - if (i >= p.KX) { - return; - } - - const float val = float(data_a[i]); - data_d[i] = D_TYPE(max(val, 0.0f) + min(val, 0.0f) * p.param1); -} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp index cc3ea0b760..792012d57e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp @@ -1,26 +1,26 @@ #version 450 -#include "generic_head.glsl" #include "types.glsl" +#include "generic_unary_head.glsl" #extension GL_EXT_control_flow_attributes : enable #define BLOCK_SIZE 512 layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; -layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; -layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; - shared vec2 sum[BLOCK_SIZE]; void main() { const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; const uint tid = gl_LocalInvocationID.x; + const uint a_base = get_aoffset() + src0_idx(row * p.ne00); + const uint d_base = get_doffset() + dst_idx(row * p.ne10); + sum[tid] = vec2(0.0f, 0.0f); - [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { - const float xi = float(data_a[row*p.KX + col]); + [[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) { + const float xi = float(data_a[a_base + i0*p.nb00]); sum[tid].x += xi; sum[tid].y += xi * xi; } @@ -34,11 +34,11 @@ void main() { barrier(); } - const float mean = sum[0].x / p.KX; - const float var = sum[0].y / p.KX - mean * mean; + const float mean = sum[0].x / p.ne00; + const float var = sum[0].y / p.ne00 - mean * mean; const float inv_std = inversesqrt(var + p.param1); - [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { - data_d[row*p.KX + col] = D_TYPE((float(data_a[row*p.KX + col]) - mean) * inv_std); + [[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) { + data_d[d_base + i0*p.nb10] = D_TYPE((float(data_a[a_base + i0*p.nb00]) - mean) * inv_std); } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp b/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp deleted file mode 100644 index 61f17b2f00..0000000000 --- a/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +++ /dev/null @@ -1,17 +0,0 @@ -#version 450 - -#include "types.glsl" -#include "generic_unary_head.glsl" - -layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; - -void main() { - const uint idx = get_idx(); - - if (idx >= p.ne) { - return; - } - - const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]); - data_d[get_doffset() + dst_idx(idx)] = D_TYPE(sin(val)); -} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp b/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp deleted file mode 100644 index 70daad6c5d..0000000000 --- a/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +++ /dev/null @@ -1,17 +0,0 @@ -#version 450 - -#include "types.glsl" -#include "generic_unary_head.glsl" - -layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; - -void main() { - const uint idx = get_idx(); - - if (idx >= p.ne) { - return; - } - - const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]); - data_d[get_doffset() + dst_idx(idx)] = D_TYPE(sqrt(val)); -} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/square.comp b/ggml/src/ggml-vulkan/vulkan-shaders/square.comp deleted file mode 100644 index 4eb56afcb1..0000000000 --- a/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +++ /dev/null @@ -1,17 +0,0 @@ -#version 450 - -#include "types.glsl" -#include "generic_unary_head.glsl" - -layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; - -void main() { - const uint idx = get_idx(); - - if (idx >= p.ne) { - return; - } - - const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]); - data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val * val); -} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/unary.comp b/ggml/src/ggml-vulkan/vulkan-shaders/unary.comp index 47a4573996..c62bce8255 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/unary.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/unary.comp @@ -17,6 +17,30 @@ float op_neg(float x) { return -x; } +float op_sqr(float x) { + return x * x; +} + +float op_sqrt(float x) { + return sqrt(x); +} + +float op_sin(float x) { + return sin(x); +} + +float op_cos(float x) { + return cos(x); +} + +float op_clamp(float x) { + return clamp(x, p.param1, p.param2); +} + +float op_leaky_relu(float x) { + return max(x, 0.0f) + min(x, 0.0f) * p.param1; +} + float op_step(float x) { return x >= 0.0f ? 1.0f : 0.0f; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 502602f799..3bd93d256c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -849,16 +849,6 @@ void process_shaders() { string_to_spv("scale_f32", "scale.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); - string_to_spv("sqr_f32", "square.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); - - string_to_spv("sqrt_f32", "sqrt.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); - - string_to_spv("sin_f32", "sin.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); - - string_to_spv("cos_f32", "cos.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); - - string_to_spv("clamp_f32", "clamp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); - string_to_spv("pad_f32", "pad.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("concat_i8", "concat.comp", {{"A_TYPE", "uint8_t"}, {"B_TYPE", "uint8_t"}, {"D_TYPE", "uint8_t"}}); @@ -885,6 +875,18 @@ void process_shaders() { string_to_spv("silu_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_silu"}}); string_to_spv("relu_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_relu"}}); string_to_spv("relu_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_relu"}}); + string_to_spv("sqr_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_sqr"}}); + string_to_spv("sqr_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_sqr"}}); + string_to_spv("sqrt_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_sqrt"}}); + string_to_spv("sqrt_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_sqrt"}}); + string_to_spv("sin_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_sin"}}); + string_to_spv("sin_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_sin"}}); + string_to_spv("cos_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_cos"}}); + string_to_spv("cos_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_cos"}}); + string_to_spv("clamp_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_clamp"}}); + string_to_spv("clamp_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_clamp"}}); + string_to_spv("leaky_relu_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_leaky_relu"}}); + string_to_spv("leaky_relu_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_leaky_relu"}}); string_to_spv("neg_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_neg"}}); string_to_spv("neg_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_neg"}}); string_to_spv("tanh_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_tanh"}}); @@ -942,7 +944,6 @@ void process_shaders() { string_to_spv("geglu_quick_f16","geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("geglu_quick_f32","geglu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index e8eafd185a..f0ec18abd9 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -4270,7 +4270,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const case GGML_OP_RMS_NORM: case GGML_OP_NORM: case GGML_OP_L2_NORM: - supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32; + supports_op = (op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32) && ggml_is_contiguous_rows(src0); break; case GGML_OP_ROPE: supports_op = op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 719ae51cc2..e1d3853e43 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -3298,21 +3298,29 @@ struct test_norm : public test_case { const std::array ne; const bool v; // whether a is a non-contiguous view const float eps; + const bool noncontig_rows; std::string vars() override { - return VARS_TO_STR4(type, ne, v, eps); + return VARS_TO_STR5(type, ne, v, eps, noncontig_rows); } test_norm(ggml_type type = GGML_TYPE_F32, std::array ne = {64, 5, 4, 3}, bool v = false, - float eps = 1e-6f) - : type(type), ne(ne), v(v), eps(eps) {} + float eps = 1e-6f, + bool noncontig_rows = false) + : type(type), ne(ne), v(v), eps(eps), noncontig_rows(noncontig_rows) {} ggml_tensor * build_graph(ggml_context * ctx) override { - ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + const std::array ne_a = noncontig_rows ? + std::array{ ne[1], ne[0], ne[2], ne[3] } : ne; + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data()); ggml_set_name(a, "a"); + if (noncontig_rows) { + a = ggml_permute(ctx, a, 1, 0, 2, 3); + ggml_set_name(a, "permuted a"); + } if (v) { a = ggml_view_4d(ctx, a, a->ne[0]/2, a->ne[1]/2, a->ne[2]/2, a->ne[3]/2, a->nb[1], a->nb[2], a->nb[3], 0); ggml_set_name(a, "view of a"); @@ -6193,21 +6201,29 @@ struct test_l2_norm : public test_case { const std::array ne; const float eps; bool v; + bool noncontig_rows; std::string vars() override { - return VARS_TO_STR4(type, ne, eps, v); + return VARS_TO_STR5(type, ne, eps, v, noncontig_rows); } test_l2_norm(ggml_type type = GGML_TYPE_F32, std::array ne = {64, 64, 320, 1}, float eps = 1e-12f, - bool v = false) - : type(type), ne(ne), eps(eps), v(v) {} + bool v = false, + bool noncontig_rows = false) + : type(type), ne(ne), eps(eps), v(v), noncontig_rows(noncontig_rows) {} ggml_tensor * build_graph(ggml_context * ctx) override { - ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + const std::array ne_a = noncontig_rows ? + std::array{ ne[1], ne[0], ne[2], ne[3] } : ne; + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data()); ggml_set_name(a, "a"); + if (noncontig_rows) { + a = ggml_permute(ctx, a, 1, 0, 2, 3); + ggml_set_name(a, "permuted a"); + } if (v) { a = ggml_view_4d(ctx, a, a->ne[0]/2, a->ne[1]/2, a->ne[2]/2, a->ne[3]/2, a->nb[1], a->nb[2], a->nb[3], 0); ggml_set_name(a, "view of a"); @@ -8282,9 +8298,11 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, v, eps)); test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, v, eps)); } + test_cases.emplace_back(new test_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, false, eps, true)); test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, { n, 5, 4, 3 }, eps)); test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, false)); test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, true)); + test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, false, true)); } } From be4a6a63eb2b848e19c277bdcf2bd399e8af76d9 Mon Sep 17 00:00:00 2001 From: kononnable Date: Tue, 23 Jun 2026 16:56:50 +0200 Subject: [PATCH 64/86] server : check draft context creation error (#24922) --- tools/server/server-context.cpp | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index ca91449d26..39b7eb218e 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -89,7 +89,9 @@ struct server_batch { } ~server_batch() { - llama_batch_free(batch); + if (batch.token != nullptr) { + llama_batch_free(batch); + } } void init(int32_t n_tokens_alloc) { @@ -1215,6 +1217,10 @@ private: cparams.ctx_other = ctx_tgt; ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams)); + if (ctx_dft == nullptr) { + SRV_ERR("%s", "failed to create draft context\n"); + return false; + } params_base.speculative.draft.ctx_tgt = ctx_tgt; params_base.speculative.draft.ctx_dft = ctx_dft.get(); From ac4105d68b2955027115cf9bb50941ccf56974eb Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Tue, 23 Jun 2026 22:34:00 -0500 Subject: [PATCH 65/86] vulkan: Apply bias before softmax in FA, to avoid overflow (#24909) --- ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp | 1 + ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp | 1 + 2 files changed, 2 insertions(+) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 91fb07c93e..3192130ccf 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -463,6 +463,7 @@ void main() { } rowmaxf = max(rowmaxf, float(Sf[r][c])); } + rowmaxf += FATTN_KQ_MAX_OFFSET; float Moldf = Mf[r]; // M = max(rowmax, Mold) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index 23ae3833e5..16178e5770 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -352,6 +352,7 @@ void main() { } rowmaxf = max(rowmaxf, float(sfsh[r_vec + (c * cols_per_iter + col_tid) * sfshstride][r_comp])); } + rowmaxf += FATTN_KQ_MAX_OFFSET; float Moldf = Mf[r]; // Compute max across the row From 88636e178ff2972e1002cf2024cb39008eda1192 Mon Sep 17 00:00:00 2001 From: Tarek Dakhran Date: Wed, 24 Jun 2026 08:49:46 +0200 Subject: [PATCH 66/86] model : Add LFM2.5-ColBERT-350M and LFM2.5-Embedding-350M (#24913) * model : Add LFM2.5-ColBERT-350M and LFM2.5-Embedding-350M * Restore LFM2 models in README.md --- README.md | 4 +++- conversion/__init__.py | 1 + conversion/lfm2.py | 13 ++++++++++--- src/models/lfm2.cpp | 18 ++++++++++++++---- 4 files changed, 28 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 0652d13f29..e98f2b7f18 100644 --- a/README.md +++ b/README.md @@ -142,7 +142,9 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo - [x] [GigaChat-20B-A3B](https://huggingface.co/ai-sage/GigaChat-20B-A3B-instruct) - [X] [Trillion-7B-preview](https://huggingface.co/trillionlabs/Trillion-7B-preview) - [x] [Ling models](https://huggingface.co/collections/inclusionAI/ling-67c51c85b34a7ea0aba94c32) -- [x] [LFM2 models](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38) +- [x] [Liquid LFM2 models](https://huggingface.co/collections/LiquidAI/lfm2) +- [x] [Liquid LFM2.5 models](https://huggingface.co/collections/LiquidAI/lfm25) +- [x] [Liquid Nanos](https://huggingface.co/collections/LiquidAI/liquid-nanos) - [x] [Hunyuan models](https://huggingface.co/collections/tencent/hunyuan-dense-model-6890632cda26b19119c9c5e7) - [x] [BailingMoeV2 (Ring/Ling 2.0) models](https://huggingface.co/collections/inclusionAI/ling-v2-68bf1dd2fc34c306c1fa6f86) - [x] [Mellum models](https://huggingface.co/JetBrains/models?search=mellum) diff --git a/conversion/__init__.py b/conversion/__init__.py index c6af6f7318..2bce1bbd7c 100644 --- a/conversion/__init__.py +++ b/conversion/__init__.py @@ -124,6 +124,7 @@ TEXT_MODEL_MAP: dict[str, str] = { "LLaDAModelLM": "llada", "LLaMAForCausalLM": "llama", "Lfm25AudioTokenizer": "lfm2", + "Lfm2BidirectionalModel": "lfm2", "Lfm2ForCausalLM": "lfm2", "Lfm2Model": "lfm2", "Lfm2MoeForCausalLM": "lfm2", diff --git a/conversion/lfm2.py b/conversion/lfm2.py index f28fccf10f..70ce45658b 100644 --- a/conversion/lfm2.py +++ b/conversion/lfm2.py @@ -64,11 +64,17 @@ class LFM2Model(TextModel): yield from super().modify_tensors(data_torch, name, bid) -@ModelBase.register("Lfm2Model") +@ModelBase.register("Lfm2Model", "Lfm2BidirectionalModel") class LFM2ColBertModel(LFM2Model): model_arch = gguf.MODEL_ARCH.LFM2 dense_tensor_name = "dense_2" + def set_gguf_parameters(self): + super().set_gguf_parameters() + if self.hf_arch == "Lfm2BidirectionalModel": + self.gguf_writer.add_causal_attention(False) + self._try_set_pooling_type() + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: if not name.startswith(self.dense_tensor_name): name = "model." + name @@ -76,10 +82,11 @@ class LFM2ColBertModel(LFM2Model): yield from super().modify_tensors(data_torch, name, bid) def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: - # dense tensor is stored in a separate safetensors file + # optional dense tensor is stored in a separate safetensors file from safetensors.torch import load_file tensors_file = self.dir_model / "1_Dense" / "model.safetensors" - assert tensors_file.is_file() + if not tensors_file.is_file(): + return tensor = load_file(tensors_file)["linear.weight"] self.gguf_writer.add_embedding_length_out(tensor.shape[0]) yield f"{self.dense_tensor_name}.weight", tensor.clone() diff --git a/src/models/lfm2.cpp b/src/models/lfm2.cpp index 97da8a6abb..07b7346ee4 100644 --- a/src/models/lfm2.cpp +++ b/src/models/lfm2.cpp @@ -190,7 +190,15 @@ llama_model_lfm2::graph::graph(const llama_model & model, const llm_graph_ auto * conv_rs = build_rs(inp_recr, conv_state, hparams.n_embd_r(), n_seqs); auto * conv = ggml_reshape_3d(ctx0, conv_rs, d_conv, hparams.n_embd, n_seqs); - bx = ggml_concat(ctx0, conv, bx, 0); + // causal prepends the state, non-causal pads symmetrically for a centered window + if (hparams.causal_attn) { + bx = ggml_concat(ctx0, conv, bx, 0); + } else { + const int64_t pad = (hparams.n_shortconv_l_cache - 1) / 2; + auto * left = ggml_cont(ctx0, + ggml_view_3d(ctx0, conv, pad, hparams.n_embd, n_seqs, conv->nb[1], conv->nb[2], (d_conv - pad) * conv->nb[0])); + bx = ggml_pad_ext(ctx0, ggml_concat(ctx0, left, bx, 0), 0, pad, 0, 0, 0, 0, 0, 0); + } GGML_ASSERT(bx->ne[0] > conv->ne[0]); // last d_conv columns is a new conv state @@ -266,10 +274,12 @@ llama_model_lfm2::graph::graph(const llama_model & model, const llm_graph_ cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur, model.output_s); - cb(cur, "result_output", -1); + if (!cparams.embeddings) { + cur = build_lora_mm(model.output, cur, model.output_s); + cb(cur, "result_output", -1); - res->t_logits = cur; + res->t_logits = cur; + } ggml_build_forward_expand(gf, cur); } From ef9c13d4c27f1ad9804341a0187c33b621609c63 Mon Sep 17 00:00:00 2001 From: Aleksander Grygier Date: Wed, 24 Jun 2026 10:21:33 +0200 Subject: [PATCH 67/86] ui: New Logo + Navigation cleanup & Mobile UI/UX improvements (#24897) * chore: `npm audit fix --force` * feat: Update sidebar toggle to use Logo * refactor: Clean up favicon SVG * feat: Refactor logo component and implement theme-aware favicon generation * feat: Add configurable padding to generated PWA assets * test: Add unit tests for writeThemeFavicons * refactor: Componentization * feat: WIP * feat: WIP * feat: WIP * feat: Mobile UI * feat: add SEARCH route constant * feat: create SidebarNavigationSearchResults component * refactor: use SidebarNavigationSearchResults in conversation list * feat: enable mobile search navigation in sidebar actions * feat: add mobile search route and page * fix: prevent sidebar overflow on mobile viewports * fix: Mobile sidebar * feat: Mobile Search WIP * feat: Mobile WIP * feat: Add PWA standalone detection and refine mobile UI * feat: Improve mobile layout, sidebar handling, and chat scrolling * feat: Improve mobile sidebar visibility and iOS Safari chat spacing * fix: Disable auto-scroll on mobile * chore: Linting * fix: Wrong condition * feat: Mobile chat scroll * refactor: WIP * fix: Desktop initial scroll always working again * fix: Partial fix for mobile auto-scroll / initial scroll * fix: Desktop auto-scroll on initial load and during streaming * fix: Mobile scrolling logic * refactor: Clean up * feat: Improve start UI * feat: Add `delay` to `fadeInView` * feat: Auto-scroll button * refactor: Cleanup * refactor: Extract chat dialogs and alerts into dedicated component * refactor: Reorganize ChatScreen component structure and initialization * feat: Improve auto-scroll after sending message * feat: UI improvements * fix: Settings link * feat: UI improvements * fix: better UI spacing * fix: Remove unneeded logic * fix: Chat Processing Info UI rendering * feat: Improve mobile UI * feat: UI improvement * fix: Conditional transition delay for Chat Messages based on route from * fix: Delay mobile sidebar collapse for smoother transitions * fix: Mobile scroll down button + sidebar pointer events * fix: Mobile UI * fix: Auto scrolling * fix: Implement dynamic height calculations for chat auto-scroll positioning and UI elements * fix: Retrieve `autofocus` for Chat Form textarea * fix: Use proper class Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> * refactor: extract scroll-to-bottom logic and fix message send flow * fix: update viewport store usage and remove conflicting autofocus * feat: add accessibility labels to scroll down button * fix: correct HTML structure in sidebar empty states * fix: dynamically toggle processing info visibility * chore: remove commented exports and fix formatting * fix * fix: Mobile Chat Form Add Action Sheet interactions --------- Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- tools/ui/.gitignore | 3 +- tools/ui/package-lock.json | 14 +- tools/ui/package.json | 2 +- tools/ui/pwa-assets-dark.config.ts | 9 +- tools/ui/pwa-assets.config.ts | 21 +- tools/ui/scripts/favicon-colorize.ts | 107 ++++ tools/ui/src/app.css | 7 +- .../ui/src/lib/actions/fade-in-view.svelte.ts | 14 +- tools/ui/src/lib/assets/logo.svg | 7 + .../components/app/actions/ActionIcon.svelte | 75 ++- .../app/chat/ChatForm/ChatForm.svelte | 4 +- .../ChatFormActionAddButton.svelte | 2 +- .../ChatFormActionAddSheet.svelte | 19 +- .../ChatFormActionsAdd.svelte | 1 + .../ChatFormActionSubmit.svelte | 2 +- .../app/chat/ChatForm/ChatFormTextarea.svelte | 5 +- .../ChatMessage/ChatMessage.svelte | 4 +- .../ChatMessageAssistant.svelte | 64 +- .../ChatMessageUser/ChatMessageUser.svelte | 2 +- .../ChatMessageUserBubble.svelte | 4 +- .../app/chat/ChatMessages/ChatMessages.svelte | 11 +- .../app/chat/ChatScreen/ChatScreen.svelte | 579 +++++++----------- .../ChatScreenActionScrollDown.svelte | 64 +- .../ChatScreenDialogsAndAlerts.svelte | 55 ++ .../app/chat/ChatScreen/ChatScreenForm.svelte | 36 +- .../chat/ChatScreen/ChatScreenGreeting.svelte | 6 +- .../ChatScreenProcessingInfo.svelte | 9 +- tools/ui/src/lib/components/app/chat/index.ts | 7 - .../components/app/forms/SearchInput.svelte | 5 +- .../src/lib/components/app/misc/Logo.svelte | 15 + tools/ui/src/lib/components/app/misc/index.ts | 8 + .../app/navigation/DesktopIconStrip.svelte | 84 --- .../SidebarNavigation.svelte | 529 +++++++--------- .../SidebarNavigationActions.svelte | 214 +++++-- .../SidebarNavigationConversationList.svelte | 135 ++++ .../SidebarNavigationSearch.svelte | 4 +- .../SidebarNavigationSearchResults.svelte | 76 +++ .../lib/components/app/navigation/index.ts | 68 +- .../settings/SettingsChat/SettingsChat.svelte | 5 +- .../SettingsChatDesktopSidebar.svelte | 10 +- .../app/settings/SettingsMcpServers.svelte | 51 +- .../lib/components/ui/button/button.svelte | 2 +- .../lib/components/ui/sidebar/constants.ts | 7 - .../components/ui/sidebar/context.svelte.ts | 79 --- .../ui/src/lib/components/ui/sidebar/index.ts | 75 --- .../ui/sidebar/sidebar-content.svelte | 24 - .../ui/sidebar/sidebar-footer.svelte | 21 - .../ui/sidebar/sidebar-group-action.svelte | 36 -- .../ui/sidebar/sidebar-group-content.svelte | 21 - .../ui/sidebar/sidebar-group-label.svelte | 34 - .../ui/sidebar/sidebar-group.svelte | 21 - .../ui/sidebar/sidebar-header.svelte | 21 - .../ui/sidebar/sidebar-input.svelte | 21 - .../ui/sidebar/sidebar-inset.svelte | 24 - .../ui/sidebar/sidebar-menu-action.svelte | 43 -- .../ui/sidebar/sidebar-menu-badge.svelte | 29 - .../ui/sidebar/sidebar-menu-button.svelte | 106 ---- .../ui/sidebar/sidebar-menu-item.svelte | 21 - .../ui/sidebar/sidebar-menu-skeleton.svelte | 36 -- .../ui/sidebar/sidebar-menu-sub-button.svelte | 43 -- .../ui/sidebar/sidebar-menu-sub-item.svelte | 21 - .../ui/sidebar/sidebar-menu-sub.svelte | 25 - .../components/ui/sidebar/sidebar-menu.svelte | 21 - .../ui/sidebar/sidebar-provider.svelte | 51 -- .../components/ui/sidebar/sidebar-rail.svelte | 36 -- .../ui/sidebar/sidebar-separator.svelte | 19 - .../ui/sidebar/sidebar-trigger.svelte | 43 -- .../lib/components/ui/sidebar/sidebar.svelte | 150 ----- tools/ui/src/lib/constants/pwa.ts | 15 +- tools/ui/src/lib/constants/routes.ts | 4 +- tools/ui/src/lib/constants/ui.ts | 7 +- .../src/lib/hooks/use-auto-scroll.svelte.ts | 5 +- .../use-chat-screen-active-model.svelte.ts | 100 +++ .../use-chat-screen-drag-and-drop.svelte.ts | 72 +++ .../use-chat-screen-file-upload.svelte.ts | 104 ++++ .../hooks/use-chat-screen-scroll.svelte.ts | 47 ++ .../lib/hooks/use-models-selector.svelte.ts | 2 +- tools/ui/src/lib/stores/device.svelte.ts | 72 +++ tools/ui/src/lib/stores/models.svelte.ts | 2 +- tools/ui/src/routes/+layout.svelte | 94 +-- tools/ui/src/routes/search/+page.svelte | 95 +++ tools/ui/src/routes/settings/+layout.svelte | 14 +- .../routes/settings/[[section]]/+page.svelte | 5 +- tools/ui/static/favicon-dark.svg | 14 - tools/ui/static/favicon.svg | 14 - .../client/components/TestWrapper.svelte | 7 +- .../stories/SidebarNavigation.stories.svelte | 31 +- tools/ui/tests/unit/favicon-colorize.test.ts | 198 ++++++ 88 files changed, 2122 insertions(+), 2147 deletions(-) create mode 100644 tools/ui/scripts/favicon-colorize.ts create mode 100644 tools/ui/src/lib/assets/logo.svg create mode 100644 tools/ui/src/lib/components/app/chat/ChatScreen/ChatScreenDialogsAndAlerts.svelte create mode 100644 tools/ui/src/lib/components/app/misc/Logo.svelte delete mode 100644 tools/ui/src/lib/components/app/navigation/DesktopIconStrip.svelte create mode 100644 tools/ui/src/lib/components/app/navigation/SidebarNavigation/SidebarNavigationConversationList.svelte create mode 100644 tools/ui/src/lib/components/app/navigation/SidebarNavigation/SidebarNavigationSearchResults.svelte delete mode 100644 tools/ui/src/lib/components/ui/sidebar/constants.ts delete mode 100644 tools/ui/src/lib/components/ui/sidebar/context.svelte.ts delete mode 100644 tools/ui/src/lib/components/ui/sidebar/index.ts delete mode 100644 tools/ui/src/lib/components/ui/sidebar/sidebar-content.svelte delete mode 100644 tools/ui/src/lib/components/ui/sidebar/sidebar-footer.svelte delete mode 100644 tools/ui/src/lib/components/ui/sidebar/sidebar-group-action.svelte delete mode 100644 tools/ui/src/lib/components/ui/sidebar/sidebar-group-content.svelte delete mode 100644 tools/ui/src/lib/components/ui/sidebar/sidebar-group-label.svelte delete mode 100644 tools/ui/src/lib/components/ui/sidebar/sidebar-group.svelte delete mode 100644 tools/ui/src/lib/components/ui/sidebar/sidebar-header.svelte delete mode 100644 tools/ui/src/lib/components/ui/sidebar/sidebar-input.svelte delete mode 100644 tools/ui/src/lib/components/ui/sidebar/sidebar-inset.svelte delete mode 100644 tools/ui/src/lib/components/ui/sidebar/sidebar-menu-action.svelte delete mode 100644 tools/ui/src/lib/components/ui/sidebar/sidebar-menu-badge.svelte delete mode 100644 tools/ui/src/lib/components/ui/sidebar/sidebar-menu-button.svelte delete mode 100644 tools/ui/src/lib/components/ui/sidebar/sidebar-menu-item.svelte delete mode 100644 tools/ui/src/lib/components/ui/sidebar/sidebar-menu-skeleton.svelte delete mode 100644 tools/ui/src/lib/components/ui/sidebar/sidebar-menu-sub-button.svelte delete mode 100644 tools/ui/src/lib/components/ui/sidebar/sidebar-menu-sub-item.svelte delete mode 100644 tools/ui/src/lib/components/ui/sidebar/sidebar-menu-sub.svelte delete mode 100644 tools/ui/src/lib/components/ui/sidebar/sidebar-menu.svelte delete mode 100644 tools/ui/src/lib/components/ui/sidebar/sidebar-provider.svelte delete mode 100644 tools/ui/src/lib/components/ui/sidebar/sidebar-rail.svelte delete mode 100644 tools/ui/src/lib/components/ui/sidebar/sidebar-separator.svelte delete mode 100644 tools/ui/src/lib/components/ui/sidebar/sidebar-trigger.svelte delete mode 100644 tools/ui/src/lib/components/ui/sidebar/sidebar.svelte create mode 100644 tools/ui/src/lib/hooks/use-chat-screen-active-model.svelte.ts create mode 100644 tools/ui/src/lib/hooks/use-chat-screen-drag-and-drop.svelte.ts create mode 100644 tools/ui/src/lib/hooks/use-chat-screen-file-upload.svelte.ts create mode 100644 tools/ui/src/lib/hooks/use-chat-screen-scroll.svelte.ts create mode 100644 tools/ui/src/lib/stores/device.svelte.ts create mode 100644 tools/ui/src/routes/search/+page.svelte delete mode 100644 tools/ui/static/favicon-dark.svg delete mode 100644 tools/ui/static/favicon.svg create mode 100644 tools/ui/tests/unit/favicon-colorize.test.ts diff --git a/tools/ui/.gitignore b/tools/ui/.gitignore index ddcfe2e60f..0bb8c9b3c2 100644 --- a/tools/ui/.gitignore +++ b/tools/ui/.gitignore @@ -28,10 +28,9 @@ vite.config.ts.timestamp-* # PWA Artifacts apple-splash-*.png apple-touch-icon-*.png -favicon.ico -favicon-dark.ico maskable-icon-*.png pwa-*.png +static/favicon* # Storybook *storybook.log diff --git a/tools/ui/package-lock.json b/tools/ui/package-lock.json index 9d0cdfea6c..9dce3a0c9d 100644 --- a/tools/ui/package-lock.json +++ b/tools/ui/package-lock.json @@ -35,7 +35,7 @@ "bits-ui": "2.18.1", "clsx": "2.1.1", "dexie": "4.4.3", - "dompurify": "3.4.5", + "dompurify": "3.4.11", "eslint": "9.39.4", "eslint-config-prettier": "10.1.8", "eslint-plugin-storybook": "10.4.2", @@ -8653,9 +8653,9 @@ "peer": true }, "node_modules/dompurify": { - "version": "3.4.5", - "resolved": "https://registry.npmjs.org/dompurify/-/dompurify-3.4.5.tgz", - "integrity": "sha512-OrwIBKsdNSVEeubdJ1HBv/wNENRM9ytAVCv7YXt//A3vPdVMNuACRqK9mXCGCBW2ln7BT/A4X0jXHo2Gu89miA==", + "version": "3.4.11", + "resolved": "https://registry.npmjs.org/dompurify/-/dompurify-3.4.11.tgz", + "integrity": "sha512-zhlUV12GsaRzMsf9q5M254YhA4+VuF0fG+QFqu6aYpoGlKtz+w8//jBcGVYBgQkR5GHjUomejY84AV+/uPbWdw==", "dev": true, "license": "(MPL-2.0 OR Apache-2.0)", "optionalDependencies": { @@ -10226,9 +10226,9 @@ } }, "node_modules/hono": { - "version": "4.12.23", - "resolved": "https://registry.npmjs.org/hono/-/hono-4.12.23.tgz", - "integrity": "sha512-eIaZ9qDgu7XV0pxOCrg7/WhnQ6Ivm22UcxhXx/A3dcbqbbYgBEkc6e/J/s7j2tS96zoB0S9VBdLwQNCWwUo4LA==", + "version": "4.12.26", + "resolved": "https://registry.npmjs.org/hono/-/hono-4.12.26.tgz", + "integrity": "sha512-uyZtpnYxM9CmQ7QsQknM4zN8EftNqhON1qYeIKM0Se67CCEe2c44xyGURwB0axX2fBDu1dqHrHAc1hmNT8ITkw==", "dev": true, "license": "MIT", "engines": { diff --git a/tools/ui/package.json b/tools/ui/package.json index 4803922889..bcb4165d10 100644 --- a/tools/ui/package.json +++ b/tools/ui/package.json @@ -54,7 +54,7 @@ "bits-ui": "2.18.1", "clsx": "2.1.1", "dexie": "4.4.3", - "dompurify": "3.4.5", + "dompurify": "3.4.11", "eslint": "9.39.4", "eslint-config-prettier": "10.1.8", "eslint-plugin-storybook": "10.4.2", diff --git a/tools/ui/pwa-assets-dark.config.ts b/tools/ui/pwa-assets-dark.config.ts index e9793bca9e..358c0ebc07 100644 --- a/tools/ui/pwa-assets-dark.config.ts +++ b/tools/ui/pwa-assets-dark.config.ts @@ -1,4 +1,10 @@ import { defineConfig } from '@vite-pwa/assets-generator/config'; +import { FAVICON_COLORS, PWA_ASSET_GENERATOR } from './src/lib/constants/pwa'; +import { writeThemeFavicons } from './scripts/favicon-colorize'; + +writeThemeFavicons(FAVICON_COLORS.LIGHT, FAVICON_COLORS.DARK, { + padding: PWA_ASSET_GENERATOR.FAVICON_PADDING +}); export default defineConfig({ headLinkOptions: { @@ -7,7 +13,8 @@ export default defineConfig({ preset: { transparent: { sizes: [], - favicons: [[48, 'favicon-dark.ico']] + favicons: [[48, 'favicon-dark.ico']], + padding: PWA_ASSET_GENERATOR.FAVICON_PADDING }, maskable: { sizes: [] diff --git a/tools/ui/pwa-assets.config.ts b/tools/ui/pwa-assets.config.ts index 54928eeb4a..b69884d94a 100644 --- a/tools/ui/pwa-assets.config.ts +++ b/tools/ui/pwa-assets.config.ts @@ -5,15 +5,32 @@ import { } from '@vite-pwa/assets-generator/config'; import { readFileSync } from 'node:fs'; import { resolve } from 'node:path'; -import { THEME_COLORS, PWA_GENERATOR_DEVICES, PWA_ASSET_GENERATOR } from './src/lib/constants/pwa'; +import { + THEME_COLORS, + PWA_GENERATOR_DEVICES, + PWA_ASSET_GENERATOR, + FAVICON_COLORS +} from './src/lib/constants/pwa'; import { SplashOrientation } from './src/lib/enums/splash.enums'; +import { writeThemeFavicons } from './scripts/favicon-colorize'; + +writeThemeFavicons(FAVICON_COLORS.LIGHT, FAVICON_COLORS.DARK, { + padding: PWA_ASSET_GENERATOR.FAVICON_PADDING +}); export default defineConfig({ headLinkOptions: { preset: PWA_ASSET_GENERATOR.LINK_PRESET }, preset: combinePresetAndAppleSplashScreens( - minimal2023Preset, + { + ...minimal2023Preset, + // tiny margin so favicon.ico / pwa-*.png breathe inside the canvas + transparent: { + ...minimal2023Preset.transparent, + padding: PWA_ASSET_GENERATOR.FAVICON_PADDING + } + }, { padding: PWA_ASSET_GENERATOR.SPLASH_PADDING, resizeOptions: { diff --git a/tools/ui/scripts/favicon-colorize.ts b/tools/ui/scripts/favicon-colorize.ts new file mode 100644 index 0000000000..e1872b7774 --- /dev/null +++ b/tools/ui/scripts/favicon-colorize.ts @@ -0,0 +1,107 @@ +import { mkdirSync, readFileSync, writeFileSync } from 'node:fs'; +import { dirname, resolve } from 'node:path'; +import { fileURLToPath } from 'node:url'; + +const HERE = dirname(fileURLToPath(import.meta.url)); +const PROJECT_ROOT = resolve(HERE, '..'); + +const DEFAULT_LOGO = resolve(PROJECT_ROOT, 'src/lib/assets/logo.svg'); +const DEFAULT_OUT_DIR = resolve(PROJECT_ROOT, 'static'); +const DEFAULT_OUT_LIGHT = resolve(DEFAULT_OUT_DIR, 'favicon.svg'); +const DEFAULT_OUT_DARK = resolve(DEFAULT_OUT_DIR, 'favicon-dark.svg'); + +const CURRENT_COLOR = 'currentColor'; + +export interface ColorizedFavicon { + light: string; + dark: string; +} + +export interface WriteThemeFaviconsOptions { + sourcePath?: string; + lightOutPath?: string; + darkOutPath?: string; + /** + * Fraction of the icon (0..1) to leave as an even margin on each side. + * Applied by wrapping the inner content in a `` so the + * source `src/lib/assets/logo.svg` is not modified. Pass 0 to disable. + */ + padding?: number; +} + +/** + * Replace every `currentColor` occurrence in the SVG with the given color. + * Pure: no filesystem access, so it is straightforward to unit-test. + */ +export function colorizeFaviconSvg( + svg: string, + lightColor: string, + darkColor: string +): ColorizedFavicon { + return { + light: svg.replaceAll(CURRENT_COLOR, lightColor), + dark: svg.replaceAll(CURRENT_COLOR, darkColor) + }; +} + +/** + * Shrink the inner SVG content uniformly and re-center it so `padding` (a + * 0..1 fraction) is reserved as equal margin on each side. Returns the input + * unchanged for non-positive padding, missing/invalid `viewBox`, or unexpected + * markup so the caller always gets a renderable SVG. + */ +export function padFaviconSvg(svg: string, padding: number): string { + if (!(padding > 0) || padding >= 1) return svg; + + const viewBoxMatch = svg.match(/viewBox\s*=\s*["']([^"']+)["']/i); + if (!viewBoxMatch) return svg; + + const parts = viewBoxMatch[1] + .trim() + .split(/[\s,]+/) + .map(Number); + if (parts.length !== 4 || parts.some((n) => !Number.isFinite(n))) return svg; + + const [, , width, height] = parts; + if (width <= 0 || height <= 0) return svg; + + const scale = 1 - padding; + const translateX = (padding * width) / 2; + const translateY = (padding * height) / 2; + + const openTagStart = svg.search(/', openTagStart); + if (openTagEnd === -1) return svg; + const closeStart = svg.lastIndexOf('`; + return `${openTag}${group}${inner}${closeTag}`; +} + +/** + * Read `src/lib/assets/logo.svg`, colorize it for both themes, and write + * the results to the static directory so the PWA asset generator can consume + * them. Paths can be overridden for tests. + */ +export function writeThemeFavicons( + lightColor: string, + darkColor: string, + { + sourcePath = DEFAULT_LOGO, + lightOutPath = DEFAULT_OUT_LIGHT, + darkOutPath = DEFAULT_OUT_DARK, + padding = 0 + }: WriteThemeFaviconsOptions = {} +): void { + const source = readFileSync(sourcePath, 'utf-8'); + const { light, dark } = colorizeFaviconSvg(source, lightColor, darkColor); + mkdirSync(dirname(lightOutPath), { recursive: true }); + writeFileSync(lightOutPath, padFaviconSvg(light, padding)); + writeFileSync(darkOutPath, padFaviconSvg(dark, padding)); +} diff --git a/tools/ui/src/app.css b/tools/ui/src/app.css index 9254a96df7..8c4056477d 100644 --- a/tools/ui/src/app.css +++ b/tools/ui/src/app.css @@ -48,6 +48,7 @@ --chat-form-area-height: 8rem; --chat-form-area-offset: 2rem; + --chat-form-padding-top: 6rem; --max-message-height: max(24rem, min(80dvh, calc(100dvh - var(--chat-form-area-height) - 12rem))); } @@ -55,6 +56,7 @@ :root { --chat-form-area-height: 24rem; --chat-form-area-offset: 12rem; + --chat-form-padding-top: 6rem; } } @@ -141,7 +143,6 @@ @apply bg-background text-foreground; scrollbar-width: thin; scrollbar-gutter: stable; - overflow: hidden; /* Added due to Mermaid rendering somehow causing the double scrollbar */ } /* Global scrollbar styling - visible only on hover */ @@ -193,3 +194,7 @@ scrollbar-width: none; } } + +.mermaidTooltip { + display: none !important; +} diff --git a/tools/ui/src/lib/actions/fade-in-view.svelte.ts b/tools/ui/src/lib/actions/fade-in-view.svelte.ts index d930448050..9a5918131a 100644 --- a/tools/ui/src/lib/actions/fade-in-view.svelte.ts +++ b/tools/ui/src/lib/actions/fade-in-view.svelte.ts @@ -10,9 +10,9 @@ import { isElementInViewport } from '$lib/utils/viewport'; */ export function fadeInView( node: HTMLElement, - options: { duration?: number; y?: number; skipIfVisible?: boolean } = {} + options: { duration?: number; y?: number; delay?: number; skipIfVisible?: boolean } = {} ) { - const { duration = 300, y = 0, skipIfVisible = false } = options; + const { duration = 300, y = 0, delay = 0, skipIfVisible = false } = options; if (skipIfVisible && isElementInViewport(node)) { return; @@ -27,10 +27,12 @@ export function fadeInView( (entries) => { for (const entry of entries) { if (entry.isIntersecting) { - requestAnimationFrame(() => { - node.style.opacity = '1'; - node.style.transform = 'translateY(0)'; - }); + setTimeout(() => { + requestAnimationFrame(() => { + node.style.opacity = '1'; + node.style.transform = 'translateY(0)'; + }); + }, delay); observer.disconnect(); } } diff --git a/tools/ui/src/lib/assets/logo.svg b/tools/ui/src/lib/assets/logo.svg new file mode 100644 index 0000000000..05424790af --- /dev/null +++ b/tools/ui/src/lib/assets/logo.svg @@ -0,0 +1,7 @@ + + + + + + + diff --git a/tools/ui/src/lib/components/app/actions/ActionIcon.svelte b/tools/ui/src/lib/components/app/actions/ActionIcon.svelte index f156df6699..8a86557bb9 100644 --- a/tools/ui/src/lib/components/app/actions/ActionIcon.svelte +++ b/tools/ui/src/lib/components/app/actions/ActionIcon.svelte @@ -8,12 +8,13 @@ ariaLabel?: string; class?: string; disabled?: boolean; + href?: string; icon: Component; iconSize?: string; - onclick: (e?: MouseEvent) => void; + onclick?: (e?: MouseEvent) => void; size?: ButtonSize; stopPropagationOnClick?: boolean; - tooltip: string; + tooltip?: string; variant?: ButtonVariant; tooltipSide?: TooltipSide; } @@ -22,6 +23,7 @@ icon, tooltip, variant = 'ghost', + href = '', size = 'sm', class: className = '', disabled = false, @@ -31,34 +33,49 @@ onclick, ariaLabel }: Props = $props(); + + let innerWidth = $state(0); + const showTooltip = $derived(!!tooltip && innerWidth > 768); - - - - {#snippet child({ props })} - - {/snippet} - + onclick?.(e); + }} + class="h-6 w-6 p-0 {className} flex hover:bg-transparent data-[state=open]:bg-transparent!" + aria-label={ariaLabel || tooltip} + > + {#if icon} + {@const IconComponent = icon} - -

{tooltip}

-
-
+ + {/if} + +{/snippet} + +{#if showTooltip} + + + + {#snippet child({ props })} + {@render button(props)} + {/snippet} + + + +

{tooltip}

+
+
+{:else} + {@render button({ href })} +{/if} + + diff --git a/tools/ui/src/lib/components/app/chat/ChatForm/ChatForm.svelte b/tools/ui/src/lib/components/app/chat/ChatForm/ChatForm.svelte index ed26f9ea58..9b2077b8dc 100644 --- a/tools/ui/src/lib/components/app/chat/ChatForm/ChatForm.svelte +++ b/tools/ui/src/lib/components/app/chat/ChatForm/ChatForm.svelte @@ -494,7 +494,7 @@ />
{#if hasMcpPromptsSupport} - +{#if innerWidth > 768 || (!page.url.hash.includes(ROUTES.SETTINGS) && !page.url.hash.includes(ROUTES.MCP_SERVERS) && !page.url.hash.includes(ROUTES.SEARCH))} + +{/if} + + diff --git a/tools/ui/src/lib/components/app/navigation/SidebarNavigation/SidebarNavigationActions.svelte b/tools/ui/src/lib/components/app/navigation/SidebarNavigation/SidebarNavigationActions.svelte index f0d63970ee..f118a68dfc 100644 --- a/tools/ui/src/lib/components/app/navigation/SidebarNavigation/SidebarNavigationActions.svelte +++ b/tools/ui/src/lib/components/app/navigation/SidebarNavigation/SidebarNavigationActions.svelte @@ -1,39 +1,86 @@ @@ -41,56 +88,109 @@ {/snippet} -
- {#if isSearchModeActive} +{#if isSearchModeActive} +
e.key === 'Escape' && handleSearchModeDeactivate()} placeholder="Search conversations..." - {isCancelAlwaysVisible} /> - {:else} - {#each SIDEBAR_ACTIONS_ITEMS as item (item.route)} - {#if !item.route} - - {:else} - + {#if item.keys} + + {/if} + +
{/if} {/each} - {/if} -
+
+{:else} + +{/if} diff --git a/tools/ui/src/lib/components/app/navigation/SidebarNavigation/SidebarNavigationConversationList.svelte b/tools/ui/src/lib/components/app/navigation/SidebarNavigation/SidebarNavigationConversationList.svelte new file mode 100644 index 0000000000..488e96bbcf --- /dev/null +++ b/tools/ui/src/lib/components/app/navigation/SidebarNavigation/SidebarNavigationConversationList.svelte @@ -0,0 +1,135 @@ + + +{#if isSearchModeActive} + +{:else} + {#if pinnedConversations.length > 0} +
+
+ + + Pinned +
+
+ +
    + {#each pinnedConversations as { conversation, depth } (conversation.id)} +
  • + +
  • + {/each} +
+ {/if} + +
+ {#if filteredConversations.length > 0} +
+ Recent conversations +
+ {/if} + +
+
    + {#each unpinnedConversations as { conversation, depth } (conversation.id)} +
  • + +
  • + {/each} + + {#if unpinnedConversations.length === 0} +
  • +

    + {recentEmptyMessage} +

    +
  • + {/if} +
+
+
+{/if} diff --git a/tools/ui/src/lib/components/app/navigation/SidebarNavigation/SidebarNavigationSearch.svelte b/tools/ui/src/lib/components/app/navigation/SidebarNavigation/SidebarNavigationSearch.svelte index afc9847028..491e7c3479 100644 --- a/tools/ui/src/lib/components/app/navigation/SidebarNavigation/SidebarNavigationSearch.svelte +++ b/tools/ui/src/lib/components/app/navigation/SidebarNavigation/SidebarNavigationSearch.svelte @@ -16,4 +16,6 @@ }: Props = $props(); - +
+ +
diff --git a/tools/ui/src/lib/components/app/navigation/SidebarNavigation/SidebarNavigationSearchResults.svelte b/tools/ui/src/lib/components/app/navigation/SidebarNavigation/SidebarNavigationSearchResults.svelte new file mode 100644 index 0000000000..92d8fd0bda --- /dev/null +++ b/tools/ui/src/lib/components/app/navigation/SidebarNavigation/SidebarNavigationSearchResults.svelte @@ -0,0 +1,76 @@ + + +
+ {#if showHeader} +
+ Search results +
+ {/if} + +
+
    + {#each tree as { conversation, depth } (conversation.id)} +
  • + +
  • + {/each} + + {#if tree.length === 0} +
  • +

    + {emptyMessage} +

    +
  • + {/if} +
+
+
diff --git a/tools/ui/src/lib/components/app/navigation/index.ts b/tools/ui/src/lib/components/app/navigation/index.ts index d4ca914594..e07dde6bc9 100644 --- a/tools/ui/src/lib/components/app/navigation/index.ts +++ b/tools/ui/src/lib/components/app/navigation/index.ts @@ -63,15 +63,6 @@ export { default as DropdownMenuSearchable } from './DropdownMenuSearchable.svel * ``` */ export { default as DropdownMenuActions } from './DropdownMenuActions.svelte'; - -/** - * **DesktopIconStrip** - Fixed icon strip for desktop sidebar - * - * Vertical icon strip shown on desktop when the sidebar is collapsed. - * Contains navigation shortcuts for new chat, search, MCP, import/export, and settings. - */ -export { default as DesktopIconStrip } from './DesktopIconStrip.svelte'; - /** * **SidebarNavigation** - Sidebar with actions menu and conversation list * @@ -115,13 +106,6 @@ export { default as DesktopIconStrip } from './DesktopIconStrip.svelte'; */ export { default as SidebarNavigation } from './SidebarNavigation/SidebarNavigation.svelte'; -/** - * Action buttons for sidebar header. Contains new chat button, settings button, - * and delete all conversations button. Manages dialog states for settings and - * delete confirmation. - */ -export { default as SidebarNavigationActions } from './SidebarNavigation/SidebarNavigationActions.svelte'; - /** * Single conversation item in sidebar. Displays conversation title (truncated), * last message preview, and timestamp. Shows context menu on right-click with @@ -130,6 +114,58 @@ export { default as SidebarNavigationActions } from './SidebarNavigation/Sidebar */ export { default as SidebarNavigationConversationItem } from './SidebarNavigation/SidebarNavigationConversationItem.svelte'; +/** + * **SidebarNavigationConversationList** - Grouped conversation list + * + * Pure-presentational list of conversations. Splits items into a Pinned + * section (when not in search mode) and a Recent Conversations / Search + * Results section with the unpinned items. Item selection, edit, delete, + * and stop-generation are delegated to the caller via callbacks. + * + * @example + * ```svelte + * + * ``` + */ +export { default as SidebarNavigationConversationList } from './SidebarNavigation/SidebarNavigationConversationList.svelte'; +export { default as SidebarNavigationActions } from './SidebarNavigation/SidebarNavigationActions.svelte'; + +/** + * **SidebarNavigationSearchResults** - Filtered conversation list for search. + * + * Pure-presentational rendering of the search-mode subtree: "Search results" + * header, the matching items rendered through {@link SidebarNavigationConversationItem}, + * and contextual empty-state messages. Used both inline inside + * {@link SidebarNavigationConversationList} (when search mode is active in the + * sidebar) and as the body of the mobile `/search` route. + * + * The caller is expected to provide an already-filtered list via + * `filteredConversations` and a `searchQuery` for the empty-state messages. + * + * @example + * ```svelte + * + * ``` + */ +export { default as SidebarNavigationSearchResults } from './SidebarNavigation/SidebarNavigationSearchResults.svelte'; + /** * Search input for filtering conversations in sidebar. Filters conversation * list by title as user types. Shows clear button when query is not empty. diff --git a/tools/ui/src/lib/components/app/settings/SettingsChat/SettingsChat.svelte b/tools/ui/src/lib/components/app/settings/SettingsChat/SettingsChat.svelte index 62e73d8579..41105baa5e 100644 --- a/tools/ui/src/lib/components/app/settings/SettingsChat/SettingsChat.svelte +++ b/tools/ui/src/lib/components/app/settings/SettingsChat/SettingsChat.svelte @@ -126,10 +126,7 @@ }); -
+
-